Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2023-09-07 15:12:29 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2023-09-07 15:12:29 +0800
Commit
a12ebfab2294f4ede5cd4c85e3e63d994851caba
a12ebfab
1 parent
ffeff3b8
treat unk as blank (#299)
隐藏空白字符变更
内嵌
并排对比
正在显示
5 个修改的文件
包含
29 行增加
和
12 行删除
sherpa-onnx/csrc/online-recognizer-transducer-impl.h
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h
sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc
sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h
sherpa-onnx/csrc/online-recognizer-transducer-impl.h
查看文件 @
a12ebfa
...
...
@@ -57,6 +57,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
model_
(
OnlineTransducerModel
::
Create
(
config
.
model_config
)),
sym_
(
config
.
model_config
.
tokens
),
endpoint_
(
config_
.
endpoint_config
)
{
if
(
sym_
.
contains
(
"<unk>"
))
{
unk_id_
=
sym_
[
"<unk>"
];
}
if
(
config
.
decoding_method
==
"modified_beam_search"
)
{
if
(
!
config_
.
lm_config
.
model
.
empty
())
{
lm_
=
OnlineLM
::
Create
(
config
.
lm_config
);
...
...
@@ -64,10 +68,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
decoder_
=
std
::
make_unique
<
OnlineTransducerModifiedBeamSearchDecoder
>
(
model_
.
get
(),
lm_
.
get
(),
config_
.
max_active_paths
,
config_
.
lm_config
.
scale
);
config_
.
lm_config
.
scale
,
unk_id_
);
}
else
if
(
config
.
decoding_method
==
"greedy_search"
)
{
decoder_
=
std
::
make_unique
<
OnlineTransducerGreedySearchDecoder
>
(
model_
.
get
());
decoder_
=
std
::
make_unique
<
OnlineTransducerGreedySearchDecoder
>
(
model_
.
get
(),
unk_id_
);
}
else
{
SHERPA_ONNX_LOGE
(
"Unsupported decoding method: %s"
,
config
.
decoding_method
.
c_str
());
...
...
@@ -82,13 +86,17 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
model_
(
OnlineTransducerModel
::
Create
(
mgr
,
config
.
model_config
)),
sym_
(
mgr
,
config
.
model_config
.
tokens
),
endpoint_
(
config_
.
endpoint_config
)
{
if
(
sym_
.
contains
(
"<unk>"
))
{
unk_id_
=
sym_
[
"<unk>"
];
}
if
(
config
.
decoding_method
==
"modified_beam_search"
)
{
decoder_
=
std
::
make_unique
<
OnlineTransducerModifiedBeamSearchDecoder
>
(
model_
.
get
(),
lm_
.
get
(),
config_
.
max_active_paths
,
config_
.
lm_config
.
scale
);
config_
.
lm_config
.
scale
,
unk_id_
);
}
else
if
(
config
.
decoding_method
==
"greedy_search"
)
{
decoder_
=
std
::
make_unique
<
OnlineTransducerGreedySearchDecoder
>
(
model_
.
get
());
decoder_
=
std
::
make_unique
<
OnlineTransducerGreedySearchDecoder
>
(
model_
.
get
(),
unk_id_
);
}
else
{
SHERPA_ONNX_LOGE
(
"Unsupported decoding method: %s"
,
config
.
decoding_method
.
c_str
());
...
...
@@ -268,6 +276,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
std
::
unique_ptr
<
OnlineTransducerDecoder
>
decoder_
;
SymbolTable
sym_
;
Endpoint
endpoint_
;
int32_t
unk_id_
=
-
1
;
};
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
查看文件 @
a12ebfa
...
...
@@ -108,7 +108,9 @@ void OnlineTransducerGreedySearchDecoder::Decode(
static_cast
<
const
float
*>
(
p_logit
),
std
::
max_element
(
static_cast
<
const
float
*>
(
p_logit
),
static_cast
<
const
float
*>
(
p_logit
)
+
vocab_size
)));
if
(
y
!=
0
)
{
// blank id is hardcoded to 0
// also, it treats unk as blank
if
(
y
!=
0
&&
y
!=
unk_id_
)
{
emitted
=
true
;
r
.
tokens
.
push_back
(
y
);
r
.
timestamps
.
push_back
(
t
+
r
.
frame_offset
);
...
...
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h
查看文件 @
a12ebfa
...
...
@@ -14,8 +14,9 @@ namespace sherpa_onnx {
class
OnlineTransducerGreedySearchDecoder
:
public
OnlineTransducerDecoder
{
public
:
explicit
OnlineTransducerGreedySearchDecoder
(
OnlineTransducerModel
*
model
)
:
model_
(
model
)
{}
OnlineTransducerGreedySearchDecoder
(
OnlineTransducerModel
*
model
,
int32_t
unk_id
)
:
model_
(
model
),
unk_id_
(
unk_id
)
{}
OnlineTransducerDecoderResult
GetEmptyResult
()
const
override
;
...
...
@@ -26,6 +27,7 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
private
:
OnlineTransducerModel
*
model_
;
// Not owned
int32_t
unk_id_
;
};
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc
查看文件 @
a12ebfa
...
...
@@ -155,7 +155,9 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
float
context_score
=
0
;
auto
context_state
=
new_hyp
.
context_state
;
if
(
new_token
!=
0
)
{
// blank is hardcoded to 0
// also, it treats unk as blank
if
(
new_token
!=
0
&&
new_token
!=
unk_id_
)
{
new_hyp
.
ys
.
push_back
(
new_token
);
new_hyp
.
timestamps
.
push_back
(
t
+
frame_offset
);
new_hyp
.
num_trailing_blanks
=
0
;
...
...
sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h
查看文件 @
a12ebfa
...
...
@@ -21,11 +21,12 @@ class OnlineTransducerModifiedBeamSearchDecoder
OnlineTransducerModifiedBeamSearchDecoder
(
OnlineTransducerModel
*
model
,
OnlineLM
*
lm
,
int32_t
max_active_paths
,
float
lm_scale
)
float
lm_scale
,
int32_t
unk_id
)
:
model_
(
model
),
lm_
(
lm
),
max_active_paths_
(
max_active_paths
),
lm_scale_
(
lm_scale
)
{}
lm_scale_
(
lm_scale
),
unk_id_
(
unk_id
)
{}
OnlineTransducerDecoderResult
GetEmptyResult
()
const
override
;
...
...
@@ -45,6 +46,7 @@ class OnlineTransducerModifiedBeamSearchDecoder
int32_t
max_active_paths_
;
float
lm_scale_
;
// used only when lm_ is not nullptr
int32_t
unk_id_
;
};
}
// namespace sherpa_onnx
...
...
请
注册
或
登录
后发表评论