Fangjun Kuang
Committed by GitHub

Add Python APIs for WeNet CTC models (#428)

@@ -8,6 +8,51 @@ log() { @@ -8,6 +8,51 @@ log() {
8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" 8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
9 } 9 }
10 10
  11 +wenet_models=(
  12 +sherpa-onnx-zh-wenet-aishell
  13 +sherpa-onnx-zh-wenet-aishell2
  14 +sherpa-onnx-zh-wenet-wenetspeech
  15 +sherpa-onnx-zh-wenet-multi-cn
  16 +sherpa-onnx-en-wenet-librispeech
  17 +sherpa-onnx-en-wenet-gigaspeech
  18 +)
  19 +
  20 +mkdir -p /tmp/icefall-models
  21 +dir=/tmp/icefall-models
  22 +
  23 +for name in ${wenet_models[@]}; do
  24 + repo_url=https://huggingface.co/csukuangfj/$name
  25 + log "Start testing ${repo_url}"
  26 + repo=$dir/$(basename $repo_url)
  27 + log "Download pretrained model and test-data from $repo_url"
  28 + pushd $dir
  29 + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
  30 + cd $repo
  31 + git lfs pull --include "*.onnx"
  32 + ls -lh *.onnx
  33 + popd
  34 +
  35 + python3 ./python-api-examples/offline-decode-files.py \
  36 + --tokens=$repo/tokens.txt \
  37 + --wenet-ctc=$repo/model.onnx \
  38 + $repo/test_wavs/0.wav \
  39 + $repo/test_wavs/1.wav \
  40 + $repo/test_wavs/8k.wav
  41 +
  42 + python3 ./python-api-examples/online-decode-files.py \
  43 + --tokens=$repo/tokens.txt \
  44 + --wenet-ctc=$repo/model-streaming.onnx \
  45 + $repo/test_wavs/0.wav \
  46 + $repo/test_wavs/1.wav \
  47 + $repo/test_wavs/8k.wav
  48 +
  49 + python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
  50 +
  51 + python3 sherpa-onnx/python/tests/test_online_recognizer.py --verbose
  52 +
  53 + rm -rf $repo
  54 +done
  55 +
11 log "Offline TTS test" 56 log "Offline TTS test"
12 # test waves are saved in ./tts 57 # test waves are saved in ./tts
13 mkdir ./tts 58 mkdir ./tts
@@ -85,10 +85,19 @@ jobs: @@ -85,10 +85,19 @@ jobs:
85 arch=${{ matrix.arch }} 85 arch=${{ matrix.arch }}
86 86
87 cd mfc-examples/$arch/Release 87 cd mfc-examples/$arch/Release
88 - cp StreamingSpeechRecognition.exe sherpa-onnx-streaming-${SHERPA_ONNX_VERSION}.exe  
89 - cp NonStreamingSpeechRecognition.exe sherpa-onnx-non-streaming-${SHERPA_ONNX_VERSION}.exe  
90 ls -lh 88 ls -lh
91 89
  90 + cp -v StreamingSpeechRecognition.exe sherpa-onnx-streaming-${SHERPA_ONNX_VERSION}.exe
  91 + cp -v NonStreamingSpeechRecognition.exe sherpa-onnx-non-streaming-${SHERPA_ONNX_VERSION}.exe
  92 + cp -v NonStreamingTextToSpeech.exe ../sherpa-onnx-non-streaming-tts-${SHERPA_ONNX_VERSION}.exe
  93 + ls -lh
  94 +
  95 + - name: Upload artifact tts
  96 + uses: actions/upload-artifact@v3
  97 + with:
  98 + name: non-streaming-tts-${{ matrix.arch }}
  99 + path: ./mfc-examples/${{ matrix.arch }}/Release/NonStreamingTextToSpeech.exe
  100 +
92 - name: Upload artifact 101 - name: Upload artifact
93 uses: actions/upload-artifact@v3 102 uses: actions/upload-artifact@v3
94 with: 103 with:
@@ -116,3 +125,11 @@ jobs: @@ -116,3 +125,11 @@ jobs:
116 file_glob: true 125 file_glob: true
117 overwrite: true 126 overwrite: true
118 file: ./mfc-examples/${{ matrix.arch }}/Release/sherpa-onnx-non-streaming-*.exe 127 file: ./mfc-examples/${{ matrix.arch }}/Release/sherpa-onnx-non-streaming-*.exe
  128 +
  129 + - name: Release pre-compiled binaries and libs for Windows ${{ matrix.arch }}
  130 + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' && github.event_name == 'push' && contains(github.ref, 'refs/tags/')
  131 + uses: svenstaro/upload-release-action@v2
  132 + with:
  133 + file_glob: true
  134 + overwrite: true
  135 + file: ./mfc-examples/${{ matrix.arch }}/sherpa-onnx-non-streaming-*.exe
@@ -10,6 +10,7 @@ on: @@ -10,6 +10,7 @@ on:
10 - 'CMakeLists.txt' 10 - 'CMakeLists.txt'
11 - 'cmake/**' 11 - 'cmake/**'
12 - 'sherpa-onnx/csrc/*' 12 - 'sherpa-onnx/csrc/*'
  13 + - 'python-api-examples/**'
13 pull_request: 14 pull_request:
14 branches: 15 branches:
15 - master 16 - master
@@ -19,6 +20,7 @@ on: @@ -19,6 +20,7 @@ on:
19 - 'CMakeLists.txt' 20 - 'CMakeLists.txt'
20 - 'cmake/**' 21 - 'cmake/**'
21 - 'sherpa-onnx/csrc/*' 22 - 'sherpa-onnx/csrc/*'
  23 + - 'python-api-examples/**'
22 workflow_dispatch: 24 workflow_dispatch:
23 25
24 concurrency: 26 concurrency:
1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR) 1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
2 project(sherpa-onnx) 2 project(sherpa-onnx)
3 3
4 -set(SHERPA_ONNX_VERSION "1.8.9") 4 +set(SHERPA_ONNX_VERSION "1.8.10")
5 5
6 # Disable warning about 6 # Disable warning about
7 # 7 #
@@ -58,6 +58,15 @@ wget https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx @@ -58,6 +58,15 @@ wget https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx
58 --num-threads=2 \ 58 --num-threads=2 \
59 /path/to/test.mp4 59 /path/to/test.mp4
60 60
  61 +(4) For WeNet CTC models
  62 +
  63 +./python-api-examples/generate-subtitles.py \
  64 + --silero-vad-model=/path/to/silero_vad.onnx \
  65 + --wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
  66 + --tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
  67 + --num-threads=2 \
  68 + /path/to/test.mp4
  69 +
