Committed by
GitHub
Ebranchformer (#1951)
* adding ebranchformer encoder * extend surfaced FeatureExtractorConfig - so ebranchformer feature extraction can be configured from Python - the GlobCmvn is not needed, as it is a module in the OnnxEncoder * clean the code * Integrating remarks from Fangjun
正在显示
8 个修改的文件
包含
609 行增加
和
5 行删除
| @@ -68,6 +68,7 @@ set(sources | @@ -68,6 +68,7 @@ set(sources | ||
| 68 | online-ctc-fst-decoder.cc | 68 | online-ctc-fst-decoder.cc |
| 69 | online-ctc-greedy-search-decoder.cc | 69 | online-ctc-greedy-search-decoder.cc |
| 70 | online-ctc-model.cc | 70 | online-ctc-model.cc |
| 71 | + online-ebranchformer-transducer-model.cc | ||
| 71 | online-lm-config.cc | 72 | online-lm-config.cc |
| 72 | online-lm.cc | 73 | online-lm.cc |
| 73 | online-lstm-transducer-model.cc | 74 | online-lstm-transducer-model.cc |
| @@ -48,7 +48,9 @@ std::string FeatureExtractorConfig::ToString() const { | @@ -48,7 +48,9 @@ std::string FeatureExtractorConfig::ToString() const { | ||
| 48 | os << "feature_dim=" << feature_dim << ", "; | 48 | os << "feature_dim=" << feature_dim << ", "; |
| 49 | os << "low_freq=" << low_freq << ", "; | 49 | os << "low_freq=" << low_freq << ", "; |
| 50 | os << "high_freq=" << high_freq << ", "; | 50 | os << "high_freq=" << high_freq << ", "; |
| 51 | - os << "dither=" << dither << ")"; | 51 | + os << "dither=" << dither << ", "; |
| 52 | + os << "normalize_samples=" << (normalize_samples ? "True" : "False") << ", "; | ||
| 53 | + os << "snip_edges=" << (snip_edges ? "True" : "False") << ")"; | ||
| 52 | 54 | ||
| 53 | return os.str(); | 55 | return os.str(); |
| 54 | } | 56 | } |
| 1 | +// sherpa-onnx/csrc/online-ebranchformer-transducer-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +// 2025 Brno University of Technology (author: Karel Vesely) | ||
| 5 | + | ||
| 6 | +#include "sherpa-onnx/csrc/online-ebranchformer-transducer-model.h" | ||
| 7 | + | ||
| 8 | +#include <algorithm> | ||
| 9 | +#include <cassert> | ||
| 10 | +#include <cmath> | ||
| 11 | +#include <memory> | ||
| 12 | +#include <numeric> | ||
| 13 | +#include <sstream> | ||
| 14 | +#include <string> | ||
| 15 | +#include <utility> | ||
| 16 | +#include <vector> | ||
| 17 | + | ||
| 18 | +#if __ANDROID_API__ >= 9 | ||
| 19 | +#include "android/asset_manager.h" | ||
| 20 | +#include "android/asset_manager_jni.h" | ||
| 21 | +#endif | ||
| 22 | + | ||
| 23 | +#if __OHOS__ | ||
| 24 | +#include "rawfile/raw_file_manager.h" | ||
| 25 | +#endif | ||
| 26 | + | ||
| 27 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 28 | +#include "sherpa-onnx/csrc/cat.h" | ||
| 29 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 30 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 31 | +#include "sherpa-onnx/csrc/online-transducer-decoder.h" | ||
| 32 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 33 | +#include "sherpa-onnx/csrc/session.h" | ||
| 34 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 35 | +#include "sherpa-onnx/csrc/unbind.h" | ||
| 36 | + | ||
| 37 | +namespace sherpa_onnx { | ||
| 38 | + | ||
| 39 | +OnlineEbranchformerTransducerModel::OnlineEbranchformerTransducerModel( | ||
| 40 | + const OnlineModelConfig &config) | ||
| 41 | + : env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 42 | + encoder_sess_opts_(GetSessionOptions(config)), | ||
| 43 | + decoder_sess_opts_(GetSessionOptions(config, "decoder")), | ||
| 44 | + joiner_sess_opts_(GetSessionOptions(config, "joiner")), | ||
| 45 | + config_(config), | ||
| 46 | + allocator_{} { | ||
| 47 | + { | ||
| 48 | + auto buf = ReadFile(config.transducer.encoder); | ||
| 49 | + InitEncoder(buf.data(), buf.size()); | ||
| 50 | + } | ||
| 51 | + | ||
| 52 | + { | ||
| 53 | + auto buf = ReadFile(config.transducer.decoder); | ||
| 54 | + InitDecoder(buf.data(), buf.size()); | ||
| 55 | + } | ||
| 56 | + | ||
| 57 | + { | ||
| 58 | + auto buf = ReadFile(config.transducer.joiner); | ||
| 59 | + InitJoiner(buf.data(), buf.size()); | ||
| 60 | + } | ||
| 61 | +} | ||
| 62 | + | ||
| 63 | +template <typename Manager> | ||
| 64 | +OnlineEbranchformerTransducerModel::OnlineEbranchformerTransducerModel( | ||
| 65 | + Manager *mgr, const OnlineModelConfig &config) | ||
| 66 | + : env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 67 | + config_(config), | ||
| 68 | + encoder_sess_opts_(GetSessionOptions(config)), | ||
| 69 | + decoder_sess_opts_(GetSessionOptions(config)), | ||
| 70 | + joiner_sess_opts_(GetSessionOptions(config)), | ||
| 71 | + allocator_{} { | ||
| 72 | + { | ||
| 73 | + auto buf = ReadFile(mgr, config.transducer.encoder); | ||
| 74 | + InitEncoder(buf.data(), buf.size()); | ||
| 75 | + } | ||
| 76 | + | ||
| 77 | + { | ||
| 78 | + auto buf = ReadFile(mgr, config.transducer.decoder); | ||
| 79 | + InitDecoder(buf.data(), buf.size()); | ||
| 80 | + } | ||
| 81 | + | ||
| 82 | + { | ||
| 83 | + auto buf = ReadFile(mgr, config.transducer.joiner); | ||
| 84 | + InitJoiner(buf.data(), buf.size()); | ||
| 85 | + } | ||
| 86 | +} | ||
| 87 | + | ||
| 88 | + | ||
| 89 | +void OnlineEbranchformerTransducerModel::InitEncoder(void *model_data, | ||
| 90 | + size_t model_data_length) { | ||
| 91 | + encoder_sess_ = std::make_unique<Ort::Session>( | ||
| 92 | + env_, model_data, model_data_length, encoder_sess_opts_); | ||
| 93 | + | ||
| 94 | + GetInputNames(encoder_sess_.get(), &encoder_input_names_, | ||
| 95 | + &encoder_input_names_ptr_); | ||
| 96 | + | ||
| 97 | + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, | ||
| 98 | + &encoder_output_names_ptr_); | ||
| 99 | + | ||
| 100 | + // get meta data | ||
| 101 | + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); | ||
| 102 | + if (config_.debug) { | ||
| 103 | + std::ostringstream os; | ||
| 104 | + os << "---encoder---\n"; | ||
| 105 | + PrintModelMetadata(os, meta_data); | ||
| 106 | +#if __OHOS__ | ||
| 107 | + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); | ||
| 108 | +#else | ||
| 109 | + SHERPA_ONNX_LOGE("%s", os.str().c_str()); | ||
| 110 | +#endif | ||
| 111 | + } | ||
| 112 | + | ||
| 113 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 114 | + | ||
| 115 | + SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); | ||
| 116 | + SHERPA_ONNX_READ_META_DATA(T_, "T"); | ||
| 117 | + | ||
| 118 | + SHERPA_ONNX_READ_META_DATA(num_hidden_layers_, "num_hidden_layers"); | ||
| 119 | + SHERPA_ONNX_READ_META_DATA(hidden_size_, "hidden_size"); | ||
| 120 | + SHERPA_ONNX_READ_META_DATA(intermediate_size_, "intermediate_size"); | ||
| 121 | + SHERPA_ONNX_READ_META_DATA(csgu_kernel_size_, "csgu_kernel_size"); | ||
| 122 | + SHERPA_ONNX_READ_META_DATA(merge_conv_kernel_, "merge_conv_kernel"); | ||
| 123 | + SHERPA_ONNX_READ_META_DATA(left_context_len_, "left_context_len"); | ||
| 124 | + SHERPA_ONNX_READ_META_DATA(num_heads_, "num_heads"); | ||
| 125 | + SHERPA_ONNX_READ_META_DATA(head_dim_, "head_dim"); | ||
| 126 | + | ||
| 127 | + if (config_.debug) { | ||
| 128 | +#if __OHOS__ | ||
| 129 | + SHERPA_ONNX_LOGE("T: %{public}d", T_); | ||
| 130 | + SHERPA_ONNX_LOGE("decode_chunk_len_: %{public}d", decode_chunk_len_); | ||
| 131 | + | ||
| 132 | + SHERPA_ONNX_LOGE("num_hidden_layers_: %{public}d", num_hidden_layers_); | ||
| 133 | + SHERPA_ONNX_LOGE("hidden_size_: %{public}d", hidden_size_); | ||
| 134 | + SHERPA_ONNX_LOGE("intermediate_size_: %{public}d", intermediate_size_); | ||
| 135 | + SHERPA_ONNX_LOGE("csgu_kernel_size_: %{public}d", csgu_kernel_size_); | ||
| 136 | + SHERPA_ONNX_LOGE("merge_conv_kernel_: %{public}d", merge_conv_kernel_); | ||
| 137 | + SHERPA_ONNX_LOGE("left_context_len_: %{public}d", left_context_len_); | ||
| 138 | + SHERPA_ONNX_LOGE("num_heads_: %{public}d", num_heads_); | ||
| 139 | + SHERPA_ONNX_LOGE("head_dim_: %{public}d", head_dim_); | ||
| 140 | +#else | ||
| 141 | + SHERPA_ONNX_LOGE("T: %d", T_); | ||
| 142 | + SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_); | ||
| 143 | + | ||
| 144 | + SHERPA_ONNX_LOGE("num_hidden_layers_: %d", num_hidden_layers_); | ||
| 145 | + SHERPA_ONNX_LOGE("hidden_size_: %d", hidden_size_); | ||
| 146 | + SHERPA_ONNX_LOGE("intermediate_size_: %d", intermediate_size_); | ||
| 147 | + SHERPA_ONNX_LOGE("csgu_kernel_size_: %d", csgu_kernel_size_); | ||
| 148 | + SHERPA_ONNX_LOGE("merge_conv_kernel_: %d", merge_conv_kernel_); | ||
| 149 | + SHERPA_ONNX_LOGE("left_context_len_: %d", left_context_len_); | ||
| 150 | + SHERPA_ONNX_LOGE("num_heads_: %d", num_heads_); | ||
| 151 | + SHERPA_ONNX_LOGE("head_dim_: %d", head_dim_); | ||
| 152 | +#endif | ||
| 153 | + } | ||
| 154 | +} | ||
| 155 | + | ||
| 156 | + | ||
| 157 | +void OnlineEbranchformerTransducerModel::InitDecoder(void *model_data, | ||
| 158 | + size_t model_data_length) { | ||
| 159 | + decoder_sess_ = std::make_unique<Ort::Session>( | ||
| 160 | + env_, model_data, model_data_length, decoder_sess_opts_); | ||
| 161 | + | ||
| 162 | + GetInputNames(decoder_sess_.get(), &decoder_input_names_, | ||
| 163 | + &decoder_input_names_ptr_); | ||
| 164 | + | ||
| 165 | + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, | ||
| 166 | + &decoder_output_names_ptr_); | ||
| 167 | + | ||
| 168 | + // get meta data | ||
| 169 | + Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata(); | ||
| 170 | + if (config_.debug) { | ||
| 171 | + std::ostringstream os; | ||
| 172 | + os << "---decoder---\n"; | ||
| 173 | + PrintModelMetadata(os, meta_data); | ||
| 174 | + SHERPA_ONNX_LOGE("%s", os.str().c_str()); | ||
| 175 | + } | ||
| 176 | + | ||
| 177 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 178 | + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); | ||
| 179 | + SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); | ||
| 180 | +} | ||
| 181 | + | ||
| 182 | +void OnlineEbranchformerTransducerModel::InitJoiner(void *model_data, | ||
| 183 | + size_t model_data_length) { | ||
| 184 | + joiner_sess_ = std::make_unique<Ort::Session>( | ||
| 185 | + env_, model_data, model_data_length, joiner_sess_opts_); | ||
| 186 | + | ||
| 187 | + GetInputNames(joiner_sess_.get(), &joiner_input_names_, | ||
| 188 | + &joiner_input_names_ptr_); | ||
| 189 | + | ||
| 190 | + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, | ||
| 191 | + &joiner_output_names_ptr_); | ||
| 192 | + | ||
| 193 | + // get meta data | ||
| 194 | + Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata(); | ||
| 195 | + if (config_.debug) { | ||
| 196 | + std::ostringstream os; | ||
| 197 | + os << "---joiner---\n"; | ||
| 198 | + PrintModelMetadata(os, meta_data); | ||
| 199 | + SHERPA_ONNX_LOGE("%s", os.str().c_str()); | ||
| 200 | + } | ||
| 201 | +} | ||
| 202 | + | ||
| 203 | + | ||
| 204 | +std::vector<Ort::Value> OnlineEbranchformerTransducerModel::StackStates( | ||
| 205 | + const std::vector<std::vector<Ort::Value>> &states) const { | ||
| 206 | + int32_t batch_size = static_cast<int32_t>(states.size()); | ||
| 207 | + | ||
| 208 | + std::vector<const Ort::Value *> buf(batch_size); | ||
| 209 | + | ||
| 210 | + auto allocator = | ||
| 211 | + const_cast<OnlineEbranchformerTransducerModel *>(this)->allocator_; | ||
| 212 | + | ||
| 213 | + std::vector<Ort::Value> ans; | ||
| 214 | + int32_t num_states = static_cast<int32_t>(states[0].size()); | ||
| 215 | + ans.reserve(num_states); | ||
| 216 | + | ||
| 217 | + for (int32_t i = 0; i != num_hidden_layers_; ++i) { | ||
| 218 | + { // cached_key | ||
| 219 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 220 | + buf[n] = &states[n][4 * i]; | ||
| 221 | + } | ||
| 222 | + auto v = Cat(allocator, buf, /* axis */ 0); | ||
| 223 | + ans.push_back(std::move(v)); | ||
| 224 | + } | ||
| 225 | + { // cached_value | ||
| 226 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 227 | + buf[n] = &states[n][4 * i + 1]; | ||
| 228 | + } | ||
| 229 | + auto v = Cat(allocator, buf, 0); | ||
| 230 | + ans.push_back(std::move(v)); | ||
| 231 | + } | ||
| 232 | + { // cached_conv | ||
| 233 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 234 | + buf[n] = &states[n][4 * i + 2]; | ||
| 235 | + } | ||
| 236 | + auto v = Cat(allocator, buf, 0); | ||
| 237 | + ans.push_back(std::move(v)); | ||
| 238 | + } | ||
| 239 | + { // cached_conv_fusion | ||
| 240 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 241 | + buf[n] = &states[n][4 * i + 3]; | ||
| 242 | + } | ||
| 243 | + auto v = Cat(allocator, buf, 0); | ||
| 244 | + ans.push_back(std::move(v)); | ||
| 245 | + } | ||
| 246 | + } | ||
| 247 | + | ||
| 248 | + { // processed_lens | ||
| 249 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 250 | + buf[n] = &states[n][num_states - 1]; | ||
| 251 | + } | ||
| 252 | + auto v = Cat<int64_t>(allocator, buf, 0); | ||
| 253 | + ans.push_back(std::move(v)); | ||
| 254 | + } | ||
| 255 | + | ||
| 256 | + return ans; | ||
| 257 | +} | ||
| 258 | + | ||
| 259 | + | ||
| 260 | +std::vector<std::vector<Ort::Value>> | ||
| 261 | +OnlineEbranchformerTransducerModel::UnStackStates( | ||
| 262 | + const std::vector<Ort::Value> &states) const { | ||
| 263 | + | ||
| 264 | + assert(static_cast<int32_t>(states.size()) == num_hidden_layers_ * 4 + 1); | ||
| 265 | + | ||
| 266 | + int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[0]; | ||
| 267 | + | ||
| 268 | + auto allocator = | ||
| 269 | + const_cast<OnlineEbranchformerTransducerModel *>(this)->allocator_; | ||
| 270 | + | ||
| 271 | + std::vector<std::vector<Ort::Value>> ans; | ||
| 272 | + ans.resize(batch_size); | ||
| 273 | + | ||
| 274 | + for (int32_t i = 0; i != num_hidden_layers_; ++i) { | ||
| 275 | + { // cached_key | ||
| 276 | + auto v = Unbind(allocator, &states[i * 4], /* axis */ 0); | ||
| 277 | + assert(static_cast<int32_t>(v.size()) == batch_size); | ||
| 278 | + | ||
| 279 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 280 | + ans[n].push_back(std::move(v[n])); | ||
| 281 | + } | ||
| 282 | + } | ||
| 283 | + { // cached_value | ||
| 284 | + auto v = Unbind(allocator, &states[i * 4 + 1], 0); | ||
| 285 | + assert(static_cast<int32_t>(v.size()) == batch_size); | ||
| 286 | + | ||
| 287 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 288 | + ans[n].push_back(std::move(v[n])); | ||
| 289 | + } | ||
| 290 | + } | ||
| 291 | + { // cached_conv | ||
| 292 | + auto v = Unbind(allocator, &states[i * 4 + 2], 0); | ||
| 293 | + assert(static_cast<int32_t>(v.size()) == batch_size); | ||
| 294 | + | ||
| 295 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 296 | + ans[n].push_back(std::move(v[n])); | ||
| 297 | + } | ||
| 298 | + } | ||
| 299 | + { // cached_conv_fusion | ||
| 300 | + auto v = Unbind(allocator, &states[i * 4 + 3], 0); | ||
| 301 | + assert(static_cast<int32_t>(v.size()) == batch_size); | ||
| 302 | + | ||
| 303 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 304 | + ans[n].push_back(std::move(v[n])); | ||
| 305 | + } | ||
| 306 | + } | ||
| 307 | + } | ||
| 308 | + | ||
| 309 | + { // processed_lens | ||
| 310 | + auto v = Unbind<int64_t>(allocator, &states.back(), 0); | ||
| 311 | + assert(static_cast<int32_t>(v.size()) == batch_size); | ||
| 312 | + | ||
| 313 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 314 | + ans[n].push_back(std::move(v[n])); | ||
| 315 | + } | ||
| 316 | + } | ||
| 317 | + | ||
| 318 | + return ans; | ||
| 319 | +} | ||
| 320 | + | ||
| 321 | + | ||
| 322 | +std::vector<Ort::Value> | ||
| 323 | +OnlineEbranchformerTransducerModel::GetEncoderInitStates() { | ||
| 324 | + std::vector<Ort::Value> ans; | ||
| 325 | + | ||
| 326 | + ans.reserve(num_hidden_layers_ * 4 + 1); | ||
| 327 | + | ||
| 328 | + int32_t left_context_conv = csgu_kernel_size_ - 1; | ||
| 329 | + int32_t channels_conv = intermediate_size_ / 2; | ||
| 330 | + | ||
| 331 | + int32_t left_context_conv_fusion = merge_conv_kernel_ - 1; | ||
| 332 | + int32_t channels_conv_fusion = 2 * hidden_size_; | ||
| 333 | + | ||
| 334 | + for (int32_t i = 0; i != num_hidden_layers_; ++i) { | ||
| 335 | + { // cached_key_{i} | ||
| 336 | + std::array<int64_t, 4> s{1, num_heads_, left_context_len_, head_dim_}; | ||
| 337 | + auto v = | ||
| 338 | + Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size()); | ||
| 339 | + Fill(&v, 0); | ||
| 340 | + ans.push_back(std::move(v)); | ||
| 341 | + } | ||
| 342 | + | ||
| 343 | + { // cahced_value_{i} | ||
| 344 | + std::array<int64_t, 4> s{1, num_heads_, left_context_len_, head_dim_}; | ||
| 345 | + auto v = | ||
| 346 | + Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size()); | ||
| 347 | + Fill(&v, 0); | ||
| 348 | + ans.push_back(std::move(v)); | ||
| 349 | + } | ||
| 350 | + | ||
| 351 | + { // cached_conv_{i} | ||
| 352 | + std::array<int64_t, 3> s{1, channels_conv, left_context_conv}; | ||
| 353 | + auto v = | ||
| 354 | + Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size()); | ||
| 355 | + Fill(&v, 0); | ||
| 356 | + ans.push_back(std::move(v)); | ||
| 357 | + } | ||
| 358 | + | ||
| 359 | + { // cached_conv_fusion_{i} | ||
| 360 | + std::array<int64_t, 3> s{1, channels_conv_fusion, left_context_conv_fusion}; | ||
| 361 | + auto v = | ||
| 362 | + Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size()); | ||
| 363 | + Fill(&v, 0); | ||
| 364 | + ans.push_back(std::move(v)); | ||
| 365 | + } | ||
| 366 | + } // num_hidden_layers_ | ||
| 367 | + | ||
| 368 | + { // processed_lens | ||
| 369 | + std::array<int64_t, 1> s{1}; | ||
| 370 | + auto v = Ort::Value::CreateTensor<int64_t>(allocator_, s.data(), s.size()); | ||
| 371 | + Fill<int64_t>(&v, 0); | ||
| 372 | + ans.push_back(std::move(v)); | ||
| 373 | + } | ||
| 374 | + | ||
| 375 | + return ans; | ||
| 376 | +} | ||
| 377 | + | ||
| 378 | + | ||
| 379 | +std::pair<Ort::Value, std::vector<Ort::Value>> | ||
| 380 | +OnlineEbranchformerTransducerModel::RunEncoder(Ort::Value features, | ||
| 381 | + std::vector<Ort::Value> states, | ||
| 382 | + Ort::Value /* processed_frames */) { | ||
| 383 | + std::vector<Ort::Value> encoder_inputs; | ||
| 384 | + encoder_inputs.reserve(1 + states.size()); | ||
| 385 | + | ||
| 386 | + encoder_inputs.push_back(std::move(features)); | ||
| 387 | + for (auto &v : states) { | ||
| 388 | + encoder_inputs.push_back(std::move(v)); | ||
| 389 | + } | ||
| 390 | + | ||
| 391 | + auto encoder_out = encoder_sess_->Run( | ||
| 392 | + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(), | ||
| 393 | + encoder_inputs.size(), encoder_output_names_ptr_.data(), | ||
| 394 | + encoder_output_names_ptr_.size()); | ||
| 395 | + | ||
| 396 | + std::vector<Ort::Value> next_states; | ||
| 397 | + next_states.reserve(states.size()); | ||
| 398 | + | ||
| 399 | + for (int32_t i = 1; i != static_cast<int32_t>(encoder_out.size()); ++i) { | ||
| 400 | + next_states.push_back(std::move(encoder_out[i])); | ||
| 401 | + } | ||
| 402 | + return {std::move(encoder_out[0]), std::move(next_states)}; | ||
| 403 | +} | ||
| 404 | + | ||
| 405 | + | ||
| 406 | +Ort::Value OnlineEbranchformerTransducerModel::RunDecoder( | ||
| 407 | + Ort::Value decoder_input) { | ||
| 408 | + auto decoder_out = decoder_sess_->Run( | ||
| 409 | + {}, decoder_input_names_ptr_.data(), &decoder_input, 1, | ||
| 410 | + decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size()); | ||
| 411 | + return std::move(decoder_out[0]); | ||
| 412 | +} | ||
| 413 | + | ||
| 414 | + | ||
| 415 | +Ort::Value OnlineEbranchformerTransducerModel::RunJoiner(Ort::Value encoder_out, | ||
| 416 | + Ort::Value decoder_out) { | ||
| 417 | + std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out), | ||
| 418 | + std::move(decoder_out)}; | ||
| 419 | + auto logit = | ||
| 420 | + joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(), | ||
| 421 | + joiner_input.size(), joiner_output_names_ptr_.data(), | ||
| 422 | + joiner_output_names_ptr_.size()); | ||
| 423 | + | ||
| 424 | + return std::move(logit[0]); | ||
| 425 | +} | ||
| 426 | + | ||
| 427 | + | ||
| 428 | +#if __ANDROID_API__ >= 9 | ||
| 429 | +template OnlineEbranchformerTransducerModel::OnlineEbranchformerTransducerModel( | ||
| 430 | + AAssetManager *mgr, const OnlineModelConfig &config); | ||
| 431 | +#endif | ||
| 432 | + | ||
| 433 | +#if __OHOS__ | ||
| 434 | +template OnlineEbranchformerTransducerModel::OnlineEbranchformerTransducerModel( | ||
| 435 | + NativeResourceManager *mgr, const OnlineModelConfig &config); | ||
| 436 | +#endif | ||
| 437 | + | ||
| 438 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/online-ebranchformer-transducer-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +// 2025 Brno University of Technology (author: Karel Vesely) | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_EBRANCHFORMER_TRANSDUCER_MODEL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_ONLINE_EBRANCHFORMER_TRANSDUCER_MODEL_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | +#include <string> | ||
| 10 | +#include <utility> | ||
| 11 | +#include <vector> | ||
| 12 | + | ||
| 13 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 14 | +#include "sherpa-onnx/csrc/online-model-config.h" | ||
| 15 | +#include "sherpa-onnx/csrc/online-transducer-model.h" | ||
| 16 | + | ||
| 17 | +namespace sherpa_onnx { | ||
| 18 | + | ||
| 19 | +class OnlineEbranchformerTransducerModel : public OnlineTransducerModel { | ||
| 20 | + public: | ||
| 21 | + explicit OnlineEbranchformerTransducerModel(const OnlineModelConfig &config); | ||
| 22 | + | ||
| 23 | + template <typename Manager> | ||
| 24 | + OnlineEbranchformerTransducerModel(Manager *mgr, | ||
| 25 | + const OnlineModelConfig &config); | ||
| 26 | + | ||
| 27 | + std::vector<Ort::Value> StackStates( | ||
| 28 | + const std::vector<std::vector<Ort::Value>> &states) const override; | ||
| 29 | + | ||
| 30 | + std::vector<std::vector<Ort::Value>> UnStackStates( | ||
| 31 | + const std::vector<Ort::Value> &states) const override; | ||
| 32 | + | ||
| 33 | + std::vector<Ort::Value> GetEncoderInitStates() override; | ||
| 34 | + | ||
| 35 | + void SetFeatureDim(int32_t feature_dim) override { | ||
| 36 | + feature_dim_ = feature_dim; | ||
| 37 | + } | ||
| 38 | + | ||
| 39 | + std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( | ||
| 40 | + Ort::Value features, std::vector<Ort::Value> states, | ||
| 41 | + Ort::Value processed_frames) override; | ||
| 42 | + | ||
| 43 | + Ort::Value RunDecoder(Ort::Value decoder_input) override; | ||
| 44 | + | ||
| 45 | + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override; | ||
| 46 | + | ||
| 47 | + int32_t ContextSize() const override { return context_size_; } | ||
| 48 | + | ||
| 49 | + int32_t ChunkSize() const override { return T_; } | ||
| 50 | + | ||
| 51 | + int32_t ChunkShift() const override { return decode_chunk_len_; } | ||
| 52 | + | ||
| 53 | + int32_t VocabSize() const override { return vocab_size_; } | ||
| 54 | + OrtAllocator *Allocator() override { return allocator_; } | ||
| 55 | + | ||
| 56 | + private: | ||
| 57 | + void InitEncoder(void *model_data, size_t model_data_length); | ||
| 58 | + void InitDecoder(void *model_data, size_t model_data_length); | ||
| 59 | + void InitJoiner(void *model_data, size_t model_data_length); | ||
| 60 | + | ||
| 61 | + private: | ||
| 62 | + Ort::Env env_; | ||
| 63 | + Ort::SessionOptions encoder_sess_opts_; | ||
| 64 | + Ort::SessionOptions decoder_sess_opts_; | ||
| 65 | + Ort::SessionOptions joiner_sess_opts_; | ||
| 66 | + | ||
| 67 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 68 | + | ||
| 69 | + std::unique_ptr<Ort::Session> encoder_sess_; | ||
| 70 | + std::unique_ptr<Ort::Session> decoder_sess_; | ||
| 71 | + std::unique_ptr<Ort::Session> joiner_sess_; | ||
| 72 | + | ||
| 73 | + std::vector<std::string> encoder_input_names_; | ||
| 74 | + std::vector<const char *> encoder_input_names_ptr_; | ||
| 75 | + | ||
| 76 | + std::vector<std::string> encoder_output_names_; | ||
| 77 | + std::vector<const char *> encoder_output_names_ptr_; | ||
| 78 | + | ||
| 79 | + std::vector<std::string> decoder_input_names_; | ||
| 80 | + std::vector<const char *> decoder_input_names_ptr_; | ||
| 81 | + | ||
| 82 | + std::vector<std::string> decoder_output_names_; | ||
| 83 | + std::vector<const char *> decoder_output_names_ptr_; | ||
| 84 | + | ||
| 85 | + std::vector<std::string> joiner_input_names_; | ||
| 86 | + std::vector<const char *> joiner_input_names_ptr_; | ||
| 87 | + | ||
| 88 | + std::vector<std::string> joiner_output_names_; | ||
| 89 | + std::vector<const char *> joiner_output_names_ptr_; | ||
| 90 | + | ||
| 91 | + OnlineModelConfig config_; | ||
| 92 | + | ||
| 93 | + int32_t decode_chunk_len_ = 0; | ||
| 94 | + int32_t T_ = 0; | ||
| 95 | + | ||
| 96 | + int32_t num_hidden_layers_ = 0; | ||
| 97 | + int32_t hidden_size_ = 0; | ||
| 98 | + int32_t intermediate_size_ = 0; | ||
| 99 | + int32_t csgu_kernel_size_ = 0; | ||
| 100 | + int32_t merge_conv_kernel_ = 0; | ||
| 101 | + int32_t left_context_len_ = 0; | ||
| 102 | + int32_t num_heads_ = 0; | ||
| 103 | + int32_t head_dim_ = 0; | ||
| 104 | + | ||
| 105 | + int32_t context_size_ = 0; | ||
| 106 | + int32_t vocab_size_ = 0; | ||
| 107 | + int32_t feature_dim_ = 80; | ||
| 108 | +}; | ||
| 109 | + | ||
| 110 | +} // namespace sherpa_onnx | ||
| 111 | + | ||
| 112 | +#endif // SHERPA_ONNX_CSRC_ONLINE_EBRANCHFORMER_TRANSDUCER_MODEL_H_ |
| @@ -21,6 +21,7 @@ | @@ -21,6 +21,7 @@ | ||
| 21 | #include "sherpa-onnx/csrc/file-utils.h" | 21 | #include "sherpa-onnx/csrc/file-utils.h" |
| 22 | #include "sherpa-onnx/csrc/macros.h" | 22 | #include "sherpa-onnx/csrc/macros.h" |
| 23 | #include "sherpa-onnx/csrc/online-conformer-transducer-model.h" | 23 | #include "sherpa-onnx/csrc/online-conformer-transducer-model.h" |
| 24 | +#include "sherpa-onnx/csrc/online-ebranchformer-transducer-model.h" | ||
| 24 | #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" | 25 | #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" |
| 25 | #include "sherpa-onnx/csrc/online-zipformer-transducer-model.h" | 26 | #include "sherpa-onnx/csrc/online-zipformer-transducer-model.h" |
| 26 | #include "sherpa-onnx/csrc/online-zipformer2-transducer-model.h" | 27 | #include "sherpa-onnx/csrc/online-zipformer2-transducer-model.h" |
| @@ -30,6 +31,7 @@ namespace { | @@ -30,6 +31,7 @@ namespace { | ||
| 30 | 31 | ||
| 31 | enum class ModelType : std::uint8_t { | 32 | enum class ModelType : std::uint8_t { |
| 32 | kConformer, | 33 | kConformer, |
| 34 | + kEbranchformer, | ||
| 33 | kLstm, | 35 | kLstm, |
| 34 | kZipformer, | 36 | kZipformer, |
| 35 | kZipformer2, | 37 | kZipformer2, |
| @@ -74,6 +76,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -74,6 +76,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 74 | 76 | ||
| 75 | if (model_type == "conformer") { | 77 | if (model_type == "conformer") { |
| 76 | return ModelType::kConformer; | 78 | return ModelType::kConformer; |
| 79 | + } else if (model_type == "ebranchformer") { | ||
| 80 | + return ModelType::kEbranchformer; | ||
| 77 | } else if (model_type == "lstm") { | 81 | } else if (model_type == "lstm") { |
| 78 | return ModelType::kLstm; | 82 | return ModelType::kLstm; |
| 79 | } else if (model_type == "zipformer") { | 83 | } else if (model_type == "zipformer") { |
| @@ -92,6 +96,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | @@ -92,6 +96,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | ||
| 92 | const auto &model_type = config.model_type; | 96 | const auto &model_type = config.model_type; |
| 93 | if (model_type == "conformer") { | 97 | if (model_type == "conformer") { |
| 94 | return std::make_unique<OnlineConformerTransducerModel>(config); | 98 | return std::make_unique<OnlineConformerTransducerModel>(config); |
| 99 | + } else if (model_type == "ebranchformer") { | ||
| 100 | + return std::make_unique<OnlineEbranchformerTransducerModel>(config); | ||
| 95 | } else if (model_type == "lstm") { | 101 | } else if (model_type == "lstm") { |
| 96 | return std::make_unique<OnlineLstmTransducerModel>(config); | 102 | return std::make_unique<OnlineLstmTransducerModel>(config); |
| 97 | } else if (model_type == "zipformer") { | 103 | } else if (model_type == "zipformer") { |
| @@ -115,6 +121,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | @@ -115,6 +121,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | ||
| 115 | switch (model_type) { | 121 | switch (model_type) { |
| 116 | case ModelType::kConformer: | 122 | case ModelType::kConformer: |
| 117 | return std::make_unique<OnlineConformerTransducerModel>(config); | 123 | return std::make_unique<OnlineConformerTransducerModel>(config); |
| 124 | + case ModelType::kEbranchformer: | ||
| 125 | + return std::make_unique<OnlineEbranchformerTransducerModel>(config); | ||
| 118 | case ModelType::kLstm: | 126 | case ModelType::kLstm: |
| 119 | return std::make_unique<OnlineLstmTransducerModel>(config); | 127 | return std::make_unique<OnlineLstmTransducerModel>(config); |
| 120 | case ModelType::kZipformer: | 128 | case ModelType::kZipformer: |
| @@ -171,6 +179,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | @@ -171,6 +179,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | ||
| 171 | const auto &model_type = config.model_type; | 179 | const auto &model_type = config.model_type; |
| 172 | if (model_type == "conformer") { | 180 | if (model_type == "conformer") { |
| 173 | return std::make_unique<OnlineConformerTransducerModel>(mgr, config); | 181 | return std::make_unique<OnlineConformerTransducerModel>(mgr, config); |
| 182 | + } else if (model_type == "ebranchformer") { | ||
| 183 | + return std::make_unique<OnlineEbranchformerTransducerModel>(mgr, config); | ||
| 174 | } else if (model_type == "lstm") { | 184 | } else if (model_type == "lstm") { |
| 175 | return std::make_unique<OnlineLstmTransducerModel>(mgr, config); | 185 | return std::make_unique<OnlineLstmTransducerModel>(mgr, config); |
| 176 | } else if (model_type == "zipformer") { | 186 | } else if (model_type == "zipformer") { |
| @@ -190,6 +200,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | @@ -190,6 +200,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | ||
| 190 | switch (model_type) { | 200 | switch (model_type) { |
| 191 | case ModelType::kConformer: | 201 | case ModelType::kConformer: |
| 192 | return std::make_unique<OnlineConformerTransducerModel>(mgr, config); | 202 | return std::make_unique<OnlineConformerTransducerModel>(mgr, config); |
| 203 | + case ModelType::kEbranchformer: | ||
| 204 | + return std::make_unique<OnlineEbranchformerTransducerModel>(mgr, config); | ||
| 193 | case ModelType::kLstm: | 205 | case ModelType::kLstm: |
| 194 | return std::make_unique<OnlineLstmTransducerModel>(mgr, config); | 206 | return std::make_unique<OnlineLstmTransducerModel>(mgr, config); |
| 195 | case ModelType::kZipformer: | 207 | case ModelType::kZipformer: |
| @@ -11,15 +11,21 @@ namespace sherpa_onnx { | @@ -11,15 +11,21 @@ namespace sherpa_onnx { | ||
| 11 | static void PybindFeatureExtractorConfig(py::module *m) { | 11 | static void PybindFeatureExtractorConfig(py::module *m) { |
| 12 | using PyClass = FeatureExtractorConfig; | 12 | using PyClass = FeatureExtractorConfig; |
| 13 | py::class_<PyClass>(*m, "FeatureExtractorConfig") | 13 | py::class_<PyClass>(*m, "FeatureExtractorConfig") |
| 14 | - .def(py::init<int32_t, int32_t, float, float, float>(), | ||
| 15 | - py::arg("sampling_rate") = 16000, py::arg("feature_dim") = 80, | ||
| 16 | - py::arg("low_freq") = 20.0f, py::arg("high_freq") = -400.0f, | ||
| 17 | - py::arg("dither") = 0.0f) | 14 | + .def(py::init<int32_t, int32_t, float, float, float, bool, bool>(), |
| 15 | + py::arg("sampling_rate") = 16000, | ||
| 16 | + py::arg("feature_dim") = 80, | ||
| 17 | + py::arg("low_freq") = 20.0f, | ||
| 18 | + py::arg("high_freq") = -400.0f, | ||
| 19 | + py::arg("dither") = 0.0f, | ||
| 20 | + py::arg("normalize_samples") = true, | ||
| 21 | + py::arg("snip_edges") = false) | ||
| 18 | .def_readwrite("sampling_rate", &PyClass::sampling_rate) | 22 | .def_readwrite("sampling_rate", &PyClass::sampling_rate) |
| 19 | .def_readwrite("feature_dim", &PyClass::feature_dim) | 23 | .def_readwrite("feature_dim", &PyClass::feature_dim) |
| 20 | .def_readwrite("low_freq", &PyClass::low_freq) | 24 | .def_readwrite("low_freq", &PyClass::low_freq) |
| 21 | .def_readwrite("high_freq", &PyClass::high_freq) | 25 | .def_readwrite("high_freq", &PyClass::high_freq) |
| 22 | .def_readwrite("dither", &PyClass::dither) | 26 | .def_readwrite("dither", &PyClass::dither) |
| 27 | + .def_readwrite("normalize_samples", &PyClass::normalize_samples) | ||
| 28 | + .def_readwrite("snip_edges", &PyClass::snip_edges) | ||
| 23 | .def("__str__", &PyClass::ToString); | 29 | .def("__str__", &PyClass::ToString); |
| 24 | } | 30 | } |
| 25 | 31 |
| @@ -22,6 +22,23 @@ Args: | @@ -22,6 +22,23 @@ Args: | ||
| 22 | to the range [-1, 1]. | 22 | to the range [-1, 1]. |
| 23 | )"; | 23 | )"; |
| 24 | 24 | ||
| 25 | + | ||
| 26 | +constexpr const char *kGetFramesUsage = R"( | ||
| 27 | +Get n frames starting from the given frame index. | ||
| 28 | +(hint: intended for debugging, for comparing FBANK features across pipelines) | ||
| 29 | + | ||
| 30 | +Args: | ||
| 31 | + frame_index: | ||
| 32 | + The starting frame index | ||
| 33 | + n: | ||
| 34 | + Number of frames to get. | ||
| 35 | +Return: | ||
| 36 | + Return a 2-D tensor of shape (n, feature_dim). | ||
| 37 | + which is flattened into a 1-D vector (flattened in row major). | ||
| 38 | + Unflatten in python with: | ||
| 39 | + `features = np.reshape(arr, (n, feature_dim))` | ||
| 40 | +)"; | ||
| 41 | + | ||
| 25 | void PybindOnlineStream(py::module *m) { | 42 | void PybindOnlineStream(py::module *m) { |
| 26 | using PyClass = OnlineStream; | 43 | using PyClass = OnlineStream; |
| 27 | py::class_<PyClass>(*m, "OnlineStream") | 44 | py::class_<PyClass>(*m, "OnlineStream") |
| @@ -34,6 +51,9 @@ void PybindOnlineStream(py::module *m) { | @@ -34,6 +51,9 @@ void PybindOnlineStream(py::module *m) { | ||
| 34 | py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage, | 51 | py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage, |
| 35 | py::call_guard<py::gil_scoped_release>()) | 52 | py::call_guard<py::gil_scoped_release>()) |
| 36 | .def("input_finished", &PyClass::InputFinished, | 53 | .def("input_finished", &PyClass::InputFinished, |
| 54 | + py::call_guard<py::gil_scoped_release>()) | ||
| 55 | + .def("get_frames", &PyClass::GetFrames, | ||
| 56 | + py::arg("frame_index"), py::arg("n"), kGetFramesUsage, | ||
| 37 | py::call_guard<py::gil_scoped_release>()); | 57 | py::call_guard<py::gil_scoped_release>()); |
| 38 | } | 58 | } |
| 39 | 59 |
| @@ -50,6 +50,8 @@ class OnlineRecognizer(object): | @@ -50,6 +50,8 @@ class OnlineRecognizer(object): | ||
| 50 | low_freq: float = 20.0, | 50 | low_freq: float = 20.0, |
| 51 | high_freq: float = -400.0, | 51 | high_freq: float = -400.0, |
| 52 | dither: float = 0.0, | 52 | dither: float = 0.0, |
| 53 | + normalize_samples: bool = True, | ||
| 54 | + snip_edges: bool = False, | ||
| 53 | enable_endpoint_detection: bool = False, | 55 | enable_endpoint_detection: bool = False, |
| 54 | rule1_min_trailing_silence: float = 2.4, | 56 | rule1_min_trailing_silence: float = 2.4, |
| 55 | rule2_min_trailing_silence: float = 1.2, | 57 | rule2_min_trailing_silence: float = 1.2, |
| @@ -118,6 +120,15 @@ class OnlineRecognizer(object): | @@ -118,6 +120,15 @@ class OnlineRecognizer(object): | ||
| 118 | By default the audio samples are in range [-1,+1], | 120 | By default the audio samples are in range [-1,+1], |
| 119 | so dithering constant 0.00003 is a good value, | 121 | so dithering constant 0.00003 is a good value, |
| 120 | equivalent to the default 1.0 from kaldi | 122 | equivalent to the default 1.0 from kaldi |
| 123 | + normalize_samples: | ||
| 124 | + True for +/- 1.0 range of audio samples (default, zipformer feats), | ||
| 125 | + False for +/- 32k samples (ebranchformer features). | ||
| 126 | + snip_edges: | ||
| 127 | + handling of end of audio signal in kaldi feature extraction. | ||
| 128 | + If true, end effects will be handled by outputting only frames that | ||
| 129 | + completely fit in the file, and the number of frames depends on the | ||
| 130 | + frame-length. If false, the number of frames depends only on the | ||
| 131 | + frame-shift, and we reflect the data at the ends. | ||
| 121 | enable_endpoint_detection: | 132 | enable_endpoint_detection: |
| 122 | True to enable endpoint detection. False to disable endpoint | 133 | True to enable endpoint detection. False to disable endpoint |
| 123 | detection. | 134 | detection. |
| @@ -248,6 +259,8 @@ class OnlineRecognizer(object): | @@ -248,6 +259,8 @@ class OnlineRecognizer(object): | ||
| 248 | 259 | ||
| 249 | feat_config = FeatureExtractorConfig( | 260 | feat_config = FeatureExtractorConfig( |
| 250 | sampling_rate=sample_rate, | 261 | sampling_rate=sample_rate, |
| 262 | + normalize_samples=normalize_samples, | ||
| 263 | + snip_edges=snip_edges, | ||
| 251 | feature_dim=feature_dim, | 264 | feature_dim=feature_dim, |
| 252 | low_freq=low_freq, | 265 | low_freq=low_freq, |
| 253 | high_freq=high_freq, | 266 | high_freq=high_freq, |
-
请 注册 或 登录 后发表评论