Fangjun Kuang
Committed by GitHub

Support TDNN models from the yesno recipe from icefall (#262)

@@ -14,6 +14,50 @@ echo "PATH: $PATH" @@ -14,6 +14,50 @@ echo "PATH: $PATH"
14 which $EXE 14 which $EXE
15 15
16 log "------------------------------------------------------------" 16 log "------------------------------------------------------------"
  17 +log "Run tdnn yesno (Hebrew)"
  18 +log "------------------------------------------------------------"
  19 +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-tdnn-yesno
  20 +log "Start testing ${repo_url}"
  21 +repo=$(basename $repo_url)
  22 +log "Download pretrained model and test-data from $repo_url"
  23 +
  24 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
  25 +pushd $repo
  26 +git lfs pull --include "*.onnx"
  27 +ls -lh *.onnx
  28 +popd
  29 +
  30 +log "test float32 models"
  31 +time $EXE \
  32 + --sample-rate=8000 \
  33 + --feat-dim=23 \
  34 + \
  35 + --tokens=$repo/tokens.txt \
  36 + --tdnn-model=$repo/model-epoch-14-avg-2.onnx \
  37 + $repo/test_wavs/0_0_0_1_0_0_0_1.wav \
  38 + $repo/test_wavs/0_0_1_0_0_0_1_0.wav \
  39 + $repo/test_wavs/0_0_1_0_0_1_1_1.wav \
  40 + $repo/test_wavs/0_0_1_0_1_0_0_1.wav \
  41 + $repo/test_wavs/0_0_1_1_0_0_0_1.wav \
  42 + $repo/test_wavs/0_0_1_1_0_1_1_0.wav
  43 +
  44 +log "test int8 models"
  45 +time $EXE \
  46 + --sample-rate=8000 \
  47 + --feat-dim=23 \
  48 + \
  49 + --tokens=$repo/tokens.txt \
  50 + --tdnn-model=$repo/model-epoch-14-avg-2.int8.onnx \
  51 + $repo/test_wavs/0_0_0_1_0_0_0_1.wav \
  52 + $repo/test_wavs/0_0_1_0_0_0_1_0.wav \
  53 + $repo/test_wavs/0_0_1_0_0_1_1_1.wav \
  54 + $repo/test_wavs/0_0_1_0_1_0_0_1.wav \
  55 + $repo/test_wavs/0_0_1_1_0_0_0_1.wav \
  56 + $repo/test_wavs/0_0_1_1_0_1_1_0.wav
  57 +
  58 +rm -rf $repo
  59 +
  60 +log "------------------------------------------------------------"
17 log "Run Citrinet (stt_en_citrinet_512, English)" 61 log "Run Citrinet (stt_en_citrinet_512, English)"
18 log "------------------------------------------------------------" 62 log "------------------------------------------------------------"
19 63
@@ -24,7 +24,7 @@ jobs: @@ -24,7 +24,7 @@ jobs:
24 matrix: 24 matrix:
25 os: [ubuntu-latest, windows-latest, macos-latest] 25 os: [ubuntu-latest, windows-latest, macos-latest]
26 python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] 26 python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
27 - model_type: ["transducer", "paraformer", "nemo_ctc", "whisper"] 27 + model_type: ["transducer", "paraformer", "nemo_ctc", "whisper", "tdnn"]
28 28
29 steps: 29 steps:
30 - uses: actions/checkout@v2 30 - uses: actions/checkout@v2
@@ -172,3 +172,41 @@ jobs: @@ -172,3 +172,41 @@ jobs:
172 ./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav \ 172 ./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav \
173 ./sherpa-onnx-whisper-tiny.en/test_wavs/1.wav \ 173 ./sherpa-onnx-whisper-tiny.en/test_wavs/1.wav \
174 ./sherpa-onnx-whisper-tiny.en/test_wavs/8k.wav 174 ./sherpa-onnx-whisper-tiny.en/test_wavs/8k.wav
  175 +
  176 + - name: Start server for tdnn models
  177 + if: matrix.model_type == 'tdnn'
  178 + shell: bash
  179 + run: |
  180 + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-tdnn-yesno
  181 + cd sherpa-onnx-tdnn-yesno
  182 + git lfs pull --include "*.onnx"
  183 + cd ..
  184 +
  185 + python3 ./python-api-examples/non_streaming_server.py \
  186 + --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
  187 + --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
  188 + --sample-rate=8000 \
  189 + --feat-dim=23 &
  190 +
  191 + echo "sleep 10 seconds to wait the server start"
  192 + sleep 10
  193 +
  194 + - name: Start client for tdnn models
  195 + if: matrix.model_type == 'tdnn'
  196 + shell: bash
  197 + run: |
  198 + python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \
  199 + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
  200 + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
  201 + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav \
  202 + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_1_0_0_1.wav \
  203 + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_1_0_0_0_1.wav \
  204 + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_1_0_1_1_0.wav
  205 +
  206 + python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \
  207 + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
  208 + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
  209 + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav \
  210 + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_1_0_0_1.wav \
  211 + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_1_0_0_0_1.wav \
  212 + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_1_0_1_1_0.wav
1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR) 1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
2 project(sherpa-onnx) 2 project(sherpa-onnx)
3 3
4 -set(SHERPA_ONNX_VERSION "1.7.2") 4 +set(SHERPA_ONNX_VERSION "1.7.3")
5 5
6 # Disable warning about 6 # Disable warning about
7 # 7 #
@@ -71,6 +71,20 @@ python3 ./python-api-examples/non_streaming_server.py \ @@ -71,6 +71,20 @@ python3 ./python-api-examples/non_streaming_server.py \
71 --whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \ 71 --whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \
72 --tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt 72 --tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt
73 73
  74 +(5) Use a tdnn model of the yesno recipe from icefall
  75 +
  76 +cd /path/to/sherpa-onnx
  77 +
  78 +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-tdnn-yesno
  79 +cd sherpa-onnx-tdnn-yesno
  80 +git lfs pull --include "*.onnx"
  81 +
  82 +python3 ./python-api-examples/non_streaming_server.py \
  83 + --sample-rate=8000 \
  84 + --feat-dim=23 \
  85 + --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
  86 + --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt
  87 +
74 ---- 88 ----
75 89
76 To use a certificate so that you can use https, please use 90 To use a certificate so that you can use https, please use
@@ -196,6 +210,15 @@ def add_nemo_ctc_model_args(parser: argparse.ArgumentParser): @@ -196,6 +210,15 @@ def add_nemo_ctc_model_args(parser: argparse.ArgumentParser):
196 ) 210 )
197 211
198 212
  213 +def add_tdnn_ctc_model_args(parser: argparse.ArgumentParser):
  214 + parser.add_argument(
  215 + "--tdnn-model",
  216 + default="",
  217 + type=str,
  218 + help="Path to the model.onnx for the tdnn model of the yesno recipe",
  219 + )
  220 +
  221 +
