Committed by
GitHub
Add lm rescore to online-modified-beam-search (#133)
正在显示
26 个修改的文件
包含
495 行增加
和
39 行删除
| @@ -182,9 +182,10 @@ class MainActivity : AppCompatActivity() { | @@ -182,9 +182,10 @@ class MainActivity : AppCompatActivity() { | ||
| 182 | val config = OnlineRecognizerConfig( | 182 | val config = OnlineRecognizerConfig( |
| 183 | featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80), | 183 | featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80), |
| 184 | modelConfig = getModelConfig(type = type)!!, | 184 | modelConfig = getModelConfig(type = type)!!, |
| 185 | + lmConfig = getOnlineLMConfig(type = type), | ||
| 185 | endpointConfig = getEndpointConfig(), | 186 | endpointConfig = getEndpointConfig(), |
| 186 | enableEndpoint = true, | 187 | enableEndpoint = true, |
| 187 | - decodingMethod = "greedy_search", | 188 | + decodingMethod = "modified_beam_search", |
| 188 | maxActivePaths = 4, | 189 | maxActivePaths = 4, |
| 189 | ) | 190 | ) |
| 190 | 191 |
| @@ -23,6 +23,11 @@ data class OnlineTransducerModelConfig( | @@ -23,6 +23,11 @@ data class OnlineTransducerModelConfig( | ||
| 23 | var debug: Boolean = false, | 23 | var debug: Boolean = false, |
| 24 | ) | 24 | ) |
| 25 | 25 | ||
| 26 | +data class OnlineLMConfig( | ||
| 27 | + var model: String = "", | ||
| 28 | + var scale: Float = 0.5f, | ||
| 29 | +) | ||
| 30 | + | ||
| 26 | data class FeatureConfig( | 31 | data class FeatureConfig( |
| 27 | var sampleRate: Int = 16000, | 32 | var sampleRate: Int = 16000, |
| 28 | var featureDim: Int = 80, | 33 | var featureDim: Int = 80, |
| @@ -31,6 +36,7 @@ data class FeatureConfig( | @@ -31,6 +36,7 @@ data class FeatureConfig( | ||
| 31 | data class OnlineRecognizerConfig( | 36 | data class OnlineRecognizerConfig( |
| 32 | var featConfig: FeatureConfig = FeatureConfig(), | 37 | var featConfig: FeatureConfig = FeatureConfig(), |
| 33 | var modelConfig: OnlineTransducerModelConfig, | 38 | var modelConfig: OnlineTransducerModelConfig, |
| 39 | + var lmConfig : OnlineLMConfig, | ||
| 34 | var endpointConfig: EndpointConfig = EndpointConfig(), | 40 | var endpointConfig: EndpointConfig = EndpointConfig(), |
| 35 | var enableEndpoint: Boolean = true, | 41 | var enableEndpoint: Boolean = true, |
| 36 | var decodingMethod: String = "greedy_search", | 42 | var decodingMethod: String = "greedy_search", |
| @@ -151,6 +157,32 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { | @@ -151,6 +157,32 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { | ||
| 151 | return null; | 157 | return null; |
| 152 | } | 158 | } |
| 153 | 159 | ||
| 160 | +/* | ||
| 161 | +Please see | ||
| 162 | +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html | ||
| 163 | +for a list of pre-trained models. | ||
| 164 | + | ||
| 165 | +We only add a few here. Please change the following code | ||
| 166 | +to add your own LM model. (It should be straightforward to train a new NN LM model | ||
| 167 | +by following the code, https://github.com/k2-fsa/icefall/blob/master/icefall/rnn_lm/train.py) | ||
| 168 | + | ||
| 169 | +@param type | ||
| 170 | +0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English) | ||
| 171 | + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english | ||
| 172 | + */ | ||
| 173 | +fun getOnlineLMConfig(type : Int): OnlineLMConfig { | ||
| 174 | + when (type) { | ||
| 175 | + 0 -> { | ||
| 176 | + val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20" | ||
| 177 | + return OnlineLMConfig( | ||
| 178 | + model = "$modelDir/with-state-epoch-99-avg-1.int8.onnx", | ||
| 179 | + scale = 0.5f, | ||
| 180 | + ) | ||
| 181 | + } | ||
| 182 | + } | ||
| 183 | + return OnlineLMConfig(); | ||
| 184 | +} | ||
| 185 | + | ||
| 154 | fun getEndpointConfig(): EndpointConfig { | 186 | fun getEndpointConfig(): EndpointConfig { |
| 155 | return EndpointConfig( | 187 | return EndpointConfig( |
| 156 | rule1 = EndpointRule(false, 2.4f, 0.0f), | 188 | rule1 = EndpointRule(false, 2.4f, 0.0f), |
| @@ -22,8 +22,11 @@ fun main() { | @@ -22,8 +22,11 @@ fun main() { | ||
| 22 | 22 | ||
| 23 | var endpointConfig = EndpointConfig() | 23 | var endpointConfig = EndpointConfig() |
| 24 | 24 | ||
| 25 | + var lmConfig = OnlineLMConfig() | ||
| 26 | + | ||
| 25 | var config = OnlineRecognizerConfig( | 27 | var config = OnlineRecognizerConfig( |
| 26 | modelConfig = modelConfig, | 28 | modelConfig = modelConfig, |
| 29 | + lmConfig = lmConfig, | ||
| 27 | featConfig = featConfig, | 30 | featConfig = featConfig, |
| 28 | endpointConfig = endpointConfig, | 31 | endpointConfig = endpointConfig, |
| 29 | enableEndpoint = true, | 32 | enableEndpoint = true, |
| @@ -34,9 +34,11 @@ set(sources | @@ -34,9 +34,11 @@ set(sources | ||
| 34 | offline-transducer-model-config.cc | 34 | offline-transducer-model-config.cc |
| 35 | offline-transducer-model.cc | 35 | offline-transducer-model.cc |
| 36 | offline-transducer-modified-beam-search-decoder.cc | 36 | offline-transducer-modified-beam-search-decoder.cc |
| 37 | + online-lm.cc | ||
| 37 | online-lm-config.cc | 38 | online-lm-config.cc |
| 38 | online-lstm-transducer-model.cc | 39 | online-lstm-transducer-model.cc |
| 39 | online-recognizer.cc | 40 | online-recognizer.cc |
| 41 | + online-rnn-lm.cc | ||
| 40 | online-stream.cc | 42 | online-stream.cc |
| 41 | online-transducer-decoder.cc | 43 | online-transducer-decoder.cc |
| 42 | online-transducer-greedy-search-decoder.cc | 44 | online-transducer-greedy-search-decoder.cc |
| 1 | /** | 1 | /** |
| 2 | * Copyright (c) 2023 Xiaomi Corporation | 2 | * Copyright (c) 2023 Xiaomi Corporation |
| 3 | + * Copyright (c) 2023 Pingfeng Luo | ||
| 3 | * | 4 | * |
| 4 | */ | 5 | */ |
| 5 | 6 | ||
| @@ -12,7 +13,9 @@ | @@ -12,7 +13,9 @@ | ||
| 12 | #include <utility> | 13 | #include <utility> |
| 13 | #include <vector> | 14 | #include <vector> |
| 14 | 15 | ||
| 16 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 15 | #include "sherpa-onnx/csrc/math.h" | 17 | #include "sherpa-onnx/csrc/math.h" |
| 18 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 16 | 19 | ||
| 17 | namespace sherpa_onnx { | 20 | namespace sherpa_onnx { |
| 18 | 21 | ||
| @@ -31,6 +34,13 @@ struct Hypothesis { | @@ -31,6 +34,13 @@ struct Hypothesis { | ||
| 31 | // LM log prob if any. | 34 | // LM log prob if any. |
| 32 | double lm_log_prob = 0; | 35 | double lm_log_prob = 0; |
| 33 | 36 | ||
| 37 | + int32_t cur_scored_pos = 0; // cur scored tokens by RNN LM | ||
| 38 | + std::vector<CopyableOrtValue> nn_lm_states; | ||
| 39 | + | ||
| 40 | + // TODO(fangjun): Make it configurable | ||
| 41 | + // the minimum of tokens in a chunk for streaming RNN LM | ||
| 42 | + int32_t lm_rescore_min_chunk = 2; // a const | ||
| 43 | + | ||
| 34 | int32_t num_trailing_blanks = 0; | 44 | int32_t num_trailing_blanks = 0; |
| 35 | 45 | ||
| 36 | Hypothesis() = default; | 46 | Hypothesis() = default; |
| @@ -96,17 +96,15 @@ void LogSoftmax(T *in, int32_t w, int32_t h) { | @@ -96,17 +96,15 @@ void LogSoftmax(T *in, int32_t w, int32_t h) { | ||
| 96 | } | 96 | } |
| 97 | } | 97 | } |
| 98 | 98 | ||
| 99 | -// TODO(fangjun): use std::partial_sort to replace std::sort. | ||
| 100 | -// Remember also to fix sherpa-ncnn | ||
| 101 | template <class T> | 99 | template <class T> |
| 102 | std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) { | 100 | std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) { |
| 103 | std::vector<int32_t> vec_index(size); | 101 | std::vector<int32_t> vec_index(size); |
| 104 | std::iota(vec_index.begin(), vec_index.end(), 0); | 102 | std::iota(vec_index.begin(), vec_index.end(), 0); |
| 105 | 103 | ||
| 106 | - std::sort(vec_index.begin(), vec_index.end(), | ||
| 107 | - [vec](int32_t index_1, int32_t index_2) { | ||
| 108 | - return vec[index_1] > vec[index_2]; | ||
| 109 | - }); | 104 | + std::partial_sort(vec_index.begin(), vec_index.begin() + topk, |
| 105 | + vec_index.end(), [vec](int32_t index_1, int32_t index_2) { | ||
| 106 | + return vec[index_1] > vec[index_2]; | ||
| 107 | + }); | ||
| 110 | 108 | ||
| 111 | int32_t k_num = std::min<int32_t>(size, topk); | 109 | int32_t k_num = std::min<int32_t>(size, topk); |
| 112 | std::vector<int32_t> index(vec_index.begin(), vec_index.begin() + k_num); | 110 | std::vector<int32_t> index(vec_index.begin(), vec_index.begin() + k_num); |
sherpa-onnx/csrc/online-lm.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/online-lm.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Pingfeng Luo | ||
| 4 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 5 | + | ||
| 6 | +#include "sherpa-onnx/csrc/online-lm.h" | ||
| 7 | + | ||
| 8 | +#include <algorithm> | ||
| 9 | +#include <utility> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#include "sherpa-onnx/csrc/online-rnn-lm.h" | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +static std::vector<CopyableOrtValue> Convert(std::vector<Ort::Value> values) { | ||
| 17 | + std::vector<CopyableOrtValue> ans; | ||
| 18 | + ans.reserve(values.size()); | ||
| 19 | + | ||
| 20 | + for (auto &v : values) { | ||
| 21 | + ans.emplace_back(std::move(v)); | ||
| 22 | + } | ||
| 23 | + | ||
| 24 | + return ans; | ||
| 25 | +} | ||
| 26 | + | ||
| 27 | +static std::vector<Ort::Value> Convert(std::vector<CopyableOrtValue> values) { | ||
| 28 | + std::vector<Ort::Value> ans; | ||
| 29 | + ans.reserve(values.size()); | ||
| 30 | + | ||
| 31 | + for (auto &v : values) { | ||
| 32 | + ans.emplace_back(std::move(v.value)); | ||
| 33 | + } | ||
| 34 | + | ||
| 35 | + return ans; | ||
| 36 | +} | ||
| 37 | + | ||
| 38 | +std::unique_ptr<OnlineLM> OnlineLM::Create(const OnlineLMConfig &config) { | ||
| 39 | + return std::make_unique<OnlineRnnLM>(config); | ||
| 40 | +} | ||
| 41 | + | ||
| 42 | +void OnlineLM::ComputeLMScore(float scale, int32_t context_size, | ||
| 43 | + std::vector<Hypotheses> *hyps) { | ||
| 44 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 45 | + | ||
| 46 | + for (auto &hyp : *hyps) { | ||
| 47 | + for (auto &h_m : hyp) { | ||
| 48 | + auto &h = h_m.second; | ||
| 49 | + auto &ys = h.ys; | ||
| 50 | + const int32_t token_num_in_chunk = | ||
| 51 | + ys.size() - context_size - h.cur_scored_pos - 1; | ||
| 52 | + | ||
| 53 | + if (token_num_in_chunk < 1) { | ||
| 54 | + continue; | ||
| 55 | + } | ||
| 56 | + | ||
| 57 | + if (h.nn_lm_states.empty()) { | ||
| 58 | + h.nn_lm_states = Convert(GetInitStates()); | ||
| 59 | + } | ||
| 60 | + | ||
| 61 | + if (token_num_in_chunk >= h.lm_rescore_min_chunk) { | ||
| 62 | + std::array<int64_t, 2> x_shape{1, token_num_in_chunk}; | ||
| 63 | + // shape of x and y are same | ||
| 64 | + Ort::Value x = Ort::Value::CreateTensor<int64_t>( | ||
| 65 | + allocator, x_shape.data(), x_shape.size()); | ||
| 66 | + Ort::Value y = Ort::Value::CreateTensor<int64_t>( | ||
| 67 | + allocator, x_shape.data(), x_shape.size()); | ||
| 68 | + int64_t *p_x = x.GetTensorMutableData<int64_t>(); | ||
| 69 | + int64_t *p_y = y.GetTensorMutableData<int64_t>(); | ||
| 70 | + std::copy(ys.begin() + context_size + h.cur_scored_pos, ys.end() - 1, | ||
| 71 | + p_x); | ||
| 72 | + std::copy(ys.begin() + context_size + h.cur_scored_pos + 1, ys.end(), | ||
| 73 | + p_y); | ||
| 74 | + | ||
| 75 | + // streaming forward by NN LM | ||
| 76 | + auto out = Rescore(std::move(x), std::move(y), | ||
| 77 | + Convert(std::move(h.nn_lm_states))); | ||
| 78 | + | ||
| 79 | + // update NN LM score in hyp | ||
| 80 | + const float *p_nll = out.first.GetTensorData<float>(); | ||
| 81 | + h.lm_log_prob = -scale * (*p_nll); | ||
| 82 | + | ||
| 83 | + // update NN LM states in hyp | ||
| 84 | + h.nn_lm_states = Convert(std::move(out.second)); | ||
| 85 | + | ||
| 86 | + h.cur_scored_pos += token_num_in_chunk; | ||
| 87 | + } | ||
| 88 | + } | ||
| 89 | + } | ||
| 90 | +} | ||
| 91 | + | ||
| 92 | +} // namespace sherpa_onnx |
| @@ -34,7 +34,7 @@ class OnlineLM { | @@ -34,7 +34,7 @@ class OnlineLM { | ||
| 34 | * | 34 | * |
| 35 | * Caution: It returns negative log likelihood (nll), not log likelihood | 35 | * Caution: It returns negative log likelihood (nll), not log likelihood |
| 36 | */ | 36 | */ |
| 37 | - std::pair<Ort::Value, std::vector<Ort::Value>> Ort::Value Rescore( | 37 | + virtual std::pair<Ort::Value, std::vector<Ort::Value>> Rescore( |
| 38 | Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) = 0; | 38 | Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) = 0; |
| 39 | 39 | ||
| 40 | // This function updates hyp.lm_lob_prob of hyps. | 40 | // This function updates hyp.lm_lob_prob of hyps. |
| @@ -44,19 +44,6 @@ class OnlineLM { | @@ -44,19 +44,6 @@ class OnlineLM { | ||
| 44 | // @param hyps It is changed in-place. | 44 | // @param hyps It is changed in-place. |
| 45 | void ComputeLMScore(float scale, int32_t context_size, | 45 | void ComputeLMScore(float scale, int32_t context_size, |
| 46 | std::vector<Hypotheses> *hyps); | 46 | std::vector<Hypotheses> *hyps); |
| 47 | - /** TODO(fangjun): | ||
| 48 | - * | ||
| 49 | - * 1. Add two fields to Hypothesis | ||
| 50 | - * (a) int32_t lm_cur_pos = 0; number of scored tokens so far | ||
| 51 | - * (b) std::vector<Ort::Value> lm_states; | ||
| 52 | - * 2. When we want to score a hypothesis, we construct x and y as follows: | ||
| 53 | - * | ||
| 54 | - * std::vector x = {hyp.ys.begin() + context_size + lm_cur_pos, | ||
| 55 | - * hyp.ys.end() - 1}; | ||
| 56 | - * std::vector y = {hyp.ys.begin() + context_size + lm_cur_pos + 1 | ||
| 57 | - * hyp.ys.end()}; | ||
| 58 | - * hyp.lm_cur_pos += hyp.ys.size() - context_size - lm_cur_pos; | ||
| 59 | - */ | ||
| 60 | }; | 47 | }; |
| 61 | 48 | ||
| 62 | } // namespace sherpa_onnx | 49 | } // namespace sherpa_onnx |
| @@ -16,6 +16,8 @@ | @@ -16,6 +16,8 @@ | ||
| 16 | 16 | ||
| 17 | #include "nlohmann/json.hpp" | 17 | #include "nlohmann/json.hpp" |
| 18 | #include "sherpa-onnx/csrc/file-utils.h" | 18 | #include "sherpa-onnx/csrc/file-utils.h" |
| 19 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 20 | +#include "sherpa-onnx/csrc/online-lm.h" | ||
| 19 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" | 21 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" |
| 20 | #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" | 22 | #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" |
| 21 | #include "sherpa-onnx/csrc/online-transducer-model.h" | 23 | #include "sherpa-onnx/csrc/online-transducer-model.h" |
| @@ -80,6 +82,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { | @@ -80,6 +82,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { | ||
| 80 | feat_config.Register(po); | 82 | feat_config.Register(po); |
| 81 | model_config.Register(po); | 83 | model_config.Register(po); |
| 82 | endpoint_config.Register(po); | 84 | endpoint_config.Register(po); |
| 85 | + lm_config.Register(po); | ||
| 83 | 86 | ||
| 84 | po->Register("enable-endpoint", &enable_endpoint, | 87 | po->Register("enable-endpoint", &enable_endpoint, |
| 85 | "True to enable endpoint detection. False to disable it."); | 88 | "True to enable endpoint detection. False to disable it."); |
| @@ -91,6 +94,14 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { | @@ -91,6 +94,14 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { | ||
| 91 | } | 94 | } |
| 92 | 95 | ||
| 93 | bool OnlineRecognizerConfig::Validate() const { | 96 | bool OnlineRecognizerConfig::Validate() const { |
| 97 | + if (decoding_method == "modified_beam_search" && !lm_config.model.empty()) { | ||
| 98 | + if (max_active_paths <= 0) { | ||
| 99 | + SHERPA_ONNX_LOGE("max_active_paths is less than 0! Given: %d", | ||
| 100 | + max_active_paths); | ||
| 101 | + return false; | ||
| 102 | + } | ||
| 103 | + if (!lm_config.Validate()) return false; | ||
| 104 | + } | ||
| 94 | return model_config.Validate(); | 105 | return model_config.Validate(); |
| 95 | } | 106 | } |
| 96 | 107 | ||
| @@ -100,6 +111,7 @@ std::string OnlineRecognizerConfig::ToString() const { | @@ -100,6 +111,7 @@ std::string OnlineRecognizerConfig::ToString() const { | ||
| 100 | os << "OnlineRecognizerConfig("; | 111 | os << "OnlineRecognizerConfig("; |
| 101 | os << "feat_config=" << feat_config.ToString() << ", "; | 112 | os << "feat_config=" << feat_config.ToString() << ", "; |
| 102 | os << "model_config=" << model_config.ToString() << ", "; | 113 | os << "model_config=" << model_config.ToString() << ", "; |
| 114 | + os << "lm_config=" << lm_config.ToString() << ", "; | ||
| 103 | os << "endpoint_config=" << endpoint_config.ToString() << ", "; | 115 | os << "endpoint_config=" << endpoint_config.ToString() << ", "; |
| 104 | os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", "; | 116 | os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", "; |
| 105 | os << "max_active_paths=" << max_active_paths << ", "; | 117 | os << "max_active_paths=" << max_active_paths << ", "; |
| @@ -116,8 +128,13 @@ class OnlineRecognizer::Impl { | @@ -116,8 +128,13 @@ class OnlineRecognizer::Impl { | ||
| 116 | sym_(config.model_config.tokens), | 128 | sym_(config.model_config.tokens), |
| 117 | endpoint_(config_.endpoint_config) { | 129 | endpoint_(config_.endpoint_config) { |
| 118 | if (config.decoding_method == "modified_beam_search") { | 130 | if (config.decoding_method == "modified_beam_search") { |
| 131 | + if (!config_.lm_config.model.empty()) { | ||
| 132 | + lm_ = OnlineLM::Create(config.lm_config); | ||
| 133 | + } | ||
| 134 | + | ||
| 119 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( | 135 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( |
| 120 | - model_.get(), config_.max_active_paths); | 136 | + model_.get(), lm_.get(), config_.max_active_paths, |
| 137 | + config_.lm_config.scale); | ||
| 121 | } else if (config.decoding_method == "greedy_search") { | 138 | } else if (config.decoding_method == "greedy_search") { |
| 122 | decoder_ = | 139 | decoder_ = |
| 123 | std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); | 140 | std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); |
| @@ -136,7 +153,8 @@ class OnlineRecognizer::Impl { | @@ -136,7 +153,8 @@ class OnlineRecognizer::Impl { | ||
| 136 | endpoint_(config_.endpoint_config) { | 153 | endpoint_(config_.endpoint_config) { |
| 137 | if (config.decoding_method == "modified_beam_search") { | 154 | if (config.decoding_method == "modified_beam_search") { |
| 138 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( | 155 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( |
| 139 | - model_.get(), config_.max_active_paths); | 156 | + model_.get(), lm_.get(), config_.max_active_paths, |
| 157 | + config_.lm_config.scale); | ||
| 140 | } else if (config.decoding_method == "greedy_search") { | 158 | } else if (config.decoding_method == "greedy_search") { |
| 141 | decoder_ = | 159 | decoder_ = |
| 142 | std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); | 160 | std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); |
| @@ -246,6 +264,7 @@ class OnlineRecognizer::Impl { | @@ -246,6 +264,7 @@ class OnlineRecognizer::Impl { | ||
| 246 | private: | 264 | private: |
| 247 | OnlineRecognizerConfig config_; | 265 | OnlineRecognizerConfig config_; |
| 248 | std::unique_ptr<OnlineTransducerModel> model_; | 266 | std::unique_ptr<OnlineTransducerModel> model_; |
| 267 | + std::unique_ptr<OnlineLM> lm_; | ||
| 249 | std::unique_ptr<OnlineTransducerDecoder> decoder_; | 268 | std::unique_ptr<OnlineTransducerDecoder> decoder_; |
| 250 | SymbolTable sym_; | 269 | SymbolTable sym_; |
| 251 | Endpoint endpoint_; | 270 | Endpoint endpoint_; |
| @@ -16,6 +16,7 @@ | @@ -16,6 +16,7 @@ | ||
| 16 | 16 | ||
| 17 | #include "sherpa-onnx/csrc/endpoint.h" | 17 | #include "sherpa-onnx/csrc/endpoint.h" |
| 18 | #include "sherpa-onnx/csrc/features.h" | 18 | #include "sherpa-onnx/csrc/features.h" |
| 19 | +#include "sherpa-onnx/csrc/online-lm-config.h" | ||
| 19 | #include "sherpa-onnx/csrc/online-stream.h" | 20 | #include "sherpa-onnx/csrc/online-stream.h" |
| 20 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" | 21 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" |
| 21 | #include "sherpa-onnx/csrc/parse-options.h" | 22 | #include "sherpa-onnx/csrc/parse-options.h" |
| @@ -67,10 +68,11 @@ struct OnlineRecognizerResult { | @@ -67,10 +68,11 @@ struct OnlineRecognizerResult { | ||
| 67 | struct OnlineRecognizerConfig { | 68 | struct OnlineRecognizerConfig { |
| 68 | FeatureExtractorConfig feat_config; | 69 | FeatureExtractorConfig feat_config; |
| 69 | OnlineTransducerModelConfig model_config; | 70 | OnlineTransducerModelConfig model_config; |
| 71 | + OnlineLMConfig lm_config; | ||
| 70 | EndpointConfig endpoint_config; | 72 | EndpointConfig endpoint_config; |
| 71 | bool enable_endpoint = true; | 73 | bool enable_endpoint = true; |
| 72 | 74 | ||
| 73 | - std::string decoding_method = "greedy_search"; | 75 | + std::string decoding_method = "modified_beam_search"; |
| 74 | // now support modified_beam_search and greedy_search | 76 | // now support modified_beam_search and greedy_search |
| 75 | 77 | ||
| 76 | int32_t max_active_paths = 4; // used only for modified_beam_search | 78 | int32_t max_active_paths = 4; // used only for modified_beam_search |
| @@ -79,6 +81,7 @@ struct OnlineRecognizerConfig { | @@ -79,6 +81,7 @@ struct OnlineRecognizerConfig { | ||
| 79 | 81 | ||
| 80 | OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config, | 82 | OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config, |
| 81 | const OnlineTransducerModelConfig &model_config, | 83 | const OnlineTransducerModelConfig &model_config, |
| 84 | + const OnlineLMConfig &lm_config, | ||
| 82 | const EndpointConfig &endpoint_config, | 85 | const EndpointConfig &endpoint_config, |
| 83 | bool enable_endpoint, | 86 | bool enable_endpoint, |
| 84 | const std::string &decoding_method, | 87 | const std::string &decoding_method, |
sherpa-onnx/csrc/online-rnn-lm.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/on-rnn-lm.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Pingfeng Luo | ||
| 4 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 5 | + | ||
| 6 | +#include "sherpa-onnx/csrc/online-rnn-lm.h" | ||
| 7 | + | ||
| 8 | +#include <string> | ||
| 9 | +#include <utility> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 13 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 14 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 15 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 16 | + | ||
| 17 | +namespace sherpa_onnx { | ||
| 18 | + | ||
| 19 | +class OnlineRnnLM::Impl { | ||
| 20 | + public: | ||
| 21 | + explicit Impl(const OnlineLMConfig &config) | ||
| 22 | + : config_(config), | ||
| 23 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 24 | + sess_opts_{}, | ||
| 25 | + allocator_{} { | ||
| 26 | + Init(config); | ||
| 27 | + } | ||
| 28 | + | ||
| 29 | + std::pair<Ort::Value, std::vector<Ort::Value>> Rescore( | ||
| 30 | + Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) { | ||
| 31 | + std::array<Ort::Value, 4> inputs = { | ||
| 32 | + std::move(x), std::move(y), std::move(states[0]), std::move(states[1])}; | ||
| 33 | + | ||
| 34 | + auto out = | ||
| 35 | + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 36 | + output_names_ptr_.data(), output_names_ptr_.size()); | ||
| 37 | + | ||
| 38 | + std::vector<Ort::Value> next_states; | ||
| 39 | + next_states.reserve(2); | ||
| 40 | + next_states.push_back(std::move(out[1])); | ||
| 41 | + next_states.push_back(std::move(out[2])); | ||
| 42 | + | ||
| 43 | + return {std::move(out[0]), std::move(next_states)}; | ||
| 44 | + } | ||
| 45 | + | ||
| 46 | + std::vector<Ort::Value> GetInitStates() const { | ||
| 47 | + std::vector<Ort::Value> ans; | ||
| 48 | + ans.reserve(init_states_.size()); | ||
| 49 | + | ||
| 50 | + for (const auto &s : init_states_) { | ||
| 51 | + ans.emplace_back(Clone(allocator_, &s)); | ||
| 52 | + } | ||
| 53 | + | ||
| 54 | + return ans; | ||
| 55 | + } | ||
| 56 | + | ||
| 57 | + private: | ||
| 58 | + void Init(const OnlineLMConfig &config) { | ||
| 59 | + auto buf = ReadFile(config_.model); | ||
| 60 | + | ||
| 61 | + sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(), | ||
| 62 | + sess_opts_); | ||
| 63 | + | ||
| 64 | + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | ||
| 65 | + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | ||
| 66 | + | ||
| 67 | + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); | ||
| 68 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 69 | + SHERPA_ONNX_READ_META_DATA(rnn_num_layers_, "num_layers"); | ||
| 70 | + SHERPA_ONNX_READ_META_DATA(rnn_hidden_size_, "hidden_size"); | ||
| 71 | + SHERPA_ONNX_READ_META_DATA(sos_id_, "sos_id"); | ||
| 72 | + | ||
| 73 | + ComputeInitStates(); | ||
| 74 | + } | ||
| 75 | + | ||
| 76 | + void ComputeInitStates() { | ||
| 77 | + constexpr int32_t kBatchSize = 1; | ||
| 78 | + std::array<int64_t, 3> h_shape{rnn_num_layers_, kBatchSize, | ||
| 79 | + rnn_hidden_size_}; | ||
| 80 | + std::array<int64_t, 3> c_shape{rnn_num_layers_, kBatchSize, | ||
| 81 | + rnn_hidden_size_}; | ||
| 82 | + Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(), | ||
| 83 | + h_shape.size()); | ||
| 84 | + Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(), | ||
| 85 | + c_shape.size()); | ||
| 86 | + Fill<float>(&h, 0); | ||
| 87 | + Fill<float>(&c, 0); | ||
| 88 | + std::array<int64_t, 2> x_shape{1, 1}; | ||
| 89 | + // shape of x and y are same | ||
| 90 | + Ort::Value x = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(), | ||
| 91 | + x_shape.size()); | ||
| 92 | + Ort::Value y = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(), | ||
| 93 | + x_shape.size()); | ||
| 94 | + *x.GetTensorMutableData<int64_t>() = sos_id_; | ||
| 95 | + *y.GetTensorMutableData<int64_t>() = sos_id_; | ||
| 96 | + | ||
| 97 | + std::vector<Ort::Value> states; | ||
| 98 | + states.push_back(std::move(h)); | ||
| 99 | + states.push_back(std::move(c)); | ||
| 100 | + auto pair = Rescore(std::move(x), std::move(y), std::move(states)); | ||
| 101 | + | ||
| 102 | + init_states_ = std::move(pair.second); | ||
| 103 | + } | ||
| 104 | + | ||
| 105 | + private: | ||
| 106 | + OnlineLMConfig config_; | ||
| 107 | + Ort::Env env_; | ||
| 108 | + Ort::SessionOptions sess_opts_; | ||
| 109 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 110 | + | ||
| 111 | + std::unique_ptr<Ort::Session> sess_; | ||
| 112 | + | ||
| 113 | + std::vector<std::string> input_names_; | ||
| 114 | + std::vector<const char *> input_names_ptr_; | ||
| 115 | + | ||
| 116 | + std::vector<std::string> output_names_; | ||
| 117 | + std::vector<const char *> output_names_ptr_; | ||
| 118 | + | ||
| 119 | + std::vector<Ort::Value> init_states_; | ||
| 120 | + | ||
| 121 | + int32_t rnn_num_layers_ = 2; | ||
| 122 | + int32_t rnn_hidden_size_ = 512; | ||
| 123 | + int32_t sos_id_ = 1; | ||
| 124 | +}; | ||
| 125 | + | ||
| 126 | +OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config) | ||
| 127 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 128 | + | ||
| 129 | +OnlineRnnLM::~OnlineRnnLM() = default; | ||
| 130 | + | ||
| 131 | +std::vector<Ort::Value> OnlineRnnLM::GetInitStates() { | ||
| 132 | + return impl_->GetInitStates(); | ||
| 133 | +} | ||
| 134 | + | ||
| 135 | +std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::Rescore( | ||
| 136 | + Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) { | ||
| 137 | + return impl_->Rescore(std::move(x), std::move(y), std::move(states)); | ||
| 138 | +} | ||
| 139 | + | ||
| 140 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/online-rnn-lm.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/online-rnn-lm.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Pingfeng Luo | ||
| 4 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 5 | + | ||
| 6 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_RNN_LM_H_ | ||
| 7 | +#define SHERPA_ONNX_CSRC_ONLINE_RNN_LM_H_ | ||
| 8 | + | ||
| 9 | +#include <memory> | ||
| 10 | +#include <utility> | ||
| 11 | +#include <vector> | ||
| 12 | + | ||
| 13 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 14 | +#include "sherpa-onnx/csrc/online-lm-config.h" | ||
| 15 | +#include "sherpa-onnx/csrc/online-lm.h" | ||
| 16 | + | ||
| 17 | +namespace sherpa_onnx { | ||
| 18 | + | ||
| 19 | +class OnlineRnnLM : public OnlineLM { | ||
| 20 | + public: | ||
| 21 | + ~OnlineRnnLM() override; | ||
| 22 | + | ||
| 23 | + explicit OnlineRnnLM(const OnlineLMConfig &config); | ||
| 24 | + | ||
| 25 | + std::vector<Ort::Value> GetInitStates() override; | ||
| 26 | + | ||
| 27 | + /** Rescore a batch of sentences. | ||
| 28 | + * | ||
| 29 | + * @param x A 2-D tensor of shape (N, L) with data type int64. | ||
| 30 | + * @param y A 2-D tensor of shape (N, L) with data type int64. | ||
| 31 | + * @param states It contains the states for the LM model | ||
| 32 | + * @return Return a pair containingo | ||
| 33 | + * - negative loglike | ||
| 34 | + * - updated states | ||
| 35 | + * | ||
| 36 | + * Caution: It returns negative log likelihood (nll), not log likelihood | ||
| 37 | + */ | ||
| 38 | + std::pair<Ort::Value, std::vector<Ort::Value>> Rescore( | ||
| 39 | + Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) override; | ||
| 40 | + | ||
| 41 | + private: | ||
| 42 | + class Impl; | ||
| 43 | + std::unique_ptr<Impl> impl_; | ||
| 44 | +}; | ||
| 45 | + | ||
| 46 | +} // namespace sherpa_onnx | ||
| 47 | + | ||
| 48 | +#endif // SHERPA_ONNX_CSRC_ONLINE_RNN_LM_H_ |
| @@ -156,6 +156,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | @@ -156,6 +156,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 156 | } // for (int32_t b = 0; b != batch_size; ++b) | 156 | } // for (int32_t b = 0; b != batch_size; ++b) |
| 157 | } | 157 | } |
| 158 | 158 | ||
| 159 | + if (lm_) { | ||
| 160 | + lm_->ComputeLMScore(lm_scale_, model_->ContextSize(), &cur); | ||
| 161 | + } | ||
| 162 | + | ||
| 159 | for (int32_t b = 0; b != batch_size; ++b) { | 163 | for (int32_t b = 0; b != batch_size; ++b) { |
| 160 | auto &hyps = cur[b]; | 164 | auto &hyps = cur[b]; |
| 161 | auto best_hyp = hyps.GetMostProbable(true); | 165 | auto best_hyp = hyps.GetMostProbable(true); |
| @@ -8,6 +8,7 @@ | @@ -8,6 +8,7 @@ | ||
| 8 | 8 | ||
| 9 | #include <vector> | 9 | #include <vector> |
| 10 | 10 | ||
| 11 | +#include "sherpa-onnx/csrc/online-lm.h" | ||
| 11 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" | 12 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" |
| 12 | #include "sherpa-onnx/csrc/online-transducer-model.h" | 13 | #include "sherpa-onnx/csrc/online-transducer-model.h" |
| 13 | 14 | ||
| @@ -17,8 +18,13 @@ class OnlineTransducerModifiedBeamSearchDecoder | @@ -17,8 +18,13 @@ class OnlineTransducerModifiedBeamSearchDecoder | ||
| 17 | : public OnlineTransducerDecoder { | 18 | : public OnlineTransducerDecoder { |
| 18 | public: | 19 | public: |
| 19 | OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model, | 20 | OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model, |
| 20 | - int32_t max_active_paths) | ||
| 21 | - : model_(model), max_active_paths_(max_active_paths) {} | 21 | + OnlineLM *lm, |
| 22 | + int32_t max_active_paths, | ||
| 23 | + float lm_scale) | ||
| 24 | + : model_(model), | ||
| 25 | + lm_(lm), | ||
| 26 | + max_active_paths_(max_active_paths), | ||
| 27 | + lm_scale_(lm_scale) {} | ||
| 22 | 28 | ||
| 23 | OnlineTransducerDecoderResult GetEmptyResult() const override; | 29 | OnlineTransducerDecoderResult GetEmptyResult() const override; |
| 24 | 30 | ||
| @@ -31,7 +37,10 @@ class OnlineTransducerModifiedBeamSearchDecoder | @@ -31,7 +37,10 @@ class OnlineTransducerModifiedBeamSearchDecoder | ||
| 31 | 37 | ||
| 32 | private: | 38 | private: |
| 33 | OnlineTransducerModel *model_; // Not owned | 39 | OnlineTransducerModel *model_; // Not owned |
| 40 | + OnlineLM *lm_; // Not owned | ||
| 41 | + | ||
| 34 | int32_t max_active_paths_; | 42 | int32_t max_active_paths_; |
| 43 | + float lm_scale_; // used only when lm_ is not nullptr | ||
| 35 | }; | 44 | }; |
| 36 | 45 | ||
| 37 | } // namespace sherpa_onnx | 46 | } // namespace sherpa_onnx |
| 1 | // sherpa-onnx/csrc/onnx-utils.cc | 1 | // sherpa-onnx/csrc/onnx-utils.cc |
| 2 | // | 2 | // |
| 3 | // Copyright (c) 2023 Xiaomi Corporation | 3 | // Copyright (c) 2023 Xiaomi Corporation |
| 4 | +// Copyright (c) 2023 Pingfeng Luo | ||
| 4 | #include "sherpa-onnx/csrc/onnx-utils.h" | 5 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 5 | 6 | ||
| 6 | #include <algorithm> | 7 | #include <algorithm> |
| 7 | #include <fstream> | 8 | #include <fstream> |
| 8 | #include <sstream> | 9 | #include <sstream> |
| 9 | #include <string> | 10 | #include <string> |
| 10 | -#include <vector> | ||
| 11 | 11 | ||
| 12 | #if __ANDROID_API__ >= 9 | 12 | #if __ANDROID_API__ >= 9 |
| 13 | #include "android/asset_manager.h" | 13 | #include "android/asset_manager.h" |
| @@ -218,4 +218,31 @@ Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, | @@ -218,4 +218,31 @@ Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, | ||
| 218 | return ans; | 218 | return ans; |
| 219 | } | 219 | } |
| 220 | 220 | ||
| 221 | +CopyableOrtValue::CopyableOrtValue(const CopyableOrtValue &other) { | ||
| 222 | + *this = other; | ||
| 223 | +} | ||
| 224 | + | ||
| 225 | +CopyableOrtValue &CopyableOrtValue::operator=(const CopyableOrtValue &other) { | ||
| 226 | + if (this == &other) { | ||
| 227 | + return *this; | ||
| 228 | + } | ||
| 229 | + if (other.value) { | ||
| 230 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 231 | + value = Clone(allocator, &other.value); | ||
| 232 | + } | ||
| 233 | + return *this; | ||
| 234 | +} | ||
| 235 | + | ||
| 236 | +CopyableOrtValue::CopyableOrtValue(CopyableOrtValue &&other) { | ||
| 237 | + *this = std::move(other); | ||
| 238 | +} | ||
| 239 | + | ||
| 240 | +CopyableOrtValue &CopyableOrtValue::operator=(CopyableOrtValue &&other) { | ||
| 241 | + if (this == &other) { | ||
| 242 | + return *this; | ||
| 243 | + } | ||
| 244 | + value = std::move(other.value); | ||
| 245 | + return *this; | ||
| 246 | +} | ||
| 247 | + | ||
| 221 | } // namespace sherpa_onnx | 248 | } // namespace sherpa_onnx |
| 1 | // sherpa-onnx/csrc/onnx-utils.h | 1 | // sherpa-onnx/csrc/onnx-utils.h |
| 2 | // | 2 | // |
| 3 | // Copyright (c) 2023 Xiaomi Corporation | 3 | // Copyright (c) 2023 Xiaomi Corporation |
| 4 | +// Copyright (c) 2023 Pingfeng Luo | ||
| 4 | #ifndef SHERPA_ONNX_CSRC_ONNX_UTILS_H_ | 5 | #ifndef SHERPA_ONNX_CSRC_ONNX_UTILS_H_ |
| 5 | #define SHERPA_ONNX_CSRC_ONNX_UTILS_H_ | 6 | #define SHERPA_ONNX_CSRC_ONNX_UTILS_H_ |
| 6 | 7 | ||
| @@ -13,6 +14,7 @@ | @@ -13,6 +14,7 @@ | ||
| 13 | #include <cassert> | 14 | #include <cassert> |
| 14 | #include <ostream> | 15 | #include <ostream> |
| 15 | #include <string> | 16 | #include <string> |
| 17 | +#include <utility> | ||
| 16 | #include <vector> | 18 | #include <vector> |
| 17 | 19 | ||
| 18 | #if __ANDROID_API__ >= 9 | 20 | #if __ANDROID_API__ >= 9 |
| @@ -89,6 +91,24 @@ std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename); | @@ -89,6 +91,24 @@ std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename); | ||
| 89 | // TODO(fangjun): Document it | 91 | // TODO(fangjun): Document it |
| 90 | Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, | 92 | Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, |
| 91 | const std::vector<int32_t> &hyps_num_split); | 93 | const std::vector<int32_t> &hyps_num_split); |
| 94 | + | ||
| 95 | +struct CopyableOrtValue { | ||
| 96 | + Ort::Value value{nullptr}; | ||
| 97 | + | ||
| 98 | + CopyableOrtValue() = default; | ||
| 99 | + | ||
| 100 | + /*explicit*/ CopyableOrtValue(Ort::Value v) // NOLINT | ||
| 101 | + : value(std::move(v)) {} | ||
| 102 | + | ||
| 103 | + CopyableOrtValue(const CopyableOrtValue &other); | ||
| 104 | + | ||
| 105 | + CopyableOrtValue &operator=(const CopyableOrtValue &other); | ||
| 106 | + | ||
| 107 | + CopyableOrtValue(CopyableOrtValue &&other); | ||
| 108 | + | ||
| 109 | + CopyableOrtValue &operator=(CopyableOrtValue &&other); | ||
| 110 | +}; | ||
| 111 | + | ||
| 92 | } // namespace sherpa_onnx | 112 | } // namespace sherpa_onnx |
| 93 | 113 | ||
| 94 | #endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_ | 114 | #endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_ |
| @@ -13,8 +13,9 @@ | @@ -13,8 +13,9 @@ | ||
| 13 | #include "sherpa-onnx/csrc/symbol-table.h" | 13 | #include "sherpa-onnx/csrc/symbol-table.h" |
| 14 | #include "sherpa-onnx/csrc/wave-reader.h" | 14 | #include "sherpa-onnx/csrc/wave-reader.h" |
| 15 | 15 | ||
| 16 | +// TODO(fangjun): Use ParseOptions as we are getting more args | ||
| 16 | int main(int32_t argc, char *argv[]) { | 17 | int main(int32_t argc, char *argv[]) { |
| 17 | - if (argc < 6 || argc > 8) { | 18 | + if (argc < 6 || argc > 9) { |
| 18 | const char *usage = R"usage( | 19 | const char *usage = R"usage( |
| 19 | Usage: | 20 | Usage: |
| 20 | ./bin/sherpa-onnx \ | 21 | ./bin/sherpa-onnx \ |
| @@ -22,7 +23,7 @@ Usage: | @@ -22,7 +23,7 @@ Usage: | ||
| 22 | /path/to/encoder.onnx \ | 23 | /path/to/encoder.onnx \ |
| 23 | /path/to/decoder.onnx \ | 24 | /path/to/decoder.onnx \ |
| 24 | /path/to/joiner.onnx \ | 25 | /path/to/joiner.onnx \ |
| 25 | - /path/to/foo.wav [num_threads [decoding_method]] | 26 | + /path/to/foo.wav [num_threads [decoding_method [/path/to/rnn_lm.onnx]]] |
| 26 | 27 | ||
| 27 | Default value for num_threads is 2. | 28 | Default value for num_threads is 2. |
| 28 | Valid values for decoding_method: greedy_search (default), modified_beam_search. | 29 | Valid values for decoding_method: greedy_search (default), modified_beam_search. |
| @@ -53,10 +54,12 @@ for a list of pre-trained models to download. | @@ -53,10 +54,12 @@ for a list of pre-trained models to download. | ||
| 53 | if (argc == 7 && atoi(argv[6]) > 0) { | 54 | if (argc == 7 && atoi(argv[6]) > 0) { |
| 54 | config.model_config.num_threads = atoi(argv[6]); | 55 | config.model_config.num_threads = atoi(argv[6]); |
| 55 | } | 56 | } |
| 56 | - | ||
| 57 | if (argc == 8) { | 57 | if (argc == 8) { |
| 58 | config.decoding_method = argv[7]; | 58 | config.decoding_method = argv[7]; |
| 59 | } | 59 | } |
| 60 | + if (argc == 9) { | ||
| 61 | + config.lm_config.model = argv[8]; | ||
| 62 | + } | ||
| 60 | config.max_active_paths = 4; | 63 | config.max_active_paths = 4; |
| 61 | 64 | ||
| 62 | fprintf(stderr, "%s\n", config.ToString().c_str()); | 65 | fprintf(stderr, "%s\n", config.ToString().c_str()); |
| @@ -16,9 +16,8 @@ | @@ -16,9 +16,8 @@ | ||
| 16 | #if __ANDROID_API__ >= 9 | 16 | #if __ANDROID_API__ >= 9 |
| 17 | #include "android/asset_manager.h" | 17 | #include "android/asset_manager.h" |
| 18 | #include "android/asset_manager_jni.h" | 18 | #include "android/asset_manager_jni.h" |
| 19 | -#else | ||
| 20 | -#include <fstream> | ||
| 21 | #endif | 19 | #endif |
| 20 | +#include <fstream> | ||
| 22 | 21 | ||
| 23 | #include "sherpa-onnx/csrc/macros.h" | 22 | #include "sherpa-onnx/csrc/macros.h" |
| 24 | #include "sherpa-onnx/csrc/online-recognizer.h" | 23 | #include "sherpa-onnx/csrc/online-recognizer.h" |
| @@ -188,6 +187,21 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { | @@ -188,6 +187,21 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { | ||
| 188 | fid = env->GetFieldID(model_config_cls, "debug", "Z"); | 187 | fid = env->GetFieldID(model_config_cls, "debug", "Z"); |
| 189 | ans.model_config.debug = env->GetBooleanField(model_config, fid); | 188 | ans.model_config.debug = env->GetBooleanField(model_config, fid); |
| 190 | 189 | ||
| 190 | + //---------- rnn lm model config ---------- | ||
| 191 | + fid = env->GetFieldID(cls, "lmConfig", | ||
| 192 | + "Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;"); | ||
| 193 | + jobject lm_model_config = env->GetObjectField(config, fid); | ||
| 194 | + jclass lm_model_config_cls = env->GetObjectClass(lm_model_config); | ||
| 195 | + | ||
| 196 | + fid = env->GetFieldID(lm_model_config_cls, "model", "Ljava/lang/String;"); | ||
| 197 | + s = (jstring)env->GetObjectField(lm_model_config, fid); | ||
| 198 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 199 | + ans.lm_config.model = p; | ||
| 200 | + env->ReleaseStringUTFChars(s, p); | ||
| 201 | + | ||
| 202 | + fid = env->GetFieldID(lm_model_config_cls, "scale", "F"); | ||
| 203 | + ans.lm_config.scale = env->GetFloatField(lm_model_config, fid); | ||
| 204 | + | ||
| 191 | return ans; | 205 | return ans; |
| 192 | } | 206 | } |
| 193 | 207 |
| @@ -11,6 +11,7 @@ pybind11_add_module(_sherpa_onnx | @@ -11,6 +11,7 @@ pybind11_add_module(_sherpa_onnx | ||
| 11 | offline-recognizer.cc | 11 | offline-recognizer.cc |
| 12 | offline-stream.cc | 12 | offline-stream.cc |
| 13 | offline-transducer-model-config.cc | 13 | offline-transducer-model-config.cc |
| 14 | + online-lm-config.cc | ||
| 14 | online-recognizer.cc | 15 | online-recognizer.cc |
| 15 | online-stream.cc | 16 | online-stream.cc |
| 16 | online-transducer-model-config.cc | 17 | online-transducer-model-config.cc |
sherpa-onnx/python/csrc/online-lm-config.cc
0 → 100644
| 1 | +// sherpa-onnx/python/csrc/online-lm-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/python/csrc/online-lm-config.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx//csrc/online-lm-config.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +void PybindOnlineLMConfig(py::module *m) { | ||
| 14 | + using PyClass = OnlineLMConfig; | ||
| 15 | + py::class_<PyClass>(*m, "OnlineLMConfig") | ||
| 16 | + .def(py::init<const std::string &, float>(), py::arg("model") = "", | ||
| 17 | + py::arg("scale") = 0.5f) | ||
| 18 | + .def_readwrite("model", &PyClass::model) | ||
| 19 | + .def_readwrite("scale", &PyClass::scale) | ||
| 20 | + .def("__str__", &PyClass::ToString); | ||
| 21 | +} | ||
| 22 | + | ||
| 23 | +} // namespace sherpa_onnx |
sherpa-onnx/python/csrc/online-lm-config.h
0 → 100644
| 1 | +// sherpa-onnx/python/csrc/online-lm-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_LM_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_LM_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindOnlineLMConfig(py::module *m); | ||
| 13 | + | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_LM_CONFIG_H_ |
| @@ -21,11 +21,13 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | @@ -21,11 +21,13 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | ||
| 21 | using PyClass = OnlineRecognizerConfig; | 21 | using PyClass = OnlineRecognizerConfig; |
| 22 | py::class_<PyClass>(*m, "OnlineRecognizerConfig") | 22 | py::class_<PyClass>(*m, "OnlineRecognizerConfig") |
| 23 | .def(py::init<const FeatureExtractorConfig &, | 23 | .def(py::init<const FeatureExtractorConfig &, |
| 24 | - const OnlineTransducerModelConfig &, const EndpointConfig &, | ||
| 25 | - bool, const std::string &, int32_t>(), | 24 | + const OnlineTransducerModelConfig &, const OnlineLMConfig &, |
| 25 | + const EndpointConfig &, bool, const std::string &, | ||
| 26 | + int32_t>(), | ||
| 26 | py::arg("feat_config"), py::arg("model_config"), | 27 | py::arg("feat_config"), py::arg("model_config"), |
| 27 | - py::arg("endpoint_config"), py::arg("enable_endpoint"), | ||
| 28 | - py::arg("decoding_method"), py::arg("max_active_paths")) | 28 | + py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"), |
| 29 | + py::arg("enable_endpoint"), py::arg("decoding_method"), | ||
| 30 | + py::arg("max_active_paths")) | ||
| 29 | .def_readwrite("feat_config", &PyClass::feat_config) | 31 | .def_readwrite("feat_config", &PyClass::feat_config) |
| 30 | .def_readwrite("model_config", &PyClass::model_config) | 32 | .def_readwrite("model_config", &PyClass::model_config) |
| 31 | .def_readwrite("endpoint_config", &PyClass::endpoint_config) | 33 | .def_readwrite("endpoint_config", &PyClass::endpoint_config) |
| @@ -11,6 +11,7 @@ | @@ -11,6 +11,7 @@ | ||
| 11 | #include "sherpa-onnx/python/csrc/offline-model-config.h" | 11 | #include "sherpa-onnx/python/csrc/offline-model-config.h" |
| 12 | #include "sherpa-onnx/python/csrc/offline-recognizer.h" | 12 | #include "sherpa-onnx/python/csrc/offline-recognizer.h" |
| 13 | #include "sherpa-onnx/python/csrc/offline-stream.h" | 13 | #include "sherpa-onnx/python/csrc/offline-stream.h" |
| 14 | +#include "sherpa-onnx/python/csrc/online-lm-config.h" | ||
| 14 | #include "sherpa-onnx/python/csrc/online-recognizer.h" | 15 | #include "sherpa-onnx/python/csrc/online-recognizer.h" |
| 15 | #include "sherpa-onnx/python/csrc/online-stream.h" | 16 | #include "sherpa-onnx/python/csrc/online-stream.h" |
| 16 | #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" | 17 | #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" |
| @@ -22,6 +23,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | @@ -22,6 +23,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | ||
| 22 | 23 | ||
| 23 | PybindFeatures(&m); | 24 | PybindFeatures(&m); |
| 24 | PybindOnlineTransducerModelConfig(&m); | 25 | PybindOnlineTransducerModelConfig(&m); |
| 26 | + PybindOnlineLMConfig(&m); | ||
| 25 | PybindOnlineStream(&m); | 27 | PybindOnlineStream(&m); |
| 26 | PybindEndpoint(&m); | 28 | PybindEndpoint(&m); |
| 27 | PybindOnlineRecognizer(&m); | 29 | PybindOnlineRecognizer(&m); |
-
请 注册 或 登录 后发表评论