Committed by
GitHub
Re-implement LM rescore for online transducer (#1231)
Co-authored-by: Martins Kronis <martins.kuznecovs@tilde.lv>
正在显示
11 个修改的文件
包含
175 行增加
和
31 行删除
| @@ -51,8 +51,13 @@ struct Hypothesis { | @@ -51,8 +51,13 @@ struct Hypothesis { | ||
| 51 | // LM log prob if any. | 51 | // LM log prob if any. |
| 52 | double lm_log_prob = 0; | 52 | double lm_log_prob = 0; |
| 53 | 53 | ||
| 54 | - // the nn lm score for next token given the current ys | 54 | + // the nn lm score for next token given the current ys, |
| 55 | + // when using shallow fusion | ||
| 55 | CopyableOrtValue nn_lm_scores; | 56 | CopyableOrtValue nn_lm_scores; |
| 57 | + | ||
| 58 | + // cur scored tokens by RNN LM, when rescoring | ||
| 59 | + int32_t cur_scored_pos = 0; | ||
| 60 | + | ||
| 56 | // the nn lm states | 61 | // the nn lm states |
| 57 | std::vector<CopyableOrtValue> nn_lm_states; | 62 | std::vector<CopyableOrtValue> nn_lm_states; |
| 58 | 63 |
| @@ -18,6 +18,8 @@ void OnlineLMConfig::Register(ParseOptions *po) { | @@ -18,6 +18,8 @@ void OnlineLMConfig::Register(ParseOptions *po) { | ||
| 18 | "Number of threads to run the neural network of LM model"); | 18 | "Number of threads to run the neural network of LM model"); |
| 19 | po->Register("lm-provider", &lm_provider, | 19 | po->Register("lm-provider", &lm_provider, |
| 20 | "Specify a provider to LM model use: cpu, cuda, coreml"); | 20 | "Specify a provider to LM model use: cpu, cuda, coreml"); |
| 21 | + po->Register("lm-shallow-fusion", &shallow_fusion, | ||
| 22 | + "Boolean whether to use shallow fusion or rescore."); | ||
| 21 | } | 23 | } |
| 22 | 24 | ||
| 23 | bool OnlineLMConfig::Validate() const { | 25 | bool OnlineLMConfig::Validate() const { |
| @@ -34,7 +36,8 @@ std::string OnlineLMConfig::ToString() const { | @@ -34,7 +36,8 @@ std::string OnlineLMConfig::ToString() const { | ||
| 34 | 36 | ||
| 35 | os << "OnlineLMConfig("; | 37 | os << "OnlineLMConfig("; |
| 36 | os << "model=\"" << model << "\", "; | 38 | os << "model=\"" << model << "\", "; |
| 37 | - os << "scale=" << scale << ")"; | 39 | + os << "scale=" << scale << ", "; |
| 40 | + os << "shallow_fusion=" << (shallow_fusion ? "True" : "False") << ")"; | ||
| 38 | 41 | ||
| 39 | return os.str(); | 42 | return os.str(); |
| 40 | } | 43 | } |
| @@ -18,15 +18,18 @@ struct OnlineLMConfig { | @@ -18,15 +18,18 @@ struct OnlineLMConfig { | ||
| 18 | float scale = 0.5; | 18 | float scale = 0.5; |
| 19 | int32_t lm_num_threads = 1; | 19 | int32_t lm_num_threads = 1; |
| 20 | std::string lm_provider = "cpu"; | 20 | std::string lm_provider = "cpu"; |
| 21 | + // enable shallow fusion | ||
| 22 | + bool shallow_fusion = true; | ||
| 21 | 23 | ||
| 22 | OnlineLMConfig() = default; | 24 | OnlineLMConfig() = default; |
| 23 | 25 | ||
| 24 | OnlineLMConfig(const std::string &model, float scale, int32_t lm_num_threads, | 26 | OnlineLMConfig(const std::string &model, float scale, int32_t lm_num_threads, |
| 25 | - const std::string &lm_provider) | 27 | + const std::string &lm_provider, bool shallow_fusion) |
| 26 | : model(model), | 28 | : model(model), |
| 27 | scale(scale), | 29 | scale(scale), |
| 28 | lm_num_threads(lm_num_threads), | 30 | lm_num_threads(lm_num_threads), |
| 29 | - lm_provider(lm_provider) {} | 31 | + lm_provider(lm_provider), |
| 32 | + shallow_fusion(shallow_fusion) {} | ||
| 30 | 33 | ||
| 31 | void Register(ParseOptions *po); | 34 | void Register(ParseOptions *po); |
| 32 | bool Validate() const; | 35 | bool Validate() const; |
| @@ -21,13 +21,17 @@ class OnlineLM { | @@ -21,13 +21,17 @@ class OnlineLM { | ||
| 21 | 21 | ||
| 22 | static std::unique_ptr<OnlineLM> Create(const OnlineLMConfig &config); | 22 | static std::unique_ptr<OnlineLM> Create(const OnlineLMConfig &config); |
| 23 | 23 | ||
| 24 | - virtual std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() = 0; | 24 | + // init states for classic rescore |
| 25 | + virtual std::vector<Ort::Value> GetInitStates() = 0; | ||
| 25 | 26 | ||
| 26 | - /** ScoreToken a batch of sentences. | 27 | + // init states for shallow fusion |
| 28 | + virtual std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStatesSF() = 0; | ||
| 29 | + | ||
| 30 | + /** ScoreToken a batch of sentences (shallow fusion). | ||
| 27 | * | 31 | * |
| 28 | * @param x A 2-D tensor of shape (N, 1) with data type int64. | 32 | * @param x A 2-D tensor of shape (N, 1) with data type int64. |
| 29 | * @param states It contains the states for the LM model | 33 | * @param states It contains the states for the LM model |
| 30 | - * @return Return a pair containingo | 34 | + * @return Return a pair containing |
| 31 | * - log_prob of NN LM | 35 | * - log_prob of NN LM |
| 32 | * - updated states | 36 | * - updated states |
| 33 | * | 37 | * |
| @@ -35,13 +39,23 @@ class OnlineLM { | @@ -35,13 +39,23 @@ class OnlineLM { | ||
| 35 | virtual std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken( | 39 | virtual std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken( |
| 36 | Ort::Value x, std::vector<Ort::Value> states) = 0; | 40 | Ort::Value x, std::vector<Ort::Value> states) = 0; |
| 37 | 41 | ||
| 38 | - /** This function updates lm_lob_prob and nn_lm_scores of hyp | 42 | + /** This function updates hyp.lm_log_prob of hyps (classic rescore). |
| 43 | + * | ||
| 44 | + * @param scale LM score | ||
| 45 | + * @param context_size Context size of the transducer decoder model | ||
| 46 | + * @param hyps It is changed in-place. | ||
| 47 | + * | ||
| 48 | + */ | ||
| 49 | + virtual void ComputeLMScore(float scale, int32_t context_size, | ||
| 50 | + std::vector<Hypotheses> *hyps) = 0; | ||
| 51 | + | ||
| 52 | + /** This function updates lm_log_prob and nn_lm_scores of hyp (shallow fusion). | ||
| 39 | * | 53 | * |
| 40 | * @param scale LM score | 54 | * @param scale LM score |
| 41 | * @param hyps It is changed in-place. | 55 | * @param hyps It is changed in-place. |
| 42 | * | 56 | * |
| 43 | */ | 57 | */ |
| 44 | - virtual void ComputeLMScore(float scale, Hypothesis *hyp) = 0; | 58 | + virtual void ComputeLMScoreSF(float scale, Hypothesis *hyp) = 0; |
| 45 | }; | 59 | }; |
| 46 | 60 | ||
| 47 | } // namespace sherpa_onnx | 61 | } // namespace sherpa_onnx |
| @@ -107,7 +107,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -107,7 +107,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 107 | 107 | ||
| 108 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( | 108 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( |
| 109 | model_.get(), lm_.get(), config_.max_active_paths, | 109 | model_.get(), lm_.get(), config_.max_active_paths, |
| 110 | - config_.lm_config.scale, unk_id_, config_.blank_penalty, | 110 | + config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_, |
| 111 | + config_.blank_penalty, | ||
| 111 | config_.temperature_scale); | 112 | config_.temperature_scale); |
| 112 | 113 | ||
| 113 | } else if (config.decoding_method == "greedy_search") { | 114 | } else if (config.decoding_method == "greedy_search") { |
| @@ -156,7 +157,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -156,7 +157,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 156 | 157 | ||
| 157 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( | 158 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( |
| 158 | model_.get(), lm_.get(), config_.max_active_paths, | 159 | model_.get(), lm_.get(), config_.max_active_paths, |
| 159 | - config_.lm_config.scale, unk_id_, config_.blank_penalty, | 160 | + config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_, |
| 161 | + config_.blank_penalty, | ||
| 160 | config_.temperature_scale); | 162 | config_.temperature_scale); |
| 161 | 163 | ||
| 162 | } else if (config.decoding_method == "greedy_search") { | 164 | } else if (config.decoding_method == "greedy_search") { |
| @@ -8,6 +8,7 @@ | @@ -8,6 +8,7 @@ | ||
| 8 | #include <string> | 8 | #include <string> |
| 9 | #include <utility> | 9 | #include <utility> |
| 10 | #include <vector> | 10 | #include <vector> |
| 11 | +#include <algorithm> | ||
| 11 | 12 | ||
| 12 | #include "onnxruntime_cxx_api.h" // NOLINT | 13 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 13 | #include "sherpa-onnx/csrc/macros.h" | 14 | #include "sherpa-onnx/csrc/macros.h" |
| @@ -27,9 +28,10 @@ class OnlineRnnLM::Impl { | @@ -27,9 +28,10 @@ class OnlineRnnLM::Impl { | ||
| 27 | Init(config); | 28 | Init(config); |
| 28 | } | 29 | } |
| 29 | 30 | ||
| 30 | - void ComputeLMScore(float scale, Hypothesis *hyp) { | 31 | + // shallow fusion scoring function |
| 32 | + void ComputeLMScoreSF(float scale, Hypothesis *hyp) { | ||
| 31 | if (hyp->nn_lm_states.empty()) { | 33 | if (hyp->nn_lm_states.empty()) { |
| 32 | - auto init_states = GetInitStates(); | 34 | + auto init_states = GetInitStatesSF(); |
| 33 | hyp->nn_lm_scores.value = std::move(init_states.first); | 35 | hyp->nn_lm_scores.value = std::move(init_states.first); |
| 34 | hyp->nn_lm_states = Convert(std::move(init_states.second)); | 36 | hyp->nn_lm_states = Convert(std::move(init_states.second)); |
| 35 | } | 37 | } |
| @@ -49,6 +51,52 @@ class OnlineRnnLM::Impl { | @@ -49,6 +51,52 @@ class OnlineRnnLM::Impl { | ||
| 49 | hyp->nn_lm_states = Convert(std::move(lm_out.second)); | 51 | hyp->nn_lm_states = Convert(std::move(lm_out.second)); |
| 50 | } | 52 | } |
| 51 | 53 | ||
| 54 | + // classic rescore function | ||
| 55 | + void ComputeLMScore(float scale, int32_t context_size, | ||
| 56 | + std::vector<Hypotheses> *hyps) { | ||
| 57 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 58 | + | ||
| 59 | + for (auto &hyp : *hyps) { | ||
| 60 | + for (auto &h_m : hyp) { | ||
| 61 | + auto &h = h_m.second; | ||
| 62 | + auto &ys = h.ys; | ||
| 63 | + const int32_t token_num_in_chunk = | ||
| 64 | + ys.size() - context_size - h.cur_scored_pos - 1; | ||
| 65 | + | ||
| 66 | + if (token_num_in_chunk < 1) { | ||
| 67 | + continue; | ||
| 68 | + } | ||
| 69 | + | ||
| 70 | + if (h.nn_lm_states.empty()) { | ||
| 71 | + h.nn_lm_states = Convert(GetInitStates()); | ||
| 72 | + } | ||
| 73 | + | ||
| 74 | + if (token_num_in_chunk >= h.lm_rescore_min_chunk) { | ||
| 75 | + std::array<int64_t, 2> x_shape{1, token_num_in_chunk}; | ||
| 76 | + | ||
| 77 | + Ort::Value x = Ort::Value::CreateTensor<int64_t>( | ||
| 78 | + allocator, x_shape.data(), x_shape.size()); | ||
| 79 | + int64_t *p_x = x.GetTensorMutableData<int64_t>(); | ||
| 80 | + std::copy(ys.begin() + context_size + h.cur_scored_pos, | ||
| 81 | + ys.end() - 1, p_x); | ||
| 82 | + | ||
| 83 | + // streaming forward by NN LM | ||
| 84 | + auto out = ScoreToken(std::move(x), | ||
| 85 | + Convert(std::move(h.nn_lm_states))); | ||
| 86 | + | ||
| 87 | + // update NN LM score in hyp | ||
| 88 | + const float *p_nll = out.first.GetTensorData<float>(); | ||
| 89 | + h.lm_log_prob = -scale * (*p_nll); | ||
| 90 | + | ||
| 91 | + // update NN LM states in hyp | ||
| 92 | + h.nn_lm_states = Convert(std::move(out.second)); | ||
| 93 | + | ||
| 94 | + h.cur_scored_pos += token_num_in_chunk; | ||
| 95 | + } | ||
| 96 | + } | ||
| 97 | + } | ||
| 98 | + } | ||
| 99 | + | ||
| 52 | std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken( | 100 | std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken( |
| 53 | Ort::Value x, std::vector<Ort::Value> states) { | 101 | Ort::Value x, std::vector<Ort::Value> states) { |
| 54 | std::array<Ort::Value, 3> inputs = {std::move(x), std::move(states[0]), | 102 | std::array<Ort::Value, 3> inputs = {std::move(x), std::move(states[0]), |
| @@ -66,7 +114,8 @@ class OnlineRnnLM::Impl { | @@ -66,7 +114,8 @@ class OnlineRnnLM::Impl { | ||
| 66 | return {std::move(out[0]), std::move(next_states)}; | 114 | return {std::move(out[0]), std::move(next_states)}; |
| 67 | } | 115 | } |
| 68 | 116 | ||
| 69 | - std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() { | 117 | + // get init states for shallow fusion |
| 118 | + std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStatesSF() { | ||
| 70 | std::vector<Ort::Value> ans; | 119 | std::vector<Ort::Value> ans; |
| 71 | ans.reserve(init_states_.size()); | 120 | ans.reserve(init_states_.size()); |
| 72 | for (auto &s : init_states_) { | 121 | for (auto &s : init_states_) { |
| @@ -75,6 +124,18 @@ class OnlineRnnLM::Impl { | @@ -75,6 +124,18 @@ class OnlineRnnLM::Impl { | ||
| 75 | return {View(&init_scores_.value), std::move(ans)}; | 124 | return {View(&init_scores_.value), std::move(ans)}; |
| 76 | } | 125 | } |
| 77 | 126 | ||
| 127 | + // get init states for classic rescore | ||
| 128 | + std::vector<Ort::Value> GetInitStates() const { | ||
| 129 | + std::vector<Ort::Value> ans; | ||
| 130 | + ans.reserve(init_states_.size()); | ||
| 131 | + | ||
| 132 | + for (const auto &s : init_states_) { | ||
| 133 | + ans.emplace_back(Clone(allocator_, &s)); | ||
| 134 | + } | ||
| 135 | + | ||
| 136 | + return ans; | ||
| 137 | + } | ||
| 138 | + | ||
| 78 | private: | 139 | private: |
| 79 | void Init(const OnlineLMConfig &config) { | 140 | void Init(const OnlineLMConfig &config) { |
| 80 | auto buf = ReadFile(config_.model); | 141 | auto buf = ReadFile(config_.model); |
| @@ -116,7 +177,8 @@ class OnlineRnnLM::Impl { | @@ -116,7 +177,8 @@ class OnlineRnnLM::Impl { | ||
| 116 | states.push_back(std::move(c)); | 177 | states.push_back(std::move(c)); |
| 117 | auto pair = ScoreToken(std::move(x), std::move(states)); | 178 | auto pair = ScoreToken(std::move(x), std::move(states)); |
| 118 | 179 | ||
| 119 | - init_scores_.value = std::move(pair.first); | 180 | + init_scores_.value = std::move(pair.first); // only used during |
| 181 | + // shallow fusion | ||
| 120 | init_states_ = std::move(pair.second); | 182 | init_states_ = std::move(pair.second); |
| 121 | } | 183 | } |
| 122 | 184 | ||
| @@ -147,17 +209,31 @@ OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config) | @@ -147,17 +209,31 @@ OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config) | ||
| 147 | 209 | ||
| 148 | OnlineRnnLM::~OnlineRnnLM() = default; | 210 | OnlineRnnLM::~OnlineRnnLM() = default; |
| 149 | 211 | ||
| 150 | -std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::GetInitStates() { | 212 | +// classic rescore state init |
| 213 | +std::vector<Ort::Value> OnlineRnnLM::GetInitStates() { | ||
| 151 | return impl_->GetInitStates(); | 214 | return impl_->GetInitStates(); |
| 152 | } | 215 | } |
| 153 | 216 | ||
| 217 | +// shallow fusion state init | ||
| 218 | +std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::GetInitStatesSF() { | ||
| 219 | + return impl_->GetInitStatesSF(); | ||
| 220 | +} | ||
| 221 | + | ||
| 154 | std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::ScoreToken( | 222 | std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::ScoreToken( |
| 155 | Ort::Value x, std::vector<Ort::Value> states) { | 223 | Ort::Value x, std::vector<Ort::Value> states) { |
| 156 | return impl_->ScoreToken(std::move(x), std::move(states)); | 224 | return impl_->ScoreToken(std::move(x), std::move(states)); |
| 157 | } | 225 | } |
| 158 | 226 | ||
| 159 | -void OnlineRnnLM::ComputeLMScore(float scale, Hypothesis *hyp) { | ||
| 160 | - return impl_->ComputeLMScore(scale, hyp); | 227 | +// classic rescore scores |
| 228 | +void OnlineRnnLM::ComputeLMScore(float scale, int32_t context_size, | ||
| 229 | + std::vector<Hypotheses> *hyps) { | ||
| 230 | + return impl_->ComputeLMScore(scale, context_size, hyps); | ||
| 161 | } | 231 | } |
| 162 | 232 | ||
| 233 | +// shallow fusion scores | ||
| 234 | +void OnlineRnnLM::ComputeLMScoreSF(float scale, Hypothesis *hyp) { | ||
| 235 | + return impl_->ComputeLMScoreSF(scale, hyp); | ||
| 236 | +} | ||
| 237 | + | ||
| 238 | + | ||
| 163 | } // namespace sherpa_onnx | 239 | } // namespace sherpa_onnx |
| @@ -22,13 +22,17 @@ class OnlineRnnLM : public OnlineLM { | @@ -22,13 +22,17 @@ class OnlineRnnLM : public OnlineLM { | ||
| 22 | 22 | ||
| 23 | explicit OnlineRnnLM(const OnlineLMConfig &config); | 23 | explicit OnlineRnnLM(const OnlineLMConfig &config); |
| 24 | 24 | ||
| 25 | - std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() override; | 25 | + // init scores for classic rescore |
| 26 | + std::vector<Ort::Value> GetInitStates() override; | ||
| 26 | 27 | ||
| 27 | - /** ScoreToken a batch of sentences. | 28 | + // init scores for shallow fusion |
| 29 | + std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStatesSF() override; | ||
| 30 | + | ||
| 31 | + /** ScoreToken a batch of sentences (shallow fusion). | ||
| 28 | * | 32 | * |
| 29 | * @param x A 2-D tensor of shape (N, L) with data type int64. | 33 | * @param x A 2-D tensor of shape (N, L) with data type int64. |
| 30 | * @param states It contains the states for the LM model | 34 | * @param states It contains the states for the LM model |
| 31 | - * @return Return a pair containingo | 35 | + * @return Return a pair containing |
| 32 | * - log_prob of NN LM | 36 | * - log_prob of NN LM |
| 33 | * - updated states | 37 | * - updated states |
| 34 | * | 38 | * |
| @@ -36,13 +40,23 @@ class OnlineRnnLM : public OnlineLM { | @@ -36,13 +40,23 @@ class OnlineRnnLM : public OnlineLM { | ||
| 36 | std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken( | 40 | std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken( |
| 37 | Ort::Value x, std::vector<Ort::Value> states) override; | 41 | Ort::Value x, std::vector<Ort::Value> states) override; |
| 38 | 42 | ||
| 39 | - /** This function updates lm_lob_prob and nn_lm_scores of hyp | 43 | + /** This function updates hyp.lm_lob_prob of hyps (classic rescore). |
| 44 | + * | ||
| 45 | + * @param scale LM score | ||
| 46 | + * @param context_size Context size of the transducer decoder model | ||
| 47 | + * @param hyps It is changed in-place. | ||
| 48 | + * | ||
| 49 | + */ | ||
| 50 | + void ComputeLMScore(float scale, int32_t context_size, | ||
| 51 | + std::vector<Hypotheses> *hyps) override; | ||
| 52 | + | ||
| 53 | + /** This function updates lm_lob_prob and nn_lm_scores of hyp (shallow fusion). | ||
| 40 | * | 54 | * |
| 41 | * @param scale LM score | 55 | * @param scale LM score |
| 42 | * @param hyps It is changed in-place. | 56 | * @param hyps It is changed in-place. |
| 43 | * | 57 | * |
| 44 | */ | 58 | */ |
| 45 | - void ComputeLMScore(float scale, Hypothesis *hyp) override; | 59 | + void ComputeLMScoreSF(float scale, Hypothesis *hyp) override; |
| 46 | 60 | ||
| 47 | private: | 61 | private: |
| 48 | class Impl; | 62 | class Impl; |
| @@ -156,7 +156,11 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | @@ -156,7 +156,11 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 156 | 156 | ||
| 157 | // add log_prob of each hypothesis to p_logprob before taking top_k | 157 | // add log_prob of each hypothesis to p_logprob before taking top_k |
| 158 | for (int32_t i = 0; i != num_hyps; ++i) { | 158 | for (int32_t i = 0; i != num_hyps; ++i) { |
| 159 | - float log_prob = prev[i].log_prob + prev[i].lm_log_prob; | 159 | + float log_prob = prev[i].log_prob; |
| 160 | + if (lm_ && shallow_fusion_) { | ||
| 161 | + log_prob += prev[i].lm_log_prob; | ||
| 162 | + } | ||
| 163 | + | ||
| 160 | for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) { | 164 | for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) { |
| 161 | *p_logprob += log_prob; | 165 | *p_logprob += log_prob; |
| 162 | } | 166 | } |
| @@ -192,22 +196,31 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | @@ -192,22 +196,31 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 192 | context_score = std::get<0>(context_res); | 196 | context_score = std::get<0>(context_res); |
| 193 | new_hyp.context_state = std::get<1>(context_res); | 197 | new_hyp.context_state = std::get<1>(context_res); |
| 194 | } | 198 | } |
| 195 | - if (lm_) { | ||
| 196 | - lm_->ComputeLMScore(lm_scale_, &new_hyp); | 199 | + if (lm_ && shallow_fusion_) { |
| 200 | + lm_->ComputeLMScoreSF(lm_scale_, &new_hyp); | ||
| 197 | } | 201 | } |
| 198 | } else { | 202 | } else { |
| 199 | ++new_hyp.num_trailing_blanks; | 203 | ++new_hyp.num_trailing_blanks; |
| 200 | } | 204 | } |
| 201 | - new_hyp.log_prob = p_logprob[k] + context_score - | 205 | + if (lm_ && shallow_fusion_) { |
| 206 | + new_hyp.log_prob = p_logprob[k] + context_score - | ||
| 202 | prev_lm_log_prob; // log_prob only includes the | 207 | prev_lm_log_prob; // log_prob only includes the |
| 203 | // score of the transducer | 208 | // score of the transducer |
| 209 | + } else { | ||
| 210 | + new_hyp.log_prob = p_logprob[k] + context_score; // rescore or no LM | ||
| 211 | + // previous token | ||
| 212 | + // score is ignored | ||
| 213 | + } | ||
| 214 | + | ||
| 204 | // export the per-token log scores | 215 | // export the per-token log scores |
| 205 | if (new_token != 0 && new_token != unk_id_) { | 216 | if (new_token != 0 && new_token != unk_id_) { |
| 206 | float y_prob = logit_with_temperature[start * vocab_size + k]; | 217 | float y_prob = logit_with_temperature[start * vocab_size + k]; |
| 207 | new_hyp.ys_probs.push_back(y_prob); | 218 | new_hyp.ys_probs.push_back(y_prob); |
| 208 | 219 | ||
| 209 | - if (lm_) { // export only when LM is used | 220 | + if (lm_ && shallow_fusion_) { // export only if |
| 221 | + // LM shallow fusion is used | ||
| 210 | float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob; | 222 | float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob; |
| 223 | + | ||
| 211 | if (lm_scale_ != 0.0) { | 224 | if (lm_scale_ != 0.0) { |
| 212 | lm_prob /= lm_scale_; // remove lm-scale | 225 | lm_prob /= lm_scale_; // remove lm-scale |
| 213 | } | 226 | } |
| @@ -227,6 +240,11 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | @@ -227,6 +240,11 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 227 | } // for (int32_t b = 0; b != batch_size; ++b) | 240 | } // for (int32_t b = 0; b != batch_size; ++b) |
| 228 | } // for (int32_t t = 0; t != num_frames; ++t) | 241 | } // for (int32_t t = 0; t != num_frames; ++t) |
| 229 | 242 | ||
| 243 | + // classic lm rescore | ||
| 244 | + if (lm_ && !shallow_fusion_) { | ||
| 245 | + lm_->ComputeLMScore(lm_scale_, model_->ContextSize(), &cur); | ||
| 246 | + } | ||
| 247 | + | ||
| 230 | for (int32_t b = 0; b != batch_size; ++b) { | 248 | for (int32_t b = 0; b != batch_size; ++b) { |
| 231 | auto &hyps = cur[b]; | 249 | auto &hyps = cur[b]; |
| 232 | auto best_hyp = hyps.GetMostProbable(true); | 250 | auto best_hyp = hyps.GetMostProbable(true); |
| @@ -21,13 +21,16 @@ class OnlineTransducerModifiedBeamSearchDecoder | @@ -21,13 +21,16 @@ class OnlineTransducerModifiedBeamSearchDecoder | ||
| 21 | OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model, | 21 | OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model, |
| 22 | OnlineLM *lm, | 22 | OnlineLM *lm, |
| 23 | int32_t max_active_paths, | 23 | int32_t max_active_paths, |
| 24 | - float lm_scale, int32_t unk_id, | 24 | + float lm_scale, |
| 25 | + bool shallow_fusion, | ||
| 26 | + int32_t unk_id, | ||
| 25 | float blank_penalty, | 27 | float blank_penalty, |
| 26 | float temperature_scale) | 28 | float temperature_scale) |
| 27 | : model_(model), | 29 | : model_(model), |
| 28 | lm_(lm), | 30 | lm_(lm), |
| 29 | max_active_paths_(max_active_paths), | 31 | max_active_paths_(max_active_paths), |
| 30 | lm_scale_(lm_scale), | 32 | lm_scale_(lm_scale), |
| 33 | + shallow_fusion_(shallow_fusion), | ||
| 31 | unk_id_(unk_id), | 34 | unk_id_(unk_id), |
| 32 | blank_penalty_(blank_penalty), | 35 | blank_penalty_(blank_penalty), |
| 33 | temperature_scale_(temperature_scale) {} | 36 | temperature_scale_(temperature_scale) {} |
| @@ -50,6 +53,7 @@ class OnlineTransducerModifiedBeamSearchDecoder | @@ -50,6 +53,7 @@ class OnlineTransducerModifiedBeamSearchDecoder | ||
| 50 | 53 | ||
| 51 | int32_t max_active_paths_; | 54 | int32_t max_active_paths_; |
| 52 | float lm_scale_; // used only when lm_ is not nullptr | 55 | float lm_scale_; // used only when lm_ is not nullptr |
| 56 | + bool shallow_fusion_; // used only when lm_ is not nullptr | ||
| 53 | int32_t unk_id_; | 57 | int32_t unk_id_; |
| 54 | float blank_penalty_; | 58 | float blank_penalty_; |
| 55 | float temperature_scale_; | 59 | float temperature_scale_; |
| @@ -13,13 +13,16 @@ namespace sherpa_onnx { | @@ -13,13 +13,16 @@ namespace sherpa_onnx { | ||
| 13 | void PybindOnlineLMConfig(py::module *m) { | 13 | void PybindOnlineLMConfig(py::module *m) { |
| 14 | using PyClass = OnlineLMConfig; | 14 | using PyClass = OnlineLMConfig; |
| 15 | py::class_<PyClass>(*m, "OnlineLMConfig") | 15 | py::class_<PyClass>(*m, "OnlineLMConfig") |
| 16 | - .def(py::init<const std::string &, float, int32_t, const std::string &>(), | 16 | + .def(py::init<const std::string &, float, int32_t, |
| 17 | + const std::string &, bool>(), | ||
| 17 | py::arg("model") = "", py::arg("scale") = 0.5f, | 18 | py::arg("model") = "", py::arg("scale") = 0.5f, |
| 18 | - py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu") | 19 | + py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu", |
| 20 | + py::arg("shallow_fusion") = true) | ||
| 19 | .def_readwrite("model", &PyClass::model) | 21 | .def_readwrite("model", &PyClass::model) |
| 20 | .def_readwrite("scale", &PyClass::scale) | 22 | .def_readwrite("scale", &PyClass::scale) |
| 21 | .def_readwrite("lm_provider", &PyClass::lm_provider) | 23 | .def_readwrite("lm_provider", &PyClass::lm_provider) |
| 22 | .def_readwrite("lm_num_threads", &PyClass::lm_num_threads) | 24 | .def_readwrite("lm_num_threads", &PyClass::lm_num_threads) |
| 25 | + .def_readwrite("shallow_fusion", &PyClass::shallow_fusion) | ||
| 23 | .def("__str__", &PyClass::ToString); | 26 | .def("__str__", &PyClass::ToString); |
| 24 | } | 27 | } |
| 25 | 28 |
| @@ -64,6 +64,7 @@ class OnlineRecognizer(object): | @@ -64,6 +64,7 @@ class OnlineRecognizer(object): | ||
| 64 | bpe_vocab: str = "", | 64 | bpe_vocab: str = "", |
| 65 | lm: str = "", | 65 | lm: str = "", |
| 66 | lm_scale: float = 0.1, | 66 | lm_scale: float = 0.1, |
| 67 | + lm_shallow_fusion: bool = True, | ||
| 67 | temperature_scale: float = 2.0, | 68 | temperature_scale: float = 2.0, |
| 68 | debug: bool = False, | 69 | debug: bool = False, |
| 69 | rule_fsts: str = "", | 70 | rule_fsts: str = "", |
| @@ -274,6 +275,7 @@ class OnlineRecognizer(object): | @@ -274,6 +275,7 @@ class OnlineRecognizer(object): | ||
| 274 | lm_config = OnlineLMConfig( | 275 | lm_config = OnlineLMConfig( |
| 275 | model=lm, | 276 | model=lm, |
| 276 | scale=lm_scale, | 277 | scale=lm_scale, |
| 278 | + shallow_fusion=lm_shallow_fusion, | ||
| 277 | ) | 279 | ) |
| 278 | 280 | ||
| 279 | recognizer_config = OnlineRecognizerConfig( | 281 | recognizer_config = OnlineRecognizerConfig( |
-
请 注册 或 登录 后发表评论