Committed by
GitHub
Support non-streaming WeNet CTC models. (#426)
正在显示
21 个修改的文件
包含
469 行增加
和
32 行删除
| @@ -14,6 +14,47 @@ echo "PATH: $PATH" | @@ -14,6 +14,47 @@ echo "PATH: $PATH" | ||
| 14 | which $EXE | 14 | which $EXE |
| 15 | 15 | ||
| 16 | log "------------------------------------------------------------" | 16 | log "------------------------------------------------------------" |
| 17 | +log "Run Wenet models" | ||
| 18 | +log "------------------------------------------------------------" | ||
| 19 | +wenet_models=( | ||
| 20 | +sherpa-onnx-zh-wenet-aishell | ||
| 21 | +sherpa-onnx-zh-wenet-aishell2 | ||
| 22 | +sherpa-onnx-zh-wenet-wenetspeech | ||
| 23 | +sherpa-onnx-zh-wenet-multi-cn | ||
| 24 | +sherpa-onnx-en-wenet-librispeech | ||
| 25 | +sherpa-onnx-en-wenet-gigaspeech | ||
| 26 | +) | ||
| 27 | +for name in ${wenet_models[@]}; do | ||
| 28 | + repo_url=https://huggingface.co/csukuangfj/$name | ||
| 29 | + log "Start testing ${repo_url}" | ||
| 30 | + repo=$(basename $repo_url) | ||
| 31 | + log "Download pretrained model and test-data from $repo_url" | ||
| 32 | + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | ||
| 33 | + pushd $repo | ||
| 34 | + git lfs pull --include "*.onnx" | ||
| 35 | + ls -lh *.onnx | ||
| 36 | + popd | ||
| 37 | + | ||
| 38 | + log "test float32 models" | ||
| 39 | + time $EXE \ | ||
| 40 | + --tokens=$repo/tokens.txt \ | ||
| 41 | + --wenet-ctc-model=$repo/model.onnx \ | ||
| 42 | + $repo/test_wavs/0.wav \ | ||
| 43 | + $repo/test_wavs/1.wav \ | ||
| 44 | + $repo/test_wavs/8k.wav | ||
| 45 | + | ||
| 46 | + log "test int8 models" | ||
| 47 | + time $EXE \ | ||
| 48 | + --tokens=$repo/tokens.txt \ | ||
| 49 | + --wenet-ctc-model=$repo/model.int8.onnx \ | ||
| 50 | + $repo/test_wavs/0.wav \ | ||
| 51 | + $repo/test_wavs/1.wav \ | ||
| 52 | + $repo/test_wavs/8k.wav | ||
| 53 | + | ||
| 54 | + rm -rf $repo | ||
| 55 | +done | ||
| 56 | + | ||
| 57 | +log "------------------------------------------------------------" | ||
| 17 | log "Run tdnn yesno (Hebrew)" | 58 | log "Run tdnn yesno (Hebrew)" |
| 18 | log "------------------------------------------------------------" | 59 | log "------------------------------------------------------------" |
| 19 | repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-tdnn-yesno | 60 | repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-tdnn-yesno |
| 1 | name: export-wenet-to-onnx | 1 | name: export-wenet-to-onnx |
| 2 | 2 | ||
| 3 | on: | 3 | on: |
| 4 | - push: | ||
| 5 | - branches: | ||
| 6 | - - master | ||
| 7 | - paths: | ||
| 8 | - - 'scripts/wenet/**' | ||
| 9 | - - '.github/workflows/export-wenet-to-onnx.yaml' | ||
| 10 | - pull_request: | ||
| 11 | - paths: | ||
| 12 | - - 'scripts/wenet/**' | ||
| 13 | - - '.github/workflows/export-wenet-to-onnx.yaml' | ||
| 14 | - | ||
| 15 | workflow_dispatch: | 4 | workflow_dispatch: |
| 16 | 5 | ||
| 17 | concurrency: | 6 | concurrency: |
| @@ -89,6 +89,14 @@ jobs: | @@ -89,6 +89,14 @@ jobs: | ||
| 89 | file build/bin/sherpa-onnx | 89 | file build/bin/sherpa-onnx |
| 90 | readelf -d build/bin/sherpa-onnx | 90 | readelf -d build/bin/sherpa-onnx |
| 91 | 91 | ||
| 92 | + - name: Test offline CTC | ||
| 93 | + shell: bash | ||
| 94 | + run: | | ||
| 95 | + export PATH=$PWD/build/bin:$PATH | ||
| 96 | + export EXE=sherpa-onnx-offline | ||
| 97 | + | ||
| 98 | + .github/scripts/test-offline-ctc.sh | ||
| 99 | + | ||
| 92 | - name: Test offline TTS | 100 | - name: Test offline TTS |
| 93 | shell: bash | 101 | shell: bash |
| 94 | run: | | 102 | run: | |
| @@ -115,14 +123,6 @@ jobs: | @@ -115,14 +123,6 @@ jobs: | ||
| 115 | 123 | ||
| 116 | .github/scripts/test-offline-whisper.sh | 124 | .github/scripts/test-offline-whisper.sh |
| 117 | 125 | ||
| 118 | - - name: Test offline CTC | ||
| 119 | - shell: bash | ||
| 120 | - run: | | ||
| 121 | - export PATH=$PWD/build/bin:$PATH | ||
| 122 | - export EXE=sherpa-onnx-offline | ||
| 123 | - | ||
| 124 | - .github/scripts/test-offline-ctc.sh | ||
| 125 | - | ||
| 126 | - name: Test offline transducer | 126 | - name: Test offline transducer |
| 127 | shell: bash | 127 | shell: bash |
| 128 | run: | | 128 | run: | |
| @@ -172,7 +172,7 @@ def main(): | @@ -172,7 +172,7 @@ def main(): | ||
| 172 | # https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz | 172 | # https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz |
| 173 | url = os.environ.get("WENET_URL", "") | 173 | url = os.environ.get("WENET_URL", "") |
| 174 | meta_data = { | 174 | meta_data = { |
| 175 | - "model_type": "wenet-ctc", | 175 | + "model_type": "wenet_ctc", |
| 176 | "version": "1", | 176 | "version": "1", |
| 177 | "model_author": "wenet", | 177 | "model_author": "wenet", |
| 178 | "comment": "streaming", | 178 | "comment": "streaming", |
| @@ -185,6 +185,7 @@ def main(): | @@ -185,6 +185,7 @@ def main(): | ||
| 185 | "cnn_module_kernel": cnn_module_kernel, | 185 | "cnn_module_kernel": cnn_module_kernel, |
| 186 | "right_context": right_context, | 186 | "right_context": right_context, |
| 187 | "subsampling_factor": subsampling_factor, | 187 | "subsampling_factor": subsampling_factor, |
| 188 | + "vocab_size": torch_model.ctc.ctc_lo.weight.shape[0], | ||
| 188 | } | 189 | } |
| 189 | add_meta_data(filename=filename, meta_data=meta_data) | 190 | add_meta_data(filename=filename, meta_data=meta_data) |
| 190 | 191 |
| @@ -107,10 +107,12 @@ def main(): | @@ -107,10 +107,12 @@ def main(): | ||
| 107 | # https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz | 107 | # https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz |
| 108 | url = os.environ.get("WENET_URL", "") | 108 | url = os.environ.get("WENET_URL", "") |
| 109 | meta_data = { | 109 | meta_data = { |
| 110 | - "model_type": "wenet-ctc", | 110 | + "model_type": "wenet_ctc", |
| 111 | "version": "1", | 111 | "version": "1", |
| 112 | "model_author": "wenet", | 112 | "model_author": "wenet", |
| 113 | "comment": "non-streaming", | 113 | "comment": "non-streaming", |
| 114 | + "subsampling_factor": torch_model.encoder.embed.subsampling_rate, | ||
| 115 | + "vocab_size": torch_model.ctc.ctc_lo.weight.shape[0], | ||
| 114 | "url": url, | 116 | "url": url, |
| 115 | } | 117 | } |
| 116 | add_meta_data(filename=filename, meta_data=meta_data) | 118 | add_meta_data(filename=filename, meta_data=meta_data) |
| @@ -41,6 +41,8 @@ set(sources | @@ -41,6 +41,8 @@ set(sources | ||
| 41 | offline-transducer-model-config.cc | 41 | offline-transducer-model-config.cc |
| 42 | offline-transducer-model.cc | 42 | offline-transducer-model.cc |
| 43 | offline-transducer-modified-beam-search-decoder.cc | 43 | offline-transducer-modified-beam-search-decoder.cc |
| 44 | + offline-wenet-ctc-model-config.cc | ||
| 45 | + offline-wenet-ctc-model.cc | ||
| 44 | offline-whisper-greedy-search-decoder.cc | 46 | offline-whisper-greedy-search-decoder.cc |
| 45 | offline-whisper-model-config.cc | 47 | offline-whisper-model-config.cc |
| 46 | offline-whisper-model.cc | 48 | offline-whisper-model.cc |
| @@ -12,6 +12,7 @@ | @@ -12,6 +12,7 @@ | ||
| 12 | #include "sherpa-onnx/csrc/macros.h" | 12 | #include "sherpa-onnx/csrc/macros.h" |
| 13 | #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h" | 13 | #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h" |
| 14 | #include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h" | 14 | #include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h" |
| 15 | +#include "sherpa-onnx/csrc/offline-wenet-ctc-model.h" | ||
| 15 | #include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h" | 16 | #include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h" |
| 16 | #include "sherpa-onnx/csrc/onnx-utils.h" | 17 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 17 | 18 | ||
| @@ -21,10 +22,11 @@ enum class ModelType { | @@ -21,10 +22,11 @@ enum class ModelType { | ||
| 21 | kEncDecCTCModelBPE, | 22 | kEncDecCTCModelBPE, |
| 22 | kTdnn, | 23 | kTdnn, |
| 23 | kZipformerCtc, | 24 | kZipformerCtc, |
| 25 | + kWenetCtc, | ||
| 24 | kUnkown, | 26 | kUnkown, |
| 25 | }; | 27 | }; |
| 26 | 28 | ||
| 27 | -} | 29 | +} // namespace |
| 28 | 30 | ||
| 29 | namespace sherpa_onnx { | 31 | namespace sherpa_onnx { |
| 30 | 32 | ||
| @@ -52,6 +54,9 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -52,6 +54,9 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 52 | "If you are using models from NeMo, please refer to\n" | 54 | "If you are using models from NeMo, please refer to\n" |
| 53 | "https://huggingface.co/csukuangfj/" | 55 | "https://huggingface.co/csukuangfj/" |
| 54 | "sherpa-onnx-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py" | 56 | "sherpa-onnx-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py" |
| 57 | + "If you are using models from WeNet, please refer to\n" | ||
| 58 | + "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/" | ||
| 59 | + "run.sh\n" | ||
| 55 | "\n" | 60 | "\n" |
| 56 | "for how to add metadta to model.onnx\n"); | 61 | "for how to add metadta to model.onnx\n"); |
| 57 | return ModelType::kUnkown; | 62 | return ModelType::kUnkown; |
| @@ -63,6 +68,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -63,6 +68,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 63 | return ModelType::kTdnn; | 68 | return ModelType::kTdnn; |
| 64 | } else if (model_type.get() == std::string("zipformer2_ctc")) { | 69 | } else if (model_type.get() == std::string("zipformer2_ctc")) { |
| 65 | return ModelType::kZipformerCtc; | 70 | return ModelType::kZipformerCtc; |
| 71 | + } else if (model_type.get() == std::string("wenet_ctc")) { | ||
| 72 | + return ModelType::kWenetCtc; | ||
| 66 | } else { | 73 | } else { |
| 67 | SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); | 74 | SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); |
| 68 | return ModelType::kUnkown; | 75 | return ModelType::kUnkown; |
| @@ -80,6 +87,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | @@ -80,6 +87,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | ||
| 80 | filename = config.tdnn.model; | 87 | filename = config.tdnn.model; |
| 81 | } else if (!config.zipformer_ctc.model.empty()) { | 88 | } else if (!config.zipformer_ctc.model.empty()) { |
| 82 | filename = config.zipformer_ctc.model; | 89 | filename = config.zipformer_ctc.model; |
| 90 | + } else if (!config.wenet_ctc.model.empty()) { | ||
| 91 | + filename = config.wenet_ctc.model; | ||
| 83 | } else { | 92 | } else { |
| 84 | SHERPA_ONNX_LOGE("Please specify a CTC model"); | 93 | SHERPA_ONNX_LOGE("Please specify a CTC model"); |
| 85 | exit(-1); | 94 | exit(-1); |
| @@ -101,6 +110,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | @@ -101,6 +110,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | ||
| 101 | case ModelType::kZipformerCtc: | 110 | case ModelType::kZipformerCtc: |
| 102 | return std::make_unique<OfflineZipformerCtcModel>(config); | 111 | return std::make_unique<OfflineZipformerCtcModel>(config); |
| 103 | break; | 112 | break; |
| 113 | + case ModelType::kWenetCtc: | ||
| 114 | + return std::make_unique<OfflineWenetCtcModel>(config); | ||
| 115 | + break; | ||
| 104 | case ModelType::kUnkown: | 116 | case ModelType::kUnkown: |
| 105 | SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); | 117 | SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); |
| 106 | return nullptr; | 118 | return nullptr; |
| @@ -122,6 +134,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | @@ -122,6 +134,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | ||
| 122 | filename = config.tdnn.model; | 134 | filename = config.tdnn.model; |
| 123 | } else if (!config.zipformer_ctc.model.empty()) { | 135 | } else if (!config.zipformer_ctc.model.empty()) { |
| 124 | filename = config.zipformer_ctc.model; | 136 | filename = config.zipformer_ctc.model; |
| 137 | + } else if (!config.wenet_ctc.model.empty()) { | ||
| 138 | + filename = config.wenet_ctc.model; | ||
| 125 | } else { | 139 | } else { |
| 126 | SHERPA_ONNX_LOGE("Please specify a CTC model"); | 140 | SHERPA_ONNX_LOGE("Please specify a CTC model"); |
| 127 | exit(-1); | 141 | exit(-1); |
| @@ -143,6 +157,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | @@ -143,6 +157,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | ||
| 143 | case ModelType::kZipformerCtc: | 157 | case ModelType::kZipformerCtc: |
| 144 | return std::make_unique<OfflineZipformerCtcModel>(mgr, config); | 158 | return std::make_unique<OfflineZipformerCtcModel>(mgr, config); |
| 145 | break; | 159 | break; |
| 160 | + case ModelType::kWenetCtc: | ||
| 161 | + return std::make_unique<OfflineWenetCtcModel>(mgr, config); | ||
| 162 | + break; | ||
| 146 | case ModelType::kUnkown: | 163 | case ModelType::kUnkown: |
| 147 | SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); | 164 | SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); |
| 148 | return nullptr; | 165 | return nullptr; |
| @@ -63,6 +63,9 @@ class OfflineCtcModel { | @@ -63,6 +63,9 @@ class OfflineCtcModel { | ||
| 63 | * for the features. | 63 | * for the features. |
| 64 | */ | 64 | */ |
| 65 | virtual std::string FeatureNormalizationMethod() const { return {}; } | 65 | virtual std::string FeatureNormalizationMethod() const { return {}; } |
| 66 | + | ||
| 67 | + // Return true if the model supports batch size > 1 | ||
| 68 | + virtual bool SupportBatchProcessing() const { return true; } | ||
| 66 | }; | 69 | }; |
| 67 | 70 | ||
| 68 | } // namespace sherpa_onnx | 71 | } // namespace sherpa_onnx |
| @@ -17,6 +17,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { | @@ -17,6 +17,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { | ||
| 17 | whisper.Register(po); | 17 | whisper.Register(po); |
| 18 | tdnn.Register(po); | 18 | tdnn.Register(po); |
| 19 | zipformer_ctc.Register(po); | 19 | zipformer_ctc.Register(po); |
| 20 | + wenet_ctc.Register(po); | ||
| 20 | 21 | ||
| 21 | po->Register("tokens", &tokens, "Path to tokens.txt"); | 22 | po->Register("tokens", &tokens, "Path to tokens.txt"); |
| 22 | 23 | ||
| @@ -67,6 +68,10 @@ bool OfflineModelConfig::Validate() const { | @@ -67,6 +68,10 @@ bool OfflineModelConfig::Validate() const { | ||
| 67 | return zipformer_ctc.Validate(); | 68 | return zipformer_ctc.Validate(); |
| 68 | } | 69 | } |
| 69 | 70 | ||
| 71 | + if (!wenet_ctc.model.empty()) { | ||
| 72 | + return wenet_ctc.Validate(); | ||
| 73 | + } | ||
| 74 | + | ||
| 70 | return transducer.Validate(); | 75 | return transducer.Validate(); |
| 71 | } | 76 | } |
| 72 | 77 | ||
| @@ -80,6 +85,7 @@ std::string OfflineModelConfig::ToString() const { | @@ -80,6 +85,7 @@ std::string OfflineModelConfig::ToString() const { | ||
| 80 | os << "whisper=" << whisper.ToString() << ", "; | 85 | os << "whisper=" << whisper.ToString() << ", "; |
| 81 | os << "tdnn=" << tdnn.ToString() << ", "; | 86 | os << "tdnn=" << tdnn.ToString() << ", "; |
| 82 | os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", "; | 87 | os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", "; |
| 88 | + os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; | ||
| 83 | os << "tokens=\"" << tokens << "\", "; | 89 | os << "tokens=\"" << tokens << "\", "; |
| 84 | os << "num_threads=" << num_threads << ", "; | 90 | os << "num_threads=" << num_threads << ", "; |
| 85 | os << "debug=" << (debug ? "True" : "False") << ", "; | 91 | os << "debug=" << (debug ? "True" : "False") << ", "; |
| @@ -10,6 +10,7 @@ | @@ -10,6 +10,7 @@ | ||
| 10 | #include "sherpa-onnx/csrc/offline-paraformer-model-config.h" | 10 | #include "sherpa-onnx/csrc/offline-paraformer-model-config.h" |
| 11 | #include "sherpa-onnx/csrc/offline-tdnn-model-config.h" | 11 | #include "sherpa-onnx/csrc/offline-tdnn-model-config.h" |
| 12 | #include "sherpa-onnx/csrc/offline-transducer-model-config.h" | 12 | #include "sherpa-onnx/csrc/offline-transducer-model-config.h" |
| 13 | +#include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h" | ||
| 13 | #include "sherpa-onnx/csrc/offline-whisper-model-config.h" | 14 | #include "sherpa-onnx/csrc/offline-whisper-model-config.h" |
| 14 | #include "sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h" | 15 | #include "sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h" |
| 15 | 16 | ||
| @@ -22,6 +23,7 @@ struct OfflineModelConfig { | @@ -22,6 +23,7 @@ struct OfflineModelConfig { | ||
| 22 | OfflineWhisperModelConfig whisper; | 23 | OfflineWhisperModelConfig whisper; |
| 23 | OfflineTdnnModelConfig tdnn; | 24 | OfflineTdnnModelConfig tdnn; |
| 24 | OfflineZipformerCtcModelConfig zipformer_ctc; | 25 | OfflineZipformerCtcModelConfig zipformer_ctc; |
| 26 | + OfflineWenetCtcModelConfig wenet_ctc; | ||
| 25 | 27 | ||
| 26 | std::string tokens; | 28 | std::string tokens; |
| 27 | int32_t num_threads = 2; | 29 | int32_t num_threads = 2; |
| @@ -46,6 +48,7 @@ struct OfflineModelConfig { | @@ -46,6 +48,7 @@ struct OfflineModelConfig { | ||
| 46 | const OfflineWhisperModelConfig &whisper, | 48 | const OfflineWhisperModelConfig &whisper, |
| 47 | const OfflineTdnnModelConfig &tdnn, | 49 | const OfflineTdnnModelConfig &tdnn, |
| 48 | const OfflineZipformerCtcModelConfig &zipformer_ctc, | 50 | const OfflineZipformerCtcModelConfig &zipformer_ctc, |
| 51 | + const OfflineWenetCtcModelConfig &wenet_ctc, | ||
| 49 | const std::string &tokens, int32_t num_threads, bool debug, | 52 | const std::string &tokens, int32_t num_threads, bool debug, |
| 50 | const std::string &provider, const std::string &model_type) | 53 | const std::string &provider, const std::string &model_type) |
| 51 | : transducer(transducer), | 54 | : transducer(transducer), |
| @@ -54,6 +57,7 @@ struct OfflineModelConfig { | @@ -54,6 +57,7 @@ struct OfflineModelConfig { | ||
| 54 | whisper(whisper), | 57 | whisper(whisper), |
| 55 | tdnn(tdnn), | 58 | tdnn(tdnn), |
| 56 | zipformer_ctc(zipformer_ctc), | 59 | zipformer_ctc(zipformer_ctc), |
| 60 | + wenet_ctc(wenet_ctc), | ||
| 57 | tokens(tokens), | 61 | tokens(tokens), |
| 58 | num_threads(num_threads), | 62 | num_threads(num_threads), |
| 59 | debug(debug), | 63 | debug(debug), |
| @@ -75,6 +75,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | @@ -75,6 +75,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | ||
| 75 | #endif | 75 | #endif |
| 76 | 76 | ||
| 77 | void Init() { | 77 | void Init() { |
| 78 | + if (!config_.model_config.wenet_ctc.model.empty()) { | ||
| 79 | + // WeNet CTC models assume input samples are in the range | ||
| 80 | + // [-32768, 32767], so we set normalize_samples to false | ||
| 81 | + config_.feat_config.normalize_samples = false; | ||
| 82 | + } | ||
| 83 | + | ||
| 78 | config_.feat_config.nemo_normalize_type = | 84 | config_.feat_config.nemo_normalize_type = |
| 79 | model_->FeatureNormalizationMethod(); | 85 | model_->FeatureNormalizationMethod(); |
| 80 | 86 | ||
| @@ -85,10 +91,11 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | @@ -85,10 +91,11 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | ||
| 85 | config_.ctc_fst_decoder_config); | 91 | config_.ctc_fst_decoder_config); |
| 86 | } else if (config_.decoding_method == "greedy_search") { | 92 | } else if (config_.decoding_method == "greedy_search") { |
| 87 | if (!symbol_table_.contains("<blk>") && | 93 | if (!symbol_table_.contains("<blk>") && |
| 88 | - !symbol_table_.contains("<eps>")) { | 94 | + !symbol_table_.contains("<eps>") && |
| 95 | + !symbol_table_.contains("<blank>")) { | ||
| 89 | SHERPA_ONNX_LOGE( | 96 | SHERPA_ONNX_LOGE( |
| 90 | "We expect that tokens.txt contains " | 97 | "We expect that tokens.txt contains " |
| 91 | - "the symbol <blk> or <eps> and its ID."); | 98 | + "the symbol <blk> or <eps> or <blank> and its ID."); |
| 92 | exit(-1); | 99 | exit(-1); |
| 93 | } | 100 | } |
| 94 | 101 | ||
| @@ -98,6 +105,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | @@ -98,6 +105,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | ||
| 98 | } else if (symbol_table_.contains("<eps>")) { | 105 | } else if (symbol_table_.contains("<eps>")) { |
| 99 | // for tdnn models of the yesno recipe from icefall | 106 | // for tdnn models of the yesno recipe from icefall |
| 100 | blank_id = symbol_table_["<eps>"]; | 107 | blank_id = symbol_table_["<eps>"]; |
| 108 | + } else if (symbol_table_.contains("<blank>")) { | ||
| 109 | + // for Wenet CTC models | ||
| 110 | + blank_id = symbol_table_["<blank>"]; | ||
| 101 | } | 111 | } |
| 102 | 112 | ||
| 103 | decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id); | 113 | decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id); |
| @@ -113,6 +123,15 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | @@ -113,6 +123,15 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | ||
| 113 | } | 123 | } |
| 114 | 124 | ||
| 115 | void DecodeStreams(OfflineStream **ss, int32_t n) const override { | 125 | void DecodeStreams(OfflineStream **ss, int32_t n) const override { |
| 126 | + if (!model_->SupportBatchProcessing()) { | ||
| 127 | + // If the model does not support batch process, | ||
| 128 | + // we process each stream independently. | ||
| 129 | + for (int32_t i = 0; i != n; ++i) { | ||
| 130 | + DecodeStream(ss[i]); | ||
| 131 | + } | ||
| 132 | + return; | ||
| 133 | + } | ||
| 134 | + | ||
| 116 | auto memory_info = | 135 | auto memory_info = |
| 117 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | 136 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); |
| 118 | 137 | ||
| @@ -165,6 +184,38 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | @@ -165,6 +184,38 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | ||
| 165 | } | 184 | } |
| 166 | 185 | ||
| 167 | private: | 186 | private: |
| 187 | + // Decode a single stream. | ||
| 188 | + // Some models do not support batch size > 1, e.g., WeNet CTC models. | ||
| 189 | + void DecodeStream(OfflineStream *s) const { | ||
| 190 | + auto memory_info = | ||
| 191 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 192 | + | ||
| 193 | + int32_t feat_dim = config_.feat_config.feature_dim; | ||
| 194 | + std::vector<float> f = s->GetFrames(); | ||
| 195 | + | ||
| 196 | + int32_t num_frames = f.size() / feat_dim; | ||
| 197 | + | ||
| 198 | + std::array<int64_t, 3> shape = {1, num_frames, feat_dim}; | ||
| 199 | + | ||
| 200 | + Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(), | ||
| 201 | + shape.data(), shape.size()); | ||
| 202 | + | ||
| 203 | + int64_t x_length_scalar = num_frames; | ||
| 204 | + std::array<int64_t, 1> x_length_shape = {1}; | ||
| 205 | + Ort::Value x_length = | ||
| 206 | + Ort::Value::CreateTensor(memory_info, &x_length_scalar, 1, | ||
| 207 | + x_length_shape.data(), x_length_shape.size()); | ||
| 208 | + | ||
| 209 | + auto t = model_->Forward(std::move(x), std::move(x_length)); | ||
| 210 | + auto results = decoder_->Decode(std::move(t[0]), std::move(t[1])); | ||
| 211 | + int32_t frame_shift_ms = 10; | ||
| 212 | + | ||
| 213 | + auto r = Convert(results[0], symbol_table_, frame_shift_ms, | ||
| 214 | + model_->SubsamplingFactor()); | ||
| 215 | + s->SetResult(r); | ||
| 216 | + } | ||
| 217 | + | ||
| 218 | + private: | ||
| 168 | OfflineRecognizerConfig config_; | 219 | OfflineRecognizerConfig config_; |
| 169 | SymbolTable symbol_table_; | 220 | SymbolTable symbol_table_; |
| 170 | std::unique_ptr<OfflineCtcModel> model_; | 221 | std::unique_ptr<OfflineCtcModel> model_; |
| @@ -26,7 +26,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -26,7 +26,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 26 | } else if (model_type == "paraformer") { | 26 | } else if (model_type == "paraformer") { |
| 27 | return std::make_unique<OfflineRecognizerParaformerImpl>(config); | 27 | return std::make_unique<OfflineRecognizerParaformerImpl>(config); |
| 28 | } else if (model_type == "nemo_ctc" || model_type == "tdnn" || | 28 | } else if (model_type == "nemo_ctc" || model_type == "tdnn" || |
| 29 | - model_type == "zipformer2_ctc") { | 29 | + model_type == "zipformer2_ctc" || model_type == "wenet_ctc") { |
| 30 | return std::make_unique<OfflineRecognizerCtcImpl>(config); | 30 | return std::make_unique<OfflineRecognizerCtcImpl>(config); |
| 31 | } else if (model_type == "whisper") { | 31 | } else if (model_type == "whisper") { |
| 32 | return std::make_unique<OfflineRecognizerWhisperImpl>(config); | 32 | return std::make_unique<OfflineRecognizerWhisperImpl>(config); |
| @@ -51,6 +51,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -51,6 +51,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 51 | model_filename = config.model_config.tdnn.model; | 51 | model_filename = config.model_config.tdnn.model; |
| 52 | } else if (!config.model_config.zipformer_ctc.model.empty()) { | 52 | } else if (!config.model_config.zipformer_ctc.model.empty()) { |
| 53 | model_filename = config.model_config.zipformer_ctc.model; | 53 | model_filename = config.model_config.zipformer_ctc.model; |
| 54 | + } else if (!config.model_config.wenet_ctc.model.empty()) { | ||
| 55 | + model_filename = config.model_config.wenet_ctc.model; | ||
| 54 | } else if (!config.model_config.whisper.encoder.empty()) { | 56 | } else if (!config.model_config.whisper.encoder.empty()) { |
| 55 | model_filename = config.model_config.whisper.encoder; | 57 | model_filename = config.model_config.whisper.encoder; |
| 56 | } else { | 58 | } else { |
| @@ -99,6 +101,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -99,6 +101,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 99 | "https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/" | 101 | "https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/" |
| 100 | "zipformer/export-onnx-ctc.py" | 102 | "zipformer/export-onnx-ctc.py" |
| 101 | "\n" | 103 | "\n" |
| 104 | + "(6) CTC models from WeNet" | ||
| 105 | + "\n " | ||
| 106 | + "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/run.sh" | ||
| 107 | + "\n" | ||
| 102 | "\n"); | 108 | "\n"); |
| 103 | exit(-1); | 109 | exit(-1); |
| 104 | } | 110 | } |
| @@ -114,7 +120,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -114,7 +120,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 114 | } | 120 | } |
| 115 | 121 | ||
| 116 | if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" || | 122 | if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" || |
| 117 | - model_type == "zipformer2_ctc") { | 123 | + model_type == "zipformer2_ctc" || model_type == "wenet_ctc") { |
| 118 | return std::make_unique<OfflineRecognizerCtcImpl>(config); | 124 | return std::make_unique<OfflineRecognizerCtcImpl>(config); |
| 119 | } | 125 | } |
| 120 | 126 | ||
| @@ -130,7 +136,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -130,7 +136,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 130 | " - EncDecCTCModelBPE models from NeMo\n" | 136 | " - EncDecCTCModelBPE models from NeMo\n" |
| 131 | " - Whisper models\n" | 137 | " - Whisper models\n" |
| 132 | " - Tdnn models\n" | 138 | " - Tdnn models\n" |
| 133 | - " - Zipformer CTC models\n", | 139 | + " - Zipformer CTC models\n" |
| 140 | + " - WeNet CTC models\n", | ||
| 134 | model_type.c_str()); | 141 | model_type.c_str()); |
| 135 | 142 | ||
| 136 | exit(-1); | 143 | exit(-1); |
| @@ -146,7 +153,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -146,7 +153,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 146 | } else if (model_type == "paraformer") { | 153 | } else if (model_type == "paraformer") { |
| 147 | return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config); | 154 | return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config); |
| 148 | } else if (model_type == "nemo_ctc" || model_type == "tdnn" || | 155 | } else if (model_type == "nemo_ctc" || model_type == "tdnn" || |
| 149 | - model_type == "zipformer2_ctc") { | 156 | + model_type == "zipformer2_ctc" || model_type == "wenet_ctc") { |
| 150 | return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config); | 157 | return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config); |
| 151 | } else if (model_type == "whisper") { | 158 | } else if (model_type == "whisper") { |
| 152 | return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config); | 159 | return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config); |
| @@ -171,6 +178,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -171,6 +178,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 171 | model_filename = config.model_config.tdnn.model; | 178 | model_filename = config.model_config.tdnn.model; |
| 172 | } else if (!config.model_config.zipformer_ctc.model.empty()) { | 179 | } else if (!config.model_config.zipformer_ctc.model.empty()) { |
| 173 | model_filename = config.model_config.zipformer_ctc.model; | 180 | model_filename = config.model_config.zipformer_ctc.model; |
| 181 | + } else if (!config.model_config.wenet_ctc.model.empty()) { | ||
| 182 | + model_filename = config.model_config.wenet_ctc.model; | ||
| 174 | } else if (!config.model_config.whisper.encoder.empty()) { | 183 | } else if (!config.model_config.whisper.encoder.empty()) { |
| 175 | model_filename = config.model_config.whisper.encoder; | 184 | model_filename = config.model_config.whisper.encoder; |
| 176 | } else { | 185 | } else { |
| @@ -219,6 +228,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -219,6 +228,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 219 | "https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/" | 228 | "https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/" |
| 220 | "zipformer/export-onnx-ctc.py" | 229 | "zipformer/export-onnx-ctc.py" |
| 221 | "\n" | 230 | "\n" |
| 231 | + "(6) CTC models from WeNet" | ||
| 232 | + "\n " | ||
| 233 | + "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/run.sh" | ||
| 234 | + "\n" | ||
| 222 | "\n"); | 235 | "\n"); |
| 223 | exit(-1); | 236 | exit(-1); |
| 224 | } | 237 | } |
| @@ -234,7 +247,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -234,7 +247,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 234 | } | 247 | } |
| 235 | 248 | ||
| 236 | if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" || | 249 | if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" || |
| 237 | - model_type == "zipformer2_ctc") { | 250 | + model_type == "zipformer2_ctc" || model_type == "wenet_ctc") { |
| 238 | return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config); | 251 | return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config); |
| 239 | } | 252 | } |
| 240 | 253 | ||
| @@ -250,7 +263,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -250,7 +263,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 250 | " - EncDecCTCModelBPE models from NeMo\n" | 263 | " - EncDecCTCModelBPE models from NeMo\n" |
| 251 | " - Whisper models\n" | 264 | " - Whisper models\n" |
| 252 | " - Tdnn models\n" | 265 | " - Tdnn models\n" |
| 253 | - " - Zipformer CTC models\n", | 266 | + " - Zipformer CTC models\n" |
| 267 | + " - WeNet CTC models\n", | ||
| 254 | model_type.c_str()); | 268 | model_type.c_str()); |
| 255 | 269 | ||
| 256 | exit(-1); | 270 | exit(-1); |
| @@ -40,7 +40,8 @@ struct OfflineFeatureExtractorConfig { | @@ -40,7 +40,8 @@ struct OfflineFeatureExtractorConfig { | ||
| 40 | // Feature dimension | 40 | // Feature dimension |
| 41 | int32_t feature_dim = 80; | 41 | int32_t feature_dim = 80; |
| 42 | 42 | ||
| 43 | - // Set internally by some models, e.g., paraformer sets it to false. | 43 | + // Set internally by some models, e.g., paraformer and wenet CTC models set |
| 44 | + // it to false. | ||
| 44 | // This parameter is not exposed to users from the commandline | 45 | // This parameter is not exposed to users from the commandline |
| 45 | // If true, the feature extractor expects inputs to be normalized to | 46 | // If true, the feature extractor expects inputs to be normalized to |
| 46 | // the range [-1, 1]. | 47 | // the range [-1, 1]. |
| 1 | +// sherpa-onnx/csrc/offline-wenet-ctc-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 8 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void OfflineWenetCtcModelConfig::Register(ParseOptions *po) { | ||
| 13 | + po->Register( | ||
| 14 | + "wenet-ctc-model", &model, | ||
| 15 | + "Path to model.onnx from WeNet. Please see " | ||
| 16 | + "https://github.com/k2-fsa/sherpa-onnx/pull/425 for available models"); | ||
| 17 | +} | ||
| 18 | + | ||
| 19 | +bool OfflineWenetCtcModelConfig::Validate() const { | ||
| 20 | + if (!FileExists(model)) { | ||
| 21 | + SHERPA_ONNX_LOGE("WeNet model: %s does not exist", model.c_str()); | ||
| 22 | + return false; | ||
| 23 | + } | ||
| 24 | + | ||
| 25 | + return true; | ||
| 26 | +} | ||
| 27 | + | ||
| 28 | +std::string OfflineWenetCtcModelConfig::ToString() const { | ||
| 29 | + std::ostringstream os; | ||
| 30 | + | ||
| 31 | + os << "OfflineWenetCtcModelConfig("; | ||
| 32 | + os << "model=\"" << model << "\")"; | ||
| 33 | + | ||
| 34 | + return os.str(); | ||
| 35 | +} | ||
| 36 | + | ||
| 37 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-wenet-ctc-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +struct OfflineWenetCtcModelConfig { | ||
| 14 | + std::string model; | ||
| 15 | + | ||
| 16 | + OfflineWenetCtcModelConfig() = default; | ||
| 17 | + explicit OfflineWenetCtcModelConfig(const std::string &model) | ||
| 18 | + : model(model) {} | ||
| 19 | + | ||
| 20 | + void Register(ParseOptions *po); | ||
| 21 | + bool Validate() const; | ||
| 22 | + | ||
| 23 | + std::string ToString() const; | ||
| 24 | +}; | ||
| 25 | + | ||
| 26 | +} // namespace sherpa_onnx | ||
| 27 | + | ||
| 28 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_ |
sherpa-onnx/csrc/offline-wenet-ctc-model.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-wenet-ctc-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-wenet-ctc-model.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 8 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 9 | +#include "sherpa-onnx/csrc/session.h" | ||
| 10 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 11 | +#include "sherpa-onnx/csrc/transpose.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +class OfflineWenetCtcModel::Impl { | ||
| 16 | + public: | ||
| 17 | + explicit Impl(const OfflineModelConfig &config) | ||
| 18 | + : config_(config), | ||
| 19 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 20 | + sess_opts_(GetSessionOptions(config)), | ||
| 21 | + allocator_{} { | ||
| 22 | + auto buf = ReadFile(config_.wenet_ctc.model); | ||
| 23 | + Init(buf.data(), buf.size()); | ||
| 24 | + } | ||
| 25 | + | ||
| 26 | +#if __ANDROID_API__ >= 9 | ||
| 27 | + Impl(AAssetManager *mgr, const OfflineModelConfig &config) | ||
| 28 | + : config_(config), | ||
| 29 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 30 | + sess_opts_(GetSessionOptions(config)), | ||
| 31 | + allocator_{} { | ||
| 32 | + auto buf = ReadFile(mgr, config_.wenet_ctc.model); | ||
| 33 | + Init(buf.data(), buf.size()); | ||
| 34 | + } | ||
| 35 | +#endif | ||
| 36 | + | ||
| 37 | + std::vector<Ort::Value> Forward(Ort::Value features, | ||
| 38 | + Ort::Value features_length) { | ||
| 39 | + std::array<Ort::Value, 2> inputs = {std::move(features), | ||
| 40 | + std::move(features_length)}; | ||
| 41 | + | ||
| 42 | + return sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 43 | + output_names_ptr_.data(), output_names_ptr_.size()); | ||
| 44 | + } | ||
| 45 | + | ||
| 46 | + int32_t VocabSize() const { return vocab_size_; } | ||
| 47 | + | ||
| 48 | + int32_t SubsamplingFactor() const { return subsampling_factor_; } | ||
| 49 | + | ||
| 50 | + OrtAllocator *Allocator() const { return allocator_; } | ||
| 51 | + | ||
| 52 | + private: | ||
| 53 | + void Init(void *model_data, size_t model_data_length) { | ||
| 54 | + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length, | ||
| 55 | + sess_opts_); | ||
| 56 | + | ||
| 57 | + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | ||
| 58 | + | ||
| 59 | + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | ||
| 60 | + | ||
| 61 | + // get meta data | ||
| 62 | + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); | ||
| 63 | + if (config_.debug) { | ||
| 64 | + std::ostringstream os; | ||
| 65 | + PrintModelMetadata(os, meta_data); | ||
| 66 | + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
| 67 | + } | ||
| 68 | + | ||
| 69 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 70 | + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); | ||
| 71 | + SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor"); | ||
| 72 | + } | ||
| 73 | + | ||
| 74 | + private: | ||
| 75 | + OfflineModelConfig config_; | ||
| 76 | + Ort::Env env_; | ||
| 77 | + Ort::SessionOptions sess_opts_; | ||
| 78 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 79 | + | ||
| 80 | + std::unique_ptr<Ort::Session> sess_; | ||
| 81 | + | ||
| 82 | + std::vector<std::string> input_names_; | ||
| 83 | + std::vector<const char *> input_names_ptr_; | ||
| 84 | + | ||
| 85 | + std::vector<std::string> output_names_; | ||
| 86 | + std::vector<const char *> output_names_ptr_; | ||
| 87 | + | ||
| 88 | + int32_t vocab_size_ = 0; | ||
| 89 | + int32_t subsampling_factor_ = 0; | ||
| 90 | +}; | ||
| 91 | + | ||
| 92 | +OfflineWenetCtcModel::OfflineWenetCtcModel(const OfflineModelConfig &config) | ||
| 93 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 94 | + | ||
| 95 | +#if __ANDROID_API__ >= 9 | ||
| 96 | +OfflineWenetCtcModel::OfflineWenetCtcModel(AAssetManager *mgr, | ||
| 97 | + const OfflineModelConfig &config) | ||
| 98 | + : impl_(std::make_unique<Impl>(mgr, config)) {} | ||
| 99 | +#endif | ||
| 100 | + | ||
| 101 | +OfflineWenetCtcModel::~OfflineWenetCtcModel() = default; | ||
| 102 | + | ||
| 103 | +std::vector<Ort::Value> OfflineWenetCtcModel::Forward( | ||
| 104 | + Ort::Value features, Ort::Value features_length) { | ||
| 105 | + return impl_->Forward(std::move(features), std::move(features_length)); | ||
| 106 | +} | ||
| 107 | + | ||
| 108 | +int32_t OfflineWenetCtcModel::VocabSize() const { return impl_->VocabSize(); } | ||
| 109 | + | ||
| 110 | +int32_t OfflineWenetCtcModel::SubsamplingFactor() const { | ||
| 111 | + return impl_->SubsamplingFactor(); | ||
| 112 | +} | ||
| 113 | + | ||
| 114 | +OrtAllocator *OfflineWenetCtcModel::Allocator() const { | ||
| 115 | + return impl_->Allocator(); | ||
| 116 | +} | ||
| 117 | + | ||
| 118 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/offline-wenet-ctc-model.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-wenet-ctc-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_H_ | ||
| 6 | +#include <memory> | ||
| 7 | +#include <string> | ||
| 8 | +#include <utility> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#if __ANDROID_API__ >= 9 | ||
| 12 | +#include "android/asset_manager.h" | ||
| 13 | +#include "android/asset_manager_jni.h" | ||
| 14 | +#endif | ||
| 15 | + | ||
| 16 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 17 | +#include "sherpa-onnx/csrc/offline-ctc-model.h" | ||
| 18 | +#include "sherpa-onnx/csrc/offline-model-config.h" | ||
| 19 | + | ||
| 20 | +namespace sherpa_onnx { | ||
| 21 | + | ||
| 22 | +/** This class implements the CTC model from WeNet. | ||
| 23 | + * | ||
| 24 | + * See | ||
| 25 | + * https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/export-onnx.py | ||
| 26 | + * https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/test-onnx.py | ||
| 27 | + * https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/run.sh | ||
| 28 | + * | ||
| 29 | + */ | ||
| 30 | +class OfflineWenetCtcModel : public OfflineCtcModel { | ||
| 31 | + public: | ||
| 32 | + explicit OfflineWenetCtcModel(const OfflineModelConfig &config); | ||
| 33 | + | ||
| 34 | +#if __ANDROID_API__ >= 9 | ||
| 35 | + OfflineWenetCtcModel(AAssetManager *mgr, const OfflineModelConfig &config); | ||
| 36 | +#endif | ||
| 37 | + | ||
| 38 | + ~OfflineWenetCtcModel() override; | ||
| 39 | + | ||
| 40 | + /** Run the forward method of the model. | ||
| 41 | + * | ||
| 42 | + * @param features A tensor of shape (N, T, C). | ||
| 43 | + * @param features_length A 1-D tensor of shape (N,) containing number of | ||
| 44 | + * valid frames in `features` before padding. | ||
| 45 | + * Its dtype is int64_t. | ||
| 46 | + * | ||
| 47 | + * @return Return a vector containing: | ||
| 48 | + * - log_probs: A 3-D tensor of shape (N, T', vocab_size). | ||
| 49 | + * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t | ||
| 50 | + */ | ||
| 51 | + std::vector<Ort::Value> Forward(Ort::Value features, | ||
| 52 | + Ort::Value features_length) override; | ||
| 53 | + | ||
| 54 | + /** Return the vocabulary size of the model | ||
| 55 | + */ | ||
| 56 | + int32_t VocabSize() const override; | ||
| 57 | + | ||
| 58 | + /** SubsamplingFactor of the model | ||
| 59 | + * | ||
| 60 | + * For Citrinet, the subsampling factor is usually 4. | ||
| 61 | + * For Conformer CTC, the subsampling factor is usually 8. | ||
| 62 | + */ | ||
| 63 | + int32_t SubsamplingFactor() const override; | ||
| 64 | + | ||
| 65 | + /** Return an allocator for allocating memory | ||
| 66 | + */ | ||
| 67 | + OrtAllocator *Allocator() const override; | ||
| 68 | + | ||
| 69 | + // WeNet CTC models do not support batch size > 1 | ||
| 70 | + bool SupportBatchProcessing() const override { return false; } | ||
| 71 | + | ||
| 72 | + private: | ||
| 73 | + class Impl; | ||
| 74 | + std::unique_ptr<Impl> impl_; | ||
| 75 | +}; | ||
| 76 | + | ||
| 77 | +} // namespace sherpa_onnx | ||
| 78 | + | ||
| 79 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_H_ |
| @@ -17,6 +17,7 @@ pybind11_add_module(_sherpa_onnx | @@ -17,6 +17,7 @@ pybind11_add_module(_sherpa_onnx | ||
| 17 | offline-tts-model-config.cc | 17 | offline-tts-model-config.cc |
| 18 | offline-tts-vits-model-config.cc | 18 | offline-tts-vits-model-config.cc |
| 19 | offline-tts.cc | 19 | offline-tts.cc |
| 20 | + offline-wenet-ctc-model-config.cc | ||
| 20 | offline-whisper-model-config.cc | 21 | offline-whisper-model-config.cc |
| 21 | offline-zipformer-ctc-model-config.cc | 22 | offline-zipformer-ctc-model-config.cc |
| 22 | online-lm-config.cc | 23 | online-lm-config.cc |
| @@ -12,6 +12,7 @@ | @@ -12,6 +12,7 @@ | ||
| 12 | #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" | 12 | #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" |
| 13 | #include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h" | 13 | #include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h" |
| 14 | #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" | 14 | #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" |
| 15 | +#include "sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h" | ||
| 15 | #include "sherpa-onnx/python/csrc/offline-whisper-model-config.h" | 16 | #include "sherpa-onnx/python/csrc/offline-whisper-model-config.h" |
| 16 | #include "sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h" | 17 | #include "sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h" |
| 17 | 18 | ||
| @@ -24,6 +25,7 @@ void PybindOfflineModelConfig(py::module *m) { | @@ -24,6 +25,7 @@ void PybindOfflineModelConfig(py::module *m) { | ||
| 24 | PybindOfflineWhisperModelConfig(m); | 25 | PybindOfflineWhisperModelConfig(m); |
| 25 | PybindOfflineTdnnModelConfig(m); | 26 | PybindOfflineTdnnModelConfig(m); |
| 26 | PybindOfflineZipformerCtcModelConfig(m); | 27 | PybindOfflineZipformerCtcModelConfig(m); |
| 28 | + PybindOfflineWenetCtcModelConfig(m); | ||
| 27 | 29 | ||
| 28 | using PyClass = OfflineModelConfig; | 30 | using PyClass = OfflineModelConfig; |
| 29 | py::class_<PyClass>(*m, "OfflineModelConfig") | 31 | py::class_<PyClass>(*m, "OfflineModelConfig") |
| @@ -32,7 +34,8 @@ void PybindOfflineModelConfig(py::module *m) { | @@ -32,7 +34,8 @@ void PybindOfflineModelConfig(py::module *m) { | ||
| 32 | const OfflineNemoEncDecCtcModelConfig &, | 34 | const OfflineNemoEncDecCtcModelConfig &, |
| 33 | const OfflineWhisperModelConfig &, | 35 | const OfflineWhisperModelConfig &, |
| 34 | const OfflineTdnnModelConfig &, | 36 | const OfflineTdnnModelConfig &, |
| 35 | - const OfflineZipformerCtcModelConfig &, const std::string &, | 37 | + const OfflineZipformerCtcModelConfig &, |
| 38 | + const OfflineWenetCtcModelConfig &, const std::string &, | ||
| 36 | int32_t, bool, const std::string &, const std::string &>(), | 39 | int32_t, bool, const std::string &, const std::string &>(), |
| 37 | py::arg("transducer") = OfflineTransducerModelConfig(), | 40 | py::arg("transducer") = OfflineTransducerModelConfig(), |
| 38 | py::arg("paraformer") = OfflineParaformerModelConfig(), | 41 | py::arg("paraformer") = OfflineParaformerModelConfig(), |
| @@ -40,6 +43,7 @@ void PybindOfflineModelConfig(py::module *m) { | @@ -40,6 +43,7 @@ void PybindOfflineModelConfig(py::module *m) { | ||
| 40 | py::arg("whisper") = OfflineWhisperModelConfig(), | 43 | py::arg("whisper") = OfflineWhisperModelConfig(), |
| 41 | py::arg("tdnn") = OfflineTdnnModelConfig(), | 44 | py::arg("tdnn") = OfflineTdnnModelConfig(), |
| 42 | py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(), | 45 | py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(), |
| 46 | + py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(), | ||
| 43 | py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, | 47 | py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, |
| 44 | py::arg("provider") = "cpu", py::arg("model_type") = "") | 48 | py::arg("provider") = "cpu", py::arg("model_type") = "") |
| 45 | .def_readwrite("transducer", &PyClass::transducer) | 49 | .def_readwrite("transducer", &PyClass::transducer) |
| @@ -48,6 +52,7 @@ void PybindOfflineModelConfig(py::module *m) { | @@ -48,6 +52,7 @@ void PybindOfflineModelConfig(py::module *m) { | ||
| 48 | .def_readwrite("whisper", &PyClass::whisper) | 52 | .def_readwrite("whisper", &PyClass::whisper) |
| 49 | .def_readwrite("tdnn", &PyClass::tdnn) | 53 | .def_readwrite("tdnn", &PyClass::tdnn) |
| 50 | .def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc) | 54 | .def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc) |
| 55 | + .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) | ||
| 51 | .def_readwrite("tokens", &PyClass::tokens) | 56 | .def_readwrite("tokens", &PyClass::tokens) |
| 52 | .def_readwrite("num_threads", &PyClass::num_threads) | 57 | .def_readwrite("num_threads", &PyClass::num_threads) |
| 53 | .def_readwrite("debug", &PyClass::debug) | 58 | .def_readwrite("debug", &PyClass::debug) |
| 1 | +// sherpa-onnx/python/csrc/offline-wenet-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +void PybindOfflineWenetCtcModelConfig(py::module *m) { | ||
| 15 | + using PyClass = OfflineWenetCtcModelConfig; | ||
| 16 | + py::class_<PyClass>(*m, "OfflineWenetCtcModelConfig") | ||
| 17 | + .def(py::init<const std::string &>(), py::arg("model")) | ||
| 18 | + .def_readwrite("model", &PyClass::model) | ||
| 19 | + .def("__str__", &PyClass::ToString); | ||
| 20 | +} | ||
| 21 | + | ||
| 22 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/python/csrc/offline-wenet-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindOfflineWenetCtcModelConfig(py::module *m); | ||
| 13 | + | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_ |
-
请 注册 或 登录 后发表评论