Committed by
GitHub
Reduce model initialization time for offline speech recognition (#213)
正在显示
10 个修改的文件
包含
108 行增加
和
35 行删除
| @@ -387,6 +387,7 @@ void CNonStreamingSpeechRecognitionDlg::InitParaformer() { | @@ -387,6 +387,7 @@ void CNonStreamingSpeechRecognitionDlg::InitParaformer() { | ||
| 387 | config_.model_config.tokens = tokens.c_str(); | 387 | config_.model_config.tokens = tokens.c_str(); |
| 388 | config_.model_config.num_threads = 1; | 388 | config_.model_config.num_threads = 1; |
| 389 | config_.model_config.debug = 1; | 389 | config_.model_config.debug = 1; |
| 390 | + config_.model_config.model_type = "paraformer"; | ||
| 390 | 391 | ||
| 391 | config_.decoding_method = "greedy_search"; | 392 | config_.decoding_method = "greedy_search"; |
| 392 | config_.max_active_paths = 4; | 393 | config_.max_active_paths = 4; |
| @@ -447,6 +448,7 @@ void CNonStreamingSpeechRecognitionDlg::InitRecognizer() { | @@ -447,6 +448,7 @@ void CNonStreamingSpeechRecognitionDlg::InitRecognizer() { | ||
| 447 | config_.model_config.tokens = tokens.c_str(); | 448 | config_.model_config.tokens = tokens.c_str(); |
| 448 | config_.model_config.num_threads = 1; | 449 | config_.model_config.num_threads = 1; |
| 449 | config_.model_config.debug = 0; | 450 | config_.model_config.debug = 0; |
| 451 | + config_.model_config.model_type = "transducer"; | ||
| 450 | 452 | ||
| 451 | config_.decoding_method = "greedy_search"; | 453 | config_.decoding_method = "greedy_search"; |
| 452 | config_.max_active_paths = 4; | 454 | config_.max_active_paths = 4; |
| @@ -76,6 +76,8 @@ namespace SherpaOnnx | @@ -76,6 +76,8 @@ namespace SherpaOnnx | ||
| 76 | Tokens = ""; | 76 | Tokens = ""; |
| 77 | NumThreads = 1; | 77 | NumThreads = 1; |
| 78 | Debug = 0; | 78 | Debug = 0; |
| 79 | + Provider = "cpu"; | ||
| 80 | + ModelType = ""; | ||
| 79 | } | 81 | } |
| 80 | public OfflineTransducerModelConfig Transducer; | 82 | public OfflineTransducerModelConfig Transducer; |
| 81 | public OfflineParaformerModelConfig Paraformer; | 83 | public OfflineParaformerModelConfig Paraformer; |
| @@ -87,6 +89,12 @@ namespace SherpaOnnx | @@ -87,6 +89,12 @@ namespace SherpaOnnx | ||
| 87 | public int NumThreads; | 89 | public int NumThreads; |
| 88 | 90 | ||
| 89 | public int Debug; | 91 | public int Debug; |
| 92 | + | ||
| 93 | + [MarshalAs(UnmanagedType.LPStr)] | ||
| 94 | + public string Provider; | ||
| 95 | + | ||
| 96 | + [MarshalAs(UnmanagedType.LPStr)] | ||
| 97 | + public string ModelType; | ||
| 90 | } | 98 | } |
| 91 | 99 | ||
| 92 | [StructLayout(LayoutKind.Sequential)] | 100 | [StructLayout(LayoutKind.Sequential)] |
| @@ -33,23 +33,33 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( | @@ -33,23 +33,33 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( | ||
| 33 | const SherpaOnnxOnlineRecognizerConfig *config) { | 33 | const SherpaOnnxOnlineRecognizerConfig *config) { |
| 34 | sherpa_onnx::OnlineRecognizerConfig recognizer_config; | 34 | sherpa_onnx::OnlineRecognizerConfig recognizer_config; |
| 35 | 35 | ||
| 36 | - recognizer_config.feat_config.sampling_rate = SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000); | ||
| 37 | - recognizer_config.feat_config.feature_dim = SHERPA_ONNX_OR(config->feat_config.feature_dim, 80); | 36 | + recognizer_config.feat_config.sampling_rate = |
| 37 | + SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000); | ||
| 38 | + recognizer_config.feat_config.feature_dim = | ||
| 39 | + SHERPA_ONNX_OR(config->feat_config.feature_dim, 80); | ||
| 38 | 40 | ||
| 39 | recognizer_config.model_config.encoder_filename = | 41 | recognizer_config.model_config.encoder_filename = |
| 40 | SHERPA_ONNX_OR(config->model_config.encoder, ""); | 42 | SHERPA_ONNX_OR(config->model_config.encoder, ""); |
| 41 | recognizer_config.model_config.decoder_filename = | 43 | recognizer_config.model_config.decoder_filename = |
| 42 | SHERPA_ONNX_OR(config->model_config.decoder, ""); | 44 | SHERPA_ONNX_OR(config->model_config.decoder, ""); |
| 43 | - recognizer_config.model_config.joiner_filename = SHERPA_ONNX_OR(config->model_config.joiner, ""); | ||
| 44 | - recognizer_config.model_config.tokens = SHERPA_ONNX_OR(config->model_config.tokens, ""); | ||
| 45 | - recognizer_config.model_config.num_threads = SHERPA_ONNX_OR(config->model_config.num_threads, 1); | ||
| 46 | - recognizer_config.model_config.provider = SHERPA_ONNX_OR(config->model_config.provider, "cpu"); | ||
| 47 | - recognizer_config.model_config.debug = SHERPA_ONNX_OR(config->model_config.debug, 0); | ||
| 48 | - | ||
| 49 | - recognizer_config.decoding_method = SHERPA_ONNX_OR(config->decoding_method, "greedy_search"); | ||
| 50 | - recognizer_config.max_active_paths = SHERPA_ONNX_OR(config->max_active_paths, 4); | ||
| 51 | - | ||
| 52 | - recognizer_config.enable_endpoint = SHERPA_ONNX_OR(config->enable_endpoint, 0); | 45 | + recognizer_config.model_config.joiner_filename = |
| 46 | + SHERPA_ONNX_OR(config->model_config.joiner, ""); | ||
| 47 | + recognizer_config.model_config.tokens = | ||
| 48 | + SHERPA_ONNX_OR(config->model_config.tokens, ""); | ||
| 49 | + recognizer_config.model_config.num_threads = | ||
| 50 | + SHERPA_ONNX_OR(config->model_config.num_threads, 1); | ||
| 51 | + recognizer_config.model_config.provider = | ||
| 52 | + SHERPA_ONNX_OR(config->model_config.provider, "cpu"); | ||
| 53 | + recognizer_config.model_config.debug = | ||
| 54 | + SHERPA_ONNX_OR(config->model_config.debug, 0); | ||
| 55 | + | ||
| 56 | + recognizer_config.decoding_method = | ||
| 57 | + SHERPA_ONNX_OR(config->decoding_method, "greedy_search"); | ||
| 58 | + recognizer_config.max_active_paths = | ||
| 59 | + SHERPA_ONNX_OR(config->max_active_paths, 4); | ||
| 60 | + | ||
| 61 | + recognizer_config.enable_endpoint = | ||
| 62 | + SHERPA_ONNX_OR(config->enable_endpoint, 0); | ||
| 53 | 63 | ||
| 54 | recognizer_config.endpoint_config.rule1.min_trailing_silence = | 64 | recognizer_config.endpoint_config.rule1.min_trailing_silence = |
| 55 | SHERPA_ONNX_OR(config->rule1_min_trailing_silence, 2.4); | 65 | SHERPA_ONNX_OR(config->rule1_min_trailing_silence, 2.4); |
| @@ -173,9 +183,11 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( | @@ -173,9 +183,11 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( | ||
| 173 | const SherpaOnnxOfflineRecognizerConfig *config) { | 183 | const SherpaOnnxOfflineRecognizerConfig *config) { |
| 174 | sherpa_onnx::OfflineRecognizerConfig recognizer_config; | 184 | sherpa_onnx::OfflineRecognizerConfig recognizer_config; |
| 175 | 185 | ||
| 176 | - recognizer_config.feat_config.sampling_rate = SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000); | 186 | + recognizer_config.feat_config.sampling_rate = |
| 187 | + SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000); | ||
| 177 | 188 | ||
| 178 | - recognizer_config.feat_config.feature_dim = SHERPA_ONNX_OR(config->feat_config.feature_dim, 80); | 189 | + recognizer_config.feat_config.feature_dim = |
| 190 | + SHERPA_ONNX_OR(config->feat_config.feature_dim, 80); | ||
| 179 | 191 | ||
| 180 | recognizer_config.model_config.transducer.encoder_filename = | 192 | recognizer_config.model_config.transducer.encoder_filename = |
| 181 | SHERPA_ONNX_OR(config->model_config.transducer.encoder, ""); | 193 | SHERPA_ONNX_OR(config->model_config.transducer.encoder, ""); |
| @@ -184,7 +196,7 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( | @@ -184,7 +196,7 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( | ||
| 184 | SHERPA_ONNX_OR(config->model_config.transducer.decoder, ""); | 196 | SHERPA_ONNX_OR(config->model_config.transducer.decoder, ""); |
| 185 | 197 | ||
| 186 | recognizer_config.model_config.transducer.joiner_filename = | 198 | recognizer_config.model_config.transducer.joiner_filename = |
| 187 | - SHERPA_ONNX_OR(config->model_config.transducer.joiner,""); | 199 | + SHERPA_ONNX_OR(config->model_config.transducer.joiner, ""); |
| 188 | 200 | ||
| 189 | recognizer_config.model_config.paraformer.model = | 201 | recognizer_config.model_config.paraformer.model = |
| 190 | SHERPA_ONNX_OR(config->model_config.paraformer.model, ""); | 202 | SHERPA_ONNX_OR(config->model_config.paraformer.model, ""); |
| @@ -192,15 +204,26 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( | @@ -192,15 +204,26 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( | ||
| 192 | recognizer_config.model_config.nemo_ctc.model = | 204 | recognizer_config.model_config.nemo_ctc.model = |
| 193 | SHERPA_ONNX_OR(config->model_config.nemo_ctc.model, ""); | 205 | SHERPA_ONNX_OR(config->model_config.nemo_ctc.model, ""); |
| 194 | 206 | ||
| 195 | - recognizer_config.model_config.tokens = SHERPA_ONNX_OR(config->model_config.tokens, ""); | ||
| 196 | - recognizer_config.model_config.num_threads = SHERPA_ONNX_OR(config->model_config.num_threads, 1); | ||
| 197 | - recognizer_config.model_config.debug = SHERPA_ONNX_OR(config->model_config.debug, 0); | ||
| 198 | - | ||
| 199 | - recognizer_config.lm_config.model = SHERPA_ONNX_OR(config->lm_config.model, ""); | ||
| 200 | - recognizer_config.lm_config.scale = SHERPA_ONNX_OR(config->lm_config.scale, 1.0); | ||
| 201 | - | ||
| 202 | - recognizer_config.decoding_method = SHERPA_ONNX_OR(config->decoding_method, "greedy_search"); | ||
| 203 | - recognizer_config.max_active_paths = SHERPA_ONNX_OR(config->max_active_paths, 4); | 207 | + recognizer_config.model_config.tokens = |
| 208 | + SHERPA_ONNX_OR(config->model_config.tokens, ""); | ||
| 209 | + recognizer_config.model_config.num_threads = | ||
| 210 | + SHERPA_ONNX_OR(config->model_config.num_threads, 1); | ||
| 211 | + recognizer_config.model_config.debug = | ||
| 212 | + SHERPA_ONNX_OR(config->model_config.debug, 0); | ||
| 213 | + recognizer_config.model_config.provider = | ||
| 214 | + SHERPA_ONNX_OR(config->model_config.provider, "cpu"); | ||
| 215 | + recognizer_config.model_config.model_type = | ||
| 216 | + SHERPA_ONNX_OR(config->model_config.model_type, ""); | ||
| 217 | + | ||
| 218 | + recognizer_config.lm_config.model = | ||
| 219 | + SHERPA_ONNX_OR(config->lm_config.model, ""); | ||
| 220 | + recognizer_config.lm_config.scale = | ||
| 221 | + SHERPA_ONNX_OR(config->lm_config.scale, 1.0); | ||
| 222 | + | ||
| 223 | + recognizer_config.decoding_method = | ||
| 224 | + SHERPA_ONNX_OR(config->decoding_method, "greedy_search"); | ||
| 225 | + recognizer_config.max_active_paths = | ||
| 226 | + SHERPA_ONNX_OR(config->max_active_paths, 4); | ||
| 204 | 227 | ||
| 205 | if (config->model_config.debug) { | 228 | if (config->model_config.debug) { |
| 206 | fprintf(stderr, "%s\n", recognizer_config.ToString().c_str()); | 229 | fprintf(stderr, "%s\n", recognizer_config.ToString().c_str()); |
| @@ -272,6 +272,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig { | @@ -272,6 +272,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig { | ||
| 272 | const char *tokens; | 272 | const char *tokens; |
| 273 | int32_t num_threads; | 273 | int32_t num_threads; |
| 274 | int32_t debug; | 274 | int32_t debug; |
| 275 | + const char *provider; | ||
| 276 | + const char *model_type; | ||
| 275 | } SherpaOnnxOfflineModelConfig; | 277 | } SherpaOnnxOfflineModelConfig; |
| 276 | 278 | ||
| 277 | SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig { | 279 | SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig { |
| @@ -25,6 +25,11 @@ void OfflineModelConfig::Register(ParseOptions *po) { | @@ -25,6 +25,11 @@ void OfflineModelConfig::Register(ParseOptions *po) { | ||
| 25 | 25 | ||
| 26 | po->Register("provider", &provider, | 26 | po->Register("provider", &provider, |
| 27 | "Specify a provider to use: cpu, cuda, coreml"); | 27 | "Specify a provider to use: cpu, cuda, coreml"); |
| 28 | + | ||
| 29 | + po->Register("model-type", &model_type, | ||
| 30 | + "Specify it to reduce model initialization time. " | ||
| 31 | + "Valid values are: transducer, paraformer, nemo_ctc. " | ||
| 32 | + "All other values lead to loading the model twice."); | ||
| 28 | } | 33 | } |
| 29 | 34 | ||
| 30 | bool OfflineModelConfig::Validate() const { | 35 | bool OfflineModelConfig::Validate() const { |
| @@ -34,7 +39,7 @@ bool OfflineModelConfig::Validate() const { | @@ -34,7 +39,7 @@ bool OfflineModelConfig::Validate() const { | ||
| 34 | } | 39 | } |
| 35 | 40 | ||
| 36 | if (!FileExists(tokens)) { | 41 | if (!FileExists(tokens)) { |
| 37 | - SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str()); | 42 | + SHERPA_ONNX_LOGE("tokens: %s does not exist", tokens.c_str()); |
| 38 | return false; | 43 | return false; |
| 39 | } | 44 | } |
| 40 | 45 | ||
| @@ -59,7 +64,8 @@ std::string OfflineModelConfig::ToString() const { | @@ -59,7 +64,8 @@ std::string OfflineModelConfig::ToString() const { | ||
| 59 | os << "tokens=\"" << tokens << "\", "; | 64 | os << "tokens=\"" << tokens << "\", "; |
| 60 | os << "num_threads=" << num_threads << ", "; | 65 | os << "num_threads=" << num_threads << ", "; |
| 61 | os << "debug=" << (debug ? "True" : "False") << ", "; | 66 | os << "debug=" << (debug ? "True" : "False") << ", "; |
| 62 | - os << "provider=\"" << provider << "\")"; | 67 | + os << "provider=\"" << provider << "\", "; |
| 68 | + os << "model_type=\"" << model_type << "\")"; | ||
| 63 | 69 | ||
| 64 | return os.str(); | 70 | return os.str(); |
| 65 | } | 71 | } |
| @@ -22,19 +22,31 @@ struct OfflineModelConfig { | @@ -22,19 +22,31 @@ struct OfflineModelConfig { | ||
| 22 | bool debug = false; | 22 | bool debug = false; |
| 23 | std::string provider = "cpu"; | 23 | std::string provider = "cpu"; |
| 24 | 24 | ||
| 25 | + // With the help of this field, we only need to load the model once | ||
| 26 | + // instead of twice; and therefore it reduces initialization time. | ||
| 27 | + // | ||
| 28 | + // Valid values: | ||
| 29 | + // - transducer. The given model is from icefall | ||
| 30 | + // - paraformer. It is a paraformer model | ||
| 31 | + // - nemo_ctc. It is a NeMo CTC model. | ||
| 32 | + // | ||
| 33 | + // All other values are invalid and lead to loading the model twice. | ||
| 34 | + std::string model_type; | ||
| 35 | + | ||
| 25 | OfflineModelConfig() = default; | 36 | OfflineModelConfig() = default; |
| 26 | OfflineModelConfig(const OfflineTransducerModelConfig &transducer, | 37 | OfflineModelConfig(const OfflineTransducerModelConfig &transducer, |
| 27 | const OfflineParaformerModelConfig ¶former, | 38 | const OfflineParaformerModelConfig ¶former, |
| 28 | const OfflineNemoEncDecCtcModelConfig &nemo_ctc, | 39 | const OfflineNemoEncDecCtcModelConfig &nemo_ctc, |
| 29 | const std::string &tokens, int32_t num_threads, bool debug, | 40 | const std::string &tokens, int32_t num_threads, bool debug, |
| 30 | - const std::string &provider) | 41 | + const std::string &provider, const std::string &model_type) |
| 31 | : transducer(transducer), | 42 | : transducer(transducer), |
| 32 | paraformer(paraformer), | 43 | paraformer(paraformer), |
| 33 | nemo_ctc(nemo_ctc), | 44 | nemo_ctc(nemo_ctc), |
| 34 | tokens(tokens), | 45 | tokens(tokens), |
| 35 | num_threads(num_threads), | 46 | num_threads(num_threads), |
| 36 | debug(debug), | 47 | debug(debug), |
| 37 | - provider(provider) {} | 48 | + provider(provider), |
| 49 | + model_type(model_type) {} | ||
| 38 | 50 | ||
| 39 | void Register(ParseOptions *po); | 51 | void Register(ParseOptions *po); |
| 40 | bool Validate() const; | 52 | bool Validate() const; |
| @@ -18,6 +18,21 @@ namespace sherpa_onnx { | @@ -18,6 +18,21 @@ namespace sherpa_onnx { | ||
| 18 | 18 | ||
| 19 | std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | 19 | std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( |
| 20 | const OfflineRecognizerConfig &config) { | 20 | const OfflineRecognizerConfig &config) { |
| 21 | + if (!config.model_config.model_type.empty()) { | ||
| 22 | + const auto &model_type = config.model_config.model_type; | ||
| 23 | + if (model_type == "transducer") { | ||
| 24 | + return std::make_unique<OfflineRecognizerTransducerImpl>(config); | ||
| 25 | + } else if (model_type == "paraformer") { | ||
| 26 | + return std::make_unique<OfflineRecognizerParaformerImpl>(config); | ||
| 27 | + } else if (model_type == "nemo_ctc") { | ||
| 28 | + return std::make_unique<OfflineRecognizerCtcImpl>(config); | ||
| 29 | + } else { | ||
| 30 | + SHERPA_ONNX_LOGE( | ||
| 31 | + "Invalid model_type: %s. Trying to load the model to get its type", | ||
| 32 | + model_type.c_str()); | ||
| 33 | + } | ||
| 34 | + } | ||
| 35 | + | ||
| 21 | Ort::Env env(ORT_LOGGING_LEVEL_ERROR); | 36 | Ort::Env env(ORT_LOGGING_LEVEL_ERROR); |
| 22 | 37 | ||
| 23 | Ort::SessionOptions sess_opts; | 38 | Ort::SessionOptions sess_opts; |
| @@ -18,17 +18,17 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) { | @@ -18,17 +18,17 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) { | ||
| 18 | 18 | ||
| 19 | bool OfflineTransducerModelConfig::Validate() const { | 19 | bool OfflineTransducerModelConfig::Validate() const { |
| 20 | if (!FileExists(encoder_filename)) { | 20 | if (!FileExists(encoder_filename)) { |
| 21 | - SHERPA_ONNX_LOGE("%s does not exist", encoder_filename.c_str()); | 21 | + SHERPA_ONNX_LOGE("encoder: %s does not exist", encoder_filename.c_str()); |
| 22 | return false; | 22 | return false; |
| 23 | } | 23 | } |
| 24 | 24 | ||
| 25 | if (!FileExists(decoder_filename)) { | 25 | if (!FileExists(decoder_filename)) { |
| 26 | - SHERPA_ONNX_LOGE("%s does not exist", decoder_filename.c_str()); | 26 | + SHERPA_ONNX_LOGE("decoder: %s does not exist", decoder_filename.c_str()); |
| 27 | return false; | 27 | return false; |
| 28 | } | 28 | } |
| 29 | 29 | ||
| 30 | if (!FileExists(joiner_filename)) { | 30 | if (!FileExists(joiner_filename)) { |
| 31 | - SHERPA_ONNX_LOGE("%s does not exist", joiner_filename.c_str()); | 31 | + SHERPA_ONNX_LOGE("joiner: %s does not exist", joiner_filename.c_str()); |
| 32 | return false; | 32 | return false; |
| 33 | } | 33 | } |
| 34 | 34 |
| @@ -21,15 +21,16 @@ void PybindOfflineModelConfig(py::module *m) { | @@ -21,15 +21,16 @@ void PybindOfflineModelConfig(py::module *m) { | ||
| 21 | 21 | ||
| 22 | using PyClass = OfflineModelConfig; | 22 | using PyClass = OfflineModelConfig; |
| 23 | py::class_<PyClass>(*m, "OfflineModelConfig") | 23 | py::class_<PyClass>(*m, "OfflineModelConfig") |
| 24 | - .def(py::init<const OfflineTransducerModelConfig &, | 24 | + .def( |
| 25 | + py::init<const OfflineTransducerModelConfig &, | ||
| 25 | const OfflineParaformerModelConfig &, | 26 | const OfflineParaformerModelConfig &, |
| 26 | - const OfflineNemoEncDecCtcModelConfig &, | ||
| 27 | - const std::string &, int32_t, bool, const std::string &>(), | 27 | + const OfflineNemoEncDecCtcModelConfig &, const std::string &, |
| 28 | + int32_t, bool, const std::string &, const std::string &>(), | ||
| 28 | py::arg("transducer") = OfflineTransducerModelConfig(), | 29 | py::arg("transducer") = OfflineTransducerModelConfig(), |
| 29 | py::arg("paraformer") = OfflineParaformerModelConfig(), | 30 | py::arg("paraformer") = OfflineParaformerModelConfig(), |
| 30 | py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), | 31 | py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), |
| 31 | py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, | 32 | py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, |
| 32 | - py::arg("provider") = "cpu") | 33 | + py::arg("provider") = "cpu", py::arg("model_type") = "") |
| 33 | .def_readwrite("transducer", &PyClass::transducer) | 34 | .def_readwrite("transducer", &PyClass::transducer) |
| 34 | .def_readwrite("paraformer", &PyClass::paraformer) | 35 | .def_readwrite("paraformer", &PyClass::paraformer) |
| 35 | .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) | 36 | .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) |
| @@ -37,6 +38,7 @@ void PybindOfflineModelConfig(py::module *m) { | @@ -37,6 +38,7 @@ void PybindOfflineModelConfig(py::module *m) { | ||
| 37 | .def_readwrite("num_threads", &PyClass::num_threads) | 38 | .def_readwrite("num_threads", &PyClass::num_threads) |
| 38 | .def_readwrite("debug", &PyClass::debug) | 39 | .def_readwrite("debug", &PyClass::debug) |
| 39 | .def_readwrite("provider", &PyClass::provider) | 40 | .def_readwrite("provider", &PyClass::provider) |
| 41 | + .def_readwrite("model_type", &PyClass::model_type) | ||
| 40 | .def("__str__", &PyClass::ToString); | 42 | .def("__str__", &PyClass::ToString); |
| 41 | } | 43 | } |
| 42 | 44 |
| @@ -86,6 +86,7 @@ class OfflineRecognizer(object): | @@ -86,6 +86,7 @@ class OfflineRecognizer(object): | ||
| 86 | num_threads=num_threads, | 86 | num_threads=num_threads, |
| 87 | debug=debug, | 87 | debug=debug, |
| 88 | provider=provider, | 88 | provider=provider, |
| 89 | + model_type="transducer", | ||
| 89 | ) | 90 | ) |
| 90 | 91 | ||
| 91 | feat_config = OfflineFeatureExtractorConfig( | 92 | feat_config = OfflineFeatureExtractorConfig( |
| @@ -149,6 +150,7 @@ class OfflineRecognizer(object): | @@ -149,6 +150,7 @@ class OfflineRecognizer(object): | ||
| 149 | num_threads=num_threads, | 150 | num_threads=num_threads, |
| 150 | debug=debug, | 151 | debug=debug, |
| 151 | provider=provider, | 152 | provider=provider, |
| 153 | + model_type="paraformer", | ||
| 152 | ) | 154 | ) |
| 153 | 155 | ||
| 154 | feat_config = OfflineFeatureExtractorConfig( | 156 | feat_config = OfflineFeatureExtractorConfig( |
| @@ -211,6 +213,7 @@ class OfflineRecognizer(object): | @@ -211,6 +213,7 @@ class OfflineRecognizer(object): | ||
| 211 | num_threads=num_threads, | 213 | num_threads=num_threads, |
| 212 | debug=debug, | 214 | debug=debug, |
| 213 | provider=provider, | 215 | provider=provider, |
| 216 | + model_type="nemo_ctc", | ||
| 214 | ) | 217 | ) |
| 215 | 218 | ||
| 216 | feat_config = OfflineFeatureExtractorConfig( | 219 | feat_config = OfflineFeatureExtractorConfig( |
-
请 注册 或 登录 后发表评论