61 Please refer to 70 Please refer to
62 https://k2-fsa.github.io/sherpa/onnx/index.html 71 https://k2-fsa.github.io/sherpa/onnx/index.html
63 to install sherpa-onnx and to download non-streaming pre-trained models 72 to install sherpa-onnx and to download non-streaming pre-trained models
@@ -122,6 +131,13 @@ def get_args(): @@ -122,6 +131,13 @@ def get_args():
122 ) 131 )
123 132
124 parser.add_argument( 133 parser.add_argument(
  134 + "--wenet-ctc",
  135 + default="",
  136 + type=str,
  137 + help="Path to the CTC model.onnx from WeNet",
  138 + )
  139 +
  140 + parser.add_argument(
125 "--num-threads", 141 "--num-threads",
126 type=int, 142 type=int,
127 default=1, 143 default=1,
@@ -215,6 +231,7 @@ def assert_file_exists(filename: str): @@ -215,6 +231,7 @@ def assert_file_exists(filename: str):
215 def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: 231 def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
216 if args.encoder: 232 if args.encoder:
217 assert len(args.paraformer) == 0, args.paraformer 233 assert len(args.paraformer) == 0, args.paraformer
  234 + assert len(args.wenet_ctc) == 0, args.wenet_ctc
218 assert len(args.whisper_encoder) == 0, args.whisper_encoder 235 assert len(args.whisper_encoder) == 0, args.whisper_encoder
219 assert len(args.whisper_decoder) == 0, args.whisper_decoder 236 assert len(args.whisper_decoder) == 0, args.whisper_decoder
220 237
@@ -234,6 +251,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -234,6 +251,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
234 debug=args.debug, 251 debug=args.debug,
235 ) 252 )
236 elif args.paraformer: 253 elif args.paraformer:
  254 + assert len(args.wenet_ctc) == 0, args.wenet_ctc
237 assert len(args.whisper_encoder) == 0, args.whisper_encoder 255 assert len(args.whisper_encoder) == 0, args.whisper_encoder
238 assert len(args.whisper_decoder) == 0, args.whisper_decoder 256 assert len(args.whisper_decoder) == 0, args.whisper_decoder
239 257
@@ -248,6 +266,21 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -248,6 +266,21 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
248 decoding_method=args.decoding_method, 266 decoding_method=args.decoding_method,
249 debug=args.debug, 267 debug=args.debug,
250 ) 268 )
  269 + elif args.wenet_ctc:
  270 + assert len(args.whisper_encoder) == 0, args.whisper_encoder
  271 + assert len(args.whisper_decoder) == 0, args.whisper_decoder
  272 +
  273 + assert_file_exists(args.wenet_ctc)
  274 +
  275 + recognizer = sherpa_onnx.OfflineRecognizer.from_wenet_ctc(
  276 + model=args.wenet_ctc,
  277 + tokens=args.tokens,
  278 + num_threads=args.num_threads,
  279 + sample_rate=args.sample_rate,
  280 + feature_dim=args.feature_dim,
  281 + decoding_method=args.decoding_method,
  282 + debug=args.debug,
  283 + )
251 elif args.whisper_encoder: 284 elif args.whisper_encoder:
252 assert_file_exists(args.whisper_encoder) 285 assert_file_exists(args.whisper_encoder)
253 assert_file_exists(args.whisper_decoder) 286 assert_file_exists(args.whisper_decoder)
@@ -58,7 +58,19 @@ python3 ./python-api-examples/non_streaming_server.py \ @@ -58,7 +58,19 @@ python3 ./python-api-examples/non_streaming_server.py \
58 --nemo-ctc ./sherpa-onnx-nemo-ctc-en-conformer-medium/model.onnx \ 58 --nemo-ctc ./sherpa-onnx-nemo-ctc-en-conformer-medium/model.onnx \
59 --tokens ./sherpa-onnx-nemo-ctc-en-conformer-medium/tokens.txt 59 --tokens ./sherpa-onnx-nemo-ctc-en-conformer-medium/tokens.txt
60 60
61 -(4) Use a Whisper model 61 +(4) Use a non-streaming CTC model from WeNet
  62 +
  63 +cd /path/to/sherpa-onnx
  64 +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-wenetspeech
  65 +cd sherpa-onnx-zh-wenet-wenetspeech
  66 +git lfs pull --include "*.onnx"
  67 +cd ..
  68 +
  69 +python3 ./python-api-examples/non_streaming_server.py \
  70 + --wenet-ctc ./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
  71 + --tokens ./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt
  72 +
  73 +(5) Use a Whisper model
