Committed by
GitHub
Add C++ support for streaming NeMo CTC models. (#857)
正在显示
22 个修改的文件
包含
782 行增加
和
41 行删除
| @@ -14,6 +14,28 @@ echo "PATH: $PATH" | @@ -14,6 +14,28 @@ echo "PATH: $PATH" | ||
| 14 | which $EXE | 14 | which $EXE |
| 15 | 15 | ||
| 16 | log "------------------------------------------------------------" | 16 | log "------------------------------------------------------------" |
| 17 | +log "Run streaming NeMo CTC " | ||
| 18 | +log "------------------------------------------------------------" | ||
| 19 | + | ||
| 20 | +url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms.tar.bz2 | ||
| 21 | +name=$(basename $url) | ||
| 22 | +repo=$(basename -s .tar.bz2 $name) | ||
| 23 | + | ||
| 24 | +curl -SL -O $url | ||
| 25 | +tar xvf $name | ||
| 26 | +rm $name | ||
| 27 | +ls -lh $repo | ||
| 28 | + | ||
| 29 | +$EXE \ | ||
| 30 | + --nemo-ctc-model=$repo/model.onnx \ | ||
| 31 | + --tokens=$repo/tokens.txt \ | ||
| 32 | + $repo/test_wavs/0.wav \ | ||
| 33 | + $repo/test_wavs/1.wav \ | ||
| 34 | + $repo/test_wavs/8k.wav | ||
| 35 | + | ||
| 36 | +rm -rf $repo | ||
| 37 | + | ||
| 38 | +log "------------------------------------------------------------" | ||
| 17 | log "Run streaming Zipformer2 CTC HLG decoding " | 39 | log "Run streaming Zipformer2 CTC HLG decoding " |
| 18 | log "------------------------------------------------------------" | 40 | log "------------------------------------------------------------" |
| 19 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 | 41 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 |
| @@ -8,6 +8,19 @@ log() { | @@ -8,6 +8,19 @@ log() { | ||
| 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" |
| 9 | } | 9 | } |
| 10 | 10 | ||
| 11 | +log "test online NeMo CTC" | ||
| 12 | + | ||
| 13 | +url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms.tar.bz2 | ||
| 14 | +name=$(basename $url) | ||
| 15 | +repo=$(basename -s .tar.bz2 $name) | ||
| 16 | + | ||
| 17 | +curl -SL -O $url | ||
| 18 | +tar xvf $name | ||
| 19 | +rm $name | ||
| 20 | +ls -lh $repo | ||
| 21 | +python3 ./python-api-examples/online-nemo-ctc-decode-files.py | ||
| 22 | +rm -rf $repo | ||
| 23 | + | ||
| 11 | log "test offline punctuation" | 24 | log "test offline punctuation" |
| 12 | 25 | ||
| 13 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 | 26 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 |
| @@ -128,6 +128,14 @@ jobs: | @@ -128,6 +128,14 @@ jobs: | ||
| 128 | name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} | 128 | name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} |
| 129 | path: install/* | 129 | path: install/* |
| 130 | 130 | ||
| 131 | + - name: Test online CTC | ||
| 132 | + shell: bash | ||
| 133 | + run: | | ||
| 134 | + export PATH=$PWD/build/bin:$PATH | ||
| 135 | + export EXE=sherpa-onnx | ||
| 136 | + | ||
| 137 | + .github/scripts/test-online-ctc.sh | ||
| 138 | + | ||
| 131 | - name: Test offline transducer | 139 | - name: Test offline transducer |
| 132 | shell: bash | 140 | shell: bash |
| 133 | run: | | 141 | run: | |
| @@ -163,14 +171,6 @@ jobs: | @@ -163,14 +171,6 @@ jobs: | ||
| 163 | 171 | ||
| 164 | .github/scripts/test-offline-ctc.sh | 172 | .github/scripts/test-offline-ctc.sh |
| 165 | 173 | ||
| 166 | - - name: Test online CTC | ||
| 167 | - shell: bash | ||
| 168 | - run: | | ||
| 169 | - export PATH=$PWD/build/bin:$PATH | ||
| 170 | - export EXE=sherpa-onnx | ||
| 171 | - | ||
| 172 | - .github/scripts/test-online-ctc.sh | ||
| 173 | - | ||
| 174 | - name: Test offline punctuation | 174 | - name: Test offline punctuation |
| 175 | shell: bash | 175 | shell: bash |
| 176 | run: | | 176 | run: | |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | + | ||
| 3 | +""" | ||
| 4 | +This file shows how to use a streaming CTC model from NeMo | ||
| 5 | +to decode files. | ||
| 6 | + | ||
| 7 | +Please download model files from | ||
| 8 | +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models | ||
| 9 | + | ||
| 10 | + | ||
| 11 | +The example model is converted from | ||
| 12 | +https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_hybrid_large_streaming_80ms | ||
| 13 | +""" | ||
| 14 | + | ||
| 15 | +from pathlib import Path | ||
| 16 | + | ||
| 17 | +import numpy as np | ||
| 18 | +import sherpa_onnx | ||
| 19 | +import soundfile as sf | ||
| 20 | + | ||
| 21 | + | ||
| 22 | +def create_recognizer(): | ||
| 23 | + model = "./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/model.onnx" | ||
| 24 | + tokens = "./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/tokens.txt" | ||
| 25 | + | ||
| 26 | + test_wav = "./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/test_wavs/0.wav" | ||
| 27 | + | ||
| 28 | + if not Path(model).is_file() or not Path(test_wav).is_file(): | ||
| 29 | + raise ValueError( | ||
| 30 | + """Please download model files from | ||
| 31 | + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models | ||
| 32 | + """ | ||
| 33 | + ) | ||
| 34 | + return ( | ||
| 35 | + sherpa_onnx.OnlineRecognizer.from_nemo_ctc( | ||
| 36 | + model=model, | ||
| 37 | + tokens=tokens, | ||
| 38 | + debug=True, | ||
| 39 | + ), | ||
| 40 | + test_wav, | ||
| 41 | + ) | ||
| 42 | + | ||
| 43 | + | ||
| 44 | +def main(): | ||
| 45 | + recognizer, wave_filename = create_recognizer() | ||
| 46 | + | ||
| 47 | + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) | ||
| 48 | + audio = audio[:, 0] # only use the first channel | ||
| 49 | + | ||
| 50 | + # audio is a 1-D float32 numpy array normalized to the range [-1, 1] | ||
| 51 | + # sample_rate does not need to be 16000 Hz | ||
| 52 | + | ||
| 53 | + stream = recognizer.create_stream() | ||
| 54 | + stream.accept_waveform(sample_rate, audio) | ||
| 55 | + | ||
| 56 | + tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32) | ||
| 57 | + stream.accept_waveform(sample_rate, tail_paddings) | ||
| 58 | + stream.input_finished() | ||
| 59 | + | ||
| 60 | + while recognizer.is_ready(stream): | ||
| 61 | + recognizer.decode_stream(stream) | ||
| 62 | + print(wave_filename) | ||
| 63 | + print(recognizer.get_result_all(stream)) | ||
| 64 | + | ||
| 65 | + | ||
| 66 | +if __name__ == "__main__": | ||
| 67 | + main() |
| @@ -100,7 +100,7 @@ class OnnxModel: | @@ -100,7 +100,7 @@ class OnnxModel: | ||
| 100 | dtype=torch.float32, | 100 | dtype=torch.float32, |
| 101 | ).numpy() | 101 | ).numpy() |
| 102 | 102 | ||
| 103 | - self.cache_last_channel_len = torch.ones([1], dtype=torch.int64).numpy() | 103 | + self.cache_last_channel_len = torch.zeros([1], dtype=torch.int64).numpy() |
| 104 | 104 | ||
| 105 | def __call__(self, x: np.ndarray): | 105 | def __call__(self, x: np.ndarray): |
| 106 | # x: (T, C) | 106 | # x: (T, C) |
| @@ -142,7 +142,7 @@ class OnnxModel: | @@ -142,7 +142,7 @@ class OnnxModel: | ||
| 142 | dtype=torch.float32, | 142 | dtype=torch.float32, |
| 143 | ).numpy() | 143 | ).numpy() |
| 144 | 144 | ||
| 145 | - self.cache_last_channel_len = torch.ones([1], dtype=torch.int64).numpy() | 145 | + self.cache_last_channel_len = torch.zeros([1], dtype=torch.int64).numpy() |
| 146 | 146 | ||
| 147 | def run_encoder(self, x: np.ndarray): | 147 | def run_encoder(self, x: np.ndarray): |
| 148 | # x: (T, C) | 148 | # x: (T, C) |
| @@ -61,6 +61,8 @@ set(sources | @@ -61,6 +61,8 @@ set(sources | ||
| 61 | online-lm.cc | 61 | online-lm.cc |
| 62 | online-lstm-transducer-model.cc | 62 | online-lstm-transducer-model.cc |
| 63 | online-model-config.cc | 63 | online-model-config.cc |
| 64 | + online-nemo-ctc-model-config.cc | ||
| 65 | + online-nemo-ctc-model.cc | ||
| 64 | online-paraformer-model-config.cc | 66 | online-paraformer-model-config.cc |
| 65 | online-paraformer-model.cc | 67 | online-paraformer-model.cc |
| 66 | online-recognizer-impl.cc | 68 | online-recognizer-impl.cc |
| @@ -4,11 +4,12 @@ | @@ -4,11 +4,12 @@ | ||
| 4 | #ifndef SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_ | 4 | #ifndef SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_ |
| 5 | #define SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_ | 5 | #define SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_ |
| 6 | 6 | ||
| 7 | +#include <math.h> | ||
| 8 | + | ||
| 7 | #include <memory> | 9 | #include <memory> |
| 8 | #include <string> | 10 | #include <string> |
| 9 | #include <utility> | 11 | #include <utility> |
| 10 | #include <vector> | 12 | #include <vector> |
| 11 | -#include <math.h> | ||
| 12 | 13 | ||
| 13 | #if __ANDROID_API__ >= 9 | 14 | #if __ANDROID_API__ >= 9 |
| 14 | #include "android/asset_manager.h" | 15 | #include "android/asset_manager.h" |
| @@ -61,7 +62,9 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { | @@ -61,7 +62,9 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { | ||
| 61 | 62 | ||
| 62 | int32_t segment_size = 20; | 63 | int32_t segment_size = 20; |
| 63 | int32_t max_len = 200; | 64 | int32_t max_len = 200; |
| 64 | - int32_t num_segments = ceil(((float)token_ids.size() + segment_size - 1) / segment_size); | 65 | + int32_t num_segments = |
| 66 | + ceil((static_cast<float>(token_ids.size()) + segment_size - 1) / | ||
| 67 | + segment_size); | ||
| 65 | 68 | ||
| 66 | std::vector<int32_t> punctuations; | 69 | std::vector<int32_t> punctuations; |
| 67 | int32_t last = -1; | 70 | int32_t last = -1; |
| @@ -10,6 +10,7 @@ | @@ -10,6 +10,7 @@ | ||
| 10 | #include <string> | 10 | #include <string> |
| 11 | 11 | ||
| 12 | #include "sherpa-onnx/csrc/macros.h" | 12 | #include "sherpa-onnx/csrc/macros.h" |
| 13 | +#include "sherpa-onnx/csrc/online-nemo-ctc-model.h" | ||
| 13 | #include "sherpa-onnx/csrc/online-wenet-ctc-model.h" | 14 | #include "sherpa-onnx/csrc/online-wenet-ctc-model.h" |
| 14 | #include "sherpa-onnx/csrc/online-zipformer2-ctc-model.h" | 15 | #include "sherpa-onnx/csrc/online-zipformer2-ctc-model.h" |
| 15 | #include "sherpa-onnx/csrc/onnx-utils.h" | 16 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| @@ -22,6 +23,8 @@ std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create( | @@ -22,6 +23,8 @@ std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create( | ||
| 22 | return std::make_unique<OnlineWenetCtcModel>(config); | 23 | return std::make_unique<OnlineWenetCtcModel>(config); |
| 23 | } else if (!config.zipformer2_ctc.model.empty()) { | 24 | } else if (!config.zipformer2_ctc.model.empty()) { |
| 24 | return std::make_unique<OnlineZipformer2CtcModel>(config); | 25 | return std::make_unique<OnlineZipformer2CtcModel>(config); |
| 26 | + } else if (!config.nemo_ctc.model.empty()) { | ||
| 27 | + return std::make_unique<OnlineNeMoCtcModel>(config); | ||
| 25 | } else { | 28 | } else { |
| 26 | SHERPA_ONNX_LOGE("Please specify a CTC model"); | 29 | SHERPA_ONNX_LOGE("Please specify a CTC model"); |
| 27 | exit(-1); | 30 | exit(-1); |
| @@ -36,6 +39,8 @@ std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create( | @@ -36,6 +39,8 @@ std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create( | ||
| 36 | return std::make_unique<OnlineWenetCtcModel>(mgr, config); | 39 | return std::make_unique<OnlineWenetCtcModel>(mgr, config); |
| 37 | } else if (!config.zipformer2_ctc.model.empty()) { | 40 | } else if (!config.zipformer2_ctc.model.empty()) { |
| 38 | return std::make_unique<OnlineZipformer2CtcModel>(mgr, config); | 41 | return std::make_unique<OnlineZipformer2CtcModel>(mgr, config); |
| 42 | + } else if (!config.nemo_ctc.model.empty()) { | ||
| 43 | + return std::make_unique<OnlineNeMoCtcModel>(mgr, config); | ||
| 39 | } else { | 44 | } else { |
| 40 | SHERPA_ONNX_LOGE("Please specify a CTC model"); | 45 | SHERPA_ONNX_LOGE("Please specify a CTC model"); |
| 41 | exit(-1); | 46 | exit(-1); |
| @@ -15,6 +15,7 @@ void OnlineModelConfig::Register(ParseOptions *po) { | @@ -15,6 +15,7 @@ void OnlineModelConfig::Register(ParseOptions *po) { | ||
| 15 | paraformer.Register(po); | 15 | paraformer.Register(po); |
| 16 | wenet_ctc.Register(po); | 16 | wenet_ctc.Register(po); |
| 17 | zipformer2_ctc.Register(po); | 17 | zipformer2_ctc.Register(po); |
| 18 | + nemo_ctc.Register(po); | ||
| 18 | 19 | ||
| 19 | po->Register("tokens", &tokens, "Path to tokens.txt"); | 20 | po->Register("tokens", &tokens, "Path to tokens.txt"); |
| 20 | 21 | ||
| @@ -31,11 +32,11 @@ void OnlineModelConfig::Register(ParseOptions *po) { | @@ -31,11 +32,11 @@ void OnlineModelConfig::Register(ParseOptions *po) { | ||
| 31 | po->Register("provider", &provider, | 32 | po->Register("provider", &provider, |
| 32 | "Specify a provider to use: cpu, cuda, coreml"); | 33 | "Specify a provider to use: cpu, cuda, coreml"); |
| 33 | 34 | ||
| 34 | - po->Register( | ||
| 35 | - "model-type", &model_type, | ||
| 36 | - "Specify it to reduce model initialization time. " | ||
| 37 | - "Valid values are: conformer, lstm, zipformer, zipformer2, wenet_ctc" | ||
| 38 | - "All other values lead to loading the model twice."); | 35 | + po->Register("model-type", &model_type, |
| 36 | + "Specify it to reduce model initialization time. " | ||
| 37 | + "Valid values are: conformer, lstm, zipformer, zipformer2, " | ||
| 38 | + "wenet_ctc, nemo_ctc. " | ||
| 39 | + "All other values lead to loading the model twice."); | ||
| 39 | } | 40 | } |
| 40 | 41 | ||
| 41 | bool OnlineModelConfig::Validate() const { | 42 | bool OnlineModelConfig::Validate() const { |
| @@ -61,6 +62,10 @@ bool OnlineModelConfig::Validate() const { | @@ -61,6 +62,10 @@ bool OnlineModelConfig::Validate() const { | ||
| 61 | return zipformer2_ctc.Validate(); | 62 | return zipformer2_ctc.Validate(); |
| 62 | } | 63 | } |
| 63 | 64 | ||
| 65 | + if (!nemo_ctc.model.empty()) { | ||
| 66 | + return nemo_ctc.Validate(); | ||
| 67 | + } | ||
| 68 | + | ||
| 64 | return transducer.Validate(); | 69 | return transducer.Validate(); |
| 65 | } | 70 | } |
| 66 | 71 | ||
| @@ -72,6 +77,7 @@ std::string OnlineModelConfig::ToString() const { | @@ -72,6 +77,7 @@ std::string OnlineModelConfig::ToString() const { | ||
| 72 | os << "paraformer=" << paraformer.ToString() << ", "; | 77 | os << "paraformer=" << paraformer.ToString() << ", "; |
| 73 | os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; | 78 | os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; |
| 74 | os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", "; | 79 | os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", "; |
| 80 | + os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; | ||
| 75 | os << "tokens=\"" << tokens << "\", "; | 81 | os << "tokens=\"" << tokens << "\", "; |
| 76 | os << "num_threads=" << num_threads << ", "; | 82 | os << "num_threads=" << num_threads << ", "; |
| 77 | os << "warm_up=" << warm_up << ", "; | 83 | os << "warm_up=" << warm_up << ", "; |
| @@ -6,6 +6,7 @@ | @@ -6,6 +6,7 @@ | ||
| 6 | 6 | ||
| 7 | #include <string> | 7 | #include <string> |
| 8 | 8 | ||
| 9 | +#include "sherpa-onnx/csrc/online-nemo-ctc-model-config.h" | ||
| 9 | #include "sherpa-onnx/csrc/online-paraformer-model-config.h" | 10 | #include "sherpa-onnx/csrc/online-paraformer-model-config.h" |
| 10 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" | 11 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" |
| 11 | #include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h" | 12 | #include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h" |
| @@ -18,6 +19,7 @@ struct OnlineModelConfig { | @@ -18,6 +19,7 @@ struct OnlineModelConfig { | ||
| 18 | OnlineParaformerModelConfig paraformer; | 19 | OnlineParaformerModelConfig paraformer; |
| 19 | OnlineWenetCtcModelConfig wenet_ctc; | 20 | OnlineWenetCtcModelConfig wenet_ctc; |
| 20 | OnlineZipformer2CtcModelConfig zipformer2_ctc; | 21 | OnlineZipformer2CtcModelConfig zipformer2_ctc; |
| 22 | + OnlineNeMoCtcModelConfig nemo_ctc; | ||
| 21 | std::string tokens; | 23 | std::string tokens; |
| 22 | int32_t num_threads = 1; | 24 | int32_t num_threads = 1; |
| 23 | int32_t warm_up = 0; | 25 | int32_t warm_up = 0; |
| @@ -30,6 +32,7 @@ struct OnlineModelConfig { | @@ -30,6 +32,7 @@ struct OnlineModelConfig { | ||
| 30 | // - zipformer, zipformer transducer from icefall | 32 | // - zipformer, zipformer transducer from icefall |
| 31 | // - zipformer2, zipformer2 transducer or CTC from icefall | 33 | // - zipformer2, zipformer2 transducer or CTC from icefall |
| 32 | // - wenet_ctc, wenet CTC model | 34 | // - wenet_ctc, wenet CTC model |
| 35 | + // - nemo_ctc, NeMo CTC model | ||
| 33 | // | 36 | // |
| 34 | // All other values are invalid and lead to loading the model twice. | 37 | // All other values are invalid and lead to loading the model twice. |
| 35 | std::string model_type; | 38 | std::string model_type; |
| @@ -39,6 +42,7 @@ struct OnlineModelConfig { | @@ -39,6 +42,7 @@ struct OnlineModelConfig { | ||
| 39 | const OnlineParaformerModelConfig ¶former, | 42 | const OnlineParaformerModelConfig ¶former, |
| 40 | const OnlineWenetCtcModelConfig &wenet_ctc, | 43 | const OnlineWenetCtcModelConfig &wenet_ctc, |
| 41 | const OnlineZipformer2CtcModelConfig &zipformer2_ctc, | 44 | const OnlineZipformer2CtcModelConfig &zipformer2_ctc, |
| 45 | + const OnlineNeMoCtcModelConfig &nemo_ctc, | ||
| 42 | const std::string &tokens, int32_t num_threads, | 46 | const std::string &tokens, int32_t num_threads, |
| 43 | int32_t warm_up, bool debug, const std::string &provider, | 47 | int32_t warm_up, bool debug, const std::string &provider, |
| 44 | const std::string &model_type) | 48 | const std::string &model_type) |
| @@ -46,6 +50,7 @@ struct OnlineModelConfig { | @@ -46,6 +50,7 @@ struct OnlineModelConfig { | ||
| 46 | paraformer(paraformer), | 50 | paraformer(paraformer), |
| 47 | wenet_ctc(wenet_ctc), | 51 | wenet_ctc(wenet_ctc), |
| 48 | zipformer2_ctc(zipformer2_ctc), | 52 | zipformer2_ctc(zipformer2_ctc), |
| 53 | + nemo_ctc(nemo_ctc), | ||
| 49 | tokens(tokens), | 54 | tokens(tokens), |
| 50 | num_threads(num_threads), | 55 | num_threads(num_threads), |
| 51 | warm_up(warm_up), | 56 | warm_up(warm_up), |
| 1 | +// sherpa-onnx/csrc/online-nemo-ctc-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/online-nemo-ctc-model-config.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 8 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void OnlineNeMoCtcModelConfig::Register(ParseOptions *po) { | ||
| 13 | + po->Register("nemo-ctc-model", &model, | ||
| 14 | + "Path to CTC model.onnx from NeMo. Please see " | ||
| 15 | + "https://github.com/k2-fsa/sherpa-onnx/pull/843"); | ||
| 16 | +} | ||
| 17 | + | ||
| 18 | +bool OnlineNeMoCtcModelConfig::Validate() const { | ||
| 19 | + if (!FileExists(model)) { | ||
| 20 | + SHERPA_ONNX_LOGE("NeMo CTC model '%s' does not exist", model.c_str()); | ||
| 21 | + return false; | ||
| 22 | + } | ||
| 23 | + | ||
| 24 | + return true; | ||
| 25 | +} | ||
| 26 | + | ||
| 27 | +std::string OnlineNeMoCtcModelConfig::ToString() const { | ||
| 28 | + std::ostringstream os; | ||
| 29 | + | ||
| 30 | + os << "OnlineNeMoCtcModelConfig("; | ||
| 31 | + os << "model=\"" << model << "\")"; | ||
| 32 | + | ||
| 33 | + return os.str(); | ||
| 34 | +} | ||
| 35 | + | ||
| 36 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/online-nemo-ctc-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +struct OnlineNeMoCtcModelConfig { | ||
| 14 | + std::string model; | ||
| 15 | + | ||
| 16 | + OnlineNeMoCtcModelConfig() = default; | ||
| 17 | + | ||
| 18 | + explicit OnlineNeMoCtcModelConfig(const std::string &model) : model(model) {} | ||
| 19 | + | ||
| 20 | + void Register(ParseOptions *po); | ||
| 21 | + bool Validate() const; | ||
| 22 | + | ||
| 23 | + std::string ToString() const; | ||
| 24 | +}; | ||
| 25 | + | ||
| 26 | +} // namespace sherpa_onnx | ||
| 27 | + | ||
| 28 | +#endif // SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_ |
sherpa-onnx/csrc/online-nemo-ctc-model.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/online-nemo-ctc-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/online-nemo-ctc-model.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <cmath> | ||
| 9 | +#include <string> | ||
| 10 | + | ||
| 11 | +#if __ANDROID_API__ >= 9 | ||
| 12 | +#include "android/asset_manager.h" | ||
| 13 | +#include "android/asset_manager_jni.h" | ||
| 14 | +#endif | ||
| 15 | + | ||
| 16 | +#include "sherpa-onnx/csrc/cat.h" | ||
| 17 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 18 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 19 | +#include "sherpa-onnx/csrc/session.h" | ||
| 20 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 21 | +#include "sherpa-onnx/csrc/transpose.h" | ||
| 22 | +#include "sherpa-onnx/csrc/unbind.h" | ||
| 23 | + | ||
| 24 | +namespace sherpa_onnx { | ||
| 25 | + | ||
| 26 | +class OnlineNeMoCtcModel::Impl { | ||
| 27 | + public: | ||
| 28 | + explicit Impl(const OnlineModelConfig &config) | ||
| 29 | + : config_(config), | ||
| 30 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 31 | + sess_opts_(GetSessionOptions(config)), | ||
| 32 | + allocator_{} { | ||
| 33 | + { | ||
| 34 | + auto buf = ReadFile(config.nemo_ctc.model); | ||
| 35 | + Init(buf.data(), buf.size()); | ||
| 36 | + } | ||
| 37 | + } | ||
| 38 | + | ||
| 39 | +#if __ANDROID_API__ >= 9 | ||
| 40 | + Impl(AAssetManager *mgr, const OnlineModelConfig &config) | ||
| 41 | + : config_(config), | ||
| 42 | + env_(ORT_LOGGING_LEVEL_WARNING), | ||
| 43 | + sess_opts_(GetSessionOptions(config)), | ||
| 44 | + allocator_{} { | ||
| 45 | + { | ||
| 46 | + auto buf = ReadFile(mgr, config.nemo_ctc.model); | ||
| 47 | + Init(buf.data(), buf.size()); | ||
| 48 | + } | ||
| 49 | + } | ||
| 50 | +#endif | ||
| 51 | + | ||
| 52 | + std::vector<Ort::Value> Forward(Ort::Value x, | ||
| 53 | + std::vector<Ort::Value> states) { | ||
| 54 | + Ort::Value &cache_last_channel = states[0]; | ||
| 55 | + Ort::Value &cache_last_time = states[1]; | ||
| 56 | + Ort::Value &cache_last_channel_len = states[2]; | ||
| 57 | + | ||
| 58 | + int32_t batch_size = x.GetTensorTypeAndShapeInfo().GetShape()[0]; | ||
| 59 | + | ||
| 60 | + std::array<int64_t, 1> length_shape{batch_size}; | ||
| 61 | + | ||
| 62 | + Ort::Value length = Ort::Value::CreateTensor<int64_t>( | ||
| 63 | + allocator_, length_shape.data(), length_shape.size()); | ||
| 64 | + | ||
| 65 | + int64_t *p_length = length.GetTensorMutableData<int64_t>(); | ||
| 66 | + | ||
| 67 | + std::fill(p_length, p_length + batch_size, ChunkLength()); | ||
| 68 | + | ||
| 69 | + // (B, T, C) -> (B, C, T) | ||
| 70 | + x = Transpose12(allocator_, &x); | ||
| 71 | + | ||
| 72 | + std::array<Ort::Value, 5> inputs = { | ||
| 73 | + std::move(x), View(&length), std::move(cache_last_channel), | ||
| 74 | + std::move(cache_last_time), std::move(cache_last_channel_len)}; | ||
| 75 | + | ||
| 76 | + auto out = | ||
| 77 | + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 78 | + output_names_ptr_.data(), output_names_ptr_.size()); | ||
| 79 | + // out[0]: logit | ||
| 80 | + // out[1] logit_length | ||
| 81 | + // out[2:] states_next | ||
| 82 | + // | ||
| 83 | + // we need to remove out[1] | ||
| 84 | + | ||
| 85 | + std::vector<Ort::Value> ans; | ||
| 86 | + ans.reserve(out.size() - 1); | ||
| 87 | + | ||
| 88 | + for (int32_t i = 0; i != out.size(); ++i) { | ||
| 89 | + if (i == 1) { | ||
| 90 | + continue; | ||
| 91 | + } | ||
| 92 | + | ||
| 93 | + ans.push_back(std::move(out[i])); | ||
| 94 | + } | ||
| 95 | + | ||
| 96 | + return ans; | ||
| 97 | + } | ||
| 98 | + | ||
| 99 | + int32_t VocabSize() const { return vocab_size_; } | ||
| 100 | + | ||
| 101 | + int32_t ChunkLength() const { return window_size_; } | ||
| 102 | + | ||
| 103 | + int32_t ChunkShift() const { return chunk_shift_; } | ||
| 104 | + | ||
| 105 | + OrtAllocator *Allocator() const { return allocator_; } | ||
| 106 | + | ||
| 107 | + // Return a vector containing 3 tensors | ||
| 108 | + // - cache_last_channel | ||
| 109 | + // - cache_last_time_ | ||
| 110 | + // - cache_last_channel_len | ||
| 111 | + std::vector<Ort::Value> GetInitStates() { | ||
| 112 | + std::vector<Ort::Value> ans; | ||
| 113 | + ans.reserve(3); | ||
| 114 | + ans.push_back(View(&cache_last_channel_)); | ||
| 115 | + ans.push_back(View(&cache_last_time_)); | ||
| 116 | + ans.push_back(View(&cache_last_channel_len_)); | ||
| 117 | + | ||
| 118 | + return ans; | ||
| 119 | + } | ||
| 120 | + | ||
| 121 | + std::vector<Ort::Value> StackStates( | ||
| 122 | + std::vector<std::vector<Ort::Value>> states) const { | ||
| 123 | + int32_t batch_size = static_cast<int32_t>(states.size()); | ||
| 124 | + if (batch_size == 1) { | ||
| 125 | + return std::move(states[0]); | ||
| 126 | + } | ||
| 127 | + | ||
| 128 | + std::vector<Ort::Value> ans; | ||
| 129 | + | ||
| 130 | + // stack cache_last_channel | ||
| 131 | + std::vector<const Ort::Value *> buf(batch_size); | ||
| 132 | + | ||
| 133 | + // there are 3 states to be stacked | ||
| 134 | + for (int32_t i = 0; i != 3; ++i) { | ||
| 135 | + buf.clear(); | ||
| 136 | + buf.reserve(batch_size); | ||
| 137 | + | ||
| 138 | + for (int32_t b = 0; b != batch_size; ++b) { | ||
| 139 | + assert(states[b].size() == 3); | ||
| 140 | + buf.push_back(&states[b][i]); | ||
| 141 | + } | ||
| 142 | + | ||
| 143 | + Ort::Value c{nullptr}; | ||
| 144 | + if (i == 2) { | ||
| 145 | + c = Cat<int64_t>(allocator_, buf, 0); | ||
| 146 | + } else { | ||
| 147 | + c = Cat(allocator_, buf, 0); | ||
| 148 | + } | ||
| 149 | + | ||
| 150 | + ans.push_back(std::move(c)); | ||
| 151 | + } | ||
| 152 | + | ||
| 153 | + return ans; | ||
| 154 | + } | ||
| 155 | + | ||
| 156 | + std::vector<std::vector<Ort::Value>> UnStackStates( | ||
| 157 | + std::vector<Ort::Value> states) const { | ||
| 158 | + assert(states.size() == 3); | ||
| 159 | + | ||
| 160 | + std::vector<std::vector<Ort::Value>> ans; | ||
| 161 | + | ||
| 162 | + auto shape = states[0].GetTensorTypeAndShapeInfo().GetShape(); | ||
| 163 | + int32_t batch_size = shape[0]; | ||
| 164 | + ans.resize(batch_size); | ||
| 165 | + | ||
| 166 | + if (batch_size == 1) { | ||
| 167 | + ans[0] = std::move(states); | ||
| 168 | + return ans; | ||
| 169 | + } | ||
| 170 | + | ||
| 171 | + for (int32_t i = 0; i != 3; ++i) { | ||
| 172 | + std::vector<Ort::Value> v; | ||
| 173 | + if (i == 2) { | ||
| 174 | + v = Unbind<int64_t>(allocator_, &states[i], 0); | ||
| 175 | + } else { | ||
| 176 | + v = Unbind(allocator_, &states[i], 0); | ||
| 177 | + } | ||
| 178 | + | ||
| 179 | + assert(v.size() == batch_size); | ||
| 180 | + | ||
| 181 | + for (int32_t b = 0; b != batch_size; ++b) { | ||
| 182 | + ans[b].push_back(std::move(v[b])); | ||
| 183 | + } | ||
| 184 | + } | ||
| 185 | + | ||
| 186 | + return ans; | ||
| 187 | + } | ||
| 188 | + | ||
| 189 | + private: | ||
| 190 | + void Init(void *model_data, size_t model_data_length) { | ||
| 191 | + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length, | ||
| 192 | + sess_opts_); | ||
| 193 | + | ||
| 194 | + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | ||
| 195 | + | ||
| 196 | + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | ||
| 197 | + | ||
| 198 | + // get meta data | ||
| 199 | + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); | ||
| 200 | + if (config_.debug) { | ||
| 201 | + std::ostringstream os; | ||
| 202 | + PrintModelMetadata(os, meta_data); | ||
| 203 | + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
| 204 | + } | ||
| 205 | + | ||
| 206 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 207 | + SHERPA_ONNX_READ_META_DATA(window_size_, "window_size"); | ||
| 208 | + SHERPA_ONNX_READ_META_DATA(chunk_shift_, "chunk_shift"); | ||
| 209 | + SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor"); | ||
| 210 | + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); | ||
| 211 | + SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim1_, | ||
| 212 | + "cache_last_channel_dim1"); | ||
| 213 | + SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim2_, | ||
| 214 | + "cache_last_channel_dim2"); | ||
| 215 | + SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim3_, | ||
| 216 | + "cache_last_channel_dim3"); | ||
| 217 | + SHERPA_ONNX_READ_META_DATA(cache_last_time_dim1_, "cache_last_time_dim1"); | ||
| 218 | + SHERPA_ONNX_READ_META_DATA(cache_last_time_dim2_, "cache_last_time_dim2"); | ||
| 219 | + SHERPA_ONNX_READ_META_DATA(cache_last_time_dim3_, "cache_last_time_dim3"); | ||
| 220 | + | ||
| 221 | + // need to increase by 1 since the blank token is not included in computing | ||
| 222 | + // vocab_size in NeMo. | ||
| 223 | + vocab_size_ += 1; | ||
| 224 | + | ||
| 225 | + InitStates(); | ||
| 226 | + } | ||
| 227 | + | ||
| 228 | + void InitStates() { | ||
| 229 | + std::array<int64_t, 4> cache_last_channel_shape{1, cache_last_channel_dim1_, | ||
| 230 | + cache_last_channel_dim2_, | ||
| 231 | + cache_last_channel_dim3_}; | ||
| 232 | + | ||
| 233 | + cache_last_channel_ = Ort::Value::CreateTensor<float>( | ||
| 234 | + allocator_, cache_last_channel_shape.data(), | ||
| 235 | + cache_last_channel_shape.size()); | ||
| 236 | + | ||
| 237 | + Fill<float>(&cache_last_channel_, 0); | ||
| 238 | + | ||
| 239 | + std::array<int64_t, 4> cache_last_time_shape{ | ||
| 240 | + 1, cache_last_time_dim1_, cache_last_time_dim2_, cache_last_time_dim3_}; | ||
| 241 | + | ||
| 242 | + cache_last_time_ = Ort::Value::CreateTensor<float>( | ||
| 243 | + allocator_, cache_last_time_shape.data(), cache_last_time_shape.size()); | ||
| 244 | + | ||
| 245 | + Fill<float>(&cache_last_time_, 0); | ||
| 246 | + | ||
| 247 | + int64_t shape = 1; | ||
| 248 | + cache_last_channel_len_ = | ||
| 249 | + Ort::Value::CreateTensor<int64_t>(allocator_, &shape, 1); | ||
| 250 | + | ||
| 251 | + cache_last_channel_len_.GetTensorMutableData<int64_t>()[0] = 0; | ||
| 252 | + } | ||
| 253 | + | ||
| 254 | + private: | ||
| 255 | + OnlineModelConfig config_; | ||
| 256 | + Ort::Env env_; | ||
| 257 | + Ort::SessionOptions sess_opts_; | ||
| 258 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 259 | + | ||
| 260 | + std::unique_ptr<Ort::Session> sess_; | ||
| 261 | + | ||
| 262 | + std::vector<std::string> input_names_; | ||
| 263 | + std::vector<const char *> input_names_ptr_; | ||
| 264 | + | ||
| 265 | + std::vector<std::string> output_names_; | ||
| 266 | + std::vector<const char *> output_names_ptr_; | ||
| 267 | + | ||
| 268 | + int32_t window_size_; | ||
| 269 | + int32_t chunk_shift_; | ||
| 270 | + int32_t subsampling_factor_; | ||
| 271 | + int32_t vocab_size_; | ||
| 272 | + int32_t cache_last_channel_dim1_; | ||
| 273 | + int32_t cache_last_channel_dim2_; | ||
| 274 | + int32_t cache_last_channel_dim3_; | ||
| 275 | + int32_t cache_last_time_dim1_; | ||
| 276 | + int32_t cache_last_time_dim2_; | ||
| 277 | + int32_t cache_last_time_dim3_; | ||
| 278 | + | ||
| 279 | + Ort::Value cache_last_channel_{nullptr}; | ||
| 280 | + Ort::Value cache_last_time_{nullptr}; | ||
| 281 | + Ort::Value cache_last_channel_len_{nullptr}; | ||
| 282 | +}; | ||
| 283 | + | ||
| 284 | +OnlineNeMoCtcModel::OnlineNeMoCtcModel(const OnlineModelConfig &config) | ||
| 285 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 286 | + | ||
| 287 | +#if __ANDROID_API__ >= 9 | ||
| 288 | +OnlineNeMoCtcModel::OnlineNeMoCtcModel(AAssetManager *mgr, | ||
| 289 | + const OnlineModelConfig &config) | ||
| 290 | + : impl_(std::make_unique<Impl>(mgr, config)) {} | ||
| 291 | +#endif | ||
| 292 | + | ||
| 293 | +OnlineNeMoCtcModel::~OnlineNeMoCtcModel() = default; | ||
| 294 | + | ||
| 295 | +std::vector<Ort::Value> OnlineNeMoCtcModel::Forward( | ||
| 296 | + Ort::Value x, std::vector<Ort::Value> states) const { | ||
| 297 | + return impl_->Forward(std::move(x), std::move(states)); | ||
| 298 | +} | ||
| 299 | + | ||
| 300 | +int32_t OnlineNeMoCtcModel::VocabSize() const { return impl_->VocabSize(); } | ||
| 301 | + | ||
| 302 | +int32_t OnlineNeMoCtcModel::ChunkLength() const { return impl_->ChunkLength(); } | ||
| 303 | + | ||
| 304 | +int32_t OnlineNeMoCtcModel::ChunkShift() const { return impl_->ChunkShift(); } | ||
| 305 | + | ||
| 306 | +OrtAllocator *OnlineNeMoCtcModel::Allocator() const { | ||
| 307 | + return impl_->Allocator(); | ||
| 308 | +} | ||
| 309 | + | ||
| 310 | +std::vector<Ort::Value> OnlineNeMoCtcModel::GetInitStates() const { | ||
| 311 | + return impl_->GetInitStates(); | ||
| 312 | +} | ||
| 313 | + | ||
| 314 | +std::vector<Ort::Value> OnlineNeMoCtcModel::StackStates( | ||
| 315 | + std::vector<std::vector<Ort::Value>> states) const { | ||
| 316 | + return impl_->StackStates(std::move(states)); | ||
| 317 | +} | ||
| 318 | + | ||
| 319 | +std::vector<std::vector<Ort::Value>> OnlineNeMoCtcModel::UnStackStates( | ||
| 320 | + std::vector<Ort::Value> states) const { | ||
| 321 | + return impl_->UnStackStates(std::move(states)); | ||
| 322 | +} | ||
| 323 | + | ||
| 324 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/online-nemo-ctc-model.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/online-nemo-ctc-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <utility> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#if __ANDROID_API__ >= 9 | ||
| 12 | +#include "android/asset_manager.h" | ||
| 13 | +#include "android/asset_manager_jni.h" | ||
| 14 | +#endif | ||
| 15 | + | ||
| 16 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 17 | +#include "sherpa-onnx/csrc/online-ctc-model.h" | ||
| 18 | +#include "sherpa-onnx/csrc/online-model-config.h" | ||
| 19 | + | ||
| 20 | +namespace sherpa_onnx { | ||
| 21 | + | ||
| 22 | +class OnlineNeMoCtcModel : public OnlineCtcModel { | ||
| 23 | + public: | ||
| 24 | + explicit OnlineNeMoCtcModel(const OnlineModelConfig &config); | ||
| 25 | + | ||
| 26 | +#if __ANDROID_API__ >= 9 | ||
| 27 | + OnlineNeMoCtcModel(AAssetManager *mgr, const OnlineModelConfig &config); | ||
| 28 | +#endif | ||
| 29 | + | ||
| 30 | + ~OnlineNeMoCtcModel() override; | ||
| 31 | + | ||
| 32 | + // A list of 3 tensors: | ||
| 33 | + // - cache_last_channel | ||
| 34 | + // - cache_last_time | ||
| 35 | + // - cache_last_channel_len | ||
| 36 | + std::vector<Ort::Value> GetInitStates() const override; | ||
| 37 | + | ||
| 38 | + std::vector<Ort::Value> StackStates( | ||
| 39 | + std::vector<std::vector<Ort::Value>> states) const override; | ||
| 40 | + | ||
| 41 | + std::vector<std::vector<Ort::Value>> UnStackStates( | ||
| 42 | + std::vector<Ort::Value> states) const override; | ||
| 43 | + | ||
| 44 | + /** | ||
| 45 | + * | ||
| 46 | + * @param x A 3-D tensor of shape (N, T, C). N has to be 1. | ||
| 47 | + * @param states It is from GetInitStates() or returned from this method. | ||
| 48 | + * | ||
| 49 | + * @return Return a list of tensors | ||
| 50 | + * - ans[0] contains log_probs, of shape (N, T, C) | ||
| 51 | + * - ans[1:] contains next_states | ||
| 52 | + */ | ||
| 53 | + std::vector<Ort::Value> Forward( | ||
| 54 | + Ort::Value x, std::vector<Ort::Value> states) const override; | ||
| 55 | + | ||
| 56 | + /** Return the vocabulary size of the model | ||
| 57 | + */ | ||
| 58 | + int32_t VocabSize() const override; | ||
| 59 | + | ||
| 60 | + /** Return an allocator for allocating memory | ||
| 61 | + */ | ||
| 62 | + OrtAllocator *Allocator() const override; | ||
| 63 | + | ||
| 64 | + // The model accepts this number of frames before subsampling as input | ||
| 65 | + int32_t ChunkLength() const override; | ||
| 66 | + | ||
| 67 | + // Similar to frame_shift in feature extractor, after processing | ||
| 68 | + // ChunkLength() frames, we advance by ChunkShift() frames | ||
| 69 | + // before we process the next chunk. | ||
| 70 | + int32_t ChunkShift() const override; | ||
| 71 | + | ||
| 72 | + bool SupportBatchProcessing() const override { return true; } | ||
| 73 | + | ||
| 74 | + private: | ||
| 75 | + class Impl; | ||
| 76 | + std::unique_ptr<Impl> impl_; | ||
| 77 | +}; | ||
| 78 | + | ||
| 79 | +} // namespace sherpa_onnx | ||
| 80 | + | ||
| 81 | +#endif // SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_H_ |
| @@ -21,7 +21,8 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | @@ -21,7 +21,8 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | ||
| 21 | } | 21 | } |
| 22 | 22 | ||
| 23 | if (!config.model_config.wenet_ctc.model.empty() || | 23 | if (!config.model_config.wenet_ctc.model.empty() || |
| 24 | - !config.model_config.zipformer2_ctc.model.empty()) { | 24 | + !config.model_config.zipformer2_ctc.model.empty() || |
| 25 | + !config.model_config.nemo_ctc.model.empty()) { | ||
| 25 | return std::make_unique<OnlineRecognizerCtcImpl>(config); | 26 | return std::make_unique<OnlineRecognizerCtcImpl>(config); |
| 26 | } | 27 | } |
| 27 | 28 | ||
| @@ -41,7 +42,8 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | @@ -41,7 +42,8 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | ||
| 41 | } | 42 | } |
| 42 | 43 | ||
| 43 | if (!config.model_config.wenet_ctc.model.empty() || | 44 | if (!config.model_config.wenet_ctc.model.empty() || |
| 44 | - !config.model_config.zipformer2_ctc.model.empty()) { | 45 | + !config.model_config.zipformer2_ctc.model.empty() || |
| 46 | + !config.model_config.nemo_ctc.model.empty()) { | ||
| 45 | return std::make_unique<OnlineRecognizerCtcImpl>(mgr, config); | 47 | return std::make_unique<OnlineRecognizerCtcImpl>(mgr, config); |
| 46 | } | 48 | } |
| 47 | 49 |
| @@ -23,6 +23,7 @@ set(srcs | @@ -23,6 +23,7 @@ set(srcs | ||
| 23 | online-ctc-fst-decoder-config.cc | 23 | online-ctc-fst-decoder-config.cc |
| 24 | online-lm-config.cc | 24 | online-lm-config.cc |
| 25 | online-model-config.cc | 25 | online-model-config.cc |
| 26 | + online-nemo-ctc-model-config.cc | ||
| 26 | online-paraformer-model-config.cc | 27 | online-paraformer-model-config.cc |
| 27 | online-recognizer.cc | 28 | online-recognizer.cc |
| 28 | online-stream.cc | 29 | online-stream.cc |
| @@ -9,6 +9,7 @@ | @@ -9,6 +9,7 @@ | ||
| 9 | 9 | ||
| 10 | #include "sherpa-onnx/csrc/online-model-config.h" | 10 | #include "sherpa-onnx/csrc/online-model-config.h" |
| 11 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" | 11 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" |
| 12 | +#include "sherpa-onnx/python/csrc/online-nemo-ctc-model-config.h" | ||
| 12 | #include "sherpa-onnx/python/csrc/online-paraformer-model-config.h" | 13 | #include "sherpa-onnx/python/csrc/online-paraformer-model-config.h" |
| 13 | #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" | 14 | #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" |
| 14 | #include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h" | 15 | #include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h" |
| @@ -21,26 +22,30 @@ void PybindOnlineModelConfig(py::module *m) { | @@ -21,26 +22,30 @@ void PybindOnlineModelConfig(py::module *m) { | ||
| 21 | PybindOnlineParaformerModelConfig(m); | 22 | PybindOnlineParaformerModelConfig(m); |
| 22 | PybindOnlineWenetCtcModelConfig(m); | 23 | PybindOnlineWenetCtcModelConfig(m); |
| 23 | PybindOnlineZipformer2CtcModelConfig(m); | 24 | PybindOnlineZipformer2CtcModelConfig(m); |
| 25 | + PybindOnlineNeMoCtcModelConfig(m); | ||
| 24 | 26 | ||
| 25 | using PyClass = OnlineModelConfig; | 27 | using PyClass = OnlineModelConfig; |
| 26 | py::class_<PyClass>(*m, "OnlineModelConfig") | 28 | py::class_<PyClass>(*m, "OnlineModelConfig") |
| 27 | .def(py::init<const OnlineTransducerModelConfig &, | 29 | .def(py::init<const OnlineTransducerModelConfig &, |
| 28 | const OnlineParaformerModelConfig &, | 30 | const OnlineParaformerModelConfig &, |
| 29 | const OnlineWenetCtcModelConfig &, | 31 | const OnlineWenetCtcModelConfig &, |
| 30 | - const OnlineZipformer2CtcModelConfig &, const std::string &, | 32 | + const OnlineZipformer2CtcModelConfig &, |
| 33 | + const OnlineNeMoCtcModelConfig &, const std::string &, | ||
| 31 | int32_t, int32_t, bool, const std::string &, | 34 | int32_t, int32_t, bool, const std::string &, |
| 32 | const std::string &>(), | 35 | const std::string &>(), |
| 33 | py::arg("transducer") = OnlineTransducerModelConfig(), | 36 | py::arg("transducer") = OnlineTransducerModelConfig(), |
| 34 | py::arg("paraformer") = OnlineParaformerModelConfig(), | 37 | py::arg("paraformer") = OnlineParaformerModelConfig(), |
| 35 | py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), | 38 | py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), |
| 36 | py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(), | 39 | py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(), |
| 37 | - py::arg("tokens"), py::arg("num_threads"), py::arg("warm_up") = 0, | 40 | + py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(), py::arg("tokens"), |
| 41 | + py::arg("num_threads"), py::arg("warm_up") = 0, | ||
| 38 | py::arg("debug") = false, py::arg("provider") = "cpu", | 42 | py::arg("debug") = false, py::arg("provider") = "cpu", |
| 39 | py::arg("model_type") = "") | 43 | py::arg("model_type") = "") |
| 40 | .def_readwrite("transducer", &PyClass::transducer) | 44 | .def_readwrite("transducer", &PyClass::transducer) |
| 41 | .def_readwrite("paraformer", &PyClass::paraformer) | 45 | .def_readwrite("paraformer", &PyClass::paraformer) |
| 42 | .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) | 46 | .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) |
| 43 | .def_readwrite("zipformer2_ctc", &PyClass::zipformer2_ctc) | 47 | .def_readwrite("zipformer2_ctc", &PyClass::zipformer2_ctc) |
| 48 | + .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) | ||
| 44 | .def_readwrite("tokens", &PyClass::tokens) | 49 | .def_readwrite("tokens", &PyClass::tokens) |
| 45 | .def_readwrite("num_threads", &PyClass::num_threads) | 50 | .def_readwrite("num_threads", &PyClass::num_threads) |
| 46 | .def_readwrite("debug", &PyClass::debug) | 51 | .def_readwrite("debug", &PyClass::debug) |
| 1 | +// sherpa-onnx/python/csrc/online-nemo-ctc-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/python/csrc/online-nemo-ctc-model-config.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/online-nemo-ctc-model-config.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +void PybindOnlineNeMoCtcModelConfig(py::module *m) { | ||
| 15 | + using PyClass = OnlineNeMoCtcModelConfig; | ||
| 16 | + py::class_<PyClass>(*m, "OnlineNeMoCtcModelConfig") | ||
| 17 | + .def(py::init<const std::string &>(), py::arg("model")) | ||
| 18 | + .def_readwrite("model", &PyClass::model) | ||
| 19 | + .def("__str__", &PyClass::ToString); | ||
| 20 | +} | ||
| 21 | + | ||
| 22 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/python/csrc/online-nemo-ctc-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindOnlineNeMoCtcModelConfig(py::module *m); | ||
| 13 | + | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_ |
| @@ -42,6 +42,8 @@ static void PybindOnlineRecognizerResult(py::module *m) { | @@ -42,6 +42,8 @@ static void PybindOnlineRecognizerResult(py::module *m) { | ||
| 42 | "segment", [](PyClass &self) -> int32_t { return self.segment; }) | 42 | "segment", [](PyClass &self) -> int32_t { return self.segment; }) |
| 43 | .def_property_readonly( | 43 | .def_property_readonly( |
| 44 | "is_final", [](PyClass &self) -> bool { return self.is_final; }) | 44 | "is_final", [](PyClass &self) -> bool { return self.is_final; }) |
| 45 | + .def("__str__", &PyClass::AsJsonString, | ||
| 46 | + py::call_guard<py::gil_scoped_release>()) | ||
| 45 | .def("as_json_string", &PyClass::AsJsonString, | 47 | .def("as_json_string", &PyClass::AsJsonString, |
| 46 | py::call_guard<py::gil_scoped_release>()); | 48 | py::call_guard<py::gil_scoped_release>()); |
| 47 | } | 49 | } |
| @@ -50,29 +52,17 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | @@ -50,29 +52,17 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | ||
| 50 | using PyClass = OnlineRecognizerConfig; | 52 | using PyClass = OnlineRecognizerConfig; |
| 51 | py::class_<PyClass>(*m, "OnlineRecognizerConfig") | 53 | py::class_<PyClass>(*m, "OnlineRecognizerConfig") |
| 52 | .def( | 54 | .def( |
| 53 | - py::init<const FeatureExtractorConfig &, | ||
| 54 | - const OnlineModelConfig &, | ||
| 55 | - const OnlineLMConfig &, | ||
| 56 | - const EndpointConfig &, | ||
| 57 | - const OnlineCtcFstDecoderConfig &, | ||
| 58 | - bool, | ||
| 59 | - const std::string &, | ||
| 60 | - int32_t, | ||
| 61 | - const std::string &, | ||
| 62 | - float, | ||
| 63 | - float, | ||
| 64 | - float>(), | ||
| 65 | - py::arg("feat_config"), | ||
| 66 | - py::arg("model_config"), | 55 | + py::init<const FeatureExtractorConfig &, const OnlineModelConfig &, |
| 56 | + const OnlineLMConfig &, const EndpointConfig &, | ||
| 57 | + const OnlineCtcFstDecoderConfig &, bool, const std::string &, | ||
| 58 | + int32_t, const std::string &, float, float, float>(), | ||
| 59 | + py::arg("feat_config"), py::arg("model_config"), | ||
| 67 | py::arg("lm_config") = OnlineLMConfig(), | 60 | py::arg("lm_config") = OnlineLMConfig(), |
| 68 | py::arg("endpoint_config") = EndpointConfig(), | 61 | py::arg("endpoint_config") = EndpointConfig(), |
| 69 | py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(), | 62 | py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(), |
| 70 | - py::arg("enable_endpoint"), | ||
| 71 | - py::arg("decoding_method"), | ||
| 72 | - py::arg("max_active_paths") = 4, | ||
| 73 | - py::arg("hotwords_file") = "", | ||
| 74 | - py::arg("hotwords_score") = 0, | ||
| 75 | - py::arg("blank_penalty") = 0.0, | 63 | + py::arg("enable_endpoint"), py::arg("decoding_method"), |
| 64 | + py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", | ||
| 65 | + py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0, | ||
| 76 | py::arg("temperature_scale") = 2.0) | 66 | py::arg("temperature_scale") = 2.0) |
| 77 | .def_readwrite("feat_config", &PyClass::feat_config) | 67 | .def_readwrite("feat_config", &PyClass::feat_config) |
| 78 | .def_readwrite("model_config", &PyClass::model_config) | 68 | .def_readwrite("model_config", &PyClass::model_config) |
| @@ -12,9 +12,11 @@ from _sherpa_onnx import ( | @@ -12,9 +12,11 @@ from _sherpa_onnx import ( | ||
| 12 | from _sherpa_onnx import OnlineRecognizer as _Recognizer | 12 | from _sherpa_onnx import OnlineRecognizer as _Recognizer |
| 13 | from _sherpa_onnx import ( | 13 | from _sherpa_onnx import ( |
| 14 | OnlineRecognizerConfig, | 14 | OnlineRecognizerConfig, |
| 15 | + OnlineRecognizerResult, | ||
| 15 | OnlineStream, | 16 | OnlineStream, |
| 16 | OnlineTransducerModelConfig, | 17 | OnlineTransducerModelConfig, |
| 17 | OnlineWenetCtcModelConfig, | 18 | OnlineWenetCtcModelConfig, |
| 19 | + OnlineNeMoCtcModelConfig, | ||
| 18 | OnlineZipformer2CtcModelConfig, | 20 | OnlineZipformer2CtcModelConfig, |
| 19 | OnlineCtcFstDecoderConfig, | 21 | OnlineCtcFstDecoderConfig, |
| 20 | ) | 22 | ) |
| @@ -59,6 +61,7 @@ class OnlineRecognizer(object): | @@ -59,6 +61,7 @@ class OnlineRecognizer(object): | ||
| 59 | lm: str = "", | 61 | lm: str = "", |
| 60 | lm_scale: float = 0.1, | 62 | lm_scale: float = 0.1, |
| 61 | temperature_scale: float = 2.0, | 63 | temperature_scale: float = 2.0, |
| 64 | + debug: bool = False, | ||
| 62 | ): | 65 | ): |
| 63 | """ | 66 | """ |
| 64 | Please refer to | 67 | Please refer to |
| @@ -154,6 +157,7 @@ class OnlineRecognizer(object): | @@ -154,6 +157,7 @@ class OnlineRecognizer(object): | ||
| 154 | num_threads=num_threads, | 157 | num_threads=num_threads, |
| 155 | provider=provider, | 158 | provider=provider, |
| 156 | model_type=model_type, | 159 | model_type=model_type, |
| 160 | + debug=debug, | ||
| 157 | ) | 161 | ) |
| 158 | 162 | ||
| 159 | feat_config = FeatureExtractorConfig( | 163 | feat_config = FeatureExtractorConfig( |
| @@ -220,6 +224,7 @@ class OnlineRecognizer(object): | @@ -220,6 +224,7 @@ class OnlineRecognizer(object): | ||
| 220 | rule3_min_utterance_length: float = 20.0, | 224 | rule3_min_utterance_length: float = 20.0, |
| 221 | decoding_method: str = "greedy_search", | 225 | decoding_method: str = "greedy_search", |
| 222 | provider: str = "cpu", | 226 | provider: str = "cpu", |
| 227 | + debug: bool = False, | ||
| 223 | ): | 228 | ): |
| 224 | """ | 229 | """ |
| 225 | Please refer to | 230 | Please refer to |
| @@ -283,6 +288,7 @@ class OnlineRecognizer(object): | @@ -283,6 +288,7 @@ class OnlineRecognizer(object): | ||
| 283 | num_threads=num_threads, | 288 | num_threads=num_threads, |
| 284 | provider=provider, | 289 | provider=provider, |
| 285 | model_type="paraformer", | 290 | model_type="paraformer", |
| 291 | + debug=debug, | ||
| 286 | ) | 292 | ) |
| 287 | 293 | ||
| 288 | feat_config = FeatureExtractorConfig( | 294 | feat_config = FeatureExtractorConfig( |
| @@ -324,6 +330,7 @@ class OnlineRecognizer(object): | @@ -324,6 +330,7 @@ class OnlineRecognizer(object): | ||
| 324 | ctc_graph: str = "", | 330 | ctc_graph: str = "", |
| 325 | ctc_max_active: int = 3000, | 331 | ctc_max_active: int = 3000, |
| 326 | provider: str = "cpu", | 332 | provider: str = "cpu", |
| 333 | + debug: bool = False, | ||
| 327 | ): | 334 | ): |
| 328 | """ | 335 | """ |
| 329 | Please refer to | 336 | Please refer to |
| @@ -386,6 +393,7 @@ class OnlineRecognizer(object): | @@ -386,6 +393,7 @@ class OnlineRecognizer(object): | ||
| 386 | tokens=tokens, | 393 | tokens=tokens, |
| 387 | num_threads=num_threads, | 394 | num_threads=num_threads, |
| 388 | provider=provider, | 395 | provider=provider, |
| 396 | + debug=debug, | ||
| 389 | ) | 397 | ) |
| 390 | 398 | ||
| 391 | feat_config = FeatureExtractorConfig( | 399 | feat_config = FeatureExtractorConfig( |
| @@ -418,6 +426,106 @@ class OnlineRecognizer(object): | @@ -418,6 +426,106 @@ class OnlineRecognizer(object): | ||
| 418 | return self | 426 | return self |
| 419 | 427 | ||
| 420 | @classmethod | 428 | @classmethod |
| 429 | + def from_nemo_ctc( | ||
| 430 | + cls, | ||
| 431 | + tokens: str, | ||
| 432 | + model: str, | ||
| 433 | + num_threads: int = 2, | ||
| 434 | + sample_rate: float = 16000, | ||
| 435 | + feature_dim: int = 80, | ||
| 436 | + enable_endpoint_detection: bool = False, | ||
| 437 | + rule1_min_trailing_silence: float = 2.4, | ||
| 438 | + rule2_min_trailing_silence: float = 1.2, | ||
| 439 | + rule3_min_utterance_length: float = 20.0, | ||
| 440 | + decoding_method: str = "greedy_search", | ||
| 441 | + provider: str = "cpu", | ||
| 442 | + debug: bool = False, | ||
| 443 | + ): | ||
| 444 | + """ | ||
| 445 | + Please refer to | ||
| 446 | + `<https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models>`_ | ||
| 447 | + to download pre-trained models. | ||
| 448 | + | ||
| 449 | + Args: | ||
| 450 | + tokens: | ||
| 451 | + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two | ||
| 452 | + columns:: | ||
| 453 | + | ||
| 454 | + symbol integer_id | ||
| 455 | + | ||
| 456 | + model: | ||
| 457 | + Path to ``model.onnx``. | ||
| 458 | + num_threads: | ||
| 459 | + Number of threads for neural network computation. | ||
| 460 | + sample_rate: | ||
| 461 | + Sample rate of the training data used to train the model. | ||
| 462 | + feature_dim: | ||
| 463 | + Dimension of the feature used to train the model. | ||
| 464 | + enable_endpoint_detection: | ||
| 465 | + True to enable endpoint detection. False to disable endpoint | ||
| 466 | + detection. | ||
| 467 | + rule1_min_trailing_silence: | ||
| 468 | + Used only when enable_endpoint_detection is True. If the duration | ||
| 469 | + of trailing silence in seconds is larger than this value, we assume | ||
| 470 | + an endpoint is detected. | ||
| 471 | + rule2_min_trailing_silence: | ||
| 472 | + Used only when enable_endpoint_detection is True. If we have decoded | ||
| 473 | + something that is nonsilence and if the duration of trailing silence | ||
| 474 | + in seconds is larger than this value, we assume an endpoint is | ||
| 475 | + detected. | ||
| 476 | + rule3_min_utterance_length: | ||
| 477 | + Used only when enable_endpoint_detection is True. If the utterance | ||
| 478 | + length in seconds is larger than this value, we assume an endpoint | ||
| 479 | + is detected. | ||
| 480 | + decoding_method: | ||
| 481 | + The only valid value is greedy_search. | ||
| 482 | + provider: | ||
| 483 | + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | ||
| 484 | + debug: | ||
| 485 | + True to show meta data in the model. | ||
| 486 | + """ | ||
| 487 | + self = cls.__new__(cls) | ||
| 488 | + _assert_file_exists(tokens) | ||
| 489 | + _assert_file_exists(model) | ||
| 490 | + | ||
| 491 | + assert num_threads > 0, num_threads | ||
| 492 | + | ||
| 493 | + nemo_ctc_config = OnlineNeMoCtcModelConfig( | ||
| 494 | + model=model, | ||
| 495 | + ) | ||
| 496 | + | ||
| 497 | + model_config = OnlineModelConfig( | ||
| 498 | + nemo_ctc=nemo_ctc_config, | ||
| 499 | + tokens=tokens, | ||
| 500 | + num_threads=num_threads, | ||
| 501 | + provider=provider, | ||
| 502 | + debug=debug, | ||
| 503 | + ) | ||
| 504 | + | ||
| 505 | + feat_config = FeatureExtractorConfig( | ||
| 506 | + sampling_rate=sample_rate, | ||
| 507 | + feature_dim=feature_dim, | ||
| 508 | + ) | ||
| 509 | + | ||
| 510 | + endpoint_config = EndpointConfig( | ||
| 511 | + rule1_min_trailing_silence=rule1_min_trailing_silence, | ||
| 512 | + rule2_min_trailing_silence=rule2_min_trailing_silence, | ||
| 513 | + rule3_min_utterance_length=rule3_min_utterance_length, | ||
| 514 | + ) | ||
| 515 | + | ||
| 516 | + recognizer_config = OnlineRecognizerConfig( | ||
| 517 | + feat_config=feat_config, | ||
| 518 | + model_config=model_config, | ||
| 519 | + endpoint_config=endpoint_config, | ||
| 520 | + enable_endpoint=enable_endpoint_detection, | ||
| 521 | + decoding_method=decoding_method, | ||
| 522 | + ) | ||
| 523 | + | ||
| 524 | + self.recognizer = _Recognizer(recognizer_config) | ||
| 525 | + self.config = recognizer_config | ||
| 526 | + return self | ||
| 527 | + | ||
| 528 | + @classmethod | ||
| 421 | def from_wenet_ctc( | 529 | def from_wenet_ctc( |
| 422 | cls, | 530 | cls, |
| 423 | tokens: str, | 531 | tokens: str, |
| @@ -433,6 +541,7 @@ class OnlineRecognizer(object): | @@ -433,6 +541,7 @@ class OnlineRecognizer(object): | ||
| 433 | rule3_min_utterance_length: float = 20.0, | 541 | rule3_min_utterance_length: float = 20.0, |
| 434 | decoding_method: str = "greedy_search", | 542 | decoding_method: str = "greedy_search", |
| 435 | provider: str = "cpu", | 543 | provider: str = "cpu", |
| 544 | + debug: bool = False, | ||
| 436 | ): | 545 | ): |
| 437 | """ | 546 | """ |
| 438 | Please refer to | 547 | Please refer to |
| @@ -497,6 +606,7 @@ class OnlineRecognizer(object): | @@ -497,6 +606,7 @@ class OnlineRecognizer(object): | ||
| 497 | tokens=tokens, | 606 | tokens=tokens, |
| 498 | num_threads=num_threads, | 607 | num_threads=num_threads, |
| 499 | provider=provider, | 608 | provider=provider, |
| 609 | + debug=debug, | ||
| 500 | ) | 610 | ) |
| 501 | 611 | ||
| 502 | feat_config = FeatureExtractorConfig( | 612 | feat_config = FeatureExtractorConfig( |
| @@ -537,6 +647,9 @@ class OnlineRecognizer(object): | @@ -537,6 +647,9 @@ class OnlineRecognizer(object): | ||
| 537 | def is_ready(self, s: OnlineStream) -> bool: | 647 | def is_ready(self, s: OnlineStream) -> bool: |
| 538 | return self.recognizer.is_ready(s) | 648 | return self.recognizer.is_ready(s) |
| 539 | 649 | ||
| 650 | + def get_result_all(self, s: OnlineStream) -> OnlineRecognizerResult: | ||
| 651 | + return self.recognizer.get_result(s) | ||
| 652 | + | ||
| 540 | def get_result(self, s: OnlineStream) -> str: | 653 | def get_result(self, s: OnlineStream) -> str: |
| 541 | return self.recognizer.get_result(s).text.strip() | 654 | return self.recognizer.get_result(s).text.strip() |
| 542 | 655 |
-
请 注册 或 登录 后发表评论