Askars Salimbajevs
Committed by GitHub

Preserve more context after endpointing in transducer (#2061)

@@ -388,16 +388,20 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -388,16 +388,20 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
388 auto r = decoder_->GetEmptyResult(); 388 auto r = decoder_->GetEmptyResult();
389 auto last_result = s->GetResult(); 389 auto last_result = s->GetResult();
390 // if last result is not empty, then 390 // if last result is not empty, then
391 - // preserve last tokens as the context for next result 391 + // truncate all last hyps and save as the context for next result
392 if (static_cast<int32_t>(last_result.tokens.size()) > context_size) { 392 if (static_cast<int32_t>(last_result.tokens.size()) > context_size) {
393 - std::vector<int64_t> context(last_result.tokens.end() - context_size,  
394 - last_result.tokens.end()); 393 + for (const auto &it : last_result.hyps) {
  394 + auto h = it.second;
  395 + r.hyps.Add({std::vector<int64_t>(h.ys.end() - context_size,
  396 + h.ys.end()),
  397 + h.log_prob});
  398 + }
395 399
396 - Hypotheses context_hyp({{context, 0}});  
397 - r.hyps = std::move(context_hyp);  
398 - r.tokens = std::move(context); 400 + r.tokens = std::vector<int64_t> (last_result.tokens.end() - context_size,
  401 + last_result.tokens.end());
399 } 402 }
400 403
  404 + // but reset all contextual biasing graph states to root
401 if (config_.decoding_method == "modified_beam_search" && 405 if (config_.decoding_method == "modified_beam_search" &&
402 nullptr != s->GetContextGraph()) { 406 nullptr != s->GetContextGraph()) {
403 for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) { 407 for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {