Committed by
GitHub
Preserve previous result as context for next segment (#1335)
Co-authored-by: vsd-vector <askars.salimbajevs@tilde.lv>
正在显示
1 个修改的文件
包含
16 行增加
和
7 行删除
| @@ -360,11 +360,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -360,11 +360,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 360 | } | 360 | } |
| 361 | 361 | ||
| 362 | void Reset(OnlineStream *s) const override { | 362 | void Reset(OnlineStream *s) const override { |
| 363 | + int32_t context_size = model_->ContextSize(); | ||
| 364 | + | ||
| 363 | { | 365 | { |
| 364 | // segment is incremented only when the last | 366 | // segment is incremented only when the last |
| 365 | - // result is not empty | 367 | + // result is not empty, contains non-blanks and longer than context_size) |
| 366 | const auto &r = s->GetResult(); | 368 | const auto &r = s->GetResult(); |
| 367 | - if (!r.tokens.empty() && r.tokens.back() != 0) { | 369 | + if (!r.tokens.empty() && r.tokens.back() != 0 && r.tokens.size() > context_size) { |
| 368 | s->GetCurrentSegment() += 1; | 370 | s->GetCurrentSegment() += 1; |
| 369 | } | 371 | } |
| 370 | } | 372 | } |
| @@ -372,10 +374,6 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -372,10 +374,6 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 372 | // reset encoder states | 374 | // reset encoder states |
| 373 | // s->SetStates(model_->GetEncoderInitStates()); | 375 | // s->SetStates(model_->GetEncoderInitStates()); |
| 374 | 376 | ||
| 375 | - // we keep the decoder_out | ||
| 376 | - decoder_->UpdateDecoderOut(&s->GetResult()); | ||
| 377 | - Ort::Value decoder_out = std::move(s->GetResult().decoder_out); | ||
| 378 | - | ||
| 379 | auto r = decoder_->GetEmptyResult(); | 377 | auto r = decoder_->GetEmptyResult(); |
| 380 | if (config_.decoding_method == "modified_beam_search" && | 378 | if (config_.decoding_method == "modified_beam_search" && |
| 381 | nullptr != s->GetContextGraph()) { | 379 | nullptr != s->GetContextGraph()) { |
| @@ -383,8 +381,19 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -383,8 +381,19 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 383 | it->second.context_state = s->GetContextGraph()->Root(); | 381 | it->second.context_state = s->GetContextGraph()->Root(); |
| 384 | } | 382 | } |
| 385 | } | 383 | } |
| 384 | + | ||
| 385 | + auto last_result = s->GetResult(); | ||
| 386 | + // if last result is not empty, then | ||
| 387 | + // preserve last tokens as the context for next result | ||
| 388 | + if (static_cast<int32_t>(last_result.tokens.size()) > context_size) { | ||
| 389 | + std::vector<int64_t> context(last_result.tokens.end() - context_size, last_result.tokens.end()); | ||
| 390 | + | ||
| 391 | + Hypotheses context_hyp({{context, 0}}); | ||
| 392 | + r.hyps = std::move(context_hyp); | ||
| 393 | + r.tokens = std::move(context); | ||
| 394 | + } | ||
| 395 | + | ||
| 386 | s->SetResult(r); | 396 | s->SetResult(r); |
| 387 | - s->GetResult().decoder_out = std::move(decoder_out); | ||
| 388 | 397 | ||
| 389 | // Note: We only update counters. The underlying audio samples | 398 | // Note: We only update counters. The underlying audio samples |
| 390 | // are not discarded. | 399 | // are not discarded. |
-
请 注册 或 登录 后发表评论