Fangjun Kuang
Committed by GitHub

Support batch greedy search decoding (#30)

@@ -3,6 +3,7 @@ @@ -3,6 +3,7 @@
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
4 #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" 4 #include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
5 5
  6 +#include <algorithm>
6 #include <memory> 7 #include <memory>
7 #include <sstream> 8 #include <sstream>
8 #include <string> 9 #include <string>
@@ -10,6 +11,7 @@ @@ -10,6 +11,7 @@
10 #include <vector> 11 #include <vector>
11 12
12 #include "onnxruntime_cxx_api.h" // NOLINT 13 #include "onnxruntime_cxx_api.h" // NOLINT
  14 +#include "sherpa-onnx/csrc/online-transducer-decoder.h"
13 #include "sherpa-onnx/csrc/onnx-utils.h" 15 #include "sherpa-onnx/csrc/onnx-utils.h"
14 16
15 #define SHERPA_ONNX_READ_META_DATA(dst, src_key) \ 17 #define SHERPA_ONNX_READ_META_DATA(dst, src_key) \
@@ -114,23 +116,85 @@ void OnlineLstmTransducerModel::InitJoiner(const std::string &filename) { @@ -114,23 +116,85 @@ void OnlineLstmTransducerModel::InitJoiner(const std::string &filename) {
114 } 116 }
115 } 117 }
116 118
117 -Ort::Value OnlineLstmTransducerModel::StackStates(  
118 - const std::vector<Ort::Value> &states) const {  
119 - fprintf(stderr, "implement me: %s:%d!\n", __func__,  
120 - static_cast<int>(__LINE__));  
121 - auto memory_info =  
122 - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);  
123 - int64_t a;  
124 - std::array<int64_t, 3> x_shape{1, 1, 1};  
125 - Ort::Value x = Ort::Value::CreateTensor(memory_info, &a, 0, &a, 0);  
126 - return x; 119 +std::vector<Ort::Value> OnlineLstmTransducerModel::StackStates(
  120 + const std::vector<std::vector<Ort::Value>> &states) const {
  121 + int32_t batch_size = static_cast<int32_t>(states.size());
  122 +
  123 + std::array<int64_t, 3> h_shape{num_encoder_layers_, batch_size, d_model_};
  124 + Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
  125 + h_shape.size());
  126 +
  127 + std::array<int64_t, 3> c_shape{num_encoder_layers_, batch_size,
  128 + rnn_hidden_size_};
  129 +
  130 + Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(),
  131 + c_shape.size());
  132 +
  133 + float *dst_h = h.GetTensorMutableData<float>();
  134 + float *dst_c = c.GetTensorMutableData<float>();
  135 +
  136 + for (int32_t layer = 0; layer != num_encoder_layers_; ++layer) {
  137 + for (int32_t i = 0; i != batch_size; ++i) {
  138 + const float *src_h =
  139 + states[i][0].GetTensorData<float>() + layer * d_model_;
  140 +
  141 + const float *src_c =
  142 + states[i][1].GetTensorData<float>() + layer * rnn_hidden_size_;
  143 +
  144 + std::copy(src_h, src_h + d_model_, dst_h);
  145 + std::copy(src_c, src_c + rnn_hidden_size_, dst_c);
  146 +
  147 + dst_h += d_model_;
  148 + dst_c += rnn_hidden_size_;
  149 + }
  150 + }
  151 +
  152 + std::vector<Ort::Value> ans;
  153 +
  154 + ans.reserve(2);
  155 + ans.push_back(std::move(h));
  156 + ans.push_back(std::move(c));
  157 +
  158 + return ans;
