Wilson Wongso
Committed by GitHub

Reduce model initialization time for online speech recognition (#215)

* Reduce model initialization time for online speech recognition

* Fixed Styling

---------

Co-authored-by: w11wo <wilsowong961@gmail.com>
@@ -50,6 +50,8 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( @@ -50,6 +50,8 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
50 SHERPA_ONNX_OR(config->model_config.num_threads, 1); 50 SHERPA_ONNX_OR(config->model_config.num_threads, 1);
51 recognizer_config.model_config.provider = 51 recognizer_config.model_config.provider =
52 SHERPA_ONNX_OR(config->model_config.provider, "cpu"); 52 SHERPA_ONNX_OR(config->model_config.provider, "cpu");
  53 + recognizer_config.model_config.model_type =
  54 + SHERPA_ONNX_OR(config->model_config.model_type, "");
53 recognizer_config.model_config.debug = 55 recognizer_config.model_config.debug =
54 SHERPA_ONNX_OR(config->model_config.debug, 0); 56 SHERPA_ONNX_OR(config->model_config.debug, 0);
55 57
@@ -53,6 +53,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineTransducerModelConfig { @@ -53,6 +53,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineTransducerModelConfig {
53 const char *tokens; 53 const char *tokens;
54 int32_t num_threads; 54 int32_t num_threads;
55 const char *provider; 55 const char *provider;
  56 + const char *model_type;
56 int32_t debug; // true to print debug information of the model 57 int32_t debug; // true to print debug information of the model
57 } SherpaOnnxOnlineTransducerModelConfig; 58 } SherpaOnnxOnlineTransducerModelConfig;
58 59
@@ -22,26 +22,30 @@ void OnlineTransducerModelConfig::Register(ParseOptions *po) { @@ -22,26 +22,30 @@ void OnlineTransducerModelConfig::Register(ParseOptions *po) {
22 22
23 po->Register("debug", &debug, 23 po->Register("debug", &debug,
24 "true to print model information while loading it."); 24 "true to print model information while loading it.");
  25 + po->Register("model-type", &model_type,
  26 + "Specify it to reduce model initialization time. "
  27 + "Valid values are: conformer, lstm, zipformer, zipformer2. "
  28 + "All other values lead to loading the model twice.");
25 } 29 }
26 30
27 bool OnlineTransducerModelConfig::Validate() const { 31 bool OnlineTransducerModelConfig::Validate() const {
28 if (!FileExists(tokens)) { 32 if (!FileExists(tokens)) {
29 - SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str()); 33 + SHERPA_ONNX_LOGE("tokens: %s does not exist", tokens.c_str());
30 return false; 34 return false;
31 } 35 }
32 36
33 if (!FileExists(encoder_filename)) { 37 if (!FileExists(encoder_filename)) {
34 - SHERPA_ONNX_LOGE("%s does not exist", encoder_filename.c_str()); 38 + SHERPA_ONNX_LOGE("encoder: %s does not exist", encoder_filename.c_str());
35 return false; 39 return false;
36 } 40 }
37 41
38 if (!FileExists(decoder_filename)) { 42 if (!FileExists(decoder_filename)) {
39 - SHERPA_ONNX_LOGE("%s does not exist", decoder_filename.c_str()); 43 + SHERPA_ONNX_LOGE("decoder: %s does not exist", decoder_filename.c_str());
40 return false; 44 return false;
41 } 45 }
42 46
43 if (!FileExists(joiner_filename)) { 47 if (!FileExists(joiner_filename)) {
44 - SHERPA_ONNX_LOGE("%s does not exist", joiner_filename.c_str()); 48 + SHERPA_ONNX_LOGE("joiner: %s does not exist", joiner_filename.c_str());
45 return false; 49 return false;
46 } 50 }
47 51
@@ -63,6 +67,7 @@ std::string OnlineTransducerModelConfig::ToString() const { @@ -63,6 +67,7 @@ std::string OnlineTransducerModelConfig::ToString() const {
63 os << "tokens=\"" << tokens << "\", "; 67 os << "tokens=\"" << tokens << "\", ";
64 os << "num_threads=" << num_threads << ", "; 68 os << "num_threads=" << num_threads << ", ";
65 os << "provider=\"" << provider << "\", "; 69 os << "provider=\"" << provider << "\", ";
  70 + os << "model_type=\"" << model_type << "\", ";
66 os << "debug=" << (debug ? "True" : "False") << ")"; 71 os << "debug=" << (debug ? "True" : "False") << ")";
67 72
68 return os.str(); 73 return os.str();
@@ -19,19 +19,33 @@ struct OnlineTransducerModelConfig { @@ -19,19 +19,33 @@ struct OnlineTransducerModelConfig {
19 bool debug = false; 19 bool debug = false;
20 std::string provider = "cpu"; 20 std::string provider = "cpu";
21 21
  22 + // With the help of this field, we only need to load the model once
  23 + // instead of twice; and therefore it reduces initialization time.
  24 + //
  25 + // Valid values:
  26 + // - conformer
  27 + // - lstm
  28 + // - zipformer
  29 + // - zipformer2
  30 + //
  31 + // All other values are invalid and lead to loading the model twice.
  32 + std::string model_type;
  33 +
22 OnlineTransducerModelConfig() = default; 34 OnlineTransducerModelConfig() = default;
23 OnlineTransducerModelConfig(const std::string &encoder_filename, 35 OnlineTransducerModelConfig(const std::string &encoder_filename,
24 const std::string &decoder_filename, 36 const std::string &decoder_filename,
25 const std::string &joiner_filename, 37 const std::string &joiner_filename,
26 const std::string &tokens, int32_t num_threads, 38 const std::string &tokens, int32_t num_threads,
27 - bool debug, const std::string &provider) 39 + bool debug, const std::string &provider,
  40 + const std::string &model_type)
28 : encoder_filename(encoder_filename), 41 : encoder_filename(encoder_filename),
29 decoder_filename(decoder_filename), 42 decoder_filename(decoder_filename),
30 joiner_filename(joiner_filename), 43 joiner_filename(joiner_filename),
31 tokens(tokens), 44 tokens(tokens),
32 num_threads(num_threads), 45 num_threads(num_threads),
33 debug(debug), 46 debug(debug),
34 - provider(provider) {} 47 + provider(provider),
  48 + model_type(model_type) {}
35 49
36 void Register(ParseOptions *po); 50 void Register(ParseOptions *po);
37 bool Validate() const; 51 bool Validate() const;
@@ -77,6 +77,22 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -77,6 +77,22 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
77 77
78 std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( 78 std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
79 const OnlineTransducerModelConfig &config) { 79 const OnlineTransducerModelConfig &config) {
  80 + if (!config.model_type.empty()) {
  81 + const auto &model_type = config.model_type;
  82 + if (model_type == "conformer") {
  83 + return std::make_unique<OnlineConformerTransducerModel>(config);
  84 + } else if (model_type == "lstm") {
  85 + return std::make_unique<OnlineLstmTransducerModel>(config);
  86 + } else if (model_type == "zipformer") {
  87 + return std::make_unique<OnlineZipformerTransducerModel>(config);
  88 + } else if (model_type == "zipformer2") {
  89 + return std::make_unique<OnlineZipformer2TransducerModel>(config);
  90 + } else {
  91 + SHERPA_ONNX_LOGE(
  92 + "Invalid model_type: %s. Trying to load the model to get its type",
  93 + model_type.c_str());
  94 + }
  95 + }
80 ModelType model_type = ModelType::kUnkown; 96 ModelType model_type = ModelType::kUnkown;
81 97
82 { 98 {
@@ -140,6 +156,23 @@ Ort::Value OnlineTransducerModel::BuildDecoderInput( @@ -140,6 +156,23 @@ Ort::Value OnlineTransducerModel::BuildDecoderInput(
140 #if __ANDROID_API__ >= 9 156 #if __ANDROID_API__ >= 9
141 std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( 157 std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
142 AAssetManager *mgr, const OnlineTransducerModelConfig &config) { 158 AAssetManager *mgr, const OnlineTransducerModelConfig &config) {
  159 + if (!config.model_type.empty()) {
  160 + const auto &model_type = config.model_type;
  161 + if (model_type == "conformer") {
  162 + return std::make_unique<OnlineConformerTransducerModel>(mgr, config);
  163 + } else if (model_type == "lstm") {
  164 + return std::make_unique<OnlineLstmTransducerModel>(mgr, config);
  165 + } else if (model_type == "zipformer") {
  166 + return std::make_unique<OnlineZipformerTransducerModel>(mgr, config);
  167 + } else if (model_type == "zipformer2") {
  168 + return std::make_unique<OnlineZipformer2TransducerModel>(mgr, config);
  169 + } else {
  170 + SHERPA_ONNX_LOGE(
  171 + "Invalid model_type: %s. Trying to load the model to get its type",
  172 + model_type.c_str());
  173 + }
  174 + }
  175 +
143 auto buffer = ReadFile(mgr, config.encoder_filename); 176 auto buffer = ReadFile(mgr, config.encoder_filename);
144 auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug); 177 auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
145 178
@@ -15,11 +15,11 @@ void PybindOnlineTransducerModelConfig(py::module *m) { @@ -15,11 +15,11 @@ void PybindOnlineTransducerModelConfig(py::module *m) {
15 py::class_<PyClass>(*m, "OnlineTransducerModelConfig") 15 py::class_<PyClass>(*m, "OnlineTransducerModelConfig")
16 .def(py::init<const std::string &, const std::string &, 16 .def(py::init<const std::string &, const std::string &,
17 const std::string &, const std::string &, int32_t, bool, 17 const std::string &, const std::string &, int32_t, bool,
18 - const std::string &>(), 18 + const std::string &, const std::string &>(),
19 py::arg("encoder_filename"), py::arg("decoder_filename"), 19 py::arg("encoder_filename"), py::arg("decoder_filename"),
20 py::arg("joiner_filename"), py::arg("tokens"), 20 py::arg("joiner_filename"), py::arg("tokens"),
21 py::arg("num_threads"), py::arg("debug") = false, 21 py::arg("num_threads"), py::arg("debug") = false,
22 - py::arg("provider") = "cpu") 22 + py::arg("provider") = "cpu", py::arg("model_type") = "")
23 .def_readwrite("encoder_filename", &PyClass::encoder_filename) 23 .def_readwrite("encoder_filename", &PyClass::encoder_filename)
24 .def_readwrite("decoder_filename", &PyClass::decoder_filename) 24 .def_readwrite("decoder_filename", &PyClass::decoder_filename)
25 .def_readwrite("joiner_filename", &PyClass::joiner_filename) 25 .def_readwrite("joiner_filename", &PyClass::joiner_filename)
@@ -27,6 +27,7 @@ void PybindOnlineTransducerModelConfig(py::module *m) { @@ -27,6 +27,7 @@ void PybindOnlineTransducerModelConfig(py::module *m) {
27 .def_readwrite("num_threads", &PyClass::num_threads) 27 .def_readwrite("num_threads", &PyClass::num_threads)
28 .def_readwrite("debug", &PyClass::debug) 28 .def_readwrite("debug", &PyClass::debug)
29 .def_readwrite("provider", &PyClass::provider) 29 .def_readwrite("provider", &PyClass::provider)
  30 + .def_readwrite("model_type", &PyClass::model_type)
30 .def("__str__", &PyClass::ToString); 31 .def("__str__", &PyClass::ToString);
31 } 32 }
32 33
@@ -41,6 +41,7 @@ class OnlineRecognizer(object): @@ -41,6 +41,7 @@ class OnlineRecognizer(object):
41 max_active_paths: int = 4, 41 max_active_paths: int = 4,
42 context_score: float = 1.5, 42 context_score: float = 1.5,
43 provider: str = "cpu", 43 provider: str = "cpu",
  44 + model_type: str = "",
44 ): 45 ):
45 """ 46 """
46 Please refer to 47 Please refer to
@@ -90,6 +91,9 @@ class OnlineRecognizer(object): @@ -90,6 +91,9 @@ class OnlineRecognizer(object):
90 the maximum number of active paths during beam search. 91 the maximum number of active paths during beam search.
91 provider: 92 provider:
92 onnxruntime execution providers. Valid values are: cpu, cuda, coreml. 93 onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
  94 + model_type:
  95 + Online transducer model type. Valid values are: conformer, lstm,
  96 + zipformer, zipformer2. All other values lead to loading the model twice.
93 """ 97 """
94 _assert_file_exists(tokens) 98 _assert_file_exists(tokens)
95 _assert_file_exists(encoder) 99 _assert_file_exists(encoder)
@@ -105,6 +109,7 @@ class OnlineRecognizer(object): @@ -105,6 +109,7 @@ class OnlineRecognizer(object):
105 tokens=tokens, 109 tokens=tokens,
106 num_threads=num_threads, 110 num_threads=num_threads,
107 provider=provider, 111 provider=provider,
  112 + model_type=model_type,
108 ) 113 )
109 114
110 feat_config = FeatureExtractorConfig( 115 feat_config = FeatureExtractorConfig(