199 def add_whisper_model_args(parser: argparse.ArgumentParser): 222 def add_whisper_model_args(parser: argparse.ArgumentParser):
200 parser.add_argument( 223 parser.add_argument(
201 "--whisper-encoder", 224 "--whisper-encoder",
@@ -216,6 +239,7 @@ def add_model_args(parser: argparse.ArgumentParser): @@ -216,6 +239,7 @@ def add_model_args(parser: argparse.ArgumentParser):
216 add_transducer_model_args(parser) 239 add_transducer_model_args(parser)
217 add_paraformer_model_args(parser) 240 add_paraformer_model_args(parser)
218 add_nemo_ctc_model_args(parser) 241 add_nemo_ctc_model_args(parser)
  242 + add_tdnn_ctc_model_args(parser)
219 add_whisper_model_args(parser) 243 add_whisper_model_args(parser)
220 244
221 parser.add_argument( 245 parser.add_argument(
@@ -730,6 +754,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -730,6 +754,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
730 assert len(args.nemo_ctc) == 0, args.nemo_ctc 754 assert len(args.nemo_ctc) == 0, args.nemo_ctc
731 assert len(args.whisper_encoder) == 0, args.whisper_encoder 755 assert len(args.whisper_encoder) == 0, args.whisper_encoder
732 assert len(args.whisper_decoder) == 0, args.whisper_decoder 756 assert len(args.whisper_decoder) == 0, args.whisper_decoder
  757 + assert len(args.tdnn_model) == 0, args.tdnn_model
733 758
734 assert_file_exists(args.encoder) 759 assert_file_exists(args.encoder)
735 assert_file_exists(args.decoder) 760 assert_file_exists(args.decoder)
@@ -750,6 +775,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -750,6 +775,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
750 assert len(args.nemo_ctc) == 0, args.nemo_ctc 775 assert len(args.nemo_ctc) == 0, args.nemo_ctc
751 assert len(args.whisper_encoder) == 0, args.whisper_encoder 776 assert len(args.whisper_encoder) == 0, args.whisper_encoder
752 assert len(args.whisper_decoder) == 0, args.whisper_decoder 777 assert len(args.whisper_decoder) == 0, args.whisper_decoder
  778 + assert len(args.tdnn_model) == 0, args.tdnn_model
753 779
754 assert_file_exists(args.paraformer) 780 assert_file_exists(args.paraformer)
755 781
@@ -764,6 +790,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -764,6 +790,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
764 elif args.nemo_ctc: 790 elif args.nemo_ctc:
765 assert len(args.whisper_encoder) == 0, args.whisper_encoder 791 assert len(args.whisper_encoder) == 0, args.whisper_encoder
766 assert len(args.whisper_decoder) == 0, args.whisper_decoder 792 assert len(args.whisper_decoder) == 0, args.whisper_decoder
  793 + assert len(args.tdnn_model) == 0, args.tdnn_model
767 794
768 assert_file_exists(args.nemo_ctc) 795 assert_file_exists(args.nemo_ctc)
769 796
@@ -776,6 +803,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -776,6 +803,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
776 decoding_method=args.decoding_method, 803 decoding_method=args.decoding_method,
777 ) 804 )
778 elif args.whisper_encoder: 805 elif args.whisper_encoder:
  806 + assert len(args.tdnn_model) == 0, args.tdnn_model
779 assert_file_exists(args.whisper_encoder) 807 assert_file_exists(args.whisper_encoder)
780 assert_file_exists(args.whisper_decoder) 808 assert_file_exists(args.whisper_decoder)
781 809
@@ -786,6 +814,17 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -786,6 +814,17 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
786 num_threads=args.num_threads, 814 num_threads=args.num_threads,
787 decoding_method=args.decoding_method, 815 decoding_method=args.decoding_method,
788 ) 816 )
  817 + elif args.tdnn_model:
  818 + assert_file_exists(args.tdnn_model)
  819 +
  820 + recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc(
  821 + model=args.tdnn_model,
  822 + tokens=args.tokens,
  823 + sample_rate=args.sample_rate,
  824 + feature_dim=args.feat_dim,
  825 + num_threads=args.num_threads,
  826 + decoding_method=args.decoding_method,
  827 + )
789 else: 828 else:
790 raise ValueError("Please specify at least one model") 829 raise ValueError("Please specify at least one model")
791 830
@@ -8,6 +8,7 @@ This file demonstrates how to use sherpa-onnx Python API to transcribe @@ -8,6 +8,7 @@ This file demonstrates how to use sherpa-onnx Python API to transcribe
8 file(s) with a non-streaming model. 8 file(s) with a non-streaming model.
9 9
10 (1) For paraformer 10 (1) For paraformer
  11 +
11 ./python-api-examples/offline-decode-files.py \ 12 ./python-api-examples/offline-decode-files.py \
12 --tokens=/path/to/tokens.txt \ 13 --tokens=/path/to/tokens.txt \
13 --paraformer=/path/to/paraformer.onnx \ 14 --paraformer=/path/to/paraformer.onnx \
@@ -20,6 +21,7 @@ file(s) with a non-streaming model. @@ -20,6 +21,7 @@ file(s) with a non-streaming model.
20 /path/to/1.wav 21 /path/to/1.wav
21 22
22 (2) For transducer models from icefall 23 (2) For transducer models from icefall
  24 +
23 ./python-api-examples/offline-decode-files.py \ 25 ./python-api-examples/offline-decode-files.py \
24 --tokens=/path/to/tokens.txt \ 26 --tokens=/path/to/tokens.txt \
25 --encoder=/path/to/encoder.onnx \ 27 --encoder=/path/to/encoder.onnx \
@@ -56,9 +58,20 @@ python3 ./python-api-examples/offline-decode-files.py \ @@ -56,9 +58,20 @@ python3 ./python-api-examples/offline-decode-files.py \
56 ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \ 58 ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
57 ./sherpa-onnx-whisper-base.en/test_wavs/8k.wav 59 ./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
58 60
  61 +(5) For tdnn models of the yesno recipe from icefall
  62 +
  63 +python3 ./python-api-examples/offline-decode-files.py \
  64 + --sample-rate=8000 \
  65 + --feature-dim=23 \
  66 + --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
  67 + --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
  68 + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
  69 + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
  70 + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav
  71 +
59 Please refer to 72 Please refer to
60 https://k2-fsa.github.io/sherpa/onnx/index.html 73 https://k2-fsa.github.io/sherpa/onnx/index.html
61 -to install sherpa-onnx and to download the pre-trained models 74 +to install sherpa-onnx and to download non-streaming pre-trained models
62 used in this file. 75 used in this file.
63 """ 76 """
64 import argparse 77 import argparse
@@ -160,6 +173,13 @@ def get_args(): @@ -160,6 +173,13 @@ def get_args():
160 ) 173 )
161 174
162 parser.add_argument( 175 parser.add_argument(
  176 + "--tdnn-model",
  177 + default="",
  178 + type=str,
  179 + help="Path to the model.onnx for the tdnn model of the yesno recipe",
  180 + )
  181 +
  182 + parser.add_argument(
163 "--num-threads", 183 "--num-threads",
164 type=int, 184 type=int,
165 default=1, 185 default=1,
@@ -285,6 +305,7 @@ def main(): @@ -285,6 +305,7 @@ def main():
285 assert len(args.nemo_ctc) == 0, args.nemo_ctc 305 assert len(args.nemo_ctc) == 0, args.nemo_ctc
286 assert len(args.whisper_encoder) == 0, args.whisper_encoder 306 assert len(args.whisper_encoder) == 0, args.whisper_encoder
287 assert len(args.whisper_decoder) == 0, args.whisper_decoder 307 assert len(args.whisper_decoder) == 0, args.whisper_decoder
  308 + assert len(args.tdnn_model) == 0, args.tdnn_model
288 309
289 contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()] 310 contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]
290 if contexts: 311 if contexts:
@@ -311,6 +332,7 @@ def main(): @@ -311,6 +332,7 @@ def main():
311 assert len(args.nemo_ctc) == 0, args.nemo_ctc 332 assert len(args.nemo_ctc) == 0, args.nemo_ctc
312 assert len(args.whisper_encoder) == 0, args.whisper_encoder 333 assert len(args.whisper_encoder) == 0, args.whisper_encoder
313 assert len(args.whisper_decoder) == 0, args.whisper_decoder 334 assert len(args.whisper_decoder) == 0, args.whisper_decoder
  335 + assert len(args.tdnn_model) == 0, args.tdnn_model