127 } 159 }
128 160
129 -std::vector<Ort::Value> OnlineLstmTransducerModel::UnStackStates(  
130 - Ort::Value states) const {  
131 - fprintf(stderr, "implement me: %s:%d!\n", __func__,  
132 - static_cast<int>(__LINE__));  
133 - return {}; 161 +std::vector<std::vector<Ort::Value>> OnlineLstmTransducerModel::UnStackStates(
  162 + const std::vector<Ort::Value> &states) const {
  163 + int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
  164 +
  165 + std::vector<std::vector<Ort::Value>> ans(batch_size);
  166 +
  167 + // allocate space
  168 + std::array<int64_t, 3> h_shape{num_encoder_layers_, 1, d_model_};
  169 + std::array<int64_t, 3> c_shape{num_encoder_layers_, 1, rnn_hidden_size_};
  170 +
  171 + for (int32_t i = 0; i != batch_size; ++i) {
  172 + Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
  173 + h_shape.size());
  174 + Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(),
  175 + c_shape.size());
  176 + ans[i].push_back(std::move(h));
  177 + ans[i].push_back(std::move(c));
  178 + }
  179 +
  180 + for (int32_t layer = 0; layer != num_encoder_layers_; ++layer) {
  181 + for (int32_t i = 0; i != batch_size; ++i) {
  182 + const float *src_h = states[0].GetTensorData<float>() +
  183 + layer * batch_size * d_model_ + i * d_model_;
  184 + const float *src_c = states[1].GetTensorData<float>() +
  185 + layer * batch_size * rnn_hidden_size_ +
  186 + i * rnn_hidden_size_;
  187 +
  188 + float *dst_h = ans[i][0].GetTensorMutableData<float>() + layer * d_model_;
  189 + float *dst_c =
  190 + ans[i][1].GetTensorMutableData<float>() + layer * rnn_hidden_size_;
  191 +
  192 + std::copy(src_h, src_h + d_model_, dst_h);
  193 + std::copy(src_c, src_c + rnn_hidden_size_, dst_c);
  194 + }
  195 + }
  196 +
  197 + return ans;
