Fangjun Kuang
Committed by GitHub

Support non-streaming WeNet CTC models. (#426)

@@ -14,6 +14,47 @@ echo "PATH: $PATH" @@ -14,6 +14,47 @@ echo "PATH: $PATH"
14 which $EXE 14 which $EXE
15 15
16 log "------------------------------------------------------------" 16 log "------------------------------------------------------------"
  17 +log "Run Wenet models"
  18 +log "------------------------------------------------------------"
  19 +wenet_models=(
  20 +sherpa-onnx-zh-wenet-aishell
  21 +sherpa-onnx-zh-wenet-aishell2
  22 +sherpa-onnx-zh-wenet-wenetspeech
  23 +sherpa-onnx-zh-wenet-multi-cn
  24 +sherpa-onnx-en-wenet-librispeech
  25 +sherpa-onnx-en-wenet-gigaspeech
  26 +)
  27 +for name in ${wenet_models[@]}; do
  28 + repo_url=https://huggingface.co/csukuangfj/$name
  29 + log "Start testing ${repo_url}"
  30 + repo=$(basename $repo_url)
  31 + log "Download pretrained model and test-data from $repo_url"
  32 + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
  33 + pushd $repo
  34 + git lfs pull --include "*.onnx"
  35 + ls -lh *.onnx
  36 + popd
  37 +
  38 + log "test float32 models"
  39 + time $EXE \
  40 + --tokens=$repo/tokens.txt \
  41 + --wenet-ctc-model=$repo/model.onnx \
  42 + $repo/test_wavs/0.wav \
  43 + $repo/test_wavs/1.wav \
  44 + $repo/test_wavs/8k.wav
  45 +
  46 + log "test int8 models"
  47 + time $EXE \
  48 + --tokens=$repo/tokens.txt \
  49 + --wenet-ctc-model=$repo/model.int8.onnx \
  50 + $repo/test_wavs/0.wav \
  51 + $repo/test_wavs/1.wav \
  52 + $repo/test_wavs/8k.wav
  53 +
  54 + rm -rf $repo
  55 +done
  56 +
  57 +log "------------------------------------------------------------"
17 log "Run tdnn yesno (Hebrew)" 58 log "Run tdnn yesno (Hebrew)"
18 log "------------------------------------------------------------" 59 log "------------------------------------------------------------"
19 repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-tdnn-yesno 60 repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-tdnn-yesno
1 name: export-wenet-to-onnx 1 name: export-wenet-to-onnx
2 2
3 on: 3 on:
4 - push:  
5 - branches:  
6 - - master  
7 - paths:  
8 - - 'scripts/wenet/**'  
9 - - '.github/workflows/export-wenet-to-onnx.yaml'  
10 - pull_request:  
11 - paths:  
12 - - 'scripts/wenet/**'  
13 - - '.github/workflows/export-wenet-to-onnx.yaml'  
14 -  
15 workflow_dispatch: 4 workflow_dispatch:
16 5
17 concurrency: 6 concurrency:
@@ -89,6 +89,14 @@ jobs: @@ -89,6 +89,14 @@ jobs:
89 file build/bin/sherpa-onnx 89 file build/bin/sherpa-onnx
90 readelf -d build/bin/sherpa-onnx 90 readelf -d build/bin/sherpa-onnx
91 91
  92 + - name: Test offline CTC
  93 + shell: bash
  94 + run: |
  95 + export PATH=$PWD/build/bin:$PATH
  96 + export EXE=sherpa-onnx-offline
  97 +
  98 + .github/scripts/test-offline-ctc.sh
  99 +
92 - name: Test offline TTS 100 - name: Test offline TTS
93 shell: bash 101 shell: bash
94 run: | 102 run: |
@@ -115,14 +123,6 @@ jobs: @@ -115,14 +123,6 @@ jobs:
115 123
116 .github/scripts/test-offline-whisper.sh 124 .github/scripts/test-offline-whisper.sh
117 125
118 - - name: Test offline CTC  
119 - shell: bash  
120 - run: |  
121 - export PATH=$PWD/build/bin:$PATH  
122 - export EXE=sherpa-onnx-offline  
123 -  
124 - .github/scripts/test-offline-ctc.sh  
125 -  
126 - name: Test offline transducer 126 - name: Test offline transducer
127 shell: bash 127 shell: bash
128 run: | 128 run: |
@@ -172,7 +172,7 @@ def main(): @@ -172,7 +172,7 @@ def main():
172 # https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz 172 # https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz
173 url = os.environ.get("WENET_URL", "") 173 url = os.environ.get("WENET_URL", "")
174 meta_data = { 174 meta_data = {
175 - "model_type": "wenet-ctc", 175 + "model_type": "wenet_ctc",
176 "version": "1", 176 "version": "1",
177 "model_author": "wenet", 177 "model_author": "wenet",
178 "comment": "streaming", 178 "comment": "streaming",
@@ -185,6 +185,7 @@ def main(): @@ -185,6 +185,7 @@ def main():
185 "cnn_module_kernel": cnn_module_kernel, 185 "cnn_module_kernel": cnn_module_kernel,
186 "right_context": right_context, 186 "right_context": right_context,
187 "subsampling_factor": subsampling_factor, 187 "subsampling_factor": subsampling_factor,
  188 + "vocab_size": torch_model.ctc.ctc_lo.weight.shape[0],
188 } 189 }
189 add_meta_data(filename=filename, meta_data=meta_data) 190 add_meta_data(filename=filename, meta_data=meta_data)
190 191
@@ -107,10 +107,12 @@ def main(): @@ -107,10 +107,12 @@ def main():
107 # https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz 107 # https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz
108 url = os.environ.get("WENET_URL", "") 108 url = os.environ.get("WENET_URL", "")
109 meta_data = { 109 meta_data = {
110 - "model_type": "wenet-ctc", 110 + "model_type": "wenet_ctc",
111 "version": "1", 111 "version": "1",
112 "model_author": "wenet", 112 "model_author": "wenet",
113 "comment": "non-streaming", 113 "comment": "non-streaming",
  114 + "subsampling_factor": torch_model.encoder.embed.subsampling_rate,
  115 + "vocab_size": torch_model.ctc.ctc_lo.weight.shape[0],
114 "url": url, 116 "url": url,
115 } 117 }
116 add_meta_data(filename=filename, meta_data=meta_data) 118 add_meta_data(filename=filename, meta_data=meta_data)
@@ -41,6 +41,8 @@ set(sources @@ -41,6 +41,8 @@ set(sources
41 offline-transducer-model-config.cc 41 offline-transducer-model-config.cc
42 offline-transducer-model.cc 42 offline-transducer-model.cc
43 offline-transducer-modified-beam-search-decoder.cc 43 offline-transducer-modified-beam-search-decoder.cc
  44 + offline-wenet-ctc-model-config.cc
  45 + offline-wenet-ctc-model.cc
44 offline-whisper-greedy-search-decoder.cc 46 offline-whisper-greedy-search-decoder.cc
45 offline-whisper-model-config.cc 47 offline-whisper-model-config.cc
46 offline-whisper-model.cc 48 offline-whisper-model.cc
@@ -12,6 +12,7 @@ @@ -12,6 +12,7 @@
12 #include "sherpa-onnx/csrc/macros.h" 12 #include "sherpa-onnx/csrc/macros.h"
13 #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h" 13 #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
14 #include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h" 14 #include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h"
  15 +#include "sherpa-onnx/csrc/offline-wenet-ctc-model.h"
15 #include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h" 16 #include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h"
16 #include "sherpa-onnx/csrc/onnx-utils.h" 17 #include "sherpa-onnx/csrc/onnx-utils.h"
17 18
@@ -21,10 +22,11 @@ enum class ModelType { @@ -21,10 +22,11 @@ enum class ModelType {
21 kEncDecCTCModelBPE, 22 kEncDecCTCModelBPE,
22 kTdnn, 23 kTdnn,
23 kZipformerCtc, 24 kZipformerCtc,
  25 + kWenetCtc,
24 kUnkown, 26 kUnkown,
25 }; 27 };
26 28
27 -} 29 +} // namespace
28 30
29 namespace sherpa_onnx { 31 namespace sherpa_onnx {
30 32
@@ -52,6 +54,9 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -52,6 +54,9 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
52 "If you are using models from NeMo, please refer to\n" 54 "If you are using models from NeMo, please refer to\n"
53 "https://huggingface.co/csukuangfj/" 55 "https://huggingface.co/csukuangfj/"
54 "sherpa-onnx-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py" 56 "sherpa-onnx-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py"
  57 + "If you are using models from WeNet, please refer to\n"
  58 + "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/"
  59 + "run.sh\n"
55 "\n" 60 "\n"
56 "for how to add metadta to model.onnx\n"); 61 "for how to add metadta to model.onnx\n");
57 return ModelType::kUnkown; 62 return ModelType::kUnkown;
@@ -63,6 +68,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -63,6 +68,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
63 return ModelType::kTdnn; 68 return ModelType::kTdnn;
64 } else if (model_type.get() == std::string("zipformer2_ctc")) { 69 } else if (model_type.get() == std::string("zipformer2_ctc")) {
65 return ModelType::kZipformerCtc; 70 return ModelType::kZipformerCtc;
  71 + } else if (model_type.get() == std::string("wenet_ctc")) {
  72 + return ModelType::kWenetCtc;
66 } else { 73 } else {
67 SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); 74 SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
68 return ModelType::kUnkown; 75 return ModelType::kUnkown;
@@ -80,6 +87,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( @@ -80,6 +87,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
80 filename = config.tdnn.model; 87 filename = config.tdnn.model;
81 } else if (!config.zipformer_ctc.model.empty()) { 88 } else if (!config.zipformer_ctc.model.empty()) {
82 filename = config.zipformer_ctc.model; 89 filename = config.zipformer_ctc.model;
  90 + } else if (!config.wenet_ctc.model.empty()) {
  91 + filename = config.wenet_ctc.model;
83 } else { 92 } else {
84 SHERPA_ONNX_LOGE("Please specify a CTC model"); 93 SHERPA_ONNX_LOGE("Please specify a CTC model");
85 exit(-1); 94 exit(-1);
@@ -101,6 +110,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( @@ -101,6 +110,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
101 case ModelType::kZipformerCtc: 110 case ModelType::kZipformerCtc:
102 return std::make_unique<OfflineZipformerCtcModel>(config); 111 return std::make_unique<OfflineZipformerCtcModel>(config);
103 break; 112 break;
  113 + case ModelType::kWenetCtc:
  114 + return std::make_unique<OfflineWenetCtcModel>(config);
  115 + break;
104 case ModelType::kUnkown: 116 case ModelType::kUnkown:
105 SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); 117 SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
106 return nullptr; 118 return nullptr;
@@ -122,6 +134,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( @@ -122,6 +134,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
122 filename = config.tdnn.model; 134 filename = config.tdnn.model;
123 } else if (!config.zipformer_ctc.model.empty()) { 135 } else if (!config.zipformer_ctc.model.empty()) {
124 filename = config.zipformer_ctc.model; 136 filename = config.zipformer_ctc.model;
  137 + } else if (!config.wenet_ctc.model.empty()) {
  138 + filename = config.wenet_ctc.model;
125 } else { 139 } else {
126 SHERPA_ONNX_LOGE("Please specify a CTC model"); 140 SHERPA_ONNX_LOGE("Please specify a CTC model");
127 exit(-1); 141 exit(-1);
@@ -143,6 +157,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( @@ -143,6 +157,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
143 case ModelType::kZipformerCtc: 157 case ModelType::kZipformerCtc:
144 return std::make_unique<OfflineZipformerCtcModel>(mgr, config); 158 return std::make_unique<OfflineZipformerCtcModel>(mgr, config);
145 break; 159 break;
  160 + case ModelType::kWenetCtc:
  161 + return std::make_unique<OfflineWenetCtcModel>(mgr, config);
  162 + break;
146 case ModelType::kUnkown: 163 case ModelType::kUnkown:
147 SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); 164 SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
148 return nullptr; 165 return nullptr;
@@ -63,6 +63,9 @@ class OfflineCtcModel { @@ -63,6 +63,9 @@ class OfflineCtcModel {
63 * for the features. 63 * for the features.
64 */ 64 */
65 virtual std::string FeatureNormalizationMethod() const { return {}; } 65 virtual std::string FeatureNormalizationMethod() const { return {}; }
  66 +
  67 + // Return true if the model supports batch size > 1
  68 + virtual bool SupportBatchProcessing() const { return true; }
66 }; 69 };
67 70
68 } // namespace sherpa_onnx 71 } // namespace sherpa_onnx
@@ -17,6 +17,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { @@ -17,6 +17,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
17 whisper.Register(po); 17 whisper.Register(po);
18 tdnn.Register(po); 18 tdnn.Register(po);
19 zipformer_ctc.Register(po); 19 zipformer_ctc.Register(po);
  20 + wenet_ctc.Register(po);
20 21
21 po->Register("tokens", &tokens, "Path to tokens.txt"); 22 po->Register("tokens", &tokens, "Path to tokens.txt");
22 23
@@ -67,6 +68,10 @@ bool OfflineModelConfig::Validate() const { @@ -67,6 +68,10 @@ bool OfflineModelConfig::Validate() const {
67 return zipformer_ctc.Validate(); 68 return zipformer_ctc.Validate();
68 } 69 }
69 70
  71 + if (!wenet_ctc.model.empty()) {
  72 + return wenet_ctc.Validate();
  73 + }
  74 +
70 return transducer.Validate(); 75 return transducer.Validate();
71 } 76 }
72 77
@@ -80,6 +85,7 @@ std::string OfflineModelConfig::ToString() const { @@ -80,6 +85,7 @@ std::string OfflineModelConfig::ToString() const {
80 os << "whisper=" << whisper.ToString() << ", "; 85 os << "whisper=" << whisper.ToString() << ", ";
81 os << "tdnn=" << tdnn.ToString() << ", "; 86 os << "tdnn=" << tdnn.ToString() << ", ";
82 os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", "; 87 os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
  88 + os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
83 os << "tokens=\"" << tokens << "\", "; 89 os << "tokens=\"" << tokens << "\", ";
84 os << "num_threads=" << num_threads << ", "; 90 os << "num_threads=" << num_threads << ", ";
85 os << "debug=" << (debug ? "True" : "False") << ", "; 91 os << "debug=" << (debug ? "True" : "False") << ", ";
@@ -10,6 +10,7 @@ @@ -10,6 +10,7 @@
10 #include "sherpa-onnx/csrc/offline-paraformer-model-config.h" 10 #include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
11 #include "sherpa-onnx/csrc/offline-tdnn-model-config.h" 11 #include "sherpa-onnx/csrc/offline-tdnn-model-config.h"
12 #include "sherpa-onnx/csrc/offline-transducer-model-config.h" 12 #include "sherpa-onnx/csrc/offline-transducer-model-config.h"
  13 +#include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h"
13 #include "sherpa-onnx/csrc/offline-whisper-model-config.h" 14 #include "sherpa-onnx/csrc/offline-whisper-model-config.h"
14 #include "sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h" 15 #include "sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h"
15 16
@@ -22,6 +23,7 @@ struct OfflineModelConfig { @@ -22,6 +23,7 @@ struct OfflineModelConfig {
22 OfflineWhisperModelConfig whisper; 23 OfflineWhisperModelConfig whisper;
23 OfflineTdnnModelConfig tdnn; 24 OfflineTdnnModelConfig tdnn;
24 OfflineZipformerCtcModelConfig zipformer_ctc; 25 OfflineZipformerCtcModelConfig zipformer_ctc;
  26 + OfflineWenetCtcModelConfig wenet_ctc;
25 27
26 std::string tokens; 28 std::string tokens;
27 int32_t num_threads = 2; 29 int32_t num_threads = 2;
@@ -46,6 +48,7 @@ struct OfflineModelConfig { @@ -46,6 +48,7 @@ struct OfflineModelConfig {
46 const OfflineWhisperModelConfig &whisper, 48 const OfflineWhisperModelConfig &whisper,
47 const OfflineTdnnModelConfig &tdnn, 49 const OfflineTdnnModelConfig &tdnn,
48 const OfflineZipformerCtcModelConfig &zipformer_ctc, 50 const OfflineZipformerCtcModelConfig &zipformer_ctc,
  51 + const OfflineWenetCtcModelConfig &wenet_ctc,
49 const std::string &tokens, int32_t num_threads, bool debug, 52 const std::string &tokens, int32_t num_threads, bool debug,
50 const std::string &provider, const std::string &model_type) 53 const std::string &provider, const std::string &model_type)
51 : transducer(transducer), 54 : transducer(transducer),
@@ -54,6 +57,7 @@ struct OfflineModelConfig { @@ -54,6 +57,7 @@ struct OfflineModelConfig {
54 whisper(whisper), 57 whisper(whisper),
55 tdnn(tdnn), 58 tdnn(tdnn),
56 zipformer_ctc(zipformer_ctc), 59 zipformer_ctc(zipformer_ctc),
  60 + wenet_ctc(wenet_ctc),
57 tokens(tokens), 61 tokens(tokens),
58 num_threads(num_threads), 62 num_threads(num_threads),
59 debug(debug), 63 debug(debug),
@@ -75,6 +75,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { @@ -75,6 +75,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
75 #endif 75 #endif
76 76
77 void Init() { 77 void Init() {
  78 + if (!config_.model_config.wenet_ctc.model.empty()) {
  79 + // WeNet CTC models assume input samples are in the range
  80 + // [-32768, 32767], so we set normalize_samples to false
  81 + config_.feat_config.normalize_samples = false;
  82 + }
  83 +
78 config_.feat_config.nemo_normalize_type = 84 config_.feat_config.nemo_normalize_type =
79 model_->FeatureNormalizationMethod(); 85 model_->FeatureNormalizationMethod();
80 86
@@ -85,10 +91,11 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { @@ -85,10 +91,11 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
85 config_.ctc_fst_decoder_config); 91 config_.ctc_fst_decoder_config);
86 } else if (config_.decoding_method == "greedy_search") { 92 } else if (config_.decoding_method == "greedy_search") {
87 if (!symbol_table_.contains("<blk>") && 93 if (!symbol_table_.contains("<blk>") &&
88 - !symbol_table_.contains("<eps>")) { 94 + !symbol_table_.contains("<eps>") &&
  95 + !symbol_table_.contains("<blank>")) {
89 SHERPA_ONNX_LOGE( 96 SHERPA_ONNX_LOGE(
90 "We expect that tokens.txt contains " 97 "We expect that tokens.txt contains "
91 - "the symbol <blk> or <eps> and its ID."); 98 + "the symbol <blk> or <eps> or <blank> and its ID.");
92 exit(-1); 99 exit(-1);
93 } 100 }
94 101
@@ -98,6 +105,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { @@ -98,6 +105,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
98 } else if (symbol_table_.contains("<eps>")) { 105 } else if (symbol_table_.contains("<eps>")) {
99 // for tdnn models of the yesno recipe from icefall 106 // for tdnn models of the yesno recipe from icefall
100 blank_id = symbol_table_["<eps>"]; 107 blank_id = symbol_table_["<eps>"];
  108 + } else if (symbol_table_.contains("<blank>")) {
  109 + // for Wenet CTC models
  110 + blank_id = symbol_table_["<blank>"];
101 } 111 }
102 112
103 decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id); 113 decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id);
@@ -113,6 +123,15 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { @@ -113,6 +123,15 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
113 } 123 }
114 124
115 void DecodeStreams(OfflineStream **ss, int32_t n) const override { 125 void DecodeStreams(OfflineStream **ss, int32_t n) const override {
  126 + if (!model_->SupportBatchProcessing()) {
  127 + // If the model does not support batch process,
  128 + // we process each stream independently.
  129 + for (int32_t i = 0; i != n; ++i) {
  130 + DecodeStream(ss[i]);
  131 + }
  132 + return;
  133 + }
  134 +
116 auto memory_info = 135 auto memory_info =
117 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); 136 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
118 137
@@ -165,6 +184,38 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { @@ -165,6 +184,38 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
165 } 184 }
166 185
167 private: 186 private:
  187 + // Decode a single stream.
  188 + // Some models do not support batch size > 1, e.g., WeNet CTC models.
  189 + void DecodeStream(OfflineStream *s) const {
  190 + auto memory_info =
  191 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  192 +
  193 + int32_t feat_dim = config_.feat_config.feature_dim;
  194 + std::vector<float> f = s->GetFrames();
  195 +
  196 + int32_t num_frames = f.size() / feat_dim;
  197 +
  198 + std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
  199 +
  200 + Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
  201 + shape.data(), shape.size());
  202 +
  203 + int64_t x_length_scalar = num_frames;
  204 + std::array<int64_t, 1> x_length_shape = {1};
  205 + Ort::Value x_length =
  206 + Ort::Value::CreateTensor(memory_info, &x_length_scalar, 1,
  207 + x_length_shape.data(), x_length_shape.size());
  208 +
  209 + auto t = model_->Forward(std::move(x), std::move(x_length));
  210 + auto results = decoder_->Decode(std::move(t[0]), std::move(t[1]));
  211 + int32_t frame_shift_ms = 10;
  212 +
  213 + auto r = Convert(results[0], symbol_table_, frame_shift_ms,
  214 + model_->SubsamplingFactor());
  215 + s->SetResult(r);
  216 + }
  217 +
  218 + private:
168 OfflineRecognizerConfig config_; 219 OfflineRecognizerConfig config_;
169 SymbolTable symbol_table_; 220 SymbolTable symbol_table_;
170 std::unique_ptr<OfflineCtcModel> model_; 221 std::unique_ptr<OfflineCtcModel> model_;
@@ -26,7 +26,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -26,7 +26,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
26 } else if (model_type == "paraformer") { 26 } else if (model_type == "paraformer") {
27 return std::make_unique<OfflineRecognizerParaformerImpl>(config); 27 return std::make_unique<OfflineRecognizerParaformerImpl>(config);
28 } else if (model_type == "nemo_ctc" || model_type == "tdnn" || 28 } else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
29 - model_type == "zipformer2_ctc") { 29 + model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
30 return std::make_unique<OfflineRecognizerCtcImpl>(config); 30 return std::make_unique<OfflineRecognizerCtcImpl>(config);
31 } else if (model_type == "whisper") { 31 } else if (model_type == "whisper") {
32 return std::make_unique<OfflineRecognizerWhisperImpl>(config); 32 return std::make_unique<OfflineRecognizerWhisperImpl>(config);
@@ -51,6 +51,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -51,6 +51,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
51 model_filename = config.model_config.tdnn.model; 51 model_filename = config.model_config.tdnn.model;
52 } else if (!config.model_config.zipformer_ctc.model.empty()) { 52 } else if (!config.model_config.zipformer_ctc.model.empty()) {
53 model_filename = config.model_config.zipformer_ctc.model; 53 model_filename = config.model_config.zipformer_ctc.model;
  54 + } else if (!config.model_config.wenet_ctc.model.empty()) {
  55 + model_filename = config.model_config.wenet_ctc.model;
54 } else if (!config.model_config.whisper.encoder.empty()) { 56 } else if (!config.model_config.whisper.encoder.empty()) {
55 model_filename = config.model_config.whisper.encoder; 57 model_filename = config.model_config.whisper.encoder;
56 } else { 58 } else {
@@ -99,6 +101,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -99,6 +101,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
99 "https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/" 101 "https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/"
100 "zipformer/export-onnx-ctc.py" 102 "zipformer/export-onnx-ctc.py"
101 "\n" 103 "\n"
  104 + "(6) CTC models from WeNet"
  105 + "\n "
  106 + "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/run.sh"
  107 + "\n"
102 "\n"); 108 "\n");
103 exit(-1); 109 exit(-1);
104 } 110 }
@@ -114,7 +120,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -114,7 +120,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
114 } 120 }
115 121
116 if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" || 122 if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" ||
117 - model_type == "zipformer2_ctc") { 123 + model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
118 return std::make_unique<OfflineRecognizerCtcImpl>(config); 124 return std::make_unique<OfflineRecognizerCtcImpl>(config);
119 } 125 }
120 126
@@ -130,7 +136,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -130,7 +136,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
130 " - EncDecCTCModelBPE models from NeMo\n" 136 " - EncDecCTCModelBPE models from NeMo\n"
131 " - Whisper models\n" 137 " - Whisper models\n"
132 " - Tdnn models\n" 138 " - Tdnn models\n"
133 - " - Zipformer CTC models\n", 139 + " - Zipformer CTC models\n"
  140 + " - WeNet CTC models\n",
