Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
HieDean
2023-11-20 09:20:50 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2023-11-20 09:20:50 +0800
Commit
e6a2d0da3bc09b307ff8b753ee9b9e833d8ae6aa
e6a2d0da
1 parent
ac00edab
Replace Clone() with View() (#432)
Co-authored-by: hiedean <hiedean@tju.edu.cn>
显示空白字符变更
内嵌
并排对比
正在显示
5 个修改的文件
包含
14 行增加
和
12 行删除
sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc
sherpa-onnx/csrc/online-rnn-lm.cc
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc
sherpa-onnx/csrc/online-wenet-ctc-model.cc
sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc
查看文件 @
e6a2d0d
...
...
@@ -94,7 +94,7 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
// now cur_encoder_out is of shape (num_hyps, joiner_dim)
Ort
::
Value
logit
=
model_
->
RunJoiner
(
std
::
move
(
cur_encoder_out
),
Clone
(
model_
->
Allocator
(),
&
decoder_out
));
std
::
move
(
cur_encoder_out
),
View
(
&
decoder_out
));
float
*
p_logit
=
logit
.
GetTensorMutableData
<
float
>
();
LogSoftmax
(
p_logit
,
vocab_size
,
num_hyps
);
...
...
sherpa-onnx/csrc/online-rnn-lm.cc
查看文件 @
e6a2d0d
...
...
@@ -67,13 +67,13 @@ class OnlineRnnLM::Impl {
return
{
std
::
move
(
out
[
0
]),
std
::
move
(
next_states
)};
}
std
::
pair
<
Ort
::
Value
,
std
::
vector
<
Ort
::
Value
>>
GetInitStates
()
const
{
std
::
pair
<
Ort
::
Value
,
std
::
vector
<
Ort
::
Value
>>
GetInitStates
()
{
std
::
vector
<
Ort
::
Value
>
ans
;
ans
.
reserve
(
init_states_
.
size
());
for
(
const
auto
&
s
:
init_states_
)
{
ans
.
emplace_back
(
Clone
(
allocator_
,
&
s
));
for
(
auto
&
s
:
init_states_
)
{
ans
.
emplace_back
(
View
(
&
s
));
}
return
{
std
::
move
(
Clone
(
allocator_
,
&
init_scores_
.
value
)
),
std
::
move
(
ans
)};
return
{
View
(
&
init_scores_
.
value
),
std
::
move
(
ans
)};
}
private
:
...
...
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
查看文件 @
e6a2d0d
...
...
@@ -99,9 +99,11 @@ void OnlineTransducerGreedySearchDecoder::Decode(
}
if
(
is_batch_decoder_out_cached
)
{
auto
&
r
=
result
->
front
();
std
::
vector
<
int64_t
>
decoder_out_shape
=
r
.
decoder_out
.
GetTensorTypeAndShapeInfo
().
GetShape
();
std
::
vector
<
int64_t
>
decoder_out_shape
=
r
.
decoder_out
.
GetTensorTypeAndShapeInfo
().
GetShape
();
decoder_out_shape
[
0
]
=
batch_size
;
decoder_out
=
Ort
::
Value
::
CreateTensor
<
float
>
(
model_
->
Allocator
(),
decoder_out_shape
.
data
(),
decoder_out_shape
.
size
());
decoder_out
=
Ort
::
Value
::
CreateTensor
<
float
>
(
model_
->
Allocator
(),
decoder_out_shape
.
data
(),
decoder_out_shape
.
size
());
UseCachedDecoderOut
(
*
result
,
&
decoder_out
);
}
else
{
Ort
::
Value
decoder_input
=
model_
->
BuildDecoderInput
(
*
result
);
...
...
@@ -112,7 +114,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
Ort
::
Value
cur_encoder_out
=
GetEncoderOutFrame
(
model_
->
Allocator
(),
&
encoder_out
,
t
);
Ort
::
Value
logit
=
model_
->
RunJoiner
(
std
::
move
(
cur_encoder_out
),
Clone
(
model_
->
Allocator
(),
&
decoder_out
));
std
::
move
(
cur_encoder_out
),
View
(
&
decoder_out
));
const
float
*
p_logit
=
logit
.
GetTensorData
<
float
>
();
...
...
sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc
查看文件 @
e6a2d0d
...
...
@@ -120,7 +120,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
cur_encoder_out
=
Repeat
(
model_
->
Allocator
(),
&
cur_encoder_out
,
hyps_row_splits
);
Ort
::
Value
logit
=
model_
->
RunJoiner
(
std
::
move
(
cur_encoder_out
),
Clone
(
model_
->
Allocator
(),
&
decoder_out
));
std
::
move
(
cur_encoder_out
),
View
(
&
decoder_out
));
float
*
p_logit
=
logit
.
GetTensorMutableData
<
float
>
();
LogSoftmax
(
p_logit
,
vocab_size
,
num_hyps
);
...
...
sherpa-onnx/csrc/online-wenet-ctc-model.cc
查看文件 @
e6a2d0d
...
...
@@ -105,11 +105,11 @@ class OnlineWenetCtcModel::Impl {
// - attn_cache
// - conv_cache
// - offset
std
::
vector
<
Ort
::
Value
>
GetInitStates
()
const
{
std
::
vector
<
Ort
::
Value
>
GetInitStates
()
{
std
::
vector
<
Ort
::
Value
>
ans
;
ans
.
reserve
(
3
);
ans
.
push_back
(
Clone
(
Allocator
(),
&
attn_cache_
));
ans
.
push_back
(
Clone
(
Allocator
(),
&
conv_cache_
));
ans
.
push_back
(
View
(
&
attn_cache_
));
ans
.
push_back
(
View
(
&
conv_cache_
));
int64_t
offset_shape
=
1
;
...
...
请
注册
或
登录
后发表评论