Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Zhong-Yi Li
2024-06-19 20:52:42 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2024-06-19 20:52:42 +0800
Commit
675fb1574f82d591df362f6a0a52f69a1a9fc647
675fb157
1 parent
a11c8599
offline transducer: treat unk as blank (#1005)
Co-authored-by: chungyi.li <chungyi.li@ailabs.tw>
隐藏空白字符变更
内嵌
并排对比
正在显示
5 个修改的文件
包含
25 行增加
和
9 行删除
sherpa-onnx/csrc/offline-recognizer-transducer-impl.h
sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc
sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h
sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc
sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h
sherpa-onnx/csrc/offline-recognizer-transducer-impl.h
查看文件 @
675fb15
...
...
@@ -78,9 +78,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
config_
(
config
),
symbol_table_
(
config_
.
model_config
.
tokens
),
model_
(
std
::
make_unique
<
OfflineTransducerModel
>
(
config_
.
model_config
))
{
if
(
symbol_table_
.
Contains
(
"<unk>"
))
{
unk_id_
=
symbol_table_
[
"<unk>"
];
}
if
(
config_
.
decoding_method
==
"greedy_search"
)
{
decoder_
=
std
::
make_unique
<
OfflineTransducerGreedySearchDecoder
>
(
model_
.
get
(),
config_
.
blank_penalty
);
model_
.
get
(),
unk_id_
,
config_
.
blank_penalty
);
}
else
if
(
config_
.
decoding_method
==
"modified_beam_search"
)
{
if
(
!
config_
.
lm_config
.
model
.
empty
())
{
lm_
=
OfflineLM
::
Create
(
config
.
lm_config
);
...
...
@@ -97,7 +101,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
decoder_
=
std
::
make_unique
<
OfflineTransducerModifiedBeamSearchDecoder
>
(
model_
.
get
(),
lm_
.
get
(),
config_
.
max_active_paths
,
config_
.
lm_config
.
scale
,
config_
.
blank_penalty
);
config_
.
lm_config
.
scale
,
unk_id_
,
config_
.
blank_penalty
);
}
else
{
SHERPA_ONNX_LOGE
(
"Unsupported decoding method: %s"
,
config_
.
decoding_method
.
c_str
());
...
...
@@ -113,9 +117,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
symbol_table_
(
mgr
,
config_
.
model_config
.
tokens
),
model_
(
std
::
make_unique
<
OfflineTransducerModel
>
(
mgr
,
config_
.
model_config
))
{
if
(
symbol_table_
.
Contains
(
"<unk>"
))
{
unk_id_
=
symbol_table_
[
"<unk>"
];
}
if
(
config_
.
decoding_method
==
"greedy_search"
)
{
decoder_
=
std
::
make_unique
<
OfflineTransducerGreedySearchDecoder
>
(
model_
.
get
(),
config_
.
blank_penalty
);
model_
.
get
(),
unk_id_
,
config_
.
blank_penalty
);
}
else
if
(
config_
.
decoding_method
==
"modified_beam_search"
)
{
if
(
!
config_
.
lm_config
.
model
.
empty
())
{
lm_
=
OfflineLM
::
Create
(
mgr
,
config
.
lm_config
);
...
...
@@ -133,7 +141,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
decoder_
=
std
::
make_unique
<
OfflineTransducerModifiedBeamSearchDecoder
>
(
model_
.
get
(),
lm_
.
get
(),
config_
.
max_active_paths
,
config_
.
lm_config
.
scale
,
config_
.
blank_penalty
);
config_
.
lm_config
.
scale
,
unk_id_
,
config_
.
blank_penalty
);
}
else
{
SHERPA_ONNX_LOGE
(
"Unsupported decoding method: %s"
,
config_
.
decoding_method
.
c_str
());
...
...
@@ -293,6 +301,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
std
::
unique_ptr
<
OfflineTransducerModel
>
model_
;
std
::
unique_ptr
<
OfflineTransducerDecoder
>
decoder_
;
std
::
unique_ptr
<
OfflineLM
>
lm_
;
int32_t
unk_id_
=
-
1
;
};
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc
查看文件 @
675fb15
...
...
@@ -57,7 +57,9 @@ OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out,
std
::
max_element
(
static_cast
<
const
float
*>
(
p_logit
),
static_cast
<
const
float
*>
(
p_logit
)
+
vocab_size
)));
p_logit
+=
vocab_size
;
if
(
y
!=
0
)
{
// blank id is hardcoded to 0
// also, it treats unk as blank
if
(
y
!=
0
&&
y
!=
unk_id_
)
{
ans
[
i
].
tokens
.
push_back
(
y
);
ans
[
i
].
timestamps
.
push_back
(
t
);
emitted
=
true
;
...
...
sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h
查看文件 @
675fb15
...
...
@@ -15,8 +15,9 @@ namespace sherpa_onnx {
class
OfflineTransducerGreedySearchDecoder
:
public
OfflineTransducerDecoder
{
public
:
OfflineTransducerGreedySearchDecoder
(
OfflineTransducerModel
*
model
,
int32_t
unk_id
,
float
blank_penalty
)
:
model_
(
model
),
blank_penalty_
(
blank_penalty
)
{}
:
model_
(
model
),
unk_id_
(
unk_id
),
blank_penalty_
(
blank_penalty
)
{}
std
::
vector
<
OfflineTransducerDecoderResult
>
Decode
(
Ort
::
Value
encoder_out
,
Ort
::
Value
encoder_out_length
,
...
...
@@ -24,6 +25,7 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
private
:
OfflineTransducerModel
*
model_
;
// Not owned
int32_t
unk_id_
;
float
blank_penalty_
;
};
...
...
sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc
查看文件 @
675fb15
...
...
@@ -131,8 +131,9 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
float
context_score
=
0
;
auto
context_state
=
new_hyp
.
context_state
;
if
(
new_token
!=
0
)
{
// blank id is fixed to 0
// blank is hardcoded to 0
// also, it treats unk as blank
if
(
new_token
!=
0
&&
new_token
!=
unk_id_
)
{
new_hyp
.
ys
.
push_back
(
new_token
);
new_hyp
.
timestamps
.
push_back
(
t
);
if
(
context_graphs
[
i
]
!=
nullptr
)
{
...
...
sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h
查看文件 @
675fb15
...
...
@@ -19,12 +19,13 @@ class OfflineTransducerModifiedBeamSearchDecoder
OfflineTransducerModifiedBeamSearchDecoder
(
OfflineTransducerModel
*
model
,
OfflineLM
*
lm
,
int32_t
max_active_paths
,
float
lm_scale
,
float
lm_scale
,
int32_t
unk_id
,
float
blank_penalty
)
:
model_
(
model
),
lm_
(
lm
),
max_active_paths_
(
max_active_paths
),
lm_scale_
(
lm_scale
),
unk_id_
(
unk_id
),
blank_penalty_
(
blank_penalty
)
{}
std
::
vector
<
OfflineTransducerDecoderResult
>
Decode
(
...
...
@@ -37,6 +38,7 @@ class OfflineTransducerModifiedBeamSearchDecoder
int32_t
max_active_paths_
;
float
lm_scale_
;
// used only when lm_ is not nullptr
int32_t
unk_id_
;
float
blank_penalty_
;
};
...
...
请
注册
或
登录
后发表评论