Committed by
GitHub
Preserve more context after endpointing in transducer (#2061)
正在显示
1 个修改的文件
包含
10 行增加
和
6 行删除
| @@ -388,16 +388,20 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -388,16 +388,20 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 388 | auto r = decoder_->GetEmptyResult(); | 388 | auto r = decoder_->GetEmptyResult(); |
| 389 | auto last_result = s->GetResult(); | 389 | auto last_result = s->GetResult(); |
| 390 | // if last result is not empty, then | 390 | // if last result is not empty, then |
| 391 | - // preserve last tokens as the context for next result | 391 | + // truncate all last hyps and save as the context for next result |
| 392 | if (static_cast<int32_t>(last_result.tokens.size()) > context_size) { | 392 | if (static_cast<int32_t>(last_result.tokens.size()) > context_size) { |
| 393 | - std::vector<int64_t> context(last_result.tokens.end() - context_size, | ||
| 394 | - last_result.tokens.end()); | 393 | + for (const auto &it : last_result.hyps) { |
| 394 | + auto h = it.second; | ||
| 395 | + r.hyps.Add({std::vector<int64_t>(h.ys.end() - context_size, | ||
| 396 | + h.ys.end()), | ||
| 397 | + h.log_prob}); | ||
| 398 | + } | ||
| 395 | 399 | ||
| 396 | - Hypotheses context_hyp({{context, 0}}); | ||
| 397 | - r.hyps = std::move(context_hyp); | ||
| 398 | - r.tokens = std::move(context); | 400 | + r.tokens = std::vector<int64_t> (last_result.tokens.end() - context_size, |
| 401 | + last_result.tokens.end()); | ||
| 399 | } | 402 | } |
| 400 | 403 | ||
| 404 | + // but reset all contextual biasing graph states to root | ||
| 401 | if (config_.decoding_method == "modified_beam_search" && | 405 | if (config_.decoding_method == "modified_beam_search" && |
| 402 | nullptr != s->GetContextGraph()) { | 406 | nullptr != s->GetContextGraph()) { |
| 403 | for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) { | 407 | for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) { |
-
请 注册 或 登录 后发表评论