Sangeet Sagar
Committed by GitHub

Add C++ runtime for *streaming* faster conformer transducer from NeMo. (#889)

Co-authored-by: sangeet2020 <15uec053@gmail.com>
... ... @@ -74,6 +74,8 @@ set(sources
online-transducer-model-config.cc
online-transducer-model.cc
online-transducer-modified-beam-search-decoder.cc
online-transducer-nemo-model.cc
online-transducer-greedy-search-nemo-decoder.cc
online-wenet-ctc-model-config.cc
online-wenet-ctc-model.cc
online-zipformer-transducer-model.cc
... ...
... ... @@ -7,13 +7,28 @@
#include "sherpa-onnx/csrc/online-recognizer-ctc-impl.h"
#include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h"
#include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h"
#include "sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
const OnlineRecognizerConfig &config) {
if (!config.model_config.transducer.encoder.empty()) {
return std::make_unique<OnlineRecognizerTransducerImpl>(config);
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
auto decoder_model = ReadFile(config.model_config.transducer.decoder);
auto sess = std::make_unique<Ort::Session>(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});
size_t node_count = sess->GetOutputCount();
if (node_count == 1) {
return std::make_unique<OnlineRecognizerTransducerImpl>(config);
} else {
SHERPA_ONNX_LOGE("Running streaming Nemo transducer model");
return std::make_unique<OnlineRecognizerTransducerNeMoImpl>(config);
}
}
if (!config.model_config.paraformer.encoder.empty()) {
... ... @@ -34,7 +49,18 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
AAssetManager *mgr, const OnlineRecognizerConfig &config) {
if (!config.model_config.transducer.encoder.empty()) {
return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config);
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
auto decoder_model = ReadFile(mgr, config.model_config.transducer.decoder);
auto sess = std::make_unique<Ort::Session>(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});
size_t node_count = sess->GetOutputCount();
if (node_count == 1) {
return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config);
} else {
return std::make_unique<OnlineRecognizerTransducerNeMoImpl>(mgr, config);
}
}
if (!config.model_config.paraformer.encoder.empty()) {
... ...
... ... @@ -46,6 +46,7 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
r.timestamps.reserve(src.tokens.size());
for (auto i : src.tokens) {
if (i == -1) continue;
auto sym = sym_table[i];
r.text.append(sym);
... ...
// sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h
//
// Copyright (c) 2022-2024 Xiaomi Corporation
// Copyright (c) 2024 Sangeet Sagar
#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
#include <fstream>
#include <ios>
#include <memory>
#include <regex> // NOLINT
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-nemo-model.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/transpose.h"
#include "sherpa-onnx/csrc/utils.h"
namespace sherpa_onnx {
// defined in ./online-recognizer-transducer-impl.h
// static may or may not be here? TODDOs
static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
const SymbolTable &sym_table,
float frame_shift_ms,
int32_t subsampling_factor,
int32_t segment,
int32_t frames_since_start);
class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
public:
explicit OnlineRecognizerTransducerNeMoImpl(
const OnlineRecognizerConfig &config)
: config_(config),
symbol_table_(config.model_config.tokens),
endpoint_(config_.endpoint_config),
model_(std::make_unique<OnlineTransducerNeMoModel>(
config.model_config)) {
if (config.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>(
model_.get(), config_.blank_penalty);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config.decoding_method.c_str());
exit(-1);
}
PostInit();
}
#if __ANDROID_API__ >= 9
explicit OnlineRecognizerTransducerNeMoImpl(
AAssetManager *mgr, const OnlineRecognizerConfig &config)
: config_(config),
symbol_table_(mgr, config.model_config.tokens),
endpoint_(mgrconfig_.endpoint_config),
model_(std::make_unique<OnlineTransducerNeMoModel>(
mgr, config.model_config)) {
if (config.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>(
model_.get(), config_.blank_penalty);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config.decoding_method.c_str());
exit(-1);
}
PostInit();
}
#endif
std::unique_ptr<OnlineStream> CreateStream() const override {
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
stream->SetStates(model_->GetInitStates());
InitOnlineStream(stream.get());
return stream;
}
bool IsReady(OnlineStream *s) const override {
return s->GetNumProcessedFrames() + model_->ChunkSize() <
s->NumFramesReady();
}
OnlineRecognizerResult GetResult(OnlineStream *s) const override {
OnlineTransducerDecoderResult decoder_result = s->GetResult();
decoder_->StripLeadingBlanks(&decoder_result);
// TODO(fangjun): Remember to change these constants if needed
int32_t frame_shift_ms = 10;
int32_t subsampling_factor = 8;
return Convert(decoder_result, symbol_table_, frame_shift_ms, subsampling_factor,
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
}
bool IsEndpoint(OnlineStream *s) const override {
if (!config_.enable_endpoint) {
return false;
}
int32_t num_processed_frames = s->GetNumProcessedFrames();
// frame shift is 10 milliseconds
float frame_shift_in_seconds = 0.01;
// subsampling factor is 8
int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 8;
return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames,
frame_shift_in_seconds);
}
void Reset(OnlineStream *s) const override {
{
// segment is incremented only when the last
// result is not empty
const auto &r = s->GetResult();
if (!r.tokens.empty() && r.tokens.back() != 0) {
s->GetCurrentSegment() += 1;
}
}
// we keep the decoder_out
decoder_->UpdateDecoderOut(&s->GetResult());
Ort::Value decoder_out = std::move(s->GetResult().decoder_out);
auto r = decoder_->GetEmptyResult();
s->SetResult(r);
s->GetResult().decoder_out = std::move(decoder_out);
// Note: We only update counters. The underlying audio samples
// are not discarded.
s->Reset();
}
void DecodeStreams(OnlineStream **ss, int32_t n) const override {
int32_t chunk_size = model_->ChunkSize();
int32_t chunk_shift = model_->ChunkShift();
int32_t feature_dim = ss[0]->FeatureDim();
std::vector<OnlineTransducerDecoderResult> result(n);
std::vector<float> features_vec(n * chunk_size * feature_dim);
std::vector<std::vector<Ort::Value>> encoder_states(n);
for (int32_t i = 0; i != n; ++i) {
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
std::vector<float> features =
ss[i]->GetFrames(num_processed_frames, chunk_size);
// Question: should num_processed_frames include chunk_shift?
ss[i]->GetNumProcessedFrames() += chunk_shift;
std::copy(features.begin(), features.end(),
features_vec.data() + i * chunk_size * feature_dim);
result[i] = std::move(ss[i]->GetResult());
encoder_states[i] = std::move(ss[i]->GetStates());
}
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 3> x_shape{n, chunk_size, feature_dim};
Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(),
features_vec.size(), x_shape.data(),
x_shape.size());
// Batch size is 1
auto states = std::move(encoder_states[0]);
int32_t num_states = states.size(); // num_states = 3
auto t = model_->RunEncoder(std::move(x), std::move(states));
// t[0] encoder_out, float tensor, (batch_size, dim, T)
// t[1] next states
std::vector<Ort::Value> out_states;
out_states.reserve(num_states);
for (int32_t k = 1; k != num_states + 1; ++k) {
out_states.push_back(std::move(t[k]));
}
Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]);
// defined in online-transducer-greedy-search-nemo-decoder.h
// get intial states of decoder.
std::vector<Ort::Value> &decoder_states = ss[0]->GetNeMoDecoderStates();
// Subsequent decoder states (for each chunks) are updated inside the Decode method.
// This returns the decoder state from the LAST chunk. We probably dont need it. So we can discard it.
decoder_states = decoder_->Decode(std::move(encoder_out),
std::move(decoder_states),
&result, ss, n);
ss[0]->SetResult(result[0]);
ss[0]->SetStates(std::move(out_states));
}
void InitOnlineStream(OnlineStream *stream) const {
auto r = decoder_->GetEmptyResult();
stream->SetResult(r);
stream->SetNeMoDecoderStates(model_->GetDecoderInitStates(1));
}
private:
void PostInit() {
config_.feat_config.nemo_normalize_type =
model_->FeatureNormalizationMethod();
config_.feat_config.low_freq = 0;
// config_.feat_config.high_freq = 8000;
config_.feat_config.is_librosa = true;
config_.feat_config.remove_dc_offset = false;
// config_.feat_config.window_type = "hann";
config_.feat_config.dither = 0;
config_.feat_config.nemo_normalize_type =
model_->FeatureNormalizationMethod();
int32_t vocab_size = model_->VocabSize();
// check the blank ID
if (!symbol_table_.Contains("<blk>")) {
SHERPA_ONNX_LOGE("tokens.txt does not include the blank token <blk>");
exit(-1);
}
if (symbol_table_["<blk>"] != vocab_size - 1) {
SHERPA_ONNX_LOGE("<blk> is not the last token!");
exit(-1);
}
if (symbol_table_.NumSymbols() != vocab_size) {
SHERPA_ONNX_LOGE("number of lines in tokens.txt %d != %d (vocab_size)",
symbol_table_.NumSymbols(), vocab_size);
exit(-1);
}
}
private:
OnlineRecognizerConfig config_;
SymbolTable symbol_table_;
std::unique_ptr<OnlineTransducerNeMoModel> model_;
std::unique_ptr<OnlineTransducerGreedySearchNeMoDecoder> decoder_;
Endpoint endpoint_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
\ No newline at end of file
... ...
... ... @@ -90,6 +90,12 @@ class OnlineStream::Impl {
std::vector<Ort::Value> &GetStates() { return states_; }
void SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states) {
decoder_states_ = std::move(decoder_states);
}
std::vector<Ort::Value> &GetNeMoDecoderStates() { return decoder_states_; }
const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
std::vector<float> &GetParaformerFeatCache() {
... ... @@ -129,6 +135,7 @@ class OnlineStream::Impl {
TransducerKeywordResult empty_keyword_result_;
OnlineCtcDecoderResult ctc_result_;
std::vector<Ort::Value> states_; // states for transducer or ctc models
std::vector<Ort::Value> decoder_states_; // states for nemo transducer models
std::vector<float> paraformer_feat_cache_;
std::vector<float> paraformer_encoder_out_cache_;
std::vector<float> paraformer_alpha_cache_;
... ... @@ -218,6 +225,14 @@ std::vector<Ort::Value> &OnlineStream::GetStates() {
return impl_->GetStates();
}
void OnlineStream::SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states) {
return impl_->SetNeMoDecoderStates(std::move(decoder_states));
}
std::vector<Ort::Value> &OnlineStream::GetNeMoDecoderStates() {
return impl_->GetNeMoDecoderStates();
}
const ContextGraphPtr &OnlineStream::GetContextGraph() const {
return impl_->GetContextGraph();
}
... ...
... ... @@ -91,6 +91,9 @@ class OnlineStream {
void SetStates(std::vector<Ort::Value> states);
std::vector<Ort::Value> &GetStates();
void SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states);
std::vector<Ort::Value> &GetNeMoDecoderStates();
/**
* Get the context graph corresponding to this stream.
*
... ...
// sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc
//
// Copyright (c) 2024 Xiaomi Corporation
// Copyright (c) 2024 Sangeet Sagar
#include "sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h"
#include <algorithm>
#include <iterator>
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
static std::pair<Ort::Value, Ort::Value> BuildDecoderInput(
int32_t token, OrtAllocator *allocator) {
std::array<int64_t, 2> shape{1, 1};
Ort::Value decoder_input =
Ort::Value::CreateTensor<int32_t>(allocator, shape.data(), shape.size());
std::array<int64_t, 1> length_shape{1};
Ort::Value decoder_input_length = Ort::Value::CreateTensor<int32_t>(
allocator, length_shape.data(), length_shape.size());
int32_t *p = decoder_input.GetTensorMutableData<int32_t>();
int32_t *p_length = decoder_input_length.GetTensorMutableData<int32_t>();
p[0] = token;
p_length[0] = 1;
return {std::move(decoder_input), std::move(decoder_input_length)};
}
OnlineTransducerDecoderResult
OnlineTransducerGreedySearchNeMoDecoder::GetEmptyResult() const {
int32_t context_size = 8;
int32_t blank_id = 0; // always 0
OnlineTransducerDecoderResult r;
r.tokens.resize(context_size, -1);
r.tokens.back() = blank_id;
return r;
}
static void UpdateCachedDecoderOut(
OrtAllocator *allocator, const Ort::Value *decoder_out,
std::vector<OnlineTransducerDecoderResult> *result) {
std::vector<int64_t> shape =
decoder_out->GetTensorTypeAndShapeInfo().GetShape();
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 2> v_shape{1, shape[1]};
const float *src = decoder_out->GetTensorData<float>();
for (auto &r : *result) {
if (!r.decoder_out) {
r.decoder_out = Ort::Value::CreateTensor<float>(allocator, v_shape.data(),
v_shape.size());
}
float *dst = r.decoder_out.GetTensorMutableData<float>();
std::copy(src, src + shape[1], dst);
src += shape[1];
}
}
std::vector<Ort::Value> DecodeOne(
const float *encoder_out, int32_t num_rows, int32_t num_cols,
OnlineTransducerNeMoModel *model, float blank_penalty,
std::vector<Ort::Value>& decoder_states,
std::vector<OnlineTransducerDecoderResult> *result) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
// OnlineTransducerDecoderResult result;
int32_t vocab_size = model->VocabSize();
int32_t blank_id = vocab_size - 1;
auto &r = (*result)[0];
Ort::Value decoder_out{nullptr};
auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator());
// decoder_input_pair[0]: decoder_input
// decoder_input_pair[1]: decoder_input_length (discarded)
// decoder_output_pair.second returns the next decoder state
std::pair<Ort::Value, std::vector<Ort::Value>> decoder_output_pair =
model->RunDecoder(std::move(decoder_input_pair.first),
std::move(decoder_states)); // here decoder_states = {len=0, cap=0}. But decoder_output_pair= {first, second: {len=2, cap=2}} // ATTN
std::array<int64_t, 3> encoder_shape{1, num_cols, 1};
decoder_states = std::move(decoder_output_pair.second);
// TODO: Inside this loop, I need to framewise decoding.
for (int32_t t = 0; t != num_rows; ++t) {
Ort::Value cur_encoder_out = Ort::Value::CreateTensor(
memory_info, const_cast<float *>(encoder_out) + t * num_cols, num_cols,
encoder_shape.data(), encoder_shape.size());
Ort::Value logit = model->RunJoiner(std::move(cur_encoder_out),
View(&decoder_output_pair.first));
float *p_logit = logit.GetTensorMutableData<float>();
if (blank_penalty > 0) {
p_logit[blank_id] -= blank_penalty;
}
auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),
static_cast<const float *>(p_logit) + vocab_size)));
SHERPA_ONNX_LOGE("y=%d", y);
if (y != blank_id) {
r.tokens.push_back(y);
r.timestamps.push_back(t + r.frame_offset);
decoder_input_pair = BuildDecoderInput(y, model->Allocator());
// last decoder state becomes the current state for the first chunk
decoder_output_pair =
model->RunDecoder(std::move(decoder_input_pair.first),
std::move(decoder_states));
// Update the decoder states for the next chunk
decoder_states = std::move(decoder_output_pair.second);
}
}
decoder_out = std::move(decoder_output_pair.first);
// UpdateCachedDecoderOut(model->Allocator(), &decoder_out, result);
// Update frame_offset
for (auto &r : *result) {
r.frame_offset += num_rows;
}
return std::move(decoder_states);
}
std::vector<Ort::Value> OnlineTransducerGreedySearchNeMoDecoder::Decode(
Ort::Value encoder_out,
std::vector<Ort::Value> decoder_states,
std::vector<OnlineTransducerDecoderResult> *result,
OnlineStream ** /*ss = nullptr*/, int32_t /*n= 0*/) {
auto shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape();
if (shape[0] != result->size()) {
SHERPA_ONNX_LOGE(
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d",
static_cast<int32_t>(shape[0]),
static_cast<int32_t>(result->size()));
exit(-1);
}
int32_t batch_size = static_cast<int32_t>(shape[0]); // bs = 1
int32_t dim1 = static_cast<int32_t>(shape[1]); // 2
int32_t dim2 = static_cast<int32_t>(shape[2]); // 512
// Define and initialize encoder_out_length
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
int64_t length_value = 1;
std::vector<int64_t> length_shape = {1};
Ort::Value encoder_out_length = Ort::Value::CreateTensor<int64_t>(
memory_info, &length_value, 1, length_shape.data(), length_shape.size()
);
const int64_t *p_length = encoder_out_length.GetTensorData<int64_t>();
const float *p = encoder_out.GetTensorData<float>();
// std::vector<OnlineTransducerDecoderResult> ans(batch_size);
for (int32_t i = 0; i != batch_size; ++i) {
const float *this_p = p + dim1 * dim2 * i;
int32_t this_len = p_length[i];
// outputs the decoder state from last chunk.
auto last_decoder_states = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_, decoder_states, result);
// ans[i] = decode_result_pair.first;
decoder_states = std::move(last_decoder_states);
}
return decoder_states;
}
} // namespace sherpa_onnx
\ No newline at end of file
... ...
// sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h
//
// Copyright (c) 2024 Xiaomi Corporation
// Copyright (c) 2024 Sangeet Sagar
#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
#include <vector>
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-nemo-model.h"
namespace sherpa_onnx {
class OnlineTransducerGreedySearchNeMoDecoder {
public:
OnlineTransducerGreedySearchNeMoDecoder(OnlineTransducerNeMoModel *model,
float blank_penalty)
: model_(model),
blank_penalty_(blank_penalty) {}
OnlineTransducerDecoderResult GetEmptyResult() const;
void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {}
void StripLeadingBlanks(OnlineTransducerDecoderResult * /*r*/) const {}
std::vector<Ort::Value> Decode(
Ort::Value encoder_out,
std::vector<Ort::Value> decoder_states,
std::vector<OnlineTransducerDecoderResult> *result,
OnlineStream **ss = nullptr, int32_t n = 0);
private:
OnlineTransducerNeMoModel *model_; // Not owned
float blank_penalty_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
... ...
// sherpa-onnx/csrc/online-transducer-nemo-model.cc
//
// Copyright (c) 2024 Xiaomi Corporation
// Copyright (c) 2024 Sangeet Sagar
#include "sherpa-onnx/csrc/online-transducer-nemo-model.h"
#include <assert.h>
#include <math.h>
#include <algorithm>
#include <memory>
#include <numeric>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/cat.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
#include "sherpa-onnx/csrc/transpose.h"
#include "sherpa-onnx/csrc/unbind.h"
namespace sherpa_onnx {
class OnlineTransducerNeMoModel::Impl {
public:
explicit Impl(const OnlineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_WARNING),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(config.transducer.encoder);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(config.transducer.decoder);
InitDecoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(config.transducer.joiner);
InitJoiner(buf.data(), buf.size());
}
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OnlineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_WARNING),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(mgr, config.transducer.encoder_filename);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.transducer.decoder_filename);
InitDecoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.transducer.joiner_filename);
InitJoiner(buf.data(), buf.size());
}
}
#endif
std::vector<Ort::Value> RunEncoder(Ort::Value features,
std::vector<Ort::Value> states) {
Ort::Value &cache_last_channel = states[0];
Ort::Value &cache_last_time = states[1];
Ort::Value &cache_last_channel_len = states[2];
int32_t batch_size = features.GetTensorTypeAndShapeInfo().GetShape()[0];
std::array<int64_t, 1> length_shape{batch_size};
Ort::Value length = Ort::Value::CreateTensor<int64_t>(
allocator_, length_shape.data(), length_shape.size());
int64_t *p_length = length.GetTensorMutableData<int64_t>();
std::fill(p_length, p_length + batch_size, ChunkSize());
// (B, T, C) -> (B, C, T)
features = Transpose12(allocator_, &features);
std::array<Ort::Value, 5> inputs = {
std::move(features), View(&length), std::move(cache_last_channel),
std::move(cache_last_time), std::move(cache_last_channel_len)};
auto out =
encoder_sess_->Run({}, encoder_input_names_ptr_.data(), inputs.data(), inputs.size(),
encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size());
// out[0]: logit
// out[1] logit_length
// out[2:] states_next
//
// we need to remove out[1]
std::vector<Ort::Value> ans;
ans.reserve(out.size() - 1);
for (int32_t i = 0; i != out.size(); ++i) {
if (i == 1) {
continue;
}
ans.push_back(std::move(out[i]));
}
return ans;
}
std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder(
Ort::Value targets, std::vector<Ort::Value> states) {
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
// Create the tensor with a single int32_t value of 1
int32_t length_value = 1;
std::vector<int64_t> length_shape = {1};
Ort::Value targets_length = Ort::Value::CreateTensor<int32_t>(
memory_info, &length_value, 1, length_shape.data(), length_shape.size()
);
std::vector<Ort::Value> decoder_inputs;
decoder_inputs.reserve(2 + states.size());
decoder_inputs.push_back(std::move(targets));
decoder_inputs.push_back(std::move(targets_length));
for (auto &s : states) {
decoder_inputs.push_back(std::move(s));
}
auto decoder_out = decoder_sess_->Run(
{}, decoder_input_names_ptr_.data(), decoder_inputs.data(),
decoder_inputs.size(), decoder_output_names_ptr_.data(),
decoder_output_names_ptr_.size());
std::vector<Ort::Value> states_next;
states_next.reserve(states.size());
// decoder_out[0]: decoder_output
// decoder_out[1]: decoder_output_length (discarded)
// decoder_out[2:] states_next
for (int32_t i = 0; i != states.size(); ++i) {
states_next.push_back(std::move(decoder_out[i + 2]));
}
// we discard decoder_out[1]
return {std::move(decoder_out[0]), std::move(states_next)};
}
Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) {
std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
std::move(decoder_out)};
auto logit =
joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(),
joiner_input.size(), joiner_output_names_ptr_.data(),
joiner_output_names_ptr_.size());
return std::move(logit[0]);
}
std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const {
std::array<int64_t, 3> s0_shape{pred_rnn_layers_, batch_size, pred_hidden_};
Ort::Value s0 = Ort::Value::CreateTensor<float>(allocator_, s0_shape.data(),
s0_shape.size());
Fill<float>(&s0, 0);
std::array<int64_t, 3> s1_shape{pred_rnn_layers_, batch_size, pred_hidden_};
Ort::Value s1 = Ort::Value::CreateTensor<float>(allocator_, s1_shape.data(),
s1_shape.size());
Fill<float>(&s1, 0);
std::vector<Ort::Value> states;
states.reserve(2);
states.push_back(std::move(s0));
states.push_back(std::move(s1));
return states;
}
int32_t ChunkSize() const { return window_size_; }
int32_t ChunkShift() const { return chunk_shift_; }
int32_t SubsamplingFactor() const { return subsampling_factor_; }
int32_t VocabSize() const { return vocab_size_; }
OrtAllocator *Allocator() const { return allocator_; }
std::string FeatureNormalizationMethod() const { return normalize_type_; }
// Return a vector containing 3 tensors
// - cache_last_channel
// - cache_last_time_
// - cache_last_channel_len
std::vector<Ort::Value> GetInitStates() {
std::vector<Ort::Value> ans;
ans.reserve(3);
ans.push_back(View(&cache_last_channel_));
ans.push_back(View(&cache_last_time_));
ans.push_back(View(&cache_last_channel_len_));
return ans;
}
private:
void InitEncoder(void *model_data, size_t model_data_length) {
encoder_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
&encoder_input_names_ptr_);
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
&encoder_output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
os << "---encoder---\n";
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
// need to increase by 1 since the blank token is not included in computing
// vocab_size in NeMo.
vocab_size_ += 1;
SHERPA_ONNX_READ_META_DATA(window_size_, "window_size");
SHERPA_ONNX_READ_META_DATA(chunk_shift_, "chunk_shift");
SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor");
SHERPA_ONNX_READ_META_DATA_STR(normalize_type_, "normalize_type");
SHERPA_ONNX_READ_META_DATA(pred_rnn_layers_, "pred_rnn_layers");
SHERPA_ONNX_READ_META_DATA(pred_hidden_, "pred_hidden");
SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim1_,
"cache_last_channel_dim1");
SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim2_,
"cache_last_channel_dim2");
SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim3_,
"cache_last_channel_dim3");
SHERPA_ONNX_READ_META_DATA(cache_last_time_dim1_, "cache_last_time_dim1");
SHERPA_ONNX_READ_META_DATA(cache_last_time_dim2_, "cache_last_time_dim2");
SHERPA_ONNX_READ_META_DATA(cache_last_time_dim3_, "cache_last_time_dim3");
if (normalize_type_ == "NA") {
normalize_type_ = "";
}
InitStates();
}
void InitStates() {
std::array<int64_t, 4> cache_last_channel_shape{1, cache_last_channel_dim1_,
cache_last_channel_dim2_,
cache_last_channel_dim3_};
cache_last_channel_ = Ort::Value::CreateTensor<float>(
allocator_, cache_last_channel_shape.data(),
cache_last_channel_shape.size());
Fill<float>(&cache_last_channel_, 0);
std::array<int64_t, 4> cache_last_time_shape{
1, cache_last_time_dim1_, cache_last_time_dim2_, cache_last_time_dim3_};
cache_last_time_ = Ort::Value::CreateTensor<float>(
allocator_, cache_last_time_shape.data(), cache_last_time_shape.size());
Fill<float>(&cache_last_time_, 0);
int64_t shape = 1;
cache_last_channel_len_ =
Ort::Value::CreateTensor<int64_t>(allocator_, &shape, 1);
cache_last_channel_len_.GetTensorMutableData<int64_t>()[0] = 0;
}
void InitDecoder(void *model_data, size_t model_data_length) {
decoder_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
&decoder_input_names_ptr_);
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
&decoder_output_names_ptr_);
}
void InitJoiner(void *model_data, size_t model_data_length) {
joiner_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(joiner_sess_.get(), &joiner_input_names_,
&joiner_input_names_ptr_);
GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
&joiner_output_names_ptr_);
}
private:
OnlineModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> encoder_sess_;
std::unique_ptr<Ort::Session> decoder_sess_;
std::unique_ptr<Ort::Session> joiner_sess_;
std::vector<std::string> encoder_input_names_;
std::vector<const char *> encoder_input_names_ptr_;
std::vector<std::string> encoder_output_names_;
std::vector<const char *> encoder_output_names_ptr_;
std::vector<std::string> decoder_input_names_;
std::vector<const char *> decoder_input_names_ptr_;
std::vector<std::string> decoder_output_names_;
std::vector<const char *> decoder_output_names_ptr_;
std::vector<std::string> joiner_input_names_;
std::vector<const char *> joiner_input_names_ptr_;
std::vector<std::string> joiner_output_names_;
std::vector<const char *> joiner_output_names_ptr_;
int32_t window_size_;
int32_t chunk_shift_;
int32_t vocab_size_ = 0;
int32_t subsampling_factor_ = 8;
std::string normalize_type_;
int32_t pred_rnn_layers_ = -1;
int32_t pred_hidden_ = -1;
int32_t cache_last_channel_dim1_;
int32_t cache_last_channel_dim2_;
int32_t cache_last_channel_dim3_;
int32_t cache_last_time_dim1_;
int32_t cache_last_time_dim2_;
int32_t cache_last_time_dim3_;
Ort::Value cache_last_channel_{nullptr};
Ort::Value cache_last_time_{nullptr};
Ort::Value cache_last_channel_len_{nullptr};
};
OnlineTransducerNeMoModel::OnlineTransducerNeMoModel(
const OnlineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OnlineTransducerNeMoModel::OnlineTransducerNeMoModel(
AAssetManager *mgr, const OnlineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OnlineTransducerNeMoModel::~OnlineTransducerNeMoModel() = default;
std::vector<Ort::Value>
OnlineTransducerNeMoModel::RunEncoder(Ort::Value features,
std::vector<Ort::Value> states) const {
return impl_->RunEncoder(std::move(features), std::move(states));
}
std::pair<Ort::Value, std::vector<Ort::Value>>
OnlineTransducerNeMoModel::RunDecoder(Ort::Value targets,
std::vector<Ort::Value> states) const {
return impl_->RunDecoder(std::move(targets), std::move(states));
}
std::vector<Ort::Value> OnlineTransducerNeMoModel::GetDecoderInitStates(
int32_t batch_size) const {
return impl_->GetDecoderInitStates(batch_size);
}
Ort::Value OnlineTransducerNeMoModel::RunJoiner(Ort::Value encoder_out,
Ort::Value decoder_out) const {
return impl_->RunJoiner(std::move(encoder_out), std::move(decoder_out));
}
int32_t OnlineTransducerNeMoModel::ChunkSize() const {
return impl_->ChunkSize();
}
int32_t OnlineTransducerNeMoModel::ChunkShift() const {
return impl_->ChunkShift();
}
int32_t OnlineTransducerNeMoModel::SubsamplingFactor() const {
return impl_->SubsamplingFactor();
}
int32_t OnlineTransducerNeMoModel::VocabSize() const {
return impl_->VocabSize();
}
OrtAllocator *OnlineTransducerNeMoModel::Allocator() const {
return impl_->Allocator();
}
std::string OnlineTransducerNeMoModel::FeatureNormalizationMethod() const {
return impl_->FeatureNormalizationMethod();
}
std::vector<Ort::Value> OnlineTransducerNeMoModel::GetInitStates() const {
return impl_->GetInitStates();
}
} // namespace sherpa_onnx
\ No newline at end of file
... ...
// sherpa-onnx/csrc/online-transducer-nemo-model.h
//
// Copyright (c) 2024 Xiaomi Corporation
// Copyright (c) 2024 Sangeet Sagar
#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_NEMO_MODEL_H_
#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_NEMO_MODEL_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-model-config.h"
namespace sherpa_onnx {
// see
// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py#L40
// Its decoder is stateful, not stateless.
class OnlineTransducerNeMoModel {
public:
explicit OnlineTransducerNeMoModel(const OnlineModelConfig &config);
#if __ANDROID_API__ >= 9
OnlineTransducerNeMoModel(AAssetManager *mgr,
const OnlineModelConfig &config);
#endif
~OnlineTransducerNeMoModel();
// A list of 3 tensors:
// - cache_last_channel
// - cache_last_time
// - cache_last_channel_len
std::vector<Ort::Value> GetInitStates() const;
/** Run the encoder.
*
* @param features A tensor of shape (N, T, C). It is changed in-place.
* @param states It is from GetInitStates() or returned from this method.
*
* @return Return a tuple containing:
* - ans[0]: encoder_out, a tensor of shape (N, T', encoder_out_dim)
* - ans[1:]: contains next states
*/
std::vector<Ort::Value> RunEncoder(
Ort::Value features, std::vector<Ort::Value> states) const; // NOLINT
/** Run the decoder network.
*
* @param targets A int32 tensor of shape (batch_size, 1)
* @param states The states for the decoder model.
* @return Return a vector:
* - ans[0] is the decoder_out (a float tensor)
* - ans[1:] is the next states
*/
std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder(
Ort::Value targets, std::vector<Ort::Value> states) const;
std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const;
/** Run the joint network.
*
* @param encoder_out Output of the encoder network.
* @param decoder_out Output of the decoder network.
* @return Return a tensor of shape (N, 1, 1, vocab_size) containing logits.
*/
Ort::Value RunJoiner(Ort::Value encoder_out,
Ort::Value decoder_out) const;
/** We send this number of feature frames to the encoder at a time. */
int32_t ChunkSize() const;
/** Number of input frames to discard after each call to RunEncoder.
*
* For instance, if we have 30 frames, chunk_size=8, chunk_shift=6.
*
* In the first call of RunEncoder, we use frames 0~7 since chunk_size is 8.
* Then we discard frame 0~5 since chunk_shift is 6.
* In the second call of RunEncoder, we use frames 6~13; and then we discard
* frames 6~11.
* In the third call of RunEncoder, we use frames 12~19; and then we discard
* frames 12~16.
*
* Note: ChunkSize() - ChunkShift() == right context size
*/
int32_t ChunkShift() const;
/** Return the subsampling factor of the model.
*/
int32_t SubsamplingFactor() const;
int32_t VocabSize() const;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const;
// Possible values:
// - per_feature
// - all_features (not implemented yet)
// - fixed_mean (not implemented)
// - fixed_std (not implemented)
// - or just leave it to empty
// See
// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59
// for details
std::string FeatureNormalizationMethod() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_NEMO_MODEL_H_
... ...