Committed by
GitHub
RNNLM model support lm_num_thread and lm_provider setting (#173)
* rnnlm model inference supports num_threads setting * rnnlm params decouple num_thread and provider with Transducer. * fix python csrc bug which offline-lm-config.cc and online-lm-config.cc arguments problem * lm_num_threads and lm_provider set default values --------- Co-authored-by: cuidongcai1035 <cuidongcai1035@wezhuiyi.com>
正在显示
18 个修改的文件
包含
67 行增加
和
31 行删除
| @@ -14,6 +14,10 @@ namespace sherpa_onnx { | @@ -14,6 +14,10 @@ namespace sherpa_onnx { | ||
| 14 | void OfflineLMConfig::Register(ParseOptions *po) { | 14 | void OfflineLMConfig::Register(ParseOptions *po) { |
| 15 | po->Register("lm", &model, "Path to LM model."); | 15 | po->Register("lm", &model, "Path to LM model."); |
| 16 | po->Register("lm-scale", &scale, "LM scale."); | 16 | po->Register("lm-scale", &scale, "LM scale."); |
| 17 | + po->Register("lm-num-threads", &lm_num_threads, | ||
| 18 | + "Number of threads to run the neural network of LM model"); | ||
| 19 | + po->Register("lm-provider", &lm_provider, | ||
| 20 | + "Specify a provider to LM model use: cpu, cuda, coreml"); | ||
| 17 | } | 21 | } |
| 18 | 22 | ||
| 19 | bool OfflineLMConfig::Validate() const { | 23 | bool OfflineLMConfig::Validate() const { |
| @@ -16,11 +16,17 @@ struct OfflineLMConfig { | @@ -16,11 +16,17 @@ struct OfflineLMConfig { | ||
| 16 | 16 | ||
| 17 | // LM scale | 17 | // LM scale |
| 18 | float scale = 0.5; | 18 | float scale = 0.5; |
| 19 | + int32_t lm_num_threads = 1; | ||
| 20 | + std::string lm_provider = "cpu"; | ||
| 19 | 21 | ||
| 20 | OfflineLMConfig() = default; | 22 | OfflineLMConfig() = default; |
| 21 | 23 | ||
| 22 | - OfflineLMConfig(const std::string &model, float scale) | ||
| 23 | - : model(model), scale(scale) {} | 24 | + OfflineLMConfig(const std::string &model, float scale, int32_t lm_num_threads, |
| 25 | + const std::string &lm_provider) | ||
| 26 | + : model(model), | ||
| 27 | + scale(scale), | ||
| 28 | + lm_num_threads(lm_num_threads), | ||
| 29 | + lm_provider(lm_provider) {} | ||
| 24 | 30 | ||
| 25 | void Register(ParseOptions *po); | 31 | void Register(ParseOptions *po); |
| 26 | bool Validate() const; | 32 | bool Validate() const; |
| @@ -12,8 +12,7 @@ | @@ -12,8 +12,7 @@ | ||
| 12 | 12 | ||
| 13 | namespace sherpa_onnx { | 13 | namespace sherpa_onnx { |
| 14 | 14 | ||
| 15 | -std::unique_ptr<OfflineLM> OfflineLM::Create( | ||
| 16 | - const OfflineRecognizerConfig &config) { | 15 | +std::unique_ptr<OfflineLM> OfflineLM::Create(const OfflineLMConfig &config) { |
| 17 | return std::make_unique<OfflineRnnLM>(config); | 16 | return std::make_unique<OfflineRnnLM>(config); |
| 18 | } | 17 | } |
| 19 | 18 |
| @@ -10,7 +10,7 @@ | @@ -10,7 +10,7 @@ | ||
| 10 | 10 | ||
| 11 | #include "onnxruntime_cxx_api.h" // NOLINT | 11 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 12 | #include "sherpa-onnx/csrc/hypothesis.h" | 12 | #include "sherpa-onnx/csrc/hypothesis.h" |
| 13 | -#include "sherpa-onnx/csrc/offline-recognizer.h" | 13 | +#include "sherpa-onnx/csrc/offline-lm-config.h" |
| 14 | 14 | ||
| 15 | namespace sherpa_onnx { | 15 | namespace sherpa_onnx { |
| 16 | 16 | ||
| @@ -18,8 +18,7 @@ class OfflineLM { | @@ -18,8 +18,7 @@ class OfflineLM { | ||
| 18 | public: | 18 | public: |
| 19 | virtual ~OfflineLM() = default; | 19 | virtual ~OfflineLM() = default; |
| 20 | 20 | ||
| 21 | - static std::unique_ptr<OfflineLM> Create( | ||
| 22 | - const OfflineRecognizerConfig &config); | 21 | + static std::unique_ptr<OfflineLM> Create(const OfflineLMConfig &config); |
| 23 | 22 | ||
| 24 | /** Rescore a batch of sentences. | 23 | /** Rescore a batch of sentences. |
| 25 | * | 24 | * |
| @@ -59,7 +59,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -59,7 +59,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 59 | std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get()); | 59 | std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get()); |
| 60 | } else if (config_.decoding_method == "modified_beam_search") { | 60 | } else if (config_.decoding_method == "modified_beam_search") { |
| 61 | if (!config_.lm_config.model.empty()) { | 61 | if (!config_.lm_config.model.empty()) { |
| 62 | - lm_ = OfflineLM::Create(config); | 62 | + lm_ = OfflineLM::Create(config.lm_config); |
| 63 | } | 63 | } |
| 64 | 64 | ||
| 65 | decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( | 65 | decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( |
| @@ -18,12 +18,12 @@ namespace sherpa_onnx { | @@ -18,12 +18,12 @@ namespace sherpa_onnx { | ||
| 18 | 18 | ||
| 19 | class OfflineRnnLM::Impl { | 19 | class OfflineRnnLM::Impl { |
| 20 | public: | 20 | public: |
| 21 | - explicit Impl(const OfflineRecognizerConfig &config) | ||
| 22 | - : config_(config.lm_config), | 21 | + explicit Impl(const OfflineLMConfig &config) |
| 22 | + : config_(config), | ||
| 23 | env_(ORT_LOGGING_LEVEL_ERROR), | 23 | env_(ORT_LOGGING_LEVEL_ERROR), |
| 24 | - sess_opts_{GetSessionOptions(config.model_config)}, | 24 | + sess_opts_{GetSessionOptions(config)}, |
| 25 | allocator_{} { | 25 | allocator_{} { |
| 26 | - Init(config.lm_config); | 26 | + Init(config); |
| 27 | } | 27 | } |
| 28 | 28 | ||
| 29 | Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) { | 29 | Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) { |
| @@ -63,7 +63,7 @@ class OfflineRnnLM::Impl { | @@ -63,7 +63,7 @@ class OfflineRnnLM::Impl { | ||
| 63 | std::vector<const char *> output_names_ptr_; | 63 | std::vector<const char *> output_names_ptr_; |
| 64 | }; | 64 | }; |
| 65 | 65 | ||
| 66 | -OfflineRnnLM::OfflineRnnLM(const OfflineRecognizerConfig &config) | 66 | +OfflineRnnLM::OfflineRnnLM(const OfflineLMConfig &config) |
| 67 | : impl_(std::make_unique<Impl>(config)) {} | 67 | : impl_(std::make_unique<Impl>(config)) {} |
| 68 | 68 | ||
| 69 | OfflineRnnLM::~OfflineRnnLM() = default; | 69 | OfflineRnnLM::~OfflineRnnLM() = default; |
| @@ -17,7 +17,7 @@ class OfflineRnnLM : public OfflineLM { | @@ -17,7 +17,7 @@ class OfflineRnnLM : public OfflineLM { | ||
| 17 | public: | 17 | public: |
| 18 | ~OfflineRnnLM() override; | 18 | ~OfflineRnnLM() override; |
| 19 | 19 | ||
| 20 | - explicit OfflineRnnLM(const OfflineRecognizerConfig &config); | 20 | + explicit OfflineRnnLM(const OfflineLMConfig &config); |
| 21 | 21 | ||
| 22 | /** Rescore a batch of sentences. | 22 | /** Rescore a batch of sentences. |
| 23 | * | 23 | * |
| @@ -14,6 +14,10 @@ namespace sherpa_onnx { | @@ -14,6 +14,10 @@ namespace sherpa_onnx { | ||
| 14 | void OnlineLMConfig::Register(ParseOptions *po) { | 14 | void OnlineLMConfig::Register(ParseOptions *po) { |
| 15 | po->Register("lm", &model, "Path to LM model."); | 15 | po->Register("lm", &model, "Path to LM model."); |
| 16 | po->Register("lm-scale", &scale, "LM scale."); | 16 | po->Register("lm-scale", &scale, "LM scale."); |
| 17 | + po->Register("lm-num-threads", &lm_num_threads, | ||
| 18 | + "Number of threads to run the neural network of LM model"); | ||
| 19 | + po->Register("lm-provider", &lm_provider, | ||
| 20 | + "Specify a provider to LM model use: cpu, cuda, coreml"); | ||
| 17 | } | 21 | } |
| 18 | 22 | ||
| 19 | bool OnlineLMConfig::Validate() const { | 23 | bool OnlineLMConfig::Validate() const { |
| @@ -16,11 +16,17 @@ struct OnlineLMConfig { | @@ -16,11 +16,17 @@ struct OnlineLMConfig { | ||
| 16 | 16 | ||
| 17 | // LM scale | 17 | // LM scale |
| 18 | float scale = 0.5; | 18 | float scale = 0.5; |
| 19 | + int32_t lm_num_threads = 1; | ||
| 20 | + std::string lm_provider = "cpu"; | ||
| 19 | 21 | ||
| 20 | OnlineLMConfig() = default; | 22 | OnlineLMConfig() = default; |
| 21 | 23 | ||
| 22 | - OnlineLMConfig(const std::string &model, float scale) | ||
| 23 | - : model(model), scale(scale) {} | 24 | + OnlineLMConfig(const std::string &model, float scale, int32_t lm_num_threads, |
| 25 | + const std::string &lm_provider) | ||
| 26 | + : model(model), | ||
| 27 | + scale(scale), | ||
| 28 | + lm_num_threads(lm_num_threads), | ||
| 29 | + lm_provider(lm_provider) {} | ||
| 24 | 30 | ||
| 25 | void Register(ParseOptions *po); | 31 | void Register(ParseOptions *po); |
| 26 | bool Validate() const; | 32 | bool Validate() const; |
| @@ -13,8 +13,7 @@ | @@ -13,8 +13,7 @@ | ||
| 13 | 13 | ||
| 14 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| 15 | 15 | ||
| 16 | -std::unique_ptr<OnlineLM> OnlineLM::Create( | ||
| 17 | - const OnlineRecognizerConfig &config) { | 16 | +std::unique_ptr<OnlineLM> OnlineLM::Create(const OnlineLMConfig &config) { |
| 18 | return std::make_unique<OnlineRnnLM>(config); | 17 | return std::make_unique<OnlineRnnLM>(config); |
| 19 | } | 18 | } |
| 20 | 19 |
| @@ -11,7 +11,7 @@ | @@ -11,7 +11,7 @@ | ||
| 11 | 11 | ||
| 12 | #include "onnxruntime_cxx_api.h" // NOLINT | 12 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 13 | #include "sherpa-onnx/csrc/hypothesis.h" | 13 | #include "sherpa-onnx/csrc/hypothesis.h" |
| 14 | -#include "sherpa-onnx/csrc/online-recognizer.h" | 14 | +#include "sherpa-onnx/csrc/online-lm-config.h" |
| 15 | 15 | ||
| 16 | namespace sherpa_onnx { | 16 | namespace sherpa_onnx { |
| 17 | 17 | ||
| @@ -19,7 +19,7 @@ class OnlineLM { | @@ -19,7 +19,7 @@ class OnlineLM { | ||
| 19 | public: | 19 | public: |
| 20 | virtual ~OnlineLM() = default; | 20 | virtual ~OnlineLM() = default; |
| 21 | 21 | ||
| 22 | - static std::unique_ptr<OnlineLM> Create(const OnlineRecognizerConfig &config); | 22 | + static std::unique_ptr<OnlineLM> Create(const OnlineLMConfig &config); |
| 23 | 23 | ||
| 24 | virtual std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() = 0; | 24 | virtual std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() = 0; |
| 25 | 25 |
| @@ -129,7 +129,7 @@ class OnlineRecognizer::Impl { | @@ -129,7 +129,7 @@ class OnlineRecognizer::Impl { | ||
| 129 | endpoint_(config_.endpoint_config) { | 129 | endpoint_(config_.endpoint_config) { |
| 130 | if (config.decoding_method == "modified_beam_search") { | 130 | if (config.decoding_method == "modified_beam_search") { |
| 131 | if (!config_.lm_config.model.empty()) { | 131 | if (!config_.lm_config.model.empty()) { |
| 132 | - lm_ = OnlineLM::Create(config); | 132 | + lm_ = OnlineLM::Create(config.lm_config); |
| 133 | } | 133 | } |
| 134 | 134 | ||
| 135 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( | 135 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( |
| @@ -19,12 +19,12 @@ namespace sherpa_onnx { | @@ -19,12 +19,12 @@ namespace sherpa_onnx { | ||
| 19 | 19 | ||
| 20 | class OnlineRnnLM::Impl { | 20 | class OnlineRnnLM::Impl { |
| 21 | public: | 21 | public: |
| 22 | - explicit Impl(const OnlineRecognizerConfig &config) | ||
| 23 | - : config_(config.lm_config), | 22 | + explicit Impl(const OnlineLMConfig &config) |
| 23 | + : config_(config), | ||
| 24 | env_(ORT_LOGGING_LEVEL_ERROR), | 24 | env_(ORT_LOGGING_LEVEL_ERROR), |
| 25 | - sess_opts_{GetSessionOptions(config.model_config)}, | 25 | + sess_opts_{GetSessionOptions(config)}, |
| 26 | allocator_{} { | 26 | allocator_{} { |
| 27 | - Init(config.lm_config); | 27 | + Init(config); |
| 28 | } | 28 | } |
| 29 | 29 | ||
| 30 | void ComputeLMScore(float scale, Hypothesis *hyp) { | 30 | void ComputeLMScore(float scale, Hypothesis *hyp) { |
| @@ -143,7 +143,7 @@ class OnlineRnnLM::Impl { | @@ -143,7 +143,7 @@ class OnlineRnnLM::Impl { | ||
| 143 | int32_t sos_id_ = 1; | 143 | int32_t sos_id_ = 1; |
| 144 | }; | 144 | }; |
| 145 | 145 | ||
| 146 | -OnlineRnnLM::OnlineRnnLM(const OnlineRecognizerConfig &config) | 146 | +OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config) |
| 147 | : impl_(std::make_unique<Impl>(config)) {} | 147 | : impl_(std::make_unique<Impl>(config)) {} |
| 148 | 148 | ||
| 149 | OnlineRnnLM::~OnlineRnnLM() = default; | 149 | OnlineRnnLM::~OnlineRnnLM() = default; |
| @@ -20,7 +20,7 @@ class OnlineRnnLM : public OnlineLM { | @@ -20,7 +20,7 @@ class OnlineRnnLM : public OnlineLM { | ||
| 20 | public: | 20 | public: |
| 21 | ~OnlineRnnLM() override; | 21 | ~OnlineRnnLM() override; |
| 22 | 22 | ||
| 23 | - explicit OnlineRnnLM(const OnlineRecognizerConfig &config); | 23 | + explicit OnlineRnnLM(const OnlineLMConfig &config); |
| 24 | 24 | ||
| 25 | std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() override; | 25 | std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() override; |
| 26 | 26 |
| @@ -69,4 +69,12 @@ Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) { | @@ -69,4 +69,12 @@ Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) { | ||
| 69 | return GetSessionOptionsImpl(config.num_threads, config.provider); | 69 | return GetSessionOptionsImpl(config.num_threads, config.provider); |
| 70 | } | 70 | } |
| 71 | 71 | ||
| 72 | +Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config) { | ||
| 73 | + return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider); | ||
| 74 | +} | ||
| 75 | + | ||
| 76 | +Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config) { | ||
| 77 | + return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider); | ||
| 78 | +} | ||
| 79 | + | ||
| 72 | } // namespace sherpa_onnx | 80 | } // namespace sherpa_onnx |
| @@ -6,7 +6,9 @@ | @@ -6,7 +6,9 @@ | ||
| 6 | #define SHERPA_ONNX_CSRC_SESSION_H_ | 6 | #define SHERPA_ONNX_CSRC_SESSION_H_ |
| 7 | 7 | ||
| 8 | #include "onnxruntime_cxx_api.h" // NOLINT | 8 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 9 | +#include "sherpa-onnx/csrc/offline-lm-config.h" | ||
| 9 | #include "sherpa-onnx/csrc/offline-model-config.h" | 10 | #include "sherpa-onnx/csrc/offline-model-config.h" |
| 11 | +#include "sherpa-onnx/csrc/online-lm-config.h" | ||
| 10 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" | 12 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" |
| 11 | 13 | ||
| 12 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| @@ -16,6 +18,9 @@ Ort::SessionOptions GetSessionOptions( | @@ -16,6 +18,9 @@ Ort::SessionOptions GetSessionOptions( | ||
| 16 | 18 | ||
| 17 | Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config); | 19 | Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config); |
| 18 | 20 | ||
| 21 | +Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config); | ||
| 22 | + | ||
| 23 | +Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config); | ||
| 19 | } // namespace sherpa_onnx | 24 | } // namespace sherpa_onnx |
| 20 | 25 | ||
| 21 | #endif // SHERPA_ONNX_CSRC_SESSION_H_ | 26 | #endif // SHERPA_ONNX_CSRC_SESSION_H_ |
| @@ -13,10 +13,13 @@ namespace sherpa_onnx { | @@ -13,10 +13,13 @@ namespace sherpa_onnx { | ||
| 13 | void PybindOfflineLMConfig(py::module *m) { | 13 | void PybindOfflineLMConfig(py::module *m) { |
| 14 | using PyClass = OfflineLMConfig; | 14 | using PyClass = OfflineLMConfig; |
| 15 | py::class_<PyClass>(*m, "OfflineLMConfig") | 15 | py::class_<PyClass>(*m, "OfflineLMConfig") |
| 16 | - .def(py::init<const std::string &, float>(), py::arg("model"), | ||
| 17 | - py::arg("scale")) | 16 | + .def(py::init<const std::string &, float, int32_t, const std::string &>(), |
| 17 | + py::arg("model"), py::arg("scale") = 0.5f, | ||
| 18 | + py::arg("lm_num_threads") = 1, py::arg("lm-provider") = "cpu") | ||
| 18 | .def_readwrite("model", &PyClass::model) | 19 | .def_readwrite("model", &PyClass::model) |
| 19 | .def_readwrite("scale", &PyClass::scale) | 20 | .def_readwrite("scale", &PyClass::scale) |
| 21 | + .def_readwrite("lm_provider", &PyClass::lm_provider) | ||
| 22 | + .def_readwrite("lm_num_threads", &PyClass::lm_num_threads) | ||
| 20 | .def("__str__", &PyClass::ToString); | 23 | .def("__str__", &PyClass::ToString); |
| 21 | } | 24 | } |
| 22 | 25 |
| @@ -13,10 +13,13 @@ namespace sherpa_onnx { | @@ -13,10 +13,13 @@ namespace sherpa_onnx { | ||
| 13 | void PybindOnlineLMConfig(py::module *m) { | 13 | void PybindOnlineLMConfig(py::module *m) { |
| 14 | using PyClass = OnlineLMConfig; | 14 | using PyClass = OnlineLMConfig; |
| 15 | py::class_<PyClass>(*m, "OnlineLMConfig") | 15 | py::class_<PyClass>(*m, "OnlineLMConfig") |
| 16 | - .def(py::init<const std::string &, float>(), py::arg("model") = "", | ||
| 17 | - py::arg("scale") = 0.5f) | 16 | + .def(py::init<const std::string &, float, int32_t, const std::string &>(), |
| 17 | + py::arg("model") = "", py::arg("scale") = 0.5f, | ||
| 18 | + py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu") | ||
| 18 | .def_readwrite("model", &PyClass::model) | 19 | .def_readwrite("model", &PyClass::model) |
| 19 | .def_readwrite("scale", &PyClass::scale) | 20 | .def_readwrite("scale", &PyClass::scale) |
| 21 | + .def_readwrite("lm_provider", &PyClass::lm_provider) | ||
| 22 | + .def_readwrite("lm_num_threads", &PyClass::lm_num_threads) | ||
| 20 | .def("__str__", &PyClass::ToString); | 23 | .def("__str__", &PyClass::ToString); |
| 21 | } | 24 | } |
| 22 | 25 |
-
请 注册 或 登录 后发表评论