Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2023-03-01 12:18:20 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2023-03-01 12:18:20 +0800
Commit
e0b76655c823c8f20896087e899ca3634ece3189
e0b76655
1 parent
ebf58552
Fix batch decoding for greedy search (#71)
隐藏空白字符变更
内嵌
并排对比
正在显示
1 个修改的文件
包含
25 行增加
和
33 行删除
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
查看文件 @
e0b7665
...
...
@@ -10,47 +10,39 @@
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace
sherpa_onnx
{
static
Ort
::
Value
GetFrame
(
Ort
::
Value
*
encoder_out
,
int32_t
t
)
{
static
Ort
::
Value
GetFrame
(
OrtAllocator
*
allocator
,
Ort
::
Value
*
encoder_out
,
int32_t
t
)
{
std
::
vector
<
int64_t
>
encoder_out_shape
=
encoder_out
->
GetTensorTypeAndShapeInfo
().
GetShape
();
assert
(
encoder_out_shape
[
0
]
==
1
);
int32_t
encoder_out_dim
=
encoder_out_shape
[
2
];
auto
batch_size
=
encoder_out_shape
[
0
];
auto
num_frames
=
encoder_out_shape
[
1
];
assert
(
t
<
num_frames
);
auto
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
std
::
array
<
int64_t
,
2
>
shape
{
1
,
encoder_out_dim
};
return
Ort
::
Value
::
CreateTensor
(
memory_info
,
encoder_out
->
GetTensorMutableData
<
float
>
()
+
t
*
encoder_out_dim
,
encoder_out_dim
,
shape
.
data
(),
shape
.
size
());
}
auto
encoder_out_dim
=
encoder_out_shape
[
2
];
static
Ort
::
Value
Repeat
(
OrtAllocator
*
allocator
,
Ort
::
Value
*
cur_encoder_out
,
int32_t
n
)
{
if
(
n
==
1
)
{
return
std
::
move
(
*
cur_encoder_out
);
}
auto
offset
=
num_frames
*
encoder_out_dim
;
std
::
vector
<
int64_t
>
cur_encoder_out_shape
=
cur_encoder_out
->
GetTensorTypeAndShapeInfo
().
GetShape
();
auto
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
std
::
array
<
int64_t
,
2
>
ans_shape
{
n
,
cur_encoder_out_shape
[
1
]
};
std
::
array
<
int64_t
,
2
>
shape
{
batch_size
,
encoder_out_dim
};
Ort
::
Value
ans
=
Ort
::
Value
::
CreateTensor
<
float
>
(
allocator
,
ans_shape
.
data
(),
ans_shape
.
size
());
Ort
::
Value
ans
=
Ort
::
Value
::
CreateTensor
<
float
>
(
allocator
,
shape
.
data
(),
shape
.
size
());
const
float
*
src
=
cur_encoder_out
->
GetTensorData
<
float
>
();
float
*
dst
=
ans
.
GetTensorMutableData
<
float
>
();
for
(
int32_t
i
=
0
;
i
!=
n
;
++
i
)
{
std
::
copy
(
src
,
src
+
cur_encoder_out_shape
[
1
],
dst
);
dst
+=
cur_encoder_out_shape
[
1
];
const
float
*
src
=
encoder_out
->
GetTensorData
<
float
>
();
for
(
int32_t
i
=
0
;
i
!=
batch_size
;
++
i
)
{
std
::
copy
(
src
+
t
*
encoder_out_dim
,
src
+
(
t
+
1
)
*
encoder_out_dim
,
dst
);
src
+=
offset
;
dst
+=
encoder_out_dim
;
}
return
ans
;
...
...
@@ -83,10 +75,10 @@ void OnlineTransducerGreedySearchDecoder::Decode(
encoder_out
.
GetTensorTypeAndShapeInfo
().
GetShape
();
if
(
encoder_out_shape
[
0
]
!=
result
->
size
())
{
fprintf
(
stderr
,
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d
\n
"
,
static_cast
<
int32_t
>
(
encoder_out_shape
[
0
]),
static_cast
<
int32_t
>
(
result
->
size
()));
SHERPA_ONNX_LOGE
(
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d"
,
static_cast
<
int32_t
>
(
encoder_out_shape
[
0
]),
static_cast
<
int32_t
>
(
result
->
size
()));
exit
(
-
1
);
}
...
...
@@ -98,10 +90,10 @@ void OnlineTransducerGreedySearchDecoder::Decode(
Ort
::
Value
decoder_out
=
model_
->
RunDecoder
(
std
::
move
(
decoder_input
));
for
(
int32_t
t
=
0
;
t
!=
num_frames
;
++
t
)
{
Ort
::
Value
cur_encoder_out
=
GetFrame
(
&
encoder_out
,
t
);
cur_encoder_out
=
Repeat
(
model_
->
Allocator
(),
&
cur_encoder_out
,
batch_size
);
Ort
::
Value
cur_encoder_out
=
GetFrame
(
model_
->
Allocator
(),
&
encoder_out
,
t
);
Ort
::
Value
logit
=
model_
->
RunJoiner
(
std
::
move
(
cur_encoder_out
),
Clone
(
model_
->
Allocator
(),
&
decoder_out
));
const
float
*
p_logit
=
logit
.
GetTensorData
<
float
>
();
bool
emitted
=
false
;
...
...
请
注册
或
登录
后发表评论