Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
chiiyeh
2024-01-26 12:12:13 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2024-01-26 12:12:13 +0800
Commit
e7b18a2139d865ccf21a452e09a7aa2fd44f15da
e7b18a21
1 parent
466a6855
add blank_penalty for online transducer (#548)
隐藏空白字符变更
内嵌
并排对比
正在显示
13 个修改的文件
包含
94 行增加
和
14 行删除
python-api-examples/online-decode-files.py
python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py
python-api-examples/speech-recognition-from-microphone.py
python-api-examples/streaming_server.py
sherpa-onnx/csrc/online-recognizer-transducer-impl.h
sherpa-onnx/csrc/online-recognizer.cc
sherpa-onnx/csrc/online-recognizer.h
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h
sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc
sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h
sherpa-onnx/python/csrc/online-recognizer.cc
sherpa-onnx/python/sherpa_onnx/online_recognizer.py
python-api-examples/online-decode-files.py
查看文件 @
e7b18a2
...
...
@@ -217,6 +217,18 @@ def get_args():
)
parser
.
add_argument
(
"--blank-penalty"
,
type
=
float
,
default
=
0.0
,
help
=
"""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
"""
,
)
parser
.
add_argument
(
"sound_files"
,
type
=
str
,
nargs
=
"+"
,
...
...
@@ -290,6 +302,7 @@ def main():
lm_scale
=
args
.
lm_scale
,
hotwords_file
=
args
.
hotwords_file
,
hotwords_score
=
args
.
hotwords_score
,
blank_penalty
=
args
.
blank_penalty
,
)
elif
args
.
zipformer2_ctc
:
recognizer
=
sherpa_onnx
.
OnlineRecognizer
.
from_zipformer2_ctc
(
...
...
python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py
查看文件 @
e7b18a2
...
...
@@ -102,6 +102,17 @@ def get_args():
"""
,
)
parser
.
add_argument
(
"--blank-penalty"
,
type
=
float
,
default
=
0.0
,
help
=
"""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
"""
,
)
return
parser
.
parse_args
()
...
...
@@ -130,6 +141,7 @@ def create_recognizer(args):
provider
=
args
.
provider
,
hotwords_file
=
args
.
hotwords_file
,
hotwords_score
=
args
.
hotwords_score
,
blank_penalty
=
args
.
blank_penalty
,
)
return
recognizer
...
...
python-api-examples/speech-recognition-from-microphone.py
查看文件 @
e7b18a2
...
...
@@ -111,6 +111,17 @@ def get_args():
"""
,
)
parser
.
add_argument
(
"--blank-penalty"
,
type
=
float
,
default
=
0.0
,
help
=
"""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
"""
,
)
return
parser
.
parse_args
()
...
...
@@ -136,6 +147,7 @@ def create_recognizer(args):
provider
=
args
.
provider
,
hotwords_file
=
args
.
hotwords_file
,
hotwords_score
=
args
.
hotwords_score
,
blank_penalty
=
args
.
blank_penalty
,
)
return
recognizer
...
...
python-api-examples/streaming_server.py
查看文件 @
e7b18a2
...
...
@@ -241,6 +241,18 @@ def add_modified_beam_search_args(parser: argparse.ArgumentParser):
"""
,
)
def
add_blank_penalty_args
(
parser
:
argparse
.
ArgumentParser
):
parser
.
add_argument
(
"--blank-penalty"
,
type
=
float
,
default
=
0.0
,
help
=
"""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
"""
,
)
def
add_endpointing_args
(
parser
:
argparse
.
ArgumentParser
):
parser
.
add_argument
(
...
...
@@ -284,6 +296,7 @@ def get_args():
add_decoding_args
(
parser
)
add_endpointing_args
(
parser
)
add_hotwords_args
(
parser
)
add_blank_penalty_args
(
parser
)
parser
.
add_argument
(
"--port"
,
...
...
@@ -390,6 +403,7 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
max_active_paths
=
args
.
num_active_paths
,
hotwords_score
=
args
.
hotwords_score
,
hotwords_file
=
args
.
hotwords_file
,
blank_penalty
=
args
.
blank_penalty
,
enable_endpoint_detection
=
args
.
use_endpoint
!=
0
,
rule1_min_trailing_silence
=
args
.
rule1_min_trailing_silence
,
rule2_min_trailing_silence
=
args
.
rule2_min_trailing_silence
,
...
...
sherpa-onnx/csrc/online-recognizer-transducer-impl.h
查看文件 @
e7b18a2
...
...
@@ -95,10 +95,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
decoder_
=
std
::
make_unique
<
OnlineTransducerModifiedBeamSearchDecoder
>
(
model_
.
get
(),
lm_
.
get
(),
config_
.
max_active_paths
,
config_
.
lm_config
.
scale
,
unk_id_
);
config_
.
lm_config
.
scale
,
unk_id_
,
config_
.
blank_penalty
);
}
else
if
(
config
.
decoding_method
==
"greedy_search"
)
{
decoder_
=
std
::
make_unique
<
OnlineTransducerGreedySearchDecoder
>
(
model_
.
get
(),
unk_id_
);
model_
.
get
(),
unk_id_
,
config_
.
blank_penalty
);
}
else
{
SHERPA_ONNX_LOGE
(
"Unsupported decoding method: %s"
,
config
.
decoding_method
.
c_str
());
...
...
@@ -131,10 +131,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
decoder_
=
std
::
make_unique
<
OnlineTransducerModifiedBeamSearchDecoder
>
(
model_
.
get
(),
lm_
.
get
(),
config_
.
max_active_paths
,
config_
.
lm_config
.
scale
,
unk_id_
);
config_
.
lm_config
.
scale
,
unk_id_
,
config_
.
blank_penalty
);
}
else
if
(
config
.
decoding_method
==
"greedy_search"
)
{
decoder_
=
std
::
make_unique
<
OnlineTransducerGreedySearchDecoder
>
(
model_
.
get
(),
unk_id_
);
model_
.
get
(),
unk_id_
,
config_
.
blank_penalty
);
}
else
{
SHERPA_ONNX_LOGE
(
"Unsupported decoding method: %s"
,
config
.
decoding_method
.
c_str
());
...
...
sherpa-onnx/csrc/online-recognizer.cc
查看文件 @
e7b18a2
...
...
@@ -81,6 +81,12 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
"True to enable endpoint detection. False to disable it."
);
po
->
Register
(
"max-active-paths"
,
&
max_active_paths
,
"beam size used in modified beam search."
);
po
->
Register
(
"blank-penalty"
,
&
blank_penalty
,
"The penalty applied on blank symbol during decoding. "
"Note: It is a positive value. "
"Increasing value will lead to lower deletion at the cost"
"of higher insertions. "
"Currently only applicable for transducer models."
);
po
->
Register
(
"hotwords-score"
,
&
hotwords_score
,
"The bonus score for each token in context word/phrase. "
"Used only when decoding_method is modified_beam_search"
);
...
...
@@ -131,7 +137,8 @@ std::string OnlineRecognizerConfig::ToString() const {
os
<<
"max_active_paths="
<<
max_active_paths
<<
", "
;
os
<<
"hotwords_score="
<<
hotwords_score
<<
", "
;
os
<<
"hotwords_file=
\"
"
<<
hotwords_file
<<
"
\"
, "
;
os
<<
"decoding_method=
\"
"
<<
decoding_method
<<
"
\"
)"
;
os
<<
"decoding_method=
\"
"
<<
decoding_method
<<
"
\"
, "
;
os
<<
"blank_penalty="
<<
blank_penalty
<<
")"
;
return
os
.
str
();
}
...
...
sherpa-onnx/csrc/online-recognizer.h
查看文件 @
e7b18a2
...
...
@@ -83,6 +83,8 @@ struct OnlineRecognizerConfig {
float
hotwords_score
=
1
.
5
;
std
::
string
hotwords_file
;
float
blank_penalty
=
0
.
0
;
OnlineRecognizerConfig
()
=
default
;
OnlineRecognizerConfig
(
const
FeatureExtractorConfig
&
feat_config
,
...
...
@@ -92,7 +94,8 @@ struct OnlineRecognizerConfig {
bool
enable_endpoint
,
const
std
::
string
&
decoding_method
,
int32_t
max_active_paths
,
const
std
::
string
&
hotwords_file
,
float
hotwords_score
)
const
std
::
string
&
hotwords_file
,
float
hotwords_score
,
float
blank_penalty
)
:
feat_config
(
feat_config
),
model_config
(
model_config
),
lm_config
(
lm_config
),
...
...
@@ -101,7 +104,8 @@ struct OnlineRecognizerConfig {
decoding_method
(
decoding_method
),
max_active_paths
(
max_active_paths
),
hotwords_score
(
hotwords_score
),
hotwords_file
(
hotwords_file
)
{}
hotwords_file
(
hotwords_file
),
blank_penalty
(
blank_penalty
)
{}
void
Register
(
ParseOptions
*
po
);
bool
Validate
()
const
;
...
...
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
查看文件 @
e7b18a2
...
...
@@ -116,11 +116,14 @@ void OnlineTransducerGreedySearchDecoder::Decode(
Ort
::
Value
logit
=
model_
->
RunJoiner
(
std
::
move
(
cur_encoder_out
),
View
(
&
decoder_out
));
const
float
*
p_logit
=
logit
.
GetTensor
Data
<
float
>
();
float
*
p_logit
=
logit
.
GetTensorMutable
Data
<
float
>
();
bool
emitted
=
false
;
for
(
int32_t
i
=
0
;
i
<
batch_size
;
++
i
,
p_logit
+=
vocab_size
)
{
auto
&
r
=
(
*
result
)[
i
];
if
(
blank_penalty_
>
0.0
)
{
p_logit
[
0
]
-=
blank_penalty_
;
// assuming blank id is 0
}
auto
y
=
static_cast
<
int32_t
>
(
std
::
distance
(
static_cast
<
const
float
*>
(
p_logit
),
std
::
max_element
(
static_cast
<
const
float
*>
(
p_logit
),
...
...
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h
查看文件 @
e7b18a2
...
...
@@ -15,8 +15,9 @@ namespace sherpa_onnx {
class
OnlineTransducerGreedySearchDecoder
:
public
OnlineTransducerDecoder
{
public
:
OnlineTransducerGreedySearchDecoder
(
OnlineTransducerModel
*
model
,
int32_t
unk_id
)
:
model_
(
model
),
unk_id_
(
unk_id
)
{}
int32_t
unk_id
,
float
blank_penalty
)
:
model_
(
model
),
unk_id_
(
unk_id
),
blank_penalty_
(
blank_penalty
)
{}
OnlineTransducerDecoderResult
GetEmptyResult
()
const
override
;
...
...
@@ -28,6 +29,7 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
private
:
OnlineTransducerModel
*
model_
;
// Not owned
int32_t
unk_id_
;
float
blank_penalty_
;
};
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc
查看文件 @
e7b18a2
...
...
@@ -123,6 +123,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
model_
->
RunJoiner
(
std
::
move
(
cur_encoder_out
),
View
(
&
decoder_out
));
float
*
p_logit
=
logit
.
GetTensorMutableData
<
float
>
();
if
(
blank_penalty_
>
0.0
)
{
// assuming blank id is 0
SubtractBlank
(
p_logit
,
vocab_size
,
num_hyps
,
0
,
blank_penalty_
);
}
LogSoftmax
(
p_logit
,
vocab_size
,
num_hyps
);
// now p_logit contains log_softmax output, we rename it to p_logprob
...
...
sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h
查看文件 @
e7b18a2
...
...
@@ -21,12 +21,14 @@ class OnlineTransducerModifiedBeamSearchDecoder
OnlineTransducerModifiedBeamSearchDecoder
(
OnlineTransducerModel
*
model
,
OnlineLM
*
lm
,
int32_t
max_active_paths
,
float
lm_scale
,
int32_t
unk_id
)
float
lm_scale
,
int32_t
unk_id
,
float
blank_penalty
)
:
model_
(
model
),
lm_
(
lm
),
max_active_paths_
(
max_active_paths
),
lm_scale_
(
lm_scale
),
unk_id_
(
unk_id
)
{}
unk_id_
(
unk_id
),
blank_penalty_
(
blank_penalty
)
{}
OnlineTransducerDecoderResult
GetEmptyResult
()
const
override
;
...
...
@@ -47,6 +49,7 @@ class OnlineTransducerModifiedBeamSearchDecoder
int32_t
max_active_paths_
;
float
lm_scale_
;
// used only when lm_ is not nullptr
int32_t
unk_id_
;
float
blank_penalty_
;
};
}
// namespace sherpa_onnx
...
...
sherpa-onnx/python/csrc/online-recognizer.cc
查看文件 @
e7b18a2
...
...
@@ -33,12 +33,13 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
py
::
class_
<
PyClass
>
(
*
m
,
"OnlineRecognizerConfig"
)
.
def
(
py
::
init
<
const
FeatureExtractorConfig
&
,
const
OnlineModelConfig
&
,
const
OnlineLMConfig
&
,
const
EndpointConfig
&
,
bool
,
const
std
::
string
&
,
int32_t
,
const
std
::
string
&
,
float
>
(),
const
std
::
string
&
,
int32_t
,
const
std
::
string
&
,
float
,
float
>
(),
py
::
arg
(
"feat_config"
),
py
::
arg
(
"model_config"
),
py
::
arg
(
"lm_config"
)
=
OnlineLMConfig
(),
py
::
arg
(
"endpoint_config"
),
py
::
arg
(
"enable_endpoint"
),
py
::
arg
(
"decoding_method"
),
py
::
arg
(
"max_active_paths"
)
=
4
,
py
::
arg
(
"hotwords_file"
)
=
""
,
py
::
arg
(
"hotwords_score"
)
=
0
)
py
::
arg
(
"hotwords_score"
)
=
0
,
py
::
arg
(
"blank_penalty"
)
=
0.0
)
.
def_readwrite
(
"feat_config"
,
&
PyClass
::
feat_config
)
.
def_readwrite
(
"model_config"
,
&
PyClass
::
model_config
)
.
def_readwrite
(
"lm_config"
,
&
PyClass
::
lm_config
)
...
...
@@ -48,6 +49,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
.
def_readwrite
(
"max_active_paths"
,
&
PyClass
::
max_active_paths
)
.
def_readwrite
(
"hotwords_file"
,
&
PyClass
::
hotwords_file
)
.
def_readwrite
(
"hotwords_score"
,
&
PyClass
::
hotwords_score
)
.
def_readwrite
(
"blank_penalty"
,
&
PyClass
::
blank_penalty
)
.
def
(
"__str__"
,
&
PyClass
::
ToString
);
}
...
...
sherpa-onnx/python/sherpa_onnx/online_recognizer.py
查看文件 @
e7b18a2
...
...
@@ -48,6 +48,7 @@ class OnlineRecognizer(object):
decoding_method
:
str
=
"greedy_search"
,
max_active_paths
:
int
=
4
,
hotwords_score
:
float
=
1.5
,
blank_penalty
:
float
=
0.0
,
hotwords_file
:
str
=
""
,
provider
:
str
=
"cpu"
,
model_type
:
str
=
""
,
...
...
@@ -100,6 +101,8 @@ class OnlineRecognizer(object):
max_active_paths:
Use only when decoding_method is modified_beam_search. It specifies
the maximum number of active paths during beam search.
blank_penalty:
The penalty applied on blank symbol during decoding.
hotwords_file:
The file containing hotwords, one words/phrases per line, and for each
phrase the bpe/cjkchar are separated by a space.
...
...
@@ -172,6 +175,7 @@ class OnlineRecognizer(object):
max_active_paths
=
max_active_paths
,
hotwords_score
=
hotwords_score
,
hotwords_file
=
hotwords_file
,
blank_penalty
=
blank_penalty
,
)
self
.
recognizer
=
_Recognizer
(
recognizer_config
)
...
...
请
注册
或
登录
后发表评论