keanu
Committed by GitHub

RNNLM model support lm_num_thread and lm_provider setting (#173)

* rnnlm model inference supports num_threads setting

* rnnlm params decouple num_thread and provider with Transducer.

* fix python csrc bug which offline-lm-config.cc and online-lm-config.cc arguments problem

* lm_num_threads and lm_provider set default values

---------

Co-authored-by: cuidongcai1035 <cuidongcai1035@wezhuiyi.com>
@@ -14,6 +14,10 @@ namespace sherpa_onnx { @@ -14,6 +14,10 @@ namespace sherpa_onnx {
14 void OfflineLMConfig::Register(ParseOptions *po) { 14 void OfflineLMConfig::Register(ParseOptions *po) {
15 po->Register("lm", &model, "Path to LM model."); 15 po->Register("lm", &model, "Path to LM model.");
16 po->Register("lm-scale", &scale, "LM scale."); 16 po->Register("lm-scale", &scale, "LM scale.");
  17 + po->Register("lm-num-threads", &lm_num_threads,
  18 + "Number of threads to run the neural network of LM model");
  19 + po->Register("lm-provider", &lm_provider,
  20 + "Specify a provider to LM model use: cpu, cuda, coreml");
17 } 21 }
18 22
19 bool OfflineLMConfig::Validate() const { 23 bool OfflineLMConfig::Validate() const {
@@ -16,11 +16,17 @@ struct OfflineLMConfig { @@ -16,11 +16,17 @@ struct OfflineLMConfig {
16 16
17 // LM scale 17 // LM scale
18 float scale = 0.5; 18 float scale = 0.5;
  19 + int32_t lm_num_threads = 1;
  20 + std::string lm_provider = "cpu";
19 21
20 OfflineLMConfig() = default; 22 OfflineLMConfig() = default;
21 23
22 - OfflineLMConfig(const std::string &model, float scale)  
23 - : model(model), scale(scale) {} 24 + OfflineLMConfig(const std::string &model, float scale, int32_t lm_num_threads,
  25 + const std::string &lm_provider)
  26 + : model(model),
  27 + scale(scale),
  28 + lm_num_threads(lm_num_threads),
  29 + lm_provider(lm_provider) {}
24 30
25 void Register(ParseOptions *po); 31 void Register(ParseOptions *po);
26 bool Validate() const; 32 bool Validate() const;
@@ -12,8 +12,7 @@ @@ -12,8 +12,7 @@
12 12
13 namespace sherpa_onnx { 13 namespace sherpa_onnx {
14 14
15 -std::unique_ptr<OfflineLM> OfflineLM::Create(  
16 - const OfflineRecognizerConfig &config) { 15 +std::unique_ptr<OfflineLM> OfflineLM::Create(const OfflineLMConfig &config) {
17 return std::make_unique<OfflineRnnLM>(config); 16 return std::make_unique<OfflineRnnLM>(config);
18 } 17 }
19 18
@@ -10,7 +10,7 @@ @@ -10,7 +10,7 @@
10 10
11 #include "onnxruntime_cxx_api.h" // NOLINT 11 #include "onnxruntime_cxx_api.h" // NOLINT
12 #include "sherpa-onnx/csrc/hypothesis.h" 12 #include "sherpa-onnx/csrc/hypothesis.h"
13 -#include "sherpa-onnx/csrc/offline-recognizer.h" 13 +#include "sherpa-onnx/csrc/offline-lm-config.h"
14 14
15 namespace sherpa_onnx { 15 namespace sherpa_onnx {
16 16
@@ -18,8 +18,7 @@ class OfflineLM { @@ -18,8 +18,7 @@ class OfflineLM {
18 public: 18 public:
19 virtual ~OfflineLM() = default; 19 virtual ~OfflineLM() = default;
20 20
21 - static std::unique_ptr<OfflineLM> Create(  
22 - const OfflineRecognizerConfig &config); 21 + static std::unique_ptr<OfflineLM> Create(const OfflineLMConfig &config);
23 22
24 /** Rescore a batch of sentences. 23 /** Rescore a batch of sentences.
25 * 24 *
@@ -59,7 +59,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -59,7 +59,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
59 std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get()); 59 std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
60 } else if (config_.decoding_method == "modified_beam_search") { 60 } else if (config_.decoding_method == "modified_beam_search") {
61 if (!config_.lm_config.model.empty()) { 61 if (!config_.lm_config.model.empty()) {
62 - lm_ = OfflineLM::Create(config); 62 + lm_ = OfflineLM::Create(config.lm_config);
63 } 63 }
64 64
65 decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( 65 decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
@@ -18,12 +18,12 @@ namespace sherpa_onnx { @@ -18,12 +18,12 @@ namespace sherpa_onnx {
18 18
19 class OfflineRnnLM::Impl { 19 class OfflineRnnLM::Impl {
20 public: 20 public:
21 - explicit Impl(const OfflineRecognizerConfig &config)  
22 - : config_(config.lm_config), 21 + explicit Impl(const OfflineLMConfig &config)
  22 + : config_(config),
23 env_(ORT_LOGGING_LEVEL_ERROR), 23 env_(ORT_LOGGING_LEVEL_ERROR),
24 - sess_opts_{GetSessionOptions(config.model_config)}, 24 + sess_opts_{GetSessionOptions(config)},
25 allocator_{} { 25 allocator_{} {
26 - Init(config.lm_config); 26 + Init(config);
27 } 27 }
28 28
29 Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) { 29 Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) {
@@ -63,7 +63,7 @@ class OfflineRnnLM::Impl { @@ -63,7 +63,7 @@ class OfflineRnnLM::Impl {
63 std::vector<const char *> output_names_ptr_; 63 std::vector<const char *> output_names_ptr_;
64 }; 64 };
65 65
66 -OfflineRnnLM::OfflineRnnLM(const OfflineRecognizerConfig &config) 66 +OfflineRnnLM::OfflineRnnLM(const OfflineLMConfig &config)
67 : impl_(std::make_unique<Impl>(config)) {} 67 : impl_(std::make_unique<Impl>(config)) {}
68 68
69 OfflineRnnLM::~OfflineRnnLM() = default; 69 OfflineRnnLM::~OfflineRnnLM() = default;
@@ -17,7 +17,7 @@ class OfflineRnnLM : public OfflineLM { @@ -17,7 +17,7 @@ class OfflineRnnLM : public OfflineLM {
17 public: 17 public:
18 ~OfflineRnnLM() override; 18 ~OfflineRnnLM() override;
19 19
20 - explicit OfflineRnnLM(const OfflineRecognizerConfig &config); 20 + explicit OfflineRnnLM(const OfflineLMConfig &config);
21 21
22 /** Rescore a batch of sentences. 22 /** Rescore a batch of sentences.
23 * 23 *
@@ -14,6 +14,10 @@ namespace sherpa_onnx { @@ -14,6 +14,10 @@ namespace sherpa_onnx {
14 void OnlineLMConfig::Register(ParseOptions *po) { 14 void OnlineLMConfig::Register(ParseOptions *po) {
15 po->Register("lm", &model, "Path to LM model."); 15 po->Register("lm", &model, "Path to LM model.");
16 po->Register("lm-scale", &scale, "LM scale."); 16 po->Register("lm-scale", &scale, "LM scale.");
  17 + po->Register("lm-num-threads", &lm_num_threads,
  18 + "Number of threads to run the neural network of LM model");
  19 + po->Register("lm-provider", &lm_provider,
  20 + "Specify a provider to LM model use: cpu, cuda, coreml");
17 } 21 }
18 22
19 bool OnlineLMConfig::Validate() const { 23 bool OnlineLMConfig::Validate() const {
@@ -16,11 +16,17 @@ struct OnlineLMConfig { @@ -16,11 +16,17 @@ struct OnlineLMConfig {
16 16
17 // LM scale 17 // LM scale
18 float scale = 0.5; 18 float scale = 0.5;
  19 + int32_t lm_num_threads = 1;
  20 + std::string lm_provider = "cpu";
19 21
20 OnlineLMConfig() = default; 22 OnlineLMConfig() = default;
21 23
22 - OnlineLMConfig(const std::string &model, float scale)  
23 - : model(model), scale(scale) {} 24 + OnlineLMConfig(const std::string &model, float scale, int32_t lm_num_threads,
  25 + const std::string &lm_provider)
  26 + : model(model),
  27 + scale(scale),
  28 + lm_num_threads(lm_num_threads),
  29 + lm_provider(lm_provider) {}
24 30
25 void Register(ParseOptions *po); 31 void Register(ParseOptions *po);
26 bool Validate() const; 32 bool Validate() const;
@@ -13,8 +13,7 @@ @@ -13,8 +13,7 @@
13 13
14 namespace sherpa_onnx { 14 namespace sherpa_onnx {
15 15
16 -std::unique_ptr<OnlineLM> OnlineLM::Create(  
17 - const OnlineRecognizerConfig &config) { 16 +std::unique_ptr<OnlineLM> OnlineLM::Create(const OnlineLMConfig &config) {
18 return std::make_unique<OnlineRnnLM>(config); 17 return std::make_unique<OnlineRnnLM>(config);
19 } 18 }
20 19
@@ -11,7 +11,7 @@ @@ -11,7 +11,7 @@
11 11
12 #include "onnxruntime_cxx_api.h" // NOLINT 12 #include "onnxruntime_cxx_api.h" // NOLINT
13 #include "sherpa-onnx/csrc/hypothesis.h" 13 #include "sherpa-onnx/csrc/hypothesis.h"
14 -#include "sherpa-onnx/csrc/online-recognizer.h" 14 +#include "sherpa-onnx/csrc/online-lm-config.h"
15 15
16 namespace sherpa_onnx { 16 namespace sherpa_onnx {
17 17
@@ -19,7 +19,7 @@ class OnlineLM { @@ -19,7 +19,7 @@ class OnlineLM {
19 public: 19 public:
20 virtual ~OnlineLM() = default; 20 virtual ~OnlineLM() = default;
21 21
22 - static std::unique_ptr<OnlineLM> Create(const OnlineRecognizerConfig &config); 22 + static std::unique_ptr<OnlineLM> Create(const OnlineLMConfig &config);
23 23
24 virtual std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() = 0; 24 virtual std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() = 0;
25 25
@@ -129,7 +129,7 @@ class OnlineRecognizer::Impl { @@ -129,7 +129,7 @@ class OnlineRecognizer::Impl {
129 endpoint_(config_.endpoint_config) { 129 endpoint_(config_.endpoint_config) {
130 if (config.decoding_method == "modified_beam_search") { 130 if (config.decoding_method == "modified_beam_search") {
131 if (!config_.lm_config.model.empty()) { 131 if (!config_.lm_config.model.empty()) {
132 - lm_ = OnlineLM::Create(config); 132 + lm_ = OnlineLM::Create(config.lm_config);
133 } 133 }
134 134
135 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( 135 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
@@ -19,12 +19,12 @@ namespace sherpa_onnx { @@ -19,12 +19,12 @@ namespace sherpa_onnx {
19 19
20 class OnlineRnnLM::Impl { 20 class OnlineRnnLM::Impl {
21 public: 21 public:
22 - explicit Impl(const OnlineRecognizerConfig &config)  
23 - : config_(config.lm_config), 22 + explicit Impl(const OnlineLMConfig &config)
  23 + : config_(config),
24 env_(ORT_LOGGING_LEVEL_ERROR), 24 env_(ORT_LOGGING_LEVEL_ERROR),
25 - sess_opts_{GetSessionOptions(config.model_config)}, 25 + sess_opts_{GetSessionOptions(config)},
26 allocator_{} { 26 allocator_{} {
27 - Init(config.lm_config); 27 + Init(config);
28 } 28 }
29 29
30 void ComputeLMScore(float scale, Hypothesis *hyp) { 30 void ComputeLMScore(float scale, Hypothesis *hyp) {
@@ -143,7 +143,7 @@ class OnlineRnnLM::Impl { @@ -143,7 +143,7 @@ class OnlineRnnLM::Impl {
143 int32_t sos_id_ = 1; 143 int32_t sos_id_ = 1;
144 }; 144 };
145 145
146 -OnlineRnnLM::OnlineRnnLM(const OnlineRecognizerConfig &config) 146 +OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config)
147 : impl_(std::make_unique<Impl>(config)) {} 147 : impl_(std::make_unique<Impl>(config)) {}
148 148
149 OnlineRnnLM::~OnlineRnnLM() = default; 149 OnlineRnnLM::~OnlineRnnLM() = default;
@@ -20,7 +20,7 @@ class OnlineRnnLM : public OnlineLM { @@ -20,7 +20,7 @@ class OnlineRnnLM : public OnlineLM {
20 public: 20 public:
21 ~OnlineRnnLM() override; 21 ~OnlineRnnLM() override;
22 22
23 - explicit OnlineRnnLM(const OnlineRecognizerConfig &config); 23 + explicit OnlineRnnLM(const OnlineLMConfig &config);
24 24
25 std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() override; 25 std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() override;
26 26
@@ -69,4 +69,12 @@ Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) { @@ -69,4 +69,12 @@ Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) {
69 return GetSessionOptionsImpl(config.num_threads, config.provider); 69 return GetSessionOptionsImpl(config.num_threads, config.provider);
70 } 70 }
71 71
  72 +Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config) {
  73 + return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider);
  74 +}
  75 +
  76 +Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config) {
  77 + return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider);
  78 +}
  79 +
72 } // namespace sherpa_onnx 80 } // namespace sherpa_onnx
@@ -6,7 +6,9 @@ @@ -6,7 +6,9 @@
6 #define SHERPA_ONNX_CSRC_SESSION_H_ 6 #define SHERPA_ONNX_CSRC_SESSION_H_
7 7
8 #include "onnxruntime_cxx_api.h" // NOLINT 8 #include "onnxruntime_cxx_api.h" // NOLINT
  9 +#include "sherpa-onnx/csrc/offline-lm-config.h"
9 #include "sherpa-onnx/csrc/offline-model-config.h" 10 #include "sherpa-onnx/csrc/offline-model-config.h"
  11 +#include "sherpa-onnx/csrc/online-lm-config.h"
10 #include "sherpa-onnx/csrc/online-transducer-model-config.h" 12 #include "sherpa-onnx/csrc/online-transducer-model-config.h"
11 13
12 namespace sherpa_onnx { 14 namespace sherpa_onnx {
@@ -16,6 +18,9 @@ Ort::SessionOptions GetSessionOptions( @@ -16,6 +18,9 @@ Ort::SessionOptions GetSessionOptions(
16 18
17 Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config); 19 Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config);
18 20
  21 +Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config);
  22 +
  23 +Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config);
19 } // namespace sherpa_onnx 24 } // namespace sherpa_onnx
20 25
21 #endif // SHERPA_ONNX_CSRC_SESSION_H_ 26 #endif // SHERPA_ONNX_CSRC_SESSION_H_
@@ -13,10 +13,13 @@ namespace sherpa_onnx { @@ -13,10 +13,13 @@ namespace sherpa_onnx {
13 void PybindOfflineLMConfig(py::module *m) { 13 void PybindOfflineLMConfig(py::module *m) {
14 using PyClass = OfflineLMConfig; 14 using PyClass = OfflineLMConfig;
15 py::class_<PyClass>(*m, "OfflineLMConfig") 15 py::class_<PyClass>(*m, "OfflineLMConfig")
16 - .def(py::init<const std::string &, float>(), py::arg("model"),  
17 - py::arg("scale")) 16 + .def(py::init<const std::string &, float, int32_t, const std::string &>(),
  17 + py::arg("model"), py::arg("scale") = 0.5f,
  18 + py::arg("lm_num_threads") = 1, py::arg("lm-provider") = "cpu")
18 .def_readwrite("model", &PyClass::model) 19 .def_readwrite("model", &PyClass::model)
19 .def_readwrite("scale", &PyClass::scale) 20 .def_readwrite("scale", &PyClass::scale)
  21 + .def_readwrite("lm_provider", &PyClass::lm_provider)
  22 + .def_readwrite("lm_num_threads", &PyClass::lm_num_threads)
20 .def("__str__", &PyClass::ToString); 23 .def("__str__", &PyClass::ToString);
21 } 24 }
22 25
@@ -13,10 +13,13 @@ namespace sherpa_onnx { @@ -13,10 +13,13 @@ namespace sherpa_onnx {
13 void PybindOnlineLMConfig(py::module *m) { 13 void PybindOnlineLMConfig(py::module *m) {
14 using PyClass = OnlineLMConfig; 14 using PyClass = OnlineLMConfig;
15 py::class_<PyClass>(*m, "OnlineLMConfig") 15 py::class_<PyClass>(*m, "OnlineLMConfig")
16 - .def(py::init<const std::string &, float>(), py::arg("model") = "",  
17 - py::arg("scale") = 0.5f) 16 + .def(py::init<const std::string &, float, int32_t, const std::string &>(),
  17 + py::arg("model") = "", py::arg("scale") = 0.5f,
  18 + py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu")
18 .def_readwrite("model", &PyClass::model) 19 .def_readwrite("model", &PyClass::model)
19 .def_readwrite("scale", &PyClass::scale) 20 .def_readwrite("scale", &PyClass::scale)
  21 + .def_readwrite("lm_provider", &PyClass::lm_provider)
  22 + .def_readwrite("lm_num_threads", &PyClass::lm_num_threads)
20 .def("__str__", &PyClass::ToString); 23 .def("__str__", &PyClass::ToString);
21 } 24 }
22 25