keanu
Committed by GitHub

rnnlm model inference supports num_threads setting (#169)

Co-authored-by: cuidongcai1035 <cuidongcai1035@wezhuiyi.com>
@@ -12,7 +12,8 @@ @@ -12,7 +12,8 @@
12 12
13 namespace sherpa_onnx { 13 namespace sherpa_onnx {
14 14
15 -std::unique_ptr<OfflineLM> OfflineLM::Create(const OfflineLMConfig &config) { 15 +std::unique_ptr<OfflineLM> OfflineLM::Create(
  16 + const OfflineRecognizerConfig &config) {
16 return std::make_unique<OfflineRnnLM>(config); 17 return std::make_unique<OfflineRnnLM>(config);
17 } 18 }
18 19
@@ -10,7 +10,7 @@ @@ -10,7 +10,7 @@
10 10
11 #include "onnxruntime_cxx_api.h" // NOLINT 11 #include "onnxruntime_cxx_api.h" // NOLINT
12 #include "sherpa-onnx/csrc/hypothesis.h" 12 #include "sherpa-onnx/csrc/hypothesis.h"
13 -#include "sherpa-onnx/csrc/offline-lm-config.h" 13 +#include "sherpa-onnx/csrc/offline-recognizer.h"
14 14
15 namespace sherpa_onnx { 15 namespace sherpa_onnx {
16 16
@@ -18,7 +18,8 @@ class OfflineLM { @@ -18,7 +18,8 @@ class OfflineLM {
18 public: 18 public:
19 virtual ~OfflineLM() = default; 19 virtual ~OfflineLM() = default;
20 20
21 - static std::unique_ptr<OfflineLM> Create(const OfflineLMConfig &config); 21 + static std::unique_ptr<OfflineLM> Create(
  22 + const OfflineRecognizerConfig &config);
22 23
23 /** Rescore a batch of sentences. 24 /** Rescore a batch of sentences.
24 * 25 *
@@ -59,7 +59,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -59,7 +59,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
59 std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get()); 59 std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
60 } else if (config_.decoding_method == "modified_beam_search") { 60 } else if (config_.decoding_method == "modified_beam_search") {
61 if (!config_.lm_config.model.empty()) { 61 if (!config_.lm_config.model.empty()) {
62 - lm_ = OfflineLM::Create(config.lm_config); 62 + lm_ = OfflineLM::Create(config);
63 } 63 }
64 64
65 decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( 65 decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
@@ -12,17 +12,18 @@ @@ -12,17 +12,18 @@
12 #include "sherpa-onnx/csrc/macros.h" 12 #include "sherpa-onnx/csrc/macros.h"
13 #include "sherpa-onnx/csrc/onnx-utils.h" 13 #include "sherpa-onnx/csrc/onnx-utils.h"
14 #include "sherpa-onnx/csrc/text-utils.h" 14 #include "sherpa-onnx/csrc/text-utils.h"
  15 +#include "sherpa-onnx/csrc/session.h"
15 16
16 namespace sherpa_onnx { 17 namespace sherpa_onnx {
17 18
18 class OfflineRnnLM::Impl { 19 class OfflineRnnLM::Impl {
19 public: 20 public:
20 - explicit Impl(const OfflineLMConfig &config)  
21 - : config_(config), 21 + explicit Impl(const OfflineRecognizerConfig &config)
  22 + : config_(config.lm_config),
22 env_(ORT_LOGGING_LEVEL_ERROR), 23 env_(ORT_LOGGING_LEVEL_ERROR),
23 - sess_opts_{}, 24 + sess_opts_{GetSessionOptions(config.model_config)},
24 allocator_{} { 25 allocator_{} {
25 - Init(config); 26 + Init(config.lm_config);
26 } 27 }
27 28
28 Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) { 29 Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) {
@@ -62,7 +63,7 @@ class OfflineRnnLM::Impl { @@ -62,7 +63,7 @@ class OfflineRnnLM::Impl {
62 std::vector<const char *> output_names_ptr_; 63 std::vector<const char *> output_names_ptr_;
63 }; 64 };
64 65
65 -OfflineRnnLM::OfflineRnnLM(const OfflineLMConfig &config) 66 +OfflineRnnLM::OfflineRnnLM(const OfflineRecognizerConfig &config)
66 : impl_(std::make_unique<Impl>(config)) {} 67 : impl_(std::make_unique<Impl>(config)) {}
67 68
68 OfflineRnnLM::~OfflineRnnLM() = default; 69 OfflineRnnLM::~OfflineRnnLM() = default;
@@ -17,7 +17,7 @@ class OfflineRnnLM : public OfflineLM { @@ -17,7 +17,7 @@ class OfflineRnnLM : public OfflineLM {
17 public: 17 public:
18 ~OfflineRnnLM() override; 18 ~OfflineRnnLM() override;
19 19
20 - explicit OfflineRnnLM(const OfflineLMConfig &config); 20 + explicit OfflineRnnLM(const OfflineRecognizerConfig &config);
21 21
22 /** Rescore a batch of sentences. 22 /** Rescore a batch of sentences.
23 * 23 *
@@ -13,7 +13,8 @@ @@ -13,7 +13,8 @@
13 13
14 namespace sherpa_onnx { 14 namespace sherpa_onnx {
15 15
16 -std::unique_ptr<OnlineLM> OnlineLM::Create(const OnlineLMConfig &config) { 16 +std::unique_ptr<OnlineLM> OnlineLM::Create(
  17 + const OnlineRecognizerConfig &config) {
17 return std::make_unique<OnlineRnnLM>(config); 18 return std::make_unique<OnlineRnnLM>(config);
18 } 19 }
19 20
@@ -11,7 +11,7 @@ @@ -11,7 +11,7 @@
11 11
12 #include "onnxruntime_cxx_api.h" // NOLINT 12 #include "onnxruntime_cxx_api.h" // NOLINT
13 #include "sherpa-onnx/csrc/hypothesis.h" 13 #include "sherpa-onnx/csrc/hypothesis.h"
14 -#include "sherpa-onnx/csrc/online-lm-config.h" 14 +#include "sherpa-onnx/csrc/online-recognizer.h"
15 15
16 namespace sherpa_onnx { 16 namespace sherpa_onnx {
17 17
@@ -19,7 +19,7 @@ class OnlineLM { @@ -19,7 +19,7 @@ class OnlineLM {
19 public: 19 public:
20 virtual ~OnlineLM() = default; 20 virtual ~OnlineLM() = default;
21 21
22 - static std::unique_ptr<OnlineLM> Create(const OnlineLMConfig &config); 22 + static std::unique_ptr<OnlineLM> Create(const OnlineRecognizerConfig &config);
23 23
24 virtual std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() = 0; 24 virtual std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() = 0;
25 25
@@ -129,7 +129,7 @@ class OnlineRecognizer::Impl { @@ -129,7 +129,7 @@ class OnlineRecognizer::Impl {
129 endpoint_(config_.endpoint_config) { 129 endpoint_(config_.endpoint_config) {
130 if (config.decoding_method == "modified_beam_search") { 130 if (config.decoding_method == "modified_beam_search") {
131 if (!config_.lm_config.model.empty()) { 131 if (!config_.lm_config.model.empty()) {
132 - lm_ = OnlineLM::Create(config.lm_config); 132 + lm_ = OnlineLM::Create(config);
133 } 133 }
134 134
135 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( 135 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
@@ -13,17 +13,18 @@ @@ -13,17 +13,18 @@
13 #include "sherpa-onnx/csrc/macros.h" 13 #include "sherpa-onnx/csrc/macros.h"
14 #include "sherpa-onnx/csrc/onnx-utils.h" 14 #include "sherpa-onnx/csrc/onnx-utils.h"
15 #include "sherpa-onnx/csrc/text-utils.h" 15 #include "sherpa-onnx/csrc/text-utils.h"
  16 +#include "sherpa-onnx/csrc/session.h"
16 17
17 namespace sherpa_onnx { 18 namespace sherpa_onnx {
18 19
19 class OnlineRnnLM::Impl { 20 class OnlineRnnLM::Impl {
20 public: 21 public:
21 - explicit Impl(const OnlineLMConfig &config)  
22 - : config_(config), 22 + explicit Impl(const OnlineRecognizerConfig &config)
  23 + : config_(config.lm_config),
23 env_(ORT_LOGGING_LEVEL_ERROR), 24 env_(ORT_LOGGING_LEVEL_ERROR),
24 - sess_opts_{}, 25 + sess_opts_{GetSessionOptions(config.model_config)},
25 allocator_{} { 26 allocator_{} {
26 - Init(config); 27 + Init(config.lm_config);
27 } 28 }
28 29
29 void ComputeLMScore(float scale, Hypothesis *hyp) { 30 void ComputeLMScore(float scale, Hypothesis *hyp) {
@@ -142,7 +143,7 @@ class OnlineRnnLM::Impl { @@ -142,7 +143,7 @@ class OnlineRnnLM::Impl {
142 int32_t sos_id_ = 1; 143 int32_t sos_id_ = 1;
143 }; 144 };
144 145
145 -OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config) 146 +OnlineRnnLM::OnlineRnnLM(const OnlineRecognizerConfig &config)
146 : impl_(std::make_unique<Impl>(config)) {} 147 : impl_(std::make_unique<Impl>(config)) {}
147 148
148 OnlineRnnLM::~OnlineRnnLM() = default; 149 OnlineRnnLM::~OnlineRnnLM() = default;
@@ -20,7 +20,7 @@ class OnlineRnnLM : public OnlineLM { @@ -20,7 +20,7 @@ class OnlineRnnLM : public OnlineLM {
20 public: 20 public:
21 ~OnlineRnnLM() override; 21 ~OnlineRnnLM() override;
22 22
23 - explicit OnlineRnnLM(const OnlineLMConfig &config); 23 + explicit OnlineRnnLM(const OnlineRecognizerConfig &config);
24 24
25 std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() override; 25 std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() override;
26 26