134 model_type.c_str()); 141 model_type.c_str());
135 142
136 exit(-1); 143 exit(-1);
@@ -146,7 +153,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -146,7 +153,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
146 } else if (model_type == "paraformer") { 153 } else if (model_type == "paraformer") {
147 return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config); 154 return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
148 } else if (model_type == "nemo_ctc" || model_type == "tdnn" || 155 } else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
149 - model_type == "zipformer2_ctc") { 156 + model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
150 return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config); 157 return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
151 } else if (model_type == "whisper") { 158 } else if (model_type == "whisper") {
152 return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config); 159 return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
@@ -171,6 +178,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -171,6 +178,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
171 model_filename = config.model_config.tdnn.model; 178 model_filename = config.model_config.tdnn.model;
172 } else if (!config.model_config.zipformer_ctc.model.empty()) { 179 } else if (!config.model_config.zipformer_ctc.model.empty()) {
173 model_filename = config.model_config.zipformer_ctc.model; 180 model_filename = config.model_config.zipformer_ctc.model;
  181 + } else if (!config.model_config.wenet_ctc.model.empty()) {
  182 + model_filename = config.model_config.wenet_ctc.model;
174 } else if (!config.model_config.whisper.encoder.empty()) { 183 } else if (!config.model_config.whisper.encoder.empty()) {
175 model_filename = config.model_config.whisper.encoder; 184 model_filename = config.model_config.whisper.encoder;
176 } else { 185 } else {
@@ -219,6 +228,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -219,6 +228,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
219 "https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/" 228 "https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/"
220 "zipformer/export-onnx-ctc.py" 229 "zipformer/export-onnx-ctc.py"
221 "\n" 230 "\n"
  231 + "(6) CTC models from WeNet"
  232 + "\n "
  233 + "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/run.sh"
  234 + "\n"
222 "\n"); 235 "\n");
223 exit(-1); 236 exit(-1);
224 } 237 }
@@ -234,7 +247,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -234,7 +247,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
234 } 247 }
235 248
236 if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" || 249 if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" ||
237 - model_type == "zipformer2_ctc") { 250 + model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
238 return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config); 251 return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
239 } 252 }
240 253
@@ -250,7 +263,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -250,7 +263,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
250 " - EncDecCTCModelBPE models from NeMo\n" 263 " - EncDecCTCModelBPE models from NeMo\n"
251 " - Whisper models\n" 264 " - Whisper models\n"
252 " - Tdnn models\n" 265 " - Tdnn models\n"
253 - " - Zipformer CTC models\n", 266 + " - Zipformer CTC models\n"
  267 + " - WeNet CTC models\n",