62 74
63 cd /path/to/sherpa-onnx 75 cd /path/to/sherpa-onnx
64 GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-tiny.en 76 GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-tiny.en
@@ -210,6 +222,15 @@ def add_nemo_ctc_model_args(parser: argparse.ArgumentParser): @@ -210,6 +222,15 @@ def add_nemo_ctc_model_args(parser: argparse.ArgumentParser):
210 ) 222 )
211 223
212 224
  225 +def add_wenet_ctc_model_args(parser: argparse.ArgumentParser):
  226 + parser.add_argument(
  227 + "--wenet-ctc",
  228 + default="",
  229 + type=str,
  230 + help="Path to the model.onnx from WeNet CTC",
  231 + )
  232 +
  233 +
213 def add_tdnn_ctc_model_args(parser: argparse.ArgumentParser): 234 def add_tdnn_ctc_model_args(parser: argparse.ArgumentParser):
214 parser.add_argument( 235 parser.add_argument(
215 "--tdnn-model", 236 "--tdnn-model",
@@ -261,6 +282,7 @@ def add_model_args(parser: argparse.ArgumentParser): @@ -261,6 +282,7 @@ def add_model_args(parser: argparse.ArgumentParser):
261 add_transducer_model_args(parser) 282 add_transducer_model_args(parser)
262 add_paraformer_model_args(parser) 283 add_paraformer_model_args(parser)
263 add_nemo_ctc_model_args(parser) 284 add_nemo_ctc_model_args(parser)
  285 + add_wenet_ctc_model_args(parser)
264 add_tdnn_ctc_model_args(parser) 286 add_tdnn_ctc_model_args(parser)
265 add_whisper_model_args(parser) 287 add_whisper_model_args(parser)
266 288
@@ -804,6 +826,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -804,6 +826,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
804 if args.encoder: 826 if args.encoder:
805 assert len(args.paraformer) == 0, args.paraformer 827 assert len(args.paraformer) == 0, args.paraformer
806 assert len(args.nemo_ctc) == 0, args.nemo_ctc 828 assert len(args.nemo_ctc) == 0, args.nemo_ctc
  829 + assert len(args.wenet_ctc) == 0, args.wenet_ctc
807 assert len(args.whisper_encoder) == 0, args.whisper_encoder 830 assert len(args.whisper_encoder) == 0, args.whisper_encoder
808 assert len(args.whisper_decoder) == 0, args.whisper_decoder 831 assert len(args.whisper_decoder) == 0, args.whisper_decoder
809 assert len(args.tdnn_model) == 0, args.tdnn_model 832 assert len(args.tdnn_model) == 0, args.tdnn_model
@@ -827,6 +850,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -827,6 +850,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
827 ) 850 )
828 elif args.paraformer: 851 elif args.paraformer:
829 assert len(args.nemo_ctc) == 0, args.nemo_ctc 852 assert len(args.nemo_ctc) == 0, args.nemo_ctc
  853 + assert len(args.wenet_ctc) == 0, args.wenet_ctc
830 assert len(args.whisper_encoder) == 0, args.whisper_encoder 854 assert len(args.whisper_encoder) == 0, args.whisper_encoder
831 assert len(args.whisper_decoder) == 0, args.whisper_decoder 855 assert len(args.whisper_decoder) == 0, args.whisper_decoder
832 assert len(args.tdnn_model) == 0, args.tdnn_model 856 assert len(args.tdnn_model) == 0, args.tdnn_model
@@ -842,6 +866,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -842,6 +866,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
842 decoding_method=args.decoding_method, 866 decoding_method=args.decoding_method,
843 ) 867 )
844 elif args.nemo_ctc: 868 elif args.nemo_ctc:
  869 + assert len(args.wenet_ctc) == 0, args.wenet_ctc
845 assert len(args.whisper_encoder) == 0, args.whisper_encoder 870 assert len(args.whisper_encoder) == 0, args.whisper_encoder
846 assert len(args.whisper_decoder) == 0, args.whisper_decoder 871 assert len(args.whisper_decoder) == 0, args.whisper_decoder
847 assert len(args.tdnn_model) == 0, args.tdnn_model 872 assert len(args.tdnn_model) == 0, args.tdnn_model
@@ -856,6 +881,21 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -856,6 +881,21 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
856 feature_dim=args.feat_dim, 881 feature_dim=args.feat_dim,
857 decoding_method=args.decoding_method, 882 decoding_method=args.decoding_method,
858 ) 883 )
  884 + elif args.wenet_ctc:
  885 + assert len(args.whisper_encoder) == 0, args.whisper_encoder
  886 + assert len(args.whisper_decoder) == 0, args.whisper_decoder
  887 + assert len(args.tdnn_model) == 0, args.tdnn_model
  888 +
  889 + assert_file_exists(args.wenet_ctc)
  890 +
  891 + recognizer = sherpa_onnx.OfflineRecognizer.from_wenet_ctc(
  892 + model=args.wenet_ctc,
  893 + tokens=args.tokens,
  894 + num_threads=args.num_threads,
  895 + sample_rate=args.sample_rate,
  896 + feature_dim=args.feat_dim,
  897 + decoding_method=args.decoding_method,
  898 + )
