Committed by
GitHub
Fix batch decoding for greedy search (#71)
正在显示
1 个修改的文件
包含
25 行增加
和
33 行删除
| @@ -10,47 +10,39 @@ | @@ -10,47 +10,39 @@ | ||
| 10 | #include <utility> | 10 | #include <utility> |
| 11 | #include <vector> | 11 | #include <vector> |
| 12 | 12 | ||
| 13 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 13 | #include "sherpa-onnx/csrc/onnx-utils.h" | 14 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 14 | 15 | ||
| 15 | namespace sherpa_onnx { | 16 | namespace sherpa_onnx { |
| 16 | 17 | ||
| 17 | -static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) { | 18 | +static Ort::Value GetFrame(OrtAllocator *allocator, Ort::Value *encoder_out, |
| 19 | + int32_t t) { | ||
| 18 | std::vector<int64_t> encoder_out_shape = | 20 | std::vector<int64_t> encoder_out_shape = |
| 19 | encoder_out->GetTensorTypeAndShapeInfo().GetShape(); | 21 | encoder_out->GetTensorTypeAndShapeInfo().GetShape(); |
| 20 | - assert(encoder_out_shape[0] == 1); | ||
| 21 | 22 | ||
| 22 | - int32_t encoder_out_dim = encoder_out_shape[2]; | 23 | + auto batch_size = encoder_out_shape[0]; |
| 24 | + auto num_frames = encoder_out_shape[1]; | ||
| 25 | + assert(t < num_frames); | ||
| 23 | 26 | ||
| 24 | - auto memory_info = | ||
| 25 | - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 26 | - | ||
| 27 | - std::array<int64_t, 2> shape{1, encoder_out_dim}; | ||
| 28 | - | ||
| 29 | - return Ort::Value::CreateTensor( | ||
| 30 | - memory_info, | ||
| 31 | - encoder_out->GetTensorMutableData<float>() + t * encoder_out_dim, | ||
| 32 | - encoder_out_dim, shape.data(), shape.size()); | ||
| 33 | -} | 27 | + auto encoder_out_dim = encoder_out_shape[2]; |
| 34 | 28 | ||
| 35 | -static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, | ||
| 36 | - int32_t n) { | ||
| 37 | - if (n == 1) { | ||
| 38 | - return std::move(*cur_encoder_out); | ||
| 39 | - } | 29 | + auto offset = num_frames * encoder_out_dim; |
| 40 | 30 | ||
| 41 | - std::vector<int64_t> cur_encoder_out_shape = | ||
| 42 | - cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape(); | 31 | + auto memory_info = |
| 32 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 43 | 33 | ||
| 44 | - std::array<int64_t, 2> ans_shape{n, cur_encoder_out_shape[1]}; | 34 | + std::array<int64_t, 2> shape{batch_size, encoder_out_dim}; |
| 45 | 35 | ||
| 46 | - Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(), | ||
| 47 | - ans_shape.size()); | 36 | + Ort::Value ans = |
| 37 | + Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size()); | ||
| 48 | 38 | ||
| 49 | - const float *src = cur_encoder_out->GetTensorData<float>(); | ||
| 50 | float *dst = ans.GetTensorMutableData<float>(); | 39 | float *dst = ans.GetTensorMutableData<float>(); |
| 51 | - for (int32_t i = 0; i != n; ++i) { | ||
| 52 | - std::copy(src, src + cur_encoder_out_shape[1], dst); | ||
| 53 | - dst += cur_encoder_out_shape[1]; | 40 | + const float *src = encoder_out->GetTensorData<float>(); |
| 41 | + | ||
| 42 | + for (int32_t i = 0; i != batch_size; ++i) { | ||
| 43 | + std::copy(src + t * encoder_out_dim, src + (t + 1) * encoder_out_dim, dst); | ||
| 44 | + src += offset; | ||
| 45 | + dst += encoder_out_dim; | ||
| 54 | } | 46 | } |
| 55 | 47 | ||
| 56 | return ans; | 48 | return ans; |
| @@ -83,10 +75,10 @@ void OnlineTransducerGreedySearchDecoder::Decode( | @@ -83,10 +75,10 @@ void OnlineTransducerGreedySearchDecoder::Decode( | ||
| 83 | encoder_out.GetTensorTypeAndShapeInfo().GetShape(); | 75 | encoder_out.GetTensorTypeAndShapeInfo().GetShape(); |
| 84 | 76 | ||
| 85 | if (encoder_out_shape[0] != result->size()) { | 77 | if (encoder_out_shape[0] != result->size()) { |
| 86 | - fprintf(stderr, | ||
| 87 | - "Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n", | ||
| 88 | - static_cast<int32_t>(encoder_out_shape[0]), | ||
| 89 | - static_cast<int32_t>(result->size())); | 78 | + SHERPA_ONNX_LOGE( |
| 79 | + "Size mismatch! encoder_out.size(0) %d, result.size(0): %d", | ||
| 80 | + static_cast<int32_t>(encoder_out_shape[0]), | ||
| 81 | + static_cast<int32_t>(result->size())); | ||
| 90 | exit(-1); | 82 | exit(-1); |
| 91 | } | 83 | } |
| 92 | 84 | ||
| @@ -98,10 +90,10 @@ void OnlineTransducerGreedySearchDecoder::Decode( | @@ -98,10 +90,10 @@ void OnlineTransducerGreedySearchDecoder::Decode( | ||
| 98 | Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); | 90 | Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); |
| 99 | 91 | ||
| 100 | for (int32_t t = 0; t != num_frames; ++t) { | 92 | for (int32_t t = 0; t != num_frames; ++t) { |
| 101 | - Ort::Value cur_encoder_out = GetFrame(&encoder_out, t); | ||
| 102 | - cur_encoder_out = Repeat(model_->Allocator(), &cur_encoder_out, batch_size); | 93 | + Ort::Value cur_encoder_out = GetFrame(model_->Allocator(), &encoder_out, t); |
| 103 | Ort::Value logit = model_->RunJoiner( | 94 | Ort::Value logit = model_->RunJoiner( |
| 104 | std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); | 95 | std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); |
| 96 | + | ||
| 105 | const float *p_logit = logit.GetTensorData<float>(); | 97 | const float *p_logit = logit.GetTensorData<float>(); |
| 106 | 98 | ||
| 107 | bool emitted = false; | 99 | bool emitted = false; |
-
请 注册 或 登录 后发表评论