254 model_type.c_str()); 268 model_type.c_str());
255 269
256 exit(-1); 270 exit(-1);
@@ -40,7 +40,8 @@ struct OfflineFeatureExtractorConfig { @@ -40,7 +40,8 @@ struct OfflineFeatureExtractorConfig {
40 // Feature dimension 40 // Feature dimension
41 int32_t feature_dim = 80; 41 int32_t feature_dim = 80;
42 42
43 - // Set internally by some models, e.g., paraformer sets it to false. 43 + // Set internally by some models, e.g., paraformer and wenet CTC models set
  44 + // it to false.
44 // This parameter is not exposed to users from the commandline 45 // This parameter is not exposed to users from the commandline
45 // If true, the feature extractor expects inputs to be normalized to 46 // If true, the feature extractor expects inputs to be normalized to
46 // the range [-1, 1]. 47 // the range [-1, 1].
  1 +// sherpa-onnx/csrc/offline-wenet-ctc-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h"
  6 +
  7 +#include "sherpa-onnx/csrc/file-utils.h"
  8 +#include "sherpa-onnx/csrc/macros.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void OfflineWenetCtcModelConfig::Register(ParseOptions *po) {
  13 + po->Register(
  14 + "wenet-ctc-model", &model,
  15 + "Path to model.onnx from WeNet. Please see "
  16 + "https://github.com/k2-fsa/sherpa-onnx/pull/425 for available models");
  17 +}
  18 +
  19 +bool OfflineWenetCtcModelConfig::Validate() const {
  20 + if (!FileExists(model)) {
  21 + SHERPA_ONNX_LOGE("WeNet model: %s does not exist", model.c_str());
  22 + return false;
  23 + }
  24 +
  25 + return true;
  26 +}
  27 +
  28 +std::string OfflineWenetCtcModelConfig::ToString() const {
  29 + std::ostringstream os;
  30 +
  31 + os << "OfflineWenetCtcModelConfig(";
  32 + os << "model=\"" << model << "\")";
  33 +
  34 + return os.str();
  35 +}
  36 +
  37 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-wenet-ctc-model-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/parse-options.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +struct OfflineWenetCtcModelConfig {
  14 + std::string model;
  15 +
  16 + OfflineWenetCtcModelConfig() = default;
  17 + explicit OfflineWenetCtcModelConfig(const std::string &model)
  18 + : model(model) {}
  19 +
  20 + void Register(ParseOptions *po);
  21 + bool Validate() const;
  22 +
  23 + std::string ToString() const;
  24 +};
  25 +
  26 +} // namespace sherpa_onnx
  27 +
  28 +#endif // SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/offline-wenet-ctc-model.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-wenet-ctc-model.h"
  6 +
  7 +#include "sherpa-onnx/csrc/macros.h"
  8 +#include "sherpa-onnx/csrc/onnx-utils.h"
  9 +#include "sherpa-onnx/csrc/session.h"
  10 +#include "sherpa-onnx/csrc/text-utils.h"
  11 +#include "sherpa-onnx/csrc/transpose.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +class OfflineWenetCtcModel::Impl {
  16 + public:
  17 + explicit Impl(const OfflineModelConfig &config)
  18 + : config_(config),
  19 + env_(ORT_LOGGING_LEVEL_ERROR),
  20 + sess_opts_(GetSessionOptions(config)),
  21 + allocator_{} {
  22 + auto buf = ReadFile(config_.wenet_ctc.model);
  23 + Init(buf.data(), buf.size());
  24 + }
  25 +
  26 +#if __ANDROID_API__ >= 9
  27 + Impl(AAssetManager *mgr, const OfflineModelConfig &config)
  28 + : config_(config),
  29 + env_(ORT_LOGGING_LEVEL_ERROR),
  30 + sess_opts_(GetSessionOptions(config)),
  31 + allocator_{} {
  32 + auto buf = ReadFile(mgr, config_.wenet_ctc.model);
  33 + Init(buf.data(), buf.size());
  34 + }
  35 +#endif
  36 +
  37 + std::vector<Ort::Value> Forward(Ort::Value features,
  38 + Ort::Value features_length) {
  39 + std::array<Ort::Value, 2> inputs = {std::move(features),
  40 + std::move(features_length)};
  41 +
  42 + return sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
  43 + output_names_ptr_.data(), output_names_ptr_.size());
  44 + }
  45 +
  46 + int32_t VocabSize() const { return vocab_size_; }
  47 +
  48 + int32_t SubsamplingFactor() const { return subsampling_factor_; }
  49 +
  50 + OrtAllocator *Allocator() const { return allocator_; }
  51 +
  52 + private:
  53 + void Init(void *model_data, size_t model_data_length) {
  54 + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
  55 + sess_opts_);
  56 +
  57 + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
  58 +
  59 + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
  60 +
  61 + // get meta data
  62 + Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
  63 + if (config_.debug) {
  64 + std::ostringstream os;
  65 + PrintModelMetadata(os, meta_data);
  66 + SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
  67 + }
  68 +
  69 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  70 + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
  71 + SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor");
  72 + }
  73 +
  74 + private:
  75 + OfflineModelConfig config_;
  76 + Ort::Env env_;
  77 + Ort::SessionOptions sess_opts_;
  78 + Ort::AllocatorWithDefaultOptions allocator_;
  79 +
  80 + std::unique_ptr<Ort::Session> sess_;
  81 +
  82 + std::vector<std::string> input_names_;
  83 + std::vector<const char *> input_names_ptr_;
  84 +
  85 + std::vector<std::string> output_names_;
  86 + std::vector<const char *> output_names_ptr_;
  87 +
  88 + int32_t vocab_size_ = 0;
  89 + int32_t subsampling_factor_ = 0;
  90 +};
  91 +
  92 +OfflineWenetCtcModel::OfflineWenetCtcModel(const OfflineModelConfig &config)
  93 + : impl_(std::make_unique<Impl>(config)) {}
  94 +
  95 +#if __ANDROID_API__ >= 9
  96 +OfflineWenetCtcModel::OfflineWenetCtcModel(AAssetManager *mgr,
  97 + const OfflineModelConfig &config)
  98 + : impl_(std::make_unique<Impl>(mgr, config)) {}
  99 +#endif
  100 +
  101 +OfflineWenetCtcModel::~OfflineWenetCtcModel() = default;
  102 +
  103 +std::vector<Ort::Value> OfflineWenetCtcModel::Forward(
  104 + Ort::Value features, Ort::Value features_length) {
  105 + return impl_->Forward(std::move(features), std::move(features_length));
  106 +}
  107 +
  108 +int32_t OfflineWenetCtcModel::VocabSize() const { return impl_->VocabSize(); }
  109 +
  110 +int32_t OfflineWenetCtcModel::SubsamplingFactor() const {
  111 + return impl_->SubsamplingFactor();
  112 +}
  113 +
  114 +OrtAllocator *OfflineWenetCtcModel::Allocator() const {
  115 + return impl_->Allocator();
  116 +}
  117 +
  118 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-wenet-ctc-model.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_H_
  6 +#include <memory>
  7 +#include <string>
  8 +#include <utility>
  9 +#include <vector>
  10 +
  11 +#if __ANDROID_API__ >= 9
  12 +#include "android/asset_manager.h"
  13 +#include "android/asset_manager_jni.h"
  14 +#endif
  15 +
  16 +#include "onnxruntime_cxx_api.h" // NOLINT
  17 +#include "sherpa-onnx/csrc/offline-ctc-model.h"
  18 +#include "sherpa-onnx/csrc/offline-model-config.h"
  19 +
  20 +namespace sherpa_onnx {
  21 +
  22 +/** This class implements the CTC model from WeNet.
  23 + *
  24 + * See
  25 + * https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/export-onnx.py
  26 + * https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/test-onnx.py
  27 + * https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/run.sh
  28 + *
  29 + */
  30 +class OfflineWenetCtcModel : public OfflineCtcModel {
  31 + public:
  32 + explicit OfflineWenetCtcModel(const OfflineModelConfig &config);
  33 +
  34 +#if __ANDROID_API__ >= 9
  35 + OfflineWenetCtcModel(AAssetManager *mgr, const OfflineModelConfig &config);
  36 +#endif
  37 +
  38 + ~OfflineWenetCtcModel() override;
  39 +
  40 + /** Run the forward method of the model.
  41 + *
  42 + * @param features A tensor of shape (N, T, C).
  43 + * @param features_length A 1-D tensor of shape (N,) containing number of
  44 + * valid frames in `features` before padding.
  45 + * Its dtype is int64_t.
  46 + *
  47 + * @return Return a vector containing:
  48 + * - log_probs: A 3-D tensor of shape (N, T', vocab_size).
  49 + * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
  50 + */
  51 + std::vector<Ort::Value> Forward(Ort::Value features,
  52 + Ort::Value features_length) override;
  53 +
  54 + /** Return the vocabulary size of the model
  55 + */
  56 + int32_t VocabSize() const override;
  57 +
  58 + /** SubsamplingFactor of the model
  59 + *
  60 + * For Citrinet, the subsampling factor is usually 4.
  61 + * For Conformer CTC, the subsampling factor is usually 8.
  62 + */
  63 + int32_t SubsamplingFactor() const override;
  64 +
  65 + /** Return an allocator for allocating memory
  66 + */
  67 + OrtAllocator *Allocator() const override;
  68 +
  69 + // WeNet CTC models do not support batch size > 1
  70 + bool SupportBatchProcessing() const override { return false; }
  71 +
  72 + private:
  73 + class Impl;
  74 + std::unique_ptr<Impl> impl_;
  75 +};
  76 +
  77 +} // namespace sherpa_onnx
  78 +
  79 +#endif // SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_H_
