SilverSulfide
Committed by GitHub

Re-implement LM rescore for online transducer (#1231)

Co-authored-by: Martins Kronis <martins.kuznecovs@tilde.lv>
@@ -51,8 +51,13 @@ struct Hypothesis { @@ -51,8 +51,13 @@ struct Hypothesis {
51 // LM log prob if any. 51 // LM log prob if any.
52 double lm_log_prob = 0; 52 double lm_log_prob = 0;
53 53
54 - // the nn lm score for next token given the current ys 54 + // the nn lm score for next token given the current ys,
  55 + // when using shallow fusion
55 CopyableOrtValue nn_lm_scores; 56 CopyableOrtValue nn_lm_scores;
  57 +
  58 + // cur scored tokens by RNN LM, when rescoring
  59 + int32_t cur_scored_pos = 0;
  60 +
56 // the nn lm states 61 // the nn lm states
57 std::vector<CopyableOrtValue> nn_lm_states; 62 std::vector<CopyableOrtValue> nn_lm_states;
58 63
@@ -18,6 +18,8 @@ void OnlineLMConfig::Register(ParseOptions *po) { @@ -18,6 +18,8 @@ void OnlineLMConfig::Register(ParseOptions *po) {
18 "Number of threads to run the neural network of LM model"); 18 "Number of threads to run the neural network of LM model");
19 po->Register("lm-provider", &lm_provider, 19 po->Register("lm-provider", &lm_provider,
20 "Specify a provider to LM model use: cpu, cuda, coreml"); 20 "Specify a provider to LM model use: cpu, cuda, coreml");
  21 + po->Register("lm-shallow-fusion", &shallow_fusion,
  22 + "Boolean whether to use shallow fusion or rescore.");
21 } 23 }
22 24
23 bool OnlineLMConfig::Validate() const { 25 bool OnlineLMConfig::Validate() const {
@@ -34,7 +36,8 @@ std::string OnlineLMConfig::ToString() const { @@ -34,7 +36,8 @@ std::string OnlineLMConfig::ToString() const {
34 36
35 os << "OnlineLMConfig("; 37 os << "OnlineLMConfig(";
36 os << "model=\"" << model << "\", "; 38 os << "model=\"" << model << "\", ";
37 - os << "scale=" << scale << ")"; 39 + os << "scale=" << scale << ", ";
  40 + os << "shallow_fusion=" << (shallow_fusion ? "True" : "False") << ")";
38 41
39 return os.str(); 42 return os.str();
40 } 43 }
@@ -18,15 +18,18 @@ struct OnlineLMConfig { @@ -18,15 +18,18 @@ struct OnlineLMConfig {
18 float scale = 0.5; 18 float scale = 0.5;
19 int32_t lm_num_threads = 1; 19 int32_t lm_num_threads = 1;
20 std::string lm_provider = "cpu"; 20 std::string lm_provider = "cpu";
  21 + // enable shallow fusion
  22 + bool shallow_fusion = true;
21 23
22 OnlineLMConfig() = default; 24 OnlineLMConfig() = default;
23 25
24 OnlineLMConfig(const std::string &model, float scale, int32_t lm_num_threads, 26 OnlineLMConfig(const std::string &model, float scale, int32_t lm_num_threads,
25 - const std::string &lm_provider) 27 + const std::string &lm_provider, bool shallow_fusion)
26 : model(model), 28 : model(model),
27 scale(scale), 29 scale(scale),
28 lm_num_threads(lm_num_threads), 30 lm_num_threads(lm_num_threads),
29 - lm_provider(lm_provider) {} 31 + lm_provider(lm_provider),
  32 + shallow_fusion(shallow_fusion) {}
30 33
31 void Register(ParseOptions *po); 34 void Register(ParseOptions *po);
32 bool Validate() const; 35 bool Validate() const;
@@ -21,13 +21,17 @@ class OnlineLM { @@ -21,13 +21,17 @@ class OnlineLM {
21 21
22 static std::unique_ptr<OnlineLM> Create(const OnlineLMConfig &config); 22 static std::unique_ptr<OnlineLM> Create(const OnlineLMConfig &config);
23 23
24 - virtual std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() = 0; 24 + // init states for classic rescore
  25 + virtual std::vector<Ort::Value> GetInitStates() = 0;
25 26
26 - /** ScoreToken a batch of sentences. 27 + // init states for shallow fusion
  28 + virtual std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStatesSF() = 0;
  29 +
  30 + /** ScoreToken a batch of sentences (shallow fusion).
27 * 31 *
28 * @param x A 2-D tensor of shape (N, 1) with data type int64. 32 * @param x A 2-D tensor of shape (N, 1) with data type int64.
29 * @param states It contains the states for the LM model 33 * @param states It contains the states for the LM model
30 - * @return Return a pair containingo 34 + * @return Return a pair containing
31 * - log_prob of NN LM 35 * - log_prob of NN LM
32 * - updated states 36 * - updated states
33 * 37 *
@@ -35,13 +39,23 @@ class OnlineLM { @@ -35,13 +39,23 @@ class OnlineLM {
35 virtual std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken( 39 virtual std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
36 Ort::Value x, std::vector<Ort::Value> states) = 0; 40 Ort::Value x, std::vector<Ort::Value> states) = 0;
37 41
38 - /** This function updates lm_lob_prob and nn_lm_scores of hyp 42 + /** This function updates hyp.lm_log_prob of hyps (classic rescore).
  43 + *
  44 + * @param scale LM score
  45 + * @param context_size Context size of the transducer decoder model
  46 + * @param hyps It is changed in-place.
  47 + *
  48 + */
  49 + virtual void ComputeLMScore(float scale, int32_t context_size,
  50 + std::vector<Hypotheses> *hyps) = 0;
  51 +
  52 + /** This function updates lm_log_prob and nn_lm_scores of hyp (shallow fusion).
39 * 53 *
40 * @param scale LM score 54 * @param scale LM score
41 * @param hyps It is changed in-place. 55 * @param hyps It is changed in-place.
42 * 56 *
43 */ 57 */
44 - virtual void ComputeLMScore(float scale, Hypothesis *hyp) = 0; 58 + virtual void ComputeLMScoreSF(float scale, Hypothesis *hyp) = 0;
45 }; 59 };
46 60
47 } // namespace sherpa_onnx 61 } // namespace sherpa_onnx
@@ -107,7 +107,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -107,7 +107,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
107 107
108 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( 108 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
109 model_.get(), lm_.get(), config_.max_active_paths, 109 model_.get(), lm_.get(), config_.max_active_paths,
110 - config_.lm_config.scale, unk_id_, config_.blank_penalty, 110 + config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_,
  111 + config_.blank_penalty,
111 config_.temperature_scale); 112 config_.temperature_scale);
112 113
113 } else if (config.decoding_method == "greedy_search") { 114 } else if (config.decoding_method == "greedy_search") {
@@ -156,7 +157,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -156,7 +157,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
156 157
157 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( 158 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
158 model_.get(), lm_.get(), config_.max_active_paths, 159 model_.get(), lm_.get(), config_.max_active_paths,
159 - config_.lm_config.scale, unk_id_, config_.blank_penalty, 160 + config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_,
  161 + config_.blank_penalty,
160 config_.temperature_scale); 162 config_.temperature_scale);
161 163
162 } else if (config.decoding_method == "greedy_search") { 164 } else if (config.decoding_method == "greedy_search") {
@@ -8,6 +8,7 @@ @@ -8,6 +8,7 @@
8 #include <string> 8 #include <string>
9 #include <utility> 9 #include <utility>
10 #include <vector> 10 #include <vector>
  11 +#include <algorithm>
11 12
12 #include "onnxruntime_cxx_api.h" // NOLINT 13 #include "onnxruntime_cxx_api.h" // NOLINT
13 #include "sherpa-onnx/csrc/macros.h" 14 #include "sherpa-onnx/csrc/macros.h"
@@ -27,9 +28,10 @@ class OnlineRnnLM::Impl { @@ -27,9 +28,10 @@ class OnlineRnnLM::Impl {
27 Init(config); 28 Init(config);
28 } 29 }
29 30
30 - void ComputeLMScore(float scale, Hypothesis *hyp) { 31 + // shallow fusion scoring function
  32 + void ComputeLMScoreSF(float scale, Hypothesis *hyp) {
31 if (hyp->nn_lm_states.empty()) { 33 if (hyp->nn_lm_states.empty()) {
32 - auto init_states = GetInitStates(); 34 + auto init_states = GetInitStatesSF();
33 hyp->nn_lm_scores.value = std::move(init_states.first); 35 hyp->nn_lm_scores.value = std::move(init_states.first);
34 hyp->nn_lm_states = Convert(std::move(init_states.second)); 36 hyp->nn_lm_states = Convert(std::move(init_states.second));
35 } 37 }
@@ -49,6 +51,52 @@ class OnlineRnnLM::Impl { @@ -49,6 +51,52 @@ class OnlineRnnLM::Impl {
49 hyp->nn_lm_states = Convert(std::move(lm_out.second)); 51 hyp->nn_lm_states = Convert(std::move(lm_out.second));
50 } 52 }
51 53
  54 + // classic rescore function
  55 + void ComputeLMScore(float scale, int32_t context_size,
  56 + std::vector<Hypotheses> *hyps) {
  57 + Ort::AllocatorWithDefaultOptions allocator;
  58 +
  59 + for (auto &hyp : *hyps) {
  60 + for (auto &h_m : hyp) {
  61 + auto &h = h_m.second;
  62 + auto &ys = h.ys;
  63 + const int32_t token_num_in_chunk =
  64 + ys.size() - context_size - h.cur_scored_pos - 1;
  65 +
  66 + if (token_num_in_chunk < 1) {
  67 + continue;
  68 + }
  69 +
  70 + if (h.nn_lm_states.empty()) {
  71 + h.nn_lm_states = Convert(GetInitStates());
  72 + }
  73 +
  74 + if (token_num_in_chunk >= h.lm_rescore_min_chunk) {
  75 + std::array<int64_t, 2> x_shape{1, token_num_in_chunk};
  76 +
  77 + Ort::Value x = Ort::Value::CreateTensor<int64_t>(
  78 + allocator, x_shape.data(), x_shape.size());
  79 + int64_t *p_x = x.GetTensorMutableData<int64_t>();
  80 + std::copy(ys.begin() + context_size + h.cur_scored_pos,
  81 + ys.end() - 1, p_x);
  82 +
  83 + // streaming forward by NN LM
  84 + auto out = ScoreToken(std::move(x),
  85 + Convert(std::move(h.nn_lm_states)));
  86 +
  87 + // update NN LM score in hyp
  88 + const float *p_nll = out.first.GetTensorData<float>();
  89 + h.lm_log_prob = -scale * (*p_nll);
  90 +
  91 + // update NN LM states in hyp
  92 + h.nn_lm_states = Convert(std::move(out.second));
  93 +
  94 + h.cur_scored_pos += token_num_in_chunk;
  95 + }
  96 + }
  97 + }
  98 + }
  99 +
52 std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken( 100 std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
53 Ort::Value x, std::vector<Ort::Value> states) { 101 Ort::Value x, std::vector<Ort::Value> states) {
54 std::array<Ort::Value, 3> inputs = {std::move(x), std::move(states[0]), 102 std::array<Ort::Value, 3> inputs = {std::move(x), std::move(states[0]),
@@ -66,7 +114,8 @@ class OnlineRnnLM::Impl { @@ -66,7 +114,8 @@ class OnlineRnnLM::Impl {
66 return {std::move(out[0]), std::move(next_states)}; 114 return {std::move(out[0]), std::move(next_states)};
67 } 115 }
68 116
69 - std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() { 117 + // get init states for shallow fusion
  118 + std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStatesSF() {
70 std::vector<Ort::Value> ans; 119 std::vector<Ort::Value> ans;
71 ans.reserve(init_states_.size()); 120 ans.reserve(init_states_.size());
72 for (auto &s : init_states_) { 121 for (auto &s : init_states_) {
@@ -75,6 +124,18 @@ class OnlineRnnLM::Impl { @@ -75,6 +124,18 @@ class OnlineRnnLM::Impl {
75 return {View(&init_scores_.value), std::move(ans)}; 124 return {View(&init_scores_.value), std::move(ans)};
76 } 125 }
77 126
  127 + // get init states for classic rescore
  128 + std::vector<Ort::Value> GetInitStates() const {
  129 + std::vector<Ort::Value> ans;
  130 + ans.reserve(init_states_.size());
  131 +
  132 + for (const auto &s : init_states_) {
  133 + ans.emplace_back(Clone(allocator_, &s));
  134 + }
  135 +
  136 + return ans;
  137 + }
  138 +
78 private: 139 private:
79 void Init(const OnlineLMConfig &config) { 140 void Init(const OnlineLMConfig &config) {
80 auto buf = ReadFile(config_.model); 141 auto buf = ReadFile(config_.model);
@@ -116,7 +177,8 @@ class OnlineRnnLM::Impl { @@ -116,7 +177,8 @@ class OnlineRnnLM::Impl {
116 states.push_back(std::move(c)); 177 states.push_back(std::move(c));
117 auto pair = ScoreToken(std::move(x), std::move(states)); 178 auto pair = ScoreToken(std::move(x), std::move(states));
118 179
119 - init_scores_.value = std::move(pair.first); 180 + init_scores_.value = std::move(pair.first); // only used during
  181 + // shallow fusion
120 init_states_ = std::move(pair.second); 182 init_states_ = std::move(pair.second);
121 } 183 }
122 184
@@ -147,17 +209,31 @@ OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config) @@ -147,17 +209,31 @@ OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config)
147 209
148 OnlineRnnLM::~OnlineRnnLM() = default; 210 OnlineRnnLM::~OnlineRnnLM() = default;
149 211
150 -std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::GetInitStates() { 212 +// classic rescore state init
  213 +std::vector<Ort::Value> OnlineRnnLM::GetInitStates() {
151 return impl_->GetInitStates(); 214 return impl_->GetInitStates();
152 } 215 }
153 216
  217 +// shallow fusion state init
  218 +std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::GetInitStatesSF() {
  219 + return impl_->GetInitStatesSF();
  220 +}
  221 +
154 std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::ScoreToken( 222 std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::ScoreToken(
155 Ort::Value x, std::vector<Ort::Value> states) { 223 Ort::Value x, std::vector<Ort::Value> states) {
156 return impl_->ScoreToken(std::move(x), std::move(states)); 224 return impl_->ScoreToken(std::move(x), std::move(states));
157 } 225 }
158 226
159 -void OnlineRnnLM::ComputeLMScore(float scale, Hypothesis *hyp) {  
160 - return impl_->ComputeLMScore(scale, hyp); 227 +// classic rescore scores
  228 +void OnlineRnnLM::ComputeLMScore(float scale, int32_t context_size,
  229 + std::vector<Hypotheses> *hyps) {
  230 + return impl_->ComputeLMScore(scale, context_size, hyps);
161 } 231 }
162 232
  233 +// shallow fusion scores
  234 +void OnlineRnnLM::ComputeLMScoreSF(float scale, Hypothesis *hyp) {
  235 + return impl_->ComputeLMScoreSF(scale, hyp);
  236 +}
  237 +
  238 +
163 } // namespace sherpa_onnx 239 } // namespace sherpa_onnx
@@ -22,13 +22,17 @@ class OnlineRnnLM : public OnlineLM { @@ -22,13 +22,17 @@ class OnlineRnnLM : public OnlineLM {
22 22
23 explicit OnlineRnnLM(const OnlineLMConfig &config); 23 explicit OnlineRnnLM(const OnlineLMConfig &config);
24 24
25 - std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() override; 25 + // init scores for classic rescore
  26 + std::vector<Ort::Value> GetInitStates() override;
26 27
27 - /** ScoreToken a batch of sentences. 28 + // init scores for shallow fusion
  29 + std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStatesSF() override;
  30 +
  31 + /** ScoreToken a batch of sentences (shallow fusion).
28 * 32 *
29 * @param x A 2-D tensor of shape (N, L) with data type int64. 33 * @param x A 2-D tensor of shape (N, L) with data type int64.
30 * @param states It contains the states for the LM model 34 * @param states It contains the states for the LM model
31 - * @return Return a pair containingo 35 + * @return Return a pair containing
32 * - log_prob of NN LM 36 * - log_prob of NN LM
33 * - updated states 37 * - updated states
34 * 38 *
@@ -36,13 +40,23 @@ class OnlineRnnLM : public OnlineLM { @@ -36,13 +40,23 @@ class OnlineRnnLM : public OnlineLM {
36 std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken( 40 std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
37 Ort::Value x, std::vector<Ort::Value> states) override; 41 Ort::Value x, std::vector<Ort::Value> states) override;
38 42
39 - /** This function updates lm_lob_prob and nn_lm_scores of hyp 43 + /** This function updates hyp.lm_lob_prob of hyps (classic rescore).
  44 + *
  45 + * @param scale LM score
  46 + * @param context_size Context size of the transducer decoder model
  47 + * @param hyps It is changed in-place.
  48 + *
  49 + */
  50 + void ComputeLMScore(float scale, int32_t context_size,
  51 + std::vector<Hypotheses> *hyps) override;
  52 +
  53 + /** This function updates lm_lob_prob and nn_lm_scores of hyp (shallow fusion).
40 * 54 *
41 * @param scale LM score 55 * @param scale LM score
42 * @param hyps It is changed in-place. 56 * @param hyps It is changed in-place.
43 * 57 *
44 */ 58 */
45 - void ComputeLMScore(float scale, Hypothesis *hyp) override; 59 + void ComputeLMScoreSF(float scale, Hypothesis *hyp) override;
46 60
47 private: 61 private:
48 class Impl; 62 class Impl;
@@ -156,7 +156,11 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -156,7 +156,11 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
156 156
157 // add log_prob of each hypothesis to p_logprob before taking top_k 157 // add log_prob of each hypothesis to p_logprob before taking top_k
158 for (int32_t i = 0; i != num_hyps; ++i) { 158 for (int32_t i = 0; i != num_hyps; ++i) {
159 - float log_prob = prev[i].log_prob + prev[i].lm_log_prob; 159 + float log_prob = prev[i].log_prob;
  160 + if (lm_ && shallow_fusion_) {
  161 + log_prob += prev[i].lm_log_prob;
  162 + }
  163 +
160 for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) { 164 for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) {
161 *p_logprob += log_prob; 165 *p_logprob += log_prob;
162 } 166 }
@@ -192,22 +196,31 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -192,22 +196,31 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
192 context_score = std::get<0>(context_res); 196 context_score = std::get<0>(context_res);
193 new_hyp.context_state = std::get<1>(context_res); 197 new_hyp.context_state = std::get<1>(context_res);
194 } 198 }
195 - if (lm_) {  
196 - lm_->ComputeLMScore(lm_scale_, &new_hyp); 199 + if (lm_ && shallow_fusion_) {
  200 + lm_->ComputeLMScoreSF(lm_scale_, &new_hyp);
197 } 201 }
198 } else { 202 } else {
199 ++new_hyp.num_trailing_blanks; 203 ++new_hyp.num_trailing_blanks;
200 } 204 }
201 - new_hyp.log_prob = p_logprob[k] + context_score - 205 + if (lm_ && shallow_fusion_) {
  206 + new_hyp.log_prob = p_logprob[k] + context_score -
202 prev_lm_log_prob; // log_prob only includes the 207 prev_lm_log_prob; // log_prob only includes the
203 // score of the transducer 208 // score of the transducer
  209 + } else {
  210 + new_hyp.log_prob = p_logprob[k] + context_score; // rescore or no LM
  211 + // previous token
  212 + // score is ignored
  213 + }
  214 +
204 // export the per-token log scores 215 // export the per-token log scores
205 if (new_token != 0 && new_token != unk_id_) { 216 if (new_token != 0 && new_token != unk_id_) {
206 float y_prob = logit_with_temperature[start * vocab_size + k]; 217 float y_prob = logit_with_temperature[start * vocab_size + k];
207 new_hyp.ys_probs.push_back(y_prob); 218 new_hyp.ys_probs.push_back(y_prob);
208 219
209 - if (lm_) { // export only when LM is used 220 + if (lm_ && shallow_fusion_) { // export only if
  221 + // LM shallow fusion is used
210 float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob; 222 float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob;
  223 +
211 if (lm_scale_ != 0.0) { 224 if (lm_scale_ != 0.0) {
212 lm_prob /= lm_scale_; // remove lm-scale 225 lm_prob /= lm_scale_; // remove lm-scale
213 } 226 }
@@ -227,6 +240,11 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -227,6 +240,11 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
227 } // for (int32_t b = 0; b != batch_size; ++b) 240 } // for (int32_t b = 0; b != batch_size; ++b)
228 } // for (int32_t t = 0; t != num_frames; ++t) 241 } // for (int32_t t = 0; t != num_frames; ++t)
229 242
  243 + // classic lm rescore
  244 + if (lm_ && !shallow_fusion_) {
  245 + lm_->ComputeLMScore(lm_scale_, model_->ContextSize(), &cur);
  246 + }
  247 +
230 for (int32_t b = 0; b != batch_size; ++b) { 248 for (int32_t b = 0; b != batch_size; ++b) {
231 auto &hyps = cur[b]; 249 auto &hyps = cur[b];
232 auto best_hyp = hyps.GetMostProbable(true); 250 auto best_hyp = hyps.GetMostProbable(true);
@@ -21,13 +21,16 @@ class OnlineTransducerModifiedBeamSearchDecoder @@ -21,13 +21,16 @@ class OnlineTransducerModifiedBeamSearchDecoder
21 OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model, 21 OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model,
22 OnlineLM *lm, 22 OnlineLM *lm,
23 int32_t max_active_paths, 23 int32_t max_active_paths,
24 - float lm_scale, int32_t unk_id, 24 + float lm_scale,
  25 + bool shallow_fusion,
  26 + int32_t unk_id,
25 float blank_penalty, 27 float blank_penalty,
26 float temperature_scale) 28 float temperature_scale)
27 : model_(model), 29 : model_(model),
28 lm_(lm), 30 lm_(lm),
29 max_active_paths_(max_active_paths), 31 max_active_paths_(max_active_paths),
30 lm_scale_(lm_scale), 32 lm_scale_(lm_scale),
  33 + shallow_fusion_(shallow_fusion),
31 unk_id_(unk_id), 34 unk_id_(unk_id),
32 blank_penalty_(blank_penalty), 35 blank_penalty_(blank_penalty),
33 temperature_scale_(temperature_scale) {} 36 temperature_scale_(temperature_scale) {}
@@ -50,6 +53,7 @@ class OnlineTransducerModifiedBeamSearchDecoder @@ -50,6 +53,7 @@ class OnlineTransducerModifiedBeamSearchDecoder
50 53
51 int32_t max_active_paths_; 54 int32_t max_active_paths_;
52 float lm_scale_; // used only when lm_ is not nullptr 55 float lm_scale_; // used only when lm_ is not nullptr
  56 + bool shallow_fusion_; // used only when lm_ is not nullptr
53 int32_t unk_id_; 57 int32_t unk_id_;
54 float blank_penalty_; 58 float blank_penalty_;
55 float temperature_scale_; 59 float temperature_scale_;
@@ -13,13 +13,16 @@ namespace sherpa_onnx { @@ -13,13 +13,16 @@ namespace sherpa_onnx {
13 void PybindOnlineLMConfig(py::module *m) { 13 void PybindOnlineLMConfig(py::module *m) {
14 using PyClass = OnlineLMConfig; 14 using PyClass = OnlineLMConfig;
15 py::class_<PyClass>(*m, "OnlineLMConfig") 15 py::class_<PyClass>(*m, "OnlineLMConfig")
16 - .def(py::init<const std::string &, float, int32_t, const std::string &>(), 16 + .def(py::init<const std::string &, float, int32_t,
  17 + const std::string &, bool>(),
17 py::arg("model") = "", py::arg("scale") = 0.5f, 18 py::arg("model") = "", py::arg("scale") = 0.5f,
18 - py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu") 19 + py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu",
  20 + py::arg("shallow_fusion") = true)
19 .def_readwrite("model", &PyClass::model) 21 .def_readwrite("model", &PyClass::model)
20 .def_readwrite("scale", &PyClass::scale) 22 .def_readwrite("scale", &PyClass::scale)
21 .def_readwrite("lm_provider", &PyClass::lm_provider) 23 .def_readwrite("lm_provider", &PyClass::lm_provider)
22 .def_readwrite("lm_num_threads", &PyClass::lm_num_threads) 24 .def_readwrite("lm_num_threads", &PyClass::lm_num_threads)
  25 + .def_readwrite("shallow_fusion", &PyClass::shallow_fusion)
23 .def("__str__", &PyClass::ToString); 26 .def("__str__", &PyClass::ToString);
24 } 27 }
25 28
@@ -64,6 +64,7 @@ class OnlineRecognizer(object): @@ -64,6 +64,7 @@ class OnlineRecognizer(object):
64 bpe_vocab: str = "", 64 bpe_vocab: str = "",
65 lm: str = "", 65 lm: str = "",
66 lm_scale: float = 0.1, 66 lm_scale: float = 0.1,
  67 + lm_shallow_fusion: bool = True,
67 temperature_scale: float = 2.0, 68 temperature_scale: float = 2.0,
68 debug: bool = False, 69 debug: bool = False,
69 rule_fsts: str = "", 70 rule_fsts: str = "",
@@ -274,6 +275,7 @@ class OnlineRecognizer(object): @@ -274,6 +275,7 @@ class OnlineRecognizer(object):
274 lm_config = OnlineLMConfig( 275 lm_config = OnlineLMConfig(
275 model=lm, 276 model=lm,
276 scale=lm_scale, 277 scale=lm_scale,
  278 + shallow_fusion=lm_shallow_fusion,
277 ) 279 )
278 280
279 recognizer_config = OnlineRecognizerConfig( 281 recognizer_config = OnlineRecognizerConfig(