Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2023-11-16 14:20:41 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2023-11-16 14:20:41 +0800
Commit
049fb9f45139291d408d0c5b44058e2bb00c79ab
049fb9f4
1 parent
fac4f6bc
Add Python APIs for WeNet CTC models (#428)
显示空白字符变更
内嵌
并排对比
正在显示
13 个修改的文件
包含
536 行增加
和
9 行删除
.github/scripts/test-python.sh
.github/workflows/mfc.yaml
.github/workflows/run-python-test.yaml
CMakeLists.txt
python-api-examples/generate-subtitles.py
python-api-examples/non_streaming_server.py
python-api-examples/offline-decode-files.py
python-api-examples/online-decode-files.py
python-api-examples/streaming_server.py
sherpa-onnx/python/sherpa_onnx/offline_recognizer.py
sherpa-onnx/python/sherpa_onnx/online_recognizer.py
sherpa-onnx/python/tests/test_offline_recognizer.py
sherpa-onnx/python/tests/test_online_recognizer.py
.github/scripts/test-python.sh
查看文件 @
049fb9f
...
...
@@ -8,6 +8,51 @@ log() {
echo
-e
"
$(
date
'+%Y-%m-%d %H:%M:%S'
)
(
${
fname
}
:
${
BASH_LINENO
[0]
}
:
${
FUNCNAME
[1]
}
)
$*
"
}
wenet_models
=(
sherpa-onnx-zh-wenet-aishell
sherpa-onnx-zh-wenet-aishell2
sherpa-onnx-zh-wenet-wenetspeech
sherpa-onnx-zh-wenet-multi-cn
sherpa-onnx-en-wenet-librispeech
sherpa-onnx-en-wenet-gigaspeech
)
mkdir -p /tmp/icefall-models
dir
=
/tmp/icefall-models
for
name
in
${
wenet_models
[@]
}
;
do
repo_url
=
https://huggingface.co/csukuangfj/
$name
log
"Start testing
${
repo_url
}
"
repo
=
$dir
/
$(
basename
$repo_url
)
log
"Download pretrained model and test-data from
$repo_url
"
pushd
$dir
GIT_LFS_SKIP_SMUDGE
=
1 git clone
$repo_url
cd
$repo
git lfs pull --include
"*.onnx"
ls -lh
*
.onnx
popd
python3 ./python-api-examples/offline-decode-files.py
\
--tokens
=
$repo
/tokens.txt
\
--wenet-ctc
=
$repo
/model.onnx
\
$repo
/test_wavs/0.wav
\
$repo
/test_wavs/1.wav
\
$repo
/test_wavs/8k.wav
python3 ./python-api-examples/online-decode-files.py
\
--tokens
=
$repo
/tokens.txt
\
--wenet-ctc
=
$repo
/model-streaming.onnx
\
$repo
/test_wavs/0.wav
\
$repo
/test_wavs/1.wav
\
$repo
/test_wavs/8k.wav
python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
python3 sherpa-onnx/python/tests/test_online_recognizer.py --verbose
rm -rf
$repo
done
log
"Offline TTS test"
# test waves are saved in ./tts
mkdir ./tts
...
...
.github/workflows/mfc.yaml
查看文件 @
049fb9f
...
...
@@ -85,10 +85,19 @@ jobs:
arch=${{ matrix.arch }}
cd mfc-examples/$arch/Release
cp StreamingSpeechRecognition.exe sherpa-onnx-streaming-${SHERPA_ONNX_VERSION}.exe
cp NonStreamingSpeechRecognition.exe sherpa-onnx-non-streaming-${SHERPA_ONNX_VERSION}.exe
ls -lh
cp -v StreamingSpeechRecognition.exe sherpa-onnx-streaming-${SHERPA_ONNX_VERSION}.exe
cp -v NonStreamingSpeechRecognition.exe sherpa-onnx-non-streaming-${SHERPA_ONNX_VERSION}.exe
cp -v NonStreamingTextToSpeech.exe ../sherpa-onnx-non-streaming-tts-${SHERPA_ONNX_VERSION}.exe
ls -lh
-
name
:
Upload artifact tts
uses
:
actions/upload-artifact@v3
with
:
name
:
non-streaming-tts-${{ matrix.arch }}
path
:
./mfc-examples/${{ matrix.arch }}/Release/NonStreamingTextToSpeech.exe
-
name
:
Upload artifact
uses
:
actions/upload-artifact@v3
with
:
...
...
@@ -116,3 +125,11 @@ jobs:
file_glob
:
true
overwrite
:
true
file
:
./mfc-examples/${{ matrix.arch }}/Release/sherpa-onnx-non-streaming-*.exe
-
name
:
Release pre-compiled binaries and libs for Windows ${{ matrix.arch }}
if
:
github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' && github.event_name == 'push' && contains(github.ref, 'refs/tags/')
uses
:
svenstaro/upload-release-action@v2
with
:
file_glob
:
true
overwrite
:
true
file
:
./mfc-examples/${{ matrix.arch }}/sherpa-onnx-non-streaming-*.exe
...
...
.github/workflows/run-python-test.yaml
查看文件 @
049fb9f
...
...
@@ -10,6 +10,7 @@ on:
-
'
CMakeLists.txt'
-
'
cmake/**'
-
'
sherpa-onnx/csrc/*'
-
'
python-api-examples/**'
pull_request
:
branches
:
-
master
...
...
@@ -19,6 +20,7 @@ on:
-
'
CMakeLists.txt'
-
'
cmake/**'
-
'
sherpa-onnx/csrc/*'
-
'
python-api-examples/**'
workflow_dispatch
:
concurrency
:
...
...
CMakeLists.txt
查看文件 @
049fb9f
cmake_minimum_required
(
VERSION 3.13 FATAL_ERROR
)
project
(
sherpa-onnx
)
set
(
SHERPA_ONNX_VERSION
"1.8.
9
"
)
set
(
SHERPA_ONNX_VERSION
"1.8.
10
"
)
# Disable warning about
#
...
...
python-api-examples/generate-subtitles.py
查看文件 @
049fb9f
...
...
@@ -58,6 +58,15 @@ wget https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx
--num-threads=2
\
/path/to/test.mp4
(4) For WeNet CTC models
./python-api-examples/generate-subtitles.py
\
--silero-vad-model=/path/to/silero_vad.onnx
\
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx
\
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt
\
--num-threads=2
\
/path/to/test.mp4
Please refer to
https://k2-fsa.github.io/sherpa/onnx/index.html
to install sherpa-onnx and to download non-streaming pre-trained models
...
...
@@ -122,6 +131,13 @@ def get_args():
)
parser
.
add_argument
(
"--wenet-ctc"
,
default
=
""
,
type
=
str
,
help
=
"Path to the CTC model.onnx from WeNet"
,
)
parser
.
add_argument
(
"--num-threads"
,
type
=
int
,
default
=
1
,
...
...
@@ -215,6 +231,7 @@ def assert_file_exists(filename: str):
def
create_recognizer
(
args
)
->
sherpa_onnx
.
OfflineRecognizer
:
if
args
.
encoder
:
assert
len
(
args
.
paraformer
)
==
0
,
args
.
paraformer
assert
len
(
args
.
wenet_ctc
)
==
0
,
args
.
wenet_ctc
assert
len
(
args
.
whisper_encoder
)
==
0
,
args
.
whisper_encoder
assert
len
(
args
.
whisper_decoder
)
==
0
,
args
.
whisper_decoder
...
...
@@ -234,6 +251,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
debug
=
args
.
debug
,
)
elif
args
.
paraformer
:
assert
len
(
args
.
wenet_ctc
)
==
0
,
args
.
wenet_ctc
assert
len
(
args
.
whisper_encoder
)
==
0
,
args
.
whisper_encoder
assert
len
(
args
.
whisper_decoder
)
==
0
,
args
.
whisper_decoder
...
...
@@ -248,6 +266,21 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
decoding_method
=
args
.
decoding_method
,
debug
=
args
.
debug
,
)
elif
args
.
wenet_ctc
:
assert
len
(
args
.
whisper_encoder
)
==
0
,
args
.
whisper_encoder
assert
len
(
args
.
whisper_decoder
)
==
0
,
args
.
whisper_decoder
assert_file_exists
(
args
.
wenet_ctc
)
recognizer
=
sherpa_onnx
.
OfflineRecognizer
.
from_wenet_ctc
(
model
=
args
.
wenet_ctc
,
tokens
=
args
.
tokens
,
num_threads
=
args
.
num_threads
,
sample_rate
=
args
.
sample_rate
,
feature_dim
=
args
.
feature_dim
,
decoding_method
=
args
.
decoding_method
,
debug
=
args
.
debug
,
)
elif
args
.
whisper_encoder
:
assert_file_exists
(
args
.
whisper_encoder
)
assert_file_exists
(
args
.
whisper_decoder
)
...
...
python-api-examples/non_streaming_server.py
查看文件 @
049fb9f
...
...
@@ -58,7 +58,19 @@ python3 ./python-api-examples/non_streaming_server.py \
--nemo-ctc ./sherpa-onnx-nemo-ctc-en-conformer-medium/model.onnx
\
--tokens ./sherpa-onnx-nemo-ctc-en-conformer-medium/tokens.txt
(4) Use a Whisper model
(4) Use a non-streaming CTC model from WeNet
cd /path/to/sherpa-onnx
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-wenetspeech
cd sherpa-onnx-zh-wenet-wenetspeech
git lfs pull --include "*.onnx"
cd ..
python3 ./python-api-examples/non_streaming_server.py
\
--wenet-ctc ./sherpa-onnx-zh-wenet-wenetspeech/model.onnx
\
--tokens ./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt
(5) Use a Whisper model
cd /path/to/sherpa-onnx
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-tiny.en
...
...
@@ -210,6 +222,15 @@ def add_nemo_ctc_model_args(parser: argparse.ArgumentParser):
)
def
add_wenet_ctc_model_args
(
parser
:
argparse
.
ArgumentParser
):
parser
.
add_argument
(
"--wenet-ctc"
,
default
=
""
,
type
=
str
,
help
=
"Path to the model.onnx from WeNet CTC"
,
)
def
add_tdnn_ctc_model_args
(
parser
:
argparse
.
ArgumentParser
):
parser
.
add_argument
(
"--tdnn-model"
,
...
...
@@ -261,6 +282,7 @@ def add_model_args(parser: argparse.ArgumentParser):
add_transducer_model_args
(
parser
)
add_paraformer_model_args
(
parser
)
add_nemo_ctc_model_args
(
parser
)
add_wenet_ctc_model_args
(
parser
)
add_tdnn_ctc_model_args
(
parser
)
add_whisper_model_args
(
parser
)
...
...
@@ -804,6 +826,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
if
args
.
encoder
:
assert
len
(
args
.
paraformer
)
==
0
,
args
.
paraformer
assert
len
(
args
.
nemo_ctc
)
==
0
,
args
.
nemo_ctc
assert
len
(
args
.
wenet_ctc
)
==
0
,
args
.
wenet_ctc
assert
len
(
args
.
whisper_encoder
)
==
0
,
args
.
whisper_encoder
assert
len
(
args
.
whisper_decoder
)
==
0
,
args
.
whisper_decoder
assert
len
(
args
.
tdnn_model
)
==
0
,
args
.
tdnn_model
...
...
@@ -827,6 +850,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
)
elif
args
.
paraformer
:
assert
len
(
args
.
nemo_ctc
)
==
0
,
args
.
nemo_ctc
assert
len
(
args
.
wenet_ctc
)
==
0
,
args
.
wenet_ctc
assert
len
(
args
.
whisper_encoder
)
==
0
,
args
.
whisper_encoder
assert
len
(
args
.
whisper_decoder
)
==
0
,
args
.
whisper_decoder
assert
len
(
args
.
tdnn_model
)
==
0
,
args
.
tdnn_model
...
...
@@ -842,6 +866,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
decoding_method
=
args
.
decoding_method
,
)
elif
args
.
nemo_ctc
:
assert
len
(
args
.
wenet_ctc
)
==
0
,
args
.
wenet_ctc
assert
len
(
args
.
whisper_encoder
)
==
0
,
args
.
whisper_encoder
assert
len
(
args
.
whisper_decoder
)
==
0
,
args
.
whisper_decoder
assert
len
(
args
.
tdnn_model
)
==
0
,
args
.
tdnn_model
...
...
@@ -856,6 +881,21 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
feature_dim
=
args
.
feat_dim
,
decoding_method
=
args
.
decoding_method
,
)
elif
args
.
wenet_ctc
:
assert
len
(
args
.
whisper_encoder
)
==
0
,
args
.
whisper_encoder
assert
len
(
args
.
whisper_decoder
)
==
0
,
args
.
whisper_decoder
assert
len
(
args
.
tdnn_model
)
==
0
,
args
.
tdnn_model
assert_file_exists
(
args
.
wenet_ctc
)
recognizer
=
sherpa_onnx
.
OfflineRecognizer
.
from_wenet_ctc
(
model
=
args
.
wenet_ctc
,
tokens
=
args
.
tokens
,
num_threads
=
args
.
num_threads
,
sample_rate
=
args
.
sample_rate
,
feature_dim
=
args
.
feat_dim
,
decoding_method
=
args
.
decoding_method
,
)
elif
args
.
whisper_encoder
:
assert
len
(
args
.
tdnn_model
)
==
0
,
args
.
tdnn_model
assert_file_exists
(
args
.
whisper_encoder
)
...
...
python-api-examples/offline-decode-files.py
查看文件 @
049fb9f
...
...
@@ -59,7 +59,16 @@ python3 ./python-api-examples/offline-decode-files.py \
./sherpa-onnx-whisper-base.en/test_wavs/1.wav
\
./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
(5) For tdnn models of the yesno recipe from icefall
(5) For CTC models from WeNet
python3 ./python-api-examples/offline-decode-files.py
\
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx
\
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt
\
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav
\
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav
\
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
(6) For tdnn models of the yesno recipe from icefall
python3 ./python-api-examples/offline-decode-files.py
\
--sample-rate=8000
\
...
...
@@ -155,6 +164,13 @@ def get_args():
)
parser
.
add_argument
(
"--wenet-ctc"
,
default
=
""
,
type
=
str
,
help
=
"Path to the model.onnx from WeNet CTC"
,
)
parser
.
add_argument
(
"--tdnn-model"
,
default
=
""
,
type
=
str
,
...
...
@@ -254,6 +270,7 @@ def assert_file_exists(filename: str):
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
)
def
read_wave
(
wave_filename
:
str
)
->
Tuple
[
np
.
ndarray
,
int
]:
"""
Args:
...
...
@@ -287,6 +304,7 @@ def main():
if
args
.
encoder
:
assert
len
(
args
.
paraformer
)
==
0
,
args
.
paraformer
assert
len
(
args
.
nemo_ctc
)
==
0
,
args
.
nemo_ctc
assert
len
(
args
.
wenet_ctc
)
==
0
,
args
.
wenet_ctc
assert
len
(
args
.
whisper_encoder
)
==
0
,
args
.
whisper_encoder
assert
len
(
args
.
whisper_decoder
)
==
0
,
args
.
whisper_decoder
assert
len
(
args
.
tdnn_model
)
==
0
,
args
.
tdnn_model
...
...
@@ -310,6 +328,7 @@ def main():
)
elif
args
.
paraformer
:
assert
len
(
args
.
nemo_ctc
)
==
0
,
args
.
nemo_ctc
assert
len
(
args
.
wenet_ctc
)
==
0
,
args
.
wenet_ctc
assert
len
(
args
.
whisper_encoder
)
==
0
,
args
.
whisper_encoder
assert
len
(
args
.
whisper_decoder
)
==
0
,
args
.
whisper_decoder
assert
len
(
args
.
tdnn_model
)
==
0
,
args
.
tdnn_model
...
...
@@ -326,6 +345,7 @@ def main():
debug
=
args
.
debug
,
)
elif
args
.
nemo_ctc
:
assert
len
(
args
.
wenet_ctc
)
==
0
,
args
.
wenet_ctc
assert
len
(
args
.
whisper_encoder
)
==
0
,
args
.
whisper_encoder
assert
len
(
args
.
whisper_decoder
)
==
0
,
args
.
whisper_decoder
assert
len
(
args
.
tdnn_model
)
==
0
,
args
.
tdnn_model
...
...
@@ -341,6 +361,22 @@ def main():
decoding_method
=
args
.
decoding_method
,
debug
=
args
.
debug
,
)
elif
args
.
wenet_ctc
:
assert
len
(
args
.
whisper_encoder
)
==
0
,
args
.
whisper_encoder
assert
len
(
args
.
whisper_decoder
)
==
0
,
args
.
whisper_decoder
assert
len
(
args
.
tdnn_model
)
==
0
,
args
.
tdnn_model
assert_file_exists
(
args
.
wenet_ctc
)
recognizer
=
sherpa_onnx
.
OfflineRecognizer
.
from_wenet_ctc
(
model
=
args
.
wenet_ctc
,
tokens
=
args
.
tokens
,
num_threads
=
args
.
num_threads
,
sample_rate
=
args
.
sample_rate
,
feature_dim
=
args
.
feature_dim
,
decoding_method
=
args
.
decoding_method
,
debug
=
args
.
debug
,
)
elif
args
.
whisper_encoder
:
assert
len
(
args
.
tdnn_model
)
==
0
,
args
.
tdnn_model
assert_file_exists
(
args
.
whisper_encoder
)
...
...
python-api-examples/online-decode-files.py
查看文件 @
049fb9f
...
...
@@ -37,8 +37,25 @@ git lfs pull --include "*.onnx"
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/3.wav
\
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/8k.wav
(3) Streaming Conformer CTC from WeNet
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-wenetspeech
cd sherpa-onnx-zh-wenet-wenetspeech
git lfs pull --include "*.onnx"
./python-api-examples/online-decode-files.py
\
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt
\
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model-streaming.onnx
\
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav
\
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav
\
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
Please refer to
https://k2-fsa.github.io/sherpa/onnx/index.html
and
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/wenet/index.html
to install sherpa-onnx and to download streaming pre-trained models.
"""
import
argparse
...
...
@@ -93,6 +110,26 @@ def get_args():
)
parser
.
add_argument
(
"--wenet-ctc"
,
type
=
str
,
help
=
"Path to the wenet ctc model model"
,
)
parser
.
add_argument
(
"--wenet-ctc-chunk-size"
,
type
=
int
,
default
=
16
,
help
=
"The --chunk-size parameter for streaming WeNet models"
,
)
parser
.
add_argument
(
"--wenet-ctc-num-left-chunks"
,
type
=
int
,
default
=
4
,
help
=
"The --num-left-chunks parameter for streaming WeNet models"
,
)
parser
.
add_argument
(
"--num-threads"
,
type
=
int
,
default
=
1
,
...
...
@@ -249,6 +286,18 @@ def main():
feature_dim
=
80
,
decoding_method
=
"greedy_search"
,
)
elif
args
.
wenet_ctc
:
recognizer
=
sherpa_onnx
.
OnlineRecognizer
.
from_wenet_ctc
(
tokens
=
args
.
tokens
,
model
=
args
.
wenet_ctc
,
chunk_size
=
args
.
wenet_ctc_chunk_size
,
num_left_chunks
=
args
.
wenet_ctc_num_left_chunks
,
num_threads
=
args
.
num_threads
,
provider
=
args
.
provider
,
sample_rate
=
16000
,
feature_dim
=
80
,
decoding_method
=
"greedy_search"
,
)
else
:
raise
ValueError
(
"Please provide a model"
)
...
...
python-api-examples/streaming_server.py
查看文件 @
049fb9f
...
...
@@ -40,10 +40,17 @@ python3 ./python-api-examples/streaming_server.py \
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/wenet/index.html
to download pre-trained models.
The model in the above help messages is from
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
To use a WeNet streaming Conformer CTC model, please use
python3 ./python-api-examples/streaming_server.py
\
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt
\
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model-streaming.onnx
"""
import
argparse
...
...
@@ -131,6 +138,12 @@ def add_model_args(parser: argparse.ArgumentParser):
)
parser
.
add_argument
(
"--wenet-ctc"
,
type
=
str
,
help
=
"Path to the model.onnx from WeNet"
,
)
parser
.
add_argument
(
"--paraformer-encoder"
,
type
=
str
,
help
=
"Path to the paraformer encoder model"
,
...
...
@@ -212,7 +225,6 @@ def add_hotwords_args(parser: argparse.ArgumentParser):
)
def
add_modified_beam_search_args
(
parser
:
argparse
.
ArgumentParser
):
parser
.
add_argument
(
"--num-active-paths"
,
...
...
@@ -393,6 +405,20 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
rule3_min_utterance_length
=
args
.
rule3_min_utterance_length
,
provider
=
args
.
provider
,
)
elif
args
.
wenet_ctc
:
recognizer
=
sherpa_onnx
.
OnlineRecognizer
.
from_wenet_ctc
(
tokens
=
args
.
tokens
,
model
=
args
.
wenet_ctc
,
num_threads
=
args
.
num_threads
,
sample_rate
=
args
.
sample_rate
,
feature_dim
=
args
.
feat_dim
,
decoding_method
=
args
.
decoding_method
,
enable_endpoint_detection
=
args
.
use_endpoint
!=
0
,
rule1_min_trailing_silence
=
args
.
rule1_min_trailing_silence
,
rule2_min_trailing_silence
=
args
.
rule2_min_trailing_silence
,
rule3_min_utterance_length
=
args
.
rule3_min_utterance_length
,
provider
=
args
.
provider
,
)
else
:
raise
ValueError
(
"Please provide a model"
)
...
...
@@ -727,6 +753,8 @@ def check_args(args):
assert
Path
(
args
.
paraformer_decoder
)
.
is_file
(),
f
"{args.paraformer_decoder} does not exist"
elif
args
.
wenet_ctc
:
assert
Path
(
args
.
wenet_ctc
)
.
is_file
(),
f
"{args.wenet_ctc} does not exist"
else
:
raise
ValueError
(
"Please provide a model"
)
...
...
sherpa-onnx/python/sherpa_onnx/offline_recognizer.py
查看文件 @
049fb9f
...
...
@@ -9,15 +9,16 @@ from _sherpa_onnx import (
OfflineModelConfig
,
OfflineNemoEncDecCtcModelConfig
,
OfflineParaformerModelConfig
,
OfflineTdnnModelConfig
,
OfflineWhisperModelConfig
,
OfflineZipformerCtcModelConfig
,
)
from
_sherpa_onnx
import
OfflineRecognizer
as
_Recognizer
from
_sherpa_onnx
import
(
OfflineRecognizerConfig
,
OfflineStream
,
OfflineTdnnModelConfig
,
OfflineTransducerModelConfig
,
OfflineWenetCtcModelConfig
,
OfflineWhisperModelConfig
,
OfflineZipformerCtcModelConfig
,
)
...
...
@@ -389,6 +390,70 @@ class OfflineRecognizer(object):
self
.
config
=
recognizer_config
return
self
@classmethod
def
from_wenet_ctc
(
cls
,
model
:
str
,
tokens
:
str
,
num_threads
:
int
=
1
,
sample_rate
:
int
=
16000
,
feature_dim
:
int
=
80
,
decoding_method
:
str
=
"greedy_search"
,
debug
:
bool
=
False
,
provider
:
str
=
"cpu"
,
):
"""
Please refer to
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html>`_
to download pre-trained models for different languages, e.g., Chinese,
English, etc.
Args:
model:
Path to ``model.onnx``.
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
num_threads:
Number of threads for neural network computation.
sample_rate:
Sample rate of the training data used to train the model.
feature_dim:
Dimension of the feature used to train the model.
decoding_method:
Valid values are greedy_search.
debug:
True to show debug messages.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
"""
self
=
cls
.
__new__
(
cls
)
model_config
=
OfflineModelConfig
(
wenet_ctc
=
OfflineWenetCtcModelConfig
(
model
=
model
),
tokens
=
tokens
,
num_threads
=
num_threads
,
debug
=
debug
,
provider
=
provider
,
model_type
=
"wenet_ctc"
,
)
feat_config
=
OfflineFeatureExtractorConfig
(
sampling_rate
=
sample_rate
,
feature_dim
=
feature_dim
,
)
recognizer_config
=
OfflineRecognizerConfig
(
feat_config
=
feat_config
,
model_config
=
model_config
,
decoding_method
=
decoding_method
,
)
self
.
recognizer
=
_Recognizer
(
recognizer_config
)
self
.
config
=
recognizer_config
return
self
def
create_stream
(
self
,
hotwords
:
Optional
[
str
]
=
None
):
if
hotwords
is
None
:
return
self
.
recognizer
.
create_stream
()
...
...
sherpa-onnx/python/sherpa_onnx/online_recognizer.py
查看文件 @
049fb9f
...
...
@@ -12,6 +12,7 @@ from _sherpa_onnx import (
OnlineRecognizerConfig
,
OnlineStream
,
OnlineTransducerModelConfig
,
OnlineWenetCtcModelConfig
,
)
...
...
@@ -271,6 +272,112 @@ class OnlineRecognizer(object):
self
.
config
=
recognizer_config
return
self
@classmethod
def
from_wenet_ctc
(
cls
,
tokens
:
str
,
model
:
str
,
chunk_size
:
int
=
16
,
num_left_chunks
:
int
=
4
,
num_threads
:
int
=
2
,
sample_rate
:
float
=
16000
,
feature_dim
:
int
=
80
,
enable_endpoint_detection
:
bool
=
False
,
rule1_min_trailing_silence
:
float
=
2.4
,
rule2_min_trailing_silence
:
float
=
1.2
,
rule3_min_utterance_length
:
float
=
20.0
,
decoding_method
:
str
=
"greedy_search"
,
provider
:
str
=
"cpu"
,
):
"""
Please refer to
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/wenet/index.html>`_
to download pre-trained models for different languages, e.g., Chinese,
English, etc.
Args:
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
model:
Path to ``model.onnx``.
chunk_size:
The --chunk-size parameter from WeNet.
num_left_chunks:
The --num-left-chunks parameter from WeNet.
num_threads:
Number of threads for neural network computation.
sample_rate:
Sample rate of the training data used to train the model.
feature_dim:
Dimension of the feature used to train the model.
enable_endpoint_detection:
True to enable endpoint detection. False to disable endpoint
detection.
rule1_min_trailing_silence:
Used only when enable_endpoint_detection is True. If the duration
of trailing silence in seconds is larger than this value, we assume
an endpoint is detected.
rule2_min_trailing_silence:
Used only when enable_endpoint_detection is True. If we have decoded
something that is nonsilence and if the duration of trailing silence
in seconds is larger than this value, we assume an endpoint is
detected.
rule3_min_utterance_length:
Used only when enable_endpoint_detection is True. If the utterance
length in seconds is larger than this value, we assume an endpoint
is detected.
decoding_method:
The only valid value is greedy_search.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
"""
self
=
cls
.
__new__
(
cls
)
_assert_file_exists
(
tokens
)
_assert_file_exists
(
model
)
assert
num_threads
>
0
,
num_threads
wenet_ctc_config
=
OnlineWenetCtcModelConfig
(
model
=
model
,
chunk_size
=
chunk_size
,
num_left_chunks
=
num_left_chunks
,
)
model_config
=
OnlineModelConfig
(
wenet_ctc
=
wenet_ctc_config
,
tokens
=
tokens
,
num_threads
=
num_threads
,
provider
=
provider
,
model_type
=
"wenet_ctc"
,
)
feat_config
=
FeatureExtractorConfig
(
sampling_rate
=
sample_rate
,
feature_dim
=
feature_dim
,
)
endpoint_config
=
EndpointConfig
(
rule1_min_trailing_silence
=
rule1_min_trailing_silence
,
rule2_min_trailing_silence
=
rule2_min_trailing_silence
,
rule3_min_utterance_length
=
rule3_min_utterance_length
,
)
recognizer_config
=
OnlineRecognizerConfig
(
feat_config
=
feat_config
,
model_config
=
model_config
,
endpoint_config
=
endpoint_config
,
enable_endpoint
=
enable_endpoint_detection
,
decoding_method
=
decoding_method
,
)
self
.
recognizer
=
_Recognizer
(
recognizer_config
)
self
.
config
=
recognizer_config
return
self
def
create_stream
(
self
,
hotwords
:
Optional
[
str
]
=
None
):
if
hotwords
is
None
:
return
self
.
recognizer
.
create_stream
()
...
...
sherpa-onnx/python/tests/test_offline_recognizer.py
查看文件 @
049fb9f
...
...
@@ -267,6 +267,53 @@ class TestOfflineRecognizer(unittest.TestCase):
print
(
s1
.
result
.
text
)
print
(
s2
.
result
.
text
)
def
test_wenet_ctc
(
self
):
models
=
[
"sherpa-onnx-zh-wenet-aishell"
,
"sherpa-onnx-zh-wenet-aishell2"
,
"sherpa-onnx-zh-wenet-wenetspeech"
,
"sherpa-onnx-zh-wenet-multi-cn"
,
"sherpa-onnx-en-wenet-librispeech"
,
"sherpa-onnx-en-wenet-gigaspeech"
,
]
for
m
in
models
:
for
use_int8
in
[
True
,
False
]:
name
=
"model.int8.onnx"
if
use_int8
else
"model.onnx"
model
=
f
"{d}/{m}/{name}"
tokens
=
f
"{d}/{m}/tokens.txt"
wave0
=
f
"{d}/{m}/test_wavs/0.wav"
wave1
=
f
"{d}/{m}/test_wavs/1.wav"
wave2
=
f
"{d}/{m}/test_wavs/8k.wav"
if
not
Path
(
model
)
.
is_file
():
print
(
"skipping test_wenet_ctc()"
)
return
recognizer
=
sherpa_onnx
.
OfflineRecognizer
.
from_wenet_ctc
(
model
=
model
,
tokens
=
tokens
,
num_threads
=
1
,
provider
=
"cpu"
,
)
s0
=
recognizer
.
create_stream
()
samples0
,
sample_rate0
=
read_wave
(
wave0
)
s0
.
accept_waveform
(
sample_rate0
,
samples0
)
s1
=
recognizer
.
create_stream
()
samples1
,
sample_rate1
=
read_wave
(
wave1
)
s1
.
accept_waveform
(
sample_rate1
,
samples1
)
s2
=
recognizer
.
create_stream
()
samples2
,
sample_rate2
=
read_wave
(
wave2
)
s2
.
accept_waveform
(
sample_rate2
,
samples2
)
recognizer
.
decode_streams
([
s0
,
s1
,
s2
])
print
(
s0
.
result
.
text
)
print
(
s1
.
result
.
text
)
print
(
s2
.
result
.
text
)
if
__name__
==
"__main__"
:
unittest
.
main
()
...
...
sherpa-onnx/python/tests/test_online_recognizer.py
查看文件 @
049fb9f
...
...
@@ -143,6 +143,64 @@ class TestOnlineRecognizer(unittest.TestCase):
print
(
f
"{wave_filename}
\n
{result}"
)
print
(
"-"
*
10
)
def
test_wenet_ctc
(
self
):
models
=
[
"sherpa-onnx-zh-wenet-aishell"
,
"sherpa-onnx-zh-wenet-aishell2"
,
"sherpa-onnx-zh-wenet-wenetspeech"
,
"sherpa-onnx-zh-wenet-multi-cn"
,
"sherpa-onnx-en-wenet-librispeech"
,
"sherpa-onnx-en-wenet-gigaspeech"
,
]
for
m
in
models
:
for
use_int8
in
[
True
,
False
]:
name
=
(
"model-streaming.int8.onnx"
if
use_int8
else
"model-streaming.onnx"
)
model
=
f
"{d}/{m}/{name}"
tokens
=
f
"{d}/{m}/tokens.txt"
wave0
=
f
"{d}/{m}/test_wavs/0.wav"
wave1
=
f
"{d}/{m}/test_wavs/1.wav"
wave2
=
f
"{d}/{m}/test_wavs/8k.wav"
if
not
Path
(
model
)
.
is_file
():
print
(
"skipping test_wenet_ctc()"
)
return
recognizer
=
sherpa_onnx
.
OnlineRecognizer
.
from_wenet_ctc
(
model
=
model
,
tokens
=
tokens
,
num_threads
=
1
,
provider
=
"cpu"
,
)
streams
=
[]
waves
=
[
wave0
,
wave1
,
wave2
]
for
wave
in
waves
:
s
=
recognizer
.
create_stream
()
samples
,
sample_rate
=
read_wave
(
wave
)
s
.
accept_waveform
(
sample_rate
,
samples
)
tail_paddings
=
np
.
zeros
(
int
(
0.2
*
sample_rate
),
dtype
=
np
.
float32
)
s
.
accept_waveform
(
sample_rate
,
tail_paddings
)
s
.
input_finished
()
streams
.
append
(
s
)
while
True
:
ready_list
=
[]
for
s
in
streams
:
if
recognizer
.
is_ready
(
s
):
ready_list
.
append
(
s
)
if
len
(
ready_list
)
==
0
:
break
recognizer
.
decode_streams
(
ready_list
)
results
=
[
recognizer
.
get_result
(
s
)
for
s
in
streams
]
for
wave_filename
,
result
in
zip
(
waves
,
results
):
print
(
f
"{wave_filename}
\n
{result}"
)
print
(
"-"
*
10
)
if
__name__
==
"__main__"
:
unittest
.
main
()
...
...
请
注册
或
登录
后发表评论