134 } 198 }
135 199
136 std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() { 200 std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() {
@@ -189,16 +253,21 @@ OnlineLstmTransducerModel::RunEncoder(Ort::Value features, @@ -189,16 +253,21 @@ OnlineLstmTransducerModel::RunEncoder(Ort::Value features,
189 } 253 }
190 254
191 Ort::Value OnlineLstmTransducerModel::BuildDecoderInput( 255 Ort::Value OnlineLstmTransducerModel::BuildDecoderInput(
192 - const std::vector<int64_t> &hyp) {  
193 - auto memory_info =  
194 - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);  
195 -  
196 - std::array<int64_t, 2> shape{1, context_size_}; 256 + const std::vector<OnlineTransducerDecoderResult> &results) {
  257 + int32_t batch_size = static_cast<int32_t>(results.size());
  258 + std::array<int64_t, 2> shape{batch_size, context_size_};
  259 + Ort::Value decoder_input =
  260 + Ort::Value::CreateTensor<int64_t>(allocator_, shape.data(), shape.size());
  261 + int64_t *p = decoder_input.GetTensorMutableData<int64_t>();
  262 +
  263 + for (const auto &r : results) {
  264 + const int64_t *begin = r.tokens.data() + r.tokens.size() - context_size_;
  265 + const int64_t *end = r.tokens.data() + r.tokens.size();
  266 + std::copy(begin, end, p);
  267 + p += context_size_;
  268 + }
197 269
198 - return Ort::Value::CreateTensor(  
199 - memory_info,  
200 - const_cast<int64_t *>(hyp.data() + hyp.size() - context_size_),  
201 - context_size_, shape.data(), shape.size()); 270 + return decoder_input;
202 } 271 }
203 272
204 Ort::Value OnlineLstmTransducerModel::RunDecoder(Ort::Value decoder_input) { 273 Ort::Value OnlineLstmTransducerModel::RunDecoder(Ort::Value decoder_input) {
@@ -19,16 +19,19 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { @@ -19,16 +19,19 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
19 public: 19 public:
20 explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config); 20 explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config);
21 21
22 - Ort::Value StackStates(const std::vector<Ort::Value> &states) const override; 22 + std::vector<Ort::Value> StackStates(
  23 + const std::vector<std::vector<Ort::Value>> &states) const override;
23 24
24 - std::vector<Ort::Value> UnStackStates(Ort::Value states) const override; 25 + std::vector<std::vector<Ort::Value>> UnStackStates(
  26 + const std::vector<Ort::Value> &states) const override;
25 27
26 std::vector<Ort::Value> GetEncoderInitStates() override; 28 std::vector<Ort::Value> GetEncoderInitStates() override;
27 29
28 std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( 30 std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
29 Ort::Value features, std::vector<Ort::Value> &states) override; 31 Ort::Value features, std::vector<Ort::Value> &states) override;
30 32
31 - Ort::Value BuildDecoderInput(const std::vector<int64_t> &hyp) override; 33 + Ort::Value BuildDecoderInput(
  34 + const std::vector<OnlineTransducerDecoderResult> &results) override;
32 35
33 Ort::Value RunDecoder(Ort::Value decoder_input) override; 36 Ort::Value RunDecoder(Ort::Value decoder_input) override;
34 37
@@ -41,6 +44,7 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { @@ -41,6 +44,7 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
41 int32_t ChunkShift() const override { return decode_chunk_len_; } 44 int32_t ChunkShift() const override { return decode_chunk_len_; }
42 45
43 int32_t VocabSize() const override { return vocab_size_; } 46 int32_t VocabSize() const override { return vocab_size_; }
  47 + OrtAllocator *Allocator() override { return allocator_; }
44 48
45 private: 49 private:
46 void InitEncoder(const std::string &encoder_filename); 50 void InitEncoder(const std::string &encoder_filename);
@@ -50,7 +54,6 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { @@ -50,7 +54,6 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
50 private: 54 private:
51 Ort::Env env_; 55 Ort::Env env_;
52 Ort::SessionOptions sess_opts_; 56 Ort::SessionOptions sess_opts_;
53 -  
54 Ort::AllocatorWithDefaultOptions allocator_; 57 Ort::AllocatorWithDefaultOptions allocator_;
55 58
56 std::unique_ptr<Ort::Session> encoder_sess_; 59 std::unique_ptr<Ort::Session> encoder_sess_;
@@ -6,6 +6,7 @@ @@ -6,6 +6,7 @@
6 6
7 #include <assert.h> 7 #include <assert.h>
8 8
  9 +#include <algorithm>
9 #include <memory> 10 #include <memory>
10 #include <sstream> 11 #include <sstream>
11 #include <utility> 12 #include <utility>
@@ -64,39 +65,50 @@ class OnlineRecognizer::Impl { @@ -64,39 +65,50 @@ class OnlineRecognizer::Impl {
64 } 65 }
65 66
66 void DecodeStreams(OnlineStream **ss, int32_t n) { 67 void DecodeStreams(OnlineStream **ss, int32_t n) {
67 - if (n != 1) {  
68 - fprintf(stderr, "only n == 1 is implemented\n");  
69 - exit(-1);  
70 - }  
71 - OnlineStream *s = ss[0];  
72 - assert(IsReady(s));  
73 -  
74 int32_t chunk_size = model_->ChunkSize(); 68 int32_t chunk_size = model_->ChunkSize();
75 int32_t chunk_shift = model_->ChunkShift(); 69 int32_t chunk_shift = model_->ChunkShift();
76 70
77 - int32_t feature_dim = s->FeatureDim(); 71 + int32_t feature_dim = ss[0]->FeatureDim();
  72 +
  73 + std::vector<OnlineTransducerDecoderResult> results(n);
  74 + std::vector<float> features_vec(n * chunk_size * feature_dim);
  75 + std::vector<std::vector<Ort::Value>> states_vec(n);
78 76
79 - std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim}; 77 + for (int32_t i = 0; i != n; ++i) {
  78 + std::vector<float> features =
  79 + ss[i]->GetFrames(ss[i]->GetNumProcessedFrames(), chunk_size);
  80 +
  81 + ss[i]->GetNumProcessedFrames() += chunk_shift;
  82 +
  83 + std::copy(features.begin(), features.end(),
  84 + features_vec.data() + i * chunk_size * feature_dim);
  85 +
  86 + results[i] = std::move(ss[i]->GetResult());
  87 + states_vec[i] = std::move(ss[i]->GetStates());
  88 + }
80 89
81 auto memory_info = 90 auto memory_info =
82 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); 91 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
83 92
84 - std::vector<float> features =  
85 - s->GetFrames(s->GetNumProcessedFrames(), chunk_size);  
86 -  
87 - s->GetNumProcessedFrames() += chunk_shift; 93 + std::array<int64_t, 3> x_shape{n, chunk_size, feature_dim};
88 94
89 - Ort::Value x =  
90 - Ort::Value::CreateTensor(memory_info, features.data(), features.size(),  
91 - x_shape.data(), x_shape.size()); 95 + Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(),
  96 + features_vec.size(), x_shape.data(),
  97 + x_shape.size());
92 98
93 - auto pair = model_->RunEncoder(std::move(x), s->GetStates()); 99 + auto states = model_->StackStates(states_vec);
94 100
95 - s->SetStates(std::move(pair.second));  
96 - std::vector<OnlineTransducerDecoderResult> results = {s->GetResult()}; 101 + auto pair = model_->RunEncoder(std::move(x), states);
97 102
98 decoder_->Decode(std::move(pair.first), &results); 103 decoder_->Decode(std::move(pair.first), &results);
99 - s->SetResult(results[0]); 104 +
  105 + std::vector<std::vector<Ort::Value>> next_states =
  106 + model_->UnStackStates(pair.second);
  107 +
  108 + for (int32_t i = 0; i != n; ++i) {
  109 + ss[i]->SetResult(results[i]);
  110 + ss[i]->SetStates(std::move(next_states[i]));
  111 + }
100 } 112 }
101 113
102 OnlineRecognizerResult GetResult(OnlineStream *s) { 114 OnlineRecognizerResult GetResult(OnlineStream *s) {
@@ -32,6 +32,30 @@ static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) { @@ -32,6 +32,30 @@ static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) {
32 encoder_out_dim, shape.data(), shape.size()); 32 encoder_out_dim, shape.data(), shape.size());
33 } 33 }
34 34
  35 +static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
  36 + int32_t n) {
  37 + if (n == 1) {
  38 + return std::move(*cur_encoder_out);
  39 + }
  40 +
  41 + std::vector<int64_t> cur_encoder_out_shape =
  42 + cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape();
  43 +
  44 + std::array<int64_t, 2> ans_shape{n, cur_encoder_out_shape[1]};
  45 +
  46 + Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(),
  47 + ans_shape.size());
  48 +
  49 + const float *src = cur_encoder_out->GetTensorData<float>();
  50 + float *dst = ans.GetTensorMutableData<float>();
  51 + for (int32_t i = 0; i != n; ++i) {
  52 + std::copy(src, src + cur_encoder_out_shape[1], dst);
  53 + dst += cur_encoder_out_shape[1];
  54 + }
  55 +
  56 + return ans;
  57 +}
  58 +
