Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2024-06-17 18:39:23 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2024-06-17 18:39:23 +0800
Commit
349d957da28f44121c22d25787d7ccd62c684c01
349d957d
1 parent
6e09933d
Add inverse text normalization for online ASR (#1020)
隐藏空白字符变更
内嵌
并排对比
正在显示
12 个修改的文件
包含
390 行增加
和
32 行删除
.github/scripts/test-python.sh
python-api-examples/inverse-text-normalization-online-asr.py
sherpa-onnx/csrc/online-recognizer-ctc-impl.h
sherpa-onnx/csrc/online-recognizer-impl.cc
sherpa-onnx/csrc/online-recognizer-impl.h
sherpa-onnx/csrc/online-recognizer-paraformer-impl.h
sherpa-onnx/csrc/online-recognizer-transducer-impl.h
sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h
sherpa-onnx/csrc/online-recognizer.cc
sherpa-onnx/csrc/online-recognizer.h
sherpa-onnx/python/csrc/online-recognizer.cc
sherpa-onnx/python/sherpa_onnx/online_recognizer.py
.github/scripts/test-python.sh
查看文件 @
349d957
...
...
@@ -256,7 +256,18 @@ if [[ x$OS != x'windows-latest' ]]; then
$repo
/test_wavs/3.wav
\
$repo
/test_wavs/8k.wav
ln -s
$repo
$PWD
/
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn_zh_number.fst
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn-zh-number.wav
python3 ./python-api-examples/inverse-text-normalization-online-asr.py
python3 sherpa-onnx/python/tests/test_online_recognizer.py --verbose
rm -rfv sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20
rm -rf
$repo
fi
log
"Test non-streaming transducer models"
...
...
python-api-examples/inverse-text-normalization-online-asr.py
0 → 100755
查看文件 @
349d957
#!/usr/bin/env python3
#
# Copyright (c) 2024 Xiaomi Corporation
"""
This script shows how to use inverse text normalization with streaming ASR.
Usage:
(1) Download the test model
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2
(2) Download rule fst
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn_zh_number.fst
Please refer to
https://github.com/k2-fsa/colab/blob/master/sherpa-onnx/itn_zh_number.ipynb
for how itn_zh_number.fst is generated.
(3) Download test wave
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn-zh-number.wav
(4) Run this script
python3 ./python-api-examples/inverse-text-normalization-online-asr.py
"""
from
pathlib
import
Path
import
sherpa_onnx
import
soundfile
as
sf
def
create_recognizer
():
encoder
=
"./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.int8.onnx"
decoder
=
"./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx"
joiner
=
"./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.int8.onnx"
tokens
=
"./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt"
rule_fsts
=
"./itn_zh_number.fst"
if
(
not
Path
(
encoder
)
.
is_file
()
or
not
Path
(
decoder
)
.
is_file
()
or
not
Path
(
joiner
)
.
is_file
()
or
not
Path
(
tokens
)
.
is_file
()
or
not
Path
(
rule_fsts
)
.
is_file
()
):
raise
ValueError
(
"""Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
"""
)
return
sherpa_onnx
.
OnlineRecognizer
.
from_transducer
(
encoder
=
encoder
,
decoder
=
decoder
,
joiner
=
joiner
,
tokens
=
tokens
,
debug
=
True
,
rule_fsts
=
rule_fsts
,
)
def
main
():
recognizer
=
create_recognizer
()
wave_filename
=
"./itn-zh-number.wav"
if
not
Path
(
wave_filename
)
.
is_file
():
raise
ValueError
(
"""Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
"""
)
audio
,
sample_rate
=
sf
.
read
(
wave_filename
,
dtype
=
"float32"
,
always_2d
=
True
)
audio
=
audio
[:,
0
]
# only use the first channel
stream
=
recognizer
.
create_stream
()
stream
.
accept_waveform
(
sample_rate
,
audio
)
tail_padding
=
[
0
]
*
int
(
0.3
*
sample_rate
)
stream
.
accept_waveform
(
sample_rate
,
tail_padding
)
while
recognizer
.
is_ready
(
stream
):
recognizer
.
decode_stream
(
stream
)
print
(
wave_filename
)
print
(
recognizer
.
get_result_all
(
stream
))
if
__name__
==
"__main__"
:
main
()
...
...
sherpa-onnx/csrc/online-recognizer-ctc-impl.h
查看文件 @
349d957
...
...
@@ -68,7 +68,8 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src,
class
OnlineRecognizerCtcImpl
:
public
OnlineRecognizerImpl
{
public
:
explicit
OnlineRecognizerCtcImpl
(
const
OnlineRecognizerConfig
&
config
)
:
config_
(
config
),
:
OnlineRecognizerImpl
(
config
),
config_
(
config
),
model_
(
OnlineCtcModel
::
Create
(
config
.
model_config
)),
sym_
(
config
.
model_config
.
tokens
),
endpoint_
(
config_
.
endpoint_config
)
{
...
...
@@ -84,7 +85,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
#if __ANDROID_API__ >= 9
explicit
OnlineRecognizerCtcImpl
(
AAssetManager
*
mgr
,
const
OnlineRecognizerConfig
&
config
)
:
config_
(
config
),
:
OnlineRecognizerImpl
(
mgr
,
config
),
config_
(
config
),
model_
(
OnlineCtcModel
::
Create
(
mgr
,
config
.
model_config
)),
sym_
(
mgr
,
config
.
model_config
.
tokens
),
endpoint_
(
config_
.
endpoint_config
)
{
...
...
@@ -182,8 +184,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
// TODO(fangjun): Remember to change these constants if needed
int32_t
frame_shift_ms
=
10
;
int32_t
subsampling_factor
=
4
;
return
Convert
(
decoder_result
,
sym_
,
frame_shift_ms
,
subsampling_factor
,
s
->
GetCurrentSegment
(),
s
->
GetNumFramesSinceStart
());
auto
r
=
Convert
(
decoder_result
,
sym_
,
frame_shift_ms
,
subsampling_factor
,
s
->
GetCurrentSegment
(),
s
->
GetNumFramesSinceStart
());
r
.
text
=
ApplyInverseTextNormalization
(
r
.
text
);
return
r
;
}
bool
IsEndpoint
(
OnlineStream
*
s
)
const
override
{
...
...
sherpa-onnx/csrc/online-recognizer-impl.cc
查看文件 @
349d957
...
...
@@ -4,11 +4,22 @@
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
#if __ANDROID_API__ >= 9
#include <strstream>
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "fst/extensions/far/far.h"
#include "kaldifst/csrc/kaldi-fst-io.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-recognizer-ctc-impl.h"
#include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h"
#include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h"
#include "sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace
sherpa_onnx
{
...
...
@@ -78,4 +89,110 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
}
#endif
OnlineRecognizerImpl
::
OnlineRecognizerImpl
(
const
OnlineRecognizerConfig
&
config
)
:
config_
(
config
)
{
if
(
!
config
.
rule_fsts
.
empty
())
{
std
::
vector
<
std
::
string
>
files
;
SplitStringToVector
(
config
.
rule_fsts
,
","
,
false
,
&
files
);
itn_list_
.
reserve
(
files
.
size
());
for
(
const
auto
&
f
:
files
)
{
if
(
config
.
model_config
.
debug
)
{
SHERPA_ONNX_LOGE
(
"rule fst: %s"
,
f
.
c_str
());
}
itn_list_
.
push_back
(
std
::
make_unique
<
kaldifst
::
TextNormalizer
>
(
f
));
}
}
if
(
!
config
.
rule_fars
.
empty
())
{
if
(
config
.
model_config
.
debug
)
{
SHERPA_ONNX_LOGE
(
"Loading FST archives"
);
}
std
::
vector
<
std
::
string
>
files
;
SplitStringToVector
(
config
.
rule_fars
,
","
,
false
,
&
files
);
itn_list_
.
reserve
(
files
.
size
()
+
itn_list_
.
size
());
for
(
const
auto
&
f
:
files
)
{
if
(
config
.
model_config
.
debug
)
{
SHERPA_ONNX_LOGE
(
"rule far: %s"
,
f
.
c_str
());
}
std
::
unique_ptr
<
fst
::
FarReader
<
fst
::
StdArc
>>
reader
(
fst
::
FarReader
<
fst
::
StdArc
>::
Open
(
f
));
for
(;
!
reader
->
Done
();
reader
->
Next
())
{
std
::
unique_ptr
<
fst
::
StdConstFst
>
r
(
fst
::
CastOrConvertToConstFst
(
reader
->
GetFst
()
->
Copy
()));
itn_list_
.
push_back
(
std
::
make_unique
<
kaldifst
::
TextNormalizer
>
(
std
::
move
(
r
)));
}
}
if
(
config
.
model_config
.
debug
)
{
SHERPA_ONNX_LOGE
(
"FST archives loaded!"
);
}
}
}
#if __ANDROID_API__ >= 9
OnlineRecognizerImpl
::
OnlineRecognizerImpl
(
AAssetManager
*
mgr
,
const
OnlineRecognizerConfig
&
config
)
:
config_
(
config
)
{
if
(
!
config
.
rule_fsts
.
empty
())
{
std
::
vector
<
std
::
string
>
files
;
SplitStringToVector
(
config
.
rule_fsts
,
","
,
false
,
&
files
);
itn_list_
.
reserve
(
files
.
size
());
for
(
const
auto
&
f
:
files
)
{
if
(
config
.
model_config
.
debug
)
{
SHERPA_ONNX_LOGE
(
"rule fst: %s"
,
f
.
c_str
());
}
auto
buf
=
ReadFile
(
mgr
,
f
);
std
::
istrstream
is
(
buf
.
data
(),
buf
.
size
());
itn_list_
.
push_back
(
std
::
make_unique
<
kaldifst
::
TextNormalizer
>
(
is
));
}
}
if
(
!
config
.
rule_fars
.
empty
())
{
std
::
vector
<
std
::
string
>
files
;
SplitStringToVector
(
config
.
rule_fars
,
","
,
false
,
&
files
);
itn_list_
.
reserve
(
files
.
size
()
+
itn_list_
.
size
());
for
(
const
auto
&
f
:
files
)
{
if
(
config
.
model_config
.
debug
)
{
SHERPA_ONNX_LOGE
(
"rule far: %s"
,
f
.
c_str
());
}
auto
buf
=
ReadFile
(
mgr
,
f
);
std
::
unique_ptr
<
std
::
istream
>
s
(
new
std
::
istrstream
(
buf
.
data
(),
buf
.
size
()));
std
::
unique_ptr
<
fst
::
FarReader
<
fst
::
StdArc
>>
reader
(
fst
::
FarReader
<
fst
::
StdArc
>::
Open
(
std
::
move
(
s
)));
for
(;
!
reader
->
Done
();
reader
->
Next
())
{
std
::
unique_ptr
<
fst
::
StdConstFst
>
r
(
fst
::
CastOrConvertToConstFst
(
reader
->
GetFst
()
->
Copy
()));
itn_list_
.
push_back
(
std
::
make_unique
<
kaldifst
::
TextNormalizer
>
(
std
::
move
(
r
)));
}
// for (; !reader->Done(); reader->Next())
}
// for (const auto &f : files)
}
// if (!config.rule_fars.empty())
}
#endif
std
::
string
OnlineRecognizerImpl
::
ApplyInverseTextNormalization
(
std
::
string
text
)
const
{
if
(
!
itn_list_
.
empty
())
{
for
(
const
auto
&
tn
:
itn_list_
)
{
text
=
tn
->
Normalize
(
text
);
if
(
config_
.
model_config
.
debug
)
{
SHERPA_ONNX_LOGE
(
"After inverse text normalization: %s"
,
text
.
c_str
());
}
}
}
return
text
;
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/online-recognizer-impl.h
查看文件 @
349d957
...
...
@@ -9,6 +9,12 @@
#include <string>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "kaldifst/csrc/text-normalizer.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/online-stream.h"
...
...
@@ -17,10 +23,15 @@ namespace sherpa_onnx {
class
OnlineRecognizerImpl
{
public
:
explicit
OnlineRecognizerImpl
(
const
OnlineRecognizerConfig
&
config
);
static
std
::
unique_ptr
<
OnlineRecognizerImpl
>
Create
(
const
OnlineRecognizerConfig
&
config
);
#if __ANDROID_API__ >= 9
OnlineRecognizerImpl
(
AAssetManager
*
mgr
,
const
OnlineRecognizerConfig
&
config
);
static
std
::
unique_ptr
<
OnlineRecognizerImpl
>
Create
(
AAssetManager
*
mgr
,
const
OnlineRecognizerConfig
&
config
);
#endif
...
...
@@ -50,6 +61,15 @@ class OnlineRecognizerImpl {
virtual
bool
IsEndpoint
(
OnlineStream
*
s
)
const
=
0
;
virtual
void
Reset
(
OnlineStream
*
s
)
const
=
0
;
std
::
string
ApplyInverseTextNormalization
(
std
::
string
text
)
const
;
private
:
OnlineRecognizerConfig
config_
;
// for inverse text normalization. Used only if
// config.rule_fsts is not empty or
// config.rule_fars is not empty
std
::
vector
<
std
::
unique_ptr
<
kaldifst
::
TextNormalizer
>>
itn_list_
;
};
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/online-recognizer-paraformer-impl.h
查看文件 @
349d957
...
...
@@ -96,7 +96,8 @@ static void Scale(const float *x, int32_t n, float scale, float *y) {
class
OnlineRecognizerParaformerImpl
:
public
OnlineRecognizerImpl
{
public
:
explicit
OnlineRecognizerParaformerImpl
(
const
OnlineRecognizerConfig
&
config
)
:
config_
(
config
),
:
OnlineRecognizerImpl
(
config
),
config_
(
config
),
model_
(
config
.
model_config
),
sym_
(
config
.
model_config
.
tokens
),
endpoint_
(
config_
.
endpoint_config
)
{
...
...
@@ -116,7 +117,8 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl {
#if __ANDROID_API__ >= 9
explicit
OnlineRecognizerParaformerImpl
(
AAssetManager
*
mgr
,
const
OnlineRecognizerConfig
&
config
)
:
config_
(
config
),
:
OnlineRecognizerImpl
(
mgr
,
config
),
config_
(
config
),
model_
(
mgr
,
config
.
model_config
),
sym_
(
mgr
,
config
.
model_config
.
tokens
),
endpoint_
(
config_
.
endpoint_config
)
{
...
...
@@ -160,7 +162,9 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl {
OnlineRecognizerResult
GetResult
(
OnlineStream
*
s
)
const
override
{
auto
decoder_result
=
s
->
GetParaformerResult
();
return
Convert
(
decoder_result
,
sym_
);
auto
r
=
Convert
(
decoder_result
,
sym_
);
r
.
text
=
ApplyInverseTextNormalization
(
r
.
text
);
return
r
;
}
bool
IsEndpoint
(
OnlineStream
*
s
)
const
override
{
...
...
sherpa-onnx/csrc/online-recognizer-transducer-impl.h
查看文件 @
349d957
...
...
@@ -80,7 +80,8 @@ OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
class
OnlineRecognizerTransducerImpl
:
public
OnlineRecognizerImpl
{
public
:
explicit
OnlineRecognizerTransducerImpl
(
const
OnlineRecognizerConfig
&
config
)
:
config_
(
config
),
:
OnlineRecognizerImpl
(
config
),
config_
(
config
),
model_
(
OnlineTransducerModel
::
Create
(
config
.
model_config
)),
sym_
(
config
.
model_config
.
tokens
),
endpoint_
(
config_
.
endpoint_config
)
{
...
...
@@ -124,7 +125,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
#if __ANDROID_API__ >= 9
explicit
OnlineRecognizerTransducerImpl
(
AAssetManager
*
mgr
,
const
OnlineRecognizerConfig
&
config
)
:
config_
(
config
),
:
OnlineRecognizerImpl
(
mgr
,
config
),
config_
(
config
),
model_
(
OnlineTransducerModel
::
Create
(
mgr
,
config
.
model_config
)),
sym_
(
mgr
,
config
.
model_config
.
tokens
),
endpoint_
(
config_
.
endpoint_config
)
{
...
...
@@ -332,8 +334,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
// TODO(fangjun): Remember to change these constants if needed
int32_t
frame_shift_ms
=
10
;
int32_t
subsampling_factor
=
4
;
return
Convert
(
decoder_result
,
sym_
,
frame_shift_ms
,
subsampling_factor
,
s
->
GetCurrentSegment
(),
s
->
GetNumFramesSinceStart
());
auto
r
=
Convert
(
decoder_result
,
sym_
,
frame_shift_ms
,
subsampling_factor
,
s
->
GetCurrentSegment
(),
s
->
GetNumFramesSinceStart
());
r
.
text
=
ApplyInverseTextNormalization
(
std
::
move
(
r
.
text
));
return
r
;
}
bool
IsEndpoint
(
OnlineStream
*
s
)
const
override
{
...
...
sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h
查看文件 @
349d957
...
...
@@ -42,7 +42,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
public
:
explicit
OnlineRecognizerTransducerNeMoImpl
(
const
OnlineRecognizerConfig
&
config
)
:
config_
(
config
),
:
OnlineRecognizerImpl
(
config
),
config_
(
config
),
symbol_table_
(
config
.
model_config
.
tokens
),
endpoint_
(
config_
.
endpoint_config
),
model_
(
...
...
@@ -61,7 +62,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
#if __ANDROID_API__ >= 9
explicit
OnlineRecognizerTransducerNeMoImpl
(
AAssetManager
*
mgr
,
const
OnlineRecognizerConfig
&
config
)
:
config_
(
config
),
:
OnlineRecognizerImpl
(
mgr
,
config
),
config_
(
config
),
symbol_table_
(
mgr
,
config
.
model_config
.
tokens
),
endpoint_
(
config_
.
endpoint_config
),
model_
(
std
::
make_unique
<
OnlineTransducerNeMoModel
>
(
...
...
@@ -94,9 +96,11 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
// TODO(fangjun): Remember to change these constants if needed
int32_t
frame_shift_ms
=
10
;
int32_t
subsampling_factor
=
model_
->
SubsamplingFactor
();
return
Convert
(
s
->
GetResult
(),
symbol_table_
,
frame_shift_ms
,
subsampling_factor
,
s
->
GetCurrentSegment
(),
s
->
GetNumFramesSinceStart
());
auto
r
=
Convert
(
s
->
GetResult
(),
symbol_table_
,
frame_shift_ms
,
subsampling_factor
,
s
->
GetCurrentSegment
(),
s
->
GetNumFramesSinceStart
());
r
.
text
=
ApplyInverseTextNormalization
(
std
::
move
(
r
.
text
));
return
r
;
}
bool
IsEndpoint
(
OnlineStream
*
s
)
const
override
{
...
...
sherpa-onnx/csrc/online-recognizer.cc
查看文件 @
349d957
...
...
@@ -14,7 +14,9 @@
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace
sherpa_onnx
{
...
...
@@ -100,6 +102,15 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
"now support greedy_search and modified_beam_search."
);
po
->
Register
(
"temperature-scale"
,
&
temperature_scale
,
"Temperature scale for confidence computation in decoding."
);
po
->
Register
(
"rule-fsts"
,
&
rule_fsts
,
"If not empty, it specifies fsts for inverse text normalization. "
"If there are multiple fsts, they are separated by a comma."
);
po
->
Register
(
"rule-fars"
,
&
rule_fars
,
"If not empty, it specifies fst archives for inverse text normalization. "
"If there are multiple archives, they are separated by a comma."
);
}
bool
OnlineRecognizerConfig
::
Validate
()
const
{
...
...
@@ -129,6 +140,34 @@ bool OnlineRecognizerConfig::Validate() const {
return
false
;
}
if
(
!
hotwords_file
.
empty
()
&&
!
FileExists
(
hotwords_file
))
{
SHERPA_ONNX_LOGE
(
"--hotwords-file: '%s' does not exist"
,
hotwords_file
.
c_str
());
return
false
;
}
if
(
!
rule_fsts
.
empty
())
{
std
::
vector
<
std
::
string
>
files
;
SplitStringToVector
(
rule_fsts
,
","
,
false
,
&
files
);
for
(
const
auto
&
f
:
files
)
{
if
(
!
FileExists
(
f
))
{
SHERPA_ONNX_LOGE
(
"Rule fst '%s' does not exist. "
,
f
.
c_str
());
return
false
;
}
}
}
if
(
!
rule_fars
.
empty
())
{
std
::
vector
<
std
::
string
>
files
;
SplitStringToVector
(
rule_fars
,
","
,
false
,
&
files
);
for
(
const
auto
&
f
:
files
)
{
if
(
!
FileExists
(
f
))
{
SHERPA_ONNX_LOGE
(
"Rule far '%s' does not exist. "
,
f
.
c_str
());
return
false
;
}
}
}
return
model_config
.
Validate
();
}
...
...
@@ -147,7 +186,9 @@ std::string OnlineRecognizerConfig::ToString() const {
os
<<
"hotwords_file=
\"
"
<<
hotwords_file
<<
"
\"
, "
;
os
<<
"decoding_method=
\"
"
<<
decoding_method
<<
"
\"
, "
;
os
<<
"blank_penalty="
<<
blank_penalty
<<
", "
;
os
<<
"temperature_scale="
<<
temperature_scale
<<
")"
;
os
<<
"temperature_scale="
<<
temperature_scale
<<
", "
;
os
<<
"rule_fsts=
\"
"
<<
rule_fsts
<<
"
\"
, "
;
os
<<
"rule_fars=
\"
"
<<
rule_fars
<<
"
\"
)"
;
return
os
.
str
();
}
...
...
sherpa-onnx/csrc/online-recognizer.h
查看文件 @
349d957
...
...
@@ -100,6 +100,12 @@ struct OnlineRecognizerConfig {
float
temperature_scale
=
2
.
0
;
// If there are multiple rules, they are applied from left to right.
std
::
string
rule_fsts
;
// If there are multiple FST archives, they are applied from left to right.
std
::
string
rule_fars
;
OnlineRecognizerConfig
()
=
default
;
OnlineRecognizerConfig
(
...
...
@@ -109,7 +115,8 @@ struct OnlineRecognizerConfig {
const
OnlineCtcFstDecoderConfig
&
ctc_fst_decoder_config
,
bool
enable_endpoint
,
const
std
::
string
&
decoding_method
,
int32_t
max_active_paths
,
const
std
::
string
&
hotwords_file
,
float
hotwords_score
,
float
blank_penalty
,
float
temperature_scale
)
float
hotwords_score
,
float
blank_penalty
,
float
temperature_scale
,
const
std
::
string
&
rule_fsts
,
const
std
::
string
&
rule_fars
)
:
feat_config
(
feat_config
),
model_config
(
model_config
),
lm_config
(
lm_config
),
...
...
@@ -121,7 +128,9 @@ struct OnlineRecognizerConfig {
hotwords_file
(
hotwords_file
),
hotwords_score
(
hotwords_score
),
blank_penalty
(
blank_penalty
),
temperature_scale
(
temperature_scale
)
{}
temperature_scale
(
temperature_scale
),
rule_fsts
(
rule_fsts
),
rule_fars
(
rule_fars
)
{}
void
Register
(
ParseOptions
*
po
);
bool
Validate
()
const
;
...
...
sherpa-onnx/python/csrc/online-recognizer.cc
查看文件 @
349d957
...
...
@@ -54,19 +54,20 @@ static void PybindOnlineRecognizerResult(py::module *m) {
static
void
PybindOnlineRecognizerConfig
(
py
::
module
*
m
)
{
using
PyClass
=
OnlineRecognizerConfig
;
py
::
class_
<
PyClass
>
(
*
m
,
"OnlineRecognizerConfig"
)
.
def
(
py
::
init
<
const
FeatureExtractorConfig
&
,
const
OnlineModelConfig
&
,
const
OnlineLMConfig
&
,
const
EndpointConfig
&
,
const
OnlineCtcFstDecoderConfig
&
,
bool
,
const
std
::
string
&
,
int32_t
,
const
std
::
string
&
,
float
,
float
,
float
>
(),
py
::
arg
(
"feat_config"
),
py
::
arg
(
"model_config"
),
py
::
arg
(
"lm_config"
)
=
OnlineLMConfig
(),
py
::
arg
(
"endpoint_config"
)
=
EndpointConfig
(),
py
::
arg
(
"ctc_fst_decoder_config"
)
=
OnlineCtcFstDecoderConfig
(),
py
::
arg
(
"enable_endpoint"
),
py
::
arg
(
"decoding_method"
),
py
::
arg
(
"max_active_paths"
)
=
4
,
py
::
arg
(
"hotwords_file"
)
=
""
,
py
::
arg
(
"hotwords_score"
)
=
0
,
py
::
arg
(
"blank_penalty"
)
=
0.0
,
py
::
arg
(
"temperature_scale"
)
=
2.0
)
.
def
(
py
::
init
<
const
FeatureExtractorConfig
&
,
const
OnlineModelConfig
&
,
const
OnlineLMConfig
&
,
const
EndpointConfig
&
,
const
OnlineCtcFstDecoderConfig
&
,
bool
,
const
std
::
string
&
,
int32_t
,
const
std
::
string
&
,
float
,
float
,
float
,
const
std
::
string
&
,
const
std
::
string
&>
(),
py
::
arg
(
"feat_config"
),
py
::
arg
(
"model_config"
),
py
::
arg
(
"lm_config"
)
=
OnlineLMConfig
(),
py
::
arg
(
"endpoint_config"
)
=
EndpointConfig
(),
py
::
arg
(
"ctc_fst_decoder_config"
)
=
OnlineCtcFstDecoderConfig
(),
py
::
arg
(
"enable_endpoint"
),
py
::
arg
(
"decoding_method"
),
py
::
arg
(
"max_active_paths"
)
=
4
,
py
::
arg
(
"hotwords_file"
)
=
""
,
py
::
arg
(
"hotwords_score"
)
=
0
,
py
::
arg
(
"blank_penalty"
)
=
0.0
,
py
::
arg
(
"temperature_scale"
)
=
2.0
,
py
::
arg
(
"rule_fsts"
)
=
""
,
py
::
arg
(
"rule_fars"
)
=
""
)
.
def_readwrite
(
"feat_config"
,
&
PyClass
::
feat_config
)
.
def_readwrite
(
"model_config"
,
&
PyClass
::
model_config
)
.
def_readwrite
(
"lm_config"
,
&
PyClass
::
lm_config
)
...
...
@@ -79,6 +80,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
.
def_readwrite
(
"hotwords_score"
,
&
PyClass
::
hotwords_score
)
.
def_readwrite
(
"blank_penalty"
,
&
PyClass
::
blank_penalty
)
.
def_readwrite
(
"temperature_scale"
,
&
PyClass
::
temperature_scale
)
.
def_readwrite
(
"rule_fsts"
,
&
PyClass
::
rule_fsts
)
.
def_readwrite
(
"rule_fars"
,
&
PyClass
::
rule_fars
)
.
def
(
"__str__"
,
&
PyClass
::
ToString
);
}
...
...
sherpa-onnx/python/sherpa_onnx/online_recognizer.py
查看文件 @
349d957
...
...
@@ -64,6 +64,8 @@ class OnlineRecognizer(object):
lm_scale
:
float
=
0.1
,
temperature_scale
:
float
=
2.0
,
debug
:
bool
=
False
,
rule_fsts
:
str
=
""
,
rule_fars
:
str
=
""
,
):
"""
Please refer to
...
...
@@ -148,6 +150,12 @@ class OnlineRecognizer(object):
the log probability, you can get it from the directory where
your bpe model is generated. Only used when hotwords provided
and the modeling unit is bpe or cjkchar+bpe.
rule_fsts:
If not empty, it specifies fsts for inverse text normalization.
If there are multiple fsts, they are separated by a comma.
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
"""
self
=
cls
.
__new__
(
cls
)
_assert_file_exists
(
tokens
)
...
...
@@ -217,6 +225,8 @@ class OnlineRecognizer(object):
hotwords_file
=
hotwords_file
,
blank_penalty
=
blank_penalty
,
temperature_scale
=
temperature_scale
,
rule_fsts
=
rule_fsts
,
rule_fars
=
rule_fars
,
)
self
.
recognizer
=
_Recognizer
(
recognizer_config
)
...
...
@@ -239,6 +249,8 @@ class OnlineRecognizer(object):
decoding_method
:
str
=
"greedy_search"
,
provider
:
str
=
"cpu"
,
debug
:
bool
=
False
,
rule_fsts
:
str
=
""
,
rule_fars
:
str
=
""
,
):
"""
Please refer to
...
...
@@ -283,6 +295,12 @@ class OnlineRecognizer(object):
The only valid value is greedy_search.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
rule_fsts:
If not empty, it specifies fsts for inverse text normalization.
If there are multiple fsts, they are separated by a comma.
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
"""
self
=
cls
.
__new__
(
cls
)
_assert_file_exists
(
tokens
)
...
...
@@ -322,6 +340,8 @@ class OnlineRecognizer(object):
endpoint_config
=
endpoint_config
,
enable_endpoint
=
enable_endpoint_detection
,
decoding_method
=
decoding_method
,
rule_fsts
=
rule_fsts
,
rule_fars
=
rule_fars
,
)
self
.
recognizer
=
_Recognizer
(
recognizer_config
)
...
...
@@ -345,6 +365,8 @@ class OnlineRecognizer(object):
ctc_max_active
:
int
=
3000
,
provider
:
str
=
"cpu"
,
debug
:
bool
=
False
,
rule_fsts
:
str
=
""
,
rule_fars
:
str
=
""
,
):
"""
Please refer to
...
...
@@ -393,6 +415,12 @@ class OnlineRecognizer(object):
active paths at a time.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
rule_fsts:
If not empty, it specifies fsts for inverse text normalization.
If there are multiple fsts, they are separated by a comma.
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
"""
self
=
cls
.
__new__
(
cls
)
_assert_file_exists
(
tokens
)
...
...
@@ -433,6 +461,8 @@ class OnlineRecognizer(object):
ctc_fst_decoder_config
=
ctc_fst_decoder_config
,
enable_endpoint
=
enable_endpoint_detection
,
decoding_method
=
decoding_method
,
rule_fsts
=
rule_fsts
,
rule_fars
=
rule_fars
,
)
self
.
recognizer
=
_Recognizer
(
recognizer_config
)
...
...
@@ -454,6 +484,8 @@ class OnlineRecognizer(object):
decoding_method
:
str
=
"greedy_search"
,
provider
:
str
=
"cpu"
,
debug
:
bool
=
False
,
rule_fsts
:
str
=
""
,
rule_fars
:
str
=
""
,
):
"""
Please refer to
...
...
@@ -497,6 +529,12 @@ class OnlineRecognizer(object):
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
debug:
True to show meta data in the model.
rule_fsts:
If not empty, it specifies fsts for inverse text normalization.
If there are multiple fsts, they are separated by a comma.
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
"""
self
=
cls
.
__new__
(
cls
)
_assert_file_exists
(
tokens
)
...
...
@@ -533,6 +571,8 @@ class OnlineRecognizer(object):
endpoint_config
=
endpoint_config
,
enable_endpoint
=
enable_endpoint_detection
,
decoding_method
=
decoding_method
,
rule_fsts
=
rule_fsts
,
rule_fars
=
rule_fars
,
)
self
.
recognizer
=
_Recognizer
(
recognizer_config
)
...
...
@@ -556,6 +596,8 @@ class OnlineRecognizer(object):
decoding_method
:
str
=
"greedy_search"
,
provider
:
str
=
"cpu"
,
debug
:
bool
=
False
,
rule_fsts
:
str
=
""
,
rule_fars
:
str
=
""
,
):
"""
Please refer to
...
...
@@ -602,6 +644,12 @@ class OnlineRecognizer(object):
The only valid value is greedy_search.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
rule_fsts:
If not empty, it specifies fsts for inverse text normalization.
If there are multiple fsts, they are separated by a comma.
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
"""
self
=
cls
.
__new__
(
cls
)
_assert_file_exists
(
tokens
)
...
...
@@ -640,6 +688,8 @@ class OnlineRecognizer(object):
endpoint_config
=
endpoint_config
,
enable_endpoint
=
enable_endpoint_detection
,
decoding_method
=
decoding_method
,
rule_fsts
=
rule_fsts
,
rule_fars
=
rule_fars
,
)
self
.
recognizer
=
_Recognizer
(
recognizer_config
)
...
...
请
注册
或
登录
后发表评论