Karel Vesely
Committed by GitHub

Ebranchformer (#1951)

* adding ebranchformer encoder

* extend surfaced FeatureExtractorConfig

- so ebranchformer feature extraction can be configured from Python
- the GlobCmvn is not needed, as it is a module in the OnnxEncoder

* clean the code

* Integrating remarks from Fangjun
@@ -68,6 +68,7 @@ set(sources @@ -68,6 +68,7 @@ set(sources
68 online-ctc-fst-decoder.cc 68 online-ctc-fst-decoder.cc
69 online-ctc-greedy-search-decoder.cc 69 online-ctc-greedy-search-decoder.cc
70 online-ctc-model.cc 70 online-ctc-model.cc
  71 + online-ebranchformer-transducer-model.cc
71 online-lm-config.cc 72 online-lm-config.cc
72 online-lm.cc 73 online-lm.cc
73 online-lstm-transducer-model.cc 74 online-lstm-transducer-model.cc
@@ -48,7 +48,9 @@ std::string FeatureExtractorConfig::ToString() const { @@ -48,7 +48,9 @@ std::string FeatureExtractorConfig::ToString() const {
48 os << "feature_dim=" << feature_dim << ", "; 48 os << "feature_dim=" << feature_dim << ", ";
49 os << "low_freq=" << low_freq << ", "; 49 os << "low_freq=" << low_freq << ", ";
50 os << "high_freq=" << high_freq << ", "; 50 os << "high_freq=" << high_freq << ", ";
51 - os << "dither=" << dither << ")"; 51 + os << "dither=" << dither << ", ";
  52 + os << "normalize_samples=" << (normalize_samples ? "True" : "False") << ", ";
  53 + os << "snip_edges=" << (snip_edges ? "True" : "False") << ")";
52 54
53 return os.str(); 55 return os.str();
54 } 56 }
  1 +// sherpa-onnx/csrc/online-ebranchformer-transducer-model.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +// 2025 Brno University of Technology (author: Karel Vesely)
  5 +
  6 +#include "sherpa-onnx/csrc/online-ebranchformer-transducer-model.h"
  7 +
  8 +#include <algorithm>
  9 +#include <cassert>
  10 +#include <cmath>
  11 +#include <memory>
  12 +#include <numeric>
  13 +#include <sstream>
  14 +#include <string>
  15 +#include <utility>
  16 +#include <vector>
  17 +
  18 +#if __ANDROID_API__ >= 9
  19 +#include "android/asset_manager.h"
  20 +#include "android/asset_manager_jni.h"
  21 +#endif
  22 +
  23 +#if __OHOS__
  24 +#include "rawfile/raw_file_manager.h"
  25 +#endif
  26 +
  27 +#include "onnxruntime_cxx_api.h" // NOLINT
  28 +#include "sherpa-onnx/csrc/cat.h"
  29 +#include "sherpa-onnx/csrc/file-utils.h"
  30 +#include "sherpa-onnx/csrc/macros.h"
  31 +#include "sherpa-onnx/csrc/online-transducer-decoder.h"
  32 +#include "sherpa-onnx/csrc/onnx-utils.h"
  33 +#include "sherpa-onnx/csrc/session.h"
  34 +#include "sherpa-onnx/csrc/text-utils.h"
  35 +#include "sherpa-onnx/csrc/unbind.h"
  36 +
  37 +namespace sherpa_onnx {
  38 +
  39 +OnlineEbranchformerTransducerModel::OnlineEbranchformerTransducerModel(
  40 + const OnlineModelConfig &config)
  41 + : env_(ORT_LOGGING_LEVEL_ERROR),
  42 + encoder_sess_opts_(GetSessionOptions(config)),
  43 + decoder_sess_opts_(GetSessionOptions(config, "decoder")),
  44 + joiner_sess_opts_(GetSessionOptions(config, "joiner")),
  45 + config_(config),
  46 + allocator_{} {
  47 + {
  48 + auto buf = ReadFile(config.transducer.encoder);
  49 + InitEncoder(buf.data(), buf.size());
  50 + }
  51 +
  52 + {
  53 + auto buf = ReadFile(config.transducer.decoder);
  54 + InitDecoder(buf.data(), buf.size());
  55 + }
  56 +
  57 + {
  58 + auto buf = ReadFile(config.transducer.joiner);
  59 + InitJoiner(buf.data(), buf.size());
  60 + }
  61 +}
  62 +
  63 +template <typename Manager>
  64 +OnlineEbranchformerTransducerModel::OnlineEbranchformerTransducerModel(
  65 + Manager *mgr, const OnlineModelConfig &config)
  66 + : env_(ORT_LOGGING_LEVEL_ERROR),
  67 + config_(config),
  68 + encoder_sess_opts_(GetSessionOptions(config)),
  69 + decoder_sess_opts_(GetSessionOptions(config)),
  70 + joiner_sess_opts_(GetSessionOptions(config)),
  71 + allocator_{} {
  72 + {
  73 + auto buf = ReadFile(mgr, config.transducer.encoder);
  74 + InitEncoder(buf.data(), buf.size());
  75 + }
  76 +
  77 + {
  78 + auto buf = ReadFile(mgr, config.transducer.decoder);
  79 + InitDecoder(buf.data(), buf.size());
  80 + }
  81 +
  82 + {
  83 + auto buf = ReadFile(mgr, config.transducer.joiner);
  84 + InitJoiner(buf.data(), buf.size());
  85 + }
  86 +}
  87 +
  88 +
  89 +void OnlineEbranchformerTransducerModel::InitEncoder(void *model_data,
  90 + size_t model_data_length) {
  91 + encoder_sess_ = std::make_unique<Ort::Session>(
  92 + env_, model_data, model_data_length, encoder_sess_opts_);
  93 +
  94 + GetInputNames(encoder_sess_.get(), &encoder_input_names_,
  95 + &encoder_input_names_ptr_);
  96 +
  97 + GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
  98 + &encoder_output_names_ptr_);
  99 +
  100 + // get meta data
  101 + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
  102 + if (config_.debug) {
  103 + std::ostringstream os;
  104 + os << "---encoder---\n";
  105 + PrintModelMetadata(os, meta_data);
  106 +#if __OHOS__
  107 + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str());
  108 +#else
  109 + SHERPA_ONNX_LOGE("%s", os.str().c_str());
  110 +#endif
  111 + }
  112 +
  113 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  114 +
  115 + SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len");
  116 + SHERPA_ONNX_READ_META_DATA(T_, "T");
  117 +
  118 + SHERPA_ONNX_READ_META_DATA(num_hidden_layers_, "num_hidden_layers");
  119 + SHERPA_ONNX_READ_META_DATA(hidden_size_, "hidden_size");
  120 + SHERPA_ONNX_READ_META_DATA(intermediate_size_, "intermediate_size");
  121 + SHERPA_ONNX_READ_META_DATA(csgu_kernel_size_, "csgu_kernel_size");
  122 + SHERPA_ONNX_READ_META_DATA(merge_conv_kernel_, "merge_conv_kernel");
  123 + SHERPA_ONNX_READ_META_DATA(left_context_len_, "left_context_len");
  124 + SHERPA_ONNX_READ_META_DATA(num_heads_, "num_heads");
  125 + SHERPA_ONNX_READ_META_DATA(head_dim_, "head_dim");
  126 +
  127 + if (config_.debug) {
  128 +#if __OHOS__
  129 + SHERPA_ONNX_LOGE("T: %{public}d", T_);
  130 + SHERPA_ONNX_LOGE("decode_chunk_len_: %{public}d", decode_chunk_len_);
  131 +
  132 + SHERPA_ONNX_LOGE("num_hidden_layers_: %{public}d", num_hidden_layers_);
  133 + SHERPA_ONNX_LOGE("hidden_size_: %{public}d", hidden_size_);
  134 + SHERPA_ONNX_LOGE("intermediate_size_: %{public}d", intermediate_size_);
  135 + SHERPA_ONNX_LOGE("csgu_kernel_size_: %{public}d", csgu_kernel_size_);
  136 + SHERPA_ONNX_LOGE("merge_conv_kernel_: %{public}d", merge_conv_kernel_);
  137 + SHERPA_ONNX_LOGE("left_context_len_: %{public}d", left_context_len_);
  138 + SHERPA_ONNX_LOGE("num_heads_: %{public}d", num_heads_);
  139 + SHERPA_ONNX_LOGE("head_dim_: %{public}d", head_dim_);
  140 +#else
  141 + SHERPA_ONNX_LOGE("T: %d", T_);
  142 + SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_);
  143 +
  144 + SHERPA_ONNX_LOGE("num_hidden_layers_: %d", num_hidden_layers_);
  145 + SHERPA_ONNX_LOGE("hidden_size_: %d", hidden_size_);
  146 + SHERPA_ONNX_LOGE("intermediate_size_: %d", intermediate_size_);
  147 + SHERPA_ONNX_LOGE("csgu_kernel_size_: %d", csgu_kernel_size_);
  148 + SHERPA_ONNX_LOGE("merge_conv_kernel_: %d", merge_conv_kernel_);
  149 + SHERPA_ONNX_LOGE("left_context_len_: %d", left_context_len_);
  150 + SHERPA_ONNX_LOGE("num_heads_: %d", num_heads_);
  151 + SHERPA_ONNX_LOGE("head_dim_: %d", head_dim_);
  152 +#endif
  153 + }
  154 +}
  155 +
  156 +
  157 +void OnlineEbranchformerTransducerModel::InitDecoder(void *model_data,
  158 + size_t model_data_length) {
  159 + decoder_sess_ = std::make_unique<Ort::Session>(
  160 + env_, model_data, model_data_length, decoder_sess_opts_);
  161 +
  162 + GetInputNames(decoder_sess_.get(), &decoder_input_names_,
  163 + &decoder_input_names_ptr_);
  164 +
  165 + GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
  166 + &decoder_output_names_ptr_);
  167 +
  168 + // get meta data
  169 + Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata();
  170 + if (config_.debug) {
  171 + std::ostringstream os;
  172 + os << "---decoder---\n";
  173 + PrintModelMetadata(os, meta_data);
  174 + SHERPA_ONNX_LOGE("%s", os.str().c_str());
  175 + }
  176 +
  177 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  178 + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
  179 + SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
  180 +}
  181 +
  182 +void OnlineEbranchformerTransducerModel::InitJoiner(void *model_data,
  183 + size_t model_data_length) {
  184 + joiner_sess_ = std::make_unique<Ort::Session>(
  185 + env_, model_data, model_data_length, joiner_sess_opts_);
  186 +
  187 + GetInputNames(joiner_sess_.get(), &joiner_input_names_,
  188 + &joiner_input_names_ptr_);
  189 +
  190 + GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
  191 + &joiner_output_names_ptr_);
  192 +
  193 + // get meta data
  194 + Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata();
  195 + if (config_.debug) {
  196 + std::ostringstream os;
  197 + os << "---joiner---\n";
  198 + PrintModelMetadata(os, meta_data);
  199 + SHERPA_ONNX_LOGE("%s", os.str().c_str());
  200 + }
  201 +}
  202 +
  203 +
  204 +std::vector<Ort::Value> OnlineEbranchformerTransducerModel::StackStates(
  205 + const std::vector<std::vector<Ort::Value>> &states) const {
  206 + int32_t batch_size = static_cast<int32_t>(states.size());
  207 +
  208 + std::vector<const Ort::Value *> buf(batch_size);
  209 +
  210 + auto allocator =
  211 + const_cast<OnlineEbranchformerTransducerModel *>(this)->allocator_;
  212 +
  213 + std::vector<Ort::Value> ans;
  214 + int32_t num_states = static_cast<int32_t>(states[0].size());
  215 + ans.reserve(num_states);
  216 +
  217 + for (int32_t i = 0; i != num_hidden_layers_; ++i) {
  218 + { // cached_key
  219 + for (int32_t n = 0; n != batch_size; ++n) {
  220 + buf[n] = &states[n][4 * i];
  221 + }
  222 + auto v = Cat(allocator, buf, /* axis */ 0);
  223 + ans.push_back(std::move(v));
  224 + }
  225 + { // cached_value
  226 + for (int32_t n = 0; n != batch_size; ++n) {
  227 + buf[n] = &states[n][4 * i + 1];
  228 + }
  229 + auto v = Cat(allocator, buf, 0);
  230 + ans.push_back(std::move(v));
  231 + }
  232 + { // cached_conv
  233 + for (int32_t n = 0; n != batch_size; ++n) {
  234 + buf[n] = &states[n][4 * i + 2];
  235 + }
  236 + auto v = Cat(allocator, buf, 0);
  237 + ans.push_back(std::move(v));
  238 + }
  239 + { // cached_conv_fusion
  240 + for (int32_t n = 0; n != batch_size; ++n) {
  241 + buf[n] = &states[n][4 * i + 3];
  242 + }
  243 + auto v = Cat(allocator, buf, 0);
  244 + ans.push_back(std::move(v));
  245 + }
  246 + }
  247 +
  248 + { // processed_lens
  249 + for (int32_t n = 0; n != batch_size; ++n) {
  250 + buf[n] = &states[n][num_states - 1];
  251 + }
  252 + auto v = Cat<int64_t>(allocator, buf, 0);
  253 + ans.push_back(std::move(v));
  254 + }
  255 +
  256 + return ans;
  257 +}
  258 +
  259 +
  260 +std::vector<std::vector<Ort::Value>>
  261 +OnlineEbranchformerTransducerModel::UnStackStates(
  262 + const std::vector<Ort::Value> &states) const {
  263 +
  264 + assert(static_cast<int32_t>(states.size()) == num_hidden_layers_ * 4 + 1);
  265 +
  266 + int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[0];
  267 +
  268 + auto allocator =
  269 + const_cast<OnlineEbranchformerTransducerModel *>(this)->allocator_;
  270 +
  271 + std::vector<std::vector<Ort::Value>> ans;
  272 + ans.resize(batch_size);
  273 +
  274 + for (int32_t i = 0; i != num_hidden_layers_; ++i) {
  275 + { // cached_key
  276 + auto v = Unbind(allocator, &states[i * 4], /* axis */ 0);
  277 + assert(static_cast<int32_t>(v.size()) == batch_size);
  278 +
  279 + for (int32_t n = 0; n != batch_size; ++n) {
  280 + ans[n].push_back(std::move(v[n]));
  281 + }
  282 + }
  283 + { // cached_value
  284 + auto v = Unbind(allocator, &states[i * 4 + 1], 0);
  285 + assert(static_cast<int32_t>(v.size()) == batch_size);
  286 +
  287 + for (int32_t n = 0; n != batch_size; ++n) {
  288 + ans[n].push_back(std::move(v[n]));
  289 + }
  290 + }
  291 + { // cached_conv
  292 + auto v = Unbind(allocator, &states[i * 4 + 2], 0);
  293 + assert(static_cast<int32_t>(v.size()) == batch_size);
  294 +
  295 + for (int32_t n = 0; n != batch_size; ++n) {
  296 + ans[n].push_back(std::move(v[n]));
  297 + }
  298 + }
  299 + { // cached_conv_fusion
  300 + auto v = Unbind(allocator, &states[i * 4 + 3], 0);
  301 + assert(static_cast<int32_t>(v.size()) == batch_size);
  302 +
  303 + for (int32_t n = 0; n != batch_size; ++n) {
  304 + ans[n].push_back(std::move(v[n]));
  305 + }
  306 + }
  307 + }
  308 +
  309 + { // processed_lens
  310 + auto v = Unbind<int64_t>(allocator, &states.back(), 0);
  311 + assert(static_cast<int32_t>(v.size()) == batch_size);
  312 +
  313 + for (int32_t n = 0; n != batch_size; ++n) {
  314 + ans[n].push_back(std::move(v[n]));
  315 + }
  316 + }
  317 +
  318 + return ans;
  319 +}
  320 +
  321 +
  322 +std::vector<Ort::Value>
  323 +OnlineEbranchformerTransducerModel::GetEncoderInitStates() {
  324 + std::vector<Ort::Value> ans;
  325 +
  326 + ans.reserve(num_hidden_layers_ * 4 + 1);
  327 +
  328 + int32_t left_context_conv = csgu_kernel_size_ - 1;
  329 + int32_t channels_conv = intermediate_size_ / 2;
  330 +
  331 + int32_t left_context_conv_fusion = merge_conv_kernel_ - 1;
  332 + int32_t channels_conv_fusion = 2 * hidden_size_;
  333 +
  334 + for (int32_t i = 0; i != num_hidden_layers_; ++i) {
  335 + { // cached_key_{i}
  336 + std::array<int64_t, 4> s{1, num_heads_, left_context_len_, head_dim_};
  337 + auto v =
  338 + Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
  339 + Fill(&v, 0);
  340 + ans.push_back(std::move(v));
  341 + }
  342 +
  343 + { // cahced_value_{i}
  344 + std::array<int64_t, 4> s{1, num_heads_, left_context_len_, head_dim_};
  345 + auto v =
  346 + Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
  347 + Fill(&v, 0);
  348 + ans.push_back(std::move(v));
  349 + }
  350 +
  351 + { // cached_conv_{i}
  352 + std::array<int64_t, 3> s{1, channels_conv, left_context_conv};
  353 + auto v =
  354 + Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
  355 + Fill(&v, 0);
  356 + ans.push_back(std::move(v));
  357 + }
  358 +
  359 + { // cached_conv_fusion_{i}
  360 + std::array<int64_t, 3> s{1, channels_conv_fusion, left_context_conv_fusion};
  361 + auto v =
  362 + Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
  363 + Fill(&v, 0);
  364 + ans.push_back(std::move(v));
  365 + }
  366 + } // num_hidden_layers_
  367 +
  368 + { // processed_lens
  369 + std::array<int64_t, 1> s{1};
  370 + auto v = Ort::Value::CreateTensor<int64_t>(allocator_, s.data(), s.size());
  371 + Fill<int64_t>(&v, 0);
  372 + ans.push_back(std::move(v));
  373 + }
  374 +
  375 + return ans;
  376 +}
  377 +
  378 +
  379 +std::pair<Ort::Value, std::vector<Ort::Value>>
  380 +OnlineEbranchformerTransducerModel::RunEncoder(Ort::Value features,
  381 + std::vector<Ort::Value> states,
  382 + Ort::Value /* processed_frames */) {
  383 + std::vector<Ort::Value> encoder_inputs;
  384 + encoder_inputs.reserve(1 + states.size());
  385 +
  386 + encoder_inputs.push_back(std::move(features));
  387 + for (auto &v : states) {
  388 + encoder_inputs.push_back(std::move(v));
  389 + }
  390 +
  391 + auto encoder_out = encoder_sess_->Run(
  392 + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
  393 + encoder_inputs.size(), encoder_output_names_ptr_.data(),
  394 + encoder_output_names_ptr_.size());
  395 +
  396 + std::vector<Ort::Value> next_states;
  397 + next_states.reserve(states.size());
  398 +
  399 + for (int32_t i = 1; i != static_cast<int32_t>(encoder_out.size()); ++i) {
  400 + next_states.push_back(std::move(encoder_out[i]));
  401 + }
  402 + return {std::move(encoder_out[0]), std::move(next_states)};
  403 +}
  404 +
  405 +
  406 +Ort::Value OnlineEbranchformerTransducerModel::RunDecoder(
  407 + Ort::Value decoder_input) {
  408 + auto decoder_out = decoder_sess_->Run(
  409 + {}, decoder_input_names_ptr_.data(), &decoder_input, 1,
  410 + decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size());
  411 + return std::move(decoder_out[0]);
  412 +}
  413 +
  414 +
  415 +Ort::Value OnlineEbranchformerTransducerModel::RunJoiner(Ort::Value encoder_out,
  416 + Ort::Value decoder_out) {
  417 + std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
  418 + std::move(decoder_out)};
  419 + auto logit =
  420 + joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(),
  421 + joiner_input.size(), joiner_output_names_ptr_.data(),
  422 + joiner_output_names_ptr_.size());
  423 +
  424 + return std::move(logit[0]);
  425 +}
  426 +
  427 +
  428 +#if __ANDROID_API__ >= 9
  429 +template OnlineEbranchformerTransducerModel::OnlineEbranchformerTransducerModel(
  430 + AAssetManager *mgr, const OnlineModelConfig &config);
  431 +#endif
  432 +
  433 +#if __OHOS__
  434 +template OnlineEbranchformerTransducerModel::OnlineEbranchformerTransducerModel(
  435 + NativeResourceManager *mgr, const OnlineModelConfig &config);
  436 +#endif
  437 +
  438 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-ebranchformer-transducer-model.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +// 2025 Brno University of Technology (author: Karel Vesely)
  5 +#ifndef SHERPA_ONNX_CSRC_ONLINE_EBRANCHFORMER_TRANSDUCER_MODEL_H_
  6 +#define SHERPA_ONNX_CSRC_ONLINE_EBRANCHFORMER_TRANSDUCER_MODEL_H_
  7 +
  8 +#include <memory>
  9 +#include <string>
  10 +#include <utility>
  11 +#include <vector>
  12 +
  13 +#include "onnxruntime_cxx_api.h" // NOLINT
  14 +#include "sherpa-onnx/csrc/online-model-config.h"
  15 +#include "sherpa-onnx/csrc/online-transducer-model.h"
  16 +
  17 +namespace sherpa_onnx {
  18 +
  19 +class OnlineEbranchformerTransducerModel : public OnlineTransducerModel {
  20 + public:
  21 + explicit OnlineEbranchformerTransducerModel(const OnlineModelConfig &config);
  22 +
  23 + template <typename Manager>
  24 + OnlineEbranchformerTransducerModel(Manager *mgr,
  25 + const OnlineModelConfig &config);
  26 +
  27 + std::vector<Ort::Value> StackStates(
  28 + const std::vector<std::vector<Ort::Value>> &states) const override;
  29 +
  30 + std::vector<std::vector<Ort::Value>> UnStackStates(
  31 + const std::vector<Ort::Value> &states) const override;
  32 +
  33 + std::vector<Ort::Value> GetEncoderInitStates() override;
  34 +
  35 + void SetFeatureDim(int32_t feature_dim) override {
  36 + feature_dim_ = feature_dim;
  37 + }
  38 +
  39 + std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
  40 + Ort::Value features, std::vector<Ort::Value> states,
  41 + Ort::Value processed_frames) override;
  42 +
  43 + Ort::Value RunDecoder(Ort::Value decoder_input) override;
  44 +
  45 + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override;
  46 +
  47 + int32_t ContextSize() const override { return context_size_; }
  48 +
  49 + int32_t ChunkSize() const override { return T_; }
  50 +
  51 + int32_t ChunkShift() const override { return decode_chunk_len_; }
  52 +
  53 + int32_t VocabSize() const override { return vocab_size_; }
  54 + OrtAllocator *Allocator() override { return allocator_; }
  55 +
  56 + private:
  57 + void InitEncoder(void *model_data, size_t model_data_length);
  58 + void InitDecoder(void *model_data, size_t model_data_length);
  59 + void InitJoiner(void *model_data, size_t model_data_length);
  60 +
  61 + private:
  62 + Ort::Env env_;
  63 + Ort::SessionOptions encoder_sess_opts_;
  64 + Ort::SessionOptions decoder_sess_opts_;
  65 + Ort::SessionOptions joiner_sess_opts_;
  66 +
  67 + Ort::AllocatorWithDefaultOptions allocator_;
  68 +
  69 + std::unique_ptr<Ort::Session> encoder_sess_;
  70 + std::unique_ptr<Ort::Session> decoder_sess_;
  71 + std::unique_ptr<Ort::Session> joiner_sess_;
  72 +
  73 + std::vector<std::string> encoder_input_names_;
  74 + std::vector<const char *> encoder_input_names_ptr_;
  75 +
  76 + std::vector<std::string> encoder_output_names_;
  77 + std::vector<const char *> encoder_output_names_ptr_;
  78 +
  79 + std::vector<std::string> decoder_input_names_;
  80 + std::vector<const char *> decoder_input_names_ptr_;
  81 +
  82 + std::vector<std::string> decoder_output_names_;
  83 + std::vector<const char *> decoder_output_names_ptr_;
  84 +
  85 + std::vector<std::string> joiner_input_names_;
  86 + std::vector<const char *> joiner_input_names_ptr_;
  87 +
  88 + std::vector<std::string> joiner_output_names_;
  89 + std::vector<const char *> joiner_output_names_ptr_;
  90 +
  91 + OnlineModelConfig config_;
  92 +
  93 + int32_t decode_chunk_len_ = 0;
  94 + int32_t T_ = 0;
  95 +
  96 + int32_t num_hidden_layers_ = 0;
  97 + int32_t hidden_size_ = 0;
  98 + int32_t intermediate_size_ = 0;
  99 + int32_t csgu_kernel_size_ = 0;
  100 + int32_t merge_conv_kernel_ = 0;
  101 + int32_t left_context_len_ = 0;
  102 + int32_t num_heads_ = 0;
  103 + int32_t head_dim_ = 0;
  104 +
  105 + int32_t context_size_ = 0;
  106 + int32_t vocab_size_ = 0;
  107 + int32_t feature_dim_ = 80;
  108 +};
  109 +
  110 +} // namespace sherpa_onnx
  111 +
  112 +#endif // SHERPA_ONNX_CSRC_ONLINE_EBRANCHFORMER_TRANSDUCER_MODEL_H_