35 OnlineTransducerDecoderResult 59 OnlineTransducerDecoderResult
36 OnlineTransducerGreedySearchDecoder::GetEmptyResult() const { 60 OnlineTransducerGreedySearchDecoder::GetEmptyResult() const {
37 int32_t context_size = model_->ContextSize(); 61 int32_t context_size = model_->ContextSize();
@@ -66,33 +90,33 @@ void OnlineTransducerGreedySearchDecoder::Decode( @@ -66,33 +90,33 @@ void OnlineTransducerGreedySearchDecoder::Decode(
66 exit(-1); 90 exit(-1);
67 } 91 }
68 92
69 - if (result->size() != 1) {  
70 - fprintf(stderr, "only batch size == 1 is implemented. Given: %d",  
71 - static_cast<int32_t>(result->size()));  
72 - exit(-1);  
73 - }  
74 -  
75 - auto &hyp = (*result)[0].tokens;  
76 -  
77 - int32_t num_frames = encoder_out_shape[1]; 93 + int32_t batch_size = static_cast<int32_t>(encoder_out_shape[0]);
  94 + int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]);
78 int32_t vocab_size = model_->VocabSize(); 95 int32_t vocab_size = model_->VocabSize();
79 96
80 - Ort::Value decoder_input = model_->BuildDecoderInput(hyp); 97 + Ort::Value decoder_input = model_->BuildDecoderInput(*result);
81 Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); 98 Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
82 99
83 for (int32_t t = 0; t != num_frames; ++t) { 100 for (int32_t t = 0; t != num_frames; ++t) {
84 Ort::Value cur_encoder_out = GetFrame(&encoder_out, t); 101 Ort::Value cur_encoder_out = GetFrame(&encoder_out, t);
  102 + cur_encoder_out = Repeat(model_->Allocator(), &cur_encoder_out, batch_size);
85 Ort::Value logit = 103 Ort::Value logit =
86 model_->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out)); 104 model_->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out));
87 const float *p_logit = logit.GetTensorData<float>(); 105 const float *p_logit = logit.GetTensorData<float>();
88 106
89 - auto y = static_cast<int32_t>(std::distance(  
90 - static_cast<const float *>(p_logit),  
91 - std::max_element(static_cast<const float *>(p_logit),  
92 - static_cast<const float *>(p_logit) + vocab_size)));  
93 - if (y != 0) {  
94 - hyp.push_back(y);  
95 - decoder_input = model_->BuildDecoderInput(hyp); 107 + bool emitted = false;
  108 + for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) {
  109 + auto y = static_cast<int32_t>(std::distance(
  110 + static_cast<const float *>(p_logit),
  111 + std::max_element(static_cast<const float *>(p_logit),
  112 + static_cast<const float *>(p_logit) + vocab_size)));
  113 + if (y != 0) {
  114 + emitted = true;
  115 + (*result)[i].tokens.push_back(y);
  116 + }
  117 + }
  118 + if (emitted) {
  119 + decoder_input = model_->BuildDecoderInput(*result);
96 decoder_out = model_->RunDecoder(std::move(decoder_input)); 120 decoder_out = model_->RunDecoder(std::move(decoder_input));
97 } 121 }
98 } 122 }
@@ -13,6 +13,8 @@ @@ -13,6 +13,8 @@
13 13
14 namespace sherpa_onnx { 14 namespace sherpa_onnx {
15 15
  16 +class OnlineTransducerDecoderResult;
  17 +
16 class OnlineTransducerModel { 18 class OnlineTransducerModel {
17 public: 19 public:
18 virtual ~OnlineTransducerModel() = default; 20 virtual ~OnlineTransducerModel() = default;
@@ -27,8 +29,8 @@ class OnlineTransducerModel { @@ -27,8 +29,8 @@ class OnlineTransducerModel {
27 * @param states states[i] contains the state for the i-th utterance. 29 * @param states states[i] contains the state for the i-th utterance.
28 * @return Return a single value representing the batched state. 30 * @return Return a single value representing the batched state.
29 */ 31 */
30 - virtual Ort::Value StackStates(  
31 - const std::vector<Ort::Value> &states) const = 0; 32 + virtual std::vector<Ort::Value> StackStates(
  33 + const std::vector<std::vector<Ort::Value>> &states) const = 0;
32 34
33 /** Unstack a batch state into a list of individual states. 35 /** Unstack a batch state into a list of individual states.
34 * 36 *
@@ -37,7 +39,8 @@ class OnlineTransducerModel { @@ -37,7 +39,8 @@ class OnlineTransducerModel {
37 * @param states A batched state. 39 * @param states A batched state.
38 * @return ans[i] contains the state for the i-th utterance. 40 * @return ans[i] contains the state for the i-th utterance.
39 */ 41 */
40 - virtual std::vector<Ort::Value> UnStackStates(Ort::Value states) const = 0; 42 + virtual std::vector<std::vector<Ort::Value>> UnStackStates(
  43 + const std::vector<Ort::Value> &states) const = 0;
41 44
42 /** Get the initial encoder states. 45 /** Get the initial encoder states.
43 * 46 *
@@ -58,7 +61,8 @@ class OnlineTransducerModel { @@ -58,7 +61,8 @@ class OnlineTransducerModel {
58 Ort::Value features, 61 Ort::Value features,
59 std::vector<Ort::Value> &states) = 0; // NOLINT 62 std::vector<Ort::Value> &states) = 0; // NOLINT
60 63
61 - virtual Ort::Value BuildDecoderInput(const std::vector<int64_t> &hyp) = 0; 64 + virtual Ort::Value BuildDecoderInput(
  65 + const std::vector<OnlineTransducerDecoderResult> &results) = 0;
62 66
63 /** Run the decoder network. 67 /** Run the decoder network.
64 * 68 *
@@ -111,6 +115,7 @@ class OnlineTransducerModel { @@ -111,6 +115,7 @@ class OnlineTransducerModel {
111 virtual int32_t VocabSize() const = 0; 115 virtual int32_t VocabSize() const = 0;
112 116
113 virtual int32_t SubsamplingFactor() const { return 4; } 117 virtual int32_t SubsamplingFactor() const { return 4; }
  118 + virtual OrtAllocator *Allocator() = 0;
114 }; 119 };
115 120
116 } // namespace sherpa_onnx 121 } // namespace sherpa_onnx