Committed by
GitHub
Fix lm fusion (#157)
* share GetHypsRowSplits interface and fix getting Topk not taking logprob * fix lm score of lm fusion and make padding len same with 'icefall/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py'
正在显示
3 个修改的文件
包含
5 行增加
和
3 行删除
| @@ -35,7 +35,7 @@ class OnlineRnnLM::Impl { | @@ -35,7 +35,7 @@ class OnlineRnnLM::Impl { | ||
| 35 | 35 | ||
| 36 | // get lm score for cur token given the hyp->ys[:-1] and save to lm_log_prob | 36 | // get lm score for cur token given the hyp->ys[:-1] and save to lm_log_prob |
| 37 | const float *nn_lm_scores = hyp->nn_lm_scores.value.GetTensorData<float>(); | 37 | const float *nn_lm_scores = hyp->nn_lm_scores.value.GetTensorData<float>(); |
| 38 | - hyp->lm_log_prob = nn_lm_scores[hyp->ys.back()] * scale; | 38 | + hyp->lm_log_prob += nn_lm_scores[hyp->ys.back()] * scale; |
| 39 | 39 | ||
| 40 | // get lm scores for next tokens given the hyp->ys[:] and save to | 40 | // get lm scores for next tokens given the hyp->ys[:] and save to |
| 41 | // nn_lm_scores | 41 | // nn_lm_scores |
| @@ -152,7 +152,9 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | @@ -152,7 +152,9 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 152 | } else { | 152 | } else { |
| 153 | ++new_hyp.num_trailing_blanks; | 153 | ++new_hyp.num_trailing_blanks; |
| 154 | } | 154 | } |
| 155 | - new_hyp.log_prob = p_logprob[k] - prev_lm_log_prob; | 155 | + new_hyp.log_prob = |
| 156 | + p_logprob[k] - prev_lm_log_prob; // log_prob only includes the | ||
| 157 | + // score of the transducer | ||
| 156 | hyps.Add(std::move(new_hyp)); | 158 | hyps.Add(std::move(new_hyp)); |
| 157 | } // for (auto k : topk) | 159 | } // for (auto k : topk) |
| 158 | cur.push_back(std::move(hyps)); | 160 | cur.push_back(std::move(hyps)); |
| @@ -94,7 +94,7 @@ for a list of pre-trained models to download. | @@ -94,7 +94,7 @@ for a list of pre-trained models to download. | ||
| 94 | auto s = recognizer.CreateStream(); | 94 | auto s = recognizer.CreateStream(); |
| 95 | s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); | 95 | s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); |
| 96 | 96 | ||
| 97 | - std::vector<float> tail_paddings(static_cast<int>(0.5 * sampling_rate)); | 97 | + std::vector<float> tail_paddings(static_cast<int>(0.3 * sampling_rate)); |
| 98 | // Note: We can call AcceptWaveform() multiple times. | 98 | // Note: We can call AcceptWaveform() multiple times. |
| 99 | s->AcceptWaveform(sampling_rate, tail_paddings.data(), tail_paddings.size()); | 99 | s->AcceptWaveform(sampling_rate, tail_paddings.data(), tail_paddings.size()); |
| 100 | 100 |
-
请 注册 或 登录 后发表评论