859 elif args.whisper_encoder: 899 elif args.whisper_encoder:
860 assert len(args.tdnn_model) == 0, args.tdnn_model 900 assert len(args.tdnn_model) == 0, args.tdnn_model
861 assert_file_exists(args.whisper_encoder) 901 assert_file_exists(args.whisper_encoder)
@@ -59,7 +59,16 @@ python3 ./python-api-examples/offline-decode-files.py \ @@ -59,7 +59,16 @@ python3 ./python-api-examples/offline-decode-files.py \
59 ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \ 59 ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
60 ./sherpa-onnx-whisper-base.en/test_wavs/8k.wav 60 ./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
61 61
62 -(5) For tdnn models of the yesno recipe from icefall 62 +(5) For CTC models from WeNet
  63 +
  64 +python3 ./python-api-examples/offline-decode-files.py \
  65 + --wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
  66 + --tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
  67 + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \
  68 + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \
  69 + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
  70 +
  71 +(6) For tdnn models of the yesno recipe from icefall
63 72
64 python3 ./python-api-examples/offline-decode-files.py \ 73 python3 ./python-api-examples/offline-decode-files.py \
65 --sample-rate=8000 \ 74 --sample-rate=8000 \
@@ -155,6 +164,13 @@ def get_args(): @@ -155,6 +164,13 @@ def get_args():
155 ) 164 )
156 165
157 parser.add_argument( 166 parser.add_argument(
  167 + "--wenet-ctc",
  168 + default="",
  169 + type=str,
  170 + help="Path to the model.onnx from WeNet CTC",
  171 + )
  172 +
  173 + parser.add_argument(
158 "--tdnn-model", 174 "--tdnn-model",
159 default="", 175 default="",
160 type=str, 176 type=str,
@@ -254,6 +270,7 @@ def assert_file_exists(filename: str): @@ -254,6 +270,7 @@ def assert_file_exists(filename: str):
254 "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" 270 "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
255 ) 271 )
256 272
  273 +
257 def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: 274 def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
258 """ 275 """
259 Args: 276 Args:
@@ -287,6 +304,7 @@ def main(): @@ -287,6 +304,7 @@ def main():
287 if args.encoder: 304 if args.encoder:
288 assert len(args.paraformer) == 0, args.paraformer 305 assert len(args.paraformer) == 0, args.paraformer
289 assert len(args.nemo_ctc) == 0, args.nemo_ctc 306 assert len(args.nemo_ctc) == 0, args.nemo_ctc
  307 + assert len(args.wenet_ctc) == 0, args.wenet_ctc
290 assert len(args.whisper_encoder) == 0, args.whisper_encoder 308 assert len(args.whisper_encoder) == 0, args.whisper_encoder
291 assert len(args.whisper_decoder) == 0, args.whisper_decoder 309 assert len(args.whisper_decoder) == 0, args.whisper_decoder
292 assert len(args.tdnn_model) == 0, args.tdnn_model 310 assert len(args.tdnn_model) == 0, args.tdnn_model
@@ -310,6 +328,7 @@ def main(): @@ -310,6 +328,7 @@ def main():
310 ) 328 )
311 elif args.paraformer: 329 elif args.paraformer:
312 assert len(args.nemo_ctc) == 0, args.nemo_ctc 330 assert len(args.nemo_ctc) == 0, args.nemo_ctc
  331 + assert len(args.wenet_ctc) == 0, args.wenet_ctc
313 assert len(args.whisper_encoder) == 0, args.whisper_encoder 332 assert len(args.whisper_encoder) == 0, args.whisper_encoder
314 assert len(args.whisper_decoder) == 0, args.whisper_decoder 333 assert len(args.whisper_decoder) == 0, args.whisper_decoder
315 assert len(args.tdnn_model) == 0, args.tdnn_model 334 assert len(args.tdnn_model) == 0, args.tdnn_model
@@ -326,6 +345,7 @@ def main(): @@ -326,6 +345,7 @@ def main():
326 debug=args.debug, 345 debug=args.debug,
327 ) 346 )
328 elif args.nemo_ctc: 347 elif args.nemo_ctc:
  348 + assert len(args.wenet_ctc) == 0, args.wenet_ctc
329 assert len(args.whisper_encoder) == 0, args.whisper_encoder 349 assert len(args.whisper_encoder) == 0, args.whisper_encoder
330 assert len(args.whisper_decoder) == 0, args.whisper_decoder 350 assert len(args.whisper_decoder) == 0, args.whisper_decoder
331 assert len(args.tdnn_model) == 0, args.tdnn_model 351 assert len(args.tdnn_model) == 0, args.tdnn_model
@@ -341,6 +361,22 @@ def main(): @@ -341,6 +361,22 @@ def main():
341 decoding_method=args.decoding_method, 361 decoding_method=args.decoding_method,
342 debug=args.debug, 362 debug=args.debug,
343 ) 363 )
  364 + elif args.wenet_ctc:
  365 + assert len(args.whisper_encoder) == 0, args.whisper_encoder
  366 + assert len(args.whisper_decoder) == 0, args.whisper_decoder
  367 + assert len(args.tdnn_model) == 0, args.tdnn_model
  368 +
  369 + assert_file_exists(args.wenet_ctc)
  370 +
  371 + recognizer = sherpa_onnx.OfflineRecognizer.from_wenet_ctc(
  372 + model=args.wenet_ctc,
  373 + tokens=args.tokens,
  374 + num_threads=args.num_threads,
  375 + sample_rate=args.sample_rate,
  376 + feature_dim=args.feature_dim,
  377 + decoding_method=args.decoding_method,
  378 + debug=args.debug,
  379 + )
