Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2023-03-05 22:02:50 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2023-03-05 22:02:50 +0800
Commit
da5c80cc74b80b52f825db34e39625121f1c6d6f
da5c80cc
1 parent
7cae7107
add pad_sequence (#84)
隐藏空白字符变更
内嵌
并排对比
正在显示
4 个修改的文件
包含
129 行增加
和
0 行删除
sherpa-onnx/csrc/CMakeLists.txt
sherpa-onnx/csrc/pad-sequence-test.cc
sherpa-onnx/csrc/pad-sequence.cc
sherpa-onnx/csrc/pad-sequence.h
sherpa-onnx/csrc/CMakeLists.txt
查看文件 @
da5c80c
...
...
@@ -16,6 +16,7 @@ set(sources
online-transducer-modified-beam-search-decoder.cc
online-zipformer-transducer-model.cc
onnx-utils.cc
pad-sequence.cc
parse-options.cc
resample.cc
slice.cc
...
...
@@ -122,6 +123,7 @@ endif()
if
(
SHERPA_ONNX_ENABLE_TESTS
)
set
(
sherpa_onnx_test_srcs
cat-test.cc
pad-sequence-test.cc
slice-test.cc
transpose-test.cc
unbind-test.cc
...
...
sherpa-onnx/csrc/pad-sequence-test.cc
0 → 100644
查看文件 @
da5c80c
// sherpa-onnx/csrc/pad-sequence-test.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/pad-sequence.h"
#include <numeric>
#include "gtest/gtest.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace
sherpa_onnx
{
TEST
(
PadSequence
,
ThreeTensors
)
{
Ort
::
AllocatorWithDefaultOptions
allocator
;
std
::
array
<
int64_t
,
2
>
shape1
{
3
,
5
};
Ort
::
Value
v1
=
Ort
::
Value
::
CreateTensor
<
float
>
(
allocator
,
shape1
.
data
(),
shape1
.
size
());
float
*
p1
=
v1
.
GetTensorMutableData
<
float
>
();
std
::
iota
(
p1
,
p1
+
shape1
[
0
]
*
shape1
[
1
],
0
);
std
::
array
<
int64_t
,
2
>
shape2
{
4
,
5
};
Ort
::
Value
v2
=
Ort
::
Value
::
CreateTensor
<
float
>
(
allocator
,
shape2
.
data
(),
shape2
.
size
());
float
*
p2
=
v2
.
GetTensorMutableData
<
float
>
();
std
::
iota
(
p2
,
p2
+
shape2
[
0
]
*
shape2
[
1
],
0
);
std
::
array
<
int64_t
,
2
>
shape3
{
2
,
5
};
Ort
::
Value
v3
=
Ort
::
Value
::
CreateTensor
<
float
>
(
allocator
,
shape3
.
data
(),
shape3
.
size
());
float
*
p3
=
v3
.
GetTensorMutableData
<
float
>
();
std
::
iota
(
p3
,
p3
+
shape3
[
0
]
*
shape3
[
1
],
0
);
auto
ans
=
PadSequence
(
allocator
,
{
&
v1
,
&
v2
,
&
v3
},
-
1
);
Print2D
(
&
v1
);
Print2D
(
&
v2
);
Print2D
(
&
v3
);
Print3D
(
&
ans
);
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/pad-sequence.cc
0 → 100644
查看文件 @
da5c80c
// sherpa-onnx/csrc/pad-sequence.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/pad-sequence.h"
#include <assert.h>
#include <algorithm>
#include <vector>
namespace
sherpa_onnx
{
Ort
::
Value
PadSequence
(
OrtAllocator
*
allocator
,
const
std
::
vector
<
const
Ort
::
Value
*>
&
values
,
float
padding_value
)
{
int32_t
batch_size
=
static_cast
<
int32_t
>
(
values
.
size
());
std
::
vector
<
int64_t
>
shape0
=
values
[
0
]
->
GetTensorTypeAndShapeInfo
().
GetShape
();
assert
(
shape0
.
size
()
==
2
);
auto
feature_dim
=
shape0
[
1
];
auto
max_T
=
shape0
[
0
];
for
(
int32_t
i
=
1
;
i
!=
batch_size
;
++
i
)
{
auto
shape
=
values
[
i
]
->
GetTensorTypeAndShapeInfo
().
GetShape
();
assert
(
shape
.
size
()
==
2
);
assert
(
shape
[
1
]
==
feature_dim
);
max_T
=
std
::
max
(
max_T
,
shape
[
0
]);
}
std
::
array
<
int64_t
,
3
>
ans_shape
{
batch_size
,
max_T
,
feature_dim
};
Ort
::
Value
ans
=
Ort
::
Value
::
CreateTensor
<
float
>
(
allocator
,
ans_shape
.
data
(),
ans_shape
.
size
());
float
*
dst
=
ans
.
GetTensorMutableData
<
float
>
();
std
::
fill
(
dst
,
dst
+
batch_size
*
max_T
*
feature_dim
,
padding_value
);
for
(
const
auto
*
v
:
values
)
{
const
float
*
src
=
v
->
GetTensorData
<
float
>
();
auto
shape
=
v
->
GetTensorTypeAndShapeInfo
().
GetShape
();
std
::
copy
(
src
,
src
+
shape
[
0
]
*
shape
[
1
],
dst
);
dst
+=
max_T
*
feature_dim
;
}
return
ans
;
// TODO(fangjun): Check that the returned value is correct.
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/pad-sequence.h
0 → 100644
查看文件 @
da5c80c
// sherpa-onnx/csrc/pad-sequence.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_PAD_SEQUENCE_H_
#define SHERPA_ONNX_CSRC_PAD_SEQUENCE_H_
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
namespace
sherpa_onnx
{
/** Similar to torch.nn.utils.rnn.pad_sequence but it supports only
* batch_first=true.
*
* @param allocator
* @param values A list of 2-D tensors. Each tensor's second dimension
* must be the same and the data type of each tensor should
* be float.
* @param padding_value Value used for padding. For log-fbank, you usually use
* -23.025850929940457f as the padding value.
*
* @return Return a 3-D tensor of shape (B, max_T, C).
*/
Ort
::
Value
PadSequence
(
OrtAllocator
*
allocator
,
const
std
::
vector
<
const
Ort
::
Value
*>
&
values
,
float
padding_value
);
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_PAD_SEQUENCE_H_
...
...
请
注册
或
登录
后发表评论