HieDean
Committed by GitHub

Judge before UseCachedDecoderOut (#431)

Co-authored-by: hiedean <hiedean@tju.edu.cn>
@@ -89,9 +89,24 @@ void OnlineTransducerGreedySearchDecoder::Decode( @@ -89,9 +89,24 @@ void OnlineTransducerGreedySearchDecoder::Decode(
89 int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]); 89 int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]);
90 int32_t vocab_size = model_->VocabSize(); 90 int32_t vocab_size = model_->VocabSize();
91 91
92 - Ort::Value decoder_input = model_->BuildDecoderInput(*result);  
93 - Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));  
94 - UseCachedDecoderOut(*result, &decoder_out); 92 + Ort::Value decoder_out{nullptr};
  93 + bool is_batch_decoder_out_cached = true;
  94 + for (const auto &r : *result) {
  95 + if (!r.decoder_out) {
  96 + is_batch_decoder_out_cached = false;
  97 + break;
  98 + }
  99 + }
  100 + if (is_batch_decoder_out_cached) {
  101 + auto &r = result->front();
  102 + std::vector<int64_t> decoder_out_shape = r.decoder_out.GetTensorTypeAndShapeInfo().GetShape();
  103 + decoder_out_shape[0] = batch_size;
  104 + decoder_out = Ort::Value::CreateTensor<float>(model_->Allocator(), decoder_out_shape.data(), decoder_out_shape.size());
  105 + UseCachedDecoderOut(*result, &decoder_out);
  106 + } else {
  107 + Ort::Value decoder_input = model_->BuildDecoderInput(*result);
  108 + decoder_out = model_->RunDecoder(std::move(decoder_input));
  109 + }
95 110
96 for (int32_t t = 0; t != num_frames; ++t) { 111 for (int32_t t = 0; t != num_frames; ++t) {
97 Ort::Value cur_encoder_out = 112 Ort::Value cur_encoder_out =