Committed by
GitHub
Reduce model initialization time for online speech recognition (#215)
* Reduce model initialization time for online speech recognition * Fixed Styling --------- Co-authored-by: w11wo <wilsowong961@gmail.com>
正在显示
7 个修改的文件
包含
69 行增加
和
8 行删除
| @@ -50,6 +50,8 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( | @@ -50,6 +50,8 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( | ||
| 50 | SHERPA_ONNX_OR(config->model_config.num_threads, 1); | 50 | SHERPA_ONNX_OR(config->model_config.num_threads, 1); |
| 51 | recognizer_config.model_config.provider = | 51 | recognizer_config.model_config.provider = |
| 52 | SHERPA_ONNX_OR(config->model_config.provider, "cpu"); | 52 | SHERPA_ONNX_OR(config->model_config.provider, "cpu"); |
| 53 | + recognizer_config.model_config.model_type = | ||
| 54 | + SHERPA_ONNX_OR(config->model_config.model_type, ""); | ||
| 53 | recognizer_config.model_config.debug = | 55 | recognizer_config.model_config.debug = |
| 54 | SHERPA_ONNX_OR(config->model_config.debug, 0); | 56 | SHERPA_ONNX_OR(config->model_config.debug, 0); |
| 55 | 57 |
| @@ -53,6 +53,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineTransducerModelConfig { | @@ -53,6 +53,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineTransducerModelConfig { | ||
| 53 | const char *tokens; | 53 | const char *tokens; |
| 54 | int32_t num_threads; | 54 | int32_t num_threads; |
| 55 | const char *provider; | 55 | const char *provider; |
| 56 | + const char *model_type; | ||
| 56 | int32_t debug; // true to print debug information of the model | 57 | int32_t debug; // true to print debug information of the model |
| 57 | } SherpaOnnxOnlineTransducerModelConfig; | 58 | } SherpaOnnxOnlineTransducerModelConfig; |
| 58 | 59 |
| @@ -22,26 +22,30 @@ void OnlineTransducerModelConfig::Register(ParseOptions *po) { | @@ -22,26 +22,30 @@ void OnlineTransducerModelConfig::Register(ParseOptions *po) { | ||
| 22 | 22 | ||
| 23 | po->Register("debug", &debug, | 23 | po->Register("debug", &debug, |
| 24 | "true to print model information while loading it."); | 24 | "true to print model information while loading it."); |
| 25 | + po->Register("model-type", &model_type, | ||
| 26 | + "Specify it to reduce model initialization time. " | ||
| 27 | + "Valid values are: conformer, lstm, zipformer, zipformer2. " | ||
| 28 | + "All other values lead to loading the model twice."); | ||
| 25 | } | 29 | } |
| 26 | 30 | ||
| 27 | bool OnlineTransducerModelConfig::Validate() const { | 31 | bool OnlineTransducerModelConfig::Validate() const { |
| 28 | if (!FileExists(tokens)) { | 32 | if (!FileExists(tokens)) { |
| 29 | - SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str()); | 33 | + SHERPA_ONNX_LOGE("tokens: %s does not exist", tokens.c_str()); |
| 30 | return false; | 34 | return false; |
| 31 | } | 35 | } |
| 32 | 36 | ||
| 33 | if (!FileExists(encoder_filename)) { | 37 | if (!FileExists(encoder_filename)) { |
| 34 | - SHERPA_ONNX_LOGE("%s does not exist", encoder_filename.c_str()); | 38 | + SHERPA_ONNX_LOGE("encoder: %s does not exist", encoder_filename.c_str()); |
| 35 | return false; | 39 | return false; |
| 36 | } | 40 | } |
| 37 | 41 | ||
| 38 | if (!FileExists(decoder_filename)) { | 42 | if (!FileExists(decoder_filename)) { |
| 39 | - SHERPA_ONNX_LOGE("%s does not exist", decoder_filename.c_str()); | 43 | + SHERPA_ONNX_LOGE("decoder: %s does not exist", decoder_filename.c_str()); |
| 40 | return false; | 44 | return false; |
| 41 | } | 45 | } |
| 42 | 46 | ||
| 43 | if (!FileExists(joiner_filename)) { | 47 | if (!FileExists(joiner_filename)) { |
| 44 | - SHERPA_ONNX_LOGE("%s does not exist", joiner_filename.c_str()); | 48 | + SHERPA_ONNX_LOGE("joiner: %s does not exist", joiner_filename.c_str()); |
| 45 | return false; | 49 | return false; |
| 46 | } | 50 | } |
| 47 | 51 | ||
| @@ -63,6 +67,7 @@ std::string OnlineTransducerModelConfig::ToString() const { | @@ -63,6 +67,7 @@ std::string OnlineTransducerModelConfig::ToString() const { | ||
| 63 | os << "tokens=\"" << tokens << "\", "; | 67 | os << "tokens=\"" << tokens << "\", "; |
| 64 | os << "num_threads=" << num_threads << ", "; | 68 | os << "num_threads=" << num_threads << ", "; |
| 65 | os << "provider=\"" << provider << "\", "; | 69 | os << "provider=\"" << provider << "\", "; |
| 70 | + os << "model_type=\"" << model_type << "\", "; | ||
| 66 | os << "debug=" << (debug ? "True" : "False") << ")"; | 71 | os << "debug=" << (debug ? "True" : "False") << ")"; |
| 67 | 72 | ||
| 68 | return os.str(); | 73 | return os.str(); |
| @@ -19,19 +19,33 @@ struct OnlineTransducerModelConfig { | @@ -19,19 +19,33 @@ struct OnlineTransducerModelConfig { | ||
| 19 | bool debug = false; | 19 | bool debug = false; |
| 20 | std::string provider = "cpu"; | 20 | std::string provider = "cpu"; |
| 21 | 21 | ||
| 22 | + // With the help of this field, we only need to load the model once | ||
| 23 | + // instead of twice; and therefore it reduces initialization time. | ||
| 24 | + // | ||
| 25 | + // Valid values: | ||
| 26 | + // - conformer | ||
| 27 | + // - lstm | ||
| 28 | + // - zipformer | ||
| 29 | + // - zipformer2 | ||
| 30 | + // | ||
| 31 | + // All other values are invalid and lead to loading the model twice. | ||
| 32 | + std::string model_type; | ||
| 33 | + | ||
| 22 | OnlineTransducerModelConfig() = default; | 34 | OnlineTransducerModelConfig() = default; |
| 23 | OnlineTransducerModelConfig(const std::string &encoder_filename, | 35 | OnlineTransducerModelConfig(const std::string &encoder_filename, |
| 24 | const std::string &decoder_filename, | 36 | const std::string &decoder_filename, |
| 25 | const std::string &joiner_filename, | 37 | const std::string &joiner_filename, |
| 26 | const std::string &tokens, int32_t num_threads, | 38 | const std::string &tokens, int32_t num_threads, |
| 27 | - bool debug, const std::string &provider) | 39 | + bool debug, const std::string &provider, |
| 40 | + const std::string &model_type) | ||
| 28 | : encoder_filename(encoder_filename), | 41 | : encoder_filename(encoder_filename), |
| 29 | decoder_filename(decoder_filename), | 42 | decoder_filename(decoder_filename), |
| 30 | joiner_filename(joiner_filename), | 43 | joiner_filename(joiner_filename), |
| 31 | tokens(tokens), | 44 | tokens(tokens), |
| 32 | num_threads(num_threads), | 45 | num_threads(num_threads), |
| 33 | debug(debug), | 46 | debug(debug), |
| 34 | - provider(provider) {} | 47 | + provider(provider), |
| 48 | + model_type(model_type) {} | ||
| 35 | 49 | ||
| 36 | void Register(ParseOptions *po); | 50 | void Register(ParseOptions *po); |
| 37 | bool Validate() const; | 51 | bool Validate() const; |
| @@ -77,6 +77,22 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -77,6 +77,22 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 77 | 77 | ||
| 78 | std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | 78 | std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( |
| 79 | const OnlineTransducerModelConfig &config) { | 79 | const OnlineTransducerModelConfig &config) { |
| 80 | + if (!config.model_type.empty()) { | ||
| 81 | + const auto &model_type = config.model_type; | ||
| 82 | + if (model_type == "conformer") { | ||
| 83 | + return std::make_unique<OnlineConformerTransducerModel>(config); | ||
| 84 | + } else if (model_type == "lstm") { | ||
| 85 | + return std::make_unique<OnlineLstmTransducerModel>(config); | ||
| 86 | + } else if (model_type == "zipformer") { | ||
| 87 | + return std::make_unique<OnlineZipformerTransducerModel>(config); | ||
| 88 | + } else if (model_type == "zipformer2") { | ||
| 89 | + return std::make_unique<OnlineZipformer2TransducerModel>(config); | ||
| 90 | + } else { | ||
| 91 | + SHERPA_ONNX_LOGE( | ||
| 92 | + "Invalid model_type: %s. Trying to load the model to get its type", | ||
| 93 | + model_type.c_str()); | ||
| 94 | + } | ||
| 95 | + } | ||
| 80 | ModelType model_type = ModelType::kUnkown; | 96 | ModelType model_type = ModelType::kUnkown; |
| 81 | 97 | ||
| 82 | { | 98 | { |
| @@ -140,6 +156,23 @@ Ort::Value OnlineTransducerModel::BuildDecoderInput( | @@ -140,6 +156,23 @@ Ort::Value OnlineTransducerModel::BuildDecoderInput( | ||
| 140 | #if __ANDROID_API__ >= 9 | 156 | #if __ANDROID_API__ >= 9 |
| 141 | std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | 157 | std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( |
| 142 | AAssetManager *mgr, const OnlineTransducerModelConfig &config) { | 158 | AAssetManager *mgr, const OnlineTransducerModelConfig &config) { |
| 159 | + if (!config.model_type.empty()) { | ||
| 160 | + const auto &model_type = config.model_type; | ||
| 161 | + if (model_type == "conformer") { | ||
| 162 | + return std::make_unique<OnlineConformerTransducerModel>(mgr, config); | ||
| 163 | + } else if (model_type == "lstm") { | ||
| 164 | + return std::make_unique<OnlineLstmTransducerModel>(mgr, config); | ||
| 165 | + } else if (model_type == "zipformer") { | ||
| 166 | + return std::make_unique<OnlineZipformerTransducerModel>(mgr, config); | ||
| 167 | + } else if (model_type == "zipformer2") { | ||
| 168 | + return std::make_unique<OnlineZipformer2TransducerModel>(mgr, config); | ||
| 169 | + } else { | ||
| 170 | + SHERPA_ONNX_LOGE( | ||
| 171 | + "Invalid model_type: %s. Trying to load the model to get its type", | ||
| 172 | + model_type.c_str()); | ||
| 173 | + } | ||
| 174 | + } | ||
| 175 | + | ||
| 143 | auto buffer = ReadFile(mgr, config.encoder_filename); | 176 | auto buffer = ReadFile(mgr, config.encoder_filename); |
| 144 | auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug); | 177 | auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug); |
| 145 | 178 |
| @@ -15,11 +15,11 @@ void PybindOnlineTransducerModelConfig(py::module *m) { | @@ -15,11 +15,11 @@ void PybindOnlineTransducerModelConfig(py::module *m) { | ||
| 15 | py::class_<PyClass>(*m, "OnlineTransducerModelConfig") | 15 | py::class_<PyClass>(*m, "OnlineTransducerModelConfig") |
| 16 | .def(py::init<const std::string &, const std::string &, | 16 | .def(py::init<const std::string &, const std::string &, |
| 17 | const std::string &, const std::string &, int32_t, bool, | 17 | const std::string &, const std::string &, int32_t, bool, |
| 18 | - const std::string &>(), | 18 | + const std::string &, const std::string &>(), |
| 19 | py::arg("encoder_filename"), py::arg("decoder_filename"), | 19 | py::arg("encoder_filename"), py::arg("decoder_filename"), |
| 20 | py::arg("joiner_filename"), py::arg("tokens"), | 20 | py::arg("joiner_filename"), py::arg("tokens"), |
| 21 | py::arg("num_threads"), py::arg("debug") = false, | 21 | py::arg("num_threads"), py::arg("debug") = false, |
| 22 | - py::arg("provider") = "cpu") | 22 | + py::arg("provider") = "cpu", py::arg("model_type") = "") |
| 23 | .def_readwrite("encoder_filename", &PyClass::encoder_filename) | 23 | .def_readwrite("encoder_filename", &PyClass::encoder_filename) |
| 24 | .def_readwrite("decoder_filename", &PyClass::decoder_filename) | 24 | .def_readwrite("decoder_filename", &PyClass::decoder_filename) |
| 25 | .def_readwrite("joiner_filename", &PyClass::joiner_filename) | 25 | .def_readwrite("joiner_filename", &PyClass::joiner_filename) |
| @@ -27,6 +27,7 @@ void PybindOnlineTransducerModelConfig(py::module *m) { | @@ -27,6 +27,7 @@ void PybindOnlineTransducerModelConfig(py::module *m) { | ||
| 27 | .def_readwrite("num_threads", &PyClass::num_threads) | 27 | .def_readwrite("num_threads", &PyClass::num_threads) |
| 28 | .def_readwrite("debug", &PyClass::debug) | 28 | .def_readwrite("debug", &PyClass::debug) |
| 29 | .def_readwrite("provider", &PyClass::provider) | 29 | .def_readwrite("provider", &PyClass::provider) |
| 30 | + .def_readwrite("model_type", &PyClass::model_type) | ||
| 30 | .def("__str__", &PyClass::ToString); | 31 | .def("__str__", &PyClass::ToString); |
| 31 | } | 32 | } |
| 32 | 33 |
| @@ -41,6 +41,7 @@ class OnlineRecognizer(object): | @@ -41,6 +41,7 @@ class OnlineRecognizer(object): | ||
| 41 | max_active_paths: int = 4, | 41 | max_active_paths: int = 4, |
| 42 | context_score: float = 1.5, | 42 | context_score: float = 1.5, |
| 43 | provider: str = "cpu", | 43 | provider: str = "cpu", |
| 44 | + model_type: str = "", | ||
| 44 | ): | 45 | ): |
| 45 | """ | 46 | """ |
| 46 | Please refer to | 47 | Please refer to |
| @@ -90,6 +91,9 @@ class OnlineRecognizer(object): | @@ -90,6 +91,9 @@ class OnlineRecognizer(object): | ||
| 90 | the maximum number of active paths during beam search. | 91 | the maximum number of active paths during beam search. |
| 91 | provider: | 92 | provider: |
| 92 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | 93 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. |
| 94 | + model_type: | ||
| 95 | + Online transducer model type. Valid values are: conformer, lstm, | ||
| 96 | + zipformer, zipformer2. All other values lead to loading the model twice. | ||
| 93 | """ | 97 | """ |
| 94 | _assert_file_exists(tokens) | 98 | _assert_file_exists(tokens) |
| 95 | _assert_file_exists(encoder) | 99 | _assert_file_exists(encoder) |
| @@ -105,6 +109,7 @@ class OnlineRecognizer(object): | @@ -105,6 +109,7 @@ class OnlineRecognizer(object): | ||
| 105 | tokens=tokens, | 109 | tokens=tokens, |
| 106 | num_threads=num_threads, | 110 | num_threads=num_threads, |
| 107 | provider=provider, | 111 | provider=provider, |
| 112 | + model_type=model_type, | ||
| 108 | ) | 113 | ) |
| 109 | 114 | ||
| 110 | feat_config = FeatureExtractorConfig( | 115 | feat_config = FeatureExtractorConfig( |
-
请 注册 或 登录 后发表评论