Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
chiiyeh
2024-01-25 15:00:09 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2024-01-25 15:00:09 +0800
Commit
3bb3849ec572f7aa8192e1bda0fe39d1f1737e4a
3bb3849e
1 parent
a9e77477
add blank_penalty for offline transducer (#542)
显示空白字符变更
内嵌
并排对比
正在显示
13 个修改的文件
包含
97 行增加
和
14 行删除
python-api-examples/non_streaming_server.py
python-api-examples/offline-decode-files.py
python-api-examples/vad-with-non-streaming-asr.py
sherpa-onnx/csrc/math.h
sherpa-onnx/csrc/offline-recognizer-transducer-impl.h
sherpa-onnx/csrc/offline-recognizer.cc
sherpa-onnx/csrc/offline-recognizer.h
sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc
sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h
sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc
sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h
sherpa-onnx/python/csrc/offline-recognizer.cc
sherpa-onnx/python/sherpa_onnx/offline_recognizer.py
python-api-examples/non_streaming_server.py
查看文件 @
3bb3849
...
...
@@ -383,6 +383,19 @@ def add_hotwords_args(parser: argparse.ArgumentParser):
"""
,
)
def
add_blank_penalty_args
(
parser
:
argparse
.
ArgumentParser
):
parser
.
add_argument
(
"--blank-penalty"
,
type
=
float
,
default
=
0.0
,
help
=
"""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
"""
,
)
def
check_args
(
args
):
if
not
Path
(
args
.
tokens
)
.
is_file
():
...
...
@@ -414,6 +427,7 @@ def get_args():
add_feature_config_args
(
parser
)
add_decoding_args
(
parser
)
add_hotwords_args
(
parser
)
add_blank_penalty_args
(
parser
)
parser
.
add_argument
(
"--port"
,
...
...
@@ -862,6 +876,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
max_active_paths
=
args
.
max_active_paths
,
hotwords_file
=
args
.
hotwords_file
,
hotwords_score
=
args
.
hotwords_score
,
blank_penalty
=
args
.
blank_penalty
,
provider
=
args
.
provider
,
)
elif
args
.
paraformer
:
...
...
python-api-examples/offline-decode-files.py
查看文件 @
3bb3849
...
...
@@ -232,6 +232,18 @@ def get_args():
)
parser
.
add_argument
(
"--blank-penalty"
,
type
=
float
,
default
=
0.0
,
help
=
"""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
"""
,
)
parser
.
add_argument
(
"--decoding-method"
,
type
=
str
,
default
=
"greedy_search"
,
...
...
@@ -335,6 +347,7 @@ def main():
decoding_method
=
args
.
decoding_method
,
hotwords_file
=
args
.
hotwords_file
,
hotwords_score
=
args
.
hotwords_score
,
blank_penalty
=
args
.
blank_penalty
,
debug
=
args
.
debug
,
)
elif
args
.
paraformer
:
...
...
python-api-examples/vad-with-non-streaming-asr.py
查看文件 @
3bb3849
...
...
@@ -178,6 +178,18 @@ def get_args():
)
parser
.
add_argument
(
"--blank-penalty"
,
type
=
float
,
default
=
0.0
,
help
=
"""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
"""
,
)
parser
.
add_argument
(
"--decoding-method"
,
type
=
str
,
default
=
"greedy_search"
,
...
...
@@ -237,6 +249,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
sample_rate
=
args
.
sample_rate
,
feature_dim
=
args
.
feature_dim
,
decoding_method
=
args
.
decoding_method
,
blank_penalty
=
args
.
blank_penalty
,
debug
=
args
.
debug
,
)
elif
args
.
paraformer
:
...
...
sherpa-onnx/csrc/math.h
查看文件 @
3bb3849
...
...
@@ -96,6 +96,15 @@ void LogSoftmax(T *in, int32_t w, int32_t h) {
}
}
template
<
typename
T
>
void
SubtractBlank
(
T
*
in
,
int32_t
w
,
int32_t
h
,
int32_t
blank_idx
,
float
blank_penalty
)
{
for
(
int32_t
i
=
0
;
i
!=
h
;
++
i
)
{
in
[
blank_idx
]
-=
blank_penalty
;
in
+=
w
;
}
}
template
<
class
T
>
std
::
vector
<
int32_t
>
TopkIndex
(
const
T
*
vec
,
int32_t
size
,
int32_t
topk
)
{
std
::
vector
<
int32_t
>
vec_index
(
size
);
...
...
sherpa-onnx/csrc/offline-recognizer-transducer-impl.h
查看文件 @
3bb3849
...
...
@@ -79,7 +79,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
}
if
(
config_
.
decoding_method
==
"greedy_search"
)
{
decoder_
=
std
::
make_unique
<
OfflineTransducerGreedySearchDecoder
>
(
model_
.
get
());
std
::
make_unique
<
OfflineTransducerGreedySearchDecoder
>
(
model_
.
get
(),
config_
.
blank_penalty
);
}
else
if
(
config_
.
decoding_method
==
"modified_beam_search"
)
{
if
(
!
config_
.
lm_config
.
model
.
empty
())
{
lm_
=
OfflineLM
::
Create
(
config
.
lm_config
);
...
...
@@ -87,7 +88,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
decoder_
=
std
::
make_unique
<
OfflineTransducerModifiedBeamSearchDecoder
>
(
model_
.
get
(),
lm_
.
get
(),
config_
.
max_active_paths
,
config_
.
lm_config
.
scale
);
config_
.
lm_config
.
scale
,
config_
.
blank_penalty
);
}
else
{
SHERPA_ONNX_LOGE
(
"Unsupported decoding method: %s"
,
config_
.
decoding_method
.
c_str
());
...
...
@@ -104,7 +105,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
config_
.
model_config
))
{
if
(
config_
.
decoding_method
==
"greedy_search"
)
{
decoder_
=
std
::
make_unique
<
OfflineTransducerGreedySearchDecoder
>
(
model_
.
get
());
std
::
make_unique
<
OfflineTransducerGreedySearchDecoder
>
(
model_
.
get
(),
config_
.
blank_penalty
);
}
else
if
(
config_
.
decoding_method
==
"modified_beam_search"
)
{
if
(
!
config_
.
lm_config
.
model
.
empty
())
{
lm_
=
OfflineLM
::
Create
(
mgr
,
config
.
lm_config
);
...
...
@@ -112,7 +114,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
decoder_
=
std
::
make_unique
<
OfflineTransducerModifiedBeamSearchDecoder
>
(
model_
.
get
(),
lm_
.
get
(),
config_
.
max_active_paths
,
config_
.
lm_config
.
scale
);
config_
.
lm_config
.
scale
,
config_
.
blank_penalty
);
}
else
{
SHERPA_ONNX_LOGE
(
"Unsupported decoding method: %s"
,
config_
.
decoding_method
.
c_str
());
...
...
sherpa-onnx/csrc/offline-recognizer.cc
查看文件 @
3bb3849
...
...
@@ -28,6 +28,13 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
po
->
Register
(
"max-active-paths"
,
&
max_active_paths
,
"Used only when decoding_method is modified_beam_search"
);
po
->
Register
(
"blank-penalty"
,
&
blank_penalty
,
"The penalty applied on blank symbol during decoding. "
"Note: It is a positive value. "
"Increasing value will lead to lower deletion at the cost"
"of higher insertions. "
"Currently only applicable for transducer models."
);
po
->
Register
(
"hotwords-file"
,
&
hotwords_file
,
"The file containing hotwords, one words/phrases per line, and for each"
...
...
@@ -74,7 +81,8 @@ std::string OfflineRecognizerConfig::ToString() const {
os
<<
"decoding_method=
\"
"
<<
decoding_method
<<
"
\"
, "
;
os
<<
"max_active_paths="
<<
max_active_paths
<<
", "
;
os
<<
"hotwords_file=
\"
"
<<
hotwords_file
<<
"
\"
, "
;
os
<<
"hotwords_score="
<<
hotwords_score
<<
")"
;
os
<<
"hotwords_score="
<<
hotwords_score
<<
", "
;
os
<<
"blank_penalty="
<<
blank_penalty
<<
")"
;
return
os
.
str
();
}
...
...
sherpa-onnx/csrc/offline-recognizer.h
查看文件 @
3bb3849
...
...
@@ -37,6 +37,8 @@ struct OfflineRecognizerConfig {
std
::
string
hotwords_file
;
float
hotwords_score
=
1
.
5
;
float
blank_penalty
=
0
.
0
;
// only greedy_search is implemented
// TODO(fangjun): Implement modified_beam_search
...
...
@@ -46,7 +48,8 @@ struct OfflineRecognizerConfig {
const
OfflineModelConfig
&
model_config
,
const
OfflineLMConfig
&
lm_config
,
const
OfflineCtcFstDecoderConfig
&
ctc_fst_decoder_config
,
const
std
::
string
&
decoding_method
,
int32_t
max_active_paths
,
const
std
::
string
&
hotwords_file
,
float
hotwords_score
)
const
std
::
string
&
hotwords_file
,
float
hotwords_score
,
float
blank_penalty
)
:
feat_config
(
feat_config
),
model_config
(
model_config
),
lm_config
(
lm_config
),
...
...
@@ -54,7 +57,8 @@ struct OfflineRecognizerConfig {
decoding_method
(
decoding_method
),
max_active_paths
(
max_active_paths
),
hotwords_file
(
hotwords_file
),
hotwords_score
(
hotwords_score
)
{}
hotwords_score
(
hotwords_score
),
blank_penalty
(
blank_penalty
)
{}
void
Register
(
ParseOptions
*
po
);
bool
Validate
()
const
;
...
...
sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc
查看文件 @
3bb3849
...
...
@@ -46,9 +46,12 @@ OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out,
start
+=
n
;
Ort
::
Value
logit
=
model_
->
RunJoiner
(
std
::
move
(
cur_encoder_out
),
std
::
move
(
cur_decoder_out
));
const
float
*
p_logit
=
logit
.
GetTensor
Data
<
float
>
();
float
*
p_logit
=
logit
.
GetTensorMutable
Data
<
float
>
();
bool
emitted
=
false
;
for
(
int32_t
i
=
0
;
i
!=
n
;
++
i
)
{
if
(
blank_penalty_
>
0.0
)
{
p_logit
[
0
]
-=
blank_penalty_
;
// assuming blank id is 0
}
auto
y
=
static_cast
<
int32_t
>
(
std
::
distance
(
static_cast
<
const
float
*>
(
p_logit
),
std
::
max_element
(
static_cast
<
const
float
*>
(
p_logit
),
...
...
sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h
查看文件 @
3bb3849
...
...
@@ -14,8 +14,10 @@ namespace sherpa_onnx {
class
OfflineTransducerGreedySearchDecoder
:
public
OfflineTransducerDecoder
{
public
:
explicit
OfflineTransducerGreedySearchDecoder
(
OfflineTransducerModel
*
model
)
:
model_
(
model
)
{}
explicit
OfflineTransducerGreedySearchDecoder
(
OfflineTransducerModel
*
model
,
float
blank_penalty
)
:
model_
(
model
),
blank_penalty_
(
blank_penalty
)
{}
std
::
vector
<
OfflineTransducerDecoderResult
>
Decode
(
Ort
::
Value
encoder_out
,
Ort
::
Value
encoder_out_length
,
...
...
@@ -23,6 +25,7 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
private
:
OfflineTransducerModel
*
model_
;
// Not owned
float
blank_penalty_
;
};
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc
查看文件 @
3bb3849
...
...
@@ -97,6 +97,10 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
model_
->
RunJoiner
(
std
::
move
(
cur_encoder_out
),
View
(
&
decoder_out
));
float
*
p_logit
=
logit
.
GetTensorMutableData
<
float
>
();
if
(
blank_penalty_
>
0.0
)
{
// assuming blank id is 0
SubtractBlank
(
p_logit
,
vocab_size
,
num_hyps
,
0
,
blank_penalty_
);
}
LogSoftmax
(
p_logit
,
vocab_size
,
num_hyps
);
// now p_logit contains log_softmax output, we rename it to p_logprob
...
...
sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h
查看文件 @
3bb3849
...
...
@@ -19,11 +19,13 @@ class OfflineTransducerModifiedBeamSearchDecoder
OfflineTransducerModifiedBeamSearchDecoder
(
OfflineTransducerModel
*
model
,
OfflineLM
*
lm
,
int32_t
max_active_paths
,
float
lm_scale
)
float
lm_scale
,
float
blank_penalty
)
:
model_
(
model
),
lm_
(
lm
),
max_active_paths_
(
max_active_paths
),
lm_scale_
(
lm_scale
)
{}
lm_scale_
(
lm_scale
),
blank_penalty_
(
blank_penalty
)
{}
std
::
vector
<
OfflineTransducerDecoderResult
>
Decode
(
Ort
::
Value
encoder_out
,
Ort
::
Value
encoder_out_length
,
...
...
@@ -35,6 +37,7 @@ class OfflineTransducerModifiedBeamSearchDecoder
int32_t
max_active_paths_
;
float
lm_scale_
;
// used only when lm_ is not nullptr
float
blank_penalty_
;
};
}
// namespace sherpa_onnx
...
...
sherpa-onnx/python/csrc/offline-recognizer.cc
查看文件 @
3bb3849
...
...
@@ -17,13 +17,14 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
.
def
(
py
::
init
<
const
OfflineFeatureExtractorConfig
&
,
const
OfflineModelConfig
&
,
const
OfflineLMConfig
&
,
const
OfflineCtcFstDecoderConfig
&
,
const
std
::
string
&
,
int32_t
,
const
std
::
string
&
,
float
>
(),
int32_t
,
const
std
::
string
&
,
float
,
float
>
(),
py
::
arg
(
"feat_config"
),
py
::
arg
(
"model_config"
),
py
::
arg
(
"lm_config"
)
=
OfflineLMConfig
(),
py
::
arg
(
"ctc_fst_decoder_config"
)
=
OfflineCtcFstDecoderConfig
(),
py
::
arg
(
"decoding_method"
)
=
"greedy_search"
,
py
::
arg
(
"max_active_paths"
)
=
4
,
py
::
arg
(
"hotwords_file"
)
=
""
,
py
::
arg
(
"hotwords_score"
)
=
1.5
)
py
::
arg
(
"hotwords_score"
)
=
1.5
,
py
::
arg
(
"blank_penalty"
)
=
0.0
)
.
def_readwrite
(
"feat_config"
,
&
PyClass
::
feat_config
)
.
def_readwrite
(
"model_config"
,
&
PyClass
::
model_config
)
.
def_readwrite
(
"lm_config"
,
&
PyClass
::
lm_config
)
...
...
@@ -32,6 +33,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
.
def_readwrite
(
"max_active_paths"
,
&
PyClass
::
max_active_paths
)
.
def_readwrite
(
"hotwords_file"
,
&
PyClass
::
hotwords_file
)
.
def_readwrite
(
"hotwords_score"
,
&
PyClass
::
hotwords_score
)
.
def_readwrite
(
"blank_penalty"
,
&
PyClass
::
blank_penalty
)
.
def
(
"__str__"
,
&
PyClass
::
ToString
);
}
...
...
sherpa-onnx/python/sherpa_onnx/offline_recognizer.py
查看文件 @
3bb3849
...
...
@@ -48,6 +48,7 @@ class OfflineRecognizer(object):
max_active_paths
:
int
=
4
,
hotwords_file
:
str
=
""
,
hotwords_score
:
float
=
1.5
,
blank_penalty
:
float
=
0.0
,
debug
:
bool
=
False
,
provider
:
str
=
"cpu"
,
):
...
...
@@ -81,6 +82,8 @@ class OfflineRecognizer(object):
max_active_paths:
Maximum number of active paths to keep. Used only when
decoding_method is modified_beam_search.
blank_penalty:
The penalty applied on blank symbol during decoding.
debug:
True to show debug messages.
provider:
...
...
@@ -117,6 +120,7 @@ class OfflineRecognizer(object):
decoding_method
=
decoding_method
,
hotwords_file
=
hotwords_file
,
hotwords_score
=
hotwords_score
,
blank_penalty
=
blank_penalty
,
)
self
.
recognizer
=
_Recognizer
(
recognizer_config
)
self
.
config
=
recognizer_config
...
...
请
注册
或
登录
后发表评论