Askars
Committed by GitHub

Preserve previous result as context for next segment (#1335)

Co-authored-by: vsd-vector <askars.salimbajevs@tilde.lv>
@@ -360,11 +360,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -360,11 +360,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
360 } 360 }
361 361
362 void Reset(OnlineStream *s) const override { 362 void Reset(OnlineStream *s) const override {
  363 + int32_t context_size = model_->ContextSize();
  364 +
363 { 365 {
364 // segment is incremented only when the last 366 // segment is incremented only when the last
365 - // result is not empty 367 + // result is not empty, contains non-blanks and longer than context_size)
366 const auto &r = s->GetResult(); 368 const auto &r = s->GetResult();
367 - if (!r.tokens.empty() && r.tokens.back() != 0) { 369 + if (!r.tokens.empty() && r.tokens.back() != 0 && r.tokens.size() > context_size) {
368 s->GetCurrentSegment() += 1; 370 s->GetCurrentSegment() += 1;
369 } 371 }
370 } 372 }
@@ -372,10 +374,6 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -372,10 +374,6 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
372 // reset encoder states 374 // reset encoder states
373 // s->SetStates(model_->GetEncoderInitStates()); 375 // s->SetStates(model_->GetEncoderInitStates());
374 376
375 - // we keep the decoder_out  
376 - decoder_->UpdateDecoderOut(&s->GetResult());  
377 - Ort::Value decoder_out = std::move(s->GetResult().decoder_out);  
378 -  
379 auto r = decoder_->GetEmptyResult(); 377 auto r = decoder_->GetEmptyResult();
380 if (config_.decoding_method == "modified_beam_search" && 378 if (config_.decoding_method == "modified_beam_search" &&
381 nullptr != s->GetContextGraph()) { 379 nullptr != s->GetContextGraph()) {
@@ -383,8 +381,19 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -383,8 +381,19 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
383 it->second.context_state = s->GetContextGraph()->Root(); 381 it->second.context_state = s->GetContextGraph()->Root();
384 } 382 }
385 } 383 }
  384 +
  385 + auto last_result = s->GetResult();
  386 + // if last result is not empty, then
  387 + // preserve last tokens as the context for next result
  388 + if (static_cast<int32_t>(last_result.tokens.size()) > context_size) {
  389 + std::vector<int64_t> context(last_result.tokens.end() - context_size, last_result.tokens.end());
  390 +
  391 + Hypotheses context_hyp({{context, 0}});
  392 + r.hyps = std::move(context_hyp);
  393 + r.tokens = std::move(context);
  394 + }
  395 +
386 s->SetResult(r); 396 s->SetResult(r);
387 - s->GetResult().decoder_out = std::move(decoder_out);  
388 397
389 // Note: We only update counters. The underlying audio samples 398 // Note: We only update counters. The underlying audio samples
390 // are not discarded. 399 // are not discarded.