Committed by
GitHub
Support TDNN models from the yesno recipe from icefall (#262)
正在显示
23 个修改的文件
包含
613 行增加
和
37 行删除
| @@ -14,6 +14,50 @@ echo "PATH: $PATH" | @@ -14,6 +14,50 @@ echo "PATH: $PATH" | ||
| 14 | which $EXE | 14 | which $EXE |
| 15 | 15 | ||
| 16 | log "------------------------------------------------------------" | 16 | log "------------------------------------------------------------" |
| 17 | +log "Run tdnn yesno (Hebrew)" | ||
| 18 | +log "------------------------------------------------------------" | ||
| 19 | +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-tdnn-yesno | ||
| 20 | +log "Start testing ${repo_url}" | ||
| 21 | +repo=$(basename $repo_url) | ||
| 22 | +log "Download pretrained model and test-data from $repo_url" | ||
| 23 | + | ||
| 24 | +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | ||
| 25 | +pushd $repo | ||
| 26 | +git lfs pull --include "*.onnx" | ||
| 27 | +ls -lh *.onnx | ||
| 28 | +popd | ||
| 29 | + | ||
| 30 | +log "test float32 models" | ||
| 31 | +time $EXE \ | ||
| 32 | + --sample-rate=8000 \ | ||
| 33 | + --feat-dim=23 \ | ||
| 34 | + \ | ||
| 35 | + --tokens=$repo/tokens.txt \ | ||
| 36 | + --tdnn-model=$repo/model-epoch-14-avg-2.onnx \ | ||
| 37 | + $repo/test_wavs/0_0_0_1_0_0_0_1.wav \ | ||
| 38 | + $repo/test_wavs/0_0_1_0_0_0_1_0.wav \ | ||
| 39 | + $repo/test_wavs/0_0_1_0_0_1_1_1.wav \ | ||
| 40 | + $repo/test_wavs/0_0_1_0_1_0_0_1.wav \ | ||
| 41 | + $repo/test_wavs/0_0_1_1_0_0_0_1.wav \ | ||
| 42 | + $repo/test_wavs/0_0_1_1_0_1_1_0.wav | ||
| 43 | + | ||
| 44 | +log "test int8 models" | ||
| 45 | +time $EXE \ | ||
| 46 | + --sample-rate=8000 \ | ||
| 47 | + --feat-dim=23 \ | ||
| 48 | + \ | ||
| 49 | + --tokens=$repo/tokens.txt \ | ||
| 50 | + --tdnn-model=$repo/model-epoch-14-avg-2.int8.onnx \ | ||
| 51 | + $repo/test_wavs/0_0_0_1_0_0_0_1.wav \ | ||
| 52 | + $repo/test_wavs/0_0_1_0_0_0_1_0.wav \ | ||
| 53 | + $repo/test_wavs/0_0_1_0_0_1_1_1.wav \ | ||
| 54 | + $repo/test_wavs/0_0_1_0_1_0_0_1.wav \ | ||
| 55 | + $repo/test_wavs/0_0_1_1_0_0_0_1.wav \ | ||
| 56 | + $repo/test_wavs/0_0_1_1_0_1_1_0.wav | ||
| 57 | + | ||
| 58 | +rm -rf $repo | ||
| 59 | + | ||
| 60 | +log "------------------------------------------------------------" | ||
| 17 | log "Run Citrinet (stt_en_citrinet_512, English)" | 61 | log "Run Citrinet (stt_en_citrinet_512, English)" |
| 18 | log "------------------------------------------------------------" | 62 | log "------------------------------------------------------------" |
| 19 | 63 |
| @@ -24,7 +24,7 @@ jobs: | @@ -24,7 +24,7 @@ jobs: | ||
| 24 | matrix: | 24 | matrix: |
| 25 | os: [ubuntu-latest, windows-latest, macos-latest] | 25 | os: [ubuntu-latest, windows-latest, macos-latest] |
| 26 | python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] | 26 | python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] |
| 27 | - model_type: ["transducer", "paraformer", "nemo_ctc", "whisper"] | 27 | + model_type: ["transducer", "paraformer", "nemo_ctc", "whisper", "tdnn"] |
| 28 | 28 | ||
| 29 | steps: | 29 | steps: |
| 30 | - uses: actions/checkout@v2 | 30 | - uses: actions/checkout@v2 |
| @@ -172,3 +172,41 @@ jobs: | @@ -172,3 +172,41 @@ jobs: | ||
| 172 | ./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav \ | 172 | ./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav \ |
| 173 | ./sherpa-onnx-whisper-tiny.en/test_wavs/1.wav \ | 173 | ./sherpa-onnx-whisper-tiny.en/test_wavs/1.wav \ |
| 174 | ./sherpa-onnx-whisper-tiny.en/test_wavs/8k.wav | 174 | ./sherpa-onnx-whisper-tiny.en/test_wavs/8k.wav |
| 175 | + | ||
| 176 | + - name: Start server for tdnn models | ||
| 177 | + if: matrix.model_type == 'tdnn' | ||
| 178 | + shell: bash | ||
| 179 | + run: | | ||
| 180 | + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-tdnn-yesno | ||
| 181 | + cd sherpa-onnx-tdnn-yesno | ||
| 182 | + git lfs pull --include "*.onnx" | ||
| 183 | + cd .. | ||
| 184 | + | ||
| 185 | + python3 ./python-api-examples/non_streaming_server.py \ | ||
| 186 | + --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \ | ||
| 187 | + --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \ | ||
| 188 | + --sample-rate=8000 \ | ||
| 189 | + --feat-dim=23 & | ||
| 190 | + | ||
| 191 | + echo "sleep 10 seconds to wait the server start" | ||
| 192 | + sleep 10 | ||
| 193 | + | ||
| 194 | + - name: Start client for tdnn models | ||
| 195 | + if: matrix.model_type == 'tdnn' | ||
| 196 | + shell: bash | ||
| 197 | + run: | | ||
| 198 | + python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \ | ||
| 199 | + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \ | ||
| 200 | + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \ | ||
| 201 | + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav \ | ||
| 202 | + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_1_0_0_1.wav \ | ||
| 203 | + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_1_0_0_0_1.wav \ | ||
| 204 | + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_1_0_1_1_0.wav | ||
| 205 | + | ||
| 206 | + python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \ | ||
| 207 | + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \ | ||
| 208 | + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \ | ||
| 209 | + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav \ | ||
| 210 | + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_1_0_0_1.wav \ | ||
| 211 | + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_1_0_0_0_1.wav \ | ||
| 212 | + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_1_0_1_1_0.wav |
| @@ -71,6 +71,20 @@ python3 ./python-api-examples/non_streaming_server.py \ | @@ -71,6 +71,20 @@ python3 ./python-api-examples/non_streaming_server.py \ | ||
| 71 | --whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \ | 71 | --whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \ |
| 72 | --tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt | 72 | --tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt |
| 73 | 73 | ||
| 74 | +(5) Use a tdnn model of the yesno recipe from icefall | ||
| 75 | + | ||
| 76 | +cd /path/to/sherpa-onnx | ||
| 77 | + | ||
| 78 | +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-tdnn-yesno | ||
| 79 | +cd sherpa-onnx-tdnn-yesno | ||
| 80 | +git lfs pull --include "*.onnx" | ||
| 81 | + | ||
| 82 | +python3 ./python-api-examples/non_streaming_server.py \ | ||
| 83 | + --sample-rate=8000 \ | ||
| 84 | + --feat-dim=23 \ | ||
| 85 | + --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \ | ||
| 86 | + --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt | ||
| 87 | + | ||
| 74 | ---- | 88 | ---- |
| 75 | 89 | ||
| 76 | To use a certificate so that you can use https, please use | 90 | To use a certificate so that you can use https, please use |
| @@ -196,6 +210,15 @@ def add_nemo_ctc_model_args(parser: argparse.ArgumentParser): | @@ -196,6 +210,15 @@ def add_nemo_ctc_model_args(parser: argparse.ArgumentParser): | ||
| 196 | ) | 210 | ) |
| 197 | 211 | ||
| 198 | 212 | ||
| 213 | +def add_tdnn_ctc_model_args(parser: argparse.ArgumentParser): | ||
| 214 | + parser.add_argument( | ||
| 215 | + "--tdnn-model", | ||
| 216 | + default="", | ||
| 217 | + type=str, | ||
| 218 | + help="Path to the model.onnx for the tdnn model of the yesno recipe", | ||
| 219 | + ) | ||
| 220 | + | ||
| 221 | + | ||
| 199 | def add_whisper_model_args(parser: argparse.ArgumentParser): | 222 | def add_whisper_model_args(parser: argparse.ArgumentParser): |
| 200 | parser.add_argument( | 223 | parser.add_argument( |
| 201 | "--whisper-encoder", | 224 | "--whisper-encoder", |
| @@ -216,6 +239,7 @@ def add_model_args(parser: argparse.ArgumentParser): | @@ -216,6 +239,7 @@ def add_model_args(parser: argparse.ArgumentParser): | ||
| 216 | add_transducer_model_args(parser) | 239 | add_transducer_model_args(parser) |
| 217 | add_paraformer_model_args(parser) | 240 | add_paraformer_model_args(parser) |
| 218 | add_nemo_ctc_model_args(parser) | 241 | add_nemo_ctc_model_args(parser) |
| 242 | + add_tdnn_ctc_model_args(parser) | ||
| 219 | add_whisper_model_args(parser) | 243 | add_whisper_model_args(parser) |
| 220 | 244 | ||
| 221 | parser.add_argument( | 245 | parser.add_argument( |
| @@ -730,6 +754,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -730,6 +754,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 730 | assert len(args.nemo_ctc) == 0, args.nemo_ctc | 754 | assert len(args.nemo_ctc) == 0, args.nemo_ctc |
| 731 | assert len(args.whisper_encoder) == 0, args.whisper_encoder | 755 | assert len(args.whisper_encoder) == 0, args.whisper_encoder |
| 732 | assert len(args.whisper_decoder) == 0, args.whisper_decoder | 756 | assert len(args.whisper_decoder) == 0, args.whisper_decoder |
| 757 | + assert len(args.tdnn_model) == 0, args.tdnn_model | ||
| 733 | 758 | ||
| 734 | assert_file_exists(args.encoder) | 759 | assert_file_exists(args.encoder) |
| 735 | assert_file_exists(args.decoder) | 760 | assert_file_exists(args.decoder) |
| @@ -750,6 +775,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -750,6 +775,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 750 | assert len(args.nemo_ctc) == 0, args.nemo_ctc | 775 | assert len(args.nemo_ctc) == 0, args.nemo_ctc |
| 751 | assert len(args.whisper_encoder) == 0, args.whisper_encoder | 776 | assert len(args.whisper_encoder) == 0, args.whisper_encoder |
| 752 | assert len(args.whisper_decoder) == 0, args.whisper_decoder | 777 | assert len(args.whisper_decoder) == 0, args.whisper_decoder |
| 778 | + assert len(args.tdnn_model) == 0, args.tdnn_model | ||
| 753 | 779 | ||
| 754 | assert_file_exists(args.paraformer) | 780 | assert_file_exists(args.paraformer) |
| 755 | 781 | ||
| @@ -764,6 +790,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -764,6 +790,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 764 | elif args.nemo_ctc: | 790 | elif args.nemo_ctc: |
| 765 | assert len(args.whisper_encoder) == 0, args.whisper_encoder | 791 | assert len(args.whisper_encoder) == 0, args.whisper_encoder |
| 766 | assert len(args.whisper_decoder) == 0, args.whisper_decoder | 792 | assert len(args.whisper_decoder) == 0, args.whisper_decoder |
| 793 | + assert len(args.tdnn_model) == 0, args.tdnn_model | ||
| 767 | 794 | ||
| 768 | assert_file_exists(args.nemo_ctc) | 795 | assert_file_exists(args.nemo_ctc) |
| 769 | 796 | ||
| @@ -776,6 +803,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -776,6 +803,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 776 | decoding_method=args.decoding_method, | 803 | decoding_method=args.decoding_method, |
| 777 | ) | 804 | ) |
| 778 | elif args.whisper_encoder: | 805 | elif args.whisper_encoder: |
| 806 | + assert len(args.tdnn_model) == 0, args.tdnn_model | ||
| 779 | assert_file_exists(args.whisper_encoder) | 807 | assert_file_exists(args.whisper_encoder) |
| 780 | assert_file_exists(args.whisper_decoder) | 808 | assert_file_exists(args.whisper_decoder) |
| 781 | 809 | ||
| @@ -786,6 +814,17 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -786,6 +814,17 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 786 | num_threads=args.num_threads, | 814 | num_threads=args.num_threads, |
| 787 | decoding_method=args.decoding_method, | 815 | decoding_method=args.decoding_method, |
| 788 | ) | 816 | ) |
| 817 | + elif args.tdnn_model: | ||
| 818 | + assert_file_exists(args.tdnn_model) | ||
| 819 | + | ||
| 820 | + recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc( | ||
| 821 | + model=args.tdnn_model, | ||
| 822 | + tokens=args.tokens, | ||
| 823 | + sample_rate=args.sample_rate, | ||
| 824 | + feature_dim=args.feat_dim, | ||
| 825 | + num_threads=args.num_threads, | ||
| 826 | + decoding_method=args.decoding_method, | ||
| 827 | + ) | ||
| 789 | else: | 828 | else: |
| 790 | raise ValueError("Please specify at least one model") | 829 | raise ValueError("Please specify at least one model") |
| 791 | 830 |
| @@ -8,6 +8,7 @@ This file demonstrates how to use sherpa-onnx Python API to transcribe | @@ -8,6 +8,7 @@ This file demonstrates how to use sherpa-onnx Python API to transcribe | ||
| 8 | file(s) with a non-streaming model. | 8 | file(s) with a non-streaming model. |
| 9 | 9 | ||
| 10 | (1) For paraformer | 10 | (1) For paraformer |
| 11 | + | ||
| 11 | ./python-api-examples/offline-decode-files.py \ | 12 | ./python-api-examples/offline-decode-files.py \ |
| 12 | --tokens=/path/to/tokens.txt \ | 13 | --tokens=/path/to/tokens.txt \ |
| 13 | --paraformer=/path/to/paraformer.onnx \ | 14 | --paraformer=/path/to/paraformer.onnx \ |
| @@ -20,6 +21,7 @@ file(s) with a non-streaming model. | @@ -20,6 +21,7 @@ file(s) with a non-streaming model. | ||
| 20 | /path/to/1.wav | 21 | /path/to/1.wav |
| 21 | 22 | ||
| 22 | (2) For transducer models from icefall | 23 | (2) For transducer models from icefall |
| 24 | + | ||
| 23 | ./python-api-examples/offline-decode-files.py \ | 25 | ./python-api-examples/offline-decode-files.py \ |
| 24 | --tokens=/path/to/tokens.txt \ | 26 | --tokens=/path/to/tokens.txt \ |
| 25 | --encoder=/path/to/encoder.onnx \ | 27 | --encoder=/path/to/encoder.onnx \ |
| @@ -56,9 +58,20 @@ python3 ./python-api-examples/offline-decode-files.py \ | @@ -56,9 +58,20 @@ python3 ./python-api-examples/offline-decode-files.py \ | ||
| 56 | ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \ | 58 | ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \ |
| 57 | ./sherpa-onnx-whisper-base.en/test_wavs/8k.wav | 59 | ./sherpa-onnx-whisper-base.en/test_wavs/8k.wav |
| 58 | 60 | ||
| 61 | +(5) For tdnn models of the yesno recipe from icefall | ||
| 62 | + | ||
| 63 | +python3 ./python-api-examples/offline-decode-files.py \ | ||
| 64 | + --sample-rate=8000 \ | ||
| 65 | + --feature-dim=23 \ | ||
| 66 | + --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \ | ||
| 67 | + --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \ | ||
| 68 | + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \ | ||
| 69 | + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \ | ||
| 70 | + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav | ||
| 71 | + | ||
| 59 | Please refer to | 72 | Please refer to |
| 60 | https://k2-fsa.github.io/sherpa/onnx/index.html | 73 | https://k2-fsa.github.io/sherpa/onnx/index.html |
| 61 | -to install sherpa-onnx and to download the pre-trained models | 74 | +to install sherpa-onnx and to download non-streaming pre-trained models |
| 62 | used in this file. | 75 | used in this file. |
| 63 | """ | 76 | """ |
| 64 | import argparse | 77 | import argparse |
| @@ -160,6 +173,13 @@ def get_args(): | @@ -160,6 +173,13 @@ def get_args(): | ||
| 160 | ) | 173 | ) |
| 161 | 174 | ||
| 162 | parser.add_argument( | 175 | parser.add_argument( |
| 176 | + "--tdnn-model", | ||
| 177 | + default="", | ||
| 178 | + type=str, | ||
| 179 | + help="Path to the model.onnx for the tdnn model of the yesno recipe", | ||
| 180 | + ) | ||
| 181 | + | ||
| 182 | + parser.add_argument( | ||
| 163 | "--num-threads", | 183 | "--num-threads", |
| 164 | type=int, | 184 | type=int, |
| 165 | default=1, | 185 | default=1, |
| @@ -285,6 +305,7 @@ def main(): | @@ -285,6 +305,7 @@ def main(): | ||
| 285 | assert len(args.nemo_ctc) == 0, args.nemo_ctc | 305 | assert len(args.nemo_ctc) == 0, args.nemo_ctc |
| 286 | assert len(args.whisper_encoder) == 0, args.whisper_encoder | 306 | assert len(args.whisper_encoder) == 0, args.whisper_encoder |
| 287 | assert len(args.whisper_decoder) == 0, args.whisper_decoder | 307 | assert len(args.whisper_decoder) == 0, args.whisper_decoder |
| 308 | + assert len(args.tdnn_model) == 0, args.tdnn_model | ||
| 288 | 309 | ||
| 289 | contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()] | 310 | contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()] |
| 290 | if contexts: | 311 | if contexts: |
| @@ -311,6 +332,7 @@ def main(): | @@ -311,6 +332,7 @@ def main(): | ||
| 311 | assert len(args.nemo_ctc) == 0, args.nemo_ctc | 332 | assert len(args.nemo_ctc) == 0, args.nemo_ctc |
| 312 | assert len(args.whisper_encoder) == 0, args.whisper_encoder | 333 | assert len(args.whisper_encoder) == 0, args.whisper_encoder |
| 313 | assert len(args.whisper_decoder) == 0, args.whisper_decoder | 334 | assert len(args.whisper_decoder) == 0, args.whisper_decoder |
| 335 | + assert len(args.tdnn_model) == 0, args.tdnn_model | ||
| 314 | 336 | ||
| 315 | assert_file_exists(args.paraformer) | 337 | assert_file_exists(args.paraformer) |
| 316 | 338 | ||
| @@ -326,6 +348,7 @@ def main(): | @@ -326,6 +348,7 @@ def main(): | ||
| 326 | elif args.nemo_ctc: | 348 | elif args.nemo_ctc: |
| 327 | assert len(args.whisper_encoder) == 0, args.whisper_encoder | 349 | assert len(args.whisper_encoder) == 0, args.whisper_encoder |
| 328 | assert len(args.whisper_decoder) == 0, args.whisper_decoder | 350 | assert len(args.whisper_decoder) == 0, args.whisper_decoder |
| 351 | + assert len(args.tdnn_model) == 0, args.tdnn_model | ||
| 329 | 352 | ||
| 330 | assert_file_exists(args.nemo_ctc) | 353 | assert_file_exists(args.nemo_ctc) |
| 331 | 354 | ||
| @@ -339,6 +362,7 @@ def main(): | @@ -339,6 +362,7 @@ def main(): | ||
| 339 | debug=args.debug, | 362 | debug=args.debug, |
| 340 | ) | 363 | ) |
| 341 | elif args.whisper_encoder: | 364 | elif args.whisper_encoder: |
| 365 | + assert len(args.tdnn_model) == 0, args.tdnn_model | ||
| 342 | assert_file_exists(args.whisper_encoder) | 366 | assert_file_exists(args.whisper_encoder) |
| 343 | assert_file_exists(args.whisper_decoder) | 367 | assert_file_exists(args.whisper_decoder) |
| 344 | 368 | ||
| @@ -347,6 +371,20 @@ def main(): | @@ -347,6 +371,20 @@ def main(): | ||
| 347 | decoder=args.whisper_decoder, | 371 | decoder=args.whisper_decoder, |
| 348 | tokens=args.tokens, | 372 | tokens=args.tokens, |
| 349 | num_threads=args.num_threads, | 373 | num_threads=args.num_threads, |
| 374 | + sample_rate=args.sample_rate, | ||
| 375 | + feature_dim=args.feature_dim, | ||
| 376 | + decoding_method=args.decoding_method, | ||
| 377 | + debug=args.debug, | ||
| 378 | + ) | ||
| 379 | + elif args.tdnn_model: | ||
| 380 | + assert_file_exists(args.tdnn_model) | ||
| 381 | + | ||
| 382 | + recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc( | ||
| 383 | + model=args.tdnn_model, | ||
| 384 | + tokens=args.tokens, | ||
| 385 | + sample_rate=args.sample_rate, | ||
| 386 | + feature_dim=args.feature_dim, | ||
| 387 | + num_threads=args.num_threads, | ||
| 350 | decoding_method=args.decoding_method, | 388 | decoding_method=args.decoding_method, |
| 351 | debug=args.debug, | 389 | debug=args.debug, |
| 352 | ) | 390 | ) |
| @@ -97,20 +97,18 @@ function onFileChange() { | @@ -97,20 +97,18 @@ function onFileChange() { | ||
| 97 | console.log('file.type ' + file.type); | 97 | console.log('file.type ' + file.type); |
| 98 | console.log('file.size ' + file.size); | 98 | console.log('file.size ' + file.size); |
| 99 | 99 | ||
| 100 | + let audioCtx = new AudioContext({sampleRate: 16000}); | ||
| 101 | + | ||
| 100 | let reader = new FileReader(); | 102 | let reader = new FileReader(); |
| 101 | reader.onload = function() { | 103 | reader.onload = function() { |
| 102 | console.log('reading file!'); | 104 | console.log('reading file!'); |
| 103 | - let view = new Int16Array(reader.result); | ||
| 104 | - // we assume the input file is a wav file. | ||
| 105 | - // TODO: add some checks here. | ||
| 106 | - let int16_samples = view.subarray(22); // header has 44 bytes == 22 shorts | ||
| 107 | - let num_samples = int16_samples.length; | ||
| 108 | - let float32_samples = new Float32Array(num_samples); | ||
| 109 | - console.log('num_samples ' + num_samples) | ||
| 110 | - | ||
| 111 | - for (let i = 0; i < num_samples; ++i) { | ||
| 112 | - float32_samples[i] = int16_samples[i] / 32768. | ||
| 113 | - } | 105 | + audioCtx.decodeAudioData(reader.result, decodedDone); |
| 106 | + }; | ||
| 107 | + | ||
| 108 | + function decodedDone(decoded) { | ||
| 109 | + let typedArray = new Float32Array(decoded.length); | ||
| 110 | + let float32_samples = decoded.getChannelData(0); | ||
| 111 | + let buf = float32_samples.buffer | ||
| 114 | 112 | ||
| 115 | // Send 1024 audio samples per request. | 113 | // Send 1024 audio samples per request. |
| 116 | // | 114 | // |
| @@ -119,14 +117,13 @@ function onFileChange() { | @@ -119,14 +117,13 @@ function onFileChange() { | ||
| 119 | // (2) There is a limit on the number of bytes in the payload that can be | 117 | // (2) There is a limit on the number of bytes in the payload that can be |
| 120 | // sent by websocket, which is 1MB, I think. We can send a large | 118 | // sent by websocket, which is 1MB, I think. We can send a large |
| 121 | // audio file for decoding in this approach. | 119 | // audio file for decoding in this approach. |
| 122 | - let buf = float32_samples.buffer | ||
| 123 | let n = 1024 * 4; // send this number of bytes per request. | 120 | let n = 1024 * 4; // send this number of bytes per request. |
| 124 | console.log('buf length, ' + buf.byteLength); | 121 | console.log('buf length, ' + buf.byteLength); |
| 125 | send_header(buf.byteLength); | 122 | send_header(buf.byteLength); |
| 126 | for (let start = 0; start < buf.byteLength; start += n) { | 123 | for (let start = 0; start < buf.byteLength; start += n) { |
| 127 | socket.send(buf.slice(start, start + n)); | 124 | socket.send(buf.slice(start, start + n)); |
| 128 | } | 125 | } |
| 129 | - }; | 126 | + } |
| 130 | 127 | ||
| 131 | reader.readAsArrayBuffer(file); | 128 | reader.readAsArrayBuffer(file); |
| 132 | } | 129 | } |
| @@ -32,6 +32,8 @@ set(sources | @@ -32,6 +32,8 @@ set(sources | ||
| 32 | offline-recognizer.cc | 32 | offline-recognizer.cc |
| 33 | offline-rnn-lm.cc | 33 | offline-rnn-lm.cc |
| 34 | offline-stream.cc | 34 | offline-stream.cc |
| 35 | + offline-tdnn-ctc-model.cc | ||
| 36 | + offline-tdnn-model-config.cc | ||
| 35 | offline-transducer-greedy-search-decoder.cc | 37 | offline-transducer-greedy-search-decoder.cc |
| 36 | offline-transducer-model-config.cc | 38 | offline-transducer-model-config.cc |
| 37 | offline-transducer-model.cc | 39 | offline-transducer-model.cc |
| @@ -11,12 +11,14 @@ | @@ -11,12 +11,14 @@ | ||
| 11 | 11 | ||
| 12 | #include "sherpa-onnx/csrc/macros.h" | 12 | #include "sherpa-onnx/csrc/macros.h" |
| 13 | #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h" | 13 | #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h" |
| 14 | +#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h" | ||
| 14 | #include "sherpa-onnx/csrc/onnx-utils.h" | 15 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 15 | 16 | ||
| 16 | namespace { | 17 | namespace { |
| 17 | 18 | ||
| 18 | enum class ModelType { | 19 | enum class ModelType { |
| 19 | kEncDecCTCModelBPE, | 20 | kEncDecCTCModelBPE, |
| 21 | + kTdnn, | ||
| 20 | kUnkown, | 22 | kUnkown, |
| 21 | }; | 23 | }; |
| 22 | 24 | ||
| @@ -55,6 +57,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -55,6 +57,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 55 | 57 | ||
| 56 | if (model_type.get() == std::string("EncDecCTCModelBPE")) { | 58 | if (model_type.get() == std::string("EncDecCTCModelBPE")) { |
| 57 | return ModelType::kEncDecCTCModelBPE; | 59 | return ModelType::kEncDecCTCModelBPE; |
| 60 | + } else if (model_type.get() == std::string("tdnn")) { | ||
| 61 | + return ModelType::kTdnn; | ||
| 58 | } else { | 62 | } else { |
| 59 | SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); | 63 | SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); |
| 60 | return ModelType::kUnkown; | 64 | return ModelType::kUnkown; |
| @@ -65,8 +69,18 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | @@ -65,8 +69,18 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | ||
| 65 | const OfflineModelConfig &config) { | 69 | const OfflineModelConfig &config) { |
| 66 | ModelType model_type = ModelType::kUnkown; | 70 | ModelType model_type = ModelType::kUnkown; |
| 67 | 71 | ||
| 72 | + std::string filename; | ||
| 73 | + if (!config.nemo_ctc.model.empty()) { | ||
| 74 | + filename = config.nemo_ctc.model; | ||
| 75 | + } else if (!config.tdnn.model.empty()) { | ||
| 76 | + filename = config.tdnn.model; | ||
| 77 | + } else { | ||
| 78 | + SHERPA_ONNX_LOGE("Please specify a CTC model"); | ||
| 79 | + exit(-1); | ||
| 80 | + } | ||
| 81 | + | ||
| 68 | { | 82 | { |
| 69 | - auto buffer = ReadFile(config.nemo_ctc.model); | 83 | + auto buffer = ReadFile(filename); |
| 70 | 84 | ||
| 71 | model_type = GetModelType(buffer.data(), buffer.size(), config.debug); | 85 | model_type = GetModelType(buffer.data(), buffer.size(), config.debug); |
| 72 | } | 86 | } |
| @@ -75,6 +89,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | @@ -75,6 +89,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | ||
| 75 | case ModelType::kEncDecCTCModelBPE: | 89 | case ModelType::kEncDecCTCModelBPE: |
| 76 | return std::make_unique<OfflineNemoEncDecCtcModel>(config); | 90 | return std::make_unique<OfflineNemoEncDecCtcModel>(config); |
| 77 | break; | 91 | break; |
| 92 | + case ModelType::kTdnn: | ||
| 93 | + return std::make_unique<OfflineTdnnCtcModel>(config); | ||
| 94 | + break; | ||
| 78 | case ModelType::kUnkown: | 95 | case ModelType::kUnkown: |
| 79 | SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); | 96 | SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); |
| 80 | return nullptr; | 97 | return nullptr; |
| @@ -39,10 +39,10 @@ class OfflineCtcModel { | @@ -39,10 +39,10 @@ class OfflineCtcModel { | ||
| 39 | 39 | ||
| 40 | /** SubsamplingFactor of the model | 40 | /** SubsamplingFactor of the model |
| 41 | * | 41 | * |
| 42 | - * For Citrinet, the subsampling factor is usually 4. | ||
| 43 | - * For Conformer CTC, the subsampling factor is usually 8. | 42 | + * For NeMo Citrinet, the subsampling factor is usually 4. |
| 43 | + * For NeMo Conformer CTC, the subsampling factor is usually 8. | ||
| 44 | */ | 44 | */ |
| 45 | - virtual int32_t SubsamplingFactor() const = 0; | 45 | + virtual int32_t SubsamplingFactor() const { return 1; } |
| 46 | 46 | ||
| 47 | /** Return an allocator for allocating memory | 47 | /** Return an allocator for allocating memory |
| 48 | */ | 48 | */ |
| @@ -15,6 +15,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { | @@ -15,6 +15,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { | ||
| 15 | paraformer.Register(po); | 15 | paraformer.Register(po); |
| 16 | nemo_ctc.Register(po); | 16 | nemo_ctc.Register(po); |
| 17 | whisper.Register(po); | 17 | whisper.Register(po); |
| 18 | + tdnn.Register(po); | ||
| 18 | 19 | ||
| 19 | po->Register("tokens", &tokens, "Path to tokens.txt"); | 20 | po->Register("tokens", &tokens, "Path to tokens.txt"); |
| 20 | 21 | ||
| @@ -29,7 +30,8 @@ void OfflineModelConfig::Register(ParseOptions *po) { | @@ -29,7 +30,8 @@ void OfflineModelConfig::Register(ParseOptions *po) { | ||
| 29 | 30 | ||
| 30 | po->Register("model-type", &model_type, | 31 | po->Register("model-type", &model_type, |
| 31 | "Specify it to reduce model initialization time. " | 32 | "Specify it to reduce model initialization time. " |
| 32 | - "Valid values are: transducer, paraformer, nemo_ctc, whisper." | 33 | + "Valid values are: transducer, paraformer, nemo_ctc, whisper, " |
| 34 | + "tdnn." | ||
| 33 | "All other values lead to loading the model twice."); | 35 | "All other values lead to loading the model twice."); |
| 34 | } | 36 | } |
| 35 | 37 | ||
| @@ -56,6 +58,10 @@ bool OfflineModelConfig::Validate() const { | @@ -56,6 +58,10 @@ bool OfflineModelConfig::Validate() const { | ||
| 56 | return whisper.Validate(); | 58 | return whisper.Validate(); |
| 57 | } | 59 | } |
| 58 | 60 | ||
| 61 | + if (!tdnn.model.empty()) { | ||
| 62 | + return tdnn.Validate(); | ||
| 63 | + } | ||
| 64 | + | ||
| 59 | return transducer.Validate(); | 65 | return transducer.Validate(); |
| 60 | } | 66 | } |
| 61 | 67 | ||
| @@ -67,6 +73,7 @@ std::string OfflineModelConfig::ToString() const { | @@ -67,6 +73,7 @@ std::string OfflineModelConfig::ToString() const { | ||
| 67 | os << "paraformer=" << paraformer.ToString() << ", "; | 73 | os << "paraformer=" << paraformer.ToString() << ", "; |
| 68 | os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; | 74 | os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; |
| 69 | os << "whisper=" << whisper.ToString() << ", "; | 75 | os << "whisper=" << whisper.ToString() << ", "; |
| 76 | + os << "tdnn=" << tdnn.ToString() << ", "; | ||
| 70 | os << "tokens=\"" << tokens << "\", "; | 77 | os << "tokens=\"" << tokens << "\", "; |
| 71 | os << "num_threads=" << num_threads << ", "; | 78 | os << "num_threads=" << num_threads << ", "; |
| 72 | os << "debug=" << (debug ? "True" : "False") << ", "; | 79 | os << "debug=" << (debug ? "True" : "False") << ", "; |
| @@ -8,6 +8,7 @@ | @@ -8,6 +8,7 @@ | ||
| 8 | 8 | ||
| 9 | #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h" | 9 | #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h" |
| 10 | #include "sherpa-onnx/csrc/offline-paraformer-model-config.h" | 10 | #include "sherpa-onnx/csrc/offline-paraformer-model-config.h" |
| 11 | +#include "sherpa-onnx/csrc/offline-tdnn-model-config.h" | ||
| 11 | #include "sherpa-onnx/csrc/offline-transducer-model-config.h" | 12 | #include "sherpa-onnx/csrc/offline-transducer-model-config.h" |
| 12 | #include "sherpa-onnx/csrc/offline-whisper-model-config.h" | 13 | #include "sherpa-onnx/csrc/offline-whisper-model-config.h" |
| 13 | 14 | ||
| @@ -18,6 +19,7 @@ struct OfflineModelConfig { | @@ -18,6 +19,7 @@ struct OfflineModelConfig { | ||
| 18 | OfflineParaformerModelConfig paraformer; | 19 | OfflineParaformerModelConfig paraformer; |
| 19 | OfflineNemoEncDecCtcModelConfig nemo_ctc; | 20 | OfflineNemoEncDecCtcModelConfig nemo_ctc; |
| 20 | OfflineWhisperModelConfig whisper; | 21 | OfflineWhisperModelConfig whisper; |
| 22 | + OfflineTdnnModelConfig tdnn; | ||
| 21 | 23 | ||
| 22 | std::string tokens; | 24 | std::string tokens; |
| 23 | int32_t num_threads = 2; | 25 | int32_t num_threads = 2; |
| @@ -40,12 +42,14 @@ struct OfflineModelConfig { | @@ -40,12 +42,14 @@ struct OfflineModelConfig { | ||
| 40 | const OfflineParaformerModelConfig ¶former, | 42 | const OfflineParaformerModelConfig ¶former, |
| 41 | const OfflineNemoEncDecCtcModelConfig &nemo_ctc, | 43 | const OfflineNemoEncDecCtcModelConfig &nemo_ctc, |
| 42 | const OfflineWhisperModelConfig &whisper, | 44 | const OfflineWhisperModelConfig &whisper, |
| 45 | + const OfflineTdnnModelConfig &tdnn, | ||
| 43 | const std::string &tokens, int32_t num_threads, bool debug, | 46 | const std::string &tokens, int32_t num_threads, bool debug, |
| 44 | const std::string &provider, const std::string &model_type) | 47 | const std::string &provider, const std::string &model_type) |
| 45 | : transducer(transducer), | 48 | : transducer(transducer), |
| 46 | paraformer(paraformer), | 49 | paraformer(paraformer), |
| 47 | nemo_ctc(nemo_ctc), | 50 | nemo_ctc(nemo_ctc), |
| 48 | whisper(whisper), | 51 | whisper(whisper), |
| 52 | + tdnn(tdnn), | ||
| 49 | tokens(tokens), | 53 | tokens(tokens), |
| 50 | num_threads(num_threads), | 54 | num_threads(num_threads), |
| 51 | debug(debug), | 55 | debug(debug), |
| @@ -27,6 +27,10 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, | @@ -27,6 +27,10 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, | ||
| 27 | std::string text; | 27 | std::string text; |
| 28 | 28 | ||
| 29 | for (int32_t i = 0; i != src.tokens.size(); ++i) { | 29 | for (int32_t i = 0; i != src.tokens.size(); ++i) { |
| 30 | + if (sym_table.contains("SIL") && src.tokens[i] == sym_table["SIL"]) { | ||
| 31 | + // tdnn models from yesno have a SIL token, we should remove it. | ||
| 32 | + continue; | ||
| 33 | + } | ||
| 30 | auto sym = sym_table[src.tokens[i]]; | 34 | auto sym = sym_table[src.tokens[i]]; |
| 31 | text.append(sym); | 35 | text.append(sym); |
| 32 | r.tokens.push_back(std::move(sym)); | 36 | r.tokens.push_back(std::move(sym)); |
| @@ -46,14 +50,22 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | @@ -46,14 +50,22 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | ||
| 46 | model_->FeatureNormalizationMethod(); | 50 | model_->FeatureNormalizationMethod(); |
| 47 | 51 | ||
| 48 | if (config.decoding_method == "greedy_search") { | 52 | if (config.decoding_method == "greedy_search") { |
| 49 | - if (!symbol_table_.contains("<blk>")) { | 53 | + if (!symbol_table_.contains("<blk>") && |
| 54 | + !symbol_table_.contains("<eps>")) { | ||
| 50 | SHERPA_ONNX_LOGE( | 55 | SHERPA_ONNX_LOGE( |
| 51 | "We expect that tokens.txt contains " | 56 | "We expect that tokens.txt contains " |
| 52 | - "the symbol <blk> and its ID."); | 57 | + "the symbol <blk> or <eps> and its ID."); |
| 53 | exit(-1); | 58 | exit(-1); |
| 54 | } | 59 | } |
| 55 | 60 | ||
| 56 | - int32_t blank_id = symbol_table_["<blk>"]; | 61 | + int32_t blank_id = 0; |
| 62 | + if (symbol_table_.contains("<blk>")) { | ||
| 63 | + blank_id = symbol_table_["<blk>"]; | ||
| 64 | + } else if (symbol_table_.contains("<eps>")) { | ||
| 65 | + // for tdnn models of the yesno recipe from icefall | ||
| 66 | + blank_id = symbol_table_["<eps>"]; | ||
| 67 | + } | ||
| 68 | + | ||
| 57 | decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id); | 69 | decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id); |
| 58 | } else { | 70 | } else { |
| 59 | SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", | 71 | SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", |
| @@ -27,6 +27,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -27,6 +27,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 27 | return std::make_unique<OfflineRecognizerParaformerImpl>(config); | 27 | return std::make_unique<OfflineRecognizerParaformerImpl>(config); |
| 28 | } else if (model_type == "nemo_ctc") { | 28 | } else if (model_type == "nemo_ctc") { |
| 29 | return std::make_unique<OfflineRecognizerCtcImpl>(config); | 29 | return std::make_unique<OfflineRecognizerCtcImpl>(config); |
| 30 | + } else if (model_type == "tdnn") { | ||
| 31 | + return std::make_unique<OfflineRecognizerCtcImpl>(config); | ||
| 30 | } else if (model_type == "whisper") { | 32 | } else if (model_type == "whisper") { |
| 31 | return std::make_unique<OfflineRecognizerWhisperImpl>(config); | 33 | return std::make_unique<OfflineRecognizerWhisperImpl>(config); |
| 32 | } else { | 34 | } else { |
| @@ -46,6 +48,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -46,6 +48,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 46 | model_filename = config.model_config.paraformer.model; | 48 | model_filename = config.model_config.paraformer.model; |
| 47 | } else if (!config.model_config.nemo_ctc.model.empty()) { | 49 | } else if (!config.model_config.nemo_ctc.model.empty()) { |
| 48 | model_filename = config.model_config.nemo_ctc.model; | 50 | model_filename = config.model_config.nemo_ctc.model; |
| 51 | + } else if (!config.model_config.tdnn.model.empty()) { | ||
| 52 | + model_filename = config.model_config.tdnn.model; | ||
| 49 | } else if (!config.model_config.whisper.encoder.empty()) { | 53 | } else if (!config.model_config.whisper.encoder.empty()) { |
| 50 | model_filename = config.model_config.whisper.encoder; | 54 | model_filename = config.model_config.whisper.encoder; |
| 51 | } else { | 55 | } else { |
| @@ -84,6 +88,11 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -84,6 +88,11 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 84 | "paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py" | 88 | "paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py" |
| 85 | "\n " | 89 | "\n " |
| 86 | "(3) Whisper" | 90 | "(3) Whisper" |
| 91 | + "\n " | ||
| 92 | + "(4) Tdnn models of the yesno recipe from icefall" | ||
| 93 | + "\n " | ||
| 94 | + "https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn" | ||
| 95 | + "\n" | ||
| 87 | "\n"); | 96 | "\n"); |
| 88 | exit(-1); | 97 | exit(-1); |
| 89 | } | 98 | } |
| @@ -102,6 +111,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -102,6 +111,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 102 | return std::make_unique<OfflineRecognizerCtcImpl>(config); | 111 | return std::make_unique<OfflineRecognizerCtcImpl>(config); |
| 103 | } | 112 | } |
| 104 | 113 | ||
| 114 | + if (model_type == "tdnn") { | ||
| 115 | + return std::make_unique<OfflineRecognizerCtcImpl>(config); | ||
| 116 | + } | ||
| 117 | + | ||
| 105 | if (strncmp(model_type.c_str(), "whisper", 7) == 0) { | 118 | if (strncmp(model_type.c_str(), "whisper", 7) == 0) { |
| 106 | return std::make_unique<OfflineRecognizerWhisperImpl>(config); | 119 | return std::make_unique<OfflineRecognizerWhisperImpl>(config); |
| 107 | } | 120 | } |
| @@ -112,7 +125,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -112,7 +125,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 112 | " - Non-streaming transducer models from icefall\n" | 125 | " - Non-streaming transducer models from icefall\n" |
| 113 | " - Non-streaming Paraformer models from FunASR\n" | 126 | " - Non-streaming Paraformer models from FunASR\n" |
| 114 | " - EncDecCTCModelBPE models from NeMo\n" | 127 | " - EncDecCTCModelBPE models from NeMo\n" |
| 115 | - " - Whisper models\n", | 128 | + " - Whisper models\n" |
| 129 | + " - Tdnn models\n", | ||
| 116 | model_type.c_str()); | 130 | model_type.c_str()); |
| 117 | 131 | ||
| 118 | exit(-1); | 132 | exit(-1); |
sherpa-onnx/csrc/offline-tdnn-ctc-model.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-tdnn-ctc-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 8 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 9 | +#include "sherpa-onnx/csrc/session.h" | ||
| 10 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 11 | +#include "sherpa-onnx/csrc/transpose.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +class OfflineTdnnCtcModel::Impl { | ||
| 16 | + public: | ||
| 17 | + explicit Impl(const OfflineModelConfig &config) | ||
| 18 | + : config_(config), | ||
| 19 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 20 | + sess_opts_(GetSessionOptions(config)), | ||
| 21 | + allocator_{} { | ||
| 22 | + Init(); | ||
| 23 | + } | ||
| 24 | + | ||
| 25 | + std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features) { | ||
| 26 | + auto nnet_out = | ||
| 27 | + sess_->Run({}, input_names_ptr_.data(), &features, 1, | ||
| 28 | + output_names_ptr_.data(), output_names_ptr_.size()); | ||
| 29 | + | ||
| 30 | + std::vector<int64_t> nnet_out_shape = | ||
| 31 | + nnet_out[0].GetTensorTypeAndShapeInfo().GetShape(); | ||
| 32 | + | ||
| 33 | + std::vector<int64_t> out_length_vec(nnet_out_shape[0], nnet_out_shape[1]); | ||
| 34 | + std::vector<int64_t> out_length_shape(1, nnet_out_shape[0]); | ||
| 35 | + | ||
| 36 | + auto memory_info = | ||
| 37 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 38 | + | ||
| 39 | + Ort::Value nnet_out_length = Ort::Value::CreateTensor( | ||
| 40 | + memory_info, out_length_vec.data(), out_length_vec.size(), | ||
| 41 | + out_length_shape.data(), out_length_shape.size()); | ||
| 42 | + | ||
| 43 | + return {std::move(nnet_out[0]), Clone(Allocator(), &nnet_out_length)}; | ||
| 44 | + } | ||
| 45 | + | ||
| 46 | + int32_t VocabSize() const { return vocab_size_; } | ||
| 47 | + | ||
| 48 | + OrtAllocator *Allocator() const { return allocator_; } | ||
| 49 | + | ||
| 50 | + private: | ||
| 51 | + void Init() { | ||
| 52 | + auto buf = ReadFile(config_.tdnn.model); | ||
| 53 | + | ||
| 54 | + sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(), | ||
| 55 | + sess_opts_); | ||
| 56 | + | ||
| 57 | + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | ||
| 58 | + | ||
| 59 | + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | ||
| 60 | + | ||
| 61 | + // get meta data | ||
| 62 | + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); | ||
| 63 | + if (config_.debug) { | ||
| 64 | + std::ostringstream os; | ||
| 65 | + PrintModelMetadata(os, meta_data); | ||
| 66 | + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
| 67 | + } | ||
| 68 | + | ||
| 69 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 70 | + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); | ||
| 71 | + } | ||
| 72 | + | ||
| 73 | + private: | ||
| 74 | + OfflineModelConfig config_; | ||
| 75 | + Ort::Env env_; | ||
| 76 | + Ort::SessionOptions sess_opts_; | ||
| 77 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 78 | + | ||
| 79 | + std::unique_ptr<Ort::Session> sess_; | ||
| 80 | + | ||
| 81 | + std::vector<std::string> input_names_; | ||
| 82 | + std::vector<const char *> input_names_ptr_; | ||
| 83 | + | ||
| 84 | + std::vector<std::string> output_names_; | ||
| 85 | + std::vector<const char *> output_names_ptr_; | ||
| 86 | + | ||
| 87 | + int32_t vocab_size_ = 0; | ||
| 88 | +}; | ||
| 89 | + | ||
| 90 | +OfflineTdnnCtcModel::OfflineTdnnCtcModel(const OfflineModelConfig &config) | ||
| 91 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 92 | + | ||
| 93 | +OfflineTdnnCtcModel::~OfflineTdnnCtcModel() = default; | ||
| 94 | + | ||
| 95 | +std::pair<Ort::Value, Ort::Value> OfflineTdnnCtcModel::Forward( | ||
| 96 | + Ort::Value features, Ort::Value /*features_length*/) { | ||
| 97 | + return impl_->Forward(std::move(features)); | ||
| 98 | +} | ||
| 99 | + | ||
| 100 | +int32_t OfflineTdnnCtcModel::VocabSize() const { return impl_->VocabSize(); } | ||
| 101 | + | ||
| 102 | +OrtAllocator *OfflineTdnnCtcModel::Allocator() const { | ||
| 103 | + return impl_->Allocator(); | ||
| 104 | +} | ||
| 105 | + | ||
| 106 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/offline-tdnn-ctc-model.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-tdnn-ctc-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_ | ||
| 6 | +#include <memory> | ||
| 7 | +#include <string> | ||
| 8 | +#include <utility> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 12 | +#include "sherpa-onnx/csrc/offline-ctc-model.h" | ||
| 13 | +#include "sherpa-onnx/csrc/offline-model-config.h" | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | + | ||
| 17 | +/** This class implements the tdnn model of the yesno recipe from icefall. | ||
| 18 | + * | ||
| 19 | + * See | ||
| 20 | + * https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn | ||
| 21 | + */ | ||
| 22 | +class OfflineTdnnCtcModel : public OfflineCtcModel { | ||
| 23 | + public: | ||
| 24 | + explicit OfflineTdnnCtcModel(const OfflineModelConfig &config); | ||
| 25 | + ~OfflineTdnnCtcModel() override; | ||
| 26 | + | ||
| 27 | + /** Run the forward method of the model. | ||
| 28 | + * | ||
| 29 | + * @param features A tensor of shape (N, T, C). It is changed in-place. | ||
| 30 | + * @param features_length A 1-D tensor of shape (N,) containing number of | ||
| 31 | + * valid frames in `features` before padding. | ||
| 32 | + * Its dtype is int64_t. | ||
| 33 | + * | ||
| 34 | + * @return Return a pair containing: | ||
| 35 | + * - log_probs: A 3-D tensor of shape (N, T', vocab_size). | ||
| 36 | + * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t | ||
| 37 | + */ | ||
| 38 | + std::pair<Ort::Value, Ort::Value> Forward( | ||
| 39 | + Ort::Value features, Ort::Value /*features_length*/) override; | ||
| 40 | + | ||
| 41 | + /** Return the vocabulary size of the model | ||
| 42 | + */ | ||
| 43 | + int32_t VocabSize() const override; | ||
| 44 | + | ||
| 45 | + /** Return an allocator for allocating memory | ||
| 46 | + */ | ||
| 47 | + OrtAllocator *Allocator() const override; | ||
| 48 | + | ||
| 49 | + private: | ||
| 50 | + class Impl; | ||
| 51 | + std::unique_ptr<Impl> impl_; | ||
| 52 | +}; | ||
| 53 | + | ||
| 54 | +} // namespace sherpa_onnx | ||
| 55 | + | ||
| 56 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_ |
| 1 | +// sherpa-onnx/csrc/offline-tdnn-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-tdnn-model-config.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 8 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void OfflineTdnnModelConfig::Register(ParseOptions *po) { | ||
| 13 | + po->Register("tdnn-model", &model, "Path to onnx model"); | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +bool OfflineTdnnModelConfig::Validate() const { | ||
| 17 | + if (!FileExists(model)) { | ||
| 18 | + SHERPA_ONNX_LOGE("tdnn model file %s does not exist", model.c_str()); | ||
| 19 | + return false; | ||
| 20 | + } | ||
| 21 | + | ||
| 22 | + return true; | ||
| 23 | +} | ||
| 24 | + | ||
| 25 | +std::string OfflineTdnnModelConfig::ToString() const { | ||
| 26 | + std::ostringstream os; | ||
| 27 | + | ||
| 28 | + os << "OfflineTdnnModelConfig("; | ||
| 29 | + os << "model=\"" << model << "\")"; | ||
| 30 | + | ||
| 31 | + return os.str(); | ||
| 32 | +} | ||
| 33 | + | ||
| 34 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/offline-tdnn-model-config.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-tdnn-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +// for https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn | ||
| 14 | +struct OfflineTdnnModelConfig { | ||
| 15 | + std::string model; | ||
| 16 | + | ||
| 17 | + OfflineTdnnModelConfig() = default; | ||
| 18 | + explicit OfflineTdnnModelConfig(const std::string &model) : model(model) {} | ||
| 19 | + | ||
| 20 | + void Register(ParseOptions *po); | ||
| 21 | + bool Validate() const; | ||
| 22 | + | ||
| 23 | + std::string ToString() const; | ||
| 24 | +}; | ||
| 25 | + | ||
| 26 | +} // namespace sherpa_onnx | ||
| 27 | + | ||
| 28 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_ |
| @@ -14,10 +14,14 @@ | @@ -14,10 +14,14 @@ | ||
| 14 | 14 | ||
| 15 | int main(int32_t argc, char *argv[]) { | 15 | int main(int32_t argc, char *argv[]) { |
| 16 | const char *kUsageMessage = R"usage( | 16 | const char *kUsageMessage = R"usage( |
| 17 | +Speech recognition using non-streaming models with sherpa-onnx. | ||
| 18 | + | ||
| 17 | Usage: | 19 | Usage: |
| 18 | 20 | ||
| 19 | (1) Transducer from icefall | 21 | (1) Transducer from icefall |
| 20 | 22 | ||
| 23 | +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html | ||
| 24 | + | ||
| 21 | ./bin/sherpa-onnx-offline \ | 25 | ./bin/sherpa-onnx-offline \ |
| 22 | --tokens=/path/to/tokens.txt \ | 26 | --tokens=/path/to/tokens.txt \ |
| 23 | --encoder=/path/to/encoder.onnx \ | 27 | --encoder=/path/to/encoder.onnx \ |
| @@ -30,6 +34,8 @@ Usage: | @@ -30,6 +34,8 @@ Usage: | ||
| 30 | 34 | ||
| 31 | (2) Paraformer from FunASR | 35 | (2) Paraformer from FunASR |
| 32 | 36 | ||
| 37 | +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html | ||
| 38 | + | ||
| 33 | ./bin/sherpa-onnx-offline \ | 39 | ./bin/sherpa-onnx-offline \ |
| 34 | --tokens=/path/to/tokens.txt \ | 40 | --tokens=/path/to/tokens.txt \ |
| 35 | --paraformer=/path/to/model.onnx \ | 41 | --paraformer=/path/to/model.onnx \ |
| @@ -39,6 +45,8 @@ Usage: | @@ -39,6 +45,8 @@ Usage: | ||
| 39 | 45 | ||
| 40 | (3) Whisper models | 46 | (3) Whisper models |
| 41 | 47 | ||
| 48 | +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html | ||
| 49 | + | ||
| 42 | ./bin/sherpa-onnx-offline \ | 50 | ./bin/sherpa-onnx-offline \ |
| 43 | --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \ | 51 | --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \ |
| 44 | --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \ | 52 | --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \ |
| @@ -46,6 +54,31 @@ Usage: | @@ -46,6 +54,31 @@ Usage: | ||
| 46 | --num-threads=1 \ | 54 | --num-threads=1 \ |
| 47 | /path/to/foo.wav [bar.wav foobar.wav ...] | 55 | /path/to/foo.wav [bar.wav foobar.wav ...] |
| 48 | 56 | ||
| 57 | +(4) NeMo CTC models | ||
| 58 | + | ||
| 59 | +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html | ||
| 60 | + | ||
| 61 | + ./bin/sherpa-onnx-offline \ | ||
| 62 | + --tokens=./sherpa-onnx-nemo-ctc-en-conformer-medium/tokens.txt \ | ||
| 63 | + --nemo-ctc-model=./sherpa-onnx-nemo-ctc-en-conformer-medium/model.onnx \ | ||
| 64 | + --num-threads=2 \ | ||
| 65 | + --decoding-method=greedy_search \ | ||
| 66 | + --debug=false \ | ||
| 67 | + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/0.wav \ | ||
| 68 | + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/1.wav \ | ||
| 69 | + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/8k.wav | ||
| 70 | + | ||
| 71 | +(5) TDNN CTC model for the yesno recipe from icefall | ||
| 72 | + | ||
| 73 | +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/yesno/index.html | ||
| 74 | + // | ||
| 75 | + ./build/bin/sherpa-onnx-offline \ | ||
| 76 | + --sample-rate=8000 \ | ||
| 77 | + --feat-dim=23 \ | ||
| 78 | + --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \ | ||
| 79 | + --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \ | ||
| 80 | + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \ | ||
| 81 | + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav | ||
| 49 | 82 | ||
| 50 | Note: It supports decoding multiple files in batches | 83 | Note: It supports decoding multiple files in batches |
| 51 | 84 |
| @@ -10,6 +10,7 @@ pybind11_add_module(_sherpa_onnx | @@ -10,6 +10,7 @@ pybind11_add_module(_sherpa_onnx | ||
| 10 | offline-paraformer-model-config.cc | 10 | offline-paraformer-model-config.cc |
| 11 | offline-recognizer.cc | 11 | offline-recognizer.cc |
| 12 | offline-stream.cc | 12 | offline-stream.cc |
| 13 | + offline-tdnn-model-config.cc | ||
| 13 | offline-transducer-model-config.cc | 14 | offline-transducer-model-config.cc |
| 14 | offline-whisper-model-config.cc | 15 | offline-whisper-model-config.cc |
| 15 | online-lm-config.cc | 16 | online-lm-config.cc |
| @@ -10,6 +10,7 @@ | @@ -10,6 +10,7 @@ | ||
| 10 | #include "sherpa-onnx/csrc/offline-model-config.h" | 10 | #include "sherpa-onnx/csrc/offline-model-config.h" |
| 11 | #include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h" | 11 | #include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h" |
| 12 | #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" | 12 | #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" |
| 13 | +#include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h" | ||
| 13 | #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" | 14 | #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" |
| 14 | #include "sherpa-onnx/python/csrc/offline-whisper-model-config.h" | 15 | #include "sherpa-onnx/python/csrc/offline-whisper-model-config.h" |
| 15 | 16 | ||
| @@ -20,24 +21,28 @@ void PybindOfflineModelConfig(py::module *m) { | @@ -20,24 +21,28 @@ void PybindOfflineModelConfig(py::module *m) { | ||
| 20 | PybindOfflineParaformerModelConfig(m); | 21 | PybindOfflineParaformerModelConfig(m); |
| 21 | PybindOfflineNemoEncDecCtcModelConfig(m); | 22 | PybindOfflineNemoEncDecCtcModelConfig(m); |
| 22 | PybindOfflineWhisperModelConfig(m); | 23 | PybindOfflineWhisperModelConfig(m); |
| 24 | + PybindOfflineTdnnModelConfig(m); | ||
| 23 | 25 | ||
| 24 | using PyClass = OfflineModelConfig; | 26 | using PyClass = OfflineModelConfig; |
| 25 | py::class_<PyClass>(*m, "OfflineModelConfig") | 27 | py::class_<PyClass>(*m, "OfflineModelConfig") |
| 26 | .def(py::init<const OfflineTransducerModelConfig &, | 28 | .def(py::init<const OfflineTransducerModelConfig &, |
| 27 | const OfflineParaformerModelConfig &, | 29 | const OfflineParaformerModelConfig &, |
| 28 | const OfflineNemoEncDecCtcModelConfig &, | 30 | const OfflineNemoEncDecCtcModelConfig &, |
| 29 | - const OfflineWhisperModelConfig &, const std::string &, | 31 | + const OfflineWhisperModelConfig &, |
| 32 | + const OfflineTdnnModelConfig &, const std::string &, | ||
| 30 | int32_t, bool, const std::string &, const std::string &>(), | 33 | int32_t, bool, const std::string &, const std::string &>(), |
| 31 | py::arg("transducer") = OfflineTransducerModelConfig(), | 34 | py::arg("transducer") = OfflineTransducerModelConfig(), |
| 32 | py::arg("paraformer") = OfflineParaformerModelConfig(), | 35 | py::arg("paraformer") = OfflineParaformerModelConfig(), |
| 33 | py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), | 36 | py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), |
| 34 | - py::arg("whisper") = OfflineWhisperModelConfig(), py::arg("tokens"), | 37 | + py::arg("whisper") = OfflineWhisperModelConfig(), |
| 38 | + py::arg("tdnn") = OfflineTdnnModelConfig(), py::arg("tokens"), | ||
| 35 | py::arg("num_threads"), py::arg("debug") = false, | 39 | py::arg("num_threads"), py::arg("debug") = false, |
| 36 | py::arg("provider") = "cpu", py::arg("model_type") = "") | 40 | py::arg("provider") = "cpu", py::arg("model_type") = "") |
| 37 | .def_readwrite("transducer", &PyClass::transducer) | 41 | .def_readwrite("transducer", &PyClass::transducer) |
| 38 | .def_readwrite("paraformer", &PyClass::paraformer) | 42 | .def_readwrite("paraformer", &PyClass::paraformer) |
| 39 | .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) | 43 | .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) |
| 40 | .def_readwrite("whisper", &PyClass::whisper) | 44 | .def_readwrite("whisper", &PyClass::whisper) |
| 45 | + .def_readwrite("tdnn", &PyClass::tdnn) | ||
| 41 | .def_readwrite("tokens", &PyClass::tokens) | 46 | .def_readwrite("tokens", &PyClass::tokens) |
| 42 | .def_readwrite("num_threads", &PyClass::num_threads) | 47 | .def_readwrite("num_threads", &PyClass::num_threads) |
| 43 | .def_readwrite("debug", &PyClass::debug) | 48 | .def_readwrite("debug", &PyClass::debug) |
| 1 | +// sherpa-onnx/python/csrc/offline-tdnn-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-tdnn-model-config.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +void PybindOfflineTdnnModelConfig(py::module *m) { | ||
| 15 | + using PyClass = OfflineTdnnModelConfig; | ||
| 16 | + py::class_<PyClass>(*m, "OfflineTdnnModelConfig") | ||
| 17 | + .def(py::init<const std::string &>(), py::arg("model")) | ||
| 18 | + .def_readwrite("model", &PyClass::model) | ||
| 19 | + .def("__str__", &PyClass::ToString); | ||
| 20 | +} | ||
| 21 | + | ||
| 22 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/python/csrc/offline-tdnn-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindOfflineTdnnModelConfig(py::module *m); | ||
| 13 | + | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_ |
| @@ -8,6 +8,7 @@ from _sherpa_onnx import ( | @@ -8,6 +8,7 @@ from _sherpa_onnx import ( | ||
| 8 | OfflineModelConfig, | 8 | OfflineModelConfig, |
| 9 | OfflineNemoEncDecCtcModelConfig, | 9 | OfflineNemoEncDecCtcModelConfig, |
| 10 | OfflineParaformerModelConfig, | 10 | OfflineParaformerModelConfig, |
| 11 | + OfflineTdnnModelConfig, | ||
| 11 | OfflineWhisperModelConfig, | 12 | OfflineWhisperModelConfig, |
| 12 | ) | 13 | ) |
| 13 | from _sherpa_onnx import OfflineRecognizer as _Recognizer | 14 | from _sherpa_onnx import OfflineRecognizer as _Recognizer |
| @@ -37,7 +38,7 @@ class OfflineRecognizer(object): | @@ -37,7 +38,7 @@ class OfflineRecognizer(object): | ||
| 37 | decoder: str, | 38 | decoder: str, |
| 38 | joiner: str, | 39 | joiner: str, |
| 39 | tokens: str, | 40 | tokens: str, |
| 40 | - num_threads: int, | 41 | + num_threads: int = 1, |
| 41 | sample_rate: int = 16000, | 42 | sample_rate: int = 16000, |
| 42 | feature_dim: int = 80, | 43 | feature_dim: int = 80, |
| 43 | decoding_method: str = "greedy_search", | 44 | decoding_method: str = "greedy_search", |
| @@ -48,7 +49,7 @@ class OfflineRecognizer(object): | @@ -48,7 +49,7 @@ class OfflineRecognizer(object): | ||
| 48 | ): | 49 | ): |
| 49 | """ | 50 | """ |
| 50 | Please refer to | 51 | Please refer to |
| 51 | - `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_ | 52 | + `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html>`_ |
| 52 | to download pre-trained models for different languages, e.g., Chinese, | 53 | to download pre-trained models for different languages, e.g., Chinese, |
| 53 | English, etc. | 54 | English, etc. |
| 54 | 55 | ||
| @@ -115,7 +116,7 @@ class OfflineRecognizer(object): | @@ -115,7 +116,7 @@ class OfflineRecognizer(object): | ||
| 115 | cls, | 116 | cls, |
| 116 | paraformer: str, | 117 | paraformer: str, |
| 117 | tokens: str, | 118 | tokens: str, |
| 118 | - num_threads: int, | 119 | + num_threads: int = 1, |
| 119 | sample_rate: int = 16000, | 120 | sample_rate: int = 16000, |
| 120 | feature_dim: int = 80, | 121 | feature_dim: int = 80, |
| 121 | decoding_method: str = "greedy_search", | 122 | decoding_method: str = "greedy_search", |
| @@ -124,9 +125,8 @@ class OfflineRecognizer(object): | @@ -124,9 +125,8 @@ class OfflineRecognizer(object): | ||
| 124 | ): | 125 | ): |
| 125 | """ | 126 | """ |
| 126 | Please refer to | 127 | Please refer to |
| 127 | - `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_ | ||
| 128 | - to download pre-trained models for different languages, e.g., Chinese, | ||
| 129 | - English, etc. | 128 | + `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html>`_ |
| 129 | + to download pre-trained models. | ||
| 130 | 130 | ||
| 131 | Args: | 131 | Args: |
| 132 | tokens: | 132 | tokens: |
| @@ -179,7 +179,7 @@ class OfflineRecognizer(object): | @@ -179,7 +179,7 @@ class OfflineRecognizer(object): | ||
| 179 | cls, | 179 | cls, |
| 180 | model: str, | 180 | model: str, |
| 181 | tokens: str, | 181 | tokens: str, |
| 182 | - num_threads: int, | 182 | + num_threads: int = 1, |
| 183 | sample_rate: int = 16000, | 183 | sample_rate: int = 16000, |
| 184 | feature_dim: int = 80, | 184 | feature_dim: int = 80, |
| 185 | decoding_method: str = "greedy_search", | 185 | decoding_method: str = "greedy_search", |
| @@ -188,7 +188,7 @@ class OfflineRecognizer(object): | @@ -188,7 +188,7 @@ class OfflineRecognizer(object): | ||
| 188 | ): | 188 | ): |
| 189 | """ | 189 | """ |
| 190 | Please refer to | 190 | Please refer to |
| 191 | - `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_ | 191 | + `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/nemo/index.html>`_ |
| 192 | to download pre-trained models for different languages, e.g., Chinese, | 192 | to download pre-trained models for different languages, e.g., Chinese, |
| 193 | English, etc. | 193 | English, etc. |
| 194 | 194 | ||
| @@ -244,14 +244,14 @@ class OfflineRecognizer(object): | @@ -244,14 +244,14 @@ class OfflineRecognizer(object): | ||
| 244 | encoder: str, | 244 | encoder: str, |
| 245 | decoder: str, | 245 | decoder: str, |
| 246 | tokens: str, | 246 | tokens: str, |
| 247 | - num_threads: int, | 247 | + num_threads: int = 1, |
| 248 | decoding_method: str = "greedy_search", | 248 | decoding_method: str = "greedy_search", |
| 249 | debug: bool = False, | 249 | debug: bool = False, |
| 250 | provider: str = "cpu", | 250 | provider: str = "cpu", |
| 251 | ): | 251 | ): |
| 252 | """ | 252 | """ |
| 253 | Please refer to | 253 | Please refer to |
| 254 | - `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_ | 254 | + `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html>`_ |
| 255 | to download pre-trained models for different kinds of whisper models, | 255 | to download pre-trained models for different kinds of whisper models, |
| 256 | e.g., tiny, tiny.en, base, base.en, etc. | 256 | e.g., tiny, tiny.en, base, base.en, etc. |
| 257 | 257 | ||
| @@ -301,6 +301,69 @@ class OfflineRecognizer(object): | @@ -301,6 +301,69 @@ class OfflineRecognizer(object): | ||
| 301 | self.config = recognizer_config | 301 | self.config = recognizer_config |
| 302 | return self | 302 | return self |
| 303 | 303 | ||
| 304 | + @classmethod | ||
| 305 | + def from_tdnn_ctc( | ||
| 306 | + cls, | ||
| 307 | + model: str, | ||
| 308 | + tokens: str, | ||
| 309 | + num_threads: int = 1, | ||
| 310 | + sample_rate: int = 8000, | ||
| 311 | + feature_dim: int = 23, | ||
| 312 | + decoding_method: str = "greedy_search", | ||
| 313 | + debug: bool = False, | ||
| 314 | + provider: str = "cpu", | ||
| 315 | + ): | ||
| 316 | + """ | ||
| 317 | + Please refer to | ||
| 318 | + `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/yesno/index.html>`_ | ||
| 319 | + to download pre-trained models. | ||
| 320 | + | ||
| 321 | + Args: | ||
| 322 | + model: | ||
| 323 | + Path to ``model.onnx``. | ||
| 324 | + tokens: | ||
| 325 | + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two | ||
| 326 | + columns:: | ||
| 327 | + | ||
| 328 | + symbol integer_id | ||
| 329 | + | ||
| 330 | + num_threads: | ||
| 331 | + Number of threads for neural network computation. | ||
| 332 | + sample_rate: | ||
| 333 | + Sample rate of the training data used to train the model. | ||
| 334 | + feature_dim: | ||
| 335 | + Dimension of the feature used to train the model. | ||
| 336 | + decoding_method: | ||
| 337 | + Valid values are greedy_search. | ||
| 338 | + debug: | ||
| 339 | + True to show debug messages. | ||
| 340 | + provider: | ||
| 341 | + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | ||
| 342 | + """ | ||
| 343 | + self = cls.__new__(cls) | ||
| 344 | + model_config = OfflineModelConfig( | ||
| 345 | + tdnn=OfflineTdnnModelConfig(model=model), | ||
| 346 | + tokens=tokens, | ||
| 347 | + num_threads=num_threads, | ||
| 348 | + debug=debug, | ||
| 349 | + provider=provider, | ||
| 350 | + model_type="tdnn", | ||
| 351 | + ) | ||
| 352 | + | ||
| 353 | + feat_config = OfflineFeatureExtractorConfig( | ||
| 354 | + sampling_rate=sample_rate, | ||
| 355 | + feature_dim=feature_dim, | ||
| 356 | + ) | ||
| 357 | + | ||
| 358 | + recognizer_config = OfflineRecognizerConfig( | ||
| 359 | + feat_config=feat_config, | ||
| 360 | + model_config=model_config, | ||
| 361 | + decoding_method=decoding_method, | ||
| 362 | + ) | ||
| 363 | + self.recognizer = _Recognizer(recognizer_config) | ||
| 364 | + self.config = recognizer_config | ||
| 365 | + return self | ||
| 366 | + | ||
| 304 | def create_stream(self, contexts_list: Optional[List[List[int]]] = None): | 367 | def create_stream(self, contexts_list: Optional[List[List[int]]] = None): |
| 305 | if contexts_list is None: | 368 | if contexts_list is None: |
| 306 | return self.recognizer.create_stream() | 369 | return self.recognizer.create_stream() |
-
请 注册 或 登录 后发表评论