Fangjun Kuang
Committed by GitHub

Reduce model initialization time for offline speech recognition (#213)

@@ -387,6 +387,7 @@ void CNonStreamingSpeechRecognitionDlg::InitParaformer() { @@ -387,6 +387,7 @@ void CNonStreamingSpeechRecognitionDlg::InitParaformer() {
387 config_.model_config.tokens = tokens.c_str(); 387 config_.model_config.tokens = tokens.c_str();
388 config_.model_config.num_threads = 1; 388 config_.model_config.num_threads = 1;
389 config_.model_config.debug = 1; 389 config_.model_config.debug = 1;
  390 + config_.model_config.model_type = "paraformer";
390 391
391 config_.decoding_method = "greedy_search"; 392 config_.decoding_method = "greedy_search";
392 config_.max_active_paths = 4; 393 config_.max_active_paths = 4;
@@ -447,6 +448,7 @@ void CNonStreamingSpeechRecognitionDlg::InitRecognizer() { @@ -447,6 +448,7 @@ void CNonStreamingSpeechRecognitionDlg::InitRecognizer() {
447 config_.model_config.tokens = tokens.c_str(); 448 config_.model_config.tokens = tokens.c_str();
448 config_.model_config.num_threads = 1; 449 config_.model_config.num_threads = 1;
449 config_.model_config.debug = 0; 450 config_.model_config.debug = 0;
  451 + config_.model_config.model_type = "transducer";
450 452
451 config_.decoding_method = "greedy_search"; 453 config_.decoding_method = "greedy_search";
452 config_.max_active_paths = 4; 454 config_.max_active_paths = 4;
@@ -76,6 +76,8 @@ namespace SherpaOnnx @@ -76,6 +76,8 @@ namespace SherpaOnnx
76 Tokens = ""; 76 Tokens = "";
77 NumThreads = 1; 77 NumThreads = 1;
78 Debug = 0; 78 Debug = 0;
  79 + Provider = "cpu";
  80 + ModelType = "";
79 } 81 }
80 public OfflineTransducerModelConfig Transducer; 82 public OfflineTransducerModelConfig Transducer;
81 public OfflineParaformerModelConfig Paraformer; 83 public OfflineParaformerModelConfig Paraformer;
@@ -87,6 +89,12 @@ namespace SherpaOnnx @@ -87,6 +89,12 @@ namespace SherpaOnnx
87 public int NumThreads; 89 public int NumThreads;
88 90
89 public int Debug; 91 public int Debug;
  92 +
  93 + [MarshalAs(UnmanagedType.LPStr)]
  94 + public string Provider;
  95 +
  96 + [MarshalAs(UnmanagedType.LPStr)]
  97 + public string ModelType;
90 } 98 }
91 99
92 [StructLayout(LayoutKind.Sequential)] 100 [StructLayout(LayoutKind.Sequential)]
@@ -33,23 +33,33 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( @@ -33,23 +33,33 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
33 const SherpaOnnxOnlineRecognizerConfig *config) { 33 const SherpaOnnxOnlineRecognizerConfig *config) {
34 sherpa_onnx::OnlineRecognizerConfig recognizer_config; 34 sherpa_onnx::OnlineRecognizerConfig recognizer_config;
35 35
36 - recognizer_config.feat_config.sampling_rate = SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000);  
37 - recognizer_config.feat_config.feature_dim = SHERPA_ONNX_OR(config->feat_config.feature_dim, 80); 36 + recognizer_config.feat_config.sampling_rate =
  37 + SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000);
  38 + recognizer_config.feat_config.feature_dim =
  39 + SHERPA_ONNX_OR(config->feat_config.feature_dim, 80);