344 elif args.whisper_encoder: 380 elif args.whisper_encoder:
345 assert len(args.tdnn_model) == 0, args.tdnn_model 381 assert len(args.tdnn_model) == 0, args.tdnn_model
346 assert_file_exists(args.whisper_encoder) 382 assert_file_exists(args.whisper_encoder)
@@ -37,8 +37,25 @@ git lfs pull --include "*.onnx" @@ -37,8 +37,25 @@ git lfs pull --include "*.onnx"
37 ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/3.wav \ 37 ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/3.wav \
38 ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/8k.wav 38 ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/8k.wav
39 39
  40 +(3) Streaming Conformer CTC from WeNet
  41 +
  42 +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-wenetspeech
  43 +cd sherpa-onnx-zh-wenet-wenetspeech
  44 +git lfs pull --include "*.onnx"
  45 +
  46 +./python-api-examples/online-decode-files.py \
  47 + --tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
  48 + --wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model-streaming.onnx \
  49 + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \
  50 + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \
  51 + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
  52 +
  53 +
  54 +
40 Please refer to 55 Please refer to
41 https://k2-fsa.github.io/sherpa/onnx/index.html 56 https://k2-fsa.github.io/sherpa/onnx/index.html
  57 +and
  58 +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/wenet/index.html
42 to install sherpa-onnx and to download streaming pre-trained models. 59 to install sherpa-onnx and to download streaming pre-trained models.
43 """ 60 """
44 import argparse 61 import argparse
@@ -93,6 +110,26 @@ def get_args(): @@ -93,6 +110,26 @@ def get_args():
93 ) 110 )
94 111
95 parser.add_argument( 112 parser.add_argument(
  113 + "--wenet-ctc",
  114 + type=str,
  115 + help="Path to the wenet ctc model model",
  116 + )
  117 +
  118 + parser.add_argument(
  119 + "--wenet-ctc-chunk-size",
  120 + type=int,
  121 + default=16,
  122 + help="The --chunk-size parameter for streaming WeNet models",
  123 + )
  124 +
  125 + parser.add_argument(
  126 + "--wenet-ctc-num-left-chunks",
  127 + type=int,
  128 + default=4,
  129 + help="The --num-left-chunks parameter for streaming WeNet models",
  130 + )
  131 +
  132 + parser.add_argument(
96 "--num-threads", 133 "--num-threads",
97 type=int, 134 type=int,
98 default=1, 135 default=1,
@@ -249,6 +286,18 @@ def main(): @@ -249,6 +286,18 @@ def main():
249 feature_dim=80, 286 feature_dim=80,
250 decoding_method="greedy_search", 287 decoding_method="greedy_search",
251 ) 288 )
  289 + elif args.wenet_ctc:
  290 + recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc(
  291 + tokens=args.tokens,
  292 + model=args.wenet_ctc,
  293 + chunk_size=args.wenet_ctc_chunk_size,
  294 + num_left_chunks=args.wenet_ctc_num_left_chunks,
  295 + num_threads=args.num_threads,
  296 + provider=args.provider,
  297 + sample_rate=16000,
  298 + feature_dim=80,
  299 + decoding_method="greedy_search",
  300 + )
252 else: 301 else:
253 raise ValueError("Please provide a model") 302 raise ValueError("Please provide a model")
254 303
@@ -40,10 +40,17 @@ python3 ./python-api-examples/streaming_server.py \ @@ -40,10 +40,17 @@ python3 ./python-api-examples/streaming_server.py \
40 40
41 Please refer to 41 Please refer to
42 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html 42 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html
  43 +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/wenet/index.html
43 to download pre-trained models. 44 to download pre-trained models.
44 45
45 The model in the above help messages is from 46 The model in the above help messages is from
46 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english 47 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
  48 +
  49 +To use a WeNet streaming Conformer CTC model, please use
  50 +
  51 +python3 ./python-api-examples/streaming_server.py \
  52 + --tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
  53 + --wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model-streaming.onnx
47 """ 54 """
48 55
49 import argparse 56 import argparse
@@ -131,6 +138,12 @@ def add_model_args(parser: argparse.ArgumentParser): @@ -131,6 +138,12 @@ def add_model_args(parser: argparse.ArgumentParser):
131 ) 138 )
132 139
133 parser.add_argument( 140 parser.add_argument(
  141 + "--wenet-ctc",
  142 + type=str,
  143 + help="Path to the model.onnx from WeNet",
  144 + )
  145 +
  146 + parser.add_argument(
134 "--paraformer-encoder", 147 "--paraformer-encoder",
135 type=str, 148 type=str,
136 help="Path to the paraformer encoder model", 149 help="Path to the paraformer encoder model",
@@ -212,7 +225,6 @@ def add_hotwords_args(parser: argparse.ArgumentParser): @@ -212,7 +225,6 @@ def add_hotwords_args(parser: argparse.ArgumentParser):
212 ) 225 )
213 226
214 227
215 -  
216 def add_modified_beam_search_args(parser: argparse.ArgumentParser): 228 def add_modified_beam_search_args(parser: argparse.ArgumentParser):
217 parser.add_argument( 229 parser.add_argument(
218 "--num-active-paths", 230 "--num-active-paths",
@@ -393,6 +405,20 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: @@ -393,6 +405,20 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
393 rule3_min_utterance_length=args.rule3_min_utterance_length, 405 rule3_min_utterance_length=args.rule3_min_utterance_length,
394 provider=args.provider, 406 provider=args.provider,
395 ) 407 )
  408 + elif args.wenet_ctc:
  409 + recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc(
  410 + tokens=args.tokens,
  411 + model=args.wenet_ctc,
  412 + num_threads=args.num_threads,
  413 + sample_rate=args.sample_rate,
  414 + feature_dim=args.feat_dim,
  415 + decoding_method=args.decoding_method,
  416 + enable_endpoint_detection=args.use_endpoint != 0,
  417 + rule1_min_trailing_silence=args.rule1_min_trailing_silence,
  418 + rule2_min_trailing_silence=args.rule2_min_trailing_silence,
  419 + rule3_min_utterance_length=args.rule3_min_utterance_length,
  420 + provider=args.provider,
  421 + )
