Committed by
GitHub
Judge before UseCachedDecoderOut (#431)
Co-authored-by: hiedean <hiedean@tju.edu.cn>
正在显示
1 个修改的文件
包含
18 行增加
和
3 行删除
| @@ -89,9 +89,24 @@ void OnlineTransducerGreedySearchDecoder::Decode( | @@ -89,9 +89,24 @@ void OnlineTransducerGreedySearchDecoder::Decode( | ||
| 89 | int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]); | 89 | int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]); |
| 90 | int32_t vocab_size = model_->VocabSize(); | 90 | int32_t vocab_size = model_->VocabSize(); |
| 91 | 91 | ||
| 92 | - Ort::Value decoder_input = model_->BuildDecoderInput(*result); | ||
| 93 | - Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); | ||
| 94 | - UseCachedDecoderOut(*result, &decoder_out); | 92 | + Ort::Value decoder_out{nullptr}; |
| 93 | + bool is_batch_decoder_out_cached = true; | ||
| 94 | + for (const auto &r : *result) { | ||
| 95 | + if (!r.decoder_out) { | ||
| 96 | + is_batch_decoder_out_cached = false; | ||
| 97 | + break; | ||
| 98 | + } | ||
| 99 | + } | ||
| 100 | + if (is_batch_decoder_out_cached) { | ||
| 101 | + auto &r = result->front(); | ||
| 102 | + std::vector<int64_t> decoder_out_shape = r.decoder_out.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 103 | + decoder_out_shape[0] = batch_size; | ||
| 104 | + decoder_out = Ort::Value::CreateTensor<float>(model_->Allocator(), decoder_out_shape.data(), decoder_out_shape.size()); | ||
| 105 | + UseCachedDecoderOut(*result, &decoder_out); | ||
| 106 | + } else { | ||
| 107 | + Ort::Value decoder_input = model_->BuildDecoderInput(*result); | ||
| 108 | + decoder_out = model_->RunDecoder(std::move(decoder_input)); | ||
| 109 | + } | ||
| 95 | 110 | ||
| 96 | for (int32_t t = 0; t != num_frames; ++t) { | 111 | for (int32_t t = 0; t != num_frames; ++t) { |
| 97 | Ort::Value cur_encoder_out = | 112 | Ort::Value cur_encoder_out = |
-
请 注册 或 登录 后发表评论