314 336
315 assert_file_exists(args.paraformer) 337 assert_file_exists(args.paraformer)
316 338
@@ -326,6 +348,7 @@ def main(): @@ -326,6 +348,7 @@ def main():
326 elif args.nemo_ctc: 348 elif args.nemo_ctc:
327 assert len(args.whisper_encoder) == 0, args.whisper_encoder 349 assert len(args.whisper_encoder) == 0, args.whisper_encoder
328 assert len(args.whisper_decoder) == 0, args.whisper_decoder 350 assert len(args.whisper_decoder) == 0, args.whisper_decoder
  351 + assert len(args.tdnn_model) == 0, args.tdnn_model
329 352
330 assert_file_exists(args.nemo_ctc) 353 assert_file_exists(args.nemo_ctc)
331 354
@@ -339,6 +362,7 @@ def main(): @@ -339,6 +362,7 @@ def main():
339 debug=args.debug, 362 debug=args.debug,
340 ) 363 )
341 elif args.whisper_encoder: 364 elif args.whisper_encoder:
  365 + assert len(args.tdnn_model) == 0, args.tdnn_model
342 assert_file_exists(args.whisper_encoder) 366 assert_file_exists(args.whisper_encoder)
343 assert_file_exists(args.whisper_decoder) 367 assert_file_exists(args.whisper_decoder)
344 368
@@ -347,6 +371,20 @@ def main(): @@ -347,6 +371,20 @@ def main():
347 decoder=args.whisper_decoder, 371 decoder=args.whisper_decoder,
348 tokens=args.tokens, 372 tokens=args.tokens,
349 num_threads=args.num_threads, 373 num_threads=args.num_threads,
  374 + sample_rate=args.sample_rate,
  375 + feature_dim=args.feature_dim,
  376 + decoding_method=args.decoding_method,
  377 + debug=args.debug,
  378 + )
  379 + elif args.tdnn_model:
  380 + assert_file_exists(args.tdnn_model)
  381 +
  382 + recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc(
  383 + model=args.tdnn_model,
  384 + tokens=args.tokens,
  385 + sample_rate=args.sample_rate,
  386 + feature_dim=args.feature_dim,
  387 + num_threads=args.num_threads,
350 decoding_method=args.decoding_method, 388 decoding_method=args.decoding_method,
351 debug=args.debug, 389 debug=args.debug,
352 ) 390 )
@@ -97,20 +97,18 @@ function onFileChange() { @@ -97,20 +97,18 @@ function onFileChange() {
97 console.log('file.type ' + file.type); 97 console.log('file.type ' + file.type);
98 console.log('file.size ' + file.size); 98 console.log('file.size ' + file.size);
99 99
  100 + let audioCtx = new AudioContext({sampleRate: 16000});
  101 +
100 let reader = new FileReader(); 102 let reader = new FileReader();
101 reader.onload = function() { 103 reader.onload = function() {
102 console.log('reading file!'); 104 console.log('reading file!');
103 - let view = new Int16Array(reader.result);  
104 - // we assume the input file is a wav file.  
105 - // TODO: add some checks here.  
106 - let int16_samples = view.subarray(22); // header has 44 bytes == 22 shorts  
107 - let num_samples = int16_samples.length;  
108 - let float32_samples = new Float32Array(num_samples);  
109 - console.log('num_samples ' + num_samples)  
110 -  
111 - for (let i = 0; i < num_samples; ++i) {  
112 - float32_samples[i] = int16_samples[i] / 32768.  
113 - } 105 + audioCtx.decodeAudioData(reader.result, decodedDone);
  106 + };
  107 +
  108 + function decodedDone(decoded) {
  109 + let typedArray = new Float32Array(decoded.length);
  110 + let float32_samples = decoded.getChannelData(0);
  111 + let buf = float32_samples.buffer
114 112
115 // Send 1024 audio samples per request. 113 // Send 1024 audio samples per request.
116 // 114 //
@@ -119,14 +117,13 @@ function onFileChange() { @@ -119,14 +117,13 @@ function onFileChange() {
119 // (2) There is a limit on the number of bytes in the payload that can be 117 // (2) There is a limit on the number of bytes in the payload that can be
120 // sent by websocket, which is 1MB, I think. We can send a large 118 // sent by websocket, which is 1MB, I think. We can send a large
121 // audio file for decoding in this approach. 119 // audio file for decoding in this approach.
122 - let buf = float32_samples.buffer  
123 let n = 1024 * 4; // send this number of bytes per request. 120 let n = 1024 * 4; // send this number of bytes per request.
124 console.log('buf length, ' + buf.byteLength); 121 console.log('buf length, ' + buf.byteLength);
125 send_header(buf.byteLength); 122 send_header(buf.byteLength);
126 for (let start = 0; start < buf.byteLength; start += n) { 123 for (let start = 0; start < buf.byteLength; start += n) {
127 socket.send(buf.slice(start, start + n)); 124 socket.send(buf.slice(start, start + n));
128 } 125 }
129 - }; 126 + }
130 127
131 reader.readAsArrayBuffer(file); 128 reader.readAsArrayBuffer(file);
132 } 129 }
@@ -32,6 +32,8 @@ set(sources @@ -32,6 +32,8 @@ set(sources
32 offline-recognizer.cc 32 offline-recognizer.cc
33 offline-rnn-lm.cc 33 offline-rnn-lm.cc
34 offline-stream.cc 34 offline-stream.cc
  35 + offline-tdnn-ctc-model.cc
  36 + offline-tdnn-model-config.cc
35 offline-transducer-greedy-search-decoder.cc 37 offline-transducer-greedy-search-decoder.cc
36 offline-transducer-model-config.cc 38 offline-transducer-model-config.cc
37 offline-transducer-model.cc 39 offline-transducer-model.cc
@@ -11,12 +11,14 @@ @@ -11,12 +11,14 @@
11 11
12 #include "sherpa-onnx/csrc/macros.h" 12 #include "sherpa-onnx/csrc/macros.h"
13 #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h" 13 #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
  14 +#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h"
14 #include "sherpa-onnx/csrc/onnx-utils.h" 15 #include "sherpa-onnx/csrc/onnx-utils.h"
15 16
16 namespace { 17 namespace {
17 18
18 enum class ModelType { 19 enum class ModelType {
19 kEncDecCTCModelBPE, 20 kEncDecCTCModelBPE,
  21 + kTdnn,
20 kUnkown, 22 kUnkown,
21 }; 23 };
22 24
@@ -55,6 +57,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -55,6 +57,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
55 57
56 if (model_type.get() == std::string("EncDecCTCModelBPE")) { 58 if (model_type.get() == std::string("EncDecCTCModelBPE")) {
57 return ModelType::kEncDecCTCModelBPE; 59 return ModelType::kEncDecCTCModelBPE;
  60 + } else if (model_type.get() == std::string("tdnn")) {
  61 + return ModelType::kTdnn;
58 } else { 62 } else {
59 SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); 63 SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
60 return ModelType::kUnkown; 64 return ModelType::kUnkown;
@@ -65,8 +69,18 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( @@ -65,8 +69,18 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
65 const OfflineModelConfig &config) { 69 const OfflineModelConfig &config) {
66 ModelType model_type = ModelType::kUnkown; 70 ModelType model_type = ModelType::kUnkown;
67 71
  72 + std::string filename;
  73 + if (!config.nemo_ctc.model.empty()) {
  74 + filename = config.nemo_ctc.model;
  75 + } else if (!config.tdnn.model.empty()) {
  76 + filename = config.tdnn.model;
  77 + } else {
  78 + SHERPA_ONNX_LOGE("Please specify a CTC model");
  79 + exit(-1);
  80 + }
  81 +
68 { 82 {
69 - auto buffer = ReadFile(config.nemo_ctc.model); 83 + auto buffer = ReadFile(filename);
70 84
71 model_type = GetModelType(buffer.data(), buffer.size(), config.debug); 85 model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
72 } 86 }
@@ -75,6 +89,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( @@ -75,6 +89,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
75 case ModelType::kEncDecCTCModelBPE: 89 case ModelType::kEncDecCTCModelBPE:
76 return std::make_unique<OfflineNemoEncDecCtcModel>(config); 90 return std::make_unique<OfflineNemoEncDecCtcModel>(config);
77 break; 91 break;
  92 + case ModelType::kTdnn:
  93 + return std::make_unique<OfflineTdnnCtcModel>(config);
  94 + break;
78 case ModelType::kUnkown: 95 case ModelType::kUnkown:
79 SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); 96 SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
80 return nullptr; 97 return nullptr;
@@ -39,10 +39,10 @@ class OfflineCtcModel { @@ -39,10 +39,10 @@ class OfflineCtcModel {
39 39
40 /** SubsamplingFactor of the model 40 /** SubsamplingFactor of the model
41 * 41 *
42 - * For Citrinet, the subsampling factor is usually 4.  
43 - * For Conformer CTC, the subsampling factor is usually 8. 42 + * For NeMo Citrinet, the subsampling factor is usually 4.
  43 + * For NeMo Conformer CTC, the subsampling factor is usually 8.
44 */ 44 */
45 - virtual int32_t SubsamplingFactor() const = 0; 45 + virtual int32_t SubsamplingFactor() const { return 1; }
46 46
47 /** Return an allocator for allocating memory 47 /** Return an allocator for allocating memory
48 */ 48 */
@@ -15,6 +15,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { @@ -15,6 +15,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
15 paraformer.Register(po); 15 paraformer.Register(po);
16 nemo_ctc.Register(po); 16 nemo_ctc.Register(po);
17 whisper.Register(po); 17 whisper.Register(po);
  18 + tdnn.Register(po);
18 19
19 po->Register("tokens", &tokens, "Path to tokens.txt"); 20 po->Register("tokens", &tokens, "Path to tokens.txt");
20 21
@@ -29,7 +30,8 @@ void OfflineModelConfig::Register(ParseOptions *po) { @@ -29,7 +30,8 @@ void OfflineModelConfig::Register(ParseOptions *po) {
29 30
30 po->Register("model-type", &model_type, 31 po->Register("model-type", &model_type,
31 "Specify it to reduce model initialization time. " 32 "Specify it to reduce model initialization time. "
32 - "Valid values are: transducer, paraformer, nemo_ctc, whisper." 33 + "Valid values are: transducer, paraformer, nemo_ctc, whisper, "
  34 + "tdnn."
33 "All other values lead to loading the model twice."); 35 "All other values lead to loading the model twice.");
34 } 36 }
35 37
@@ -56,6 +58,10 @@ bool OfflineModelConfig::Validate() const { @@ -56,6 +58,10 @@ bool OfflineModelConfig::Validate() const {
56 return whisper.Validate(); 58 return whisper.Validate();
57 } 59 }
58 60
  61 + if (!tdnn.model.empty()) {
  62 + return tdnn.Validate();
  63 + }
  64 +
59 return transducer.Validate(); 65 return transducer.Validate();
60 } 66 }
61 67
@@ -67,6 +73,7 @@ std::string OfflineModelConfig::ToString() const { @@ -67,6 +73,7 @@ std::string OfflineModelConfig::ToString() const {
67 os << "paraformer=" << paraformer.ToString() << ", "; 73 os << "paraformer=" << paraformer.ToString() << ", ";
68 os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; 74 os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
69 os << "whisper=" << whisper.ToString() << ", "; 75 os << "whisper=" << whisper.ToString() << ", ";
  76 + os << "tdnn=" << tdnn.ToString() << ", ";
70 os << "tokens=\"" << tokens << "\", "; 77 os << "tokens=\"" << tokens << "\", ";
71 os << "num_threads=" << num_threads << ", "; 78 os << "num_threads=" << num_threads << ", ";
72 os << "debug=" << (debug ? "True" : "False") << ", "; 79 os << "debug=" << (debug ? "True" : "False") << ", ";
@@ -8,6 +8,7 @@ @@ -8,6 +8,7 @@
8 8
9 #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h" 9 #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
10 #include "sherpa-onnx/csrc/offline-paraformer-model-config.h" 10 #include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
  11 +#include "sherpa-onnx/csrc/offline-tdnn-model-config.h"
11 #include "sherpa-onnx/csrc/offline-transducer-model-config.h" 12 #include "sherpa-onnx/csrc/offline-transducer-model-config.h"
12 #include "sherpa-onnx/csrc/offline-whisper-model-config.h" 13 #include "sherpa-onnx/csrc/offline-whisper-model-config.h"
13 14
@@ -18,6 +19,7 @@ struct OfflineModelConfig { @@ -18,6 +19,7 @@ struct OfflineModelConfig {
18 OfflineParaformerModelConfig paraformer; 19 OfflineParaformerModelConfig paraformer;
19 OfflineNemoEncDecCtcModelConfig nemo_ctc; 20 OfflineNemoEncDecCtcModelConfig nemo_ctc;
20 OfflineWhisperModelConfig whisper; 21 OfflineWhisperModelConfig whisper;
  22 + OfflineTdnnModelConfig tdnn;
21 23
22 std::string tokens; 24 std::string tokens;
23 int32_t num_threads = 2; 25 int32_t num_threads = 2;
@@ -40,12 +42,14 @@ struct OfflineModelConfig { @@ -40,12 +42,14 @@ struct OfflineModelConfig {
40 const OfflineParaformerModelConfig &paraformer, 42 const OfflineParaformerModelConfig &paraformer,
41 const OfflineNemoEncDecCtcModelConfig &nemo_ctc, 43 const OfflineNemoEncDecCtcModelConfig &nemo_ctc,
42 const OfflineWhisperModelConfig &whisper, 44 const OfflineWhisperModelConfig &whisper,
  45 + const OfflineTdnnModelConfig &tdnn,
43 const std::string &tokens, int32_t num_threads, bool debug, 46 const std::string &tokens, int32_t num_threads, bool debug,
44 const std::string &provider, const std::string &model_type) 47 const std::string &provider, const std::string &model_type)
45 : transducer(transducer), 48 : transducer(transducer),
46 paraformer(paraformer), 49 paraformer(paraformer),
47 nemo_ctc(nemo_ctc), 50 nemo_ctc(nemo_ctc),
48 whisper(whisper), 51 whisper(whisper),
  52 + tdnn(tdnn),
49 tokens(tokens), 53 tokens(tokens),
50 num_threads(num_threads), 54 num_threads(num_threads),
51 debug(debug), 55 debug(debug),
@@ -27,6 +27,10 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, @@ -27,6 +27,10 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
27 std::string text; 27 std::string text;
28 28
29 for (int32_t i = 0; i != src.tokens.size(); ++i) { 29 for (int32_t i = 0; i != src.tokens.size(); ++i) {
  30 + if (sym_table.contains("SIL") && src.tokens[i] == sym_table["SIL"]) {
  31 + // tdnn models from yesno have a SIL token, we should remove it.
  32 + continue;
  33 + }
30 auto sym = sym_table[src.tokens[i]]; 34 auto sym = sym_table[src.tokens[i]];
31 text.append(sym); 35 text.append(sym);
32 r.tokens.push_back(std::move(sym)); 36 r.tokens.push_back(std::move(sym));
@@ -46,14 +50,22 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { @@ -46,14 +50,22 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
46 model_->FeatureNormalizationMethod(); 50 model_->FeatureNormalizationMethod();
47 51
48 if (config.decoding_method == "greedy_search") { 52 if (config.decoding_method == "greedy_search") {
49 - if (!symbol_table_.contains("<blk>")) { 53 + if (!symbol_table_.contains("<blk>") &&
  54 + !symbol_table_.contains("<eps>")) {
50 SHERPA_ONNX_LOGE( 55 SHERPA_ONNX_LOGE(
51 "We expect that tokens.txt contains " 56 "We expect that tokens.txt contains "
52 - "the symbol <blk> and its ID."); 57 + "the symbol <blk> or <eps> and its ID.");
53 exit(-1); 58 exit(-1);
54 } 59 }
55 60
56 - int32_t blank_id = symbol_table_["<blk>"]; 61 + int32_t blank_id = 0;
  62 + if (symbol_table_.contains("<blk>")) {
  63 + blank_id = symbol_table_["<blk>"];
  64 + } else if (symbol_table_.contains("<eps>")) {
  65 + // for tdnn models of the yesno recipe from icefall
  66 + blank_id = symbol_table_["<eps>"];
  67 + }
  68 +
57 decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id); 69 decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id);
58 } else { 70 } else {
59 SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", 71 SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
@@ -27,6 +27,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -27,6 +27,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
27 return std::make_unique<OfflineRecognizerParaformerImpl>(config); 27 return std::make_unique<OfflineRecognizerParaformerImpl>(config);
28 } else if (model_type == "nemo_ctc") { 28 } else if (model_type == "nemo_ctc") {
29 return std::make_unique<OfflineRecognizerCtcImpl>(config); 29 return std::make_unique<OfflineRecognizerCtcImpl>(config);
  30 + } else if (model_type == "tdnn") {
  31 + return std::make_unique<OfflineRecognizerCtcImpl>(config);
30 } else if (model_type == "whisper") { 32 } else if (model_type == "whisper") {
31 return std::make_unique<OfflineRecognizerWhisperImpl>(config); 33 return std::make_unique<OfflineRecognizerWhisperImpl>(config);
32 } else { 34 } else {
@@ -46,6 +48,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -46,6 +48,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
46 model_filename = config.model_config.paraformer.model; 48 model_filename = config.model_config.paraformer.model;
47 } else if (!config.model_config.nemo_ctc.model.empty()) { 49 } else if (!config.model_config.nemo_ctc.model.empty()) {
48 model_filename = config.model_config.nemo_ctc.model; 50 model_filename = config.model_config.nemo_ctc.model;
  51 + } else if (!config.model_config.tdnn.model.empty()) {
  52 + model_filename = config.model_config.tdnn.model;
49 } else if (!config.model_config.whisper.encoder.empty()) { 53 } else if (!config.model_config.whisper.encoder.empty()) {
50 model_filename = config.model_config.whisper.encoder; 54 model_filename = config.model_config.whisper.encoder;
51 } else { 55 } else {
@@ -84,6 +88,11 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -84,6 +88,11 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
84 "paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py" 88 "paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py"
85 "\n " 89 "\n "
86 "(3) Whisper" 90 "(3) Whisper"
  91 + "\n "
  92 + "(4) Tdnn models of the yesno recipe from icefall"
  93 + "\n "
  94 + "https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn"
  95 + "\n"
87 "\n"); 96 "\n");
88 exit(-1); 97 exit(-1);
89 } 98 }
@@ -102,6 +111,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -102,6 +111,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
102 return std::make_unique<OfflineRecognizerCtcImpl>(config); 111 return std::make_unique<OfflineRecognizerCtcImpl>(config);
103 } 112 }
104 113
  114 + if (model_type == "tdnn") {
  115 + return std::make_unique<OfflineRecognizerCtcImpl>(config);
  116 + }
  117 +
105 if (strncmp(model_type.c_str(), "whisper", 7) == 0) { 118 if (strncmp(model_type.c_str(), "whisper", 7) == 0) {
106 return std::make_unique<OfflineRecognizerWhisperImpl>(config); 119 return std::make_unique<OfflineRecognizerWhisperImpl>(config);
107 } 120 }
@@ -112,7 +125,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -112,7 +125,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
112 " - Non-streaming transducer models from icefall\n" 125 " - Non-streaming transducer models from icefall\n"
113 " - Non-streaming Paraformer models from FunASR\n" 126 " - Non-streaming Paraformer models from FunASR\n"
114 " - EncDecCTCModelBPE models from NeMo\n" 127 " - EncDecCTCModelBPE models from NeMo\n"
115 - " - Whisper models\n", 128 + " - Whisper models\n"
  129 + " - Tdnn models\n",
116 model_type.c_str()); 130 model_type.c_str());
117 131
118 exit(-1); 132 exit(-1);
  1 +// sherpa-onnx/csrc/offline-tdnn-ctc-model.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h"
  6 +
  7 +#include "sherpa-onnx/csrc/macros.h"
  8 +#include "sherpa-onnx/csrc/onnx-utils.h"
  9 +#include "sherpa-onnx/csrc/session.h"
  10 +#include "sherpa-onnx/csrc/text-utils.h"
  11 +#include "sherpa-onnx/csrc/transpose.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +class OfflineTdnnCtcModel::Impl {
  16 + public:
  17 + explicit Impl(const OfflineModelConfig &config)
  18 + : config_(config),
  19 + env_(ORT_LOGGING_LEVEL_ERROR),
  20 + sess_opts_(GetSessionOptions(config)),
  21 + allocator_{} {
  22 + Init();
  23 + }
  24 +
  25 + std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features) {
  26 + auto nnet_out =
  27 + sess_->Run({}, input_names_ptr_.data(), &features, 1,
  28 + output_names_ptr_.data(), output_names_ptr_.size());
  29 +
  30 + std::vector<int64_t> nnet_out_shape =
  31 + nnet_out[0].GetTensorTypeAndShapeInfo().GetShape();
  32 +
  33 + std::vector<int64_t> out_length_vec(nnet_out_shape[0], nnet_out_shape[1]);
  34 + std::vector<int64_t> out_length_shape(1, nnet_out_shape[0]);
  35 +
  36 + auto memory_info =
  37 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  38 +
  39 + Ort::Value nnet_out_length = Ort::Value::CreateTensor(
  40 + memory_info, out_length_vec.data(), out_length_vec.size(),
  41 + out_length_shape.data(), out_length_shape.size());
  42 +
  43 + return {std::move(nnet_out[0]), Clone(Allocator(), &nnet_out_length)};
  44 + }
  45 +
  46 + int32_t VocabSize() const { return vocab_size_; }
  47 +
  48 + OrtAllocator *Allocator() const { return allocator_; }
  49 +
  50 + private:
  51 + void Init() {
  52 + auto buf = ReadFile(config_.tdnn.model);
  53 +
  54 + sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
  55 + sess_opts_);
  56 +
  57 + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
  58 +
  59 + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
  60 +
  61 + // get meta data
  62 + Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
  63 + if (config_.debug) {
  64 + std::ostringstream os;
  65 + PrintModelMetadata(os, meta_data);
  66 + SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
  67 + }
  68 +
  69 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  70 + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
  71 + }
  72 +
  73 + private:
  74 + OfflineModelConfig config_;
  75 + Ort::Env env_;
  76 + Ort::SessionOptions sess_opts_;
  77 + Ort::AllocatorWithDefaultOptions allocator_;
  78 +
  79 + std::unique_ptr<Ort::Session> sess_;
  80 +
  81 + std::vector<std::string> input_names_;
  82 + std::vector<const char *> input_names_ptr_;
  83 +
  84 + std::vector<std::string> output_names_;
  85 + std::vector<const char *> output_names_ptr_;
  86 +
  87 + int32_t vocab_size_ = 0;
  88 +};
  89 +
  90 +OfflineTdnnCtcModel::OfflineTdnnCtcModel(const OfflineModelConfig &config)
  91 + : impl_(std::make_unique<Impl>(config)) {}
  92 +
  93 +OfflineTdnnCtcModel::~OfflineTdnnCtcModel() = default;
  94 +
  95 +std::pair<Ort::Value, Ort::Value> OfflineTdnnCtcModel::Forward(
  96 + Ort::Value features, Ort::Value /*features_length*/) {
  97 + return impl_->Forward(std::move(features));
  98 +}
  99 +
  100 +int32_t OfflineTdnnCtcModel::VocabSize() const { return impl_->VocabSize(); }
  101 +
  102 +OrtAllocator *OfflineTdnnCtcModel::Allocator() const {
  103 + return impl_->Allocator();
  104 +}
  105 +
  106 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-tdnn-ctc-model.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_
  6 +#include <memory>
  7 +#include <string>
  8 +#include <utility>
  9 +#include <vector>
  10 +
  11 +#include "onnxruntime_cxx_api.h" // NOLINT
  12 +#include "sherpa-onnx/csrc/offline-ctc-model.h"
  13 +#include "sherpa-onnx/csrc/offline-model-config.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +/** This class implements the tdnn model of the yesno recipe from icefall.
  18 + *
  19 + * See
  20 + * https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn
  21 + */
  22 +class OfflineTdnnCtcModel : public OfflineCtcModel {
  23 + public:
  24 + explicit OfflineTdnnCtcModel(const OfflineModelConfig &config);
  25 + ~OfflineTdnnCtcModel() override;
  26 +
  27 + /** Run the forward method of the model.
  28 + *
  29 + * @param features A tensor of shape (N, T, C). It is changed in-place.
  30 + * @param features_length A 1-D tensor of shape (N,) containing number of
  31 + * valid frames in `features` before padding.
  32 + * Its dtype is int64_t.
  33 + *
  34 + * @return Return a pair containing:
  35 + * - log_probs: A 3-D tensor of shape (N, T', vocab_size).
  36 + * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
  37 + */
  38 + std::pair<Ort::Value, Ort::Value> Forward(
  39 + Ort::Value features, Ort::Value /*features_length*/) override;
  40 +
  41 + /** Return the vocabulary size of the model
  42 + */
  43 + int32_t VocabSize() const override;
  44 +
  45 + /** Return an allocator for allocating memory
  46 + */
  47 + OrtAllocator *Allocator() const override;
  48 +
  49 + private:
  50 + class Impl;
  51 + std::unique_ptr<Impl> impl_;
  52 +};
  53 +
  54 +} // namespace sherpa_onnx
  55 +
  56 +#endif // SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_
  1 +// sherpa-onnx/csrc/offline-tdnn-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-tdnn-model-config.h"
  6 +
  7 +#include "sherpa-onnx/csrc/file-utils.h"
  8 +#include "sherpa-onnx/csrc/macros.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void OfflineTdnnModelConfig::Register(ParseOptions *po) {
  13 + po->Register("tdnn-model", &model, "Path to onnx model");
  14 +}
  15 +
  16 +bool OfflineTdnnModelConfig::Validate() const {
  17 + if (!FileExists(model)) {
  18 + SHERPA_ONNX_LOGE("tdnn model file %s does not exist", model.c_str());
  19 + return false;
  20 + }
  21 +
  22 + return true;
  23 +}
  24 +
  25 +std::string OfflineTdnnModelConfig::ToString() const {
  26 + std::ostringstream os;
  27 +
  28 + os << "OfflineTdnnModelConfig(";
  29 + os << "model=\"" << model << "\")";
  30 +
  31 + return os.str();
  32 +}
  33 +
  34 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-tdnn-model-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/parse-options.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +// for https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn
  14 +struct OfflineTdnnModelConfig {
  15 + std::string model;
  16 +
  17 + OfflineTdnnModelConfig() = default;
  18 + explicit OfflineTdnnModelConfig(const std::string &model) : model(model) {}
  19 +
  20 + void Register(ParseOptions *po);
  21 + bool Validate() const;
  22 +
  23 + std::string ToString() const;
  24 +};
  25 +
  26 +} // namespace sherpa_onnx
  27 +
  28 +#endif // SHERPA_ONNX_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_
@@ -14,10 +14,14 @@ @@ -14,10 +14,14 @@
14 14
15 int main(int32_t argc, char *argv[]) { 15 int main(int32_t argc, char *argv[]) {
16 const char *kUsageMessage = R"usage( 16 const char *kUsageMessage = R"usage(
  17 +Speech recognition using non-streaming models with sherpa-onnx.
  18 +
17 Usage: 19 Usage:
18 20
19 (1) Transducer from icefall 21 (1) Transducer from icefall
20 22
  23 +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html
  24 +
21 ./bin/sherpa-onnx-offline \ 25 ./bin/sherpa-onnx-offline \
22 --tokens=/path/to/tokens.txt \ 26 --tokens=/path/to/tokens.txt \
23 --encoder=/path/to/encoder.onnx \ 27 --encoder=/path/to/encoder.onnx \
@@ -30,6 +34,8 @@ Usage: @@ -30,6 +34,8 @@ Usage:
30 34
31 (2) Paraformer from FunASR 35 (2) Paraformer from FunASR
32 36
  37 +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html
  38 +
33 ./bin/sherpa-onnx-offline \ 39 ./bin/sherpa-onnx-offline \
34 --tokens=/path/to/tokens.txt \ 40 --tokens=/path/to/tokens.txt \
35 --paraformer=/path/to/model.onnx \ 41 --paraformer=/path/to/model.onnx \
@@ -39,6 +45,8 @@ Usage: @@ -39,6 +45,8 @@ Usage:
39 45
40 (3) Whisper models 46 (3) Whisper models
41 47
  48 +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html
  49 +
42 ./bin/sherpa-onnx-offline \ 50 ./bin/sherpa-onnx-offline \
43 --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \ 51 --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
44 --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \ 52 --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
@@ -46,6 +54,31 @@ Usage: @@ -46,6 +54,31 @@ Usage:
46 --num-threads=1 \ 54 --num-threads=1 \
47 /path/to/foo.wav [bar.wav foobar.wav ...] 55 /path/to/foo.wav [bar.wav foobar.wav ...]
48 56
  57 +(4) NeMo CTC models
  58 +
  59 +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html
  60 +
  61 + ./bin/sherpa-onnx-offline \
  62 + --tokens=./sherpa-onnx-nemo-ctc-en-conformer-medium/tokens.txt \
  63 + --nemo-ctc-model=./sherpa-onnx-nemo-ctc-en-conformer-medium/model.onnx \
  64 + --num-threads=2 \
  65 + --decoding-method=greedy_search \
  66 + --debug=false \
  67 + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/0.wav \
  68 + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/1.wav \
  69 + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/8k.wav
  70 +
  71 +(5) TDNN CTC model for the yesno recipe from icefall
  72 +
  73 +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/yesno/index.html
  74 + //
  75 + ./build/bin/sherpa-onnx-offline \
  76 + --sample-rate=8000 \
  77 + --feat-dim=23 \
  78 + --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
  79 + --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
  80 + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
  81 + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav
49 82
50 Note: It supports decoding multiple files in batches 83 Note: It supports decoding multiple files in batches
51 84
@@ -10,6 +10,7 @@ pybind11_add_module(_sherpa_onnx @@ -10,6 +10,7 @@ pybind11_add_module(_sherpa_onnx
10 offline-paraformer-model-config.cc 10 offline-paraformer-model-config.cc
11 offline-recognizer.cc 11 offline-recognizer.cc
12 offline-stream.cc 12 offline-stream.cc
  13 + offline-tdnn-model-config.cc
13 offline-transducer-model-config.cc 14 offline-transducer-model-config.cc
14 offline-whisper-model-config.cc 15 offline-whisper-model-config.cc
15 online-lm-config.cc 16 online-lm-config.cc
@@ -10,6 +10,7 @@ @@ -10,6 +10,7 @@
10 #include "sherpa-onnx/csrc/offline-model-config.h" 10 #include "sherpa-onnx/csrc/offline-model-config.h"
11 #include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h" 11 #include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
12 #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" 12 #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
  13 +#include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h"
13 #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" 14 #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
14 #include "sherpa-onnx/python/csrc/offline-whisper-model-config.h" 15 #include "sherpa-onnx/python/csrc/offline-whisper-model-config.h"
15 16
@@ -20,24 +21,28 @@ void PybindOfflineModelConfig(py::module *m) { @@ -20,24 +21,28 @@ void PybindOfflineModelConfig(py::module *m) {
20 PybindOfflineParaformerModelConfig(m); 21 PybindOfflineParaformerModelConfig(m);
21 PybindOfflineNemoEncDecCtcModelConfig(m); 22 PybindOfflineNemoEncDecCtcModelConfig(m);
22 PybindOfflineWhisperModelConfig(m); 23 PybindOfflineWhisperModelConfig(m);
  24 + PybindOfflineTdnnModelConfig(m);
23 25
24 using PyClass = OfflineModelConfig; 26 using PyClass = OfflineModelConfig;
25 py::class_<PyClass>(*m, "OfflineModelConfig") 27 py::class_<PyClass>(*m, "OfflineModelConfig")
26 .def(py::init<const OfflineTransducerModelConfig &, 28 .def(py::init<const OfflineTransducerModelConfig &,
27 const OfflineParaformerModelConfig &, 29 const OfflineParaformerModelConfig &,
28 const OfflineNemoEncDecCtcModelConfig &, 30 const OfflineNemoEncDecCtcModelConfig &,
29 - const OfflineWhisperModelConfig &, const std::string &, 31 + const OfflineWhisperModelConfig &,
  32 + const OfflineTdnnModelConfig &, const std::string &,
30 int32_t, bool, const std::string &, const std::string &>(), 33 int32_t, bool, const std::string &, const std::string &>(),
31 py::arg("transducer") = OfflineTransducerModelConfig(), 34 py::arg("transducer") = OfflineTransducerModelConfig(),
32 py::arg("paraformer") = OfflineParaformerModelConfig(), 35 py::arg("paraformer") = OfflineParaformerModelConfig(),
33 py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), 36 py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
34 - py::arg("whisper") = OfflineWhisperModelConfig(), py::arg("tokens"), 37 + py::arg("whisper") = OfflineWhisperModelConfig(),
  38 + py::arg("tdnn") = OfflineTdnnModelConfig(), py::arg("tokens"),
35 py::arg("num_threads"), py::arg("debug") = false, 39 py::arg("num_threads"), py::arg("debug") = false,
36 py::arg("provider") = "cpu", py::arg("model_type") = "") 40 py::arg("provider") = "cpu", py::arg("model_type") = "")
37 .def_readwrite("transducer", &PyClass::transducer) 41 .def_readwrite("transducer", &PyClass::transducer)
38 .def_readwrite("paraformer", &PyClass::paraformer) 42 .def_readwrite("paraformer", &PyClass::paraformer)
39 .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) 43 .def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
40 .def_readwrite("whisper", &PyClass::whisper) 44 .def_readwrite("whisper", &PyClass::whisper)
  45 + .def_readwrite("tdnn", &PyClass::tdnn)
41 .def_readwrite("tokens", &PyClass::tokens) 46 .def_readwrite("tokens", &PyClass::tokens)
42 .def_readwrite("num_threads", &PyClass::num_threads) 47 .def_readwrite("num_threads", &PyClass::num_threads)
43 .def_readwrite("debug", &PyClass::debug) 48 .def_readwrite("debug", &PyClass::debug)
  1 +// sherpa-onnx/python/csrc/offline-tdnn-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-tdnn-model-config.h"
  6 +
  7 +#include <string>
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +void PybindOfflineTdnnModelConfig(py::module *m) {
  15 + using PyClass = OfflineTdnnModelConfig;
  16 + py::class_<PyClass>(*m, "OfflineTdnnModelConfig")
  17 + .def(py::init<const std::string &>(), py::arg("model"))
  18 + .def_readwrite("model", &PyClass::model)
  19 + .def("__str__", &PyClass::ToString);
  20 +}
  21 +
  22 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-tdnn-model-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineTdnnModelConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_
@@ -8,6 +8,7 @@ from _sherpa_onnx import ( @@ -8,6 +8,7 @@ from _sherpa_onnx import (
8 OfflineModelConfig, 8 OfflineModelConfig,
9 OfflineNemoEncDecCtcModelConfig, 9 OfflineNemoEncDecCtcModelConfig,
10 OfflineParaformerModelConfig, 10 OfflineParaformerModelConfig,
  11 + OfflineTdnnModelConfig,
11 OfflineWhisperModelConfig, 12 OfflineWhisperModelConfig,
12 ) 13 )
13 from _sherpa_onnx import OfflineRecognizer as _Recognizer 14 from _sherpa_onnx import OfflineRecognizer as _Recognizer
@@ -37,7 +38,7 @@ class OfflineRecognizer(object): @@ -37,7 +38,7 @@ class OfflineRecognizer(object):
37 decoder: str, 38 decoder: str,
38 joiner: str, 39 joiner: str,
39 tokens: str, 40 tokens: str,
40 - num_threads: int, 41 + num_threads: int = 1,
41 sample_rate: int = 16000, 42 sample_rate: int = 16000,
42 feature_dim: int = 80, 43 feature_dim: int = 80,
43 decoding_method: str = "greedy_search", 44 decoding_method: str = "greedy_search",
@@ -48,7 +49,7 @@ class OfflineRecognizer(object): @@ -48,7 +49,7 @@ class OfflineRecognizer(object):
48 ): 49 ):
49 """ 50 """
50 Please refer to 51 Please refer to
51 - `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_ 52 + `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html>`_
52 to download pre-trained models for different languages, e.g., Chinese, 53 to download pre-trained models for different languages, e.g., Chinese,
53 English, etc. 54 English, etc.
54 55
@@ -115,7 +116,7 @@ class OfflineRecognizer(object): @@ -115,7 +116,7 @@ class OfflineRecognizer(object):
115 cls, 116 cls,
116 paraformer: str, 117 paraformer: str,
117 tokens: str, 118 tokens: str,
118 - num_threads: int, 119 + num_threads: int = 1,
119 sample_rate: int = 16000, 120 sample_rate: int = 16000,
120 feature_dim: int = 80, 121 feature_dim: int = 80,
121 decoding_method: str = "greedy_search", 122 decoding_method: str = "greedy_search",
@@ -124,9 +125,8 @@ class OfflineRecognizer(object): @@ -124,9 +125,8 @@ class OfflineRecognizer(object):
124 ): 125 ):
125 """ 126 """
126 Please refer to 127 Please refer to
127 - `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_  
128 - to download pre-trained models for different languages, e.g., Chinese,  
129 - English, etc. 128 + `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html>`_
  129 + to download pre-trained models.
130 130
131 Args: 131 Args:
132 tokens: 132 tokens:
@@ -179,7 +179,7 @@ class OfflineRecognizer(object): @@ -179,7 +179,7 @@ class OfflineRecognizer(object):
179 cls, 179 cls,
180 model: str, 180 model: str,
181 tokens: str, 181 tokens: str,
182 - num_threads: int, 182 + num_threads: int = 1,
183 sample_rate: int = 16000, 183 sample_rate: int = 16000,
184 feature_dim: int = 80, 184 feature_dim: int = 80,
185 decoding_method: str = "greedy_search", 185 decoding_method: str = "greedy_search",
@@ -188,7 +188,7 @@ class OfflineRecognizer(object): @@ -188,7 +188,7 @@ class OfflineRecognizer(object):
188 ): 188 ):
189 """ 189 """
190 Please refer to 190 Please refer to
191 - `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_ 191 + `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/nemo/index.html>`_
192 to download pre-trained models for different languages, e.g., Chinese, 192 to download pre-trained models for different languages, e.g., Chinese,
193 English, etc. 193 English, etc.
194 194
@@ -244,14 +244,14 @@ class OfflineRecognizer(object): @@ -244,14 +244,14 @@ class OfflineRecognizer(object):
244 encoder: str, 244 encoder: str,
245 decoder: str, 245 decoder: str,
246 tokens: str, 246 tokens: str,
247 - num_threads: int, 247 + num_threads: int = 1,
248 decoding_method: str = "greedy_search", 248 decoding_method: str = "greedy_search",
249 debug: bool = False, 249 debug: bool = False,
250 provider: str = "cpu", 250 provider: str = "cpu",
251 ): 251 ):
252 """ 252 """
253 Please refer to 253 Please refer to
254 - `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_ 254 + `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html>`_
255 to download pre-trained models for different kinds of whisper models, 255 to download pre-trained models for different kinds of whisper models,
256 e.g., tiny, tiny.en, base, base.en, etc. 256 e.g., tiny, tiny.en, base, base.en, etc.
257 257
@@ -301,6 +301,69 @@ class OfflineRecognizer(object): @@ -301,6 +301,69 @@ class OfflineRecognizer(object):
301 self.config = recognizer_config 301 self.config = recognizer_config
302 return self 302 return self
303 303
  304 + @classmethod
  305 + def from_tdnn_ctc(
  306 + cls,
  307 + model: str,
  308 + tokens: str,
  309 + num_threads: int = 1,
  310 + sample_rate: int = 8000,
  311 + feature_dim: int = 23,
  312 + decoding_method: str = "greedy_search",
  313 + debug: bool = False,
  314 + provider: str = "cpu",
  315 + ):
  316 + """
  317 + Please refer to
  318 + `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/yesno/index.html>`_
  319 + to download pre-trained models.
  320 +
  321 + Args:
  322 + model:
  323 + Path to ``model.onnx``.
  324 + tokens:
  325 + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
  326 + columns::
  327 +
  328 + symbol integer_id
  329 +
  330 + num_threads:
  331 + Number of threads for neural network computation.
  332 + sample_rate:
  333 + Sample rate of the training data used to train the model.
  334 + feature_dim:
  335 + Dimension of the feature used to train the model.
  336 + decoding_method:
  337 + Valid values are greedy_search.
  338 + debug:
  339 + True to show debug messages.
  340 + provider:
  341 + onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
  342 + """
  343 + self = cls.__new__(cls)
  344 + model_config = OfflineModelConfig(
  345 + tdnn=OfflineTdnnModelConfig(model=model),
  346 + tokens=tokens,
  347 + num_threads=num_threads,
  348 + debug=debug,
  349 + provider=provider,
  350 + model_type="tdnn",
  351 + )
  352 +
  353 + feat_config = OfflineFeatureExtractorConfig(
  354 + sampling_rate=sample_rate,
  355 + feature_dim=feature_dim,
  356 + )
  357 +
  358 + recognizer_config = OfflineRecognizerConfig(
  359 + feat_config=feat_config,
  360 + model_config=model_config,
  361 + decoding_method=decoding_method,
  362 + )
  363 + self.recognizer = _Recognizer(recognizer_config)
  364 + self.config = recognizer_config
  365 + return self
  366 +
304 def create_stream(self, contexts_list: Optional[List[List[int]]] = None): 367 def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
305 if contexts_list is None: 368 if contexts_list is None:
306 return self.recognizer.create_stream() 369 return self.recognizer.create_stream()