Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Wei Kang
2024-05-31 12:34:30 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2024-05-31 12:34:30 +0800
Commit
a38881817c1c68946d5a344621ee271107e5dbcc
a3888181
1 parent
a689249f
Support customize scores for hotwords (#926)
* Support customize scores for hotwords * Skip blank lines
显示空白字符变更
内嵌
并排对比
正在显示
6 个修改的文件
包含
103 行增加
和
35 行删除
sherpa-onnx/csrc/context-graph.h
sherpa-onnx/csrc/offline-recognizer-transducer-impl.h
sherpa-onnx/csrc/online-recognizer-transducer-impl.h
sherpa-onnx/csrc/text2token-test.cc
sherpa-onnx/csrc/utils.cc
sherpa-onnx/csrc/utils.h
sherpa-onnx/csrc/context-graph.h
查看文件 @
a388818
...
...
@@ -61,10 +61,9 @@ class ContextGraph {
}
ContextGraph
(
const
std
::
vector
<
std
::
vector
<
int32_t
>>
&
token_ids
,
float
context_score
,
const
std
::
vector
<
float
>
&
scores
=
{},
const
std
::
vector
<
std
::
string
>
&
phrases
=
{})
:
ContextGraph
(
token_ids
,
context_score
,
0
.
0
f
,
scores
,
phrases
,
std
::
vector
<
float
>
())
{}
float
context_score
,
const
std
::
vector
<
float
>
&
scores
=
{})
:
ContextGraph
(
token_ids
,
context_score
,
0
.
0
f
,
scores
,
std
::
vector
<
std
::
string
>
(),
std
::
vector
<
float
>
())
{}
std
::
tuple
<
float
,
const
ContextState
*
,
const
ContextState
*>
ForwardOneStep
(
const
ContextState
*
state
,
int32_t
token_id
,
...
...
sherpa-onnx/csrc/offline-recognizer-transducer-impl.h
查看文件 @
a388818
...
...
@@ -145,15 +145,35 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
auto
hws
=
std
::
regex_replace
(
hotwords
,
std
::
regex
(
"/"
),
"
\n
"
);
std
::
istringstream
is
(
hws
);
std
::
vector
<
std
::
vector
<
int32_t
>>
current
;
std
::
vector
<
float
>
current_scores
;
if
(
!
EncodeHotwords
(
is
,
config_
.
model_config
.
modeling_unit
,
symbol_table_
,
bpe_encoder_
.
get
(),
&
current
))
{
bpe_encoder_
.
get
(),
&
current
,
&
current_scores
))
{
SHERPA_ONNX_LOGE
(
"Encode hotwords failed, skipping, hotwords are : %s"
,
hotwords
.
c_str
());
}
int32_t
num_default_hws
=
hotwords_
.
size
();
int32_t
num_hws
=
current
.
size
();
current
.
insert
(
current
.
end
(),
hotwords_
.
begin
(),
hotwords_
.
end
());
auto
context_graph
=
std
::
make_shared
<
ContextGraph
>
(
current
,
config_
.
hotwords_score
);
if
(
!
current_scores
.
empty
()
&&
!
boost_scores_
.
empty
())
{
current_scores
.
insert
(
current_scores
.
end
(),
boost_scores_
.
begin
(),
boost_scores_
.
end
());
}
else
if
(
!
current_scores
.
empty
()
&&
boost_scores_
.
empty
())
{
current_scores
.
insert
(
current_scores
.
end
(),
num_default_hws
,
config_
.
hotwords_score
);
}
else
if
(
current_scores
.
empty
()
&&
!
boost_scores_
.
empty
())
{
current_scores
.
insert
(
current_scores
.
end
(),
num_hws
,
config_
.
hotwords_score
);
current_scores
.
insert
(
current_scores
.
end
(),
boost_scores_
.
begin
(),
boost_scores_
.
end
());
}
else
{
// Do nothing.
}
auto
context_graph
=
std
::
make_shared
<
ContextGraph
>
(
current
,
config_
.
hotwords_score
,
current_scores
);
return
std
::
make_unique
<
OfflineStream
>
(
config_
.
feat_config
,
context_graph
);
}
...
...
@@ -226,13 +246,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
}
if
(
!
EncodeHotwords
(
is
,
config_
.
model_config
.
modeling_unit
,
symbol_table_
,
bpe_encoder_
.
get
(),
&
hotwords_
))
{
bpe_encoder_
.
get
(),
&
hotwords_
,
&
boost_scores_
))
{
SHERPA_ONNX_LOGE
(
"Failed to encode some hotwords, skip them already, see logs above "
"for details."
);
}
hotwords_graph_
=
std
::
make_shared
<
ContextGraph
>
(
hotwords_
,
config_
.
hotwords_score
);
hotwords_graph_
=
std
::
make_shared
<
ContextGraph
>
(
hotwords_
,
config_
.
hotwords_score
,
boost_scores_
);
}
#if __ANDROID_API__ >= 9
...
...
@@ -250,13 +270,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
}
if
(
!
EncodeHotwords
(
is
,
config_
.
model_config
.
modeling_unit
,
symbol_table_
,
bpe_encoder_
.
get
(),
&
hotwords_
))
{
bpe_encoder_
.
get
(),
&
hotwords_
,
&
boost_scores_
))
{
SHERPA_ONNX_LOGE
(
"Failed to encode some hotwords, skip them already, see logs above "
"for details."
);
}
hotwords_graph_
=
std
::
make_shared
<
ContextGraph
>
(
hotwords_
,
config_
.
hotwords_score
);
hotwords_graph_
=
std
::
make_shared
<
ContextGraph
>
(
hotwords_
,
config_
.
hotwords_score
,
boost_scores_
);
}
#endif
...
...
@@ -264,6 +284,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
OfflineRecognizerConfig
config_
;
SymbolTable
symbol_table_
;
std
::
vector
<
std
::
vector
<
int32_t
>>
hotwords_
;
std
::
vector
<
float
>
boost_scores_
;
ContextGraphPtr
hotwords_graph_
;
std
::
unique_ptr
<
ssentencepiece
::
Ssentencepiece
>
bpe_encoder_
;
std
::
unique_ptr
<
OfflineTransducerModel
>
model_
;
...
...
sherpa-onnx/csrc/online-recognizer-transducer-impl.h
查看文件 @
a388818
...
...
@@ -182,14 +182,35 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
auto
hws
=
std
::
regex_replace
(
hotwords
,
std
::
regex
(
"/"
),
"
\n
"
);
std
::
istringstream
is
(
hws
);
std
::
vector
<
std
::
vector
<
int32_t
>>
current
;
std
::
vector
<
float
>
current_scores
;
if
(
!
EncodeHotwords
(
is
,
config_
.
model_config
.
modeling_unit
,
sym_
,
bpe_encoder_
.
get
(),
&
current
))
{
bpe_encoder_
.
get
(),
&
current
,
&
current_scores
))
{
SHERPA_ONNX_LOGE
(
"Encode hotwords failed, skipping, hotwords are : %s"
,
hotwords
.
c_str
());
}
int32_t
num_default_hws
=
hotwords_
.
size
();
int32_t
num_hws
=
current
.
size
();
current
.
insert
(
current
.
end
(),
hotwords_
.
begin
(),
hotwords_
.
end
());
auto
context_graph
=
std
::
make_shared
<
ContextGraph
>
(
current
,
config_
.
hotwords_score
);
if
(
!
current_scores
.
empty
()
&&
!
boost_scores_
.
empty
())
{
current_scores
.
insert
(
current_scores
.
end
(),
boost_scores_
.
begin
(),
boost_scores_
.
end
());
}
else
if
(
!
current_scores
.
empty
()
&&
boost_scores_
.
empty
())
{
current_scores
.
insert
(
current_scores
.
end
(),
num_default_hws
,
config_
.
hotwords_score
);
}
else
if
(
current_scores
.
empty
()
&&
!
boost_scores_
.
empty
())
{
current_scores
.
insert
(
current_scores
.
end
(),
num_hws
,
config_
.
hotwords_score
);
current_scores
.
insert
(
current_scores
.
end
(),
boost_scores_
.
begin
(),
boost_scores_
.
end
());
}
else
{
// Do nothing.
}
auto
context_graph
=
std
::
make_shared
<
ContextGraph
>
(
current
,
config_
.
hotwords_score
,
current_scores
);
auto
stream
=
std
::
make_unique
<
OnlineStream
>
(
config_
.
feat_config
,
context_graph
);
InitOnlineStream
(
stream
.
get
());
...
...
@@ -376,13 +397,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
if
(
!
EncodeHotwords
(
is
,
config_
.
model_config
.
modeling_unit
,
sym_
,
bpe_encoder_
.
get
(),
&
hotwords_
))
{
bpe_encoder_
.
get
(),
&
hotwords_
,
&
boost_scores_
))
{
SHERPA_ONNX_LOGE
(
"Failed to encode some hotwords, skip them already, see logs above "
"for details."
);
}
hotwords_graph_
=
std
::
make_shared
<
ContextGraph
>
(
hotwords_
,
config_
.
hotwords_score
);
hotwords_graph_
=
std
::
make_shared
<
ContextGraph
>
(
hotwords_
,
config_
.
hotwords_score
,
boost_scores_
);
}
#if __ANDROID_API__ >= 9
...
...
@@ -400,13 +421,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
if
(
!
EncodeHotwords
(
is
,
config_
.
model_config
.
modeling_unit
,
sym_
,
bpe_encoder_
.
get
(),
&
hotwords_
))
{
bpe_encoder_
.
get
(),
&
hotwords_
,
&
boost_scores_
))
{
SHERPA_ONNX_LOGE
(
"Failed to encode some hotwords, skip them already, see logs above "
"for details."
);
}
hotwords_graph_
=
std
::
make_shared
<
ContextGraph
>
(
hotwords_
,
config_
.
hotwords_score
);
hotwords_graph_
=
std
::
make_shared
<
ContextGraph
>
(
hotwords_
,
config_
.
hotwords_score
,
boost_scores_
);
}
#endif
...
...
@@ -428,6 +449,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
private
:
OnlineRecognizerConfig
config_
;
std
::
vector
<
std
::
vector
<
int32_t
>>
hotwords_
;
std
::
vector
<
float
>
boost_scores_
;
ContextGraphPtr
hotwords_graph_
;
std
::
unique_ptr
<
ssentencepiece
::
Ssentencepiece
>
bpe_encoder_
;
std
::
unique_ptr
<
OnlineTransducerModel
>
model_
;
...
...
sherpa-onnx/csrc/text2token-test.cc
查看文件 @
a388818
...
...
@@ -35,17 +35,21 @@ TEST(TEXT2TOKEN, TEST_cjkchar) {
auto
sym_table
=
SymbolTable
(
tokens
);
std
::
string
text
=
"世界人民大团结
\n
中国 V S 美国"
;
std
::
string
text
=
"世界人民大团结
\n
中国 V S 美国
\n\n
"
;
// Test blank lines also
std
::
istringstream
iss
(
text
);
std
::
vector
<
std
::
vector
<
int32_t
>>
ids
;
std
::
vector
<
float
>
scores
;
auto
r
=
EncodeHotwords
(
iss
,
"cjkchar"
,
sym_table
,
nullptr
,
&
ids
);
auto
r
=
EncodeHotwords
(
iss
,
"cjkchar"
,
sym_table
,
nullptr
,
&
ids
,
&
scores
);
std
::
vector
<
std
::
vector
<
int32_t
>>
expected_ids
(
{{
379
,
380
,
72
,
874
,
93
,
1251
,
489
},
{
262
,
147
,
3423
,
2476
,
21
,
147
}});
EXPECT_EQ
(
ids
,
expected_ids
);
EXPECT_EQ
(
scores
.
size
(),
0
);
}
TEST
(
TEXT2TOKEN
,
TEST_bpe
)
{
...
...
@@ -68,17 +72,22 @@ TEST(TEXT2TOKEN, TEST_bpe) {
auto
sym_table
=
SymbolTable
(
tokens
);
auto
bpe_processor
=
std
::
make_unique
<
ssentencepiece
::
Ssentencepiece
>
(
bpe
);
std
::
string
text
=
"HELLO WORLD
\n
I LOVE YOU"
;
std
::
string
text
=
"HELLO WORLD
\n
I LOVE YOU
:2.0
"
;
std
::
istringstream
iss
(
text
);
std
::
vector
<
std
::
vector
<
int32_t
>>
ids
;
std
::
vector
<
float
>
scores
;
auto
r
=
EncodeHotwords
(
iss
,
"bpe"
,
sym_table
,
bpe_processor
.
get
(),
&
ids
);
auto
r
=
EncodeHotwords
(
iss
,
"bpe"
,
sym_table
,
bpe_processor
.
get
(),
&
ids
,
&
scores
);
std
::
vector
<
std
::
vector
<
int32_t
>>
expected_ids
(
{{
22
,
58
,
24
,
425
},
{
19
,
370
,
47
}});
EXPECT_EQ
(
ids
,
expected_ids
);
std
::
vector
<
float
>
expected_scores
({
0
,
2.0
});
EXPECT_EQ
(
scores
,
expected_scores
);
}
TEST
(
TEXT2TOKEN
,
TEST_cjkchar_bpe
)
{
...
...
@@ -101,19 +110,23 @@ TEST(TEXT2TOKEN, TEST_cjkchar_bpe) {
auto
sym_table
=
SymbolTable
(
tokens
);
auto
bpe_processor
=
std
::
make_unique
<
ssentencepiece
::
Ssentencepiece
>
(
bpe
);
std
::
string
text
=
"世界人民 GOES TOGETHER
\n
中国 GOES WITH 美国
"
;
std
::
string
text
=
"世界人民 GOES TOGETHER
:1.5
\n
中国 GOES WITH 美国 :0.5
"
;
std
::
istringstream
iss
(
text
);
std
::
vector
<
std
::
vector
<
int32_t
>>
ids
;
std
::
vector
<
float
>
scores
;
auto
r
=
EncodeHotwords
(
iss
,
"cjkchar+bpe"
,
sym_table
,
bpe_processor
.
get
(),
&
ids
);
auto
r
=
EncodeHotwords
(
iss
,
"cjkchar+bpe"
,
sym_table
,
bpe_processor
.
get
(),
&
ids
,
&
scores
);
std
::
vector
<
std
::
vector
<
int32_t
>>
expected_ids
(
{{
1368
,
1392
,
557
,
680
,
275
,
178
,
475
},
{
685
,
736
,
275
,
178
,
179
,
921
,
736
}});
EXPECT_EQ
(
ids
,
expected_ids
);
std
::
vector
<
float
>
expected_scores
({
1.5
,
0.5
});
EXPECT_EQ
(
scores
,
expected_scores
);
}
TEST
(
TEXT2TOKEN
,
TEST_bbpe
)
{
...
...
@@ -136,17 +149,22 @@ TEST(TEXT2TOKEN, TEST_bbpe) {
auto
sym_table
=
SymbolTable
(
tokens
);
auto
bpe_processor
=
std
::
make_unique
<
ssentencepiece
::
Ssentencepiece
>
(
bpe
);
std
::
string
text
=
"频繁
\n
李鞑靼"
;
std
::
string
text
=
"频繁
:1.0
\n
李鞑靼"
;
std
::
istringstream
iss
(
text
);
std
::
vector
<
std
::
vector
<
int32_t
>>
ids
;
std
::
vector
<
float
>
scores
;
auto
r
=
EncodeHotwords
(
iss
,
"bpe"
,
sym_table
,
bpe_processor
.
get
(),
&
ids
);
auto
r
=
EncodeHotwords
(
iss
,
"bpe"
,
sym_table
,
bpe_processor
.
get
(),
&
ids
,
&
scores
);
std
::
vector
<
std
::
vector
<
int32_t
>>
expected_ids
(
{{
259
,
1118
,
234
,
188
,
132
},
{
259
,
1585
,
236
,
161
,
148
,
236
,
160
,
191
}});
EXPECT_EQ
(
ids
,
expected_ids
);
std
::
vector
<
float
>
expected_scores
({
1.0
,
0
});
EXPECT_EQ
(
scores
,
expected_scores
);
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/utils.cc
查看文件 @
a388818
...
...
@@ -103,7 +103,8 @@ static bool EncodeBase(const std::vector<std::string> &lines,
bool
EncodeHotwords
(
std
::
istream
&
is
,
const
std
::
string
&
modeling_unit
,
const
SymbolTable
&
symbol_table
,
const
ssentencepiece
::
Ssentencepiece
*
bpe_encoder
,
std
::
vector
<
std
::
vector
<
int32_t
>>
*
hotwords
)
{
std
::
vector
<
std
::
vector
<
int32_t
>>
*
hotwords
,
std
::
vector
<
float
>
*
boost_scores
)
{
std
::
vector
<
std
::
string
>
lines
;
std
::
string
line
;
std
::
string
word
;
...
...
@@ -131,7 +132,12 @@ bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
break
;
}
}
phrase
=
oss
.
str
().
substr
(
1
);
phrase
=
oss
.
str
();
if
(
phrase
.
empty
())
{
continue
;
}
else
{
phrase
=
phrase
.
substr
(
1
);
}
std
::
istringstream
piss
(
phrase
);
oss
.
clear
();
oss
.
str
(
""
);
...
...
@@ -177,7 +183,8 @@ bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
}
lines
.
push_back
(
oss
.
str
());
}
return
EncodeBase
(
lines
,
symbol_table
,
hotwords
,
nullptr
,
nullptr
,
nullptr
);
return
EncodeBase
(
lines
,
symbol_table
,
hotwords
,
nullptr
,
boost_scores
,
nullptr
);
}
bool
EncodeKeywords
(
std
::
istream
&
is
,
const
SymbolTable
&
symbol_table
,
...
...
sherpa-onnx/csrc/utils.h
查看文件 @
a388818
...
...
@@ -29,7 +29,8 @@ namespace sherpa_onnx {
bool
EncodeHotwords
(
std
::
istream
&
is
,
const
std
::
string
&
modeling_unit
,
const
SymbolTable
&
symbol_table
,
const
ssentencepiece
::
Ssentencepiece
*
bpe_encoder_
,
std
::
vector
<
std
::
vector
<
int32_t
>>
*
hotwords_id
);
std
::
vector
<
std
::
vector
<
int32_t
>>
*
hotwords_id
,
std
::
vector
<
float
>
*
boost_scores
);
/* Encode the keywords in an input stream to be tokens ids.
*
...
...
请
注册
或
登录
后发表评论