Committed by
GitHub
support streaming zipformer2 (#185)
Co-authored-by: danfu <danfu@tencent.com>
正在显示
4 个修改的文件
包含
578 行增加
和
0 行删除
| @@ -48,6 +48,7 @@ set(sources | @@ -48,6 +48,7 @@ set(sources | ||
| 48 | online-transducer-model.cc | 48 | online-transducer-model.cc |
| 49 | online-transducer-modified-beam-search-decoder.cc | 49 | online-transducer-modified-beam-search-decoder.cc |
| 50 | online-zipformer-transducer-model.cc | 50 | online-zipformer-transducer-model.cc |
| 51 | + online-zipformer2-transducer-model.cc | ||
| 51 | onnx-utils.cc | 52 | onnx-utils.cc |
| 52 | session.cc | 53 | session.cc |
| 53 | packed-sequence.cc | 54 | packed-sequence.cc |
| @@ -18,6 +18,7 @@ | @@ -18,6 +18,7 @@ | ||
| 18 | #include "sherpa-onnx/csrc/online-conformer-transducer-model.h" | 18 | #include "sherpa-onnx/csrc/online-conformer-transducer-model.h" |
| 19 | #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" | 19 | #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" |
| 20 | #include "sherpa-onnx/csrc/online-zipformer-transducer-model.h" | 20 | #include "sherpa-onnx/csrc/online-zipformer-transducer-model.h" |
| 21 | +#include "sherpa-onnx/csrc/online-zipformer2-transducer-model.h" | ||
| 21 | #include "sherpa-onnx/csrc/onnx-utils.h" | 22 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 22 | 23 | ||
| 23 | namespace { | 24 | namespace { |
| @@ -26,6 +27,7 @@ enum class ModelType { | @@ -26,6 +27,7 @@ enum class ModelType { | ||
| 26 | kConformer, | 27 | kConformer, |
| 27 | kLstm, | 28 | kLstm, |
| 28 | kZipformer, | 29 | kZipformer, |
| 30 | + kZipformer2, | ||
| 29 | kUnkown, | 31 | kUnkown, |
| 30 | }; | 32 | }; |
| 31 | 33 | ||
| @@ -65,6 +67,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -65,6 +67,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 65 | return ModelType::kLstm; | 67 | return ModelType::kLstm; |
| 66 | } else if (model_type.get() == std::string("zipformer")) { | 68 | } else if (model_type.get() == std::string("zipformer")) { |
| 67 | return ModelType::kZipformer; | 69 | return ModelType::kZipformer; |
| 70 | + } else if (model_type.get() == std::string("zipformer2")) { | ||
| 71 | + return ModelType::kZipformer2; | ||
| 68 | } else { | 72 | } else { |
| 69 | SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); | 73 | SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); |
| 70 | return ModelType::kUnkown; | 74 | return ModelType::kUnkown; |
| @@ -88,6 +92,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | @@ -88,6 +92,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | ||
| 88 | return std::make_unique<OnlineLstmTransducerModel>(config); | 92 | return std::make_unique<OnlineLstmTransducerModel>(config); |
| 89 | case ModelType::kZipformer: | 93 | case ModelType::kZipformer: |
| 90 | return std::make_unique<OnlineZipformerTransducerModel>(config); | 94 | return std::make_unique<OnlineZipformerTransducerModel>(config); |
| 95 | + case ModelType::kZipformer2: | ||
| 96 | + return std::make_unique<OnlineZipformer2TransducerModel>(config); | ||
| 91 | case ModelType::kUnkown: | 97 | case ModelType::kUnkown: |
| 92 | SHERPA_ONNX_LOGE("Unknown model type in online transducer!"); | 98 | SHERPA_ONNX_LOGE("Unknown model type in online transducer!"); |
| 93 | return nullptr; | 99 | return nullptr; |
| @@ -144,6 +150,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | @@ -144,6 +150,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | ||
| 144 | return std::make_unique<OnlineLstmTransducerModel>(mgr, config); | 150 | return std::make_unique<OnlineLstmTransducerModel>(mgr, config); |
| 145 | case ModelType::kZipformer: | 151 | case ModelType::kZipformer: |
| 146 | return std::make_unique<OnlineZipformerTransducerModel>(mgr, config); | 152 | return std::make_unique<OnlineZipformerTransducerModel>(mgr, config); |
| 153 | + case ModelType::kZipformer2: | ||
| 154 | + return std::make_unique<OnlineZipformer2TransducerModel>(mgr, config); | ||
| 147 | case ModelType::kUnkown: | 155 | case ModelType::kUnkown: |
| 148 | SHERPA_ONNX_LOGE("Unknown model type in online transducer!"); | 156 | SHERPA_ONNX_LOGE("Unknown model type in online transducer!"); |
| 149 | return nullptr; | 157 | return nullptr; |
| 1 | +// sherpa-onnx/csrc/online-zipformer2-transducer-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/online-zipformer2-transducer-model.h" | ||
| 6 | + | ||
| 7 | +#include <assert.h> | ||
| 8 | +#include <math.h> | ||
| 9 | + | ||
| 10 | +#include <algorithm> | ||
| 11 | +#include <memory> | ||
| 12 | +#include <sstream> | ||
| 13 | +#include <string> | ||
| 14 | +#include <utility> | ||
| 15 | +#include <vector> | ||
| 16 | +#include <numeric> | ||
| 17 | + | ||
| 18 | +#if __ANDROID_API__ >= 9 | ||
| 19 | +#include "android/asset_manager.h" | ||
| 20 | +#include "android/asset_manager_jni.h" | ||
| 21 | +#endif | ||
| 22 | + | ||
| 23 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 24 | +#include "sherpa-onnx/csrc/cat.h" | ||
| 25 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 26 | +#include "sherpa-onnx/csrc/online-transducer-decoder.h" | ||
| 27 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 28 | +#include "sherpa-onnx/csrc/session.h" | ||
| 29 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 30 | +#include "sherpa-onnx/csrc/unbind.h" | ||
| 31 | + | ||
| 32 | +namespace sherpa_onnx { | ||
| 33 | + | ||
| 34 | +OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( | ||
| 35 | + const OnlineTransducerModelConfig &config) | ||
| 36 | + : env_(ORT_LOGGING_LEVEL_WARNING), | ||
| 37 | + config_(config), | ||
| 38 | + sess_opts_(GetSessionOptions(config)), | ||
| 39 | + allocator_{} { | ||
| 40 | + { | ||
| 41 | + auto buf = ReadFile(config.encoder_filename); | ||
| 42 | + InitEncoder(buf.data(), buf.size()); | ||
| 43 | + } | ||
| 44 | + | ||
| 45 | + { | ||
| 46 | + auto buf = ReadFile(config.decoder_filename); | ||
| 47 | + InitDecoder(buf.data(), buf.size()); | ||
| 48 | + } | ||
| 49 | + | ||
| 50 | + { | ||
| 51 | + auto buf = ReadFile(config.joiner_filename); | ||
| 52 | + InitJoiner(buf.data(), buf.size()); | ||
| 53 | + } | ||
| 54 | +} | ||
| 55 | + | ||
| 56 | +#if __ANDROID_API__ >= 9 | ||
| 57 | +OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( | ||
| 58 | + AAssetManager *mgr, const OnlineTransducerModelConfig &config) | ||
| 59 | + : env_(ORT_LOGGING_LEVEL_WARNING), | ||
| 60 | + config_(config), | ||
| 61 | + sess_opts_(GetSessionOptions(config)), | ||
| 62 | + allocator_{} { | ||
| 63 | + { | ||
| 64 | + auto buf = ReadFile(mgr, config.encoder_filename); | ||
| 65 | + InitEncoder(buf.data(), buf.size()); | ||
| 66 | + } | ||
| 67 | + | ||
| 68 | + { | ||
| 69 | + auto buf = ReadFile(mgr, config.decoder_filename); | ||
| 70 | + InitDecoder(buf.data(), buf.size()); | ||
| 71 | + } | ||
| 72 | + | ||
| 73 | + { | ||
| 74 | + auto buf = ReadFile(mgr, config.joiner_filename); | ||
| 75 | + InitJoiner(buf.data(), buf.size()); | ||
| 76 | + } | ||
| 77 | +} | ||
| 78 | +#endif | ||
| 79 | + | ||
| 80 | +void OnlineZipformer2TransducerModel::InitEncoder(void *model_data, | ||
| 81 | + size_t model_data_length) { | ||
| 82 | + encoder_sess_ = std::make_unique<Ort::Session>(env_, model_data, | ||
| 83 | + model_data_length, sess_opts_); | ||
| 84 | + | ||
| 85 | + GetInputNames(encoder_sess_.get(), &encoder_input_names_, | ||
| 86 | + &encoder_input_names_ptr_); | ||
| 87 | + | ||
| 88 | + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, | ||
| 89 | + &encoder_output_names_ptr_); | ||
| 90 | + | ||
| 91 | + // get meta data | ||
| 92 | + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); | ||
| 93 | + if (config_.debug) { | ||
| 94 | + std::ostringstream os; | ||
| 95 | + os << "---encoder---\n"; | ||
| 96 | + PrintModelMetadata(os, meta_data); | ||
| 97 | + SHERPA_ONNX_LOGE("%s", os.str().c_str()); | ||
| 98 | + } | ||
| 99 | + | ||
| 100 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 101 | + SHERPA_ONNX_READ_META_DATA_VEC(encoder_dims_, "encoder_dims"); | ||
| 102 | + SHERPA_ONNX_READ_META_DATA_VEC(query_head_dims_, "query_head_dims"); | ||
| 103 | + SHERPA_ONNX_READ_META_DATA_VEC(value_head_dims_, "value_head_dims"); | ||
| 104 | + SHERPA_ONNX_READ_META_DATA_VEC(num_heads_, "num_heads"); | ||
| 105 | + SHERPA_ONNX_READ_META_DATA_VEC(num_encoder_layers_, "num_encoder_layers"); | ||
| 106 | + SHERPA_ONNX_READ_META_DATA_VEC(cnn_module_kernels_, "cnn_module_kernels"); | ||
| 107 | + SHERPA_ONNX_READ_META_DATA_VEC(left_context_len_, "left_context_len"); | ||
| 108 | + | ||
| 109 | + SHERPA_ONNX_READ_META_DATA(T_, "T"); | ||
| 110 | + SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); | ||
| 111 | + | ||
| 112 | + if (config_.debug) { | ||
| 113 | + auto print = [](const std::vector<int32_t> &v, const char *name) { | ||
| 114 | + fprintf(stderr, "%s: ", name); | ||
| 115 | + for (auto i : v) { | ||
| 116 | + fprintf(stderr, "%d ", i); | ||
| 117 | + } | ||
| 118 | + fprintf(stderr, "\n"); | ||
| 119 | + }; | ||
| 120 | + print(encoder_dims_, "encoder_dims"); | ||
| 121 | + print(query_head_dims_, "query_head_dims"); | ||
| 122 | + print(value_head_dims_, "value_head_dims"); | ||
| 123 | + print(num_heads_, "num_heads"); | ||
| 124 | + print(num_encoder_layers_, "num_encoder_layers"); | ||
| 125 | + print(cnn_module_kernels_, "cnn_module_kernels"); | ||
| 126 | + print(left_context_len_, "left_context_len"); | ||
| 127 | + SHERPA_ONNX_LOGE("T: %d", T_); | ||
| 128 | + SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_); | ||
| 129 | + } | ||
| 130 | +} | ||
| 131 | + | ||
| 132 | +void OnlineZipformer2TransducerModel::InitDecoder(void *model_data, | ||
| 133 | + size_t model_data_length) { | ||
| 134 | + decoder_sess_ = std::make_unique<Ort::Session>(env_, model_data, | ||
| 135 | + model_data_length, sess_opts_); | ||
| 136 | + | ||
| 137 | + GetInputNames(decoder_sess_.get(), &decoder_input_names_, | ||
| 138 | + &decoder_input_names_ptr_); | ||
| 139 | + | ||
| 140 | + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, | ||
| 141 | + &decoder_output_names_ptr_); | ||
| 142 | + | ||
| 143 | + // get meta data | ||
| 144 | + Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata(); | ||
| 145 | + if (config_.debug) { | ||
| 146 | + std::ostringstream os; | ||
| 147 | + os << "---decoder---\n"; | ||
| 148 | + PrintModelMetadata(os, meta_data); | ||
| 149 | + SHERPA_ONNX_LOGE("%s", os.str().c_str()); | ||
| 150 | + } | ||
| 151 | + | ||
| 152 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 153 | + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); | ||
| 154 | + SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); | ||
| 155 | +} | ||
| 156 | + | ||
| 157 | +void OnlineZipformer2TransducerModel::InitJoiner(void *model_data, | ||
| 158 | + size_t model_data_length) { | ||
| 159 | + joiner_sess_ = std::make_unique<Ort::Session>(env_, model_data, | ||
| 160 | + model_data_length, sess_opts_); | ||
| 161 | + | ||
| 162 | + GetInputNames(joiner_sess_.get(), &joiner_input_names_, | ||
| 163 | + &joiner_input_names_ptr_); | ||
| 164 | + | ||
| 165 | + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, | ||
| 166 | + &joiner_output_names_ptr_); | ||
| 167 | + | ||
| 168 | + // get meta data | ||
| 169 | + Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata(); | ||
| 170 | + if (config_.debug) { | ||
| 171 | + std::ostringstream os; | ||
| 172 | + os << "---joiner---\n"; | ||
| 173 | + PrintModelMetadata(os, meta_data); | ||
| 174 | + SHERPA_ONNX_LOGE("%s", os.str().c_str()); | ||
| 175 | + } | ||
| 176 | +} | ||
| 177 | + | ||
| 178 | +std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates( | ||
| 179 | + const std::vector<std::vector<Ort::Value>> &states) const { | ||
| 180 | + int32_t batch_size = static_cast<int32_t>(states.size()); | ||
| 181 | + int32_t num_encoders = static_cast<int32_t>(num_encoder_layers_.size()); | ||
| 182 | + | ||
| 183 | + std::vector<const Ort::Value *> buf(batch_size); | ||
| 184 | + | ||
| 185 | + std::vector<Ort::Value> ans; | ||
| 186 | + int32_t num_states = static_cast<int32_t>(states[0].size()); | ||
| 187 | + ans.reserve(num_states); | ||
| 188 | + | ||
| 189 | + for (int32_t i = 0; i != (num_states - 2) / 6; ++i) { | ||
| 190 | + { | ||
| 191 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 192 | + buf[n] = &states[n][6 * i]; | ||
| 193 | + } | ||
| 194 | + auto v = Cat(allocator_, buf, 1); | ||
| 195 | + ans.push_back(std::move(v)); | ||
| 196 | + } | ||
| 197 | + { | ||
| 198 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 199 | + buf[n] = &states[n][6 * i + 1]; | ||
| 200 | + } | ||
| 201 | + auto v = Cat(allocator_, buf, 1); | ||
| 202 | + ans.push_back(std::move(v)); | ||
| 203 | + } | ||
| 204 | + { | ||
| 205 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 206 | + buf[n] = &states[n][6 * i + 2]; | ||
| 207 | + } | ||
| 208 | + auto v = Cat(allocator_, buf, 1); | ||
| 209 | + ans.push_back(std::move(v)); | ||
| 210 | + } | ||
| 211 | + { | ||
| 212 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 213 | + buf[n] = &states[n][6 * i + 3]; | ||
| 214 | + } | ||
| 215 | + auto v = Cat(allocator_, buf, 1); | ||
| 216 | + ans.push_back(std::move(v)); | ||
| 217 | + } | ||
| 218 | + { | ||
| 219 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 220 | + buf[n] = &states[n][6 * i + 4]; | ||
| 221 | + } | ||
| 222 | + auto v = Cat(allocator_, buf, 0); | ||
| 223 | + ans.push_back(std::move(v)); | ||
| 224 | + } | ||
| 225 | + { | ||
| 226 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 227 | + buf[n] = &states[n][6 * i + 5]; | ||
| 228 | + } | ||
| 229 | + auto v = Cat(allocator_, buf, 0); | ||
| 230 | + ans.push_back(std::move(v)); | ||
| 231 | + } | ||
| 232 | + } | ||
| 233 | + | ||
| 234 | + { | ||
| 235 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 236 | + buf[n] = &states[n][num_states - 2]; | ||
| 237 | + } | ||
| 238 | + auto v = Cat(allocator_, buf, 0); | ||
| 239 | + ans.push_back(std::move(v)); | ||
| 240 | + } | ||
| 241 | + | ||
| 242 | + { | ||
| 243 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 244 | + buf[n] = &states[n][num_states - 1]; | ||
| 245 | + } | ||
| 246 | + auto v = Cat<int64_t>(allocator_, buf, 0); | ||
| 247 | + ans.push_back(std::move(v)); | ||
| 248 | + } | ||
| 249 | + return ans; | ||
| 250 | +} | ||
| 251 | + | ||
| 252 | +std::vector<std::vector<Ort::Value>> | ||
| 253 | +OnlineZipformer2TransducerModel::UnStackStates( | ||
| 254 | + const std::vector<Ort::Value> &states) const { | ||
| 255 | + int32_t m = std::accumulate(num_encoder_layers_.begin(), num_encoder_layers_.end(), 0); | ||
| 256 | + assert(states.size() == m * 6 + 2); | ||
| 257 | + | ||
| 258 | + int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1]; | ||
| 259 | + int32_t num_encoders = num_encoder_layers_.size(); | ||
| 260 | + | ||
| 261 | + std::vector<std::vector<Ort::Value>> ans; | ||
| 262 | + ans.resize(batch_size); | ||
| 263 | + | ||
| 264 | + for (int32_t i = 0; i != m; ++i) { | ||
| 265 | + { | ||
| 266 | + auto v = Unbind(allocator_, &states[i * 6], 1); | ||
| 267 | + assert(v.size() == batch_size); | ||
| 268 | + | ||
| 269 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 270 | + ans[n].push_back(std::move(v[n])); | ||
| 271 | + } | ||
| 272 | + } | ||
| 273 | + { | ||
| 274 | + auto v = Unbind(allocator_, &states[i * 6 + 1], 1); | ||
| 275 | + assert(v.size() == batch_size); | ||
| 276 | + | ||
| 277 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 278 | + ans[n].push_back(std::move(v[n])); | ||
| 279 | + } | ||
| 280 | + } | ||
| 281 | + { | ||
| 282 | + auto v = Unbind(allocator_, &states[i * 6 + 2], 1); | ||
| 283 | + assert(v.size() == batch_size); | ||
| 284 | + | ||
| 285 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 286 | + ans[n].push_back(std::move(v[n])); | ||
| 287 | + } | ||
| 288 | + } | ||
| 289 | + { | ||
| 290 | + auto v = Unbind(allocator_, &states[i * 6 + 3], 1); | ||
| 291 | + assert(v.size() == batch_size); | ||
| 292 | + | ||
| 293 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 294 | + ans[n].push_back(std::move(v[n])); | ||
| 295 | + } | ||
| 296 | + } | ||
| 297 | + { | ||
| 298 | + auto v = Unbind(allocator_, &states[i * 6 + 4], 0); | ||
| 299 | + assert(v.size() == batch_size); | ||
| 300 | + | ||
| 301 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 302 | + ans[n].push_back(std::move(v[n])); | ||
| 303 | + } | ||
| 304 | + } | ||
| 305 | + { | ||
| 306 | + auto v = Unbind(allocator_, &states[i * 6 + 5], 0); | ||
| 307 | + assert(v.size() == batch_size); | ||
| 308 | + | ||
| 309 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 310 | + ans[n].push_back(std::move(v[n])); | ||
| 311 | + } | ||
| 312 | + } | ||
| 313 | + } | ||
| 314 | + | ||
| 315 | + { | ||
| 316 | + auto v = Unbind(allocator_, &states[m * 6], 0); | ||
| 317 | + assert(v.size() == batch_size); | ||
| 318 | + | ||
| 319 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 320 | + ans[n].push_back(std::move(v[n])); | ||
| 321 | + } | ||
| 322 | + } | ||
| 323 | + { | ||
| 324 | + auto v = Unbind<int64_t>(allocator_, &states[m * 6 + 1], 0); | ||
| 325 | + assert(v.size() == batch_size); | ||
| 326 | + | ||
| 327 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 328 | + ans[n].push_back(std::move(v[n])); | ||
| 329 | + } | ||
| 330 | + } | ||
| 331 | + | ||
| 332 | + return ans; | ||
| 333 | +} | ||
| 334 | + | ||
| 335 | +std::vector<Ort::Value> OnlineZipformer2TransducerModel::GetEncoderInitStates() { | ||
| 336 | + std::vector<Ort::Value> ans; | ||
| 337 | + int32_t n = static_cast<int32_t>(encoder_dims_.size()); | ||
| 338 | + int32_t m = std::accumulate(num_encoder_layers_.begin(), num_encoder_layers_.end(), 0); | ||
| 339 | + ans.reserve(m * 6 + 2); | ||
| 340 | + | ||
| 341 | + for (int32_t i = 0; i != n; ++i) { | ||
| 342 | + int32_t num_layers = num_encoder_layers_[i]; | ||
| 343 | + int32_t key_dim = query_head_dims_[i] * num_heads_[i]; | ||
| 344 | + int32_t value_dim = value_head_dims_[i] * num_heads_[i]; | ||
| 345 | + int32_t nonlin_attn_head_dim = 3 * encoder_dims_[i] / 4; | ||
| 346 | + | ||
| 347 | + for (int32_t j = 0; j != num_layers; ++j) { | ||
| 348 | + { | ||
| 349 | + std::array<int64_t, 3> s{left_context_len_[i], 1, key_dim}; | ||
| 350 | + auto v = | ||
| 351 | + Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size()); | ||
| 352 | + Fill(&v, 0); | ||
| 353 | + ans.push_back(std::move(v)); | ||
| 354 | + } | ||
| 355 | + | ||
| 356 | + { | ||
| 357 | + std::array<int64_t, 4> s{1, 1, left_context_len_[i], nonlin_attn_head_dim}; | ||
| 358 | + auto v = | ||
| 359 | + Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size()); | ||
| 360 | + Fill(&v, 0); | ||
| 361 | + ans.push_back(std::move(v)); | ||
| 362 | + } | ||
| 363 | + | ||
| 364 | + { | ||
| 365 | + std::array<int64_t, 3> s{left_context_len_[i], 1, value_dim}; | ||
| 366 | + auto v = | ||
| 367 | + Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size()); | ||
| 368 | + Fill(&v, 0); | ||
| 369 | + ans.push_back(std::move(v)); | ||
| 370 | + } | ||
| 371 | + | ||
| 372 | + { | ||
| 373 | + std::array<int64_t, 3> s{left_context_len_[i], 1, value_dim}; | ||
| 374 | + auto v = | ||
| 375 | + Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size()); | ||
| 376 | + Fill(&v, 0); | ||
| 377 | + ans.push_back(std::move(v)); | ||
| 378 | + } | ||
| 379 | + | ||
| 380 | + { | ||
| 381 | + std::array<int64_t, 3> s{1, encoder_dims_[i], cnn_module_kernels_[i] / 2}; | ||
| 382 | + auto v = | ||
| 383 | + Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size()); | ||
| 384 | + Fill(&v, 0); | ||
| 385 | + ans.push_back(std::move(v)); | ||
| 386 | + } | ||
| 387 | + | ||
| 388 | + { | ||
| 389 | + std::array<int64_t, 3> s{1, encoder_dims_[i], cnn_module_kernels_[i] / 2}; | ||
| 390 | + auto v = | ||
| 391 | + Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size()); | ||
| 392 | + Fill(&v, 0); | ||
| 393 | + ans.push_back(std::move(v)); | ||
| 394 | + } | ||
| 395 | + } | ||
| 396 | + } | ||
| 397 | + | ||
| 398 | + { | ||
| 399 | + std::array<int64_t, 4> s{1, 128, 3, 19}; | ||
| 400 | + auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size()); | ||
| 401 | + Fill(&v, 0); | ||
| 402 | + ans.push_back(std::move(v)); | ||
| 403 | + } | ||
| 404 | + | ||
| 405 | + { | ||
| 406 | + std::array<int64_t, 1> s{1}; | ||
| 407 | + auto v = Ort::Value::CreateTensor<int64_t>(allocator_, s.data(), s.size()); | ||
| 408 | + Fill<int64_t>(&v, 0); | ||
| 409 | + ans.push_back(std::move(v)); | ||
| 410 | + } | ||
| 411 | + return ans; | ||
| 412 | +} | ||
| 413 | + | ||
| 414 | +std::pair<Ort::Value, std::vector<Ort::Value>> | ||
| 415 | +OnlineZipformer2TransducerModel::RunEncoder(Ort::Value features, | ||
| 416 | + std::vector<Ort::Value> states, | ||
| 417 | + Ort::Value /* processed_frames */) { | ||
| 418 | + std::vector<Ort::Value> encoder_inputs; | ||
| 419 | + encoder_inputs.reserve(1 + states.size()); | ||
| 420 | + | ||
| 421 | + encoder_inputs.push_back(std::move(features)); | ||
| 422 | + for (auto &v : states) { | ||
| 423 | + encoder_inputs.push_back(std::move(v)); | ||
| 424 | + } | ||
| 425 | + | ||
| 426 | + auto encoder_out = encoder_sess_->Run( | ||
| 427 | + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(), | ||
| 428 | + encoder_inputs.size(), encoder_output_names_ptr_.data(), | ||
| 429 | + encoder_output_names_ptr_.size()); | ||
| 430 | + | ||
| 431 | + std::vector<Ort::Value> next_states; | ||
| 432 | + next_states.reserve(states.size()); | ||
| 433 | + | ||
| 434 | + for (int32_t i = 1; i != static_cast<int32_t>(encoder_out.size()); ++i) { | ||
| 435 | + next_states.push_back(std::move(encoder_out[i])); | ||
| 436 | + } | ||
| 437 | + return {std::move(encoder_out[0]), std::move(next_states)}; | ||
| 438 | +} | ||
| 439 | + | ||
| 440 | +Ort::Value OnlineZipformer2TransducerModel::RunDecoder( | ||
| 441 | + Ort::Value decoder_input) { | ||
| 442 | + auto decoder_out = decoder_sess_->Run( | ||
| 443 | + {}, decoder_input_names_ptr_.data(), &decoder_input, 1, | ||
| 444 | + decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size()); | ||
| 445 | + return std::move(decoder_out[0]); | ||
| 446 | +} | ||
| 447 | + | ||
| 448 | +Ort::Value OnlineZipformer2TransducerModel::RunJoiner(Ort::Value encoder_out, | ||
| 449 | + Ort::Value decoder_out) { | ||
| 450 | + std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out), | ||
| 451 | + std::move(decoder_out)}; | ||
| 452 | + auto logit = | ||
| 453 | + joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(), | ||
| 454 | + joiner_input.size(), joiner_output_names_ptr_.data(), | ||
| 455 | + joiner_output_names_ptr_.size()); | ||
| 456 | + | ||
| 457 | + return std::move(logit[0]); | ||
| 458 | +} | ||
| 459 | + | ||
| 460 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/online-zipformer2-transducer-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_TRANSDUCER_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_TRANSDUCER_MODEL_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <string> | ||
| 9 | +#include <utility> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#if __ANDROID_API__ >= 9 | ||
| 13 | +#include "android/asset_manager.h" | ||
| 14 | +#include "android/asset_manager_jni.h" | ||
| 15 | +#endif | ||
| 16 | + | ||
| 17 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 18 | +#include "sherpa-onnx/csrc/online-transducer-model-config.h" | ||
| 19 | +#include "sherpa-onnx/csrc/online-transducer-model.h" | ||
| 20 | + | ||
| 21 | +namespace sherpa_onnx { | ||
| 22 | + | ||
| 23 | +class OnlineZipformer2TransducerModel : public OnlineTransducerModel { | ||
| 24 | + public: | ||
| 25 | + explicit OnlineZipformer2TransducerModel( | ||
| 26 | + const OnlineTransducerModelConfig &config); | ||
| 27 | + | ||
| 28 | +#if __ANDROID_API__ >= 9 | ||
| 29 | + OnlineZipformer2TransducerModel(AAssetManager *mgr, | ||
| 30 | + const OnlineTransducerModelConfig &config); | ||
| 31 | +#endif | ||
| 32 | + | ||
| 33 | + std::vector<Ort::Value> StackStates( | ||
| 34 | + const std::vector<std::vector<Ort::Value>> &states) const override; | ||
| 35 | + | ||
| 36 | + std::vector<std::vector<Ort::Value>> UnStackStates( | ||
| 37 | + const std::vector<Ort::Value> &states) const override; | ||
| 38 | + | ||
| 39 | + std::vector<Ort::Value> GetEncoderInitStates() override; | ||
| 40 | + | ||
| 41 | + std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( | ||
| 42 | + Ort::Value features, std::vector<Ort::Value> states, | ||
| 43 | + Ort::Value processed_frames) override; | ||
| 44 | + | ||
| 45 | + Ort::Value RunDecoder(Ort::Value decoder_input) override; | ||
| 46 | + | ||
| 47 | + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override; | ||
| 48 | + | ||
| 49 | + int32_t ContextSize() const override { return context_size_; } | ||
| 50 | + | ||
| 51 | + int32_t ChunkSize() const override { return T_; } | ||
| 52 | + | ||
| 53 | + int32_t ChunkShift() const override { return decode_chunk_len_; } | ||
| 54 | + | ||
| 55 | + int32_t VocabSize() const override { return vocab_size_; } | ||
| 56 | + OrtAllocator *Allocator() override { return allocator_; } | ||
| 57 | + | ||
| 58 | + private: | ||
| 59 | + void InitEncoder(void *model_data, size_t model_data_length); | ||
| 60 | + void InitDecoder(void *model_data, size_t model_data_length); | ||
| 61 | + void InitJoiner(void *model_data, size_t model_data_length); | ||
| 62 | + | ||
| 63 | + private: | ||
| 64 | + Ort::Env env_; | ||
| 65 | + Ort::SessionOptions sess_opts_; | ||
| 66 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 67 | + | ||
| 68 | + std::unique_ptr<Ort::Session> encoder_sess_; | ||
| 69 | + std::unique_ptr<Ort::Session> decoder_sess_; | ||
| 70 | + std::unique_ptr<Ort::Session> joiner_sess_; | ||
| 71 | + | ||
| 72 | + std::vector<std::string> encoder_input_names_; | ||
| 73 | + std::vector<const char *> encoder_input_names_ptr_; | ||
| 74 | + | ||
| 75 | + std::vector<std::string> encoder_output_names_; | ||
| 76 | + std::vector<const char *> encoder_output_names_ptr_; | ||
| 77 | + | ||
| 78 | + std::vector<std::string> decoder_input_names_; | ||
| 79 | + std::vector<const char *> decoder_input_names_ptr_; | ||
| 80 | + | ||
| 81 | + std::vector<std::string> decoder_output_names_; | ||
| 82 | + std::vector<const char *> decoder_output_names_ptr_; | ||
| 83 | + | ||
| 84 | + std::vector<std::string> joiner_input_names_; | ||
| 85 | + std::vector<const char *> joiner_input_names_ptr_; | ||
| 86 | + | ||
| 87 | + std::vector<std::string> joiner_output_names_; | ||
| 88 | + std::vector<const char *> joiner_output_names_ptr_; | ||
| 89 | + | ||
| 90 | + OnlineTransducerModelConfig config_; | ||
| 91 | + | ||
| 92 | + std::vector<int32_t> encoder_dims_; | ||
| 93 | + std::vector<int32_t> query_head_dims_; | ||
| 94 | + std::vector<int32_t> value_head_dims_; | ||
| 95 | + std::vector<int32_t> num_heads_; | ||
| 96 | + std::vector<int32_t> num_encoder_layers_; | ||
| 97 | + std::vector<int32_t> cnn_module_kernels_; | ||
| 98 | + std::vector<int32_t> left_context_len_; | ||
| 99 | + | ||
| 100 | + int32_t T_ = 0; | ||
| 101 | + int32_t decode_chunk_len_ = 0; | ||
| 102 | + | ||
| 103 | + int32_t context_size_ = 0; | ||
| 104 | + int32_t vocab_size_ = 0; | ||
| 105 | +}; | ||
| 106 | + | ||
| 107 | +} // namespace sherpa_onnx | ||
| 108 | + | ||
| 109 | +#endif // SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_TRANSDUCER_MODEL_H_ |
-
请 注册 或 登录 后发表评论