Sangeet Sagar
Committed by GitHub

Add C++ runtime for *streaming* faster conformer transducer from NeMo. (#889)

Co-authored-by: sangeet2020 <15uec053@gmail.com>
@@ -74,6 +74,8 @@ set(sources @@ -74,6 +74,8 @@ set(sources
74 online-transducer-model-config.cc 74 online-transducer-model-config.cc
75 online-transducer-model.cc 75 online-transducer-model.cc
76 online-transducer-modified-beam-search-decoder.cc 76 online-transducer-modified-beam-search-decoder.cc
  77 + online-transducer-nemo-model.cc
  78 + online-transducer-greedy-search-nemo-decoder.cc
77 online-wenet-ctc-model-config.cc 79 online-wenet-ctc-model-config.cc
78 online-wenet-ctc-model.cc 80 online-wenet-ctc-model.cc
79 online-zipformer-transducer-model.cc 81 online-zipformer-transducer-model.cc
@@ -7,13 +7,28 @@ @@ -7,13 +7,28 @@
7 #include "sherpa-onnx/csrc/online-recognizer-ctc-impl.h" 7 #include "sherpa-onnx/csrc/online-recognizer-ctc-impl.h"
8 #include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h" 8 #include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h"
9 #include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h" 9 #include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h"
  10 +#include "sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h"
  11 +#include "sherpa-onnx/csrc/onnx-utils.h"
10 12
11 namespace sherpa_onnx { 13 namespace sherpa_onnx {
12 14
13 std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( 15 std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
14 const OnlineRecognizerConfig &config) { 16 const OnlineRecognizerConfig &config) {
  17 +
15 if (!config.model_config.transducer.encoder.empty()) { 18 if (!config.model_config.transducer.encoder.empty()) {
16 - return std::make_unique<OnlineRecognizerTransducerImpl>(config); 19 + Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
  20 +
  21 + auto decoder_model = ReadFile(config.model_config.transducer.decoder);
  22 + auto sess = std::make_unique<Ort::Session>(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});
  23 +
  24 + size_t node_count = sess->GetOutputCount();
  25 +
  26 + if (node_count == 1) {
  27 + return std::make_unique<OnlineRecognizerTransducerImpl>(config);
  28 + } else {
  29 + SHERPA_ONNX_LOGE("Running streaming Nemo transducer model");
  30 + return std::make_unique<OnlineRecognizerTransducerNeMoImpl>(config);
  31 + }
17 } 32 }
18 33
19 if (!config.model_config.paraformer.encoder.empty()) { 34 if (!config.model_config.paraformer.encoder.empty()) {
@@ -34,7 +49,18 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( @@ -34,7 +49,18 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
34 std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( 49 std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
35 AAssetManager *mgr, const OnlineRecognizerConfig &config) { 50 AAssetManager *mgr, const OnlineRecognizerConfig &config) {
36 if (!config.model_config.transducer.encoder.empty()) { 51 if (!config.model_config.transducer.encoder.empty()) {
37 - return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config); 52 + Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
  53 +
  54 + auto decoder_model = ReadFile(mgr, config.model_config.transducer.decoder);
  55 + auto sess = std::make_unique<Ort::Session>(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});
  56 +
  57 + size_t node_count = sess->GetOutputCount();
  58 +
  59 + if (node_count == 1) {
  60 + return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config);
  61 + } else {
  62 + return std::make_unique<OnlineRecognizerTransducerNeMoImpl>(mgr, config);
  63 + }
38 } 64 }
39 65
40 if (!config.model_config.paraformer.encoder.empty()) { 66 if (!config.model_config.paraformer.encoder.empty()) {
@@ -46,6 +46,7 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, @@ -46,6 +46,7 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
46 r.timestamps.reserve(src.tokens.size()); 46 r.timestamps.reserve(src.tokens.size());
47 47
48 for (auto i : src.tokens) { 48 for (auto i : src.tokens) {
  49 + if (i == -1) continue;
49 auto sym = sym_table[i]; 50 auto sym = sym_table[i];
50 51
51 r.text.append(sym); 52 r.text.append(sym);
  1 +// sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h
  2 +//
  3 +// Copyright (c) 2022-2024 Xiaomi Corporation
  4 +// Copyright (c) 2024 Sangeet Sagar
  5 +
  6 +#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
  7 +#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
  8 +
  9 +#include <fstream>
  10 +#include <ios>
  11 +#include <memory>
  12 +#include <regex> // NOLINT
  13 +#include <sstream>
  14 +#include <string>
  15 +#include <utility>
  16 +#include <vector>
  17 +
  18 +#if __ANDROID_API__ >= 9
  19 +#include "android/asset_manager.h"
  20 +#include "android/asset_manager_jni.h"
  21 +#endif
  22 +
  23 +#include "sherpa-onnx/csrc/macros.h"
  24 +#include "sherpa-onnx/csrc/online-recognizer-impl.h"
  25 +#include "sherpa-onnx/csrc/online-recognizer.h"
  26 +#include "sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h"
  27 +#include "sherpa-onnx/csrc/online-transducer-nemo-model.h"
  28 +#include "sherpa-onnx/csrc/symbol-table.h"
  29 +#include "sherpa-onnx/csrc/transpose.h"
  30 +#include "sherpa-onnx/csrc/utils.h"
  31 +
  32 +namespace sherpa_onnx {
  33 +
  34 +// defined in ./online-recognizer-transducer-impl.h
  35 +// static may or may not be here? TODDOs
  36 +static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
  37 + const SymbolTable &sym_table,
  38 + float frame_shift_ms,
  39 + int32_t subsampling_factor,
  40 + int32_t segment,
  41 + int32_t frames_since_start);
  42 +
  43 +class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
  44 + public:
  45 + explicit OnlineRecognizerTransducerNeMoImpl(
  46 + const OnlineRecognizerConfig &config)
  47 + : config_(config),
  48 + symbol_table_(config.model_config.tokens),
  49 + endpoint_(config_.endpoint_config),
  50 + model_(std::make_unique<OnlineTransducerNeMoModel>(
  51 + config.model_config)) {
  52 + if (config.decoding_method == "greedy_search") {
  53 + decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>(
  54 + model_.get(), config_.blank_penalty);
  55 + } else {
  56 + SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
  57 + config.decoding_method.c_str());
  58 + exit(-1);
  59 + }
  60 + PostInit();
  61 + }
  62 +
  63 +#if __ANDROID_API__ >= 9
  64 + explicit OnlineRecognizerTransducerNeMoImpl(
  65 + AAssetManager *mgr, const OnlineRecognizerConfig &config)
  66 + : config_(config),
  67 + symbol_table_(mgr, config.model_config.tokens),
  68 + endpoint_(mgrconfig_.endpoint_config),
  69 + model_(std::make_unique<OnlineTransducerNeMoModel>(
  70 + mgr, config.model_config)) {
  71 + if (config.decoding_method == "greedy_search") {
  72 + decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>(
  73 + model_.get(), config_.blank_penalty);
  74 + } else {
  75 + SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
  76 + config.decoding_method.c_str());
  77 + exit(-1);
  78 + }
  79 +
  80 + PostInit();
  81 + }
  82 +#endif
  83 +
  84 + std::unique_ptr<OnlineStream> CreateStream() const override {
  85 + auto stream = std::make_unique<OnlineStream>(config_.feat_config);
  86 + stream->SetStates(model_->GetInitStates());
  87 + InitOnlineStream(stream.get());
  88 + return stream;
  89 + }
  90 +
  91 + bool IsReady(OnlineStream *s) const override {
  92 + return s->GetNumProcessedFrames() + model_->ChunkSize() <
  93 + s->NumFramesReady();
  94 + }
  95 +
  96 + OnlineRecognizerResult GetResult(OnlineStream *s) const override {
  97 + OnlineTransducerDecoderResult decoder_result = s->GetResult();
  98 + decoder_->StripLeadingBlanks(&decoder_result);
  99 +
  100 + // TODO(fangjun): Remember to change these constants if needed
  101 + int32_t frame_shift_ms = 10;
  102 + int32_t subsampling_factor = 8;
  103 + return Convert(decoder_result, symbol_table_, frame_shift_ms, subsampling_factor,
  104 + s->GetCurrentSegment(), s->GetNumFramesSinceStart());
  105 + }
  106 +
  107 + bool IsEndpoint(OnlineStream *s) const override {
  108 + if (!config_.enable_endpoint) {
  109 + return false;
  110 + }
  111 +
  112 + int32_t num_processed_frames = s->GetNumProcessedFrames();
  113 +
  114 + // frame shift is 10 milliseconds
  115 + float frame_shift_in_seconds = 0.01;
  116 +
  117 + // subsampling factor is 8
  118 + int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 8;
  119 +
  120 + return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames,
  121 + frame_shift_in_seconds);
  122 + }
  123 +
  124 + void Reset(OnlineStream *s) const override {
  125 + {
  126 + // segment is incremented only when the last
  127 + // result is not empty
  128 + const auto &r = s->GetResult();
  129 + if (!r.tokens.empty() && r.tokens.back() != 0) {
  130 + s->GetCurrentSegment() += 1;
  131 + }
  132 + }
  133 +
  134 + // we keep the decoder_out
  135 + decoder_->UpdateDecoderOut(&s->GetResult());
  136 + Ort::Value decoder_out = std::move(s->GetResult().decoder_out);
  137 +
  138 + auto r = decoder_->GetEmptyResult();
  139 +
  140 + s->SetResult(r);
  141 + s->GetResult().decoder_out = std::move(decoder_out);
  142 +
  143 + // Note: We only update counters. The underlying audio samples
  144 + // are not discarded.
  145 + s->Reset();
  146 + }
  147 +
  148 + void DecodeStreams(OnlineStream **ss, int32_t n) const override {
  149 + int32_t chunk_size = model_->ChunkSize();
  150 + int32_t chunk_shift = model_->ChunkShift();
  151 +
  152 + int32_t feature_dim = ss[0]->FeatureDim();
  153 +
  154 + std::vector<OnlineTransducerDecoderResult> result(n);
  155 + std::vector<float> features_vec(n * chunk_size * feature_dim);
  156 + std::vector<std::vector<Ort::Value>> encoder_states(n);
  157 +
  158 + for (int32_t i = 0; i != n; ++i) {
  159 + const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
  160 + std::vector<float> features =
  161 + ss[i]->GetFrames(num_processed_frames, chunk_size);
  162 +
  163 + // Question: should num_processed_frames include chunk_shift?
  164 + ss[i]->GetNumProcessedFrames() += chunk_shift;
  165 +
  166 + std::copy(features.begin(), features.end(),
  167 + features_vec.data() + i * chunk_size * feature_dim);
  168 +
  169 + result[i] = std::move(ss[i]->GetResult());
  170 + encoder_states[i] = std::move(ss[i]->GetStates());
  171 +
  172 + }
  173 +
  174 + auto memory_info =
  175 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  176 +
  177 + std::array<int64_t, 3> x_shape{n, chunk_size, feature_dim};
  178 +
  179 + Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(),
  180 + features_vec.size(), x_shape.data(),
  181 + x_shape.size());
  182 +
  183 + // Batch size is 1
  184 + auto states = std::move(encoder_states[0]);
  185 + int32_t num_states = states.size(); // num_states = 3
  186 + auto t = model_->RunEncoder(std::move(x), std::move(states));
  187 + // t[0] encoder_out, float tensor, (batch_size, dim, T)
  188 + // t[1] next states
  189 +
  190 + std::vector<Ort::Value> out_states;
  191 + out_states.reserve(num_states);
  192 +
  193 + for (int32_t k = 1; k != num_states + 1; ++k) {
  194 + out_states.push_back(std::move(t[k]));
  195 + }
  196 +
  197 + Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]);
  198 +
  199 + // defined in online-transducer-greedy-search-nemo-decoder.h
  200 + // get intial states of decoder.
  201 + std::vector<Ort::Value> &decoder_states = ss[0]->GetNeMoDecoderStates();
  202 +
  203 + // Subsequent decoder states (for each chunks) are updated inside the Decode method.
  204 + // This returns the decoder state from the LAST chunk. We probably dont need it. So we can discard it.
  205 + decoder_states = decoder_->Decode(std::move(encoder_out),
  206 + std::move(decoder_states),
  207 + &result, ss, n);
  208 +
  209 + ss[0]->SetResult(result[0]);
  210 +
  211 + ss[0]->SetStates(std::move(out_states));
  212 + }
  213 +
  214 + void InitOnlineStream(OnlineStream *stream) const {
  215 + auto r = decoder_->GetEmptyResult();
  216 +
  217 + stream->SetResult(r);
  218 + stream->SetNeMoDecoderStates(model_->GetDecoderInitStates(1));
  219 + }
  220 +
  221 + private:
  222 + void PostInit() {
  223 + config_.feat_config.nemo_normalize_type =
  224 + model_->FeatureNormalizationMethod();
  225 +
  226 + config_.feat_config.low_freq = 0;
  227 + // config_.feat_config.high_freq = 8000;
  228 + config_.feat_config.is_librosa = true;
  229 + config_.feat_config.remove_dc_offset = false;
  230 + // config_.feat_config.window_type = "hann";
  231 + config_.feat_config.dither = 0;
  232 + config_.feat_config.nemo_normalize_type =
  233 + model_->FeatureNormalizationMethod();
  234 +
  235 + int32_t vocab_size = model_->VocabSize();
  236 +
  237 + // check the blank ID
  238 + if (!symbol_table_.Contains("<blk>")) {
  239 + SHERPA_ONNX_LOGE("tokens.txt does not include the blank token <blk>");
  240 + exit(-1);
  241 + }
  242 +
  243 + if (symbol_table_["<blk>"] != vocab_size - 1) {
  244 + SHERPA_ONNX_LOGE("<blk> is not the last token!");
  245 + exit(-1);
  246 + }
  247 +
  248 + if (symbol_table_.NumSymbols() != vocab_size) {
  249 + SHERPA_ONNX_LOGE("number of lines in tokens.txt %d != %d (vocab_size)",
  250 + symbol_table_.NumSymbols(), vocab_size);
  251 + exit(-1);
  252 + }
  253 +
  254 + }
  255 +
  256 + private:
  257 + OnlineRecognizerConfig config_;
  258 + SymbolTable symbol_table_;
  259 + std::unique_ptr<OnlineTransducerNeMoModel> model_;
  260 + std::unique_ptr<OnlineTransducerGreedySearchNeMoDecoder> decoder_;
  261 + Endpoint endpoint_;
  262 +
  263 +};
  264 +
  265 +} // namespace sherpa_onnx
  266 +
  267 +#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
@@ -90,6 +90,12 @@ class OnlineStream::Impl { @@ -90,6 +90,12 @@ class OnlineStream::Impl {
90 90
91 std::vector<Ort::Value> &GetStates() { return states_; } 91 std::vector<Ort::Value> &GetStates() { return states_; }
92 92
  93 + void SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states) {
  94 + decoder_states_ = std::move(decoder_states);
  95 + }
  96 +
  97 + std::vector<Ort::Value> &GetNeMoDecoderStates() { return decoder_states_; }
  98 +
93 const ContextGraphPtr &GetContextGraph() const { return context_graph_; } 99 const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
94 100
95 std::vector<float> &GetParaformerFeatCache() { 101 std::vector<float> &GetParaformerFeatCache() {
@@ -129,6 +135,7 @@ class OnlineStream::Impl { @@ -129,6 +135,7 @@ class OnlineStream::Impl {
129 TransducerKeywordResult empty_keyword_result_; 135 TransducerKeywordResult empty_keyword_result_;
130 OnlineCtcDecoderResult ctc_result_; 136 OnlineCtcDecoderResult ctc_result_;
131 std::vector<Ort::Value> states_; // states for transducer or ctc models 137 std::vector<Ort::Value> states_; // states for transducer or ctc models
  138 + std::vector<Ort::Value> decoder_states_; // states for nemo transducer models
132 std::vector<float> paraformer_feat_cache_; 139 std::vector<float> paraformer_feat_cache_;
133 std::vector<float> paraformer_encoder_out_cache_; 140 std::vector<float> paraformer_encoder_out_cache_;
134 std::vector<float> paraformer_alpha_cache_; 141 std::vector<float> paraformer_alpha_cache_;
@@ -218,6 +225,14 @@ std::vector<Ort::Value> &OnlineStream::GetStates() { @@ -218,6 +225,14 @@ std::vector<Ort::Value> &OnlineStream::GetStates() {
218 return impl_->GetStates(); 225 return impl_->GetStates();
219 } 226 }
220 227
  228 +void OnlineStream::SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states) {
  229 + return impl_->SetNeMoDecoderStates(std::move(decoder_states));
  230 +}
  231 +
  232 +std::vector<Ort::Value> &OnlineStream::GetNeMoDecoderStates() {
  233 + return impl_->GetNeMoDecoderStates();
  234 +}
  235 +
221 const ContextGraphPtr &OnlineStream::GetContextGraph() const { 236 const ContextGraphPtr &OnlineStream::GetContextGraph() const {
222 return impl_->GetContextGraph(); 237 return impl_->GetContextGraph();
223 } 238 }
@@ -91,6 +91,9 @@ class OnlineStream { @@ -91,6 +91,9 @@ class OnlineStream {
91 void SetStates(std::vector<Ort::Value> states); 91 void SetStates(std::vector<Ort::Value> states);
92 std::vector<Ort::Value> &GetStates(); 92 std::vector<Ort::Value> &GetStates();
93 93
  94 + void SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states);
  95 + std::vector<Ort::Value> &GetNeMoDecoderStates();
  96 +
94 /** 97 /**
95 * Get the context graph corresponding to this stream. 98 * Get the context graph corresponding to this stream.
96 * 99 *
  1 +// sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +// Copyright (c) 2024 Sangeet Sagar
  5 +
  6 +#include "sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h"
  7 +
  8 +#include <algorithm>
  9 +#include <iterator>
  10 +#include <utility>
  11 +
  12 +#include "sherpa-onnx/csrc/macros.h"
  13 +#include "sherpa-onnx/csrc/onnx-utils.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +static std::pair<Ort::Value, Ort::Value> BuildDecoderInput(
  18 + int32_t token, OrtAllocator *allocator) {
  19 + std::array<int64_t, 2> shape{1, 1};
  20 +
  21 + Ort::Value decoder_input =
  22 + Ort::Value::CreateTensor<int32_t>(allocator, shape.data(), shape.size());
  23 +
  24 + std::array<int64_t, 1> length_shape{1};
  25 + Ort::Value decoder_input_length = Ort::Value::CreateTensor<int32_t>(
  26 + allocator, length_shape.data(), length_shape.size());
  27 +
  28 + int32_t *p = decoder_input.GetTensorMutableData<int32_t>();
  29 +
  30 + int32_t *p_length = decoder_input_length.GetTensorMutableData<int32_t>();
  31 +
  32 + p[0] = token;
  33 +
  34 + p_length[0] = 1;
  35 +
  36 + return {std::move(decoder_input), std::move(decoder_input_length)};
  37 +}
  38 +
  39 +
  40 +OnlineTransducerDecoderResult
  41 +OnlineTransducerGreedySearchNeMoDecoder::GetEmptyResult() const {
  42 + int32_t context_size = 8;
  43 + int32_t blank_id = 0; // always 0
  44 + OnlineTransducerDecoderResult r;
  45 + r.tokens.resize(context_size, -1);
  46 + r.tokens.back() = blank_id;
  47 +
  48 + return r;
  49 +}
  50 +
  51 +static void UpdateCachedDecoderOut(
  52 + OrtAllocator *allocator, const Ort::Value *decoder_out,
  53 + std::vector<OnlineTransducerDecoderResult> *result) {
  54 + std::vector<int64_t> shape =
  55 + decoder_out->GetTensorTypeAndShapeInfo().GetShape();
  56 + auto memory_info =
  57 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  58 + std::array<int64_t, 2> v_shape{1, shape[1]};
  59 +
  60 + const float *src = decoder_out->GetTensorData<float>();
  61 + for (auto &r : *result) {
  62 + if (!r.decoder_out) {
  63 + r.decoder_out = Ort::Value::CreateTensor<float>(allocator, v_shape.data(),
  64 + v_shape.size());
  65 + }
  66 +
  67 + float *dst = r.decoder_out.GetTensorMutableData<float>();
  68 + std::copy(src, src + shape[1], dst);
  69 + src += shape[1];
  70 + }
  71 +}
  72 +
  73 +std::vector<Ort::Value> DecodeOne(
  74 + const float *encoder_out, int32_t num_rows, int32_t num_cols,
  75 + OnlineTransducerNeMoModel *model, float blank_penalty,
  76 + std::vector<Ort::Value>& decoder_states,
  77 + std::vector<OnlineTransducerDecoderResult> *result) {
  78 +
  79 + auto memory_info =
  80 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  81 +
  82 + // OnlineTransducerDecoderResult result;
  83 + int32_t vocab_size = model->VocabSize();
  84 + int32_t blank_id = vocab_size - 1;
  85 +
  86 + auto &r = (*result)[0];
  87 + Ort::Value decoder_out{nullptr};
  88 +
  89 + auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator());
  90 + // decoder_input_pair[0]: decoder_input
  91 + // decoder_input_pair[1]: decoder_input_length (discarded)
  92 +
  93 + // decoder_output_pair.second returns the next decoder state
  94 + std::pair<Ort::Value, std::vector<Ort::Value>> decoder_output_pair =
  95 + model->RunDecoder(std::move(decoder_input_pair.first),
  96 + std::move(decoder_states)); // here decoder_states = {len=0, cap=0}. But decoder_output_pair= {first, second: {len=2, cap=2}} // ATTN
  97 +
  98 + std::array<int64_t, 3> encoder_shape{1, num_cols, 1};
  99 +
  100 + decoder_states = std::move(decoder_output_pair.second);
  101 +
  102 + // TODO: Inside this loop, I need to framewise decoding.
  103 + for (int32_t t = 0; t != num_rows; ++t) {
  104 + Ort::Value cur_encoder_out = Ort::Value::CreateTensor(
  105 + memory_info, const_cast<float *>(encoder_out) + t * num_cols, num_cols,
  106 + encoder_shape.data(), encoder_shape.size());
  107 +
  108 + Ort::Value logit = model->RunJoiner(std::move(cur_encoder_out),
  109 + View(&decoder_output_pair.first));
  110 +
  111 + float *p_logit = logit.GetTensorMutableData<float>();
  112 + if (blank_penalty > 0) {
  113 + p_logit[blank_id] -= blank_penalty;
  114 + }
  115 +
  116 + auto y = static_cast<int32_t>(std::distance(
  117 + static_cast<const float *>(p_logit),
  118 + std::max_element(static_cast<const float *>(p_logit),
  119 + static_cast<const float *>(p_logit) + vocab_size)));
  120 + SHERPA_ONNX_LOGE("y=%d", y);
  121 + if (y != blank_id) {
  122 + r.tokens.push_back(y);
  123 + r.timestamps.push_back(t + r.frame_offset);
  124 +
  125 + decoder_input_pair = BuildDecoderInput(y, model->Allocator());
  126 +
  127 + // last decoder state becomes the current state for the first chunk
  128 + decoder_output_pair =
  129 + model->RunDecoder(std::move(decoder_input_pair.first),
  130 + std::move(decoder_states));
  131 +
  132 + // Update the decoder states for the next chunk
  133 + decoder_states = std::move(decoder_output_pair.second);
  134 + }
  135 + }
  136 +
  137 + decoder_out = std::move(decoder_output_pair.first);
  138 +// UpdateCachedDecoderOut(model->Allocator(), &decoder_out, result);
  139 +
  140 + // Update frame_offset
  141 + for (auto &r : *result) {
  142 + r.frame_offset += num_rows;
  143 + }
  144 +
  145 + return std::move(decoder_states);
  146 +}
  147 +
  148 +
  149 +std::vector<Ort::Value> OnlineTransducerGreedySearchNeMoDecoder::Decode(
  150 + Ort::Value encoder_out,
  151 + std::vector<Ort::Value> decoder_states,
  152 + std::vector<OnlineTransducerDecoderResult> *result,
  153 + OnlineStream ** /*ss = nullptr*/, int32_t /*n= 0*/) {
  154 +
  155 + auto shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape();
  156 +
  157 + if (shape[0] != result->size()) {
  158 + SHERPA_ONNX_LOGE(
  159 + "Size mismatch! encoder_out.size(0) %d, result.size(0): %d",
  160 + static_cast<int32_t>(shape[0]),
  161 + static_cast<int32_t>(result->size()));
  162 + exit(-1);
  163 + }
  164 +
  165 + int32_t batch_size = static_cast<int32_t>(shape[0]); // bs = 1
  166 + int32_t dim1 = static_cast<int32_t>(shape[1]); // 2
  167 + int32_t dim2 = static_cast<int32_t>(shape[2]); // 512
  168 +
  169 + // Define and initialize encoder_out_length
  170 + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
  171 +
  172 + int64_t length_value = 1;
  173 + std::vector<int64_t> length_shape = {1};
  174 +
  175 + Ort::Value encoder_out_length = Ort::Value::CreateTensor<int64_t>(
  176 + memory_info, &length_value, 1, length_shape.data(), length_shape.size()
  177 + );
  178 +
  179 + const int64_t *p_length = encoder_out_length.GetTensorData<int64_t>();
  180 + const float *p = encoder_out.GetTensorData<float>();
  181 +
  182 + // std::vector<OnlineTransducerDecoderResult> ans(batch_size);
  183 +
  184 + for (int32_t i = 0; i != batch_size; ++i) {
  185 + const float *this_p = p + dim1 * dim2 * i;
  186 + int32_t this_len = p_length[i];
  187 +
  188 + // outputs the decoder state from last chunk.
  189 + auto last_decoder_states = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_, decoder_states, result);
  190 + // ans[i] = decode_result_pair.first;
  191 + decoder_states = std::move(last_decoder_states);
  192 + }
  193 +
  194 + return decoder_states;
  195 +
  196 +}
  197 +
  198 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +// Copyright (c) 2024 Sangeet Sagar
  5 +
  6 +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
  7 +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
  8 +
  9 +#include <vector>
  10 +#include "sherpa-onnx/csrc/online-transducer-decoder.h"
  11 +#include "sherpa-onnx/csrc/online-transducer-nemo-model.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +class OnlineTransducerGreedySearchNeMoDecoder {
  16 + public:
  17 + OnlineTransducerGreedySearchNeMoDecoder(OnlineTransducerNeMoModel *model,
  18 + float blank_penalty)
  19 + : model_(model),
  20 + blank_penalty_(blank_penalty) {}
  21 +
  22 + OnlineTransducerDecoderResult GetEmptyResult() const;
  23 + void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {}
  24 + void StripLeadingBlanks(OnlineTransducerDecoderResult * /*r*/) const {}
  25 +
  26 + std::vector<Ort::Value> Decode(
  27 + Ort::Value encoder_out,
  28 + std::vector<Ort::Value> decoder_states,
  29 + std::vector<OnlineTransducerDecoderResult> *result,
  30 + OnlineStream **ss = nullptr, int32_t n = 0);
  31 +
  32 + private:
  33 + OnlineTransducerNeMoModel *model_; // Not owned
  34 + float blank_penalty_;
  35 +};
  36 +
  37 +} // namespace sherpa_onnx
  38 +
  39 +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
  40 +
  1 +// sherpa-onnx/csrc/online-transducer-nemo-model.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +// Copyright (c) 2024 Sangeet Sagar
  5 +
  6 +#include "sherpa-onnx/csrc/online-transducer-nemo-model.h"
  7 +
  8 +#include <assert.h>
  9 +#include <math.h>
  10 +
  11 +#include <algorithm>
  12 +#include <memory>
  13 +#include <numeric>
  14 +#include <sstream>
  15 +#include <string>
  16 +#include <utility>
  17 +#include <vector>
  18 +
  19 +#if __ANDROID_API__ >= 9
  20 +#include "android/asset_manager.h"
  21 +#include "android/asset_manager_jni.h"
  22 +#endif
  23 +
  24 +#include "sherpa-onnx/csrc/cat.h"
  25 +#include "sherpa-onnx/csrc/macros.h"
  26 +#include "sherpa-onnx/csrc/online-transducer-decoder.h"
  27 +#include "sherpa-onnx/csrc/onnx-utils.h"
  28 +#include "sherpa-onnx/csrc/session.h"
  29 +#include "sherpa-onnx/csrc/text-utils.h"
  30 +#include "sherpa-onnx/csrc/transpose.h"
  31 +#include "sherpa-onnx/csrc/unbind.h"
  32 +
  33 +namespace sherpa_onnx {
  34 +
  35 +class OnlineTransducerNeMoModel::Impl {
  36 + public:
  37 + explicit Impl(const OnlineModelConfig &config)
  38 + : config_(config),
  39 + env_(ORT_LOGGING_LEVEL_WARNING),
  40 + sess_opts_(GetSessionOptions(config)),
  41 + allocator_{} {
  42 + {
  43 + auto buf = ReadFile(config.transducer.encoder);
  44 + InitEncoder(buf.data(), buf.size());
  45 + }
  46 +
  47 + {
  48 + auto buf = ReadFile(config.transducer.decoder);
  49 + InitDecoder(buf.data(), buf.size());
  50 + }
  51 +
  52 + {
  53 + auto buf = ReadFile(config.transducer.joiner);
  54 + InitJoiner(buf.data(), buf.size());
  55 + }
  56 + }
  57 +
  58 +#if __ANDROID_API__ >= 9
  59 + Impl(AAssetManager *mgr, const OnlineModelConfig &config)
  60 + : config_(config),
  61 + env_(ORT_LOGGING_LEVEL_WARNING),
  62 + sess_opts_(GetSessionOptions(config)),
  63 + allocator_{} {
  64 + {
  65 + auto buf = ReadFile(mgr, config.transducer.encoder_filename);
  66 + InitEncoder(buf.data(), buf.size());
  67 + }
  68 +
  69 + {
  70 + auto buf = ReadFile(mgr, config.transducer.decoder_filename);
  71 + InitDecoder(buf.data(), buf.size());
  72 + }
  73 +
  74 + {
  75 + auto buf = ReadFile(mgr, config.transducer.joiner_filename);
  76 + InitJoiner(buf.data(), buf.size());
  77 + }
  78 + }
  79 +#endif
  80 +
  81 + std::vector<Ort::Value> RunEncoder(Ort::Value features,
  82 + std::vector<Ort::Value> states) {
  83 + Ort::Value &cache_last_channel = states[0];
  84 + Ort::Value &cache_last_time = states[1];
  85 + Ort::Value &cache_last_channel_len = states[2];
  86 +
  87 + int32_t batch_size = features.GetTensorTypeAndShapeInfo().GetShape()[0];
  88 +
  89 + std::array<int64_t, 1> length_shape{batch_size};
  90 +
  91 + Ort::Value length = Ort::Value::CreateTensor<int64_t>(
  92 + allocator_, length_shape.data(), length_shape.size());
  93 +
  94 + int64_t *p_length = length.GetTensorMutableData<int64_t>();
  95 +
  96 + std::fill(p_length, p_length + batch_size, ChunkSize());
  97 +
  98 + // (B, T, C) -> (B, C, T)
  99 + features = Transpose12(allocator_, &features);
  100 +
  101 + std::array<Ort::Value, 5> inputs = {
  102 + std::move(features), View(&length), std::move(cache_last_channel),
  103 + std::move(cache_last_time), std::move(cache_last_channel_len)};
  104 +
  105 + auto out =
  106 + encoder_sess_->Run({}, encoder_input_names_ptr_.data(), inputs.data(), inputs.size(),
  107 + encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size());
  108 + // out[0]: logit
  109 + // out[1] logit_length
  110 + // out[2:] states_next
  111 + //
  112 + // we need to remove out[1]
  113 +
  114 + std::vector<Ort::Value> ans;
  115 + ans.reserve(out.size() - 1);
  116 +
  117 + for (int32_t i = 0; i != out.size(); ++i) {
  118 + if (i == 1) {
  119 + continue;
  120 + }
  121 +
  122 + ans.push_back(std::move(out[i]));
  123 + }
  124 +
  125 + return ans;
  126 + }
  127 +
  128 + std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder(
  129 + Ort::Value targets, std::vector<Ort::Value> states) {
  130 +
  131 + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
  132 +
  133 + // Create the tensor with a single int32_t value of 1
  134 + int32_t length_value = 1;
  135 + std::vector<int64_t> length_shape = {1};
  136 +
  137 + Ort::Value targets_length = Ort::Value::CreateTensor<int32_t>(
  138 + memory_info, &length_value, 1, length_shape.data(), length_shape.size()
  139 + );
  140 +
  141 + std::vector<Ort::Value> decoder_inputs;
  142 + decoder_inputs.reserve(2 + states.size());
  143 +
  144 + decoder_inputs.push_back(std::move(targets));
  145 + decoder_inputs.push_back(std::move(targets_length));
  146 +
  147 + for (auto &s : states) {
  148 + decoder_inputs.push_back(std::move(s));
  149 + }
  150 +
  151 + auto decoder_out = decoder_sess_->Run(
  152 + {}, decoder_input_names_ptr_.data(), decoder_inputs.data(),
  153 + decoder_inputs.size(), decoder_output_names_ptr_.data(),
  154 + decoder_output_names_ptr_.size());
  155 +
  156 + std::vector<Ort::Value> states_next;
  157 + states_next.reserve(states.size());
  158 +
  159 + // decoder_out[0]: decoder_output
  160 + // decoder_out[1]: decoder_output_length (discarded)
  161 + // decoder_out[2:] states_next
  162 +
  163 + for (int32_t i = 0; i != states.size(); ++i) {
  164 + states_next.push_back(std::move(decoder_out[i + 2]));
  165 + }
  166 +
  167 + // we discard decoder_out[1]
  168 + return {std::move(decoder_out[0]), std::move(states_next)};
  169 + }
  170 +
  171 + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) {
  172 + std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
  173 + std::move(decoder_out)};
  174 + auto logit =
  175 + joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(),
  176 + joiner_input.size(), joiner_output_names_ptr_.data(),
  177 + joiner_output_names_ptr_.size());
  178 +
  179 + return std::move(logit[0]);
  180 +}
  181 +
  182 + std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const {
  183 + std::array<int64_t, 3> s0_shape{pred_rnn_layers_, batch_size, pred_hidden_};
  184 + Ort::Value s0 = Ort::Value::CreateTensor<float>(allocator_, s0_shape.data(),
  185 + s0_shape.size());
  186 +
  187 + Fill<float>(&s0, 0);
  188 +
  189 + std::array<int64_t, 3> s1_shape{pred_rnn_layers_, batch_size, pred_hidden_};
  190 +
  191 + Ort::Value s1 = Ort::Value::CreateTensor<float>(allocator_, s1_shape.data(),
  192 + s1_shape.size());
  193 +
  194 + Fill<float>(&s1, 0);
  195 +
  196 + std::vector<Ort::Value> states;
  197 +
  198 + states.reserve(2);
  199 + states.push_back(std::move(s0));
  200 + states.push_back(std::move(s1));
  201 +
  202 + return states;
  203 + }
  204 +
  205 + int32_t ChunkSize() const { return window_size_; }
  206 +
  207 + int32_t ChunkShift() const { return chunk_shift_; }
  208 +
  209 + int32_t SubsamplingFactor() const { return subsampling_factor_; }
  210 +
  211 + int32_t VocabSize() const { return vocab_size_; }
  212 +
  213 + OrtAllocator *Allocator() const { return allocator_; }
  214 +
  215 + std::string FeatureNormalizationMethod() const { return normalize_type_; }
  216 +
  217 + // Return a vector containing 3 tensors
  218 + // - cache_last_channel
  219 + // - cache_last_time_
  220 + // - cache_last_channel_len
  221 + std::vector<Ort::Value> GetInitStates() {
  222 + std::vector<Ort::Value> ans;
  223 + ans.reserve(3);
  224 + ans.push_back(View(&cache_last_channel_));
  225 + ans.push_back(View(&cache_last_time_));
  226 + ans.push_back(View(&cache_last_channel_len_));
  227 +
  228 + return ans;
  229 + }
  230 +
  231 +private:
  232 + void InitEncoder(void *model_data, size_t model_data_length) {
  233 + encoder_sess_ = std::make_unique<Ort::Session>(
  234 + env_, model_data, model_data_length, sess_opts_);
  235 +
  236 + GetInputNames(encoder_sess_.get(), &encoder_input_names_,
  237 + &encoder_input_names_ptr_);
  238 +
  239 + GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
  240 + &encoder_output_names_ptr_);
  241 +
  242 + // get meta data
  243 + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
  244 + if (config_.debug) {
  245 + std::ostringstream os;
  246 + os << "---encoder---\n";
  247 + PrintModelMetadata(os, meta_data);
  248 + SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
  249 + }
  250 +
  251 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  252 + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
  253 +
  254 + // need to increase by 1 since the blank token is not included in computing
  255 + // vocab_size in NeMo.
  256 + vocab_size_ += 1;
  257 +
  258 + SHERPA_ONNX_READ_META_DATA(window_size_, "window_size");
  259 + SHERPA_ONNX_READ_META_DATA(chunk_shift_, "chunk_shift");
  260 + SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor");
  261 + SHERPA_ONNX_READ_META_DATA_STR(normalize_type_, "normalize_type");
  262 + SHERPA_ONNX_READ_META_DATA(pred_rnn_layers_, "pred_rnn_layers");
  263 + SHERPA_ONNX_READ_META_DATA(pred_hidden_, "pred_hidden");
  264 +
  265 + SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim1_,
  266 + "cache_last_channel_dim1");
  267 + SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim2_,
  268 + "cache_last_channel_dim2");
  269 + SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim3_,
  270 + "cache_last_channel_dim3");
  271 + SHERPA_ONNX_READ_META_DATA(cache_last_time_dim1_, "cache_last_time_dim1");
  272 + SHERPA_ONNX_READ_META_DATA(cache_last_time_dim2_, "cache_last_time_dim2");
  273 + SHERPA_ONNX_READ_META_DATA(cache_last_time_dim3_, "cache_last_time_dim3");
  274 +
  275 + if (normalize_type_ == "NA") {
  276 + normalize_type_ = "";
  277 + }
  278 +
  279 + InitStates();
  280 + }
  281 +
  282 + void InitStates() {
  283 + std::array<int64_t, 4> cache_last_channel_shape{1, cache_last_channel_dim1_,
  284 + cache_last_channel_dim2_,
  285 + cache_last_channel_dim3_};
  286 +
  287 + cache_last_channel_ = Ort::Value::CreateTensor<float>(
  288 + allocator_, cache_last_channel_shape.data(),
  289 + cache_last_channel_shape.size());
  290 +
  291 + Fill<float>(&cache_last_channel_, 0);
  292 +
  293 + std::array<int64_t, 4> cache_last_time_shape{
  294 + 1, cache_last_time_dim1_, cache_last_time_dim2_, cache_last_time_dim3_};
  295 +
  296 + cache_last_time_ = Ort::Value::CreateTensor<float>(
  297 + allocator_, cache_last_time_shape.data(), cache_last_time_shape.size());
  298 +
  299 + Fill<float>(&cache_last_time_, 0);
  300 +
  301 + int64_t shape = 1;
  302 + cache_last_channel_len_ =
  303 + Ort::Value::CreateTensor<int64_t>(allocator_, &shape, 1);
  304 +
  305 + cache_last_channel_len_.GetTensorMutableData<int64_t>()[0] = 0;
  306 + }
  307 +
  308 + void InitDecoder(void *model_data, size_t model_data_length) {
  309 + decoder_sess_ = std::make_unique<Ort::Session>(
  310 + env_, model_data, model_data_length, sess_opts_);
  311 +
  312 + GetInputNames(decoder_sess_.get(), &decoder_input_names_,
  313 + &decoder_input_names_ptr_);
  314 +
  315 + GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
  316 + &decoder_output_names_ptr_);
  317 + }
  318 +
  319 + void InitJoiner(void *model_data, size_t model_data_length) {
  320 + joiner_sess_ = std::make_unique<Ort::Session>(
  321 + env_, model_data, model_data_length, sess_opts_);
  322 +
  323 + GetInputNames(joiner_sess_.get(), &joiner_input_names_,
  324 + &joiner_input_names_ptr_);
  325 +
  326 + GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
  327 + &joiner_output_names_ptr_);
  328 + }
  329 +
  330 + private:
  331 + OnlineModelConfig config_;
  332 + Ort::Env env_;
  333 + Ort::SessionOptions sess_opts_;
  334 + Ort::AllocatorWithDefaultOptions allocator_;
  335 +
  336 + std::unique_ptr<Ort::Session> encoder_sess_;
  337 + std::unique_ptr<Ort::Session> decoder_sess_;
  338 + std::unique_ptr<Ort::Session> joiner_sess_;
  339 +
  340 + std::vector<std::string> encoder_input_names_;
  341 + std::vector<const char *> encoder_input_names_ptr_;
  342 +
  343 + std::vector<std::string> encoder_output_names_;
  344 + std::vector<const char *> encoder_output_names_ptr_;
  345 +
  346 + std::vector<std::string> decoder_input_names_;
  347 + std::vector<const char *> decoder_input_names_ptr_;
  348 +
  349 + std::vector<std::string> decoder_output_names_;
  350 + std::vector<const char *> decoder_output_names_ptr_;
  351 +
  352 + std::vector<std::string> joiner_input_names_;
  353 + std::vector<const char *> joiner_input_names_ptr_;
  354 +
  355 + std::vector<std::string> joiner_output_names_;
  356 + std::vector<const char *> joiner_output_names_ptr_;
  357 +
  358 + int32_t window_size_;
  359 + int32_t chunk_shift_;
  360 + int32_t vocab_size_ = 0;
  361 + int32_t subsampling_factor_ = 8;
  362 + std::string normalize_type_;
  363 + int32_t pred_rnn_layers_ = -1;
  364 + int32_t pred_hidden_ = -1;
  365 +
  366 + int32_t cache_last_channel_dim1_;
  367 + int32_t cache_last_channel_dim2_;
  368 + int32_t cache_last_channel_dim3_;
  369 + int32_t cache_last_time_dim1_;
  370 + int32_t cache_last_time_dim2_;
  371 + int32_t cache_last_time_dim3_;
  372 +
  373 + Ort::Value cache_last_channel_{nullptr};
  374 + Ort::Value cache_last_time_{nullptr};
  375 + Ort::Value cache_last_channel_len_{nullptr};
  376 +};
  377 +
  378 +OnlineTransducerNeMoModel::OnlineTransducerNeMoModel(
  379 + const OnlineModelConfig &config)
  380 + : impl_(std::make_unique<Impl>(config)) {}
  381 +
  382 +#if __ANDROID_API__ >= 9
  383 +OnlineTransducerNeMoModel::OnlineTransducerNeMoModel(
  384 + AAssetManager *mgr, const OnlineModelConfig &config)
  385 + : impl_(std::make_unique<Impl>(mgr, config)) {}
  386 +#endif
  387 +
  388 +OnlineTransducerNeMoModel::~OnlineTransducerNeMoModel() = default;
  389 +
  390 +std::vector<Ort::Value>
  391 +OnlineTransducerNeMoModel::RunEncoder(Ort::Value features,
  392 + std::vector<Ort::Value> states) const {
  393 + return impl_->RunEncoder(std::move(features), std::move(states));
  394 +}
  395 +
  396 +std::pair<Ort::Value, std::vector<Ort::Value>>
  397 +OnlineTransducerNeMoModel::RunDecoder(Ort::Value targets,
  398 + std::vector<Ort::Value> states) const {
  399 + return impl_->RunDecoder(std::move(targets), std::move(states));
  400 +}
  401 +
  402 +std::vector<Ort::Value> OnlineTransducerNeMoModel::GetDecoderInitStates(
  403 + int32_t batch_size) const {
  404 + return impl_->GetDecoderInitStates(batch_size);
  405 +}
  406 +
  407 +Ort::Value OnlineTransducerNeMoModel::RunJoiner(Ort::Value encoder_out,
  408 + Ort::Value decoder_out) const {
  409 + return impl_->RunJoiner(std::move(encoder_out), std::move(decoder_out));
  410 +}
  411 +
  412 +
  413 +int32_t OnlineTransducerNeMoModel::ChunkSize() const {
  414 + return impl_->ChunkSize();
  415 + }
  416 +
  417 +int32_t OnlineTransducerNeMoModel::ChunkShift() const {
  418 + return impl_->ChunkShift();
  419 + }
  420 +
  421 +int32_t OnlineTransducerNeMoModel::SubsamplingFactor() const {
  422 + return impl_->SubsamplingFactor();
  423 +}
  424 +
  425 +int32_t OnlineTransducerNeMoModel::VocabSize() const {
  426 + return impl_->VocabSize();
  427 +}
  428 +
  429 +OrtAllocator *OnlineTransducerNeMoModel::Allocator() const {
  430 + return impl_->Allocator();
  431 +}
  432 +
  433 +std::string OnlineTransducerNeMoModel::FeatureNormalizationMethod() const {
  434 + return impl_->FeatureNormalizationMethod();
  435 +}
  436 +
  437 +std::vector<Ort::Value> OnlineTransducerNeMoModel::GetInitStates() const {
  438 + return impl_->GetInitStates();
  439 +}
  440 +
  441 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-transducer-nemo-model.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +// Copyright (c) 2024 Sangeet Sagar
  5 +
  6 +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_NEMO_MODEL_H_
  7 +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_NEMO_MODEL_H_
  8 +
  9 +#include <memory>
  10 +#include <string>
  11 +#include <utility>
  12 +#include <vector>
  13 +
  14 +#if __ANDROID_API__ >= 9
  15 +#include "android/asset_manager.h"
  16 +#include "android/asset_manager_jni.h"
  17 +#endif
  18 +
  19 +#include "onnxruntime_cxx_api.h" // NOLINT
  20 +#include "sherpa-onnx/csrc/online-model-config.h"
  21 +
  22 +namespace sherpa_onnx {
  23 +
  24 +// see
  25 +// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py#L40
  26 +// Its decoder is stateful, not stateless.
  27 +class OnlineTransducerNeMoModel {
  28 + public:
  29 + explicit OnlineTransducerNeMoModel(const OnlineModelConfig &config);
  30 +
  31 +#if __ANDROID_API__ >= 9
  32 + OnlineTransducerNeMoModel(AAssetManager *mgr,
  33 + const OnlineModelConfig &config);
  34 +#endif
  35 +
  36 + ~OnlineTransducerNeMoModel();
  37 + // A list of 3 tensors:
  38 + // - cache_last_channel
  39 + // - cache_last_time
  40 + // - cache_last_channel_len
  41 + std::vector<Ort::Value> GetInitStates() const;
  42 +
  43 + /** Run the encoder.
  44 + *
  45 + * @param features A tensor of shape (N, T, C). It is changed in-place.
  46 + * @param states It is from GetInitStates() or returned from this method.
  47 + *
  48 + * @return Return a tuple containing:
  49 + * - ans[0]: encoder_out, a tensor of shape (N, T', encoder_out_dim)
  50 + * - ans[1:]: contains next states
  51 + */
  52 + std::vector<Ort::Value> RunEncoder(
  53 + Ort::Value features, std::vector<Ort::Value> states) const; // NOLINT
  54 +
  55 + /** Run the decoder network.
  56 + *
  57 + * @param targets A int32 tensor of shape (batch_size, 1)
  58 + * @param states The states for the decoder model.
  59 + * @return Return a vector:
  60 + * - ans[0] is the decoder_out (a float tensor)
  61 + * - ans[1:] is the next states
  62 + */
  63 + std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder(
  64 + Ort::Value targets, std::vector<Ort::Value> states) const;
  65 +
  66 + std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const;
  67 +
  68 + /** Run the joint network.
  69 + *
  70 + * @param encoder_out Output of the encoder network.
  71 + * @param decoder_out Output of the decoder network.
  72 + * @return Return a tensor of shape (N, 1, 1, vocab_size) containing logits.
  73 + */
  74 + Ort::Value RunJoiner(Ort::Value encoder_out,
  75 + Ort::Value decoder_out) const;
  76 +
  77 +
  78 + /** We send this number of feature frames to the encoder at a time. */
  79 + int32_t ChunkSize() const;
  80 +
  81 + /** Number of input frames to discard after each call to RunEncoder.
  82 + *
  83 + * For instance, if we have 30 frames, chunk_size=8, chunk_shift=6.
  84 + *
  85 + * In the first call of RunEncoder, we use frames 0~7 since chunk_size is 8.
  86 + * Then we discard frame 0~5 since chunk_shift is 6.
  87 + * In the second call of RunEncoder, we use frames 6~13; and then we discard
  88 + * frames 6~11.
  89 + * In the third call of RunEncoder, we use frames 12~19; and then we discard
  90 + * frames 12~16.
  91 + *
  92 + * Note: ChunkSize() - ChunkShift() == right context size
  93 + */
  94 + int32_t ChunkShift() const;
  95 +
  96 + /** Return the subsampling factor of the model.
  97 + */
  98 + int32_t SubsamplingFactor() const;
  99 +
  100 + int32_t VocabSize() const;
  101 +
  102 + /** Return an allocator for allocating memory
  103 + */
  104 + OrtAllocator *Allocator() const;
  105 +
  106 + // Possible values:
  107 + // - per_feature
  108 + // - all_features (not implemented yet)
  109 + // - fixed_mean (not implemented)
  110 + // - fixed_std (not implemented)
  111 + // - or just leave it to empty
  112 + // See
  113 + // https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59
  114 + // for details
  115 + std::string FeatureNormalizationMethod() const;
  116 +
  117 + private:
  118 + class Impl;
  119 + std::unique_ptr<Impl> impl_;
  120 + };
  121 +
  122 +} // namespace sherpa_onnx
  123 +
  124 +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_NEMO_MODEL_H_