38 40
39 recognizer_config.model_config.encoder_filename = 41 recognizer_config.model_config.encoder_filename =
40 SHERPA_ONNX_OR(config->model_config.encoder, ""); 42 SHERPA_ONNX_OR(config->model_config.encoder, "");
41 recognizer_config.model_config.decoder_filename = 43 recognizer_config.model_config.decoder_filename =
42 SHERPA_ONNX_OR(config->model_config.decoder, ""); 44 SHERPA_ONNX_OR(config->model_config.decoder, "");
43 - recognizer_config.model_config.joiner_filename = SHERPA_ONNX_OR(config->model_config.joiner, "");  
44 - recognizer_config.model_config.tokens = SHERPA_ONNX_OR(config->model_config.tokens, "");  
45 - recognizer_config.model_config.num_threads = SHERPA_ONNX_OR(config->model_config.num_threads, 1);  
46 - recognizer_config.model_config.provider = SHERPA_ONNX_OR(config->model_config.provider, "cpu");  
47 - recognizer_config.model_config.debug = SHERPA_ONNX_OR(config->model_config.debug, 0);  
48 -  
49 - recognizer_config.decoding_method = SHERPA_ONNX_OR(config->decoding_method, "greedy_search");  
50 - recognizer_config.max_active_paths = SHERPA_ONNX_OR(config->max_active_paths, 4);  
51 -  
52 - recognizer_config.enable_endpoint = SHERPA_ONNX_OR(config->enable_endpoint, 0); 45 + recognizer_config.model_config.joiner_filename =
  46 + SHERPA_ONNX_OR(config->model_config.joiner, "");
  47 + recognizer_config.model_config.tokens =
  48 + SHERPA_ONNX_OR(config->model_config.tokens, "");
  49 + recognizer_config.model_config.num_threads =
  50 + SHERPA_ONNX_OR(config->model_config.num_threads, 1);
  51 + recognizer_config.model_config.provider =
  52 + SHERPA_ONNX_OR(config->model_config.provider, "cpu");
  53 + recognizer_config.model_config.debug =
  54 + SHERPA_ONNX_OR(config->model_config.debug, 0);
  55 +
  56 + recognizer_config.decoding_method =
  57 + SHERPA_ONNX_OR(config->decoding_method, "greedy_search");
  58 + recognizer_config.max_active_paths =
  59 + SHERPA_ONNX_OR(config->max_active_paths, 4);
  60 +
  61 + recognizer_config.enable_endpoint =
  62 + SHERPA_ONNX_OR(config->enable_endpoint, 0);
53 63
54 recognizer_config.endpoint_config.rule1.min_trailing_silence = 64 recognizer_config.endpoint_config.rule1.min_trailing_silence =
55 SHERPA_ONNX_OR(config->rule1_min_trailing_silence, 2.4); 65 SHERPA_ONNX_OR(config->rule1_min_trailing_silence, 2.4);
@@ -173,9 +183,11 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( @@ -173,9 +183,11 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
173 const SherpaOnnxOfflineRecognizerConfig *config) { 183 const SherpaOnnxOfflineRecognizerConfig *config) {
174 sherpa_onnx::OfflineRecognizerConfig recognizer_config; 184 sherpa_onnx::OfflineRecognizerConfig recognizer_config;
175 185
176 - recognizer_config.feat_config.sampling_rate = SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000); 186 + recognizer_config.feat_config.sampling_rate =
  187 + SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000);
177 188
178 - recognizer_config.feat_config.feature_dim = SHERPA_ONNX_OR(config->feat_config.feature_dim, 80); 189 + recognizer_config.feat_config.feature_dim =
  190 + SHERPA_ONNX_OR(config->feat_config.feature_dim, 80);
