Committed by
GitHub
rnnlm model inference supports num_threads setting (#169)
Co-authored-by: cuidongcai1035 <cuidongcai1035@wezhuiyi.com>
正在显示
10 个修改的文件
包含
25 行增加
和
20 行删除
| @@ -12,7 +12,8 @@ | @@ -12,7 +12,8 @@ | ||
| 12 | 12 | ||
| 13 | namespace sherpa_onnx { | 13 | namespace sherpa_onnx { |
| 14 | 14 | ||
| 15 | -std::unique_ptr<OfflineLM> OfflineLM::Create(const OfflineLMConfig &config) { | 15 | +std::unique_ptr<OfflineLM> OfflineLM::Create( |
| 16 | + const OfflineRecognizerConfig &config) { | ||
| 16 | return std::make_unique<OfflineRnnLM>(config); | 17 | return std::make_unique<OfflineRnnLM>(config); |
| 17 | } | 18 | } |
| 18 | 19 |
| @@ -10,7 +10,7 @@ | @@ -10,7 +10,7 @@ | ||
| 10 | 10 | ||
| 11 | #include "onnxruntime_cxx_api.h" // NOLINT | 11 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 12 | #include "sherpa-onnx/csrc/hypothesis.h" | 12 | #include "sherpa-onnx/csrc/hypothesis.h" |
| 13 | -#include "sherpa-onnx/csrc/offline-lm-config.h" | 13 | +#include "sherpa-onnx/csrc/offline-recognizer.h" |
| 14 | 14 | ||
| 15 | namespace sherpa_onnx { | 15 | namespace sherpa_onnx { |
| 16 | 16 | ||
| @@ -18,7 +18,8 @@ class OfflineLM { | @@ -18,7 +18,8 @@ class OfflineLM { | ||
| 18 | public: | 18 | public: |
| 19 | virtual ~OfflineLM() = default; | 19 | virtual ~OfflineLM() = default; |
| 20 | 20 | ||
| 21 | - static std::unique_ptr<OfflineLM> Create(const OfflineLMConfig &config); | 21 | + static std::unique_ptr<OfflineLM> Create( |
| 22 | + const OfflineRecognizerConfig &config); | ||
| 22 | 23 | ||
| 23 | /** Rescore a batch of sentences. | 24 | /** Rescore a batch of sentences. |
| 24 | * | 25 | * |
| @@ -59,7 +59,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -59,7 +59,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 59 | std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get()); | 59 | std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get()); |
| 60 | } else if (config_.decoding_method == "modified_beam_search") { | 60 | } else if (config_.decoding_method == "modified_beam_search") { |
| 61 | if (!config_.lm_config.model.empty()) { | 61 | if (!config_.lm_config.model.empty()) { |
| 62 | - lm_ = OfflineLM::Create(config.lm_config); | 62 | + lm_ = OfflineLM::Create(config); |
| 63 | } | 63 | } |
| 64 | 64 | ||
| 65 | decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( | 65 | decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( |
| @@ -12,17 +12,18 @@ | @@ -12,17 +12,18 @@ | ||
| 12 | #include "sherpa-onnx/csrc/macros.h" | 12 | #include "sherpa-onnx/csrc/macros.h" |
| 13 | #include "sherpa-onnx/csrc/onnx-utils.h" | 13 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 14 | #include "sherpa-onnx/csrc/text-utils.h" | 14 | #include "sherpa-onnx/csrc/text-utils.h" |
| 15 | +#include "sherpa-onnx/csrc/session.h" | ||
| 15 | 16 | ||
| 16 | namespace sherpa_onnx { | 17 | namespace sherpa_onnx { |
| 17 | 18 | ||
| 18 | class OfflineRnnLM::Impl { | 19 | class OfflineRnnLM::Impl { |
| 19 | public: | 20 | public: |
| 20 | - explicit Impl(const OfflineLMConfig &config) | ||
| 21 | - : config_(config), | 21 | + explicit Impl(const OfflineRecognizerConfig &config) |
| 22 | + : config_(config.lm_config), | ||
| 22 | env_(ORT_LOGGING_LEVEL_ERROR), | 23 | env_(ORT_LOGGING_LEVEL_ERROR), |
| 23 | - sess_opts_{}, | 24 | + sess_opts_{GetSessionOptions(config.model_config)}, |
| 24 | allocator_{} { | 25 | allocator_{} { |
| 25 | - Init(config); | 26 | + Init(config.lm_config); |
| 26 | } | 27 | } |
| 27 | 28 | ||
| 28 | Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) { | 29 | Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) { |
| @@ -62,7 +63,7 @@ class OfflineRnnLM::Impl { | @@ -62,7 +63,7 @@ class OfflineRnnLM::Impl { | ||
| 62 | std::vector<const char *> output_names_ptr_; | 63 | std::vector<const char *> output_names_ptr_; |
| 63 | }; | 64 | }; |
| 64 | 65 | ||
| 65 | -OfflineRnnLM::OfflineRnnLM(const OfflineLMConfig &config) | 66 | +OfflineRnnLM::OfflineRnnLM(const OfflineRecognizerConfig &config) |
| 66 | : impl_(std::make_unique<Impl>(config)) {} | 67 | : impl_(std::make_unique<Impl>(config)) {} |
| 67 | 68 | ||
| 68 | OfflineRnnLM::~OfflineRnnLM() = default; | 69 | OfflineRnnLM::~OfflineRnnLM() = default; |
| @@ -17,7 +17,7 @@ class OfflineRnnLM : public OfflineLM { | @@ -17,7 +17,7 @@ class OfflineRnnLM : public OfflineLM { | ||
| 17 | public: | 17 | public: |
| 18 | ~OfflineRnnLM() override; | 18 | ~OfflineRnnLM() override; |
| 19 | 19 | ||
| 20 | - explicit OfflineRnnLM(const OfflineLMConfig &config); | 20 | + explicit OfflineRnnLM(const OfflineRecognizerConfig &config); |
| 21 | 21 | ||
| 22 | /** Rescore a batch of sentences. | 22 | /** Rescore a batch of sentences. |
| 23 | * | 23 | * |
| @@ -13,7 +13,8 @@ | @@ -13,7 +13,8 @@ | ||
| 13 | 13 | ||
| 14 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| 15 | 15 | ||
| 16 | -std::unique_ptr<OnlineLM> OnlineLM::Create(const OnlineLMConfig &config) { | 16 | +std::unique_ptr<OnlineLM> OnlineLM::Create( |
| 17 | + const OnlineRecognizerConfig &config) { | ||
| 17 | return std::make_unique<OnlineRnnLM>(config); | 18 | return std::make_unique<OnlineRnnLM>(config); |
| 18 | } | 19 | } |
| 19 | 20 |
| @@ -11,7 +11,7 @@ | @@ -11,7 +11,7 @@ | ||
| 11 | 11 | ||
| 12 | #include "onnxruntime_cxx_api.h" // NOLINT | 12 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 13 | #include "sherpa-onnx/csrc/hypothesis.h" | 13 | #include "sherpa-onnx/csrc/hypothesis.h" |
| 14 | -#include "sherpa-onnx/csrc/online-lm-config.h" | 14 | +#include "sherpa-onnx/csrc/online-recognizer.h" |
| 15 | 15 | ||
| 16 | namespace sherpa_onnx { | 16 | namespace sherpa_onnx { |
| 17 | 17 | ||
| @@ -19,7 +19,7 @@ class OnlineLM { | @@ -19,7 +19,7 @@ class OnlineLM { | ||
| 19 | public: | 19 | public: |
| 20 | virtual ~OnlineLM() = default; | 20 | virtual ~OnlineLM() = default; |
| 21 | 21 | ||
| 22 | - static std::unique_ptr<OnlineLM> Create(const OnlineLMConfig &config); | 22 | + static std::unique_ptr<OnlineLM> Create(const OnlineRecognizerConfig &config); |
| 23 | 23 | ||
| 24 | virtual std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() = 0; | 24 | virtual std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() = 0; |
| 25 | 25 |
| @@ -129,7 +129,7 @@ class OnlineRecognizer::Impl { | @@ -129,7 +129,7 @@ class OnlineRecognizer::Impl { | ||
| 129 | endpoint_(config_.endpoint_config) { | 129 | endpoint_(config_.endpoint_config) { |
| 130 | if (config.decoding_method == "modified_beam_search") { | 130 | if (config.decoding_method == "modified_beam_search") { |
| 131 | if (!config_.lm_config.model.empty()) { | 131 | if (!config_.lm_config.model.empty()) { |
| 132 | - lm_ = OnlineLM::Create(config.lm_config); | 132 | + lm_ = OnlineLM::Create(config); |
| 133 | } | 133 | } |
| 134 | 134 | ||
| 135 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( | 135 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( |
| @@ -13,17 +13,18 @@ | @@ -13,17 +13,18 @@ | ||
| 13 | #include "sherpa-onnx/csrc/macros.h" | 13 | #include "sherpa-onnx/csrc/macros.h" |
| 14 | #include "sherpa-onnx/csrc/onnx-utils.h" | 14 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 15 | #include "sherpa-onnx/csrc/text-utils.h" | 15 | #include "sherpa-onnx/csrc/text-utils.h" |
| 16 | +#include "sherpa-onnx/csrc/session.h" | ||
| 16 | 17 | ||
| 17 | namespace sherpa_onnx { | 18 | namespace sherpa_onnx { |
| 18 | 19 | ||
| 19 | class OnlineRnnLM::Impl { | 20 | class OnlineRnnLM::Impl { |
| 20 | public: | 21 | public: |
| 21 | - explicit Impl(const OnlineLMConfig &config) | ||
| 22 | - : config_(config), | 22 | + explicit Impl(const OnlineRecognizerConfig &config) |
| 23 | + : config_(config.lm_config), | ||
| 23 | env_(ORT_LOGGING_LEVEL_ERROR), | 24 | env_(ORT_LOGGING_LEVEL_ERROR), |
| 24 | - sess_opts_{}, | 25 | + sess_opts_{GetSessionOptions(config.model_config)}, |
| 25 | allocator_{} { | 26 | allocator_{} { |
| 26 | - Init(config); | 27 | + Init(config.lm_config); |
| 27 | } | 28 | } |
| 28 | 29 | ||
| 29 | void ComputeLMScore(float scale, Hypothesis *hyp) { | 30 | void ComputeLMScore(float scale, Hypothesis *hyp) { |
| @@ -142,7 +143,7 @@ class OnlineRnnLM::Impl { | @@ -142,7 +143,7 @@ class OnlineRnnLM::Impl { | ||
| 142 | int32_t sos_id_ = 1; | 143 | int32_t sos_id_ = 1; |
| 143 | }; | 144 | }; |
| 144 | 145 | ||
| 145 | -OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config) | 146 | +OnlineRnnLM::OnlineRnnLM(const OnlineRecognizerConfig &config) |
| 146 | : impl_(std::make_unique<Impl>(config)) {} | 147 | : impl_(std::make_unique<Impl>(config)) {} |
| 147 | 148 | ||
| 148 | OnlineRnnLM::~OnlineRnnLM() = default; | 149 | OnlineRnnLM::~OnlineRnnLM() = default; |
| @@ -20,7 +20,7 @@ class OnlineRnnLM : public OnlineLM { | @@ -20,7 +20,7 @@ class OnlineRnnLM : public OnlineLM { | ||
| 20 | public: | 20 | public: |
| 21 | ~OnlineRnnLM() override; | 21 | ~OnlineRnnLM() override; |
| 22 | 22 | ||
| 23 | - explicit OnlineRnnLM(const OnlineLMConfig &config); | 23 | + explicit OnlineRnnLM(const OnlineRecognizerConfig &config); |
| 24 | 24 | ||
| 25 | std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() override; | 25 | std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() override; |
| 26 | 26 |
-
请 注册 或 登录 后发表评论