PF Luo
Committed by GitHub

Fix lm fusion (#157)

* share GetHypsRowSplits interface and fix getting Topk not taking logprob

* fix lm score of lm fusion and make padding len same with 'icefall/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py'
@@ -35,7 +35,7 @@ class OnlineRnnLM::Impl { @@ -35,7 +35,7 @@ class OnlineRnnLM::Impl {
35 35
36 // get lm score for cur token given the hyp->ys[:-1] and save to lm_log_prob 36 // get lm score for cur token given the hyp->ys[:-1] and save to lm_log_prob
37 const float *nn_lm_scores = hyp->nn_lm_scores.value.GetTensorData<float>(); 37 const float *nn_lm_scores = hyp->nn_lm_scores.value.GetTensorData<float>();
38 - hyp->lm_log_prob = nn_lm_scores[hyp->ys.back()] * scale; 38 + hyp->lm_log_prob += nn_lm_scores[hyp->ys.back()] * scale;
39 39
40 // get lm scores for next tokens given the hyp->ys[:] and save to 40 // get lm scores for next tokens given the hyp->ys[:] and save to
41 // nn_lm_scores 41 // nn_lm_scores
@@ -152,7 +152,9 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -152,7 +152,9 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
152 } else { 152 } else {
153 ++new_hyp.num_trailing_blanks; 153 ++new_hyp.num_trailing_blanks;
154 } 154 }
155 - new_hyp.log_prob = p_logprob[k] - prev_lm_log_prob; 155 + new_hyp.log_prob =
  156 + p_logprob[k] - prev_lm_log_prob; // log_prob only includes the
  157 + // score of the transducer
156 hyps.Add(std::move(new_hyp)); 158 hyps.Add(std::move(new_hyp));
157 } // for (auto k : topk) 159 } // for (auto k : topk)
158 cur.push_back(std::move(hyps)); 160 cur.push_back(std::move(hyps));
@@ -94,7 +94,7 @@ for a list of pre-trained models to download. @@ -94,7 +94,7 @@ for a list of pre-trained models to download.
94 auto s = recognizer.CreateStream(); 94 auto s = recognizer.CreateStream();
95 s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); 95 s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
96 96
97 - std::vector<float> tail_paddings(static_cast<int>(0.5 * sampling_rate)); 97 + std::vector<float> tail_paddings(static_cast<int>(0.3 * sampling_rate));
98 // Note: We can call AcceptWaveform() multiple times. 98 // Note: We can call AcceptWaveform() multiple times.
99 s->AcceptWaveform(sampling_rate, tail_paddings.data(), tail_paddings.size()); 99 s->AcceptWaveform(sampling_rate, tail_paddings.data(), tail_paddings.size());
100 100