396 else: 422 else:
397 raise ValueError("Please provide a model") 423 raise ValueError("Please provide a model")
398 424
@@ -727,6 +753,8 @@ def check_args(args): @@ -727,6 +753,8 @@ def check_args(args):
727 assert Path( 753 assert Path(
728 args.paraformer_decoder 754 args.paraformer_decoder
729 ).is_file(), f"{args.paraformer_decoder} does not exist" 755 ).is_file(), f"{args.paraformer_decoder} does not exist"
  756 + elif args.wenet_ctc:
  757 + assert Path(args.wenet_ctc).is_file(), f"{args.wenet_ctc} does not exist"
730 else: 758 else:
731 raise ValueError("Please provide a model") 759 raise ValueError("Please provide a model")
732 760
@@ -9,15 +9,16 @@ from _sherpa_onnx import ( @@ -9,15 +9,16 @@ from _sherpa_onnx import (
9 OfflineModelConfig, 9 OfflineModelConfig,
10 OfflineNemoEncDecCtcModelConfig, 10 OfflineNemoEncDecCtcModelConfig,
11 OfflineParaformerModelConfig, 11 OfflineParaformerModelConfig,
12 - OfflineTdnnModelConfig,  
13 - OfflineWhisperModelConfig,  
14 - OfflineZipformerCtcModelConfig,  
15 ) 12 )
16 from _sherpa_onnx import OfflineRecognizer as _Recognizer 13 from _sherpa_onnx import OfflineRecognizer as _Recognizer
17 from _sherpa_onnx import ( 14 from _sherpa_onnx import (
18 OfflineRecognizerConfig, 15 OfflineRecognizerConfig,
19 OfflineStream, 16 OfflineStream,
  17 + OfflineTdnnModelConfig,
20 OfflineTransducerModelConfig, 18 OfflineTransducerModelConfig,
  19 + OfflineWenetCtcModelConfig,
  20 + OfflineWhisperModelConfig,
  21 + OfflineZipformerCtcModelConfig,
21 ) 22 )
22 23
23 24
@@ -389,6 +390,70 @@ class OfflineRecognizer(object): @@ -389,6 +390,70 @@ class OfflineRecognizer(object):
389 self.config = recognizer_config 390 self.config = recognizer_config
390 return self 391 return self
391 392
  393 + @classmethod
  394 + def from_wenet_ctc(
  395 + cls,
  396 + model: str,
  397 + tokens: str,
  398 + num_threads: int = 1,
  399 + sample_rate: int = 16000,
  400 + feature_dim: int = 80,
  401 + decoding_method: str = "greedy_search",
  402 + debug: bool = False,
  403 + provider: str = "cpu",
  404 + ):
  405 + """
  406 + Please refer to
  407 + `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html>`_
  408 + to download pre-trained models for different languages, e.g., Chinese,
  409 + English, etc.
  410 +
  411 + Args:
  412 + model:
  413 + Path to ``model.onnx``.
  414 + tokens:
  415 + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
  416 + columns::
  417 +
  418 + symbol integer_id
  419 +
  420 + num_threads:
  421 + Number of threads for neural network computation.
  422 + sample_rate:
  423 + Sample rate of the training data used to train the model.
  424 + feature_dim:
  425 + Dimension of the feature used to train the model.
  426 + decoding_method:
  427 + Valid values are greedy_search.
  428 + debug:
  429 + True to show debug messages.
  430 + provider:
  431 + onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
  432 + """
  433 + self = cls.__new__(cls)
  434 + model_config = OfflineModelConfig(
  435 + wenet_ctc=OfflineWenetCtcModelConfig(model=model),
  436 + tokens=tokens,
  437 + num_threads=num_threads,
  438 + debug=debug,
  439 + provider=provider,
  440 + model_type="wenet_ctc",
  441 + )
  442 +
  443 + feat_config = OfflineFeatureExtractorConfig(
  444 + sampling_rate=sample_rate,
  445 + feature_dim=feature_dim,
  446 + )
  447 +
  448 + recognizer_config = OfflineRecognizerConfig(
  449 + feat_config=feat_config,
  450 + model_config=model_config,
  451 + decoding_method=decoding_method,
  452 + )
  453 + self.recognizer = _Recognizer(recognizer_config)
  454 + self.config = recognizer_config
  455 + return self
  456 +
