HieDean
Committed by GitHub

Judge before UseCachedDecoderOut (#431)

Co-authored-by: hiedean <hiedean@tju.edu.cn>
... ... @@ -89,9 +89,24 @@ void OnlineTransducerGreedySearchDecoder::Decode(
int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]);
int32_t vocab_size = model_->VocabSize();
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
Ort::Value decoder_out{nullptr};
bool is_batch_decoder_out_cached = true;
for (const auto &r : *result) {
if (!r.decoder_out) {
is_batch_decoder_out_cached = false;
break;
}
}
if (is_batch_decoder_out_cached) {
auto &r = result->front();
std::vector<int64_t> decoder_out_shape = r.decoder_out.GetTensorTypeAndShapeInfo().GetShape();
decoder_out_shape[0] = batch_size;
decoder_out = Ort::Value::CreateTensor<float>(model_->Allocator(), decoder_out_shape.data(), decoder_out_shape.size());
UseCachedDecoderOut(*result, &decoder_out);
} else {
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
decoder_out = model_->RunDecoder(std::move(decoder_input));
}
for (int32_t t = 0; t != num_frames; ++t) {
Ort::Value cur_encoder_out =
... ...