@@ -21,6 +21,7 @@ @@ -21,6 +21,7 @@
21 #include "sherpa-onnx/csrc/file-utils.h" 21 #include "sherpa-onnx/csrc/file-utils.h"
22 #include "sherpa-onnx/csrc/macros.h" 22 #include "sherpa-onnx/csrc/macros.h"
23 #include "sherpa-onnx/csrc/online-conformer-transducer-model.h" 23 #include "sherpa-onnx/csrc/online-conformer-transducer-model.h"
  24 +#include "sherpa-onnx/csrc/online-ebranchformer-transducer-model.h"
24 #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" 25 #include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
25 #include "sherpa-onnx/csrc/online-zipformer-transducer-model.h" 26 #include "sherpa-onnx/csrc/online-zipformer-transducer-model.h"
26 #include "sherpa-onnx/csrc/online-zipformer2-transducer-model.h" 27 #include "sherpa-onnx/csrc/online-zipformer2-transducer-model.h"
@@ -30,6 +31,7 @@ namespace { @@ -30,6 +31,7 @@ namespace {
30 31
31 enum class ModelType : std::uint8_t { 32 enum class ModelType : std::uint8_t {
32 kConformer, 33 kConformer,
  34 + kEbranchformer,
33 kLstm, 35 kLstm,
34 kZipformer, 36 kZipformer,
35 kZipformer2, 37 kZipformer2,
@@ -74,6 +76,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -74,6 +76,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
74 76
75 if (model_type == "conformer") { 77 if (model_type == "conformer") {
76 return ModelType::kConformer; 78 return ModelType::kConformer;
  79 + } else if (model_type == "ebranchformer") {
  80 + return ModelType::kEbranchformer;
77 } else if (model_type == "lstm") { 81 } else if (model_type == "lstm") {
78 return ModelType::kLstm; 82 return ModelType::kLstm;
79 } else if (model_type == "zipformer") { 83 } else if (model_type == "zipformer") {
@@ -92,6 +96,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( @@ -92,6 +96,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
92 const auto &model_type = config.model_type; 96 const auto &model_type = config.model_type;
93 if (model_type == "conformer") { 97 if (model_type == "conformer") {
94 return std::make_unique<OnlineConformerTransducerModel>(config); 98 return std::make_unique<OnlineConformerTransducerModel>(config);
  99 + } else if (model_type == "ebranchformer") {
  100 + return std::make_unique<OnlineEbranchformerTransducerModel>(config);
95 } else if (model_type == "lstm") { 101 } else if (model_type == "lstm") {
96 return std::make_unique<OnlineLstmTransducerModel>(config); 102 return std::make_unique<OnlineLstmTransducerModel>(config);
97 } else if (model_type == "zipformer") { 103 } else if (model_type == "zipformer") {
@@ -115,6 +121,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( @@ -115,6 +121,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
115 switch (model_type) { 121 switch (model_type) {
116 case ModelType::kConformer: 122 case ModelType::kConformer:
117 return std::make_unique<OnlineConformerTransducerModel>(config); 123 return std::make_unique<OnlineConformerTransducerModel>(config);
  124 + case ModelType::kEbranchformer:
  125 + return std::make_unique<OnlineEbranchformerTransducerModel>(config);
118 case ModelType::kLstm: 126 case ModelType::kLstm:
119 return std::make_unique<OnlineLstmTransducerModel>(config); 127 return std::make_unique<OnlineLstmTransducerModel>(config);
120 case ModelType::kZipformer: 128 case ModelType::kZipformer:
@@ -171,6 +179,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( @@ -171,6 +179,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
171 const auto &model_type = config.model_type; 179 const auto &model_type = config.model_type;
172 if (model_type == "conformer") { 180 if (model_type == "conformer") {
173 return std::make_unique<OnlineConformerTransducerModel>(mgr, config); 181 return std::make_unique<OnlineConformerTransducerModel>(mgr, config);
  182 + } else if (model_type == "ebranchformer") {
  183 + return std::make_unique<OnlineEbranchformerTransducerModel>(mgr, config);
174 } else if (model_type == "lstm") { 184 } else if (model_type == "lstm") {
175 return std::make_unique<OnlineLstmTransducerModel>(mgr, config); 185 return std::make_unique<OnlineLstmTransducerModel>(mgr, config);
176 } else if (model_type == "zipformer") { 186 } else if (model_type == "zipformer") {
@@ -190,6 +200,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( @@ -190,6 +200,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
190 switch (model_type) { 200 switch (model_type) {
191 case ModelType::kConformer: 201 case ModelType::kConformer:
192 return std::make_unique<OnlineConformerTransducerModel>(mgr, config); 202 return std::make_unique<OnlineConformerTransducerModel>(mgr, config);
  203 + case ModelType::kEbranchformer:
  204 + return std::make_unique<OnlineEbranchformerTransducerModel>(mgr, config);
193 case ModelType::kLstm: 205 case ModelType::kLstm:
194 return std::make_unique<OnlineLstmTransducerModel>(mgr, config); 206 return std::make_unique<OnlineLstmTransducerModel>(mgr, config);
195 case ModelType::kZipformer: 207 case ModelType::kZipformer:
@@ -11,15 +11,21 @@ namespace sherpa_onnx { @@ -11,15 +11,21 @@ namespace sherpa_onnx {
11 static void PybindFeatureExtractorConfig(py::module *m) { 11 static void PybindFeatureExtractorConfig(py::module *m) {
12 using PyClass = FeatureExtractorConfig; 12 using PyClass = FeatureExtractorConfig;
13 py::class_<PyClass>(*m, "FeatureExtractorConfig") 13 py::class_<PyClass>(*m, "FeatureExtractorConfig")
14 - .def(py::init<int32_t, int32_t, float, float, float>(),  
15 - py::arg("sampling_rate") = 16000, py::arg("feature_dim") = 80,  
16 - py::arg("low_freq") = 20.0f, py::arg("high_freq") = -400.0f,  
17 - py::arg("dither") = 0.0f) 14 + .def(py::init<int32_t, int32_t, float, float, float, bool, bool>(),
  15 + py::arg("sampling_rate") = 16000,
  16 + py::arg("feature_dim") = 80,
  17 + py::arg("low_freq") = 20.0f,
  18 + py::arg("high_freq") = -400.0f,
  19 + py::arg("dither") = 0.0f,
  20 + py::arg("normalize_samples") = true,
  21 + py::arg("snip_edges") = false)
18 .def_readwrite("sampling_rate", &PyClass::sampling_rate) 22 .def_readwrite("sampling_rate", &PyClass::sampling_rate)
19 .def_readwrite("feature_dim", &PyClass::feature_dim) 23 .def_readwrite("feature_dim", &PyClass::feature_dim)
20 .def_readwrite("low_freq", &PyClass::low_freq) 24 .def_readwrite("low_freq", &PyClass::low_freq)
21 .def_readwrite("high_freq", &PyClass::high_freq) 25 .def_readwrite("high_freq", &PyClass::high_freq)
22 .def_readwrite("dither", &PyClass::dither) 26 .def_readwrite("dither", &PyClass::dither)
  27 + .def_readwrite("normalize_samples", &PyClass::normalize_samples)
  28 + .def_readwrite("snip_edges", &PyClass::snip_edges)
23 .def("__str__", &PyClass::ToString); 29 .def("__str__", &PyClass::ToString);
24 } 30 }
25 31
@@ -22,6 +22,23 @@ Args: @@ -22,6 +22,23 @@ Args:
22 to the range [-1, 1]. 22 to the range [-1, 1].
23 )"; 23 )";
24 24
  25 +
  26 +constexpr const char *kGetFramesUsage = R"(
  27 +Get n frames starting from the given frame index.
  28 +(hint: intended for debugging, for comparing FBANK features across pipelines)
  29 +
  30 +Args:
  31 + frame_index:
  32 + The starting frame index
  33 + n:
  34 + Number of frames to get.
  35 +Return:
  36 + Return a 2-D tensor of shape (n, feature_dim).
  37 + which is flattened into a 1-D vector (flattened in row major).
  38 + Unflatten in python with:
  39 + `features = np.reshape(arr, (n, feature_dim))`
  40 +)";
  41 +
25 void PybindOnlineStream(py::module *m) { 42 void PybindOnlineStream(py::module *m) {
26 using PyClass = OnlineStream; 43 using PyClass = OnlineStream;
27 py::class_<PyClass>(*m, "OnlineStream") 44 py::class_<PyClass>(*m, "OnlineStream")
@@ -34,6 +51,9 @@ void PybindOnlineStream(py::module *m) { @@ -34,6 +51,9 @@ void PybindOnlineStream(py::module *m) {
34 py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage, 51 py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage,
35 py::call_guard<py::gil_scoped_release>()) 52 py::call_guard<py::gil_scoped_release>())
36 .def("input_finished", &PyClass::InputFinished, 53 .def("input_finished", &PyClass::InputFinished,
  54 + py::call_guard<py::gil_scoped_release>())
  55 + .def("get_frames", &PyClass::GetFrames,
  56 + py::arg("frame_index"), py::arg("n"), kGetFramesUsage,
37 py::call_guard<py::gil_scoped_release>()); 57 py::call_guard<py::gil_scoped_release>());
38 } 58 }
39 59
@@ -50,6 +50,8 @@ class OnlineRecognizer(object): @@ -50,6 +50,8 @@ class OnlineRecognizer(object):
50 low_freq: float = 20.0, 50 low_freq: float = 20.0,
51 high_freq: float = -400.0, 51 high_freq: float = -400.0,
52 dither: float = 0.0, 52 dither: float = 0.0,
  53 + normalize_samples: bool = True,
  54 + snip_edges: bool = False,
53 enable_endpoint_detection: bool = False, 55 enable_endpoint_detection: bool = False,
54 rule1_min_trailing_silence: float = 2.4, 56 rule1_min_trailing_silence: float = 2.4,
55 rule2_min_trailing_silence: float = 1.2, 57 rule2_min_trailing_silence: float = 1.2,
@@ -118,6 +120,15 @@ class OnlineRecognizer(object): @@ -118,6 +120,15 @@ class OnlineRecognizer(object):
118 By default the audio samples are in range [-1,+1], 120 By default the audio samples are in range [-1,+1],
119 so dithering constant 0.00003 is a good value, 121 so dithering constant 0.00003 is a good value,
120 equivalent to the default 1.0 from kaldi 122 equivalent to the default 1.0 from kaldi
  123 + normalize_samples:
  124 + True for +/- 1.0 range of audio samples (default, zipformer feats),
  125 + False for +/- 32k samples (ebranchformer features).
  126 + snip_edges:
  127 + handling of end of audio signal in kaldi feature extraction.
  128 + If true, end effects will be handled by outputting only frames that
  129 + completely fit in the file, and the number of frames depends on the
  130 + frame-length. If false, the number of frames depends only on the
  131 + frame-shift, and we reflect the data at the ends.
121 enable_endpoint_detection: 132 enable_endpoint_detection:
122 True to enable endpoint detection. False to disable endpoint 133 True to enable endpoint detection. False to disable endpoint
123 detection. 134 detection.
@@ -248,6 +259,8 @@ class OnlineRecognizer(object): @@ -248,6 +259,8 @@ class OnlineRecognizer(object):
248 259
249 feat_config = FeatureExtractorConfig( 260 feat_config = FeatureExtractorConfig(
250 sampling_rate=sample_rate, 261 sampling_rate=sample_rate,
  262 + normalize_samples=normalize_samples,
  263 + snip_edges=snip_edges,
251 feature_dim=feature_dim, 264 feature_dim=feature_dim,
252 low_freq=low_freq, 265 low_freq=low_freq,
253 high_freq=high_freq, 266 high_freq=high_freq,