@@ -17,6 +17,7 @@ pybind11_add_module(_sherpa_onnx @@ -17,6 +17,7 @@ pybind11_add_module(_sherpa_onnx
17 offline-tts-model-config.cc 17 offline-tts-model-config.cc
18 offline-tts-vits-model-config.cc 18 offline-tts-vits-model-config.cc
19 offline-tts.cc 19 offline-tts.cc
  20 + offline-wenet-ctc-model-config.cc
20 offline-whisper-model-config.cc 21 offline-whisper-model-config.cc
21 offline-zipformer-ctc-model-config.cc 22 offline-zipformer-ctc-model-config.cc
22 online-lm-config.cc 23 online-lm-config.cc
@@ -12,6 +12,7 @@ @@ -12,6 +12,7 @@
12 #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" 12 #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
13 #include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h" 13 #include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h"
14 #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" 14 #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
  15 +#include "sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h"
15 #include "sherpa-onnx/python/csrc/offline-whisper-model-config.h" 16 #include "sherpa-onnx/python/csrc/offline-whisper-model-config.h"
16 #include "sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h" 17 #include "sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h"
17 18
@@ -24,6 +25,7 @@ void PybindOfflineModelConfig(py::module *m) { @@ -24,6 +25,7 @@ void PybindOfflineModelConfig(py::module *m) {
24 PybindOfflineWhisperModelConfig(m); 25 PybindOfflineWhisperModelConfig(m);
25 PybindOfflineTdnnModelConfig(m); 26 PybindOfflineTdnnModelConfig(m);
26 PybindOfflineZipformerCtcModelConfig(m); 27 PybindOfflineZipformerCtcModelConfig(m);
  28 + PybindOfflineWenetCtcModelConfig(m);
27 29
28 using PyClass = OfflineModelConfig; 30 using PyClass = OfflineModelConfig;
29 py::class_<PyClass>(*m, "OfflineModelConfig") 31 py::class_<PyClass>(*m, "OfflineModelConfig")
@@ -32,7 +34,8 @@ void PybindOfflineModelConfig(py::module *m) { @@ -32,7 +34,8 @@ void PybindOfflineModelConfig(py::module *m) {
32 const OfflineNemoEncDecCtcModelConfig &, 34 const OfflineNemoEncDecCtcModelConfig &,
33 const OfflineWhisperModelConfig &, 35 const OfflineWhisperModelConfig &,
34 const OfflineTdnnModelConfig &, 36 const OfflineTdnnModelConfig &,
35 - const OfflineZipformerCtcModelConfig &, const std::string &, 37 + const OfflineZipformerCtcModelConfig &,
  38 + const OfflineWenetCtcModelConfig &, const std::string &,
36 int32_t, bool, const std::string &, const std::string &>(), 39 int32_t, bool, const std::string &, const std::string &>(),
37 py::arg("transducer") = OfflineTransducerModelConfig(), 40 py::arg("transducer") = OfflineTransducerModelConfig(),
38 py::arg("paraformer") = OfflineParaformerModelConfig(), 41 py::arg("paraformer") = OfflineParaformerModelConfig(),
@@ -40,6 +43,7 @@ void PybindOfflineModelConfig(py::module *m) { @@ -40,6 +43,7 @@ void PybindOfflineModelConfig(py::module *m) {
40 py::arg("whisper") = OfflineWhisperModelConfig(), 43 py::arg("whisper") = OfflineWhisperModelConfig(),
41 py::arg("tdnn") = OfflineTdnnModelConfig(), 44 py::arg("tdnn") = OfflineTdnnModelConfig(),
42 py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(), 45 py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
  46 + py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
43 py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, 47 py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
44 py::arg("provider") = "cpu", py::arg("model_type") = "") 48 py::arg("provider") = "cpu", py::arg("model_type") = "")
45 .def_readwrite("transducer", &PyClass::transducer) 49 .def_readwrite("transducer", &PyClass::transducer)
@@ -48,6 +52,7 @@ void PybindOfflineModelConfig(py::module *m) { @@ -48,6 +52,7 @@ void PybindOfflineModelConfig(py::module *m) {
48 .def_readwrite("whisper", &PyClass::whisper) 52 .def_readwrite("whisper", &PyClass::whisper)
49 .def_readwrite("tdnn", &PyClass::tdnn) 53 .def_readwrite("tdnn", &PyClass::tdnn)
50 .def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc) 54 .def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc)
  55 + .def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
51 .def_readwrite("tokens", &PyClass::tokens) 56 .def_readwrite("tokens", &PyClass::tokens)
52 .def_readwrite("num_threads", &PyClass::num_threads) 57 .def_readwrite("num_threads", &PyClass::num_threads)
53 .def_readwrite("debug", &PyClass::debug) 58 .def_readwrite("debug", &PyClass::debug)
  1 +// sherpa-onnx/python/csrc/offline-wenet-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h"
  6 +
  7 +#include <string>
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +void PybindOfflineWenetCtcModelConfig(py::module *m) {
  15 + using PyClass = OfflineWenetCtcModelConfig;
  16 + py::class_<PyClass>(*m, "OfflineWenetCtcModelConfig")
  17 + .def(py::init<const std::string &>(), py::arg("model"))
  18 + .def_readwrite("model", &PyClass::model)
  19 + .def("__str__", &PyClass::ToString);
  20 +}
  21 +
  22 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-wenet-model-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineWenetCtcModelConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_