Fangjun Kuang
Committed by GitHub

Fix batch decoding for greedy search (#71)

@@ -10,47 +10,39 @@ @@ -10,47 +10,39 @@
10 #include <utility> 10 #include <utility>
11 #include <vector> 11 #include <vector>
12 12
  13 +#include "sherpa-onnx/csrc/macros.h"
13 #include "sherpa-onnx/csrc/onnx-utils.h" 14 #include "sherpa-onnx/csrc/onnx-utils.h"
14 15
15 namespace sherpa_onnx { 16 namespace sherpa_onnx {
16 17
17 -static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) { 18 +static Ort::Value GetFrame(OrtAllocator *allocator, Ort::Value *encoder_out,
  19 + int32_t t) {
18 std::vector<int64_t> encoder_out_shape = 20 std::vector<int64_t> encoder_out_shape =
19 encoder_out->GetTensorTypeAndShapeInfo().GetShape(); 21 encoder_out->GetTensorTypeAndShapeInfo().GetShape();
20 - assert(encoder_out_shape[0] == 1);  
21 22
22 - int32_t encoder_out_dim = encoder_out_shape[2]; 23 + auto batch_size = encoder_out_shape[0];
  24 + auto num_frames = encoder_out_shape[1];
  25 + assert(t < num_frames);
23 26
24 - auto memory_info =  
25 - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);  
26 -  
27 - std::array<int64_t, 2> shape{1, encoder_out_dim};  
28 -  
29 - return Ort::Value::CreateTensor(  
30 - memory_info,  
31 - encoder_out->GetTensorMutableData<float>() + t * encoder_out_dim,  
32 - encoder_out_dim, shape.data(), shape.size());  
33 -} 27 + auto encoder_out_dim = encoder_out_shape[2];
34 28
35 -static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,  
36 - int32_t n) {  
37 - if (n == 1) {  
38 - return std::move(*cur_encoder_out);  
39 - } 29 + auto offset = num_frames * encoder_out_dim;
40 30
41 - std::vector<int64_t> cur_encoder_out_shape =  
42 - cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape(); 31 + auto memory_info =
  32 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
43 33
44 - std::array<int64_t, 2> ans_shape{n, cur_encoder_out_shape[1]}; 34 + std::array<int64_t, 2> shape{batch_size, encoder_out_dim};
45 35
46 - Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(),  
47 - ans_shape.size()); 36 + Ort::Value ans =
  37 + Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
48 38
49 - const float *src = cur_encoder_out->GetTensorData<float>();  
50 float *dst = ans.GetTensorMutableData<float>(); 39 float *dst = ans.GetTensorMutableData<float>();
51 - for (int32_t i = 0; i != n; ++i) {  
52 - std::copy(src, src + cur_encoder_out_shape[1], dst);  
53 - dst += cur_encoder_out_shape[1]; 40 + const float *src = encoder_out->GetTensorData<float>();
  41 +
  42 + for (int32_t i = 0; i != batch_size; ++i) {
  43 + std::copy(src + t * encoder_out_dim, src + (t + 1) * encoder_out_dim, dst);
  44 + src += offset;
  45 + dst += encoder_out_dim;
54 } 46 }
55 47
56 return ans; 48 return ans;
@@ -83,10 +75,10 @@ void OnlineTransducerGreedySearchDecoder::Decode( @@ -83,10 +75,10 @@ void OnlineTransducerGreedySearchDecoder::Decode(
83 encoder_out.GetTensorTypeAndShapeInfo().GetShape(); 75 encoder_out.GetTensorTypeAndShapeInfo().GetShape();
84 76
85 if (encoder_out_shape[0] != result->size()) { 77 if (encoder_out_shape[0] != result->size()) {
86 - fprintf(stderr,  
87 - "Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",  
88 - static_cast<int32_t>(encoder_out_shape[0]),  
89 - static_cast<int32_t>(result->size())); 78 + SHERPA_ONNX_LOGE(
  79 + "Size mismatch! encoder_out.size(0) %d, result.size(0): %d",
  80 + static_cast<int32_t>(encoder_out_shape[0]),
  81 + static_cast<int32_t>(result->size()));
90 exit(-1); 82 exit(-1);
91 } 83 }
92 84
@@ -98,10 +90,10 @@ void OnlineTransducerGreedySearchDecoder::Decode( @@ -98,10 +90,10 @@ void OnlineTransducerGreedySearchDecoder::Decode(
98 Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); 90 Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
99 91
100 for (int32_t t = 0; t != num_frames; ++t) { 92 for (int32_t t = 0; t != num_frames; ++t) {
101 - Ort::Value cur_encoder_out = GetFrame(&encoder_out, t);  
102 - cur_encoder_out = Repeat(model_->Allocator(), &cur_encoder_out, batch_size); 93 + Ort::Value cur_encoder_out = GetFrame(model_->Allocator(), &encoder_out, t);
103 Ort::Value logit = model_->RunJoiner( 94 Ort::Value logit = model_->RunJoiner(
104 std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); 95 std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out));
  96 +
105 const float *p_logit = logit.GetTensorData<float>(); 97 const float *p_logit = logit.GetTensorData<float>();
106 98
107 bool emitted = false; 99 bool emitted = false;