392 def create_stream(self, hotwords: Optional[str] = None): 457 def create_stream(self, hotwords: Optional[str] = None):
393 if hotwords is None: 458 if hotwords is None:
394 return self.recognizer.create_stream() 459 return self.recognizer.create_stream()
@@ -12,6 +12,7 @@ from _sherpa_onnx import ( @@ -12,6 +12,7 @@ from _sherpa_onnx import (
12 OnlineRecognizerConfig, 12 OnlineRecognizerConfig,
13 OnlineStream, 13 OnlineStream,
14 OnlineTransducerModelConfig, 14 OnlineTransducerModelConfig,
  15 + OnlineWenetCtcModelConfig,
15 ) 16 )
16 17
17 18
@@ -140,13 +141,13 @@ class OnlineRecognizer(object): @@ -140,13 +141,13 @@ class OnlineRecognizer(object):
140 "Please use --decoding-method=modified_beam_search when using " 141 "Please use --decoding-method=modified_beam_search when using "
141 f"--hotwords-file. Currently given: {decoding_method}" 142 f"--hotwords-file. Currently given: {decoding_method}"
142 ) 143 )
143 - 144 +
144 if lm and decoding_method != "modified_beam_search": 145 if lm and decoding_method != "modified_beam_search":
145 raise ValueError( 146 raise ValueError(
146 "Please use --decoding-method=modified_beam_search when using " 147 "Please use --decoding-method=modified_beam_search when using "
147 f"--lm. Currently given: {decoding_method}" 148 f"--lm. Currently given: {decoding_method}"
148 ) 149 )
149 - 150 +
150 lm_config = OnlineLMConfig( 151 lm_config = OnlineLMConfig(
151 model=lm, 152 model=lm,
152 scale=lm_scale, 153 scale=lm_scale,
@@ -271,6 +272,112 @@ class OnlineRecognizer(object): @@ -271,6 +272,112 @@ class OnlineRecognizer(object):
271 self.config = recognizer_config 272 self.config = recognizer_config
272 return self 273 return self
273 274
  275 + @classmethod
  276 + def from_wenet_ctc(
  277 + cls,
  278 + tokens: str,
  279 + model: str,
  280 + chunk_size: int = 16,
  281 + num_left_chunks: int = 4,
  282 + num_threads: int = 2,
  283 + sample_rate: float = 16000,
  284 + feature_dim: int = 80,
  285 + enable_endpoint_detection: bool = False,
  286 + rule1_min_trailing_silence: float = 2.4,
  287 + rule2_min_trailing_silence: float = 1.2,
  288 + rule3_min_utterance_length: float = 20.0,
  289 + decoding_method: str = "greedy_search",
  290 + provider: str = "cpu",
  291 + ):
  292 + """
  293 + Please refer to
  294 + `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/wenet/index.html>`_
  295 + to download pre-trained models for different languages, e.g., Chinese,
  296 + English, etc.
  297 +
  298 + Args:
  299 + tokens:
  300 + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
  301 + columns::
  302 +
  303 + symbol integer_id
  304 +
  305 + model:
  306 + Path to ``model.onnx``.
  307 + chunk_size:
  308 + The --chunk-size parameter from WeNet.
  309 + num_left_chunks:
  310 + The --num-left-chunks parameter from WeNet.
  311 + num_threads:
  312 + Number of threads for neural network computation.
  313 + sample_rate:
  314 + Sample rate of the training data used to train the model.
  315 + feature_dim:
  316 + Dimension of the feature used to train the model.
  317 + enable_endpoint_detection:
  318 + True to enable endpoint detection. False to disable endpoint
  319 + detection.
  320 + rule1_min_trailing_silence:
  321 + Used only when enable_endpoint_detection is True. If the duration
  322 + of trailing silence in seconds is larger than this value, we assume
  323 + an endpoint is detected.
  324 + rule2_min_trailing_silence:
  325 + Used only when enable_endpoint_detection is True. If we have decoded
  326 + something that is nonsilence and if the duration of trailing silence
  327 + in seconds is larger than this value, we assume an endpoint is
  328 + detected.
  329 + rule3_min_utterance_length:
  330 + Used only when enable_endpoint_detection is True. If the utterance
  331 + length in seconds is larger than this value, we assume an endpoint
  332 + is detected.
  333 + decoding_method:
  334 + The only valid value is greedy_search.
  335 + provider:
  336 + onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
  337 + """
  338 + self = cls.__new__(cls)
  339 + _assert_file_exists(tokens)
  340 + _assert_file_exists(model)
  341 +
  342 + assert num_threads > 0, num_threads
  343 +
  344 + wenet_ctc_config = OnlineWenetCtcModelConfig(
  345 + model=model,
  346 + chunk_size=chunk_size,
  347 + num_left_chunks=num_left_chunks,
  348 + )
  349 +
  350 + model_config = OnlineModelConfig(
  351 + wenet_ctc=wenet_ctc_config,
  352 + tokens=tokens,
  353 + num_threads=num_threads,
  354 + provider=provider,
  355 + model_type="wenet_ctc",
  356 + )
  357 +
  358 + feat_config = FeatureExtractorConfig(
  359 + sampling_rate=sample_rate,
  360 + feature_dim=feature_dim,
  361 + )
  362 +
  363 + endpoint_config = EndpointConfig(
  364 + rule1_min_trailing_silence=rule1_min_trailing_silence,
  365 + rule2_min_trailing_silence=rule2_min_trailing_silence,
  366 + rule3_min_utterance_length=rule3_min_utterance_length,
  367 + )
  368 +
  369 + recognizer_config = OnlineRecognizerConfig(
  370 + feat_config=feat_config,
  371 + model_config=model_config,
  372 + endpoint_config=endpoint_config,
  373 + enable_endpoint=enable_endpoint_detection,
  374 + decoding_method=decoding_method,
  375 + )
  376 +
  377 + self.recognizer = _Recognizer(recognizer_config)
  378 + self.config = recognizer_config
  379 + return self
  380 +
274 def create_stream(self, hotwords: Optional[str] = None): 381 def create_stream(self, hotwords: Optional[str] = None):
275 if hotwords is None: 382 if hotwords is None:
276 return self.recognizer.create_stream() 383 return self.recognizer.create_stream()
@@ -267,6 +267,53 @@ class TestOfflineRecognizer(unittest.TestCase): @@ -267,6 +267,53 @@ class TestOfflineRecognizer(unittest.TestCase):
267 print(s1.result.text) 267 print(s1.result.text)
268 print(s2.result.text) 268 print(s2.result.text)
269 269
  270 + def test_wenet_ctc(self):
  271 + models = [
  272 + "sherpa-onnx-zh-wenet-aishell",
  273 + "sherpa-onnx-zh-wenet-aishell2",
  274 + "sherpa-onnx-zh-wenet-wenetspeech",
  275 + "sherpa-onnx-zh-wenet-multi-cn",
  276 + "sherpa-onnx-en-wenet-librispeech",
  277 + "sherpa-onnx-en-wenet-gigaspeech",
  278 + ]
  279 + for m in models:
  280 + for use_int8 in [True, False]:
  281 + name = "model.int8.onnx" if use_int8 else "model.onnx"
  282 + model = f"{d}/{m}/{name}"
  283 + tokens = f"{d}/{m}/tokens.txt"
  284 +
  285 + wave0 = f"{d}/{m}/test_wavs/0.wav"
  286 + wave1 = f"{d}/{m}/test_wavs/1.wav"
  287 + wave2 = f"{d}/{m}/test_wavs/8k.wav"
  288 +
  289 + if not Path(model).is_file():
  290 + print("skipping test_wenet_ctc()")
  291 + return
  292 +
  293 + recognizer = sherpa_onnx.OfflineRecognizer.from_wenet_ctc(
  294 + model=model,
  295 + tokens=tokens,
  296 + num_threads=1,
  297 + provider="cpu",
  298 + )
  299 +
  300 + s0 = recognizer.create_stream()
  301 + samples0, sample_rate0 = read_wave(wave0)
  302 + s0.accept_waveform(sample_rate0, samples0)
  303 +
  304 + s1 = recognizer.create_stream()
  305 + samples1, sample_rate1 = read_wave(wave1)
  306 + s1.accept_waveform(sample_rate1, samples1)
  307 +
  308 + s2 = recognizer.create_stream()
  309 + samples2, sample_rate2 = read_wave(wave2)
  310 + s2.accept_waveform(sample_rate2, samples2)
  311 +
  312 + recognizer.decode_streams([s0, s1, s2])
  313 + print(s0.result.text)
  314 + print(s1.result.text)
  315 + print(s2.result.text)
  316 +
270 317
271 if __name__ == "__main__": 318 if __name__ == "__main__":
272 unittest.main() 319 unittest.main()
@@ -143,6 +143,64 @@ class TestOnlineRecognizer(unittest.TestCase): @@ -143,6 +143,64 @@ class TestOnlineRecognizer(unittest.TestCase):
143 print(f"{wave_filename}\n{result}") 143 print(f"{wave_filename}\n{result}")
144 print("-" * 10) 144 print("-" * 10)
145 145
  146 + def test_wenet_ctc(self):
  147 + models = [
  148 + "sherpa-onnx-zh-wenet-aishell",
  149 + "sherpa-onnx-zh-wenet-aishell2",
  150 + "sherpa-onnx-zh-wenet-wenetspeech",
  151 + "sherpa-onnx-zh-wenet-multi-cn",
  152 + "sherpa-onnx-en-wenet-librispeech",
  153 + "sherpa-onnx-en-wenet-gigaspeech",
  154 + ]
  155 + for m in models:
  156 + for use_int8 in [True, False]:
  157 + name = (
  158 + "model-streaming.int8.onnx" if use_int8 else "model-streaming.onnx"
  159 + )
  160 + model = f"{d}/{m}/{name}"
  161 + tokens = f"{d}/{m}/tokens.txt"
  162 +
  163 + wave0 = f"{d}/{m}/test_wavs/0.wav"
  164 + wave1 = f"{d}/{m}/test_wavs/1.wav"
  165 + wave2 = f"{d}/{m}/test_wavs/8k.wav"
  166 +
  167 + if not Path(model).is_file():
  168 + print("skipping test_wenet_ctc()")
  169 + return
  170 +
  171 + recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc(
  172 + model=model,
  173 + tokens=tokens,
  174 + num_threads=1,
  175 + provider="cpu",
  176 + )
  177 +
  178 + streams = []
  179 + waves = [wave0, wave1, wave2]
  180 + for wave in waves:
  181 + s = recognizer.create_stream()
  182 + samples, sample_rate = read_wave(wave)
  183 + s.accept_waveform(sample_rate, samples)
  184 +
  185 + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
  186 + s.accept_waveform(sample_rate, tail_paddings)
  187 + s.input_finished()
  188 + streams.append(s)
  189 +
  190 + while True:
  191 + ready_list = []
  192 + for s in streams:
  193 + if recognizer.is_ready(s):
  194 + ready_list.append(s)
  195 + if len(ready_list) == 0:
  196 + break
  197 + recognizer.decode_streams(ready_list)
  198 +
  199 + results = [recognizer.get_result(s) for s in streams]
  200 + for wave_filename, result in zip(waves, results):
  201 + print(f"{wave_filename}\n{result}")
  202 + print("-" * 10)
  203 +
146 204
147 if __name__ == "__main__": 205 if __name__ == "__main__":
148 unittest.main() 206 unittest.main()