Askars
Committed by GitHub

Preserve previous result as context for next segment (#1335)

Co-authored-by: vsd-vector <askars.salimbajevs@tilde.lv>
... ... @@ -360,11 +360,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
void Reset(OnlineStream *s) const override {
int32_t context_size = model_->ContextSize();
{
// segment is incremented only when the last
// result is not empty
// result is not empty, contains non-blanks and longer than context_size)
const auto &r = s->GetResult();
if (!r.tokens.empty() && r.tokens.back() != 0) {
if (!r.tokens.empty() && r.tokens.back() != 0 && r.tokens.size() > context_size) {
s->GetCurrentSegment() += 1;
}
}
... ... @@ -372,10 +374,6 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
// reset encoder states
// s->SetStates(model_->GetEncoderInitStates());
// we keep the decoder_out
decoder_->UpdateDecoderOut(&s->GetResult());
Ort::Value decoder_out = std::move(s->GetResult().decoder_out);
auto r = decoder_->GetEmptyResult();
if (config_.decoding_method == "modified_beam_search" &&
nullptr != s->GetContextGraph()) {
... ... @@ -383,8 +381,19 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
it->second.context_state = s->GetContextGraph()->Root();
}
}
auto last_result = s->GetResult();
// if last result is not empty, then
// preserve last tokens as the context for next result
if (static_cast<int32_t>(last_result.tokens.size()) > context_size) {
std::vector<int64_t> context(last_result.tokens.end() - context_size, last_result.tokens.end());
Hypotheses context_hyp({{context, 0}});
r.hyps = std::move(context_hyp);
r.tokens = std::move(context);
}
s->SetResult(r);
s->GetResult().decoder_out = std::move(decoder_out);
// Note: We only update counters. The underlying audio samples
// are not discarded.
... ...