179 191
180 recognizer_config.model_config.transducer.encoder_filename = 192 recognizer_config.model_config.transducer.encoder_filename =
181 SHERPA_ONNX_OR(config->model_config.transducer.encoder, ""); 193 SHERPA_ONNX_OR(config->model_config.transducer.encoder, "");
@@ -184,7 +196,7 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( @@ -184,7 +196,7 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
184 SHERPA_ONNX_OR(config->model_config.transducer.decoder, ""); 196 SHERPA_ONNX_OR(config->model_config.transducer.decoder, "");
185 197
186 recognizer_config.model_config.transducer.joiner_filename = 198 recognizer_config.model_config.transducer.joiner_filename =
187 - SHERPA_ONNX_OR(config->model_config.transducer.joiner,""); 199 + SHERPA_ONNX_OR(config->model_config.transducer.joiner, "");
188 200
189 recognizer_config.model_config.paraformer.model = 201 recognizer_config.model_config.paraformer.model =
190 SHERPA_ONNX_OR(config->model_config.paraformer.model, ""); 202 SHERPA_ONNX_OR(config->model_config.paraformer.model, "");
@@ -192,15 +204,26 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( @@ -192,15 +204,26 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
192 recognizer_config.model_config.nemo_ctc.model = 204 recognizer_config.model_config.nemo_ctc.model =
193 SHERPA_ONNX_OR(config->model_config.nemo_ctc.model, ""); 205 SHERPA_ONNX_OR(config->model_config.nemo_ctc.model, "");
194 206
195 - recognizer_config.model_config.tokens = SHERPA_ONNX_OR(config->model_config.tokens, "");  
196 - recognizer_config.model_config.num_threads = SHERPA_ONNX_OR(config->model_config.num_threads, 1);  
197 - recognizer_config.model_config.debug = SHERPA_ONNX_OR(config->model_config.debug, 0);  
198 -  
199 - recognizer_config.lm_config.model = SHERPA_ONNX_OR(config->lm_config.model, "");  
200 - recognizer_config.lm_config.scale = SHERPA_ONNX_OR(config->lm_config.scale, 1.0);  
201 -  
202 - recognizer_config.decoding_method = SHERPA_ONNX_OR(config->decoding_method, "greedy_search");  
203 - recognizer_config.max_active_paths = SHERPA_ONNX_OR(config->max_active_paths, 4); 207 + recognizer_config.model_config.tokens =
  208 + SHERPA_ONNX_OR(config->model_config.tokens, "");
  209 + recognizer_config.model_config.num_threads =
  210 + SHERPA_ONNX_OR(config->model_config.num_threads, 1);
  211 + recognizer_config.model_config.debug =
  212 + SHERPA_ONNX_OR(config->model_config.debug, 0);
  213 + recognizer_config.model_config.provider =
  214 + SHERPA_ONNX_OR(config->model_config.provider, "cpu");
  215 + recognizer_config.model_config.model_type =
  216 + SHERPA_ONNX_OR(config->model_config.model_type, "");
  217 +
  218 + recognizer_config.lm_config.model =
  219 + SHERPA_ONNX_OR(config->lm_config.model, "");
  220 + recognizer_config.lm_config.scale =
  221 + SHERPA_ONNX_OR(config->lm_config.scale, 1.0);
  222 +
  223 + recognizer_config.decoding_method =
  224 + SHERPA_ONNX_OR(config->decoding_method, "greedy_search");
  225 + recognizer_config.max_active_paths =
  226 + SHERPA_ONNX_OR(config->max_active_paths, 4);
204 227
205 if (config->model_config.debug) { 228 if (config->model_config.debug) {
206 fprintf(stderr, "%s\n", recognizer_config.ToString().c_str()); 229 fprintf(stderr, "%s\n", recognizer_config.ToString().c_str());
@@ -272,6 +272,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig { @@ -272,6 +272,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig {
272 const char *tokens; 272 const char *tokens;
273 int32_t num_threads; 273 int32_t num_threads;
274 int32_t debug; 274 int32_t debug;
  275 + const char *provider;
  276 + const char *model_type;
275 } SherpaOnnxOfflineModelConfig; 277 } SherpaOnnxOfflineModelConfig;
276 278
277 SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig { 279 SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig {
@@ -25,6 +25,11 @@ void OfflineModelConfig::Register(ParseOptions *po) { @@ -25,6 +25,11 @@ void OfflineModelConfig::Register(ParseOptions *po) {
25 25
26 po->Register("provider", &provider, 26 po->Register("provider", &provider,
27 "Specify a provider to use: cpu, cuda, coreml"); 27 "Specify a provider to use: cpu, cuda, coreml");
  28 +
  29 + po->Register("model-type", &model_type,
  30 + "Specify it to reduce model initialization time. "
  31 + "Valid values are: transducer, paraformer, nemo_ctc. "
  32 + "All other values lead to loading the model twice.");
28 } 33 }
29 34
30 bool OfflineModelConfig::Validate() const { 35 bool OfflineModelConfig::Validate() const {
@@ -34,7 +39,7 @@ bool OfflineModelConfig::Validate() const { @@ -34,7 +39,7 @@ bool OfflineModelConfig::Validate() const {
34 } 39 }
35 40
36 if (!FileExists(tokens)) { 41 if (!FileExists(tokens)) {
37 - SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str()); 42 + SHERPA_ONNX_LOGE("tokens: %s does not exist", tokens.c_str());
38 return false; 43 return false;
39 } 44 }
40 45
@@ -59,7 +64,8 @@ std::string OfflineModelConfig::ToString() const { @@ -59,7 +64,8 @@ std::string OfflineModelConfig::ToString() const {
59 os << "tokens=\"" << tokens << "\", "; 64 os << "tokens=\"" << tokens << "\", ";
60 os << "num_threads=" << num_threads << ", "; 65 os << "num_threads=" << num_threads << ", ";
61 os << "debug=" << (debug ? "True" : "False") << ", "; 66 os << "debug=" << (debug ? "True" : "False") << ", ";
62 - os << "provider=\"" << provider << "\")"; 67 + os << "provider=\"" << provider << "\", ";
  68 + os << "model_type=\"" << model_type << "\")";
63 69
64 return os.str(); 70 return os.str();
65 } 71 }
@@ -22,19 +22,31 @@ struct OfflineModelConfig { @@ -22,19 +22,31 @@ struct OfflineModelConfig {
22 bool debug = false; 22 bool debug = false;
23 std::string provider = "cpu"; 23 std::string provider = "cpu";
24 24
  25 + // With the help of this field, we only need to load the model once
  26 + // instead of twice; and therefore it reduces initialization time.
  27 + //
  28 + // Valid values:
  29 + // - transducer. The given model is from icefall
  30 + // - paraformer. It is a paraformer model
  31 + // - nemo_ctc. It is a NeMo CTC model.
  32 + //
  33 + // All other values are invalid and lead to loading the model twice.
  34 + std::string model_type;
  35 +
25 OfflineModelConfig() = default; 36 OfflineModelConfig() = default;
26 OfflineModelConfig(const OfflineTransducerModelConfig &transducer, 37 OfflineModelConfig(const OfflineTransducerModelConfig &transducer,
27 const OfflineParaformerModelConfig &paraformer, 38 const OfflineParaformerModelConfig &paraformer,
28 const OfflineNemoEncDecCtcModelConfig &nemo_ctc, 39 const OfflineNemoEncDecCtcModelConfig &nemo_ctc,
29 const std::string &tokens, int32_t num_threads, bool debug, 40 const std::string &tokens, int32_t num_threads, bool debug,
30 - const std::string &provider) 41 + const std::string &provider, const std::string &model_type)
31 : transducer(transducer), 42 : transducer(transducer),
32 paraformer(paraformer), 43 paraformer(paraformer),
33 nemo_ctc(nemo_ctc), 44 nemo_ctc(nemo_ctc),
34 tokens(tokens), 45 tokens(tokens),
35 num_threads(num_threads), 46 num_threads(num_threads),
36 debug(debug), 47 debug(debug),
37 - provider(provider) {} 48 + provider(provider),
  49 + model_type(model_type) {}
38 50
39 void Register(ParseOptions *po); 51 void Register(ParseOptions *po);
40 bool Validate() const; 52 bool Validate() const;
@@ -18,6 +18,21 @@ namespace sherpa_onnx { @@ -18,6 +18,21 @@ namespace sherpa_onnx {
18 18
19 std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( 19 std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
20 const OfflineRecognizerConfig &config) { 20 const OfflineRecognizerConfig &config) {
  21 + if (!config.model_config.model_type.empty()) {
  22 + const auto &model_type = config.model_config.model_type;
  23 + if (model_type == "transducer") {
  24 + return std::make_unique<OfflineRecognizerTransducerImpl>(config);
  25 + } else if (model_type == "paraformer") {
  26 + return std::make_unique<OfflineRecognizerParaformerImpl>(config);
  27 + } else if (model_type == "nemo_ctc") {
  28 + return std::make_unique<OfflineRecognizerCtcImpl>(config);
  29 + } else {
  30 + SHERPA_ONNX_LOGE(
  31 + "Invalid model_type: %s. Trying to load the model to get its type",
  32 + model_type.c_str());
  33 + }
  34 + }
  35 +
21 Ort::Env env(ORT_LOGGING_LEVEL_ERROR); 36 Ort::Env env(ORT_LOGGING_LEVEL_ERROR);
22 37
23 Ort::SessionOptions sess_opts; 38 Ort::SessionOptions sess_opts;
@@ -18,17 +18,17 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) { @@ -18,17 +18,17 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) {
18 18
19 bool OfflineTransducerModelConfig::Validate() const { 19 bool OfflineTransducerModelConfig::Validate() const {
20 if (!FileExists(encoder_filename)) { 20 if (!FileExists(encoder_filename)) {
21 - SHERPA_ONNX_LOGE("%s does not exist", encoder_filename.c_str()); 21 + SHERPA_ONNX_LOGE("encoder: %s does not exist", encoder_filename.c_str());
22 return false; 22 return false;
23 } 23 }
24 24
25 if (!FileExists(decoder_filename)) { 25 if (!FileExists(decoder_filename)) {
26 - SHERPA_ONNX_LOGE("%s does not exist", decoder_filename.c_str()); 26 + SHERPA_ONNX_LOGE("decoder: %s does not exist", decoder_filename.c_str());
27 return false; 27 return false;
28 } 28 }
29 29
30 if (!FileExists(joiner_filename)) { 30 if (!FileExists(joiner_filename)) {
31 - SHERPA_ONNX_LOGE("%s does not exist", joiner_filename.c_str()); 31 + SHERPA_ONNX_LOGE("joiner: %s does not exist", joiner_filename.c_str());
32 return false; 32 return false;
33 } 33 }
34 34
@@ -21,15 +21,16 @@ void PybindOfflineModelConfig(py::module *m) { @@ -21,15 +21,16 @@ void PybindOfflineModelConfig(py::module *m) {
21 21
22 using PyClass = OfflineModelConfig; 22 using PyClass = OfflineModelConfig;
23 py::class_<PyClass>(*m, "OfflineModelConfig") 23 py::class_<PyClass>(*m, "OfflineModelConfig")
24 - .def(py::init<const OfflineTransducerModelConfig &,  
25 - const OfflineParaformerModelConfig &,  
26 - const OfflineNemoEncDecCtcModelConfig &,  
27 - const std::string &, int32_t, bool, const std::string &>(),  
28 - py::arg("transducer") = OfflineTransducerModelConfig(),  
29 - py::arg("paraformer") = OfflineParaformerModelConfig(),  
30 - py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),  
31 - py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,  
32 - py::arg("provider") = "cpu") 24 + .def(
  25 + py::init<const OfflineTransducerModelConfig &,
  26 + const OfflineParaformerModelConfig &,
  27 + const OfflineNemoEncDecCtcModelConfig &, const std::string &,
  28 + int32_t, bool, const std::string &, const std::string &>(),
  29 + py::arg("transducer") = OfflineTransducerModelConfig(),
  30 + py::arg("paraformer") = OfflineParaformerModelConfig(),
  31 + py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
  32 + py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
  33 + py::arg("provider") = "cpu", py::arg("model_type") = "")
33 .def_readwrite("transducer", &PyClass::transducer) 34 .def_readwrite("transducer", &PyClass::transducer)
34 .def_readwrite("paraformer", &PyClass::paraformer) 35 .def_readwrite("paraformer", &PyClass::paraformer)
35 .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) 36 .def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
@@ -37,6 +38,7 @@ void PybindOfflineModelConfig(py::module *m) { @@ -37,6 +38,7 @@ void PybindOfflineModelConfig(py::module *m) {
37 .def_readwrite("num_threads", &PyClass::num_threads) 38 .def_readwrite("num_threads", &PyClass::num_threads)
38 .def_readwrite("debug", &PyClass::debug) 39 .def_readwrite("debug", &PyClass::debug)
39 .def_readwrite("provider", &PyClass::provider) 40 .def_readwrite("provider", &PyClass::provider)
  41 + .def_readwrite("model_type", &PyClass::model_type)
40 .def("__str__", &PyClass::ToString); 42 .def("__str__", &PyClass::ToString);
41 } 43 }
42 44
@@ -86,6 +86,7 @@ class OfflineRecognizer(object): @@ -86,6 +86,7 @@ class OfflineRecognizer(object):
86 num_threads=num_threads, 86 num_threads=num_threads,
87 debug=debug, 87 debug=debug,
88 provider=provider, 88 provider=provider,
  89 + model_type="transducer",
89 ) 90 )
90 91
91 feat_config = OfflineFeatureExtractorConfig( 92 feat_config = OfflineFeatureExtractorConfig(
@@ -149,6 +150,7 @@ class OfflineRecognizer(object): @@ -149,6 +150,7 @@ class OfflineRecognizer(object):
149 num_threads=num_threads, 150 num_threads=num_threads,
150 debug=debug, 151 debug=debug,
151 provider=provider, 152 provider=provider,
  153 + model_type="paraformer",
152 ) 154 )
153 155
154 feat_config = OfflineFeatureExtractorConfig( 156 feat_config = OfflineFeatureExtractorConfig(
@@ -211,6 +213,7 @@ class OfflineRecognizer(object): @@ -211,6 +213,7 @@ class OfflineRecognizer(object):
211 num_threads=num_threads, 213 num_threads=num_threads,
212 debug=debug, 214 debug=debug,
213 provider=provider, 215 provider=provider,
  216 + model_type="nemo_ctc",
214 ) 217 )
215 218
216 feat_config = OfflineFeatureExtractorConfig( 219 feat_config = OfflineFeatureExtractorConfig(