Committed by
GitHub
Add C++ runtime for non-streaming faster conformer transducer from NeMo. (#854)
正在显示
31 个修改的文件
包含
1093 行增加
和
153 行删除
| @@ -13,6 +13,105 @@ echo "PATH: $PATH" | @@ -13,6 +13,105 @@ echo "PATH: $PATH" | ||
| 13 | 13 | ||
| 14 | which $EXE | 14 | which $EXE |
| 15 | 15 | ||
| 16 | +log "------------------------------------------------------------------------" | ||
| 17 | +log "Run Nemo fast conformer hybrid transducer ctc models (transducer branch)" | ||
| 18 | +log "------------------------------------------------------------------------" | ||
| 19 | + | ||
| 20 | +url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k.tar.bz2 | ||
| 21 | +name=$(basename $url) | ||
| 22 | +curl -SL -O $url | ||
| 23 | +tar xvf $name | ||
| 24 | +rm $name | ||
| 25 | +repo=$(basename -s .tar.bz2 $name) | ||
| 26 | +ls -lh $repo | ||
| 27 | + | ||
| 28 | +log "test $repo" | ||
| 29 | +test_wavs=( | ||
| 30 | +de-german.wav | ||
| 31 | +es-spanish.wav | ||
| 32 | +hr-croatian.wav | ||
| 33 | +po-polish.wav | ||
| 34 | +uk-ukrainian.wav | ||
| 35 | +en-english.wav | ||
| 36 | +fr-french.wav | ||
| 37 | +it-italian.wav | ||
| 38 | +ru-russian.wav | ||
| 39 | +) | ||
| 40 | +for w in ${test_wavs[@]}; do | ||
| 41 | + time $EXE \ | ||
| 42 | + --tokens=$repo/tokens.txt \ | ||
| 43 | + --encoder=$repo/encoder.onnx \ | ||
| 44 | + --decoder=$repo/decoder.onnx \ | ||
| 45 | + --joiner=$repo/joiner.onnx \ | ||
| 46 | + --debug=1 \ | ||
| 47 | + $repo/test_wavs/$w | ||
| 48 | +done | ||
| 49 | + | ||
| 50 | +rm -rf $repo | ||
| 51 | + | ||
| 52 | +url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-en-24500.tar.bz2 | ||
| 53 | +name=$(basename $url) | ||
| 54 | +curl -SL -O $url | ||
| 55 | +tar xvf $name | ||
| 56 | +rm $name | ||
| 57 | +repo=$(basename -s .tar.bz2 $name) | ||
| 58 | +ls -lh $repo | ||
| 59 | + | ||
| 60 | +log "Test $repo" | ||
| 61 | + | ||
| 62 | +time $EXE \ | ||
| 63 | + --tokens=$repo/tokens.txt \ | ||
| 64 | + --encoder=$repo/encoder.onnx \ | ||
| 65 | + --decoder=$repo/decoder.onnx \ | ||
| 66 | + --joiner=$repo/joiner.onnx \ | ||
| 67 | + --debug=1 \ | ||
| 68 | + $repo/test_wavs/en-english.wav | ||
| 69 | + | ||
| 70 | +rm -rf $repo | ||
| 71 | + | ||
| 72 | +url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-es-1424.tar.bz2 | ||
| 73 | +name=$(basename $url) | ||
| 74 | +curl -SL -O $url | ||
| 75 | +tar xvf $name | ||
| 76 | +rm $name | ||
| 77 | +repo=$(basename -s .tar.bz2 $name) | ||
| 78 | +ls -lh $repo | ||
| 79 | + | ||
| 80 | +log "test $repo" | ||
| 81 | + | ||
| 82 | +time $EXE \ | ||
| 83 | + --tokens=$repo/tokens.txt \ | ||
| 84 | + --encoder=$repo/encoder.onnx \ | ||
| 85 | + --decoder=$repo/decoder.onnx \ | ||
| 86 | + --joiner=$repo/joiner.onnx \ | ||
| 87 | + --debug=1 \ | ||
| 88 | + $repo/test_wavs/es-spanish.wav | ||
| 89 | + | ||
| 90 | +rm -rf $repo | ||
| 91 | + | ||
| 92 | +url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-en-de-es-fr-14288.tar.bz2 | ||
| 93 | +name=$(basename $url) | ||
| 94 | +curl -SL -O $url | ||
| 95 | +tar xvf $name | ||
| 96 | +rm $name | ||
| 97 | +repo=$(basename -s .tar.bz2 $name) | ||
| 98 | +ls -lh $repo | ||
| 99 | + | ||
| 100 | +log "Test $repo" | ||
| 101 | + | ||
| 102 | +time $EXE \ | ||
| 103 | + --tokens=$repo/tokens.txt \ | ||
| 104 | + --encoder=$repo/encoder.onnx \ | ||
| 105 | + --decoder=$repo/decoder.onnx \ | ||
| 106 | + --joiner=$repo/joiner.onnx \ | ||
| 107 | + --debug=1 \ | ||
| 108 | + $repo/test_wavs/en-english.wav \ | ||
| 109 | + $repo/test_wavs/de-german.wav \ | ||
| 110 | + $repo/test_wavs/fr-french.wav \ | ||
| 111 | + $repo/test_wavs/es-spanish.wav | ||
| 112 | + | ||
| 113 | +rm -rf $repo | ||
| 114 | + | ||
| 16 | log "------------------------------------------------------------" | 115 | log "------------------------------------------------------------" |
| 17 | log "Run Conformer transducer (English)" | 116 | log "Run Conformer transducer (English)" |
| 18 | log "------------------------------------------------------------" | 117 | log "------------------------------------------------------------" |
| @@ -128,6 +128,14 @@ jobs: | @@ -128,6 +128,14 @@ jobs: | ||
| 128 | name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} | 128 | name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} |
| 129 | path: install/* | 129 | path: install/* |
| 130 | 130 | ||
| 131 | + - name: Test offline transducer | ||
| 132 | + shell: bash | ||
| 133 | + run: | | ||
| 134 | + export PATH=$PWD/build/bin:$PATH | ||
| 135 | + export EXE=sherpa-onnx-offline | ||
| 136 | + | ||
| 137 | + .github/scripts/test-offline-transducer.sh | ||
| 138 | + | ||
| 131 | - name: Test spoken language identification (C++ API) | 139 | - name: Test spoken language identification (C++ API) |
| 132 | shell: bash | 140 | shell: bash |
| 133 | run: | | 141 | run: | |
| @@ -215,14 +223,6 @@ jobs: | @@ -215,14 +223,6 @@ jobs: | ||
| 215 | 223 | ||
| 216 | .github/scripts/test-online-paraformer.sh | 224 | .github/scripts/test-online-paraformer.sh |
| 217 | 225 | ||
| 218 | - - name: Test offline transducer | ||
| 219 | - shell: bash | ||
| 220 | - run: | | ||
| 221 | - export PATH=$PWD/build/bin:$PATH | ||
| 222 | - export EXE=sherpa-onnx-offline | ||
| 223 | - | ||
| 224 | - .github/scripts/test-offline-transducer.sh | ||
| 225 | - | ||
| 226 | - name: Test online transducer | 226 | - name: Test online transducer |
| 227 | shell: bash | 227 | shell: bash |
| 228 | run: | | 228 | run: | |
| @@ -107,6 +107,14 @@ jobs: | @@ -107,6 +107,14 @@ jobs: | ||
| 107 | otool -L build/bin/sherpa-onnx | 107 | otool -L build/bin/sherpa-onnx |
| 108 | otool -l build/bin/sherpa-onnx | 108 | otool -l build/bin/sherpa-onnx |
| 109 | 109 | ||
| 110 | + - name: Test offline transducer | ||
| 111 | + shell: bash | ||
| 112 | + run: | | ||
| 113 | + export PATH=$PWD/build/bin:$PATH | ||
| 114 | + export EXE=sherpa-onnx-offline | ||
| 115 | + | ||
| 116 | + .github/scripts/test-offline-transducer.sh | ||
| 117 | + | ||
| 110 | - name: Test online CTC | 118 | - name: Test online CTC |
| 111 | shell: bash | 119 | shell: bash |
| 112 | run: | | 120 | run: | |
| @@ -192,14 +200,6 @@ jobs: | @@ -192,14 +200,6 @@ jobs: | ||
| 192 | 200 | ||
| 193 | .github/scripts/test-offline-ctc.sh | 201 | .github/scripts/test-offline-ctc.sh |
| 194 | 202 | ||
| 195 | - - name: Test offline transducer | ||
| 196 | - shell: bash | ||
| 197 | - run: | | ||
| 198 | - export PATH=$PWD/build/bin:$PATH | ||
| 199 | - export EXE=sherpa-onnx-offline | ||
| 200 | - | ||
| 201 | - .github/scripts/test-offline-transducer.sh | ||
| 202 | - | ||
| 203 | - name: Test online transducer | 203 | - name: Test online transducer |
| 204 | shell: bash | 204 | shell: bash |
| 205 | run: | | 205 | run: | |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | + | ||
| 3 | +""" | ||
| 4 | +This file shows how to use a non-streaming CTC model from NeMo | ||
| 5 | +to decode files. | ||
| 6 | + | ||
| 7 | +Please download model files from | ||
| 8 | +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models | ||
| 9 | + | ||
| 10 | + | ||
| 11 | +The example model supports 10 languages and it is converted from | ||
| 12 | +https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc | ||
| 13 | +""" | ||
| 14 | + | ||
| 15 | +from pathlib import Path | ||
| 16 | + | ||
| 17 | +import sherpa_onnx | ||
| 18 | +import soundfile as sf | ||
| 19 | + | ||
| 20 | + | ||
| 21 | +def create_recognizer(): | ||
| 22 | + model = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/model.onnx" | ||
| 23 | + tokens = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/tokens.txt" | ||
| 24 | + | ||
| 25 | + test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/de-german.wav" | ||
| 26 | + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/en-english.wav" | ||
| 27 | + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/es-spanish.wav" | ||
| 28 | + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/fr-french.wav" | ||
| 29 | + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/hr-croatian.wav" | ||
| 30 | + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/it-italian.wav" | ||
| 31 | + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/po-polish.wav" | ||
| 32 | + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/ru-russian.wav" | ||
| 33 | + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/uk-ukrainian.wav" | ||
| 34 | + | ||
| 35 | + if not Path(model).is_file() or not Path(test_wav).is_file(): | ||
| 36 | + raise ValueError( | ||
| 37 | + """Please download model files from | ||
| 38 | + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models | ||
| 39 | + """ | ||
| 40 | + ) | ||
| 41 | + return ( | ||
| 42 | + sherpa_onnx.OfflineRecognizer.from_nemo_ctc( | ||
| 43 | + model=model, | ||
| 44 | + tokens=tokens, | ||
| 45 | + debug=True, | ||
| 46 | + ), | ||
| 47 | + test_wav, | ||
| 48 | + ) | ||
| 49 | + | ||
| 50 | + | ||
| 51 | +def main(): | ||
| 52 | + recognizer, wave_filename = create_recognizer() | ||
| 53 | + | ||
| 54 | + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) | ||
| 55 | + audio = audio[:, 0] # only use the first channel | ||
| 56 | + | ||
| 57 | + # audio is a 1-D float32 numpy array normalized to the range [-1, 1] | ||
| 58 | + # sample_rate does not need to be 16000 Hz | ||
| 59 | + | ||
| 60 | + stream = recognizer.create_stream() | ||
| 61 | + stream.accept_waveform(sample_rate, audio) | ||
| 62 | + recognizer.decode_stream(stream) | ||
| 63 | + print(wave_filename) | ||
| 64 | + print(stream.result) | ||
| 65 | + | ||
| 66 | + | ||
| 67 | +if __name__ == "__main__": | ||
| 68 | + main() |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | + | ||
| 3 | +""" | ||
| 4 | +This file shows how to use a non-streaming transducer model from NeMo | ||
| 5 | +to decode files. | ||
| 6 | + | ||
| 7 | +Please download model files from | ||
| 8 | +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models | ||
| 9 | + | ||
| 10 | + | ||
| 11 | +The example model supports 10 languages and it is converted from | ||
| 12 | +https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc | ||
| 13 | +""" | ||
| 14 | + | ||
| 15 | +from pathlib import Path | ||
| 16 | + | ||
| 17 | +import sherpa_onnx | ||
| 18 | +import soundfile as sf | ||
| 19 | + | ||
| 20 | + | ||
| 21 | +def create_recognizer(): | ||
| 22 | + encoder = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/encoder.onnx" | ||
| 23 | + decoder = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/decoder.onnx" | ||
| 24 | + joiner = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/joiner.onnx" | ||
| 25 | + tokens = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/tokens.txt" | ||
| 26 | + | ||
| 27 | + test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/de-german.wav" | ||
| 28 | + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/en-english.wav" | ||
| 29 | + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/es-spanish.wav" | ||
| 30 | + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/fr-french.wav" | ||
| 31 | + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/hr-croatian.wav" | ||
| 32 | + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/it-italian.wav" | ||
| 33 | + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/po-polish.wav" | ||
| 34 | + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/ru-russian.wav" | ||
| 35 | + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/uk-ukrainian.wav" | ||
| 36 | + | ||
| 37 | + if not Path(encoder).is_file() or not Path(test_wav).is_file(): | ||
| 38 | + raise ValueError( | ||
| 39 | + """Please download model files from | ||
| 40 | + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models | ||
| 41 | + """ | ||
| 42 | + ) | ||
| 43 | + return ( | ||
| 44 | + sherpa_onnx.OfflineRecognizer.from_transducer( | ||
| 45 | + encoder=encoder, | ||
| 46 | + decoder=decoder, | ||
| 47 | + joiner=joiner, | ||
| 48 | + tokens=tokens, | ||
| 49 | + model_type="nemo_transducer", | ||
| 50 | + debug=True, | ||
| 51 | + ), | ||
| 52 | + test_wav, | ||
| 53 | + ) | ||
| 54 | + | ||
| 55 | + | ||
| 56 | +def main(): | ||
| 57 | + recognizer, wave_filename = create_recognizer() | ||
| 58 | + | ||
| 59 | + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) | ||
| 60 | + audio = audio[:, 0] # only use the first channel | ||
| 61 | + | ||
| 62 | + # audio is a 1-D float32 numpy array normalized to the range [-1, 1] | ||
| 63 | + # sample_rate does not need to be 16000 Hz | ||
| 64 | + | ||
| 65 | + stream = recognizer.create_stream() | ||
| 66 | + stream.accept_waveform(sample_rate, audio) | ||
| 67 | + recognizer.decode_stream(stream) | ||
| 68 | + print(wave_filename) | ||
| 69 | + print(stream.result) | ||
| 70 | + | ||
| 71 | + | ||
| 72 | +if __name__ == "__main__": | ||
| 73 | + main() |
| @@ -40,9 +40,11 @@ set(sources | @@ -40,9 +40,11 @@ set(sources | ||
| 40 | offline-tdnn-ctc-model.cc | 40 | offline-tdnn-ctc-model.cc |
| 41 | offline-tdnn-model-config.cc | 41 | offline-tdnn-model-config.cc |
| 42 | offline-transducer-greedy-search-decoder.cc | 42 | offline-transducer-greedy-search-decoder.cc |
| 43 | + offline-transducer-greedy-search-nemo-decoder.cc | ||
| 43 | offline-transducer-model-config.cc | 44 | offline-transducer-model-config.cc |
| 44 | offline-transducer-model.cc | 45 | offline-transducer-model.cc |
| 45 | offline-transducer-modified-beam-search-decoder.cc | 46 | offline-transducer-modified-beam-search-decoder.cc |
| 47 | + offline-transducer-nemo-model.cc | ||
| 46 | offline-wenet-ctc-model-config.cc | 48 | offline-wenet-ctc-model-config.cc |
| 47 | offline-wenet-ctc-model.cc | 49 | offline-wenet-ctc-model.cc |
| 48 | offline-whisper-greedy-search-decoder.cc | 50 | offline-whisper-greedy-search-decoder.cc |
| @@ -56,6 +56,19 @@ struct FeatureExtractorConfig { | @@ -56,6 +56,19 @@ struct FeatureExtractorConfig { | ||
| 56 | bool remove_dc_offset = true; // Subtract mean of wave before FFT. | 56 | bool remove_dc_offset = true; // Subtract mean of wave before FFT. |
| 57 | std::string window_type = "povey"; // e.g. Hamming window | 57 | std::string window_type = "povey"; // e.g. Hamming window |
| 58 | 58 | ||
| 59 | + // For models from NeMo | ||
| 60 | + // This option is not exposed and is set internally when loading models. | ||
| 61 | + // Possible values: | ||
| 62 | + // - per_feature | ||
| 63 | + // - all_features (not implemented yet) | ||
| 64 | + // - fixed_mean (not implemented) | ||
| 65 | + // - fixed_std (not implemented) | ||
| 66 | + // - or just leave it to empty | ||
| 67 | + // See | ||
| 68 | + // https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59 | ||
| 69 | + // for details | ||
| 70 | + std::string nemo_normalize_type; | ||
| 71 | + | ||
| 59 | std::string ToString() const; | 72 | std::string ToString() const; |
| 60 | 73 | ||
| 61 | void Register(ParseOptions *po); | 74 | void Register(ParseOptions *po); |
| @@ -68,7 +68,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { | @@ -68,7 +68,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { | ||
| 68 | : config_(config), | 68 | : config_(config), |
| 69 | model_(OnlineTransducerModel::Create(config.model_config)), | 69 | model_(OnlineTransducerModel::Create(config.model_config)), |
| 70 | sym_(config.model_config.tokens) { | 70 | sym_(config.model_config.tokens) { |
| 71 | - if (sym_.contains("<unk>")) { | 71 | + if (sym_.Contains("<unk>")) { |
| 72 | unk_id_ = sym_["<unk>"]; | 72 | unk_id_ = sym_["<unk>"]; |
| 73 | } | 73 | } |
| 74 | 74 | ||
| @@ -87,7 +87,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { | @@ -87,7 +87,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { | ||
| 87 | : config_(config), | 87 | : config_(config), |
| 88 | model_(OnlineTransducerModel::Create(mgr, config.model_config)), | 88 | model_(OnlineTransducerModel::Create(mgr, config.model_config)), |
| 89 | sym_(mgr, config.model_config.tokens) { | 89 | sym_(mgr, config.model_config.tokens) { |
| 90 | - if (sym_.contains("<unk>")) { | 90 | + if (sym_.Contains("<unk>")) { |
| 91 | unk_id_ = sym_["<unk>"]; | 91 | unk_id_ = sym_["<unk>"]; |
| 92 | } | 92 | } |
| 93 | 93 |
| 1 | // sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc | 1 | // sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc |
| 2 | // | 2 | // |
| 3 | -// Copyright (c) 2023 Xiaomi Corporation | 3 | +// Copyright (c) 2023-2024 Xiaomi Corporation |
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h" | 5 | #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h" |
| 6 | 6 |
| @@ -38,7 +38,7 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, | @@ -38,7 +38,7 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, | ||
| 38 | std::string text; | 38 | std::string text; |
| 39 | 39 | ||
| 40 | for (int32_t i = 0; i != src.tokens.size(); ++i) { | 40 | for (int32_t i = 0; i != src.tokens.size(); ++i) { |
| 41 | - if (sym_table.contains("SIL") && src.tokens[i] == sym_table["SIL"]) { | 41 | + if (sym_table.Contains("SIL") && src.tokens[i] == sym_table["SIL"]) { |
| 42 | // tdnn models from yesno have a SIL token, we should remove it. | 42 | // tdnn models from yesno have a SIL token, we should remove it. |
| 43 | continue; | 43 | continue; |
| 44 | } | 44 | } |
| @@ -103,9 +103,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | @@ -103,9 +103,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | ||
| 103 | decoder_ = std::make_unique<OfflineCtcFstDecoder>( | 103 | decoder_ = std::make_unique<OfflineCtcFstDecoder>( |
| 104 | config_.ctc_fst_decoder_config); | 104 | config_.ctc_fst_decoder_config); |
| 105 | } else if (config_.decoding_method == "greedy_search") { | 105 | } else if (config_.decoding_method == "greedy_search") { |
| 106 | - if (!symbol_table_.contains("<blk>") && | ||
| 107 | - !symbol_table_.contains("<eps>") && | ||
| 108 | - !symbol_table_.contains("<blank>")) { | 106 | + if (!symbol_table_.Contains("<blk>") && |
| 107 | + !symbol_table_.Contains("<eps>") && | ||
| 108 | + !symbol_table_.Contains("<blank>")) { | ||
| 109 | SHERPA_ONNX_LOGE( | 109 | SHERPA_ONNX_LOGE( |
| 110 | "We expect that tokens.txt contains " | 110 | "We expect that tokens.txt contains " |
| 111 | "the symbol <blk> or <eps> or <blank> and its ID."); | 111 | "the symbol <blk> or <eps> or <blank> and its ID."); |
| @@ -113,12 +113,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | @@ -113,12 +113,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | ||
| 113 | } | 113 | } |
| 114 | 114 | ||
| 115 | int32_t blank_id = 0; | 115 | int32_t blank_id = 0; |
| 116 | - if (symbol_table_.contains("<blk>")) { | 116 | + if (symbol_table_.Contains("<blk>")) { |
| 117 | blank_id = symbol_table_["<blk>"]; | 117 | blank_id = symbol_table_["<blk>"]; |
| 118 | - } else if (symbol_table_.contains("<eps>")) { | 118 | + } else if (symbol_table_.Contains("<eps>")) { |
| 119 | // for tdnn models of the yesno recipe from icefall | 119 | // for tdnn models of the yesno recipe from icefall |
| 120 | blank_id = symbol_table_["<eps>"]; | 120 | blank_id = symbol_table_["<eps>"]; |
| 121 | - } else if (symbol_table_.contains("<blank>")) { | 121 | + } else if (symbol_table_.Contains("<blank>")) { |
| 122 | // for Wenet CTC models | 122 | // for Wenet CTC models |
| 123 | blank_id = symbol_table_["<blank>"]; | 123 | blank_id = symbol_table_["<blank>"]; |
| 124 | } | 124 | } |
| @@ -11,6 +11,7 @@ | @@ -11,6 +11,7 @@ | ||
| 11 | #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" | 11 | #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" |
| 12 | #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h" | 12 | #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h" |
| 13 | #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" | 13 | #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" |
| 14 | +#include "sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h" | ||
| 14 | #include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h" | 15 | #include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h" |
| 15 | #include "sherpa-onnx/csrc/onnx-utils.h" | 16 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 16 | #include "sherpa-onnx/csrc/text-utils.h" | 17 | #include "sherpa-onnx/csrc/text-utils.h" |
| @@ -23,6 +24,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -23,6 +24,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 23 | const auto &model_type = config.model_config.model_type; | 24 | const auto &model_type = config.model_config.model_type; |
| 24 | if (model_type == "transducer") { | 25 | if (model_type == "transducer") { |
| 25 | return std::make_unique<OfflineRecognizerTransducerImpl>(config); | 26 | return std::make_unique<OfflineRecognizerTransducerImpl>(config); |
| 27 | + } else if (model_type == "nemo_transducer") { | ||
| 28 | + return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(config); | ||
| 26 | } else if (model_type == "paraformer") { | 29 | } else if (model_type == "paraformer") { |
| 27 | return std::make_unique<OfflineRecognizerParaformerImpl>(config); | 30 | return std::make_unique<OfflineRecognizerParaformerImpl>(config); |
| 28 | } else if (model_type == "nemo_ctc" || model_type == "tdnn" || | 31 | } else if (model_type == "nemo_ctc" || model_type == "tdnn" || |
| @@ -122,6 +125,12 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -122,6 +125,12 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 122 | return std::make_unique<OfflineRecognizerParaformerImpl>(config); | 125 | return std::make_unique<OfflineRecognizerParaformerImpl>(config); |
| 123 | } | 126 | } |
| 124 | 127 | ||
| 128 | + if (model_type == "EncDecHybridRNNTCTCBPEModel" && | ||
| 129 | + !config.model_config.transducer.decoder_filename.empty() && | ||
| 130 | + !config.model_config.transducer.joiner_filename.empty()) { | ||
| 131 | + return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(config); | ||
| 132 | + } | ||
| 133 | + | ||
| 125 | if (model_type == "EncDecCTCModelBPE" || | 134 | if (model_type == "EncDecCTCModelBPE" || |
| 126 | model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" || | 135 | model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" || |
| 127 | model_type == "zipformer2_ctc" || model_type == "wenet_ctc") { | 136 | model_type == "zipformer2_ctc" || model_type == "wenet_ctc") { |
| @@ -155,6 +164,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -155,6 +164,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 155 | const auto &model_type = config.model_config.model_type; | 164 | const auto &model_type = config.model_config.model_type; |
| 156 | if (model_type == "transducer") { | 165 | if (model_type == "transducer") { |
| 157 | return std::make_unique<OfflineRecognizerTransducerImpl>(mgr, config); | 166 | return std::make_unique<OfflineRecognizerTransducerImpl>(mgr, config); |
| 167 | + } else if (model_type == "nemo_transducer") { | ||
| 168 | + return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(mgr, config); | ||
| 158 | } else if (model_type == "paraformer") { | 169 | } else if (model_type == "paraformer") { |
| 159 | return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config); | 170 | return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config); |
| 160 | } else if (model_type == "nemo_ctc" || model_type == "tdnn" || | 171 | } else if (model_type == "nemo_ctc" || model_type == "tdnn" || |
| @@ -254,6 +265,12 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -254,6 +265,12 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 254 | return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config); | 265 | return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config); |
| 255 | } | 266 | } |
| 256 | 267 | ||
| 268 | + if (model_type == "EncDecHybridRNNTCTCBPEModel" && | ||
| 269 | + !config.model_config.transducer.decoder_filename.empty() && | ||
| 270 | + !config.model_config.transducer.joiner_filename.empty()) { | ||
| 271 | + return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(mgr, config); | ||
| 272 | + } | ||
| 273 | + | ||
| 257 | if (model_type == "EncDecCTCModelBPE" || | 274 | if (model_type == "EncDecCTCModelBPE" || |
| 258 | model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" || | 275 | model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" || |
| 259 | model_type == "zipformer2_ctc" || model_type == "wenet_ctc") { | 276 | model_type == "zipformer2_ctc" || model_type == "wenet_ctc") { |
| 1 | +// sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ | ||
| 7 | + | ||
| 8 | +#include <fstream> | ||
| 9 | +#include <ios> | ||
| 10 | +#include <memory> | ||
| 11 | +#include <regex> // NOLINT | ||
| 12 | +#include <sstream> | ||
| 13 | +#include <string> | ||
| 14 | +#include <utility> | ||
| 15 | +#include <vector> | ||
| 16 | + | ||
| 17 | +#if __ANDROID_API__ >= 9 | ||
| 18 | +#include "android/asset_manager.h" | ||
| 19 | +#include "android/asset_manager_jni.h" | ||
| 20 | +#endif | ||
| 21 | + | ||
| 22 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 23 | +#include "sherpa-onnx/csrc/offline-recognizer-impl.h" | ||
| 24 | +#include "sherpa-onnx/csrc/offline-recognizer.h" | ||
| 25 | +#include "sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h" | ||
| 26 | +#include "sherpa-onnx/csrc/offline-transducer-nemo-model.h" | ||
| 27 | +#include "sherpa-onnx/csrc/pad-sequence.h" | ||
| 28 | +#include "sherpa-onnx/csrc/symbol-table.h" | ||
| 29 | +#include "sherpa-onnx/csrc/transpose.h" | ||
| 30 | +#include "sherpa-onnx/csrc/utils.h" | ||
| 31 | + | ||
| 32 | +namespace sherpa_onnx { | ||
| 33 | + | ||
| 34 | +// defined in ./offline-recognizer-transducer-impl.h | ||
| 35 | +OfflineRecognitionResult Convert(const OfflineTransducerDecoderResult &src, | ||
| 36 | + const SymbolTable &sym_table, | ||
| 37 | + int32_t frame_shift_ms, | ||
| 38 | + int32_t subsampling_factor); | ||
| 39 | + | ||
| 40 | +class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { | ||
| 41 | + public: | ||
| 42 | + explicit OfflineRecognizerTransducerNeMoImpl( | ||
| 43 | + const OfflineRecognizerConfig &config) | ||
| 44 | + : config_(config), | ||
| 45 | + symbol_table_(config_.model_config.tokens), | ||
| 46 | + model_(std::make_unique<OfflineTransducerNeMoModel>( | ||
| 47 | + config_.model_config)) { | ||
| 48 | + if (config_.decoding_method == "greedy_search") { | ||
| 49 | + decoder_ = std::make_unique<OfflineTransducerGreedySearchNeMoDecoder>( | ||
| 50 | + model_.get(), config_.blank_penalty); | ||
| 51 | + } else { | ||
| 52 | + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", | ||
| 53 | + config_.decoding_method.c_str()); | ||
| 54 | + exit(-1); | ||
| 55 | + } | ||
| 56 | + PostInit(); | ||
| 57 | + } | ||
| 58 | + | ||
| 59 | +#if __ANDROID_API__ >= 9 | ||
| 60 | + explicit OfflineRecognizerTransducerNeMoImpl( | ||
| 61 | + AAssetManager *mgr, const OfflineRecognizerConfig &config) | ||
| 62 | + : config_(config), | ||
| 63 | + symbol_table_(mgr, config_.model_config.tokens), | ||
| 64 | + model_(std::make_unique<OfflineTransducerNeMoModel>( | ||
| 65 | + mgr, config_.model_config)) { | ||
| 66 | + if (config_.decoding_method == "greedy_search") { | ||
| 67 | + decoder_ = std::make_unique<OfflineTransducerGreedySearchNeMoDecoder>( | ||
| 68 | + model_.get(), config_.blank_penalty); | ||
| 69 | + } else { | ||
| 70 | + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", | ||
| 71 | + config_.decoding_method.c_str()); | ||
| 72 | + exit(-1); | ||
| 73 | + } | ||
| 74 | + | ||
| 75 | + PostInit(); | ||
| 76 | + } | ||
| 77 | +#endif | ||
| 78 | + | ||
| 79 | + std::unique_ptr<OfflineStream> CreateStream() const override { | ||
| 80 | + return std::make_unique<OfflineStream>(config_.feat_config); | ||
| 81 | + } | ||
| 82 | + | ||
| 83 | + void DecodeStreams(OfflineStream **ss, int32_t n) const override { | ||
| 84 | + auto memory_info = | ||
| 85 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 86 | + | ||
| 87 | + int32_t feat_dim = ss[0]->FeatureDim(); | ||
| 88 | + | ||
| 89 | + std::vector<Ort::Value> features; | ||
| 90 | + | ||
| 91 | + features.reserve(n); | ||
| 92 | + | ||
| 93 | + std::vector<std::vector<float>> features_vec(n); | ||
| 94 | + std::vector<int64_t> features_length_vec(n); | ||
| 95 | + for (int32_t i = 0; i != n; ++i) { | ||
| 96 | + auto f = ss[i]->GetFrames(); | ||
| 97 | + int32_t num_frames = f.size() / feat_dim; | ||
| 98 | + | ||
| 99 | + features_length_vec[i] = num_frames; | ||
| 100 | + features_vec[i] = std::move(f); | ||
| 101 | + | ||
| 102 | + std::array<int64_t, 2> shape = {num_frames, feat_dim}; | ||
| 103 | + | ||
| 104 | + Ort::Value x = Ort::Value::CreateTensor( | ||
| 105 | + memory_info, features_vec[i].data(), features_vec[i].size(), | ||
| 106 | + shape.data(), shape.size()); | ||
| 107 | + features.push_back(std::move(x)); | ||
| 108 | + } | ||
| 109 | + | ||
| 110 | + std::vector<const Ort::Value *> features_pointer(n); | ||
| 111 | + for (int32_t i = 0; i != n; ++i) { | ||
| 112 | + features_pointer[i] = &features[i]; | ||
| 113 | + } | ||
| 114 | + | ||
| 115 | + std::array<int64_t, 1> features_length_shape = {n}; | ||
| 116 | + Ort::Value x_length = Ort::Value::CreateTensor( | ||
| 117 | + memory_info, features_length_vec.data(), n, | ||
| 118 | + features_length_shape.data(), features_length_shape.size()); | ||
| 119 | + | ||
| 120 | + Ort::Value x = PadSequence(model_->Allocator(), features_pointer, 0); | ||
| 121 | + | ||
| 122 | + auto t = model_->RunEncoder(std::move(x), std::move(x_length)); | ||
| 123 | + // t[0] encoder_out, float tensor, (batch_size, dim, T) | ||
| 124 | + // t[1] encoder_out_length, int64 tensor, (batch_size,) | ||
| 125 | + | ||
| 126 | + Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]); | ||
| 127 | + | ||
| 128 | + auto results = decoder_->Decode(std::move(encoder_out), std::move(t[1])); | ||
| 129 | + | ||
| 130 | + int32_t frame_shift_ms = 10; | ||
| 131 | + for (int32_t i = 0; i != n; ++i) { | ||
| 132 | + auto r = Convert(results[i], symbol_table_, frame_shift_ms, | ||
| 133 | + model_->SubsamplingFactor()); | ||
| 134 | + | ||
| 135 | + ss[i]->SetResult(r); | ||
| 136 | + } | ||
| 137 | + } | ||
| 138 | + | ||
| 139 | + private: | ||
| 140 | + void PostInit() { | ||
| 141 | + config_.feat_config.nemo_normalize_type = | ||
| 142 | + model_->FeatureNormalizationMethod(); | ||
| 143 | + | ||
| 144 | + config_.feat_config.low_freq = 0; | ||
| 145 | + // config_.feat_config.high_freq = 8000; | ||
| 146 | + config_.feat_config.is_librosa = true; | ||
| 147 | + config_.feat_config.remove_dc_offset = false; | ||
| 148 | + // config_.feat_config.window_type = "hann"; | ||
| 149 | + config_.feat_config.dither = 0; | ||
| 150 | + config_.feat_config.nemo_normalize_type = | ||
| 151 | + model_->FeatureNormalizationMethod(); | ||
| 152 | + | ||
| 153 | + int32_t vocab_size = model_->VocabSize(); | ||
| 154 | + | ||
| 155 | + // check the blank ID | ||
| 156 | + if (!symbol_table_.Contains("<blk>")) { | ||
| 157 | + SHERPA_ONNX_LOGE("tokens.txt does not include the blank token <blk>"); | ||
| 158 | + exit(-1); | ||
| 159 | + } | ||
| 160 | + | ||
| 161 | + if (symbol_table_["<blk>"] != vocab_size - 1) { | ||
| 162 | + SHERPA_ONNX_LOGE("<blk> is not the last token!"); | ||
| 163 | + exit(-1); | ||
| 164 | + } | ||
| 165 | + | ||
| 166 | + if (symbol_table_.NumSymbols() != vocab_size) { | ||
| 167 | + SHERPA_ONNX_LOGE("number of lines in tokens.txt %d != %d (vocab_size)", | ||
| 168 | + symbol_table_.NumSymbols(), vocab_size); | ||
| 169 | + exit(-1); | ||
| 170 | + } | ||
| 171 | + } | ||
| 172 | + | ||
| 173 | + private: | ||
| 174 | + OfflineRecognizerConfig config_; | ||
| 175 | + SymbolTable symbol_table_; | ||
| 176 | + std::unique_ptr<OfflineTransducerNeMoModel> model_; | ||
| 177 | + std::unique_ptr<OfflineTransducerDecoder> decoder_; | ||
| 178 | +}; | ||
| 179 | + | ||
| 180 | +} // namespace sherpa_onnx | ||
| 181 | + | ||
| 182 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ |
| @@ -35,7 +35,7 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src, | @@ -35,7 +35,7 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src, | ||
| 35 | 35 | ||
| 36 | std::string text; | 36 | std::string text; |
| 37 | for (auto i : src.tokens) { | 37 | for (auto i : src.tokens) { |
| 38 | - if (!sym_table.contains(i)) { | 38 | + if (!sym_table.Contains(i)) { |
| 39 | continue; | 39 | continue; |
| 40 | } | 40 | } |
| 41 | 41 |
| @@ -14,6 +14,7 @@ | @@ -14,6 +14,7 @@ | ||
| 14 | #include "android/asset_manager_jni.h" | 14 | #include "android/asset_manager_jni.h" |
| 15 | #endif | 15 | #endif |
| 16 | 16 | ||
| 17 | +#include "sherpa-onnx/csrc/features.h" | ||
| 17 | #include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h" | 18 | #include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h" |
| 18 | #include "sherpa-onnx/csrc/offline-lm-config.h" | 19 | #include "sherpa-onnx/csrc/offline-lm-config.h" |
| 19 | #include "sherpa-onnx/csrc/offline-model-config.h" | 20 | #include "sherpa-onnx/csrc/offline-model-config.h" |
| @@ -26,7 +27,7 @@ namespace sherpa_onnx { | @@ -26,7 +27,7 @@ namespace sherpa_onnx { | ||
| 26 | struct OfflineRecognitionResult; | 27 | struct OfflineRecognitionResult; |
| 27 | 28 | ||
| 28 | struct OfflineRecognizerConfig { | 29 | struct OfflineRecognizerConfig { |
| 29 | - OfflineFeatureExtractorConfig feat_config; | 30 | + FeatureExtractorConfig feat_config; |
| 30 | OfflineModelConfig model_config; | 31 | OfflineModelConfig model_config; |
| 31 | OfflineLMConfig lm_config; | 32 | OfflineLMConfig lm_config; |
| 32 | OfflineCtcFstDecoderConfig ctc_fst_decoder_config; | 33 | OfflineCtcFstDecoderConfig ctc_fst_decoder_config; |
| @@ -44,7 +45,7 @@ struct OfflineRecognizerConfig { | @@ -44,7 +45,7 @@ struct OfflineRecognizerConfig { | ||
| 44 | 45 | ||
| 45 | OfflineRecognizerConfig() = default; | 46 | OfflineRecognizerConfig() = default; |
| 46 | OfflineRecognizerConfig( | 47 | OfflineRecognizerConfig( |
| 47 | - const OfflineFeatureExtractorConfig &feat_config, | 48 | + const FeatureExtractorConfig &feat_config, |
| 48 | const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config, | 49 | const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config, |
| 49 | const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config, | 50 | const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config, |
| 50 | const std::string &decoding_method, int32_t max_active_paths, | 51 | const std::string &decoding_method, int32_t max_active_paths, |
| @@ -52,42 +52,25 @@ static void ComputeMeanAndInvStd(const float *p, int32_t num_rows, | @@ -52,42 +52,25 @@ static void ComputeMeanAndInvStd(const float *p, int32_t num_rows, | ||
| 52 | } | 52 | } |
| 53 | } | 53 | } |
| 54 | 54 | ||
| 55 | -void OfflineFeatureExtractorConfig::Register(ParseOptions *po) { | ||
| 56 | - po->Register("sample-rate", &sampling_rate, | ||
| 57 | - "Sampling rate of the input waveform. " | ||
| 58 | - "Note: You can have a different " | ||
| 59 | - "sample rate for the input waveform. We will do resampling " | ||
| 60 | - "inside the feature extractor"); | ||
| 61 | - | ||
| 62 | - po->Register("feat-dim", &feature_dim, | ||
| 63 | - "Feature dimension. Must match the one expected by the model."); | ||
| 64 | -} | ||
| 65 | - | ||
| 66 | -std::string OfflineFeatureExtractorConfig::ToString() const { | ||
| 67 | - std::ostringstream os; | ||
| 68 | - | ||
| 69 | - os << "OfflineFeatureExtractorConfig("; | ||
| 70 | - os << "sampling_rate=" << sampling_rate << ", "; | ||
| 71 | - os << "feature_dim=" << feature_dim << ")"; | ||
| 72 | - | ||
| 73 | - return os.str(); | ||
| 74 | -} | ||
| 75 | - | ||
| 76 | class OfflineStream::Impl { | 55 | class OfflineStream::Impl { |
| 77 | public: | 56 | public: |
| 78 | - explicit Impl(const OfflineFeatureExtractorConfig &config, | 57 | + explicit Impl(const FeatureExtractorConfig &config, |
| 79 | ContextGraphPtr context_graph) | 58 | ContextGraphPtr context_graph) |
| 80 | : config_(config), context_graph_(context_graph) { | 59 | : config_(config), context_graph_(context_graph) { |
| 81 | - opts_.frame_opts.dither = 0; | ||
| 82 | - opts_.frame_opts.snip_edges = false; | 60 | + opts_.frame_opts.dither = config.dither; |
| 61 | + opts_.frame_opts.snip_edges = config.snip_edges; | ||
| 83 | opts_.frame_opts.samp_freq = config.sampling_rate; | 62 | opts_.frame_opts.samp_freq = config.sampling_rate; |
| 63 | + opts_.frame_opts.frame_shift_ms = config.frame_shift_ms; | ||
| 64 | + opts_.frame_opts.frame_length_ms = config.frame_length_ms; | ||
| 65 | + opts_.frame_opts.remove_dc_offset = config.remove_dc_offset; | ||
| 66 | + opts_.frame_opts.window_type = config.window_type; | ||
| 67 | + | ||
| 84 | opts_.mel_opts.num_bins = config.feature_dim; | 68 | opts_.mel_opts.num_bins = config.feature_dim; |
| 85 | 69 | ||
| 86 | - // Please see | ||
| 87 | - // https://github.com/lhotse-speech/lhotse/blob/master/lhotse/features/fbank.py#L27 | ||
| 88 | - // and | ||
| 89 | - // https://github.com/k2-fsa/sherpa-onnx/issues/514 | ||
| 90 | - opts_.mel_opts.high_freq = -400; | 70 | + opts_.mel_opts.high_freq = config.high_freq; |
| 71 | + opts_.mel_opts.low_freq = config.low_freq; | ||
| 72 | + | ||
| 73 | + opts_.mel_opts.is_librosa = config.is_librosa; | ||
| 91 | 74 | ||
| 92 | fbank_ = std::make_unique<knf::OnlineFbank>(opts_); | 75 | fbank_ = std::make_unique<knf::OnlineFbank>(opts_); |
| 93 | } | 76 | } |
| @@ -237,7 +220,7 @@ class OfflineStream::Impl { | @@ -237,7 +220,7 @@ class OfflineStream::Impl { | ||
| 237 | } | 220 | } |
| 238 | 221 | ||
| 239 | private: | 222 | private: |
| 240 | - OfflineFeatureExtractorConfig config_; | 223 | + FeatureExtractorConfig config_; |
| 241 | std::unique_ptr<knf::OnlineFbank> fbank_; | 224 | std::unique_ptr<knf::OnlineFbank> fbank_; |
| 242 | std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_; | 225 | std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_; |
| 243 | knf::FbankOptions opts_; | 226 | knf::FbankOptions opts_; |
| @@ -245,9 +228,8 @@ class OfflineStream::Impl { | @@ -245,9 +228,8 @@ class OfflineStream::Impl { | ||
| 245 | ContextGraphPtr context_graph_; | 228 | ContextGraphPtr context_graph_; |
| 246 | }; | 229 | }; |
| 247 | 230 | ||
| 248 | -OfflineStream::OfflineStream( | ||
| 249 | - const OfflineFeatureExtractorConfig &config /*= {}*/, | ||
| 250 | - ContextGraphPtr context_graph /*= nullptr*/) | 231 | +OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/, |
| 232 | + ContextGraphPtr context_graph /*= nullptr*/) | ||
| 251 | : impl_(std::make_unique<Impl>(config, context_graph)) {} | 233 | : impl_(std::make_unique<Impl>(config, context_graph)) {} |
| 252 | 234 | ||
| 253 | OfflineStream::OfflineStream(WhisperTag tag) | 235 | OfflineStream::OfflineStream(WhisperTag tag) |
| @@ -11,6 +11,7 @@ | @@ -11,6 +11,7 @@ | ||
| 11 | #include <vector> | 11 | #include <vector> |
| 12 | 12 | ||
| 13 | #include "sherpa-onnx/csrc/context-graph.h" | 13 | #include "sherpa-onnx/csrc/context-graph.h" |
| 14 | +#include "sherpa-onnx/csrc/features.h" | ||
| 14 | #include "sherpa-onnx/csrc/parse-options.h" | 15 | #include "sherpa-onnx/csrc/parse-options.h" |
| 15 | 16 | ||
| 16 | namespace sherpa_onnx { | 17 | namespace sherpa_onnx { |
| @@ -32,46 +33,12 @@ struct OfflineRecognitionResult { | @@ -32,46 +33,12 @@ struct OfflineRecognitionResult { | ||
| 32 | std::string AsJsonString() const; | 33 | std::string AsJsonString() const; |
| 33 | }; | 34 | }; |
| 34 | 35 | ||
| 35 | -struct OfflineFeatureExtractorConfig { | ||
| 36 | - // Sampling rate used by the feature extractor. If it is different from | ||
| 37 | - // the sampling rate of the input waveform, we will do resampling inside. | ||
| 38 | - int32_t sampling_rate = 16000; | ||
| 39 | - | ||
| 40 | - // Feature dimension | ||
| 41 | - int32_t feature_dim = 80; | ||
| 42 | - | ||
| 43 | - // Set internally by some models, e.g., paraformer and wenet CTC models set | ||
| 44 | - // it to false. | ||
| 45 | - // This parameter is not exposed to users from the commandline | ||
| 46 | - // If true, the feature extractor expects inputs to be normalized to | ||
| 47 | - // the range [-1, 1]. | ||
| 48 | - // If false, we will multiply the inputs by 32768 | ||
| 49 | - bool normalize_samples = true; | ||
| 50 | - | ||
| 51 | - // For models from NeMo | ||
| 52 | - // This option is not exposed and is set internally when loading models. | ||
| 53 | - // Possible values: | ||
| 54 | - // - per_feature | ||
| 55 | - // - all_features (not implemented yet) | ||
| 56 | - // - fixed_mean (not implemented) | ||
| 57 | - // - fixed_std (not implemented) | ||
| 58 | - // - or just leave it to empty | ||
| 59 | - // See | ||
| 60 | - // https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59 | ||
| 61 | - // for details | ||
| 62 | - std::string nemo_normalize_type; | ||
| 63 | - | ||
| 64 | - std::string ToString() const; | ||
| 65 | - | ||
| 66 | - void Register(ParseOptions *po); | ||
| 67 | -}; | ||
| 68 | - | ||
| 69 | struct WhisperTag {}; | 36 | struct WhisperTag {}; |
| 70 | struct CEDTag {}; | 37 | struct CEDTag {}; |
| 71 | 38 | ||
| 72 | class OfflineStream { | 39 | class OfflineStream { |
| 73 | public: | 40 | public: |
| 74 | - explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {}, | 41 | + explicit OfflineStream(const FeatureExtractorConfig &config = {}, |
| 75 | ContextGraphPtr context_graph = {}); | 42 | ContextGraphPtr context_graph = {}); |
| 76 | 43 | ||
| 77 | explicit OfflineStream(WhisperTag tag); | 44 | explicit OfflineStream(WhisperTag tag); |
| @@ -14,8 +14,8 @@ namespace sherpa_onnx { | @@ -14,8 +14,8 @@ namespace sherpa_onnx { | ||
| 14 | 14 | ||
| 15 | class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { | 15 | class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { |
| 16 | public: | 16 | public: |
| 17 | - explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model, | ||
| 18 | - float blank_penalty) | 17 | + OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model, |
| 18 | + float blank_penalty) | ||
| 19 | : model_(model), blank_penalty_(blank_penalty) {} | 19 | : model_(model), blank_penalty_(blank_penalty) {} |
| 20 | 20 | ||
| 21 | std::vector<OfflineTransducerDecoderResult> Decode( | 21 | std::vector<OfflineTransducerDecoderResult> Decode( |
| 1 | +// sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <iterator> | ||
| 9 | +#include <utility> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 12 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +static std::pair<Ort::Value, Ort::Value> BuildDecoderInput( | ||
| 17 | + int32_t token, OrtAllocator *allocator) { | ||
| 18 | + std::array<int64_t, 2> shape{1, 1}; | ||
| 19 | + | ||
| 20 | + Ort::Value decoder_input = | ||
| 21 | + Ort::Value::CreateTensor<int32_t>(allocator, shape.data(), shape.size()); | ||
| 22 | + | ||
| 23 | + std::array<int64_t, 1> length_shape{1}; | ||
| 24 | + Ort::Value decoder_input_length = Ort::Value::CreateTensor<int32_t>( | ||
| 25 | + allocator, length_shape.data(), length_shape.size()); | ||
| 26 | + | ||
| 27 | + int32_t *p = decoder_input.GetTensorMutableData<int32_t>(); | ||
| 28 | + | ||
| 29 | + int32_t *p_length = decoder_input_length.GetTensorMutableData<int32_t>(); | ||
| 30 | + | ||
| 31 | + p[0] = token; | ||
| 32 | + | ||
| 33 | + p_length[0] = 1; | ||
| 34 | + | ||
| 35 | + return {std::move(decoder_input), std::move(decoder_input_length)}; | ||
| 36 | +} | ||
| 37 | + | ||
| 38 | +static OfflineTransducerDecoderResult DecodeOne( | ||
| 39 | + const float *p, int32_t num_rows, int32_t num_cols, | ||
| 40 | + OfflineTransducerNeMoModel *model, float blank_penalty) { | ||
| 41 | + auto memory_info = | ||
| 42 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 43 | + | ||
| 44 | + OfflineTransducerDecoderResult ans; | ||
| 45 | + | ||
| 46 | + int32_t vocab_size = model->VocabSize(); | ||
| 47 | + int32_t blank_id = vocab_size - 1; | ||
| 48 | + | ||
| 49 | + auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator()); | ||
| 50 | + | ||
| 51 | + std::pair<Ort::Value, std::vector<Ort::Value>> decoder_output_pair = | ||
| 52 | + model->RunDecoder(std::move(decoder_input_pair.first), | ||
| 53 | + std::move(decoder_input_pair.second), | ||
| 54 | + model->GetDecoderInitStates(1)); | ||
| 55 | + | ||
| 56 | + std::array<int64_t, 3> encoder_shape{1, num_cols, 1}; | ||
| 57 | + | ||
| 58 | + for (int32_t t = 0; t != num_rows; ++t) { | ||
| 59 | + Ort::Value cur_encoder_out = Ort::Value::CreateTensor( | ||
| 60 | + memory_info, const_cast<float *>(p) + t * num_cols, num_cols, | ||
| 61 | + encoder_shape.data(), encoder_shape.size()); | ||
| 62 | + | ||
| 63 | + Ort::Value logit = model->RunJoiner(std::move(cur_encoder_out), | ||
| 64 | + View(&decoder_output_pair.first)); | ||
| 65 | + | ||
| 66 | + float *p_logit = logit.GetTensorMutableData<float>(); | ||
| 67 | + if (blank_penalty > 0) { | ||
| 68 | + p_logit[blank_id] -= blank_penalty; | ||
| 69 | + } | ||
| 70 | + | ||
| 71 | + auto y = static_cast<int32_t>(std::distance( | ||
| 72 | + static_cast<const float *>(p_logit), | ||
| 73 | + std::max_element(static_cast<const float *>(p_logit), | ||
| 74 | + static_cast<const float *>(p_logit) + vocab_size))); | ||
| 75 | + | ||
| 76 | + if (y != blank_id) { | ||
| 77 | + ans.tokens.push_back(y); | ||
| 78 | + ans.timestamps.push_back(t); | ||
| 79 | + | ||
| 80 | + decoder_input_pair = BuildDecoderInput(y, model->Allocator()); | ||
| 81 | + | ||
| 82 | + decoder_output_pair = | ||
| 83 | + model->RunDecoder(std::move(decoder_input_pair.first), | ||
| 84 | + std::move(decoder_input_pair.second), | ||
| 85 | + std::move(decoder_output_pair.second)); | ||
| 86 | + } // if (y != blank_id) | ||
| 87 | + } // for (int32_t i = 0; i != num_rows; ++i) | ||
| 88 | + | ||
| 89 | + return ans; | ||
| 90 | +} | ||
| 91 | + | ||
| 92 | +std::vector<OfflineTransducerDecoderResult> | ||
| 93 | +OfflineTransducerGreedySearchNeMoDecoder::Decode( | ||
| 94 | + Ort::Value encoder_out, Ort::Value encoder_out_length, | ||
| 95 | + OfflineStream ** /*ss = nullptr*/, int32_t /*n= 0*/) { | ||
| 96 | + auto shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 97 | + | ||
| 98 | + int32_t batch_size = static_cast<int32_t>(shape[0]); | ||
| 99 | + int32_t dim1 = static_cast<int32_t>(shape[1]); | ||
| 100 | + int32_t dim2 = static_cast<int32_t>(shape[2]); | ||
| 101 | + | ||
| 102 | + const int64_t *p_length = encoder_out_length.GetTensorData<int64_t>(); | ||
| 103 | + const float *p = encoder_out.GetTensorData<float>(); | ||
| 104 | + | ||
| 105 | + std::vector<OfflineTransducerDecoderResult> ans(batch_size); | ||
| 106 | + | ||
| 107 | + for (int32_t i = 0; i != batch_size; ++i) { | ||
| 108 | + const float *this_p = p + dim1 * dim2 * i; | ||
| 109 | + int32_t this_len = p_length[i]; | ||
| 110 | + | ||
| 111 | + ans[i] = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_); | ||
| 112 | + } | ||
| 113 | + | ||
| 114 | + return ans; | ||
| 115 | +} | ||
| 116 | + | ||
| 117 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ | ||
| 7 | + | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/offline-transducer-decoder.h" | ||
| 11 | +#include "sherpa-onnx/csrc/offline-transducer-nemo-model.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +class OfflineTransducerGreedySearchNeMoDecoder | ||
| 16 | + : public OfflineTransducerDecoder { | ||
| 17 | + public: | ||
| 18 | + OfflineTransducerGreedySearchNeMoDecoder(OfflineTransducerNeMoModel *model, | ||
| 19 | + float blank_penalty) | ||
| 20 | + : model_(model), blank_penalty_(blank_penalty) {} | ||
| 21 | + | ||
| 22 | + std::vector<OfflineTransducerDecoderResult> Decode( | ||
| 23 | + Ort::Value encoder_out, Ort::Value encoder_out_length, | ||
| 24 | + OfflineStream **ss = nullptr, int32_t n = 0) override; | ||
| 25 | + | ||
| 26 | + private: | ||
| 27 | + OfflineTransducerNeMoModel *model_; // Not owned | ||
| 28 | + float blank_penalty_; | ||
| 29 | +}; | ||
| 30 | + | ||
| 31 | +} // namespace sherpa_onnx | ||
| 32 | + | ||
| 33 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ |
| 1 | +// sherpa-onnx/csrc/offline-transducer-nemo-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-transducer-nemo-model.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <string> | ||
| 9 | +#include <utility> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 13 | +#include "sherpa-onnx/csrc/offline-transducer-decoder.h" | ||
| 14 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 15 | +#include "sherpa-onnx/csrc/session.h" | ||
| 16 | +#include "sherpa-onnx/csrc/transpose.h" | ||
| 17 | + | ||
| 18 | +namespace sherpa_onnx { | ||
| 19 | + | ||
| 20 | +class OfflineTransducerNeMoModel::Impl { | ||
| 21 | + public: | ||
| 22 | + explicit Impl(const OfflineModelConfig &config) | ||
| 23 | + : config_(config), | ||
| 24 | + env_(ORT_LOGGING_LEVEL_WARNING), | ||
| 25 | + sess_opts_(GetSessionOptions(config)), | ||
| 26 | + allocator_{} { | ||
| 27 | + { | ||
| 28 | + auto buf = ReadFile(config.transducer.encoder_filename); | ||
| 29 | + InitEncoder(buf.data(), buf.size()); | ||
| 30 | + } | ||
| 31 | + | ||
| 32 | + { | ||
| 33 | + auto buf = ReadFile(config.transducer.decoder_filename); | ||
| 34 | + InitDecoder(buf.data(), buf.size()); | ||
| 35 | + } | ||
| 36 | + | ||
| 37 | + { | ||
| 38 | + auto buf = ReadFile(config.transducer.joiner_filename); | ||
| 39 | + InitJoiner(buf.data(), buf.size()); | ||
| 40 | + } | ||
| 41 | + } | ||
| 42 | + | ||
| 43 | +#if __ANDROID_API__ >= 9 | ||
| 44 | + Impl(AAssetManager *mgr, const OfflineModelConfig &config) | ||
| 45 | + : config_(config), | ||
| 46 | + env_(ORT_LOGGING_LEVEL_WARNING), | ||
| 47 | + sess_opts_(GetSessionOptions(config)), | ||
| 48 | + allocator_{} { | ||
| 49 | + { | ||
| 50 | + auto buf = ReadFile(mgr, config.transducer.encoder_filename); | ||
| 51 | + InitEncoder(buf.data(), buf.size()); | ||
| 52 | + } | ||
| 53 | + | ||
| 54 | + { | ||
| 55 | + auto buf = ReadFile(mgr, config.transducer.decoder_filename); | ||
| 56 | + InitDecoder(buf.data(), buf.size()); | ||
| 57 | + } | ||
| 58 | + | ||
| 59 | + { | ||
| 60 | + auto buf = ReadFile(mgr, config.transducer.joiner_filename); | ||
| 61 | + InitJoiner(buf.data(), buf.size()); | ||
| 62 | + } | ||
| 63 | + } | ||
| 64 | +#endif | ||
| 65 | + | ||
| 66 | + std::vector<Ort::Value> RunEncoder(Ort::Value features, | ||
| 67 | + Ort::Value features_length) { | ||
| 68 | + // (B, T, C) -> (B, C, T) | ||
| 69 | + features = Transpose12(allocator_, &features); | ||
| 70 | + | ||
| 71 | + std::array<Ort::Value, 2> encoder_inputs = {std::move(features), | ||
| 72 | + std::move(features_length)}; | ||
| 73 | + | ||
| 74 | + auto encoder_out = encoder_sess_->Run( | ||
| 75 | + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(), | ||
| 76 | + encoder_inputs.size(), encoder_output_names_ptr_.data(), | ||
| 77 | + encoder_output_names_ptr_.size()); | ||
| 78 | + | ||
| 79 | + return encoder_out; | ||
| 80 | + } | ||
| 81 | + | ||
| 82 | + std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder( | ||
| 83 | + Ort::Value targets, Ort::Value targets_length, | ||
| 84 | + std::vector<Ort::Value> states) { | ||
| 85 | + std::vector<Ort::Value> decoder_inputs; | ||
| 86 | + decoder_inputs.reserve(2 + states.size()); | ||
| 87 | + | ||
| 88 | + decoder_inputs.push_back(std::move(targets)); | ||
| 89 | + decoder_inputs.push_back(std::move(targets_length)); | ||
| 90 | + | ||
| 91 | + for (auto &s : states) { | ||
| 92 | + decoder_inputs.push_back(std::move(s)); | ||
| 93 | + } | ||
| 94 | + | ||
| 95 | + auto decoder_out = decoder_sess_->Run( | ||
| 96 | + {}, decoder_input_names_ptr_.data(), decoder_inputs.data(), | ||
| 97 | + decoder_inputs.size(), decoder_output_names_ptr_.data(), | ||
| 98 | + decoder_output_names_ptr_.size()); | ||
| 99 | + | ||
| 100 | + std::vector<Ort::Value> states_next; | ||
| 101 | + states_next.reserve(states.size()); | ||
| 102 | + | ||
| 103 | + // decoder_out[0]: decoder_output | ||
| 104 | + // decoder_out[1]: decoder_output_length | ||
| 105 | + // decoder_out[2:] states_next | ||
| 106 | + | ||
| 107 | + for (int32_t i = 0; i != states.size(); ++i) { | ||
| 108 | + states_next.push_back(std::move(decoder_out[i + 2])); | ||
| 109 | + } | ||
| 110 | + | ||
| 111 | + // we discard decoder_out[1] | ||
| 112 | + return {std::move(decoder_out[0]), std::move(states_next)}; | ||
| 113 | + } | ||
| 114 | + | ||
| 115 | + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) { | ||
| 116 | + std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out), | ||
| 117 | + std::move(decoder_out)}; | ||
| 118 | + auto logit = joiner_sess_->Run({}, joiner_input_names_ptr_.data(), | ||
| 119 | + joiner_input.data(), joiner_input.size(), | ||
| 120 | + joiner_output_names_ptr_.data(), | ||
| 121 | + joiner_output_names_ptr_.size()); | ||
| 122 | + | ||
| 123 | + return std::move(logit[0]); | ||
| 124 | + } | ||
| 125 | + | ||
| 126 | + std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const { | ||
| 127 | + std::array<int64_t, 3> s0_shape{pred_rnn_layers_, batch_size, pred_hidden_}; | ||
| 128 | + Ort::Value s0 = Ort::Value::CreateTensor<float>(allocator_, s0_shape.data(), | ||
| 129 | + s0_shape.size()); | ||
| 130 | + | ||
| 131 | + Fill<float>(&s0, 0); | ||
| 132 | + | ||
| 133 | + std::array<int64_t, 3> s1_shape{pred_rnn_layers_, batch_size, pred_hidden_}; | ||
| 134 | + | ||
| 135 | + Ort::Value s1 = Ort::Value::CreateTensor<float>(allocator_, s1_shape.data(), | ||
| 136 | + s1_shape.size()); | ||
| 137 | + | ||
| 138 | + Fill<float>(&s1, 0); | ||
| 139 | + | ||
| 140 | + std::vector<Ort::Value> states; | ||
| 141 | + | ||
| 142 | + states.reserve(2); | ||
| 143 | + states.push_back(std::move(s0)); | ||
| 144 | + states.push_back(std::move(s1)); | ||
| 145 | + | ||
| 146 | + return states; | ||
| 147 | + } | ||
| 148 | + | ||
| 149 | + int32_t SubsamplingFactor() const { return subsampling_factor_; } | ||
| 150 | + int32_t VocabSize() const { return vocab_size_; } | ||
| 151 | + | ||
| 152 | + OrtAllocator *Allocator() const { return allocator_; } | ||
| 153 | + | ||
| 154 | + std::string FeatureNormalizationMethod() const { return normalize_type_; } | ||
| 155 | + | ||
| 156 | + private: | ||
| 157 | + void InitEncoder(void *model_data, size_t model_data_length) { | ||
| 158 | + encoder_sess_ = std::make_unique<Ort::Session>( | ||
| 159 | + env_, model_data, model_data_length, sess_opts_); | ||
| 160 | + | ||
| 161 | + GetInputNames(encoder_sess_.get(), &encoder_input_names_, | ||
| 162 | + &encoder_input_names_ptr_); | ||
| 163 | + | ||
| 164 | + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, | ||
| 165 | + &encoder_output_names_ptr_); | ||
| 166 | + | ||
| 167 | + // get meta data | ||
| 168 | + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); | ||
| 169 | + if (config_.debug) { | ||
| 170 | + std::ostringstream os; | ||
| 171 | + os << "---encoder---\n"; | ||
| 172 | + PrintModelMetadata(os, meta_data); | ||
| 173 | + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
| 174 | + } | ||
| 175 | + | ||
| 176 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 177 | + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); | ||
| 178 | + | ||
| 179 | + // need to increase by 1 since the blank token is not included in computing | ||
| 180 | + // vocab_size in NeMo. | ||
| 181 | + vocab_size_ += 1; | ||
| 182 | + | ||
| 183 | + SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor"); | ||
| 184 | + SHERPA_ONNX_READ_META_DATA_STR(normalize_type_, "normalize_type"); | ||
| 185 | + SHERPA_ONNX_READ_META_DATA(pred_rnn_layers_, "pred_rnn_layers"); | ||
| 186 | + SHERPA_ONNX_READ_META_DATA(pred_hidden_, "pred_hidden"); | ||
| 187 | + | ||
| 188 | + if (normalize_type_ == "NA") { | ||
| 189 | + normalize_type_ = ""; | ||
| 190 | + } | ||
| 191 | + } | ||
| 192 | + | ||
| 193 | + void InitDecoder(void *model_data, size_t model_data_length) { | ||
| 194 | + decoder_sess_ = std::make_unique<Ort::Session>( | ||
| 195 | + env_, model_data, model_data_length, sess_opts_); | ||
| 196 | + | ||
| 197 | + GetInputNames(decoder_sess_.get(), &decoder_input_names_, | ||
| 198 | + &decoder_input_names_ptr_); | ||
| 199 | + | ||
| 200 | + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, | ||
| 201 | + &decoder_output_names_ptr_); | ||
| 202 | + } | ||
| 203 | + | ||
| 204 | + void InitJoiner(void *model_data, size_t model_data_length) { | ||
| 205 | + joiner_sess_ = std::make_unique<Ort::Session>( | ||
| 206 | + env_, model_data, model_data_length, sess_opts_); | ||
| 207 | + | ||
| 208 | + GetInputNames(joiner_sess_.get(), &joiner_input_names_, | ||
| 209 | + &joiner_input_names_ptr_); | ||
| 210 | + | ||
| 211 | + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, | ||
| 212 | + &joiner_output_names_ptr_); | ||
| 213 | + } | ||
| 214 | + | ||
| 215 | + private: | ||
| 216 | + OfflineModelConfig config_; | ||
| 217 | + Ort::Env env_; | ||
| 218 | + Ort::SessionOptions sess_opts_; | ||
| 219 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 220 | + | ||
| 221 | + std::unique_ptr<Ort::Session> encoder_sess_; | ||
| 222 | + std::unique_ptr<Ort::Session> decoder_sess_; | ||
| 223 | + std::unique_ptr<Ort::Session> joiner_sess_; | ||
| 224 | + | ||
| 225 | + std::vector<std::string> encoder_input_names_; | ||
| 226 | + std::vector<const char *> encoder_input_names_ptr_; | ||
| 227 | + | ||
| 228 | + std::vector<std::string> encoder_output_names_; | ||
| 229 | + std::vector<const char *> encoder_output_names_ptr_; | ||
| 230 | + | ||
| 231 | + std::vector<std::string> decoder_input_names_; | ||
| 232 | + std::vector<const char *> decoder_input_names_ptr_; | ||
| 233 | + | ||
| 234 | + std::vector<std::string> decoder_output_names_; | ||
| 235 | + std::vector<const char *> decoder_output_names_ptr_; | ||
| 236 | + | ||
| 237 | + std::vector<std::string> joiner_input_names_; | ||
| 238 | + std::vector<const char *> joiner_input_names_ptr_; | ||
| 239 | + | ||
| 240 | + std::vector<std::string> joiner_output_names_; | ||
| 241 | + std::vector<const char *> joiner_output_names_ptr_; | ||
| 242 | + | ||
| 243 | + int32_t vocab_size_ = 0; | ||
| 244 | + int32_t subsampling_factor_ = 8; | ||
| 245 | + std::string normalize_type_; | ||
| 246 | + int32_t pred_rnn_layers_ = -1; | ||
| 247 | + int32_t pred_hidden_ = -1; | ||
| 248 | +}; | ||
| 249 | + | ||
| 250 | +OfflineTransducerNeMoModel::OfflineTransducerNeMoModel( | ||
| 251 | + const OfflineModelConfig &config) | ||
| 252 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 253 | + | ||
| 254 | +#if __ANDROID_API__ >= 9 | ||
| 255 | +OfflineTransducerNeMoModel::OfflineTransducerNeMoModel( | ||
| 256 | + AAssetManager *mgr, const OfflineModelConfig &config) | ||
| 257 | + : impl_(std::make_unique<Impl>(mgr, config)) {} | ||
| 258 | +#endif | ||
| 259 | + | ||
| 260 | +OfflineTransducerNeMoModel::~OfflineTransducerNeMoModel() = default; | ||
| 261 | + | ||
| 262 | +std::vector<Ort::Value> OfflineTransducerNeMoModel::RunEncoder( | ||
| 263 | + Ort::Value features, Ort::Value features_length) const { | ||
| 264 | + return impl_->RunEncoder(std::move(features), std::move(features_length)); | ||
| 265 | +} | ||
| 266 | + | ||
| 267 | +std::pair<Ort::Value, std::vector<Ort::Value>> | ||
| 268 | +OfflineTransducerNeMoModel::RunDecoder(Ort::Value targets, | ||
| 269 | + Ort::Value targets_length, | ||
| 270 | + std::vector<Ort::Value> states) const { | ||
| 271 | + return impl_->RunDecoder(std::move(targets), std::move(targets_length), | ||
| 272 | + std::move(states)); | ||
| 273 | +} | ||
| 274 | + | ||
| 275 | +std::vector<Ort::Value> OfflineTransducerNeMoModel::GetDecoderInitStates( | ||
| 276 | + int32_t batch_size) const { | ||
| 277 | + return impl_->GetDecoderInitStates(batch_size); | ||
| 278 | +} | ||
| 279 | + | ||
| 280 | +Ort::Value OfflineTransducerNeMoModel::RunJoiner(Ort::Value encoder_out, | ||
| 281 | + Ort::Value decoder_out) const { | ||
| 282 | + return impl_->RunJoiner(std::move(encoder_out), std::move(decoder_out)); | ||
| 283 | +} | ||
| 284 | + | ||
| 285 | +int32_t OfflineTransducerNeMoModel::SubsamplingFactor() const { | ||
| 286 | + return impl_->SubsamplingFactor(); | ||
| 287 | +} | ||
| 288 | + | ||
| 289 | +int32_t OfflineTransducerNeMoModel::VocabSize() const { | ||
| 290 | + return impl_->VocabSize(); | ||
| 291 | +} | ||
| 292 | + | ||
| 293 | +OrtAllocator *OfflineTransducerNeMoModel::Allocator() const { | ||
| 294 | + return impl_->Allocator(); | ||
| 295 | +} | ||
| 296 | + | ||
| 297 | +std::string OfflineTransducerNeMoModel::FeatureNormalizationMethod() const { | ||
| 298 | + return impl_->FeatureNormalizationMethod(); | ||
| 299 | +} | ||
| 300 | + | ||
| 301 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-transducer-nemo-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_NEMO_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_NEMO_MODEL_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <string> | ||
| 9 | +#include <utility> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#if __ANDROID_API__ >= 9 | ||
| 13 | +#include "android/asset_manager.h" | ||
| 14 | +#include "android/asset_manager_jni.h" | ||
| 15 | +#endif | ||
| 16 | + | ||
| 17 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 18 | +#include "sherpa-onnx/csrc/offline-model-config.h" | ||
| 19 | + | ||
| 20 | +namespace sherpa_onnx { | ||
| 21 | + | ||
| 22 | +// see | ||
| 23 | +// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py#L40 | ||
| 24 | +// Its decoder is stateful, not stateless. | ||
| 25 | +class OfflineTransducerNeMoModel { | ||
| 26 | + public: | ||
| 27 | + explicit OfflineTransducerNeMoModel(const OfflineModelConfig &config); | ||
| 28 | + | ||
| 29 | +#if __ANDROID_API__ >= 9 | ||
| 30 | + OfflineTransducerNeMoModel(AAssetManager *mgr, | ||
| 31 | + const OfflineModelConfig &config); | ||
| 32 | +#endif | ||
| 33 | + | ||
| 34 | + ~OfflineTransducerNeMoModel(); | ||
| 35 | + | ||
| 36 | + /** Run the encoder. | ||
| 37 | + * | ||
| 38 | + * @param features A tensor of shape (N, T, C). It is changed in-place. | ||
| 39 | + * @param features_length A 1-D tensor of shape (N,) containing number of | ||
| 40 | + * valid frames in `features` before padding. | ||
| 41 | + * Its dtype is int64_t. | ||
| 42 | + * | ||
| 43 | + * @return Return a vector containing: | ||
| 44 | + * - encoder_out: A 3-D tensor of shape (N, T', encoder_dim) | ||
| 45 | + * - encoder_out_length: A 1-D tensor of shape (N,) containing number | ||
| 46 | + * of frames in `encoder_out` before padding. | ||
| 47 | + */ | ||
| 48 | + std::vector<Ort::Value> RunEncoder(Ort::Value features, | ||
| 49 | + Ort::Value features_length) const; | ||
| 50 | + | ||
| 51 | + /** Run the decoder network. | ||
| 52 | + * | ||
| 53 | + * @param targets A int32 tensor of shape (batch_size, 1) | ||
| 54 | + * @param targets_length A int32 tensor of shape (batch_size,) | ||
| 55 | + * @param states The states for the decoder model. | ||
| 56 | + * @return Return a vector: | ||
| 57 | + * - ans[0] is the decoder_out (a float tensor) | ||
| 58 | + * - ans[1] is the decoder_out_length (a int32 tensor) | ||
| 59 | + * - ans[2:] is the states_next | ||
| 60 | + */ | ||
| 61 | + std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder( | ||
| 62 | + Ort::Value targets, Ort::Value targets_length, | ||
| 63 | + std::vector<Ort::Value> states) const; | ||
| 64 | + | ||
| 65 | + std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const; | ||
| 66 | + | ||
| 67 | + /** Run the joint network. | ||
| 68 | + * | ||
| 69 | + * @param encoder_out Output of the encoder network. | ||
| 70 | + * @param decoder_out Output of the decoder network. | ||
| 71 | + * @return Return a tensor of shape (N, 1, 1, vocab_size) containing logits. | ||
| 72 | + */ | ||
| 73 | + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) const; | ||
| 74 | + | ||
| 75 | + /** Return the subsampling factor of the model. | ||
| 76 | + */ | ||
| 77 | + int32_t SubsamplingFactor() const; | ||
| 78 | + | ||
| 79 | + int32_t VocabSize() const; | ||
| 80 | + | ||
| 81 | + /** Return an allocator for allocating memory | ||
| 82 | + */ | ||
| 83 | + OrtAllocator *Allocator() const; | ||
| 84 | + | ||
| 85 | + // Possible values: | ||
| 86 | + // - per_feature | ||
| 87 | + // - all_features (not implemented yet) | ||
| 88 | + // - fixed_mean (not implemented) | ||
| 89 | + // - fixed_std (not implemented) | ||
| 90 | + // - or just leave it to empty | ||
| 91 | + // See | ||
| 92 | + // https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59 | ||
| 93 | + // for details | ||
| 94 | + std::string FeatureNormalizationMethod() const; | ||
| 95 | + | ||
| 96 | + private: | ||
| 97 | + class Impl; | ||
| 98 | + std::unique_ptr<Impl> impl_; | ||
| 99 | +}; | ||
| 100 | + | ||
| 101 | +} // namespace sherpa_onnx | ||
| 102 | + | ||
| 103 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_NEMO_MODEL_H_ |
| @@ -223,8 +223,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | @@ -223,8 +223,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | ||
| 223 | 223 | ||
| 224 | private: | 224 | private: |
| 225 | void InitDecoder() { | 225 | void InitDecoder() { |
| 226 | - if (!sym_.contains("<blk>") && !sym_.contains("<eps>") && | ||
| 227 | - !sym_.contains("<blank>")) { | 226 | + if (!sym_.Contains("<blk>") && !sym_.Contains("<eps>") && |
| 227 | + !sym_.Contains("<blank>")) { | ||
| 228 | SHERPA_ONNX_LOGE( | 228 | SHERPA_ONNX_LOGE( |
| 229 | "We expect that tokens.txt contains " | 229 | "We expect that tokens.txt contains " |
| 230 | "the symbol <blk> or <eps> or <blank> and its ID."); | 230 | "the symbol <blk> or <eps> or <blank> and its ID."); |
| @@ -232,12 +232,12 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | @@ -232,12 +232,12 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | ||
| 232 | } | 232 | } |
| 233 | 233 | ||
| 234 | int32_t blank_id = 0; | 234 | int32_t blank_id = 0; |
| 235 | - if (sym_.contains("<blk>")) { | 235 | + if (sym_.Contains("<blk>")) { |
| 236 | blank_id = sym_["<blk>"]; | 236 | blank_id = sym_["<blk>"]; |
| 237 | - } else if (sym_.contains("<eps>")) { | 237 | + } else if (sym_.Contains("<eps>")) { |
| 238 | // for tdnn models of the yesno recipe from icefall | 238 | // for tdnn models of the yesno recipe from icefall |
| 239 | blank_id = sym_["<eps>"]; | 239 | blank_id = sym_["<eps>"]; |
| 240 | - } else if (sym_.contains("<blank>")) { | 240 | + } else if (sym_.Contains("<blank>")) { |
| 241 | // for WeNet CTC models | 241 | // for WeNet CTC models |
| 242 | blank_id = sym_["<blank>"]; | 242 | blank_id = sym_["<blank>"]; |
| 243 | } | 243 | } |
| @@ -87,7 +87,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -87,7 +87,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 87 | model_(OnlineTransducerModel::Create(config.model_config)), | 87 | model_(OnlineTransducerModel::Create(config.model_config)), |
| 88 | sym_(config.model_config.tokens), | 88 | sym_(config.model_config.tokens), |
| 89 | endpoint_(config_.endpoint_config) { | 89 | endpoint_(config_.endpoint_config) { |
| 90 | - if (sym_.contains("<unk>")) { | 90 | + if (sym_.Contains("<unk>")) { |
| 91 | unk_id_ = sym_["<unk>"]; | 91 | unk_id_ = sym_["<unk>"]; |
| 92 | } | 92 | } |
| 93 | 93 | ||
| @@ -103,19 +103,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -103,19 +103,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 103 | } | 103 | } |
| 104 | 104 | ||
| 105 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( | 105 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( |
| 106 | - model_.get(), | ||
| 107 | - lm_.get(), | ||
| 108 | - config_.max_active_paths, | ||
| 109 | - config_.lm_config.scale, | ||
| 110 | - unk_id_, | ||
| 111 | - config_.blank_penalty, | 106 | + model_.get(), lm_.get(), config_.max_active_paths, |
| 107 | + config_.lm_config.scale, unk_id_, config_.blank_penalty, | ||
| 112 | config_.temperature_scale); | 108 | config_.temperature_scale); |
| 113 | 109 | ||
| 114 | } else if (config.decoding_method == "greedy_search") { | 110 | } else if (config.decoding_method == "greedy_search") { |
| 115 | decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>( | 111 | decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>( |
| 116 | - model_.get(), | ||
| 117 | - unk_id_, | ||
| 118 | - config_.blank_penalty, | 112 | + model_.get(), unk_id_, config_.blank_penalty, |
| 119 | config_.temperature_scale); | 113 | config_.temperature_scale); |
| 120 | 114 | ||
| 121 | } else { | 115 | } else { |
| @@ -132,7 +126,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -132,7 +126,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 132 | model_(OnlineTransducerModel::Create(mgr, config.model_config)), | 126 | model_(OnlineTransducerModel::Create(mgr, config.model_config)), |
| 133 | sym_(mgr, config.model_config.tokens), | 127 | sym_(mgr, config.model_config.tokens), |
| 134 | endpoint_(config_.endpoint_config) { | 128 | endpoint_(config_.endpoint_config) { |
| 135 | - if (sym_.contains("<unk>")) { | 129 | + if (sym_.Contains("<unk>")) { |
| 136 | unk_id_ = sym_["<unk>"]; | 130 | unk_id_ = sym_["<unk>"]; |
| 137 | } | 131 | } |
| 138 | 132 | ||
| @@ -151,19 +145,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -151,19 +145,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 151 | } | 145 | } |
| 152 | 146 | ||
| 153 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( | 147 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( |
| 154 | - model_.get(), | ||
| 155 | - lm_.get(), | ||
| 156 | - config_.max_active_paths, | ||
| 157 | - config_.lm_config.scale, | ||
| 158 | - unk_id_, | ||
| 159 | - config_.blank_penalty, | 148 | + model_.get(), lm_.get(), config_.max_active_paths, |
| 149 | + config_.lm_config.scale, unk_id_, config_.blank_penalty, | ||
| 160 | config_.temperature_scale); | 150 | config_.temperature_scale); |
| 161 | 151 | ||
| 162 | } else if (config.decoding_method == "greedy_search") { | 152 | } else if (config.decoding_method == "greedy_search") { |
| 163 | decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>( | 153 | decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>( |
| 164 | - model_.get(), | ||
| 165 | - unk_id_, | ||
| 166 | - config_.blank_penalty, | 154 | + model_.get(), unk_id_, config_.blank_penalty, |
| 167 | config_.temperature_scale); | 155 | config_.temperature_scale); |
| 168 | 156 | ||
| 169 | } else { | 157 | } else { |
| @@ -13,7 +13,7 @@ namespace sherpa_onnx { | @@ -13,7 +13,7 @@ namespace sherpa_onnx { | ||
| 13 | * It returns v[dim0_start:dim0_end, dim1_start:dim1_end, :] | 13 | * It returns v[dim0_start:dim0_end, dim1_start:dim1_end, :] |
| 14 | * | 14 | * |
| 15 | * @param allocator | 15 | * @param allocator |
| 16 | - * @param v A 2-D tensor. Its data type is T. | 16 | + * @param v A 3-D tensor. Its data type is T. |
| 17 | * @param dim0_start Start index of the first dimension.. | 17 | * @param dim0_start Start index of the first dimension.. |
| 18 | * @param dim0_end End index of the first dimension.. | 18 | * @param dim0_end End index of the first dimension.. |
| 19 | * @param dim1_start Start index of the second dimension. | 19 | * @param dim1_start Start index of the second dimension. |
| @@ -100,9 +100,9 @@ int32_t SymbolTable::operator[](const std::string &sym) const { | @@ -100,9 +100,9 @@ int32_t SymbolTable::operator[](const std::string &sym) const { | ||
| 100 | return sym2id_.at(sym); | 100 | return sym2id_.at(sym); |
| 101 | } | 101 | } |
| 102 | 102 | ||
| 103 | -bool SymbolTable::contains(int32_t id) const { return id2sym_.count(id) != 0; } | 103 | +bool SymbolTable::Contains(int32_t id) const { return id2sym_.count(id) != 0; } |
| 104 | 104 | ||
| 105 | -bool SymbolTable::contains(const std::string &sym) const { | 105 | +bool SymbolTable::Contains(const std::string &sym) const { |
| 106 | return sym2id_.count(sym) != 0; | 106 | return sym2id_.count(sym) != 0; |
| 107 | } | 107 | } |
| 108 | 108 |
| @@ -40,14 +40,16 @@ class SymbolTable { | @@ -40,14 +40,16 @@ class SymbolTable { | ||
| 40 | int32_t operator[](const std::string &sym) const; | 40 | int32_t operator[](const std::string &sym) const; |
| 41 | 41 | ||
| 42 | /// Return true if there is a symbol with the given ID. | 42 | /// Return true if there is a symbol with the given ID. |
| 43 | - bool contains(int32_t id) const; | 43 | + bool Contains(int32_t id) const; |
| 44 | 44 | ||
| 45 | /// Return true if there is a given symbol in the symbol table. | 45 | /// Return true if there is a given symbol in the symbol table. |
| 46 | - bool contains(const std::string &sym) const; | 46 | + bool Contains(const std::string &sym) const; |
| 47 | 47 | ||
| 48 | // for tokens.txt from Whisper | 48 | // for tokens.txt from Whisper |
| 49 | void ApplyBase64Decode(); | 49 | void ApplyBase64Decode(); |
| 50 | 50 | ||
| 51 | + int32_t NumSymbols() const { return id2sym_.size(); } | ||
| 52 | + | ||
| 51 | private: | 53 | private: |
| 52 | void Init(std::istream &is); | 54 | void Init(std::istream &is); |
| 53 | 55 |
| @@ -49,7 +49,7 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table, | @@ -49,7 +49,7 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table, | ||
| 49 | word = word.replace(0, 3, " "); | 49 | word = word.replace(0, 3, " "); |
| 50 | } | 50 | } |
| 51 | } | 51 | } |
| 52 | - if (symbol_table.contains(word)) { | 52 | + if (symbol_table.Contains(word)) { |
| 53 | int32_t id = symbol_table[word]; | 53 | int32_t id = symbol_table[word]; |
| 54 | tmp_ids.push_back(id); | 54 | tmp_ids.push_back(id); |
| 55 | } else { | 55 | } else { |
| @@ -14,10 +14,10 @@ namespace sherpa_onnx { | @@ -14,10 +14,10 @@ namespace sherpa_onnx { | ||
| 14 | static void PybindOfflineRecognizerConfig(py::module *m) { | 14 | static void PybindOfflineRecognizerConfig(py::module *m) { |
| 15 | using PyClass = OfflineRecognizerConfig; | 15 | using PyClass = OfflineRecognizerConfig; |
| 16 | py::class_<PyClass>(*m, "OfflineRecognizerConfig") | 16 | py::class_<PyClass>(*m, "OfflineRecognizerConfig") |
| 17 | - .def(py::init<const OfflineFeatureExtractorConfig &, | ||
| 18 | - const OfflineModelConfig &, const OfflineLMConfig &, | ||
| 19 | - const OfflineCtcFstDecoderConfig &, const std::string &, | ||
| 20 | - int32_t, const std::string &, float, float>(), | 17 | + .def(py::init<const FeatureExtractorConfig &, const OfflineModelConfig &, |
| 18 | + const OfflineLMConfig &, const OfflineCtcFstDecoderConfig &, | ||
| 19 | + const std::string &, int32_t, const std::string &, float, | ||
| 20 | + float>(), | ||
| 21 | py::arg("feat_config"), py::arg("model_config"), | 21 | py::arg("feat_config"), py::arg("model_config"), |
| 22 | py::arg("lm_config") = OfflineLMConfig(), | 22 | py::arg("lm_config") = OfflineLMConfig(), |
| 23 | py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), | 23 | py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), |
| @@ -25,6 +25,7 @@ Args: | @@ -25,6 +25,7 @@ Args: | ||
| 25 | static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT | 25 | static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT |
| 26 | using PyClass = OfflineRecognitionResult; | 26 | using PyClass = OfflineRecognitionResult; |
| 27 | py::class_<PyClass>(*m, "OfflineRecognitionResult") | 27 | py::class_<PyClass>(*m, "OfflineRecognitionResult") |
| 28 | + .def("__str__", &PyClass::AsJsonString) | ||
| 28 | .def_property_readonly( | 29 | .def_property_readonly( |
| 29 | "text", | 30 | "text", |
| 30 | [](const PyClass &self) -> py::str { | 31 | [](const PyClass &self) -> py::str { |
| @@ -37,18 +38,7 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT | @@ -37,18 +38,7 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT | ||
| 37 | "timestamps", [](const PyClass &self) { return self.timestamps; }); | 38 | "timestamps", [](const PyClass &self) { return self.timestamps; }); |
| 38 | } | 39 | } |
| 39 | 40 | ||
| 40 | -static void PybindOfflineFeatureExtractorConfig(py::module *m) { | ||
| 41 | - using PyClass = OfflineFeatureExtractorConfig; | ||
| 42 | - py::class_<PyClass>(*m, "OfflineFeatureExtractorConfig") | ||
| 43 | - .def(py::init<int32_t, int32_t>(), py::arg("sampling_rate") = 16000, | ||
| 44 | - py::arg("feature_dim") = 80) | ||
| 45 | - .def_readwrite("sampling_rate", &PyClass::sampling_rate) | ||
| 46 | - .def_readwrite("feature_dim", &PyClass::feature_dim) | ||
| 47 | - .def("__str__", &PyClass::ToString); | ||
| 48 | -} | ||
| 49 | - | ||
| 50 | void PybindOfflineStream(py::module *m) { | 41 | void PybindOfflineStream(py::module *m) { |
| 51 | - PybindOfflineFeatureExtractorConfig(m); | ||
| 52 | PybindOfflineRecognitionResult(m); | 42 | PybindOfflineRecognitionResult(m); |
| 53 | 43 | ||
| 54 | using PyClass = OfflineStream; | 44 | using PyClass = OfflineStream; |
| @@ -4,8 +4,8 @@ from pathlib import Path | @@ -4,8 +4,8 @@ from pathlib import Path | ||
| 4 | from typing import List, Optional | 4 | from typing import List, Optional |
| 5 | 5 | ||
| 6 | from _sherpa_onnx import ( | 6 | from _sherpa_onnx import ( |
| 7 | + FeatureExtractorConfig, | ||
| 7 | OfflineCtcFstDecoderConfig, | 8 | OfflineCtcFstDecoderConfig, |
| 8 | - OfflineFeatureExtractorConfig, | ||
| 9 | OfflineModelConfig, | 9 | OfflineModelConfig, |
| 10 | OfflineNemoEncDecCtcModelConfig, | 10 | OfflineNemoEncDecCtcModelConfig, |
| 11 | OfflineParaformerModelConfig, | 11 | OfflineParaformerModelConfig, |
| @@ -51,6 +51,7 @@ class OfflineRecognizer(object): | @@ -51,6 +51,7 @@ class OfflineRecognizer(object): | ||
| 51 | blank_penalty: float = 0.0, | 51 | blank_penalty: float = 0.0, |
| 52 | debug: bool = False, | 52 | debug: bool = False, |
| 53 | provider: str = "cpu", | 53 | provider: str = "cpu", |
| 54 | + model_type: str = "transducer", | ||
| 54 | ): | 55 | ): |
| 55 | """ | 56 | """ |
| 56 | Please refer to | 57 | Please refer to |
| @@ -106,10 +107,10 @@ class OfflineRecognizer(object): | @@ -106,10 +107,10 @@ class OfflineRecognizer(object): | ||
| 106 | num_threads=num_threads, | 107 | num_threads=num_threads, |
| 107 | debug=debug, | 108 | debug=debug, |
| 108 | provider=provider, | 109 | provider=provider, |
| 109 | - model_type="transducer", | 110 | + model_type=model_type, |
| 110 | ) | 111 | ) |
| 111 | 112 | ||
| 112 | - feat_config = OfflineFeatureExtractorConfig( | 113 | + feat_config = FeatureExtractorConfig( |
| 113 | sampling_rate=sample_rate, | 114 | sampling_rate=sample_rate, |
| 114 | feature_dim=feature_dim, | 115 | feature_dim=feature_dim, |
| 115 | ) | 116 | ) |
| @@ -182,7 +183,7 @@ class OfflineRecognizer(object): | @@ -182,7 +183,7 @@ class OfflineRecognizer(object): | ||
| 182 | model_type="paraformer", | 183 | model_type="paraformer", |
| 183 | ) | 184 | ) |
| 184 | 185 | ||
| 185 | - feat_config = OfflineFeatureExtractorConfig( | 186 | + feat_config = FeatureExtractorConfig( |
| 186 | sampling_rate=sample_rate, | 187 | sampling_rate=sample_rate, |
| 187 | feature_dim=feature_dim, | 188 | feature_dim=feature_dim, |
| 188 | ) | 189 | ) |
| @@ -246,7 +247,7 @@ class OfflineRecognizer(object): | @@ -246,7 +247,7 @@ class OfflineRecognizer(object): | ||
| 246 | model_type="nemo_ctc", | 247 | model_type="nemo_ctc", |
| 247 | ) | 248 | ) |
| 248 | 249 | ||
| 249 | - feat_config = OfflineFeatureExtractorConfig( | 250 | + feat_config = FeatureExtractorConfig( |
| 250 | sampling_rate=sample_rate, | 251 | sampling_rate=sample_rate, |
| 251 | feature_dim=feature_dim, | 252 | feature_dim=feature_dim, |
| 252 | ) | 253 | ) |
| @@ -326,7 +327,7 @@ class OfflineRecognizer(object): | @@ -326,7 +327,7 @@ class OfflineRecognizer(object): | ||
| 326 | model_type="whisper", | 327 | model_type="whisper", |
| 327 | ) | 328 | ) |
| 328 | 329 | ||
| 329 | - feat_config = OfflineFeatureExtractorConfig( | 330 | + feat_config = FeatureExtractorConfig( |
| 330 | sampling_rate=16000, | 331 | sampling_rate=16000, |
| 331 | feature_dim=80, | 332 | feature_dim=80, |
| 332 | ) | 333 | ) |
| @@ -389,7 +390,7 @@ class OfflineRecognizer(object): | @@ -389,7 +390,7 @@ class OfflineRecognizer(object): | ||
| 389 | model_type="tdnn", | 390 | model_type="tdnn", |
| 390 | ) | 391 | ) |
| 391 | 392 | ||
| 392 | - feat_config = OfflineFeatureExtractorConfig( | 393 | + feat_config = FeatureExtractorConfig( |
| 393 | sampling_rate=sample_rate, | 394 | sampling_rate=sample_rate, |
| 394 | feature_dim=feature_dim, | 395 | feature_dim=feature_dim, |
| 395 | ) | 396 | ) |
| @@ -453,7 +454,7 @@ class OfflineRecognizer(object): | @@ -453,7 +454,7 @@ class OfflineRecognizer(object): | ||
| 453 | model_type="wenet_ctc", | 454 | model_type="wenet_ctc", |
| 454 | ) | 455 | ) |
| 455 | 456 | ||
| 456 | - feat_config = OfflineFeatureExtractorConfig( | 457 | + feat_config = FeatureExtractorConfig( |
| 457 | sampling_rate=sample_rate, | 458 | sampling_rate=sample_rate, |
| 458 | feature_dim=feature_dim, | 459 | feature_dim=feature_dim, |
| 459 | ) | 460 | ) |
-
请 注册 或 登录 后发表评论