Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2024-06-17 14:28:53 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2024-06-17 14:28:53 +0800
Commit
b0f7ed3ee30818d347d6660b84bf9c8597200d7b
b0f7ed3e
1 parent
dd69a1b5
Add inverse text normalization for non-streaming ASR (#1017)
显示空白字符变更
内嵌
并排对比
正在显示
13 个修改的文件
包含
380 行增加
和
19 行删除
.github/scripts/test-python.sh
python-api-examples/inverse-text-normalization-offline-asr.py
sherpa-onnx/csrc/offline-recognizer-ctc-impl.h
sherpa-onnx/csrc/offline-recognizer-impl.cc
sherpa-onnx/csrc/offline-recognizer-impl.h
sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h
sherpa-onnx/csrc/offline-recognizer-transducer-impl.h
sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h
sherpa-onnx/csrc/offline-recognizer-whisper-impl.h
sherpa-onnx/csrc/offline-recognizer.cc
sherpa-onnx/csrc/offline-recognizer.h
sherpa-onnx/python/csrc/offline-recognizer.cc
sherpa-onnx/python/sherpa_onnx/offline_recognizer.py
.github/scripts/test-python.sh
查看文件 @
b0f7ed3
...
...
@@ -248,7 +248,7 @@ if [[ x$OS != x'windows-latest' ]]; then
python3 ./python-api-examples/online-decode-files.py
\
--tokens
=
$repo
/tokens.txt
\
--encoder
=
$repo
/encoder-epoch-99-avg-1.int8.onnx
\
--decoder
=
$repo
/decoder-epoch-99-avg-1.
int8.
onnx
\
--decoder
=
$repo
/decoder-epoch-99-avg-1.onnx
\
--joiner
=
$repo
/joiner-epoch-99-avg-1.int8.onnx
\
$repo
/test_wavs/0.wav
\
$repo
/test_wavs/1.wav
\
...
...
@@ -286,7 +286,7 @@ python3 ./python-api-examples/offline-decode-files.py \
python3 ./python-api-examples/offline-decode-files.py
\
--tokens
=
$repo
/tokens.txt
\
--encoder
=
$repo
/encoder-epoch-99-avg-1.int8.onnx
\
--decoder
=
$repo
/decoder-epoch-99-avg-1.
int8.
onnx
\
--decoder
=
$repo
/decoder-epoch-99-avg-1.onnx
\
--joiner
=
$repo
/joiner-epoch-99-avg-1.int8.onnx
\
$repo
/test_wavs/0.wav
\
$repo
/test_wavs/1.wav
\
...
...
@@ -330,6 +330,15 @@ if [[ x$OS != x'windows-latest' ]]; then
python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
ln -s
$repo
$PWD
/
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn_zh_number.fst
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn-zh-number.wav
python3 ./python-api-examples/inverse-text-normalization-offline-asr.py
rm -rfv sherpa-onnx-paraformer-zh-2023-03-28
rm -rf
$repo
fi
...
...
python-api-examples/inverse-text-normalization-offline-asr.py
0 → 100755
查看文件 @
b0f7ed3
#!/usr/bin/env python3
#
# Copyright (c) 2024 Xiaomi Corporation
"""
This script shows how to use inverse text normalization with non-streaming ASR.
Usage:
(1) Download the test model
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2
tar xvf sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2
rm sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2
(2) Download rule fst
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn_zh_number.fst
Please refer to
https://github.com/k2-fsa/colab/blob/master/sherpa-onnx/itn_zh_number.ipynb
for how itn_zh_number.fst is generated.
(3) Download test wave
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn-zh-number.wav
(4) Run this script
python3 ./python-api-examples/inverse-text-normalization-offline-asr.py
"""
from
pathlib
import
Path
import
sherpa_onnx
import
soundfile
as
sf
def
create_recognizer
():
model
=
"./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx"
tokens
=
"./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt"
rule_fsts
=
"./itn_zh_number.fst"
if
(
not
Path
(
model
)
.
is_file
()
or
not
Path
(
tokens
)
.
is_file
()
or
not
Path
(
rule_fsts
)
.
is_file
()
):
raise
ValueError
(
"""Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
"""
)
return
sherpa_onnx
.
OfflineRecognizer
.
from_paraformer
(
paraformer
=
model
,
tokens
=
tokens
,
debug
=
True
,
rule_fsts
=
rule_fsts
,
)
def
main
():
recognizer
=
create_recognizer
()
wave_filename
=
"./itn-zh-number.wav"
if
not
Path
(
wave_filename
)
.
is_file
():
raise
ValueError
(
"""Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
"""
)
audio
,
sample_rate
=
sf
.
read
(
wave_filename
,
dtype
=
"float32"
,
always_2d
=
True
)
audio
=
audio
[:,
0
]
# only use the first channel
stream
=
recognizer
.
create_stream
()
stream
.
accept_waveform
(
sample_rate
,
audio
)
recognizer
.
decode_stream
(
stream
)
print
(
wave_filename
)
print
(
stream
.
result
)
if
__name__
==
"__main__"
:
main
()
...
...
sherpa-onnx/csrc/offline-recognizer-ctc-impl.h
查看文件 @
b0f7ed3
...
...
@@ -73,7 +73,8 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
class
OfflineRecognizerCtcImpl
:
public
OfflineRecognizerImpl
{
public
:
explicit
OfflineRecognizerCtcImpl
(
const
OfflineRecognizerConfig
&
config
)
:
config_
(
config
),
:
OfflineRecognizerImpl
(
config
),
config_
(
config
),
symbol_table_
(
config_
.
model_config
.
tokens
),
model_
(
OfflineCtcModel
::
Create
(
config_
.
model_config
))
{
Init
();
...
...
@@ -82,7 +83,8 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
#if __ANDROID_API__ >= 9
OfflineRecognizerCtcImpl
(
AAssetManager
*
mgr
,
const
OfflineRecognizerConfig
&
config
)
:
config_
(
config
),
:
OfflineRecognizerImpl
(
mgr
,
config
),
config_
(
config
),
symbol_table_
(
mgr
,
config_
.
model_config
.
tokens
),
model_
(
OfflineCtcModel
::
Create
(
mgr
,
config_
.
model_config
))
{
Init
();
...
...
@@ -205,6 +207,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
for
(
int32_t
i
=
0
;
i
!=
n
;
++
i
)
{
auto
r
=
Convert
(
results
[
i
],
symbol_table_
,
frame_shift_ms
,
model_
->
SubsamplingFactor
());
r
.
text
=
ApplyInverseTextNormalization
(
std
::
move
(
r
.
text
));
ss
[
i
]
->
SetResult
(
r
);
}
}
...
...
@@ -238,6 +241,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
auto
r
=
Convert
(
results
[
0
],
symbol_table_
,
frame_shift_ms
,
model_
->
SubsamplingFactor
());
r
.
text
=
ApplyInverseTextNormalization
(
std
::
move
(
r
.
text
));
s
->
SetResult
(
r
);
}
...
...
sherpa-onnx/csrc/offline-recognizer-impl.cc
查看文件 @
b0f7ed3
...
...
@@ -5,7 +5,18 @@
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include <strstream>
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "fst/extensions/far/far.h"
#include "kaldifst/csrc/kaldi-fst-io.h"
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
...
...
@@ -316,4 +327,111 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
}
#endif
OfflineRecognizerImpl
::
OfflineRecognizerImpl
(
const
OfflineRecognizerConfig
&
config
)
:
config_
(
config
)
{
if
(
!
config
.
rule_fsts
.
empty
())
{
std
::
vector
<
std
::
string
>
files
;
SplitStringToVector
(
config
.
rule_fsts
,
","
,
false
,
&
files
);
itn_list_
.
reserve
(
files
.
size
());
for
(
const
auto
&
f
:
files
)
{
if
(
config
.
model_config
.
debug
)
{
SHERPA_ONNX_LOGE
(
"rule fst: %s"
,
f
.
c_str
());
}
itn_list_
.
push_back
(
std
::
make_unique
<
kaldifst
::
TextNormalizer
>
(
f
));
}
}
if
(
!
config
.
rule_fars
.
empty
())
{
if
(
config
.
model_config
.
debug
)
{
SHERPA_ONNX_LOGE
(
"Loading FST archives"
);
}
std
::
vector
<
std
::
string
>
files
;
SplitStringToVector
(
config
.
rule_fars
,
","
,
false
,
&
files
);
itn_list_
.
reserve
(
files
.
size
()
+
itn_list_
.
size
());
for
(
const
auto
&
f
:
files
)
{
if
(
config
.
model_config
.
debug
)
{
SHERPA_ONNX_LOGE
(
"rule far: %s"
,
f
.
c_str
());
}
std
::
unique_ptr
<
fst
::
FarReader
<
fst
::
StdArc
>>
reader
(
fst
::
FarReader
<
fst
::
StdArc
>::
Open
(
f
));
for
(;
!
reader
->
Done
();
reader
->
Next
())
{
std
::
unique_ptr
<
fst
::
StdConstFst
>
r
(
fst
::
CastOrConvertToConstFst
(
reader
->
GetFst
()
->
Copy
()));
itn_list_
.
push_back
(
std
::
make_unique
<
kaldifst
::
TextNormalizer
>
(
std
::
move
(
r
)));
}
}
if
(
config
.
model_config
.
debug
)
{
SHERPA_ONNX_LOGE
(
"FST archives loaded!"
);
}
}
}
#if __ANDROID_API__ >= 9
OfflineRecognizerImpl
::
OfflineRecognizerImpl
(
AAssetManager
*
mgr
,
const
OfflineRecognizerConfig
&
config
)
:
config_
(
config
)
{
if
(
!
config
.
rule_fsts
.
empty
())
{
std
::
vector
<
std
::
string
>
files
;
SplitStringToVector
(
config
.
rule_fsts
,
","
,
false
,
&
files
);
itn_list_
.
reserve
(
files
.
size
());
for
(
const
auto
&
f
:
files
)
{
if
(
config
.
model_config
.
debug
)
{
SHERPA_ONNX_LOGE
(
"rule fst: %s"
,
f
.
c_str
());
}
auto
buf
=
ReadFile
(
mgr
,
f
);
std
::
istrstream
is
(
buf
.
data
(),
buf
.
size
());
itn_list_
.
push_back
(
std
::
make_unique
<
kaldifst
::
TextNormalizer
>
(
is
));
}
}
if
(
!
config
.
rule_fars
.
empty
())
{
std
::
vector
<
std
::
string
>
files
;
SplitStringToVector
(
config
.
rule_fars
,
","
,
false
,
&
files
);
itn_list_
.
reserve
(
files
.
size
()
+
itn_list_
.
size
());
for
(
const
auto
&
f
:
files
)
{
if
(
config
.
model_config
.
debug
)
{
SHERPA_ONNX_LOGE
(
"rule far: %s"
,
f
.
c_str
());
}
auto
buf
=
ReadFile
(
mgr
,
f
);
std
::
unique_ptr
<
std
::
istream
>
s
(
new
std
::
istrstream
(
buf
.
data
(),
buf
.
size
()));
std
::
unique_ptr
<
fst
::
FarReader
<
fst
::
StdArc
>>
reader
(
fst
::
FarReader
<
fst
::
StdArc
>::
Open
(
std
::
move
(
s
)));
for
(;
!
reader
->
Done
();
reader
->
Next
())
{
std
::
unique_ptr
<
fst
::
StdConstFst
>
r
(
fst
::
CastOrConvertToConstFst
(
reader
->
GetFst
()
->
Copy
()));
itn_list_
.
push_back
(
std
::
make_unique
<
kaldifst
::
TextNormalizer
>
(
std
::
move
(
r
)));
}
// for (; !reader->Done(); reader->Next())
}
// for (const auto &f : files)
}
// if (!config.rule_fars.empty())
}
#endif
std
::
string
OfflineRecognizerImpl
::
ApplyInverseTextNormalization
(
std
::
string
text
)
const
{
if
(
!
itn_list_
.
empty
())
{
for
(
const
auto
&
tn
:
itn_list_
)
{
text
=
tn
->
Normalize
(
text
);
if
(
config_
.
model_config
.
debug
)
{
SHERPA_ONNX_LOGE
(
"After inverse text normalization: %s"
,
text
.
c_str
());
}
}
}
return
text
;
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/offline-recognizer-impl.h
查看文件 @
b0f7ed3
...
...
@@ -14,6 +14,7 @@
#include "android/asset_manager_jni.h"
#endif
#include "kaldifst/csrc/text-normalizer.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/offline-stream.h"
...
...
@@ -22,10 +23,15 @@ namespace sherpa_onnx {
class
OfflineRecognizerImpl
{
public
:
explicit
OfflineRecognizerImpl
(
const
OfflineRecognizerConfig
&
config
);
static
std
::
unique_ptr
<
OfflineRecognizerImpl
>
Create
(
const
OfflineRecognizerConfig
&
config
);
#if __ANDROID_API__ >= 9
OfflineRecognizerImpl
(
AAssetManager
*
mgr
,
const
OfflineRecognizerConfig
&
config
);
static
std
::
unique_ptr
<
OfflineRecognizerImpl
>
Create
(
AAssetManager
*
mgr
,
const
OfflineRecognizerConfig
&
config
);
#endif
...
...
@@ -41,6 +47,15 @@ class OfflineRecognizerImpl {
virtual
std
::
unique_ptr
<
OfflineStream
>
CreateStream
()
const
=
0
;
virtual
void
DecodeStreams
(
OfflineStream
**
ss
,
int32_t
n
)
const
=
0
;
std
::
string
ApplyInverseTextNormalization
(
std
::
string
text
)
const
;
private
:
OfflineRecognizerConfig
config_
;
// for inverse text normalization. Used only if
// config.rule_fsts is not empty or
// config.rule_fars is not empty
std
::
vector
<
std
::
unique_ptr
<
kaldifst
::
TextNormalizer
>>
itn_list_
;
};
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h
查看文件 @
b0f7ed3
...
...
@@ -89,7 +89,8 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
public
:
explicit
OfflineRecognizerParaformerImpl
(
const
OfflineRecognizerConfig
&
config
)
:
config_
(
config
),
:
OfflineRecognizerImpl
(
config
),
config_
(
config
),
symbol_table_
(
config_
.
model_config
.
tokens
),
model_
(
std
::
make_unique
<
OfflineParaformerModel
>
(
config
.
model_config
))
{
if
(
config
.
decoding_method
==
"greedy_search"
)
{
...
...
@@ -109,7 +110,8 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
#if __ANDROID_API__ >= 9
OfflineRecognizerParaformerImpl
(
AAssetManager
*
mgr
,
const
OfflineRecognizerConfig
&
config
)
:
config_
(
config
),
:
OfflineRecognizerImpl
(
mgr
,
config
),
config_
(
config
),
symbol_table_
(
mgr
,
config_
.
model_config
.
tokens
),
model_
(
std
::
make_unique
<
OfflineParaformerModel
>
(
mgr
,
config
.
model_config
))
{
...
...
@@ -204,6 +206,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
for
(
int32_t
i
=
0
;
i
!=
n
;
++
i
)
{
auto
r
=
Convert
(
results
[
i
],
symbol_table_
);
r
.
text
=
ApplyInverseTextNormalization
(
std
::
move
(
r
.
text
));
ss
[
i
]
->
SetResult
(
r
);
}
}
...
...
sherpa-onnx/csrc/offline-recognizer-transducer-impl.h
查看文件 @
b0f7ed3
...
...
@@ -74,7 +74,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
public
:
explicit
OfflineRecognizerTransducerImpl
(
const
OfflineRecognizerConfig
&
config
)
:
config_
(
config
),
:
OfflineRecognizerImpl
(
config
),
config_
(
config
),
symbol_table_
(
config_
.
model_config
.
tokens
),
model_
(
std
::
make_unique
<
OfflineTransducerModel
>
(
config_
.
model_config
))
{
if
(
config_
.
decoding_method
==
"greedy_search"
)
{
...
...
@@ -107,7 +108,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
#if __ANDROID_API__ >= 9
explicit
OfflineRecognizerTransducerImpl
(
AAssetManager
*
mgr
,
const
OfflineRecognizerConfig
&
config
)
:
config_
(
config
),
:
OfflineRecognizerImpl
(
mgr
,
config
),
config_
(
config
),
symbol_table_
(
mgr
,
config_
.
model_config
.
tokens
),
model_
(
std
::
make_unique
<
OfflineTransducerModel
>
(
mgr
,
config_
.
model_config
))
{
...
...
@@ -230,6 +232,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
for
(
int32_t
i
=
0
;
i
!=
n
;
++
i
)
{
auto
r
=
Convert
(
results
[
i
],
symbol_table_
,
frame_shift_ms
,
model_
->
SubsamplingFactor
());
r
.
text
=
ApplyInverseTextNormalization
(
std
::
move
(
r
.
text
));
ss
[
i
]
->
SetResult
(
r
);
}
...
...
sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h
查看文件 @
b0f7ed3
...
...
@@ -41,7 +41,8 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
public
:
explicit
OfflineRecognizerTransducerNeMoImpl
(
const
OfflineRecognizerConfig
&
config
)
:
config_
(
config
),
:
OfflineRecognizerImpl
(
config
),
config_
(
config
),
symbol_table_
(
config_
.
model_config
.
tokens
),
model_
(
std
::
make_unique
<
OfflineTransducerNeMoModel
>
(
config_
.
model_config
))
{
...
...
@@ -59,7 +60,8 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
#if __ANDROID_API__ >= 9
explicit
OfflineRecognizerTransducerNeMoImpl
(
AAssetManager
*
mgr
,
const
OfflineRecognizerConfig
&
config
)
:
config_
(
config
),
:
OfflineRecognizerImpl
(
mgr
,
config
),
config_
(
config
),
symbol_table_
(
mgr
,
config_
.
model_config
.
tokens
),
model_
(
std
::
make_unique
<
OfflineTransducerNeMoModel
>
(
mgr
,
config_
.
model_config
))
{
...
...
@@ -131,6 +133,7 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
for
(
int32_t
i
=
0
;
i
!=
n
;
++
i
)
{
auto
r
=
Convert
(
results
[
i
],
symbol_table_
,
frame_shift_ms
,
model_
->
SubsamplingFactor
());
r
.
text
=
ApplyInverseTextNormalization
(
std
::
move
(
r
.
text
));
ss
[
i
]
->
SetResult
(
r
);
}
...
...
sherpa-onnx/csrc/offline-recognizer-whisper-impl.h
查看文件 @
b0f7ed3
...
...
@@ -52,7 +52,8 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
class
OfflineRecognizerWhisperImpl
:
public
OfflineRecognizerImpl
{
public
:
explicit
OfflineRecognizerWhisperImpl
(
const
OfflineRecognizerConfig
&
config
)
:
config_
(
config
),
:
OfflineRecognizerImpl
(
config
),
config_
(
config
),
symbol_table_
(
config_
.
model_config
.
tokens
),
model_
(
std
::
make_unique
<
OfflineWhisperModel
>
(
config
.
model_config
))
{
Init
();
...
...
@@ -61,7 +62,8 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
#if __ANDROID_API__ >= 9
OfflineRecognizerWhisperImpl
(
AAssetManager
*
mgr
,
const
OfflineRecognizerConfig
&
config
)
:
config_
(
config
),
:
OfflineRecognizerImpl
(
mgr
,
config
),
config_
(
config
),
symbol_table_
(
mgr
,
config_
.
model_config
.
tokens
),
model_
(
std
::
make_unique
<
OfflineWhisperModel
>
(
mgr
,
config
.
model_config
))
{
...
...
@@ -150,6 +152,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
std
::
move
(
cross_kv
.
second
));
auto
r
=
Convert
(
results
[
0
],
symbol_table_
);
r
.
text
=
ApplyInverseTextNormalization
(
std
::
move
(
r
.
text
));
s
->
SetResult
(
r
);
}
catch
(
const
Ort
::
Exception
&
ex
)
{
SHERPA_ONNX_LOGE
(
...
...
sherpa-onnx/csrc/offline-recognizer.cc
查看文件 @
b0f7ed3
...
...
@@ -10,7 +10,7 @@
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-lm-config.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace
sherpa_onnx
{
void
OfflineRecognizerConfig
::
Register
(
ParseOptions
*
po
)
{
...
...
@@ -44,6 +44,16 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
po
->
Register
(
"hotwords-score"
,
&
hotwords_score
,
"The bonus score for each token in context word/phrase. "
"Used only when decoding_method is modified_beam_search"
);
po
->
Register
(
"rule-fsts"
,
&
rule_fsts
,
"If not empty, it specifies fsts for inverse text normalization. "
"If there are multiple fsts, they are separated by a comma."
);
po
->
Register
(
"rule-fars"
,
&
rule_fars
,
"If not empty, it specifies fst archives for inverse text normalization. "
"If there are multiple archives, they are separated by a comma."
);
}
bool
OfflineRecognizerConfig
::
Validate
()
const
{
...
...
@@ -61,7 +71,7 @@ bool OfflineRecognizerConfig::Validate() const {
if
(
!
hotwords_file
.
empty
()
&&
decoding_method
!=
"modified_beam_search"
)
{
SHERPA_ONNX_LOGE
(
"Please use --decoding-method=modified_beam_search if you"
" provide --hotwords-file. Given --decoding-method=
%s
"
,
" provide --hotwords-file. Given --decoding-method=
'%s'
"
,
decoding_method
.
c_str
());
return
false
;
}
...
...
@@ -72,6 +82,34 @@ bool OfflineRecognizerConfig::Validate() const {
return
false
;
}
if
(
!
hotwords_file
.
empty
()
&&
!
FileExists
(
hotwords_file
))
{
SHERPA_ONNX_LOGE
(
"--hotwords-file: '%s' does not exist"
,
hotwords_file
.
c_str
());
return
false
;
}
if
(
!
rule_fsts
.
empty
())
{
std
::
vector
<
std
::
string
>
files
;
SplitStringToVector
(
rule_fsts
,
","
,
false
,
&
files
);
for
(
const
auto
&
f
:
files
)
{
if
(
!
FileExists
(
f
))
{
SHERPA_ONNX_LOGE
(
"Rule fst '%s' does not exist. "
,
f
.
c_str
());
return
false
;
}
}
}
if
(
!
rule_fars
.
empty
())
{
std
::
vector
<
std
::
string
>
files
;
SplitStringToVector
(
rule_fars
,
","
,
false
,
&
files
);
for
(
const
auto
&
f
:
files
)
{
if
(
!
FileExists
(
f
))
{
SHERPA_ONNX_LOGE
(
"Rule far '%s' does not exist. "
,
f
.
c_str
());
return
false
;
}
}
}
return
model_config
.
Validate
();
}
...
...
@@ -87,7 +125,9 @@ std::string OfflineRecognizerConfig::ToString() const {
os
<<
"max_active_paths="
<<
max_active_paths
<<
", "
;
os
<<
"hotwords_file=
\"
"
<<
hotwords_file
<<
"
\"
, "
;
os
<<
"hotwords_score="
<<
hotwords_score
<<
", "
;
os
<<
"blank_penalty="
<<
blank_penalty
<<
")"
;
os
<<
"blank_penalty="
<<
blank_penalty
<<
", "
;
os
<<
"rule_fsts=
\"
"
<<
rule_fsts
<<
"
\"
, "
;
os
<<
"rule_fars=
\"
"
<<
rule_fars
<<
"
\"
)"
;
return
os
.
str
();
}
...
...
sherpa-onnx/csrc/offline-recognizer.h
查看文件 @
b0f7ed3
...
...
@@ -40,6 +40,12 @@ struct OfflineRecognizerConfig {
float
blank_penalty
=
0
.
0
;
// If there are multiple rules, they are applied from left to right.
std
::
string
rule_fsts
;
// If there are multiple FST archives, they are applied from left to right.
std
::
string
rule_fars
;
// only greedy_search is implemented
// TODO(fangjun): Implement modified_beam_search
...
...
@@ -50,7 +56,8 @@ struct OfflineRecognizerConfig {
const
OfflineCtcFstDecoderConfig
&
ctc_fst_decoder_config
,
const
std
::
string
&
decoding_method
,
int32_t
max_active_paths
,
const
std
::
string
&
hotwords_file
,
float
hotwords_score
,
float
blank_penalty
)
float
blank_penalty
,
const
std
::
string
&
rule_fsts
,
const
std
::
string
&
rule_fars
)
:
feat_config
(
feat_config
),
model_config
(
model_config
),
lm_config
(
lm_config
),
...
...
@@ -59,7 +66,9 @@ struct OfflineRecognizerConfig {
max_active_paths
(
max_active_paths
),
hotwords_file
(
hotwords_file
),
hotwords_score
(
hotwords_score
),
blank_penalty
(
blank_penalty
)
{}
blank_penalty
(
blank_penalty
),
rule_fsts
(
rule_fsts
),
rule_fars
(
rule_fars
)
{}
void
Register
(
ParseOptions
*
po
);
bool
Validate
()
const
;
...
...
sherpa-onnx/python/csrc/offline-recognizer.cc
查看文件 @
b0f7ed3
...
...
@@ -17,13 +17,14 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
.
def
(
py
::
init
<
const
FeatureExtractorConfig
&
,
const
OfflineModelConfig
&
,
const
OfflineLMConfig
&
,
const
OfflineCtcFstDecoderConfig
&
,
const
std
::
string
&
,
int32_t
,
const
std
::
string
&
,
float
,
float
>
(),
float
,
const
std
::
string
&
,
const
std
::
string
&
>
(),
py
::
arg
(
"feat_config"
),
py
::
arg
(
"model_config"
),
py
::
arg
(
"lm_config"
)
=
OfflineLMConfig
(),
py
::
arg
(
"ctc_fst_decoder_config"
)
=
OfflineCtcFstDecoderConfig
(),
py
::
arg
(
"decoding_method"
)
=
"greedy_search"
,
py
::
arg
(
"max_active_paths"
)
=
4
,
py
::
arg
(
"hotwords_file"
)
=
""
,
py
::
arg
(
"hotwords_score"
)
=
1.5
,
py
::
arg
(
"blank_penalty"
)
=
0.0
)
py
::
arg
(
"hotwords_score"
)
=
1.5
,
py
::
arg
(
"blank_penalty"
)
=
0.0
,
py
::
arg
(
"rule_fsts"
)
=
""
,
py
::
arg
(
"rule_fars"
)
=
""
)
.
def_readwrite
(
"feat_config"
,
&
PyClass
::
feat_config
)
.
def_readwrite
(
"model_config"
,
&
PyClass
::
model_config
)
.
def_readwrite
(
"lm_config"
,
&
PyClass
::
lm_config
)
...
...
@@ -33,6 +34,8 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
.
def_readwrite
(
"hotwords_file"
,
&
PyClass
::
hotwords_file
)
.
def_readwrite
(
"hotwords_score"
,
&
PyClass
::
hotwords_score
)
.
def_readwrite
(
"blank_penalty"
,
&
PyClass
::
blank_penalty
)
.
def_readwrite
(
"rule_fsts"
,
&
PyClass
::
rule_fsts
)
.
def_readwrite
(
"rule_fars"
,
&
PyClass
::
rule_fars
)
.
def
(
"__str__"
,
&
PyClass
::
ToString
);
}
...
...
sherpa-onnx/python/sherpa_onnx/offline_recognizer.py
查看文件 @
b0f7ed3
...
...
@@ -54,6 +54,8 @@ class OfflineRecognizer(object):
debug
:
bool
=
False
,
provider
:
str
=
"cpu"
,
model_type
:
str
=
"transducer"
,
rule_fsts
:
str
=
""
,
rule_fars
:
str
=
""
,
):
"""
Please refer to
...
...
@@ -107,6 +109,12 @@ class OfflineRecognizer(object):
True to show debug messages.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
rule_fsts:
If not empty, it specifies fsts for inverse text normalization.
If there are multiple fsts, they are separated by a comma.
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
"""
self
=
cls
.
__new__
(
cls
)
model_config
=
OfflineModelConfig
(
...
...
@@ -143,6 +151,8 @@ class OfflineRecognizer(object):
hotwords_file
=
hotwords_file
,
hotwords_score
=
hotwords_score
,
blank_penalty
=
blank_penalty
,
rule_fsts
=
rule_fsts
,
rule_fars
=
rule_fars
,
)
self
.
recognizer
=
_Recognizer
(
recognizer_config
)
self
.
config
=
recognizer_config
...
...
@@ -159,6 +169,8 @@ class OfflineRecognizer(object):
decoding_method
:
str
=
"greedy_search"
,
debug
:
bool
=
False
,
provider
:
str
=
"cpu"
,
rule_fsts
:
str
=
""
,
rule_fars
:
str
=
""
,
):
"""
Please refer to
...
...
@@ -186,6 +198,12 @@ class OfflineRecognizer(object):
True to show debug messages.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
rule_fsts:
If not empty, it specifies fsts for inverse text normalization.
If there are multiple fsts, they are separated by a comma.
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
"""
self
=
cls
.
__new__
(
cls
)
model_config
=
OfflineModelConfig
(
...
...
@@ -206,6 +224,8 @@ class OfflineRecognizer(object):
feat_config
=
feat_config
,
model_config
=
model_config
,
decoding_method
=
decoding_method
,
rule_fsts
=
rule_fsts
,
rule_fars
=
rule_fars
,
)
self
.
recognizer
=
_Recognizer
(
recognizer_config
)
self
.
config
=
recognizer_config
...
...
@@ -222,6 +242,8 @@ class OfflineRecognizer(object):
decoding_method
:
str
=
"greedy_search"
,
debug
:
bool
=
False
,
provider
:
str
=
"cpu"
,
rule_fsts
:
str
=
""
,
rule_fars
:
str
=
""
,
):
"""
Please refer to
...
...
@@ -251,6 +273,12 @@ class OfflineRecognizer(object):
True to show debug messages.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
rule_fsts:
If not empty, it specifies fsts for inverse text normalization.
If there are multiple fsts, they are separated by a comma.
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
"""
self
=
cls
.
__new__
(
cls
)
model_config
=
OfflineModelConfig
(
...
...
@@ -271,6 +299,8 @@ class OfflineRecognizer(object):
feat_config
=
feat_config
,
model_config
=
model_config
,
decoding_method
=
decoding_method
,
rule_fsts
=
rule_fsts
,
rule_fars
=
rule_fars
,
)
self
.
recognizer
=
_Recognizer
(
recognizer_config
)
self
.
config
=
recognizer_config
...
...
@@ -287,6 +317,8 @@ class OfflineRecognizer(object):
decoding_method
:
str
=
"greedy_search"
,
debug
:
bool
=
False
,
provider
:
str
=
"cpu"
,
rule_fsts
:
str
=
""
,
rule_fars
:
str
=
""
,
):
"""
Please refer to
...
...
@@ -315,6 +347,12 @@ class OfflineRecognizer(object):
True to show debug messages.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
rule_fsts:
If not empty, it specifies fsts for inverse text normalization.
If there are multiple fsts, they are separated by a comma.
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
"""
self
=
cls
.
__new__
(
cls
)
model_config
=
OfflineModelConfig
(
...
...
@@ -335,6 +373,8 @@ class OfflineRecognizer(object):
feat_config
=
feat_config
,
model_config
=
model_config
,
decoding_method
=
decoding_method
,
rule_fsts
=
rule_fsts
,
rule_fars
=
rule_fars
,
)
self
.
recognizer
=
_Recognizer
(
recognizer_config
)
self
.
config
=
recognizer_config
...
...
@@ -353,6 +393,8 @@ class OfflineRecognizer(object):
debug
:
bool
=
False
,
provider
:
str
=
"cpu"
,
tail_paddings
:
int
=
-
1
,
rule_fsts
:
str
=
""
,
rule_fars
:
str
=
""
,
):
"""
Please refer to
...
...
@@ -389,6 +431,12 @@ class OfflineRecognizer(object):
True to show debug messages.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
rule_fsts:
If not empty, it specifies fsts for inverse text normalization.
If there are multiple fsts, they are separated by a comma.
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
"""
self
=
cls
.
__new__
(
cls
)
model_config
=
OfflineModelConfig
(
...
...
@@ -415,6 +463,8 @@ class OfflineRecognizer(object):
feat_config
=
feat_config
,
model_config
=
model_config
,
decoding_method
=
decoding_method
,
rule_fsts
=
rule_fsts
,
rule_fars
=
rule_fars
,
)
self
.
recognizer
=
_Recognizer
(
recognizer_config
)
self
.
config
=
recognizer_config
...
...
@@ -431,6 +481,8 @@ class OfflineRecognizer(object):
decoding_method
:
str
=
"greedy_search"
,
debug
:
bool
=
False
,
provider
:
str
=
"cpu"
,
rule_fsts
:
str
=
""
,
rule_fars
:
str
=
""
,
):
"""
Please refer to
...
...
@@ -458,6 +510,12 @@ class OfflineRecognizer(object):
True to show debug messages.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
rule_fsts:
If not empty, it specifies fsts for inverse text normalization.
If there are multiple fsts, they are separated by a comma.
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
"""
self
=
cls
.
__new__
(
cls
)
model_config
=
OfflineModelConfig
(
...
...
@@ -478,6 +536,8 @@ class OfflineRecognizer(object):
feat_config
=
feat_config
,
model_config
=
model_config
,
decoding_method
=
decoding_method
,
rule_fsts
=
rule_fsts
,
rule_fars
=
rule_fars
,
)
self
.
recognizer
=
_Recognizer
(
recognizer_config
)
self
.
config
=
recognizer_config
...
...
@@ -494,6 +554,8 @@ class OfflineRecognizer(object):
decoding_method
:
str
=
"greedy_search"
,
debug
:
bool
=
False
,
provider
:
str
=
"cpu"
,
rule_fsts
:
str
=
""
,
rule_fars
:
str
=
""
,
):
"""
Please refer to
...
...
@@ -522,6 +584,12 @@ class OfflineRecognizer(object):
True to show debug messages.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
rule_fsts:
If not empty, it specifies fsts for inverse text normalization.
If there are multiple fsts, they are separated by a comma.
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
"""
self
=
cls
.
__new__
(
cls
)
model_config
=
OfflineModelConfig
(
...
...
@@ -542,6 +610,8 @@ class OfflineRecognizer(object):
feat_config
=
feat_config
,
model_config
=
model_config
,
decoding_method
=
decoding_method
,
rule_fsts
=
rule_fsts
,
rule_fars
=
rule_fars
,
)
self
.
recognizer
=
_Recognizer
(
recognizer_config
)
self
.
config
=
recognizer_config
...
...
请
注册
或
登录
后发表评论