danfu
Committed by GitHub

support streaming zipformer2 (#185)

Co-authored-by: danfu <danfu@tencent.com>
@@ -48,6 +48,7 @@ set(sources @@ -48,6 +48,7 @@ set(sources
48 online-transducer-model.cc 48 online-transducer-model.cc
49 online-transducer-modified-beam-search-decoder.cc 49 online-transducer-modified-beam-search-decoder.cc
50 online-zipformer-transducer-model.cc 50 online-zipformer-transducer-model.cc
  51 + online-zipformer2-transducer-model.cc
51 onnx-utils.cc 52 onnx-utils.cc
52 session.cc 53 session.cc
53 packed-sequence.cc 54 packed-sequence.cc
@@ -18,6 +18,7 @@ @@ -18,6 +18,7 @@
18 #include "sherpa-onnx/csrc/online-conformer-transducer-model.h" 18 #include "sherpa-onnx/csrc/online-conformer-transducer-model.h"
19 #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" 19 #include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
20 #include "sherpa-onnx/csrc/online-zipformer-transducer-model.h" 20 #include "sherpa-onnx/csrc/online-zipformer-transducer-model.h"
  21 +#include "sherpa-onnx/csrc/online-zipformer2-transducer-model.h"
21 #include "sherpa-onnx/csrc/onnx-utils.h" 22 #include "sherpa-onnx/csrc/onnx-utils.h"
22 23
23 namespace { 24 namespace {
@@ -26,6 +27,7 @@ enum class ModelType { @@ -26,6 +27,7 @@ enum class ModelType {
26 kConformer, 27 kConformer,
27 kLstm, 28 kLstm,
28 kZipformer, 29 kZipformer,
  30 + kZipformer2,
29 kUnkown, 31 kUnkown,
30 }; 32 };
31 33
@@ -65,6 +67,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -65,6 +67,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
65 return ModelType::kLstm; 67 return ModelType::kLstm;
66 } else if (model_type.get() == std::string("zipformer")) { 68 } else if (model_type.get() == std::string("zipformer")) {
67 return ModelType::kZipformer; 69 return ModelType::kZipformer;
  70 + } else if (model_type.get() == std::string("zipformer2")) {
  71 + return ModelType::kZipformer2;
68 } else { 72 } else {
69 SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); 73 SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
70 return ModelType::kUnkown; 74 return ModelType::kUnkown;
@@ -88,6 +92,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( @@ -88,6 +92,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
88 return std::make_unique<OnlineLstmTransducerModel>(config); 92 return std::make_unique<OnlineLstmTransducerModel>(config);
89 case ModelType::kZipformer: 93 case ModelType::kZipformer:
90 return std::make_unique<OnlineZipformerTransducerModel>(config); 94 return std::make_unique<OnlineZipformerTransducerModel>(config);
  95 + case ModelType::kZipformer2:
  96 + return std::make_unique<OnlineZipformer2TransducerModel>(config);
91 case ModelType::kUnkown: 97 case ModelType::kUnkown:
92 SHERPA_ONNX_LOGE("Unknown model type in online transducer!"); 98 SHERPA_ONNX_LOGE("Unknown model type in online transducer!");
93 return nullptr; 99 return nullptr;
@@ -144,6 +150,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( @@ -144,6 +150,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
144 return std::make_unique<OnlineLstmTransducerModel>(mgr, config); 150 return std::make_unique<OnlineLstmTransducerModel>(mgr, config);
145 case ModelType::kZipformer: 151 case ModelType::kZipformer:
146 return std::make_unique<OnlineZipformerTransducerModel>(mgr, config); 152 return std::make_unique<OnlineZipformerTransducerModel>(mgr, config);
  153 + case ModelType::kZipformer2:
  154 + return std::make_unique<OnlineZipformer2TransducerModel>(mgr, config);
147 case ModelType::kUnkown: 155 case ModelType::kUnkown:
148 SHERPA_ONNX_LOGE("Unknown model type in online transducer!"); 156 SHERPA_ONNX_LOGE("Unknown model type in online transducer!");
149 return nullptr; 157 return nullptr;
  1 +// sherpa-onnx/csrc/online-zipformer2-transducer-model.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/online-zipformer2-transducer-model.h"
  6 +
  7 +#include <assert.h>
  8 +#include <math.h>
  9 +
  10 +#include <algorithm>
  11 +#include <memory>
  12 +#include <sstream>
  13 +#include <string>
  14 +#include <utility>
  15 +#include <vector>
  16 +#include <numeric>
  17 +
  18 +#if __ANDROID_API__ >= 9
  19 +#include "android/asset_manager.h"
  20 +#include "android/asset_manager_jni.h"
  21 +#endif
  22 +
  23 +#include "onnxruntime_cxx_api.h" // NOLINT
  24 +#include "sherpa-onnx/csrc/cat.h"
  25 +#include "sherpa-onnx/csrc/macros.h"
  26 +#include "sherpa-onnx/csrc/online-transducer-decoder.h"
  27 +#include "sherpa-onnx/csrc/onnx-utils.h"
  28 +#include "sherpa-onnx/csrc/session.h"
  29 +#include "sherpa-onnx/csrc/text-utils.h"
  30 +#include "sherpa-onnx/csrc/unbind.h"
  31 +
  32 +namespace sherpa_onnx {
  33 +
  34 +OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel(
  35 + const OnlineTransducerModelConfig &config)
  36 + : env_(ORT_LOGGING_LEVEL_WARNING),
  37 + config_(config),
  38 + sess_opts_(GetSessionOptions(config)),
  39 + allocator_{} {
  40 + {
  41 + auto buf = ReadFile(config.encoder_filename);
  42 + InitEncoder(buf.data(), buf.size());
  43 + }
  44 +
  45 + {
  46 + auto buf = ReadFile(config.decoder_filename);
  47 + InitDecoder(buf.data(), buf.size());
  48 + }
  49 +
  50 + {
  51 + auto buf = ReadFile(config.joiner_filename);
  52 + InitJoiner(buf.data(), buf.size());
  53 + }
  54 +}
  55 +
  56 +#if __ANDROID_API__ >= 9
  57 +OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel(
  58 + AAssetManager *mgr, const OnlineTransducerModelConfig &config)
  59 + : env_(ORT_LOGGING_LEVEL_WARNING),
  60 + config_(config),
  61 + sess_opts_(GetSessionOptions(config)),
  62 + allocator_{} {
  63 + {
  64 + auto buf = ReadFile(mgr, config.encoder_filename);
  65 + InitEncoder(buf.data(), buf.size());
  66 + }
  67 +
  68 + {
  69 + auto buf = ReadFile(mgr, config.decoder_filename);
  70 + InitDecoder(buf.data(), buf.size());
  71 + }
  72 +
  73 + {
  74 + auto buf = ReadFile(mgr, config.joiner_filename);
  75 + InitJoiner(buf.data(), buf.size());
  76 + }
  77 +}
  78 +#endif
  79 +
  80 +void OnlineZipformer2TransducerModel::InitEncoder(void *model_data,
  81 + size_t model_data_length) {
  82 + encoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
  83 + model_data_length, sess_opts_);
  84 +
  85 + GetInputNames(encoder_sess_.get(), &encoder_input_names_,
  86 + &encoder_input_names_ptr_);
  87 +
  88 + GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
  89 + &encoder_output_names_ptr_);
  90 +
  91 + // get meta data
  92 + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
  93 + if (config_.debug) {
  94 + std::ostringstream os;
  95 + os << "---encoder---\n";
  96 + PrintModelMetadata(os, meta_data);
  97 + SHERPA_ONNX_LOGE("%s", os.str().c_str());
  98 + }
  99 +
  100 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  101 + SHERPA_ONNX_READ_META_DATA_VEC(encoder_dims_, "encoder_dims");
  102 + SHERPA_ONNX_READ_META_DATA_VEC(query_head_dims_, "query_head_dims");
  103 + SHERPA_ONNX_READ_META_DATA_VEC(value_head_dims_, "value_head_dims");
  104 + SHERPA_ONNX_READ_META_DATA_VEC(num_heads_, "num_heads");
  105 + SHERPA_ONNX_READ_META_DATA_VEC(num_encoder_layers_, "num_encoder_layers");
  106 + SHERPA_ONNX_READ_META_DATA_VEC(cnn_module_kernels_, "cnn_module_kernels");
  107 + SHERPA_ONNX_READ_META_DATA_VEC(left_context_len_, "left_context_len");
  108 +
  109 + SHERPA_ONNX_READ_META_DATA(T_, "T");
  110 + SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len");
  111 +
  112 + if (config_.debug) {
  113 + auto print = [](const std::vector<int32_t> &v, const char *name) {
  114 + fprintf(stderr, "%s: ", name);
  115 + for (auto i : v) {
  116 + fprintf(stderr, "%d ", i);
  117 + }
  118 + fprintf(stderr, "\n");
  119 + };
  120 + print(encoder_dims_, "encoder_dims");
  121 + print(query_head_dims_, "query_head_dims");
  122 + print(value_head_dims_, "value_head_dims");
  123 + print(num_heads_, "num_heads");
  124 + print(num_encoder_layers_, "num_encoder_layers");
  125 + print(cnn_module_kernels_, "cnn_module_kernels");
  126 + print(left_context_len_, "left_context_len");
  127 + SHERPA_ONNX_LOGE("T: %d", T_);
  128 + SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_);
  129 + }
  130 +}
  131 +
  132 +void OnlineZipformer2TransducerModel::InitDecoder(void *model_data,
  133 + size_t model_data_length) {
  134 + decoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
  135 + model_data_length, sess_opts_);
  136 +
  137 + GetInputNames(decoder_sess_.get(), &decoder_input_names_,
  138 + &decoder_input_names_ptr_);
  139 +
  140 + GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
  141 + &decoder_output_names_ptr_);
  142 +
  143 + // get meta data
  144 + Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata();
  145 + if (config_.debug) {
  146 + std::ostringstream os;
  147 + os << "---decoder---\n";
  148 + PrintModelMetadata(os, meta_data);
  149 + SHERPA_ONNX_LOGE("%s", os.str().c_str());
  150 + }
  151 +
  152 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  153 + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
  154 + SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
  155 +}
  156 +
  157 +void OnlineZipformer2TransducerModel::InitJoiner(void *model_data,
  158 + size_t model_data_length) {
  159 + joiner_sess_ = std::make_unique<Ort::Session>(env_, model_data,
  160 + model_data_length, sess_opts_);
  161 +
  162 + GetInputNames(joiner_sess_.get(), &joiner_input_names_,
  163 + &joiner_input_names_ptr_);
  164 +
  165 + GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
  166 + &joiner_output_names_ptr_);
  167 +
  168 + // get meta data
  169 + Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata();
  170 + if (config_.debug) {
  171 + std::ostringstream os;
  172 + os << "---joiner---\n";
  173 + PrintModelMetadata(os, meta_data);
  174 + SHERPA_ONNX_LOGE("%s", os.str().c_str());
  175 + }
  176 +}
  177 +
  178 +std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates(
  179 + const std::vector<std::vector<Ort::Value>> &states) const {
  180 + int32_t batch_size = static_cast<int32_t>(states.size());
  181 + int32_t num_encoders = static_cast<int32_t>(num_encoder_layers_.size());
  182 +
  183 + std::vector<const Ort::Value *> buf(batch_size);
  184 +
  185 + std::vector<Ort::Value> ans;
  186 + int32_t num_states = static_cast<int32_t>(states[0].size());
  187 + ans.reserve(num_states);
  188 +
  189 + for (int32_t i = 0; i != (num_states - 2) / 6; ++i) {
  190 + {
  191 + for (int32_t n = 0; n != batch_size; ++n) {
  192 + buf[n] = &states[n][6 * i];
  193 + }
  194 + auto v = Cat(allocator_, buf, 1);
  195 + ans.push_back(std::move(v));
  196 + }
  197 + {
  198 + for (int32_t n = 0; n != batch_size; ++n) {
  199 + buf[n] = &states[n][6 * i + 1];
  200 + }
  201 + auto v = Cat(allocator_, buf, 1);
  202 + ans.push_back(std::move(v));
  203 + }
  204 + {
  205 + for (int32_t n = 0; n != batch_size; ++n) {
  206 + buf[n] = &states[n][6 * i + 2];
  207 + }
  208 + auto v = Cat(allocator_, buf, 1);
  209 + ans.push_back(std::move(v));
  210 + }
  211 + {
  212 + for (int32_t n = 0; n != batch_size; ++n) {
  213 + buf[n] = &states[n][6 * i + 3];
  214 + }
  215 + auto v = Cat(allocator_, buf, 1);
  216 + ans.push_back(std::move(v));
  217 + }
  218 + {
  219 + for (int32_t n = 0; n != batch_size; ++n) {
  220 + buf[n] = &states[n][6 * i + 4];
  221 + }
  222 + auto v = Cat(allocator_, buf, 0);
  223 + ans.push_back(std::move(v));
  224 + }
  225 + {
  226 + for (int32_t n = 0; n != batch_size; ++n) {
  227 + buf[n] = &states[n][6 * i + 5];
  228 + }
  229 + auto v = Cat(allocator_, buf, 0);
  230 + ans.push_back(std::move(v));
  231 + }
  232 + }
  233 +
  234 + {
  235 + for (int32_t n = 0; n != batch_size; ++n) {
  236 + buf[n] = &states[n][num_states - 2];
  237 + }
  238 + auto v = Cat(allocator_, buf, 0);
  239 + ans.push_back(std::move(v));
  240 + }
  241 +
  242 + {
  243 + for (int32_t n = 0; n != batch_size; ++n) {
  244 + buf[n] = &states[n][num_states - 1];
  245 + }
  246 + auto v = Cat<int64_t>(allocator_, buf, 0);
  247 + ans.push_back(std::move(v));
  248 + }
  249 + return ans;
  250 +}
  251 +
  252 +std::vector<std::vector<Ort::Value>>
  253 +OnlineZipformer2TransducerModel::UnStackStates(
  254 + const std::vector<Ort::Value> &states) const {
  255 + int32_t m = std::accumulate(num_encoder_layers_.begin(), num_encoder_layers_.end(), 0);
  256 + assert(states.size() == m * 6 + 2);
  257 +
  258 + int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
  259 + int32_t num_encoders = num_encoder_layers_.size();
  260 +
  261 + std::vector<std::vector<Ort::Value>> ans;
  262 + ans.resize(batch_size);
  263 +
  264 + for (int32_t i = 0; i != m; ++i) {
  265 + {
  266 + auto v = Unbind(allocator_, &states[i * 6], 1);
  267 + assert(v.size() == batch_size);
  268 +
  269 + for (int32_t n = 0; n != batch_size; ++n) {
  270 + ans[n].push_back(std::move(v[n]));
  271 + }
  272 + }
  273 + {
  274 + auto v = Unbind(allocator_, &states[i * 6 + 1], 1);
  275 + assert(v.size() == batch_size);
  276 +
  277 + for (int32_t n = 0; n != batch_size; ++n) {
  278 + ans[n].push_back(std::move(v[n]));
  279 + }
  280 + }
  281 + {
  282 + auto v = Unbind(allocator_, &states[i * 6 + 2], 1);
  283 + assert(v.size() == batch_size);
  284 +
  285 + for (int32_t n = 0; n != batch_size; ++n) {
  286 + ans[n].push_back(std::move(v[n]));
  287 + }
  288 + }
  289 + {
  290 + auto v = Unbind(allocator_, &states[i * 6 + 3], 1);
  291 + assert(v.size() == batch_size);
  292 +
  293 + for (int32_t n = 0; n != batch_size; ++n) {
  294 + ans[n].push_back(std::move(v[n]));
  295 + }
  296 + }
  297 + {
  298 + auto v = Unbind(allocator_, &states[i * 6 + 4], 0);
  299 + assert(v.size() == batch_size);
  300 +
  301 + for (int32_t n = 0; n != batch_size; ++n) {
  302 + ans[n].push_back(std::move(v[n]));
  303 + }
  304 + }
  305 + {
  306 + auto v = Unbind(allocator_, &states[i * 6 + 5], 0);
  307 + assert(v.size() == batch_size);
  308 +
  309 + for (int32_t n = 0; n != batch_size; ++n) {
  310 + ans[n].push_back(std::move(v[n]));
  311 + }
  312 + }
  313 + }
  314 +
  315 + {
  316 + auto v = Unbind(allocator_, &states[m * 6], 0);
  317 + assert(v.size() == batch_size);
  318 +
  319 + for (int32_t n = 0; n != batch_size; ++n) {
  320 + ans[n].push_back(std::move(v[n]));
  321 + }
  322 + }
  323 + {
  324 + auto v = Unbind<int64_t>(allocator_, &states[m * 6 + 1], 0);
  325 + assert(v.size() == batch_size);
  326 +
  327 + for (int32_t n = 0; n != batch_size; ++n) {
  328 + ans[n].push_back(std::move(v[n]));
  329 + }
  330 + }
  331 +
  332 + return ans;
  333 +}
  334 +
  335 +std::vector<Ort::Value> OnlineZipformer2TransducerModel::GetEncoderInitStates() {
  336 + std::vector<Ort::Value> ans;
  337 + int32_t n = static_cast<int32_t>(encoder_dims_.size());
  338 + int32_t m = std::accumulate(num_encoder_layers_.begin(), num_encoder_layers_.end(), 0);
  339 + ans.reserve(m * 6 + 2);
  340 +
  341 + for (int32_t i = 0; i != n; ++i) {
  342 + int32_t num_layers = num_encoder_layers_[i];
  343 + int32_t key_dim = query_head_dims_[i] * num_heads_[i];
  344 + int32_t value_dim = value_head_dims_[i] * num_heads_[i];
  345 + int32_t nonlin_attn_head_dim = 3 * encoder_dims_[i] / 4;
  346 +
  347 + for (int32_t j = 0; j != num_layers; ++j) {
  348 + {
  349 + std::array<int64_t, 3> s{left_context_len_[i], 1, key_dim};
  350 + auto v =
  351 + Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
  352 + Fill(&v, 0);
  353 + ans.push_back(std::move(v));
  354 + }
  355 +
  356 + {
  357 + std::array<int64_t, 4> s{1, 1, left_context_len_[i], nonlin_attn_head_dim};
  358 + auto v =
  359 + Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
  360 + Fill(&v, 0);
  361 + ans.push_back(std::move(v));
  362 + }
  363 +
  364 + {
  365 + std::array<int64_t, 3> s{left_context_len_[i], 1, value_dim};
  366 + auto v =
  367 + Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
  368 + Fill(&v, 0);
  369 + ans.push_back(std::move(v));
  370 + }
  371 +
  372 + {
  373 + std::array<int64_t, 3> s{left_context_len_[i], 1, value_dim};
  374 + auto v =
  375 + Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
  376 + Fill(&v, 0);
  377 + ans.push_back(std::move(v));
  378 + }
  379 +
  380 + {
  381 + std::array<int64_t, 3> s{1, encoder_dims_[i], cnn_module_kernels_[i] / 2};
  382 + auto v =
  383 + Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
  384 + Fill(&v, 0);
  385 + ans.push_back(std::move(v));
  386 + }
  387 +
  388 + {
  389 + std::array<int64_t, 3> s{1, encoder_dims_[i], cnn_module_kernels_[i] / 2};
  390 + auto v =
  391 + Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
  392 + Fill(&v, 0);
  393 + ans.push_back(std::move(v));
  394 + }
  395 + }
  396 + }
  397 +
  398 + {
  399 + std::array<int64_t, 4> s{1, 128, 3, 19};
  400 + auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
  401 + Fill(&v, 0);
  402 + ans.push_back(std::move(v));
  403 + }
  404 +
  405 + {
  406 + std::array<int64_t, 1> s{1};
  407 + auto v = Ort::Value::CreateTensor<int64_t>(allocator_, s.data(), s.size());
  408 + Fill<int64_t>(&v, 0);
  409 + ans.push_back(std::move(v));
  410 + }
  411 + return ans;
  412 +}
  413 +
  414 +std::pair<Ort::Value, std::vector<Ort::Value>>
  415 +OnlineZipformer2TransducerModel::RunEncoder(Ort::Value features,
  416 + std::vector<Ort::Value> states,
  417 + Ort::Value /* processed_frames */) {
  418 + std::vector<Ort::Value> encoder_inputs;
  419 + encoder_inputs.reserve(1 + states.size());
  420 +
  421 + encoder_inputs.push_back(std::move(features));
  422 + for (auto &v : states) {
  423 + encoder_inputs.push_back(std::move(v));
  424 + }
  425 +
  426 + auto encoder_out = encoder_sess_->Run(
  427 + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
  428 + encoder_inputs.size(), encoder_output_names_ptr_.data(),
  429 + encoder_output_names_ptr_.size());
  430 +
  431 + std::vector<Ort::Value> next_states;
  432 + next_states.reserve(states.size());
  433 +
  434 + for (int32_t i = 1; i != static_cast<int32_t>(encoder_out.size()); ++i) {
  435 + next_states.push_back(std::move(encoder_out[i]));
  436 + }
  437 + return {std::move(encoder_out[0]), std::move(next_states)};
  438 +}
  439 +
  440 +Ort::Value OnlineZipformer2TransducerModel::RunDecoder(
  441 + Ort::Value decoder_input) {
  442 + auto decoder_out = decoder_sess_->Run(
  443 + {}, decoder_input_names_ptr_.data(), &decoder_input, 1,
  444 + decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size());
  445 + return std::move(decoder_out[0]);
  446 +}
  447 +
  448 +Ort::Value OnlineZipformer2TransducerModel::RunJoiner(Ort::Value encoder_out,
  449 + Ort::Value decoder_out) {
  450 + std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
  451 + std::move(decoder_out)};
  452 + auto logit =
  453 + joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(),
  454 + joiner_input.size(), joiner_output_names_ptr_.data(),
  455 + joiner_output_names_ptr_.size());
  456 +
  457 + return std::move(logit[0]);
  458 +}
  459 +
  460 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-zipformer2-transducer-model.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_TRANSDUCER_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_TRANSDUCER_MODEL_H_
  6 +
  7 +#include <memory>
  8 +#include <string>
  9 +#include <utility>
  10 +#include <vector>
  11 +
  12 +#if __ANDROID_API__ >= 9
  13 +#include "android/asset_manager.h"
  14 +#include "android/asset_manager_jni.h"
  15 +#endif
  16 +
  17 +#include "onnxruntime_cxx_api.h" // NOLINT
  18 +#include "sherpa-onnx/csrc/online-transducer-model-config.h"
  19 +#include "sherpa-onnx/csrc/online-transducer-model.h"
  20 +
  21 +namespace sherpa_onnx {
  22 +
  23 +class OnlineZipformer2TransducerModel : public OnlineTransducerModel {
  24 + public:
  25 + explicit OnlineZipformer2TransducerModel(
  26 + const OnlineTransducerModelConfig &config);
  27 +
  28 +#if __ANDROID_API__ >= 9
  29 + OnlineZipformer2TransducerModel(AAssetManager *mgr,
  30 + const OnlineTransducerModelConfig &config);
  31 +#endif
  32 +
  33 + std::vector<Ort::Value> StackStates(
  34 + const std::vector<std::vector<Ort::Value>> &states) const override;
  35 +
  36 + std::vector<std::vector<Ort::Value>> UnStackStates(
  37 + const std::vector<Ort::Value> &states) const override;
  38 +
  39 + std::vector<Ort::Value> GetEncoderInitStates() override;
  40 +
  41 + std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
  42 + Ort::Value features, std::vector<Ort::Value> states,
  43 + Ort::Value processed_frames) override;
  44 +
  45 + Ort::Value RunDecoder(Ort::Value decoder_input) override;
  46 +
  47 + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override;
  48 +
  49 + int32_t ContextSize() const override { return context_size_; }
  50 +
  51 + int32_t ChunkSize() const override { return T_; }
  52 +
  53 + int32_t ChunkShift() const override { return decode_chunk_len_; }
  54 +
  55 + int32_t VocabSize() const override { return vocab_size_; }
  56 + OrtAllocator *Allocator() override { return allocator_; }
  57 +
  58 + private:
  59 + void InitEncoder(void *model_data, size_t model_data_length);
  60 + void InitDecoder(void *model_data, size_t model_data_length);
  61 + void InitJoiner(void *model_data, size_t model_data_length);
  62 +
  63 + private:
  64 + Ort::Env env_;
  65 + Ort::SessionOptions sess_opts_;
  66 + Ort::AllocatorWithDefaultOptions allocator_;
  67 +
  68 + std::unique_ptr<Ort::Session> encoder_sess_;
  69 + std::unique_ptr<Ort::Session> decoder_sess_;
  70 + std::unique_ptr<Ort::Session> joiner_sess_;
  71 +
  72 + std::vector<std::string> encoder_input_names_;
  73 + std::vector<const char *> encoder_input_names_ptr_;
  74 +
  75 + std::vector<std::string> encoder_output_names_;
  76 + std::vector<const char *> encoder_output_names_ptr_;
  77 +
  78 + std::vector<std::string> decoder_input_names_;
  79 + std::vector<const char *> decoder_input_names_ptr_;
  80 +
  81 + std::vector<std::string> decoder_output_names_;
  82 + std::vector<const char *> decoder_output_names_ptr_;
  83 +
  84 + std::vector<std::string> joiner_input_names_;
  85 + std::vector<const char *> joiner_input_names_ptr_;
  86 +
  87 + std::vector<std::string> joiner_output_names_;
  88 + std::vector<const char *> joiner_output_names_ptr_;
  89 +
  90 + OnlineTransducerModelConfig config_;
  91 +
  92 + std::vector<int32_t> encoder_dims_;
  93 + std::vector<int32_t> query_head_dims_;
  94 + std::vector<int32_t> value_head_dims_;
  95 + std::vector<int32_t> num_heads_;
  96 + std::vector<int32_t> num_encoder_layers_;
  97 + std::vector<int32_t> cnn_module_kernels_;
  98 + std::vector<int32_t> left_context_len_;
  99 +
  100 + int32_t T_ = 0;
  101 + int32_t decode_chunk_len_ = 0;
  102 +
  103 + int32_t context_size_ = 0;
  104 + int32_t vocab_size_ = 0;
  105 +};
  106 +
  107 +} // namespace sherpa_onnx
  108 +
  109 +#endif // SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_TRANSDUCER_MODEL_H_