Committed by
GitHub
Support batch greedy search decoding (#30)
正在显示
5 个修改的文件
包含
182 行增加
和
69 行删除
| @@ -3,6 +3,7 @@ | @@ -3,6 +3,7 @@ | ||
| 3 | // Copyright (c) 2023 Xiaomi Corporation | 3 | // Copyright (c) 2023 Xiaomi Corporation |
| 4 | #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" | 4 | #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" |
| 5 | 5 | ||
| 6 | +#include <algorithm> | ||
| 6 | #include <memory> | 7 | #include <memory> |
| 7 | #include <sstream> | 8 | #include <sstream> |
| 8 | #include <string> | 9 | #include <string> |
| @@ -10,6 +11,7 @@ | @@ -10,6 +11,7 @@ | ||
| 10 | #include <vector> | 11 | #include <vector> |
| 11 | 12 | ||
| 12 | #include "onnxruntime_cxx_api.h" // NOLINT | 13 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 14 | +#include "sherpa-onnx/csrc/online-transducer-decoder.h" | ||
| 13 | #include "sherpa-onnx/csrc/onnx-utils.h" | 15 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 14 | 16 | ||
| 15 | #define SHERPA_ONNX_READ_META_DATA(dst, src_key) \ | 17 | #define SHERPA_ONNX_READ_META_DATA(dst, src_key) \ |
| @@ -114,23 +116,85 @@ void OnlineLstmTransducerModel::InitJoiner(const std::string &filename) { | @@ -114,23 +116,85 @@ void OnlineLstmTransducerModel::InitJoiner(const std::string &filename) { | ||
| 114 | } | 116 | } |
| 115 | } | 117 | } |
| 116 | 118 | ||
| 117 | -Ort::Value OnlineLstmTransducerModel::StackStates( | ||
| 118 | - const std::vector<Ort::Value> &states) const { | ||
| 119 | - fprintf(stderr, "implement me: %s:%d!\n", __func__, | ||
| 120 | - static_cast<int>(__LINE__)); | ||
| 121 | - auto memory_info = | ||
| 122 | - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 123 | - int64_t a; | ||
| 124 | - std::array<int64_t, 3> x_shape{1, 1, 1}; | ||
| 125 | - Ort::Value x = Ort::Value::CreateTensor(memory_info, &a, 0, &a, 0); | ||
| 126 | - return x; | 119 | +std::vector<Ort::Value> OnlineLstmTransducerModel::StackStates( |
| 120 | + const std::vector<std::vector<Ort::Value>> &states) const { | ||
| 121 | + int32_t batch_size = static_cast<int32_t>(states.size()); | ||
| 122 | + | ||
| 123 | + std::array<int64_t, 3> h_shape{num_encoder_layers_, batch_size, d_model_}; | ||
| 124 | + Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(), | ||
| 125 | + h_shape.size()); | ||
| 126 | + | ||
| 127 | + std::array<int64_t, 3> c_shape{num_encoder_layers_, batch_size, | ||
| 128 | + rnn_hidden_size_}; | ||
| 129 | + | ||
| 130 | + Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(), | ||
| 131 | + c_shape.size()); | ||
| 132 | + | ||
| 133 | + float *dst_h = h.GetTensorMutableData<float>(); | ||
| 134 | + float *dst_c = c.GetTensorMutableData<float>(); | ||
| 135 | + | ||
| 136 | + for (int32_t layer = 0; layer != num_encoder_layers_; ++layer) { | ||
| 137 | + for (int32_t i = 0; i != batch_size; ++i) { | ||
| 138 | + const float *src_h = | ||
| 139 | + states[i][0].GetTensorData<float>() + layer * d_model_; | ||
| 140 | + | ||
| 141 | + const float *src_c = | ||
| 142 | + states[i][1].GetTensorData<float>() + layer * rnn_hidden_size_; | ||
| 143 | + | ||
| 144 | + std::copy(src_h, src_h + d_model_, dst_h); | ||
| 145 | + std::copy(src_c, src_c + rnn_hidden_size_, dst_c); | ||
| 146 | + | ||
| 147 | + dst_h += d_model_; | ||
| 148 | + dst_c += rnn_hidden_size_; | ||
| 149 | + } | ||
| 150 | + } | ||
| 151 | + | ||
| 152 | + std::vector<Ort::Value> ans; | ||
| 153 | + | ||
| 154 | + ans.reserve(2); | ||
| 155 | + ans.push_back(std::move(h)); | ||
| 156 | + ans.push_back(std::move(c)); | ||
| 157 | + | ||
| 158 | + return ans; | ||
| 127 | } | 159 | } |
| 128 | 160 | ||
| 129 | -std::vector<Ort::Value> OnlineLstmTransducerModel::UnStackStates( | ||
| 130 | - Ort::Value states) const { | ||
| 131 | - fprintf(stderr, "implement me: %s:%d!\n", __func__, | ||
| 132 | - static_cast<int>(__LINE__)); | ||
| 133 | - return {}; | 161 | +std::vector<std::vector<Ort::Value>> OnlineLstmTransducerModel::UnStackStates( |
| 162 | + const std::vector<Ort::Value> &states) const { | ||
| 163 | + int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1]; | ||
| 164 | + | ||
| 165 | + std::vector<std::vector<Ort::Value>> ans(batch_size); | ||
| 166 | + | ||
| 167 | + // allocate space | ||
| 168 | + std::array<int64_t, 3> h_shape{num_encoder_layers_, 1, d_model_}; | ||
| 169 | + std::array<int64_t, 3> c_shape{num_encoder_layers_, 1, rnn_hidden_size_}; | ||
| 170 | + | ||
| 171 | + for (int32_t i = 0; i != batch_size; ++i) { | ||
| 172 | + Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(), | ||
| 173 | + h_shape.size()); | ||
| 174 | + Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(), | ||
| 175 | + c_shape.size()); | ||
| 176 | + ans[i].push_back(std::move(h)); | ||
| 177 | + ans[i].push_back(std::move(c)); | ||
| 178 | + } | ||
| 179 | + | ||
| 180 | + for (int32_t layer = 0; layer != num_encoder_layers_; ++layer) { | ||
| 181 | + for (int32_t i = 0; i != batch_size; ++i) { | ||
| 182 | + const float *src_h = states[0].GetTensorData<float>() + | ||
| 183 | + layer * batch_size * d_model_ + i * d_model_; | ||
| 184 | + const float *src_c = states[1].GetTensorData<float>() + | ||
| 185 | + layer * batch_size * rnn_hidden_size_ + | ||
| 186 | + i * rnn_hidden_size_; | ||
| 187 | + | ||
| 188 | + float *dst_h = ans[i][0].GetTensorMutableData<float>() + layer * d_model_; | ||
| 189 | + float *dst_c = | ||
| 190 | + ans[i][1].GetTensorMutableData<float>() + layer * rnn_hidden_size_; | ||
| 191 | + | ||
| 192 | + std::copy(src_h, src_h + d_model_, dst_h); | ||
| 193 | + std::copy(src_c, src_c + rnn_hidden_size_, dst_c); | ||
| 194 | + } | ||
| 195 | + } | ||
| 196 | + | ||
| 197 | + return ans; | ||
| 134 | } | 198 | } |
| 135 | 199 | ||
| 136 | std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() { | 200 | std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() { |
| @@ -189,16 +253,21 @@ OnlineLstmTransducerModel::RunEncoder(Ort::Value features, | @@ -189,16 +253,21 @@ OnlineLstmTransducerModel::RunEncoder(Ort::Value features, | ||
| 189 | } | 253 | } |
| 190 | 254 | ||
| 191 | Ort::Value OnlineLstmTransducerModel::BuildDecoderInput( | 255 | Ort::Value OnlineLstmTransducerModel::BuildDecoderInput( |
| 192 | - const std::vector<int64_t> &hyp) { | ||
| 193 | - auto memory_info = | ||
| 194 | - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 195 | - | ||
| 196 | - std::array<int64_t, 2> shape{1, context_size_}; | 256 | + const std::vector<OnlineTransducerDecoderResult> &results) { |
| 257 | + int32_t batch_size = static_cast<int32_t>(results.size()); | ||
| 258 | + std::array<int64_t, 2> shape{batch_size, context_size_}; | ||
| 259 | + Ort::Value decoder_input = | ||
| 260 | + Ort::Value::CreateTensor<int64_t>(allocator_, shape.data(), shape.size()); | ||
| 261 | + int64_t *p = decoder_input.GetTensorMutableData<int64_t>(); | ||
| 262 | + | ||
| 263 | + for (const auto &r : results) { | ||
| 264 | + const int64_t *begin = r.tokens.data() + r.tokens.size() - context_size_; | ||
| 265 | + const int64_t *end = r.tokens.data() + r.tokens.size(); | ||
| 266 | + std::copy(begin, end, p); | ||
| 267 | + p += context_size_; | ||
| 268 | + } | ||
| 197 | 269 | ||
| 198 | - return Ort::Value::CreateTensor( | ||
| 199 | - memory_info, | ||
| 200 | - const_cast<int64_t *>(hyp.data() + hyp.size() - context_size_), | ||
| 201 | - context_size_, shape.data(), shape.size()); | 270 | + return decoder_input; |
| 202 | } | 271 | } |
| 203 | 272 | ||
| 204 | Ort::Value OnlineLstmTransducerModel::RunDecoder(Ort::Value decoder_input) { | 273 | Ort::Value OnlineLstmTransducerModel::RunDecoder(Ort::Value decoder_input) { |
| @@ -19,16 +19,19 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { | @@ -19,16 +19,19 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { | ||
| 19 | public: | 19 | public: |
| 20 | explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config); | 20 | explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config); |
| 21 | 21 | ||
| 22 | - Ort::Value StackStates(const std::vector<Ort::Value> &states) const override; | 22 | + std::vector<Ort::Value> StackStates( |
| 23 | + const std::vector<std::vector<Ort::Value>> &states) const override; | ||
| 23 | 24 | ||
| 24 | - std::vector<Ort::Value> UnStackStates(Ort::Value states) const override; | 25 | + std::vector<std::vector<Ort::Value>> UnStackStates( |
| 26 | + const std::vector<Ort::Value> &states) const override; | ||
| 25 | 27 | ||
| 26 | std::vector<Ort::Value> GetEncoderInitStates() override; | 28 | std::vector<Ort::Value> GetEncoderInitStates() override; |
| 27 | 29 | ||
| 28 | std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( | 30 | std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( |
| 29 | Ort::Value features, std::vector<Ort::Value> &states) override; | 31 | Ort::Value features, std::vector<Ort::Value> &states) override; |
| 30 | 32 | ||
| 31 | - Ort::Value BuildDecoderInput(const std::vector<int64_t> &hyp) override; | 33 | + Ort::Value BuildDecoderInput( |
| 34 | + const std::vector<OnlineTransducerDecoderResult> &results) override; | ||
| 32 | 35 | ||
| 33 | Ort::Value RunDecoder(Ort::Value decoder_input) override; | 36 | Ort::Value RunDecoder(Ort::Value decoder_input) override; |
| 34 | 37 | ||
| @@ -41,6 +44,7 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { | @@ -41,6 +44,7 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { | ||
| 41 | int32_t ChunkShift() const override { return decode_chunk_len_; } | 44 | int32_t ChunkShift() const override { return decode_chunk_len_; } |
| 42 | 45 | ||
| 43 | int32_t VocabSize() const override { return vocab_size_; } | 46 | int32_t VocabSize() const override { return vocab_size_; } |
| 47 | + OrtAllocator *Allocator() override { return allocator_; } | ||
| 44 | 48 | ||
| 45 | private: | 49 | private: |
| 46 | void InitEncoder(const std::string &encoder_filename); | 50 | void InitEncoder(const std::string &encoder_filename); |
| @@ -50,7 +54,6 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { | @@ -50,7 +54,6 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { | ||
| 50 | private: | 54 | private: |
| 51 | Ort::Env env_; | 55 | Ort::Env env_; |
| 52 | Ort::SessionOptions sess_opts_; | 56 | Ort::SessionOptions sess_opts_; |
| 53 | - | ||
| 54 | Ort::AllocatorWithDefaultOptions allocator_; | 57 | Ort::AllocatorWithDefaultOptions allocator_; |
| 55 | 58 | ||
| 56 | std::unique_ptr<Ort::Session> encoder_sess_; | 59 | std::unique_ptr<Ort::Session> encoder_sess_; |
| @@ -6,6 +6,7 @@ | @@ -6,6 +6,7 @@ | ||
| 6 | 6 | ||
| 7 | #include <assert.h> | 7 | #include <assert.h> |
| 8 | 8 | ||
| 9 | +#include <algorithm> | ||
| 9 | #include <memory> | 10 | #include <memory> |
| 10 | #include <sstream> | 11 | #include <sstream> |
| 11 | #include <utility> | 12 | #include <utility> |
| @@ -64,39 +65,50 @@ class OnlineRecognizer::Impl { | @@ -64,39 +65,50 @@ class OnlineRecognizer::Impl { | ||
| 64 | } | 65 | } |
| 65 | 66 | ||
| 66 | void DecodeStreams(OnlineStream **ss, int32_t n) { | 67 | void DecodeStreams(OnlineStream **ss, int32_t n) { |
| 67 | - if (n != 1) { | ||
| 68 | - fprintf(stderr, "only n == 1 is implemented\n"); | ||
| 69 | - exit(-1); | ||
| 70 | - } | ||
| 71 | - OnlineStream *s = ss[0]; | ||
| 72 | - assert(IsReady(s)); | ||
| 73 | - | ||
| 74 | int32_t chunk_size = model_->ChunkSize(); | 68 | int32_t chunk_size = model_->ChunkSize(); |
| 75 | int32_t chunk_shift = model_->ChunkShift(); | 69 | int32_t chunk_shift = model_->ChunkShift(); |
| 76 | 70 | ||
| 77 | - int32_t feature_dim = s->FeatureDim(); | 71 | + int32_t feature_dim = ss[0]->FeatureDim(); |
| 72 | + | ||
| 73 | + std::vector<OnlineTransducerDecoderResult> results(n); | ||
| 74 | + std::vector<float> features_vec(n * chunk_size * feature_dim); | ||
| 75 | + std::vector<std::vector<Ort::Value>> states_vec(n); | ||
| 78 | 76 | ||
| 79 | - std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim}; | 77 | + for (int32_t i = 0; i != n; ++i) { |
| 78 | + std::vector<float> features = | ||
| 79 | + ss[i]->GetFrames(ss[i]->GetNumProcessedFrames(), chunk_size); | ||
| 80 | + | ||
| 81 | + ss[i]->GetNumProcessedFrames() += chunk_shift; | ||
| 82 | + | ||
| 83 | + std::copy(features.begin(), features.end(), | ||
| 84 | + features_vec.data() + i * chunk_size * feature_dim); | ||
| 85 | + | ||
| 86 | + results[i] = std::move(ss[i]->GetResult()); | ||
| 87 | + states_vec[i] = std::move(ss[i]->GetStates()); | ||
| 88 | + } | ||
| 80 | 89 | ||
| 81 | auto memory_info = | 90 | auto memory_info = |
| 82 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | 91 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); |
| 83 | 92 | ||
| 84 | - std::vector<float> features = | ||
| 85 | - s->GetFrames(s->GetNumProcessedFrames(), chunk_size); | ||
| 86 | - | ||
| 87 | - s->GetNumProcessedFrames() += chunk_shift; | 93 | + std::array<int64_t, 3> x_shape{n, chunk_size, feature_dim}; |
| 88 | 94 | ||
| 89 | - Ort::Value x = | ||
| 90 | - Ort::Value::CreateTensor(memory_info, features.data(), features.size(), | ||
| 91 | - x_shape.data(), x_shape.size()); | 95 | + Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(), |
| 96 | + features_vec.size(), x_shape.data(), | ||
| 97 | + x_shape.size()); | ||
| 92 | 98 | ||
| 93 | - auto pair = model_->RunEncoder(std::move(x), s->GetStates()); | 99 | + auto states = model_->StackStates(states_vec); |
| 94 | 100 | ||
| 95 | - s->SetStates(std::move(pair.second)); | ||
| 96 | - std::vector<OnlineTransducerDecoderResult> results = {s->GetResult()}; | 101 | + auto pair = model_->RunEncoder(std::move(x), states); |
| 97 | 102 | ||
| 98 | decoder_->Decode(std::move(pair.first), &results); | 103 | decoder_->Decode(std::move(pair.first), &results); |
| 99 | - s->SetResult(results[0]); | 104 | + |
| 105 | + std::vector<std::vector<Ort::Value>> next_states = | ||
| 106 | + model_->UnStackStates(pair.second); | ||
| 107 | + | ||
| 108 | + for (int32_t i = 0; i != n; ++i) { | ||
| 109 | + ss[i]->SetResult(results[i]); | ||
| 110 | + ss[i]->SetStates(std::move(next_states[i])); | ||
| 111 | + } | ||
| 100 | } | 112 | } |
| 101 | 113 | ||
| 102 | OnlineRecognizerResult GetResult(OnlineStream *s) { | 114 | OnlineRecognizerResult GetResult(OnlineStream *s) { |
| @@ -32,6 +32,30 @@ static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) { | @@ -32,6 +32,30 @@ static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) { | ||
| 32 | encoder_out_dim, shape.data(), shape.size()); | 32 | encoder_out_dim, shape.data(), shape.size()); |
| 33 | } | 33 | } |
| 34 | 34 | ||
| 35 | +static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, | ||
| 36 | + int32_t n) { | ||
| 37 | + if (n == 1) { | ||
| 38 | + return std::move(*cur_encoder_out); | ||
| 39 | + } | ||
| 40 | + | ||
| 41 | + std::vector<int64_t> cur_encoder_out_shape = | ||
| 42 | + cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 43 | + | ||
| 44 | + std::array<int64_t, 2> ans_shape{n, cur_encoder_out_shape[1]}; | ||
| 45 | + | ||
| 46 | + Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(), | ||
| 47 | + ans_shape.size()); | ||
| 48 | + | ||
| 49 | + const float *src = cur_encoder_out->GetTensorData<float>(); | ||
| 50 | + float *dst = ans.GetTensorMutableData<float>(); | ||
| 51 | + for (int32_t i = 0; i != n; ++i) { | ||
| 52 | + std::copy(src, src + cur_encoder_out_shape[1], dst); | ||
| 53 | + dst += cur_encoder_out_shape[1]; | ||
| 54 | + } | ||
| 55 | + | ||
| 56 | + return ans; | ||
| 57 | +} | ||
| 58 | + | ||
| 35 | OnlineTransducerDecoderResult | 59 | OnlineTransducerDecoderResult |
| 36 | OnlineTransducerGreedySearchDecoder::GetEmptyResult() const { | 60 | OnlineTransducerGreedySearchDecoder::GetEmptyResult() const { |
| 37 | int32_t context_size = model_->ContextSize(); | 61 | int32_t context_size = model_->ContextSize(); |
| @@ -66,33 +90,33 @@ void OnlineTransducerGreedySearchDecoder::Decode( | @@ -66,33 +90,33 @@ void OnlineTransducerGreedySearchDecoder::Decode( | ||
| 66 | exit(-1); | 90 | exit(-1); |
| 67 | } | 91 | } |
| 68 | 92 | ||
| 69 | - if (result->size() != 1) { | ||
| 70 | - fprintf(stderr, "only batch size == 1 is implemented. Given: %d", | ||
| 71 | - static_cast<int32_t>(result->size())); | ||
| 72 | - exit(-1); | ||
| 73 | - } | ||
| 74 | - | ||
| 75 | - auto &hyp = (*result)[0].tokens; | ||
| 76 | - | ||
| 77 | - int32_t num_frames = encoder_out_shape[1]; | 93 | + int32_t batch_size = static_cast<int32_t>(encoder_out_shape[0]); |
| 94 | + int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]); | ||
| 78 | int32_t vocab_size = model_->VocabSize(); | 95 | int32_t vocab_size = model_->VocabSize(); |
| 79 | 96 | ||
| 80 | - Ort::Value decoder_input = model_->BuildDecoderInput(hyp); | 97 | + Ort::Value decoder_input = model_->BuildDecoderInput(*result); |
| 81 | Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); | 98 | Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); |
| 82 | 99 | ||
| 83 | for (int32_t t = 0; t != num_frames; ++t) { | 100 | for (int32_t t = 0; t != num_frames; ++t) { |
| 84 | Ort::Value cur_encoder_out = GetFrame(&encoder_out, t); | 101 | Ort::Value cur_encoder_out = GetFrame(&encoder_out, t); |
| 102 | + cur_encoder_out = Repeat(model_->Allocator(), &cur_encoder_out, batch_size); | ||
| 85 | Ort::Value logit = | 103 | Ort::Value logit = |
| 86 | model_->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out)); | 104 | model_->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out)); |
| 87 | const float *p_logit = logit.GetTensorData<float>(); | 105 | const float *p_logit = logit.GetTensorData<float>(); |
| 88 | 106 | ||
| 89 | - auto y = static_cast<int32_t>(std::distance( | ||
| 90 | - static_cast<const float *>(p_logit), | ||
| 91 | - std::max_element(static_cast<const float *>(p_logit), | ||
| 92 | - static_cast<const float *>(p_logit) + vocab_size))); | ||
| 93 | - if (y != 0) { | ||
| 94 | - hyp.push_back(y); | ||
| 95 | - decoder_input = model_->BuildDecoderInput(hyp); | 107 | + bool emitted = false; |
| 108 | + for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) { | ||
| 109 | + auto y = static_cast<int32_t>(std::distance( | ||
| 110 | + static_cast<const float *>(p_logit), | ||
| 111 | + std::max_element(static_cast<const float *>(p_logit), | ||
| 112 | + static_cast<const float *>(p_logit) + vocab_size))); | ||
| 113 | + if (y != 0) { | ||
| 114 | + emitted = true; | ||
| 115 | + (*result)[i].tokens.push_back(y); | ||
| 116 | + } | ||
| 117 | + } | ||
| 118 | + if (emitted) { | ||
| 119 | + decoder_input = model_->BuildDecoderInput(*result); | ||
| 96 | decoder_out = model_->RunDecoder(std::move(decoder_input)); | 120 | decoder_out = model_->RunDecoder(std::move(decoder_input)); |
| 97 | } | 121 | } |
| 98 | } | 122 | } |
| @@ -13,6 +13,8 @@ | @@ -13,6 +13,8 @@ | ||
| 13 | 13 | ||
| 14 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| 15 | 15 | ||
| 16 | +class OnlineTransducerDecoderResult; | ||
| 17 | + | ||
| 16 | class OnlineTransducerModel { | 18 | class OnlineTransducerModel { |
| 17 | public: | 19 | public: |
| 18 | virtual ~OnlineTransducerModel() = default; | 20 | virtual ~OnlineTransducerModel() = default; |
| @@ -27,8 +29,8 @@ class OnlineTransducerModel { | @@ -27,8 +29,8 @@ class OnlineTransducerModel { | ||
| 27 | * @param states states[i] contains the state for the i-th utterance. | 29 | * @param states states[i] contains the state for the i-th utterance. |
| 28 | * @return Return a single value representing the batched state. | 30 | * @return Return a single value representing the batched state. |
| 29 | */ | 31 | */ |
| 30 | - virtual Ort::Value StackStates( | ||
| 31 | - const std::vector<Ort::Value> &states) const = 0; | 32 | + virtual std::vector<Ort::Value> StackStates( |
| 33 | + const std::vector<std::vector<Ort::Value>> &states) const = 0; | ||
| 32 | 34 | ||
| 33 | /** Unstack a batch state into a list of individual states. | 35 | /** Unstack a batch state into a list of individual states. |
| 34 | * | 36 | * |
| @@ -37,7 +39,8 @@ class OnlineTransducerModel { | @@ -37,7 +39,8 @@ class OnlineTransducerModel { | ||
| 37 | * @param states A batched state. | 39 | * @param states A batched state. |
| 38 | * @return ans[i] contains the state for the i-th utterance. | 40 | * @return ans[i] contains the state for the i-th utterance. |
| 39 | */ | 41 | */ |
| 40 | - virtual std::vector<Ort::Value> UnStackStates(Ort::Value states) const = 0; | 42 | + virtual std::vector<std::vector<Ort::Value>> UnStackStates( |
| 43 | + const std::vector<Ort::Value> &states) const = 0; | ||
| 41 | 44 | ||
| 42 | /** Get the initial encoder states. | 45 | /** Get the initial encoder states. |
| 43 | * | 46 | * |
| @@ -58,7 +61,8 @@ class OnlineTransducerModel { | @@ -58,7 +61,8 @@ class OnlineTransducerModel { | ||
| 58 | Ort::Value features, | 61 | Ort::Value features, |
| 59 | std::vector<Ort::Value> &states) = 0; // NOLINT | 62 | std::vector<Ort::Value> &states) = 0; // NOLINT |
| 60 | 63 | ||
| 61 | - virtual Ort::Value BuildDecoderInput(const std::vector<int64_t> &hyp) = 0; | 64 | + virtual Ort::Value BuildDecoderInput( |
| 65 | + const std::vector<OnlineTransducerDecoderResult> &results) = 0; | ||
| 62 | 66 | ||
| 63 | /** Run the decoder network. | 67 | /** Run the decoder network. |
| 64 | * | 68 | * |
| @@ -111,6 +115,7 @@ class OnlineTransducerModel { | @@ -111,6 +115,7 @@ class OnlineTransducerModel { | ||
| 111 | virtual int32_t VocabSize() const = 0; | 115 | virtual int32_t VocabSize() const = 0; |
| 112 | 116 | ||
| 113 | virtual int32_t SubsamplingFactor() const { return 4; } | 117 | virtual int32_t SubsamplingFactor() const { return 4; } |
| 118 | + virtual OrtAllocator *Allocator() = 0; | ||
| 114 | }; | 119 | }; |
| 115 | 120 | ||
| 116 | } // namespace sherpa_onnx | 121 | } // namespace sherpa_onnx |
-
请 注册 或 登录 后发表评论