Committed by
GitHub
Add C++ runtime for *streaming* faster conformer transducer from NeMo. (#889)
Co-authored-by: sangeet2020 <15uec053@gmail.com>
正在显示
10 个修改的文件
包含
1119 行增加
和
2 行删除
| @@ -74,6 +74,8 @@ set(sources | @@ -74,6 +74,8 @@ set(sources | ||
| 74 | online-transducer-model-config.cc | 74 | online-transducer-model-config.cc |
| 75 | online-transducer-model.cc | 75 | online-transducer-model.cc |
| 76 | online-transducer-modified-beam-search-decoder.cc | 76 | online-transducer-modified-beam-search-decoder.cc |
| 77 | + online-transducer-nemo-model.cc | ||
| 78 | + online-transducer-greedy-search-nemo-decoder.cc | ||
| 77 | online-wenet-ctc-model-config.cc | 79 | online-wenet-ctc-model-config.cc |
| 78 | online-wenet-ctc-model.cc | 80 | online-wenet-ctc-model.cc |
| 79 | online-zipformer-transducer-model.cc | 81 | online-zipformer-transducer-model.cc |
| @@ -7,13 +7,28 @@ | @@ -7,13 +7,28 @@ | ||
| 7 | #include "sherpa-onnx/csrc/online-recognizer-ctc-impl.h" | 7 | #include "sherpa-onnx/csrc/online-recognizer-ctc-impl.h" |
| 8 | #include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h" | 8 | #include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h" |
| 9 | #include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h" | 9 | #include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h" |
| 10 | +#include "sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h" | ||
| 11 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 10 | 12 | ||
| 11 | namespace sherpa_onnx { | 13 | namespace sherpa_onnx { |
| 12 | 14 | ||
| 13 | std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | 15 | std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( |
| 14 | const OnlineRecognizerConfig &config) { | 16 | const OnlineRecognizerConfig &config) { |
| 17 | + | ||
| 15 | if (!config.model_config.transducer.encoder.empty()) { | 18 | if (!config.model_config.transducer.encoder.empty()) { |
| 16 | - return std::make_unique<OnlineRecognizerTransducerImpl>(config); | 19 | + Ort::Env env(ORT_LOGGING_LEVEL_WARNING); |
| 20 | + | ||
| 21 | + auto decoder_model = ReadFile(config.model_config.transducer.decoder); | ||
| 22 | + auto sess = std::make_unique<Ort::Session>(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{}); | ||
| 23 | + | ||
| 24 | + size_t node_count = sess->GetOutputCount(); | ||
| 25 | + | ||
| 26 | + if (node_count == 1) { | ||
| 27 | + return std::make_unique<OnlineRecognizerTransducerImpl>(config); | ||
| 28 | + } else { | ||
| 29 | + SHERPA_ONNX_LOGE("Running streaming Nemo transducer model"); | ||
| 30 | + return std::make_unique<OnlineRecognizerTransducerNeMoImpl>(config); | ||
| 31 | + } | ||
| 17 | } | 32 | } |
| 18 | 33 | ||
| 19 | if (!config.model_config.paraformer.encoder.empty()) { | 34 | if (!config.model_config.paraformer.encoder.empty()) { |
| @@ -34,7 +49,18 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | @@ -34,7 +49,18 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | ||
| 34 | std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | 49 | std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( |
| 35 | AAssetManager *mgr, const OnlineRecognizerConfig &config) { | 50 | AAssetManager *mgr, const OnlineRecognizerConfig &config) { |
| 36 | if (!config.model_config.transducer.encoder.empty()) { | 51 | if (!config.model_config.transducer.encoder.empty()) { |
| 37 | - return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config); | 52 | + Ort::Env env(ORT_LOGGING_LEVEL_WARNING); |
| 53 | + | ||
| 54 | + auto decoder_model = ReadFile(mgr, config.model_config.transducer.decoder); | ||
| 55 | + auto sess = std::make_unique<Ort::Session>(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{}); | ||
| 56 | + | ||
| 57 | + size_t node_count = sess->GetOutputCount(); | ||
| 58 | + | ||
| 59 | + if (node_count == 1) { | ||
| 60 | + return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config); | ||
| 61 | + } else { | ||
| 62 | + return std::make_unique<OnlineRecognizerTransducerNeMoImpl>(mgr, config); | ||
| 63 | + } | ||
| 38 | } | 64 | } |
| 39 | 65 | ||
| 40 | if (!config.model_config.paraformer.encoder.empty()) { | 66 | if (!config.model_config.paraformer.encoder.empty()) { |
| @@ -46,6 +46,7 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, | @@ -46,6 +46,7 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, | ||
| 46 | r.timestamps.reserve(src.tokens.size()); | 46 | r.timestamps.reserve(src.tokens.size()); |
| 47 | 47 | ||
| 48 | for (auto i : src.tokens) { | 48 | for (auto i : src.tokens) { |
| 49 | + if (i == -1) continue; | ||
| 49 | auto sym = sym_table[i]; | 50 | auto sym = sym_table[i]; |
| 50 | 51 | ||
| 51 | r.text.append(sym); | 52 | r.text.append(sym); |
| 1 | +// sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2024 Xiaomi Corporation | ||
| 4 | +// Copyright (c) 2024 Sangeet Sagar | ||
| 5 | + | ||
| 6 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ | ||
| 7 | +#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ | ||
| 8 | + | ||
| 9 | +#include <fstream> | ||
| 10 | +#include <ios> | ||
| 11 | +#include <memory> | ||
| 12 | +#include <regex> // NOLINT | ||
| 13 | +#include <sstream> | ||
| 14 | +#include <string> | ||
| 15 | +#include <utility> | ||
| 16 | +#include <vector> | ||
| 17 | + | ||
| 18 | +#if __ANDROID_API__ >= 9 | ||
| 19 | +#include "android/asset_manager.h" | ||
| 20 | +#include "android/asset_manager_jni.h" | ||
| 21 | +#endif | ||
| 22 | + | ||
| 23 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 24 | +#include "sherpa-onnx/csrc/online-recognizer-impl.h" | ||
| 25 | +#include "sherpa-onnx/csrc/online-recognizer.h" | ||
| 26 | +#include "sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h" | ||
| 27 | +#include "sherpa-onnx/csrc/online-transducer-nemo-model.h" | ||
| 28 | +#include "sherpa-onnx/csrc/symbol-table.h" | ||
| 29 | +#include "sherpa-onnx/csrc/transpose.h" | ||
| 30 | +#include "sherpa-onnx/csrc/utils.h" | ||
| 31 | + | ||
| 32 | +namespace sherpa_onnx { | ||
| 33 | + | ||
| 34 | +// defined in ./online-recognizer-transducer-impl.h | ||
| 35 | +// static may or may not be here? TODDOs | ||
| 36 | +static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, | ||
| 37 | + const SymbolTable &sym_table, | ||
| 38 | + float frame_shift_ms, | ||
| 39 | + int32_t subsampling_factor, | ||
| 40 | + int32_t segment, | ||
| 41 | + int32_t frames_since_start); | ||
| 42 | + | ||
| 43 | +class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | ||
| 44 | + public: | ||
| 45 | + explicit OnlineRecognizerTransducerNeMoImpl( | ||
| 46 | + const OnlineRecognizerConfig &config) | ||
| 47 | + : config_(config), | ||
| 48 | + symbol_table_(config.model_config.tokens), | ||
| 49 | + endpoint_(config_.endpoint_config), | ||
| 50 | + model_(std::make_unique<OnlineTransducerNeMoModel>( | ||
| 51 | + config.model_config)) { | ||
| 52 | + if (config.decoding_method == "greedy_search") { | ||
| 53 | + decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>( | ||
| 54 | + model_.get(), config_.blank_penalty); | ||
| 55 | + } else { | ||
| 56 | + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", | ||
| 57 | + config.decoding_method.c_str()); | ||
| 58 | + exit(-1); | ||
| 59 | + } | ||
| 60 | + PostInit(); | ||
| 61 | + } | ||
| 62 | + | ||
| 63 | +#if __ANDROID_API__ >= 9 | ||
| 64 | + explicit OnlineRecognizerTransducerNeMoImpl( | ||
| 65 | + AAssetManager *mgr, const OnlineRecognizerConfig &config) | ||
| 66 | + : config_(config), | ||
| 67 | + symbol_table_(mgr, config.model_config.tokens), | ||
| 68 | + endpoint_(mgrconfig_.endpoint_config), | ||
| 69 | + model_(std::make_unique<OnlineTransducerNeMoModel>( | ||
| 70 | + mgr, config.model_config)) { | ||
| 71 | + if (config.decoding_method == "greedy_search") { | ||
| 72 | + decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>( | ||
| 73 | + model_.get(), config_.blank_penalty); | ||
| 74 | + } else { | ||
| 75 | + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", | ||
| 76 | + config.decoding_method.c_str()); | ||
| 77 | + exit(-1); | ||
| 78 | + } | ||
| 79 | + | ||
| 80 | + PostInit(); | ||
| 81 | + } | ||
| 82 | +#endif | ||
| 83 | + | ||
| 84 | + std::unique_ptr<OnlineStream> CreateStream() const override { | ||
| 85 | + auto stream = std::make_unique<OnlineStream>(config_.feat_config); | ||
| 86 | + stream->SetStates(model_->GetInitStates()); | ||
| 87 | + InitOnlineStream(stream.get()); | ||
| 88 | + return stream; | ||
| 89 | + } | ||
| 90 | + | ||
| 91 | + bool IsReady(OnlineStream *s) const override { | ||
| 92 | + return s->GetNumProcessedFrames() + model_->ChunkSize() < | ||
| 93 | + s->NumFramesReady(); | ||
| 94 | + } | ||
| 95 | + | ||
| 96 | + OnlineRecognizerResult GetResult(OnlineStream *s) const override { | ||
| 97 | + OnlineTransducerDecoderResult decoder_result = s->GetResult(); | ||
| 98 | + decoder_->StripLeadingBlanks(&decoder_result); | ||
| 99 | + | ||
| 100 | + // TODO(fangjun): Remember to change these constants if needed | ||
| 101 | + int32_t frame_shift_ms = 10; | ||
| 102 | + int32_t subsampling_factor = 8; | ||
| 103 | + return Convert(decoder_result, symbol_table_, frame_shift_ms, subsampling_factor, | ||
| 104 | + s->GetCurrentSegment(), s->GetNumFramesSinceStart()); | ||
| 105 | + } | ||
| 106 | + | ||
| 107 | + bool IsEndpoint(OnlineStream *s) const override { | ||
| 108 | + if (!config_.enable_endpoint) { | ||
| 109 | + return false; | ||
| 110 | + } | ||
| 111 | + | ||
| 112 | + int32_t num_processed_frames = s->GetNumProcessedFrames(); | ||
| 113 | + | ||
| 114 | + // frame shift is 10 milliseconds | ||
| 115 | + float frame_shift_in_seconds = 0.01; | ||
| 116 | + | ||
| 117 | + // subsampling factor is 8 | ||
| 118 | + int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 8; | ||
| 119 | + | ||
| 120 | + return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames, | ||
| 121 | + frame_shift_in_seconds); | ||
| 122 | + } | ||
| 123 | + | ||
| 124 | + void Reset(OnlineStream *s) const override { | ||
| 125 | + { | ||
| 126 | + // segment is incremented only when the last | ||
| 127 | + // result is not empty | ||
| 128 | + const auto &r = s->GetResult(); | ||
| 129 | + if (!r.tokens.empty() && r.tokens.back() != 0) { | ||
| 130 | + s->GetCurrentSegment() += 1; | ||
| 131 | + } | ||
| 132 | + } | ||
| 133 | + | ||
| 134 | + // we keep the decoder_out | ||
| 135 | + decoder_->UpdateDecoderOut(&s->GetResult()); | ||
| 136 | + Ort::Value decoder_out = std::move(s->GetResult().decoder_out); | ||
| 137 | + | ||
| 138 | + auto r = decoder_->GetEmptyResult(); | ||
| 139 | + | ||
| 140 | + s->SetResult(r); | ||
| 141 | + s->GetResult().decoder_out = std::move(decoder_out); | ||
| 142 | + | ||
| 143 | + // Note: We only update counters. The underlying audio samples | ||
| 144 | + // are not discarded. | ||
| 145 | + s->Reset(); | ||
| 146 | + } | ||
| 147 | + | ||
| 148 | + void DecodeStreams(OnlineStream **ss, int32_t n) const override { | ||
| 149 | + int32_t chunk_size = model_->ChunkSize(); | ||
| 150 | + int32_t chunk_shift = model_->ChunkShift(); | ||
| 151 | + | ||
| 152 | + int32_t feature_dim = ss[0]->FeatureDim(); | ||
| 153 | + | ||
| 154 | + std::vector<OnlineTransducerDecoderResult> result(n); | ||
| 155 | + std::vector<float> features_vec(n * chunk_size * feature_dim); | ||
| 156 | + std::vector<std::vector<Ort::Value>> encoder_states(n); | ||
| 157 | + | ||
| 158 | + for (int32_t i = 0; i != n; ++i) { | ||
| 159 | + const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); | ||
| 160 | + std::vector<float> features = | ||
| 161 | + ss[i]->GetFrames(num_processed_frames, chunk_size); | ||
| 162 | + | ||
| 163 | + // Question: should num_processed_frames include chunk_shift? | ||
| 164 | + ss[i]->GetNumProcessedFrames() += chunk_shift; | ||
| 165 | + | ||
| 166 | + std::copy(features.begin(), features.end(), | ||
| 167 | + features_vec.data() + i * chunk_size * feature_dim); | ||
| 168 | + | ||
| 169 | + result[i] = std::move(ss[i]->GetResult()); | ||
| 170 | + encoder_states[i] = std::move(ss[i]->GetStates()); | ||
| 171 | + | ||
| 172 | + } | ||
| 173 | + | ||
| 174 | + auto memory_info = | ||
| 175 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 176 | + | ||
| 177 | + std::array<int64_t, 3> x_shape{n, chunk_size, feature_dim}; | ||
| 178 | + | ||
| 179 | + Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(), | ||
| 180 | + features_vec.size(), x_shape.data(), | ||
| 181 | + x_shape.size()); | ||
| 182 | + | ||
| 183 | + // Batch size is 1 | ||
| 184 | + auto states = std::move(encoder_states[0]); | ||
| 185 | + int32_t num_states = states.size(); // num_states = 3 | ||
| 186 | + auto t = model_->RunEncoder(std::move(x), std::move(states)); | ||
| 187 | + // t[0] encoder_out, float tensor, (batch_size, dim, T) | ||
| 188 | + // t[1] next states | ||
| 189 | + | ||
| 190 | + std::vector<Ort::Value> out_states; | ||
| 191 | + out_states.reserve(num_states); | ||
| 192 | + | ||
| 193 | + for (int32_t k = 1; k != num_states + 1; ++k) { | ||
| 194 | + out_states.push_back(std::move(t[k])); | ||
| 195 | + } | ||
| 196 | + | ||
| 197 | + Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]); | ||
| 198 | + | ||
| 199 | + // defined in online-transducer-greedy-search-nemo-decoder.h | ||
| 200 | + // get intial states of decoder. | ||
| 201 | + std::vector<Ort::Value> &decoder_states = ss[0]->GetNeMoDecoderStates(); | ||
| 202 | + | ||
| 203 | + // Subsequent decoder states (for each chunks) are updated inside the Decode method. | ||
| 204 | + // This returns the decoder state from the LAST chunk. We probably dont need it. So we can discard it. | ||
| 205 | + decoder_states = decoder_->Decode(std::move(encoder_out), | ||
| 206 | + std::move(decoder_states), | ||
| 207 | + &result, ss, n); | ||
| 208 | + | ||
| 209 | + ss[0]->SetResult(result[0]); | ||
| 210 | + | ||
| 211 | + ss[0]->SetStates(std::move(out_states)); | ||
| 212 | + } | ||
| 213 | + | ||
| 214 | + void InitOnlineStream(OnlineStream *stream) const { | ||
| 215 | + auto r = decoder_->GetEmptyResult(); | ||
| 216 | + | ||
| 217 | + stream->SetResult(r); | ||
| 218 | + stream->SetNeMoDecoderStates(model_->GetDecoderInitStates(1)); | ||
| 219 | + } | ||
| 220 | + | ||
| 221 | + private: | ||
| 222 | + void PostInit() { | ||
| 223 | + config_.feat_config.nemo_normalize_type = | ||
| 224 | + model_->FeatureNormalizationMethod(); | ||
| 225 | + | ||
| 226 | + config_.feat_config.low_freq = 0; | ||
| 227 | + // config_.feat_config.high_freq = 8000; | ||
| 228 | + config_.feat_config.is_librosa = true; | ||
| 229 | + config_.feat_config.remove_dc_offset = false; | ||
| 230 | + // config_.feat_config.window_type = "hann"; | ||
| 231 | + config_.feat_config.dither = 0; | ||
| 232 | + config_.feat_config.nemo_normalize_type = | ||
| 233 | + model_->FeatureNormalizationMethod(); | ||
| 234 | + | ||
| 235 | + int32_t vocab_size = model_->VocabSize(); | ||
| 236 | + | ||
| 237 | + // check the blank ID | ||
| 238 | + if (!symbol_table_.Contains("<blk>")) { | ||
| 239 | + SHERPA_ONNX_LOGE("tokens.txt does not include the blank token <blk>"); | ||
| 240 | + exit(-1); | ||
| 241 | + } | ||
| 242 | + | ||
| 243 | + if (symbol_table_["<blk>"] != vocab_size - 1) { | ||
| 244 | + SHERPA_ONNX_LOGE("<blk> is not the last token!"); | ||
| 245 | + exit(-1); | ||
| 246 | + } | ||
| 247 | + | ||
| 248 | + if (symbol_table_.NumSymbols() != vocab_size) { | ||
| 249 | + SHERPA_ONNX_LOGE("number of lines in tokens.txt %d != %d (vocab_size)", | ||
| 250 | + symbol_table_.NumSymbols(), vocab_size); | ||
| 251 | + exit(-1); | ||
| 252 | + } | ||
| 253 | + | ||
| 254 | + } | ||
| 255 | + | ||
| 256 | + private: | ||
| 257 | + OnlineRecognizerConfig config_; | ||
| 258 | + SymbolTable symbol_table_; | ||
| 259 | + std::unique_ptr<OnlineTransducerNeMoModel> model_; | ||
| 260 | + std::unique_ptr<OnlineTransducerGreedySearchNeMoDecoder> decoder_; | ||
| 261 | + Endpoint endpoint_; | ||
| 262 | + | ||
| 263 | +}; | ||
| 264 | + | ||
| 265 | +} // namespace sherpa_onnx | ||
| 266 | + | ||
| 267 | +#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ |
| @@ -90,6 +90,12 @@ class OnlineStream::Impl { | @@ -90,6 +90,12 @@ class OnlineStream::Impl { | ||
| 90 | 90 | ||
| 91 | std::vector<Ort::Value> &GetStates() { return states_; } | 91 | std::vector<Ort::Value> &GetStates() { return states_; } |
| 92 | 92 | ||
| 93 | + void SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states) { | ||
| 94 | + decoder_states_ = std::move(decoder_states); | ||
| 95 | + } | ||
| 96 | + | ||
| 97 | + std::vector<Ort::Value> &GetNeMoDecoderStates() { return decoder_states_; } | ||
| 98 | + | ||
| 93 | const ContextGraphPtr &GetContextGraph() const { return context_graph_; } | 99 | const ContextGraphPtr &GetContextGraph() const { return context_graph_; } |
| 94 | 100 | ||
| 95 | std::vector<float> &GetParaformerFeatCache() { | 101 | std::vector<float> &GetParaformerFeatCache() { |
| @@ -129,6 +135,7 @@ class OnlineStream::Impl { | @@ -129,6 +135,7 @@ class OnlineStream::Impl { | ||
| 129 | TransducerKeywordResult empty_keyword_result_; | 135 | TransducerKeywordResult empty_keyword_result_; |
| 130 | OnlineCtcDecoderResult ctc_result_; | 136 | OnlineCtcDecoderResult ctc_result_; |
| 131 | std::vector<Ort::Value> states_; // states for transducer or ctc models | 137 | std::vector<Ort::Value> states_; // states for transducer or ctc models |
| 138 | + std::vector<Ort::Value> decoder_states_; // states for nemo transducer models | ||
| 132 | std::vector<float> paraformer_feat_cache_; | 139 | std::vector<float> paraformer_feat_cache_; |
| 133 | std::vector<float> paraformer_encoder_out_cache_; | 140 | std::vector<float> paraformer_encoder_out_cache_; |
| 134 | std::vector<float> paraformer_alpha_cache_; | 141 | std::vector<float> paraformer_alpha_cache_; |
| @@ -218,6 +225,14 @@ std::vector<Ort::Value> &OnlineStream::GetStates() { | @@ -218,6 +225,14 @@ std::vector<Ort::Value> &OnlineStream::GetStates() { | ||
| 218 | return impl_->GetStates(); | 225 | return impl_->GetStates(); |
| 219 | } | 226 | } |
| 220 | 227 | ||
| 228 | +void OnlineStream::SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states) { | ||
| 229 | + return impl_->SetNeMoDecoderStates(std::move(decoder_states)); | ||
| 230 | +} | ||
| 231 | + | ||
| 232 | +std::vector<Ort::Value> &OnlineStream::GetNeMoDecoderStates() { | ||
| 233 | + return impl_->GetNeMoDecoderStates(); | ||
| 234 | +} | ||
| 235 | + | ||
| 221 | const ContextGraphPtr &OnlineStream::GetContextGraph() const { | 236 | const ContextGraphPtr &OnlineStream::GetContextGraph() const { |
| 222 | return impl_->GetContextGraph(); | 237 | return impl_->GetContextGraph(); |
| 223 | } | 238 | } |
| @@ -91,6 +91,9 @@ class OnlineStream { | @@ -91,6 +91,9 @@ class OnlineStream { | ||
| 91 | void SetStates(std::vector<Ort::Value> states); | 91 | void SetStates(std::vector<Ort::Value> states); |
| 92 | std::vector<Ort::Value> &GetStates(); | 92 | std::vector<Ort::Value> &GetStates(); |
| 93 | 93 | ||
| 94 | + void SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states); | ||
| 95 | + std::vector<Ort::Value> &GetNeMoDecoderStates(); | ||
| 96 | + | ||
| 94 | /** | 97 | /** |
| 95 | * Get the context graph corresponding to this stream. | 98 | * Get the context graph corresponding to this stream. |
| 96 | * | 99 | * |
| 1 | +// sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +// Copyright (c) 2024 Sangeet Sagar | ||
| 5 | + | ||
| 6 | +#include "sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h" | ||
| 7 | + | ||
| 8 | +#include <algorithm> | ||
| 9 | +#include <iterator> | ||
| 10 | +#include <utility> | ||
| 11 | + | ||
| 12 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 13 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | + | ||
| 17 | +static std::pair<Ort::Value, Ort::Value> BuildDecoderInput( | ||
| 18 | + int32_t token, OrtAllocator *allocator) { | ||
| 19 | + std::array<int64_t, 2> shape{1, 1}; | ||
| 20 | + | ||
| 21 | + Ort::Value decoder_input = | ||
| 22 | + Ort::Value::CreateTensor<int32_t>(allocator, shape.data(), shape.size()); | ||
| 23 | + | ||
| 24 | + std::array<int64_t, 1> length_shape{1}; | ||
| 25 | + Ort::Value decoder_input_length = Ort::Value::CreateTensor<int32_t>( | ||
| 26 | + allocator, length_shape.data(), length_shape.size()); | ||
| 27 | + | ||
| 28 | + int32_t *p = decoder_input.GetTensorMutableData<int32_t>(); | ||
| 29 | + | ||
| 30 | + int32_t *p_length = decoder_input_length.GetTensorMutableData<int32_t>(); | ||
| 31 | + | ||
| 32 | + p[0] = token; | ||
| 33 | + | ||
| 34 | + p_length[0] = 1; | ||
| 35 | + | ||
| 36 | + return {std::move(decoder_input), std::move(decoder_input_length)}; | ||
| 37 | +} | ||
| 38 | + | ||
| 39 | + | ||
| 40 | +OnlineTransducerDecoderResult | ||
| 41 | +OnlineTransducerGreedySearchNeMoDecoder::GetEmptyResult() const { | ||
| 42 | + int32_t context_size = 8; | ||
| 43 | + int32_t blank_id = 0; // always 0 | ||
| 44 | + OnlineTransducerDecoderResult r; | ||
| 45 | + r.tokens.resize(context_size, -1); | ||
| 46 | + r.tokens.back() = blank_id; | ||
| 47 | + | ||
| 48 | + return r; | ||
| 49 | +} | ||
| 50 | + | ||
| 51 | +static void UpdateCachedDecoderOut( | ||
| 52 | + OrtAllocator *allocator, const Ort::Value *decoder_out, | ||
| 53 | + std::vector<OnlineTransducerDecoderResult> *result) { | ||
| 54 | + std::vector<int64_t> shape = | ||
| 55 | + decoder_out->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 56 | + auto memory_info = | ||
| 57 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 58 | + std::array<int64_t, 2> v_shape{1, shape[1]}; | ||
| 59 | + | ||
| 60 | + const float *src = decoder_out->GetTensorData<float>(); | ||
| 61 | + for (auto &r : *result) { | ||
| 62 | + if (!r.decoder_out) { | ||
| 63 | + r.decoder_out = Ort::Value::CreateTensor<float>(allocator, v_shape.data(), | ||
| 64 | + v_shape.size()); | ||
| 65 | + } | ||
| 66 | + | ||
| 67 | + float *dst = r.decoder_out.GetTensorMutableData<float>(); | ||
| 68 | + std::copy(src, src + shape[1], dst); | ||
| 69 | + src += shape[1]; | ||
| 70 | + } | ||
| 71 | +} | ||
| 72 | + | ||
| 73 | +std::vector<Ort::Value> DecodeOne( | ||
| 74 | + const float *encoder_out, int32_t num_rows, int32_t num_cols, | ||
| 75 | + OnlineTransducerNeMoModel *model, float blank_penalty, | ||
| 76 | + std::vector<Ort::Value>& decoder_states, | ||
| 77 | + std::vector<OnlineTransducerDecoderResult> *result) { | ||
| 78 | + | ||
| 79 | + auto memory_info = | ||
| 80 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 81 | + | ||
| 82 | + // OnlineTransducerDecoderResult result; | ||
| 83 | + int32_t vocab_size = model->VocabSize(); | ||
| 84 | + int32_t blank_id = vocab_size - 1; | ||
| 85 | + | ||
| 86 | + auto &r = (*result)[0]; | ||
| 87 | + Ort::Value decoder_out{nullptr}; | ||
| 88 | + | ||
| 89 | + auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator()); | ||
| 90 | + // decoder_input_pair[0]: decoder_input | ||
| 91 | + // decoder_input_pair[1]: decoder_input_length (discarded) | ||
| 92 | + | ||
| 93 | + // decoder_output_pair.second returns the next decoder state | ||
| 94 | + std::pair<Ort::Value, std::vector<Ort::Value>> decoder_output_pair = | ||
| 95 | + model->RunDecoder(std::move(decoder_input_pair.first), | ||
| 96 | + std::move(decoder_states)); // here decoder_states = {len=0, cap=0}. But decoder_output_pair= {first, second: {len=2, cap=2}} // ATTN | ||
| 97 | + | ||
| 98 | + std::array<int64_t, 3> encoder_shape{1, num_cols, 1}; | ||
| 99 | + | ||
| 100 | + decoder_states = std::move(decoder_output_pair.second); | ||
| 101 | + | ||
| 102 | + // TODO: Inside this loop, I need to framewise decoding. | ||
| 103 | + for (int32_t t = 0; t != num_rows; ++t) { | ||
| 104 | + Ort::Value cur_encoder_out = Ort::Value::CreateTensor( | ||
| 105 | + memory_info, const_cast<float *>(encoder_out) + t * num_cols, num_cols, | ||
| 106 | + encoder_shape.data(), encoder_shape.size()); | ||
| 107 | + | ||
| 108 | + Ort::Value logit = model->RunJoiner(std::move(cur_encoder_out), | ||
| 109 | + View(&decoder_output_pair.first)); | ||
| 110 | + | ||
| 111 | + float *p_logit = logit.GetTensorMutableData<float>(); | ||
| 112 | + if (blank_penalty > 0) { | ||
| 113 | + p_logit[blank_id] -= blank_penalty; | ||
| 114 | + } | ||
| 115 | + | ||
| 116 | + auto y = static_cast<int32_t>(std::distance( | ||
| 117 | + static_cast<const float *>(p_logit), | ||
| 118 | + std::max_element(static_cast<const float *>(p_logit), | ||
| 119 | + static_cast<const float *>(p_logit) + vocab_size))); | ||
| 120 | + SHERPA_ONNX_LOGE("y=%d", y); | ||
| 121 | + if (y != blank_id) { | ||
| 122 | + r.tokens.push_back(y); | ||
| 123 | + r.timestamps.push_back(t + r.frame_offset); | ||
| 124 | + | ||
| 125 | + decoder_input_pair = BuildDecoderInput(y, model->Allocator()); | ||
| 126 | + | ||
| 127 | + // last decoder state becomes the current state for the first chunk | ||
| 128 | + decoder_output_pair = | ||
| 129 | + model->RunDecoder(std::move(decoder_input_pair.first), | ||
| 130 | + std::move(decoder_states)); | ||
| 131 | + | ||
| 132 | + // Update the decoder states for the next chunk | ||
| 133 | + decoder_states = std::move(decoder_output_pair.second); | ||
| 134 | + } | ||
| 135 | + } | ||
| 136 | + | ||
| 137 | + decoder_out = std::move(decoder_output_pair.first); | ||
| 138 | +// UpdateCachedDecoderOut(model->Allocator(), &decoder_out, result); | ||
| 139 | + | ||
| 140 | + // Update frame_offset | ||
| 141 | + for (auto &r : *result) { | ||
| 142 | + r.frame_offset += num_rows; | ||
| 143 | + } | ||
| 144 | + | ||
| 145 | + return std::move(decoder_states); | ||
| 146 | +} | ||
| 147 | + | ||
| 148 | + | ||
| 149 | +std::vector<Ort::Value> OnlineTransducerGreedySearchNeMoDecoder::Decode( | ||
| 150 | + Ort::Value encoder_out, | ||
| 151 | + std::vector<Ort::Value> decoder_states, | ||
| 152 | + std::vector<OnlineTransducerDecoderResult> *result, | ||
| 153 | + OnlineStream ** /*ss = nullptr*/, int32_t /*n= 0*/) { | ||
| 154 | + | ||
| 155 | + auto shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 156 | + | ||
| 157 | + if (shape[0] != result->size()) { | ||
| 158 | + SHERPA_ONNX_LOGE( | ||
| 159 | + "Size mismatch! encoder_out.size(0) %d, result.size(0): %d", | ||
| 160 | + static_cast<int32_t>(shape[0]), | ||
| 161 | + static_cast<int32_t>(result->size())); | ||
| 162 | + exit(-1); | ||
| 163 | + } | ||
| 164 | + | ||
| 165 | + int32_t batch_size = static_cast<int32_t>(shape[0]); // bs = 1 | ||
| 166 | + int32_t dim1 = static_cast<int32_t>(shape[1]); // 2 | ||
| 167 | + int32_t dim2 = static_cast<int32_t>(shape[2]); // 512 | ||
| 168 | + | ||
| 169 | + // Define and initialize encoder_out_length | ||
| 170 | + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); | ||
| 171 | + | ||
| 172 | + int64_t length_value = 1; | ||
| 173 | + std::vector<int64_t> length_shape = {1}; | ||
| 174 | + | ||
| 175 | + Ort::Value encoder_out_length = Ort::Value::CreateTensor<int64_t>( | ||
| 176 | + memory_info, &length_value, 1, length_shape.data(), length_shape.size() | ||
| 177 | + ); | ||
| 178 | + | ||
| 179 | + const int64_t *p_length = encoder_out_length.GetTensorData<int64_t>(); | ||
| 180 | + const float *p = encoder_out.GetTensorData<float>(); | ||
| 181 | + | ||
| 182 | + // std::vector<OnlineTransducerDecoderResult> ans(batch_size); | ||
| 183 | + | ||
| 184 | + for (int32_t i = 0; i != batch_size; ++i) { | ||
| 185 | + const float *this_p = p + dim1 * dim2 * i; | ||
| 186 | + int32_t this_len = p_length[i]; | ||
| 187 | + | ||
| 188 | + // outputs the decoder state from last chunk. | ||
| 189 | + auto last_decoder_states = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_, decoder_states, result); | ||
| 190 | + // ans[i] = decode_result_pair.first; | ||
| 191 | + decoder_states = std::move(last_decoder_states); | ||
| 192 | + } | ||
| 193 | + | ||
| 194 | + return decoder_states; | ||
| 195 | + | ||
| 196 | +} | ||
| 197 | + | ||
| 198 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +// Copyright (c) 2024 Sangeet Sagar | ||
| 5 | + | ||
| 6 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ | ||
| 7 | +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ | ||
| 8 | + | ||
| 9 | +#include <vector> | ||
| 10 | +#include "sherpa-onnx/csrc/online-transducer-decoder.h" | ||
| 11 | +#include "sherpa-onnx/csrc/online-transducer-nemo-model.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +class OnlineTransducerGreedySearchNeMoDecoder { | ||
| 16 | + public: | ||
| 17 | + OnlineTransducerGreedySearchNeMoDecoder(OnlineTransducerNeMoModel *model, | ||
| 18 | + float blank_penalty) | ||
| 19 | + : model_(model), | ||
| 20 | + blank_penalty_(blank_penalty) {} | ||
| 21 | + | ||
| 22 | + OnlineTransducerDecoderResult GetEmptyResult() const; | ||
| 23 | + void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {} | ||
| 24 | + void StripLeadingBlanks(OnlineTransducerDecoderResult * /*r*/) const {} | ||
| 25 | + | ||
| 26 | + std::vector<Ort::Value> Decode( | ||
| 27 | + Ort::Value encoder_out, | ||
| 28 | + std::vector<Ort::Value> decoder_states, | ||
| 29 | + std::vector<OnlineTransducerDecoderResult> *result, | ||
| 30 | + OnlineStream **ss = nullptr, int32_t n = 0); | ||
| 31 | + | ||
| 32 | + private: | ||
| 33 | + OnlineTransducerNeMoModel *model_; // Not owned | ||
| 34 | + float blank_penalty_; | ||
| 35 | +}; | ||
| 36 | + | ||
| 37 | +} // namespace sherpa_onnx | ||
| 38 | + | ||
| 39 | +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ | ||
| 40 | + |
| 1 | +// sherpa-onnx/csrc/online-transducer-nemo-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +// Copyright (c) 2024 Sangeet Sagar | ||
| 5 | + | ||
| 6 | +#include "sherpa-onnx/csrc/online-transducer-nemo-model.h" | ||
| 7 | + | ||
| 8 | +#include <assert.h> | ||
| 9 | +#include <math.h> | ||
| 10 | + | ||
| 11 | +#include <algorithm> | ||
| 12 | +#include <memory> | ||
| 13 | +#include <numeric> | ||
| 14 | +#include <sstream> | ||
| 15 | +#include <string> | ||
| 16 | +#include <utility> | ||
| 17 | +#include <vector> | ||
| 18 | + | ||
| 19 | +#if __ANDROID_API__ >= 9 | ||
| 20 | +#include "android/asset_manager.h" | ||
| 21 | +#include "android/asset_manager_jni.h" | ||
| 22 | +#endif | ||
| 23 | + | ||
| 24 | +#include "sherpa-onnx/csrc/cat.h" | ||
| 25 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 26 | +#include "sherpa-onnx/csrc/online-transducer-decoder.h" | ||
| 27 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 28 | +#include "sherpa-onnx/csrc/session.h" | ||
| 29 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 30 | +#include "sherpa-onnx/csrc/transpose.h" | ||
| 31 | +#include "sherpa-onnx/csrc/unbind.h" | ||
| 32 | + | ||
| 33 | +namespace sherpa_onnx { | ||
| 34 | + | ||
| 35 | +class OnlineTransducerNeMoModel::Impl { | ||
| 36 | + public: | ||
| 37 | + explicit Impl(const OnlineModelConfig &config) | ||
| 38 | + : config_(config), | ||
| 39 | + env_(ORT_LOGGING_LEVEL_WARNING), | ||
| 40 | + sess_opts_(GetSessionOptions(config)), | ||
| 41 | + allocator_{} { | ||
| 42 | + { | ||
| 43 | + auto buf = ReadFile(config.transducer.encoder); | ||
| 44 | + InitEncoder(buf.data(), buf.size()); | ||
| 45 | + } | ||
| 46 | + | ||
| 47 | + { | ||
| 48 | + auto buf = ReadFile(config.transducer.decoder); | ||
| 49 | + InitDecoder(buf.data(), buf.size()); | ||
| 50 | + } | ||
| 51 | + | ||
| 52 | + { | ||
| 53 | + auto buf = ReadFile(config.transducer.joiner); | ||
| 54 | + InitJoiner(buf.data(), buf.size()); | ||
| 55 | + } | ||
| 56 | + } | ||
| 57 | + | ||
| 58 | +#if __ANDROID_API__ >= 9 | ||
| 59 | + Impl(AAssetManager *mgr, const OnlineModelConfig &config) | ||
| 60 | + : config_(config), | ||
| 61 | + env_(ORT_LOGGING_LEVEL_WARNING), | ||
| 62 | + sess_opts_(GetSessionOptions(config)), | ||
| 63 | + allocator_{} { | ||
| 64 | + { | ||
| 65 | + auto buf = ReadFile(mgr, config.transducer.encoder_filename); | ||
| 66 | + InitEncoder(buf.data(), buf.size()); | ||
| 67 | + } | ||
| 68 | + | ||
| 69 | + { | ||
| 70 | + auto buf = ReadFile(mgr, config.transducer.decoder_filename); | ||
| 71 | + InitDecoder(buf.data(), buf.size()); | ||
| 72 | + } | ||
| 73 | + | ||
| 74 | + { | ||
| 75 | + auto buf = ReadFile(mgr, config.transducer.joiner_filename); | ||
| 76 | + InitJoiner(buf.data(), buf.size()); | ||
| 77 | + } | ||
| 78 | + } | ||
| 79 | +#endif | ||
| 80 | + | ||
| 81 | + std::vector<Ort::Value> RunEncoder(Ort::Value features, | ||
| 82 | + std::vector<Ort::Value> states) { | ||
| 83 | + Ort::Value &cache_last_channel = states[0]; | ||
| 84 | + Ort::Value &cache_last_time = states[1]; | ||
| 85 | + Ort::Value &cache_last_channel_len = states[2]; | ||
| 86 | + | ||
| 87 | + int32_t batch_size = features.GetTensorTypeAndShapeInfo().GetShape()[0]; | ||
| 88 | + | ||
| 89 | + std::array<int64_t, 1> length_shape{batch_size}; | ||
| 90 | + | ||
| 91 | + Ort::Value length = Ort::Value::CreateTensor<int64_t>( | ||
| 92 | + allocator_, length_shape.data(), length_shape.size()); | ||
| 93 | + | ||
| 94 | + int64_t *p_length = length.GetTensorMutableData<int64_t>(); | ||
| 95 | + | ||
| 96 | + std::fill(p_length, p_length + batch_size, ChunkSize()); | ||
| 97 | + | ||
| 98 | + // (B, T, C) -> (B, C, T) | ||
| 99 | + features = Transpose12(allocator_, &features); | ||
| 100 | + | ||
| 101 | + std::array<Ort::Value, 5> inputs = { | ||
| 102 | + std::move(features), View(&length), std::move(cache_last_channel), | ||
| 103 | + std::move(cache_last_time), std::move(cache_last_channel_len)}; | ||
| 104 | + | ||
| 105 | + auto out = | ||
| 106 | + encoder_sess_->Run({}, encoder_input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 107 | + encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size()); | ||
| 108 | + // out[0]: logit | ||
| 109 | + // out[1] logit_length | ||
| 110 | + // out[2:] states_next | ||
| 111 | + // | ||
| 112 | + // we need to remove out[1] | ||
| 113 | + | ||
| 114 | + std::vector<Ort::Value> ans; | ||
| 115 | + ans.reserve(out.size() - 1); | ||
| 116 | + | ||
| 117 | + for (int32_t i = 0; i != out.size(); ++i) { | ||
| 118 | + if (i == 1) { | ||
| 119 | + continue; | ||
| 120 | + } | ||
| 121 | + | ||
| 122 | + ans.push_back(std::move(out[i])); | ||
| 123 | + } | ||
| 124 | + | ||
| 125 | + return ans; | ||
| 126 | + } | ||
| 127 | + | ||
| 128 | + std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder( | ||
| 129 | + Ort::Value targets, std::vector<Ort::Value> states) { | ||
| 130 | + | ||
| 131 | + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); | ||
| 132 | + | ||
| 133 | + // Create the tensor with a single int32_t value of 1 | ||
| 134 | + int32_t length_value = 1; | ||
| 135 | + std::vector<int64_t> length_shape = {1}; | ||
| 136 | + | ||
| 137 | + Ort::Value targets_length = Ort::Value::CreateTensor<int32_t>( | ||
| 138 | + memory_info, &length_value, 1, length_shape.data(), length_shape.size() | ||
| 139 | + ); | ||
| 140 | + | ||
| 141 | + std::vector<Ort::Value> decoder_inputs; | ||
| 142 | + decoder_inputs.reserve(2 + states.size()); | ||
| 143 | + | ||
| 144 | + decoder_inputs.push_back(std::move(targets)); | ||
| 145 | + decoder_inputs.push_back(std::move(targets_length)); | ||
| 146 | + | ||
| 147 | + for (auto &s : states) { | ||
| 148 | + decoder_inputs.push_back(std::move(s)); | ||
| 149 | + } | ||
| 150 | + | ||
| 151 | + auto decoder_out = decoder_sess_->Run( | ||
| 152 | + {}, decoder_input_names_ptr_.data(), decoder_inputs.data(), | ||
| 153 | + decoder_inputs.size(), decoder_output_names_ptr_.data(), | ||
| 154 | + decoder_output_names_ptr_.size()); | ||
| 155 | + | ||
| 156 | + std::vector<Ort::Value> states_next; | ||
| 157 | + states_next.reserve(states.size()); | ||
| 158 | + | ||
| 159 | + // decoder_out[0]: decoder_output | ||
| 160 | + // decoder_out[1]: decoder_output_length (discarded) | ||
| 161 | + // decoder_out[2:] states_next | ||
| 162 | + | ||
| 163 | + for (int32_t i = 0; i != states.size(); ++i) { | ||
| 164 | + states_next.push_back(std::move(decoder_out[i + 2])); | ||
| 165 | + } | ||
| 166 | + | ||
| 167 | + // we discard decoder_out[1] | ||
| 168 | + return {std::move(decoder_out[0]), std::move(states_next)}; | ||
| 169 | + } | ||
| 170 | + | ||
| 171 | + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) { | ||
| 172 | + std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out), | ||
| 173 | + std::move(decoder_out)}; | ||
| 174 | + auto logit = | ||
| 175 | + joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(), | ||
| 176 | + joiner_input.size(), joiner_output_names_ptr_.data(), | ||
| 177 | + joiner_output_names_ptr_.size()); | ||
| 178 | + | ||
| 179 | + return std::move(logit[0]); | ||
| 180 | +} | ||
| 181 | + | ||
| 182 | + std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const { | ||
| 183 | + std::array<int64_t, 3> s0_shape{pred_rnn_layers_, batch_size, pred_hidden_}; | ||
| 184 | + Ort::Value s0 = Ort::Value::CreateTensor<float>(allocator_, s0_shape.data(), | ||
| 185 | + s0_shape.size()); | ||
| 186 | + | ||
| 187 | + Fill<float>(&s0, 0); | ||
| 188 | + | ||
| 189 | + std::array<int64_t, 3> s1_shape{pred_rnn_layers_, batch_size, pred_hidden_}; | ||
| 190 | + | ||
| 191 | + Ort::Value s1 = Ort::Value::CreateTensor<float>(allocator_, s1_shape.data(), | ||
| 192 | + s1_shape.size()); | ||
| 193 | + | ||
| 194 | + Fill<float>(&s1, 0); | ||
| 195 | + | ||
| 196 | + std::vector<Ort::Value> states; | ||
| 197 | + | ||
| 198 | + states.reserve(2); | ||
| 199 | + states.push_back(std::move(s0)); | ||
| 200 | + states.push_back(std::move(s1)); | ||
| 201 | + | ||
| 202 | + return states; | ||
| 203 | + } | ||
| 204 | + | ||
| 205 | + int32_t ChunkSize() const { return window_size_; } | ||
| 206 | + | ||
| 207 | + int32_t ChunkShift() const { return chunk_shift_; } | ||
| 208 | + | ||
| 209 | + int32_t SubsamplingFactor() const { return subsampling_factor_; } | ||
| 210 | + | ||
| 211 | + int32_t VocabSize() const { return vocab_size_; } | ||
| 212 | + | ||
| 213 | + OrtAllocator *Allocator() const { return allocator_; } | ||
| 214 | + | ||
| 215 | + std::string FeatureNormalizationMethod() const { return normalize_type_; } | ||
| 216 | + | ||
| 217 | + // Return a vector containing 3 tensors | ||
| 218 | + // - cache_last_channel | ||
| 219 | + // - cache_last_time_ | ||
| 220 | + // - cache_last_channel_len | ||
| 221 | + std::vector<Ort::Value> GetInitStates() { | ||
| 222 | + std::vector<Ort::Value> ans; | ||
| 223 | + ans.reserve(3); | ||
| 224 | + ans.push_back(View(&cache_last_channel_)); | ||
| 225 | + ans.push_back(View(&cache_last_time_)); | ||
| 226 | + ans.push_back(View(&cache_last_channel_len_)); | ||
| 227 | + | ||
| 228 | + return ans; | ||
| 229 | + } | ||
| 230 | + | ||
| 231 | +private: | ||
| 232 | + void InitEncoder(void *model_data, size_t model_data_length) { | ||
| 233 | + encoder_sess_ = std::make_unique<Ort::Session>( | ||
| 234 | + env_, model_data, model_data_length, sess_opts_); | ||
| 235 | + | ||
| 236 | + GetInputNames(encoder_sess_.get(), &encoder_input_names_, | ||
| 237 | + &encoder_input_names_ptr_); | ||
| 238 | + | ||
| 239 | + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, | ||
| 240 | + &encoder_output_names_ptr_); | ||
| 241 | + | ||
| 242 | + // get meta data | ||
| 243 | + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); | ||
| 244 | + if (config_.debug) { | ||
| 245 | + std::ostringstream os; | ||
| 246 | + os << "---encoder---\n"; | ||
| 247 | + PrintModelMetadata(os, meta_data); | ||
| 248 | + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
| 249 | + } | ||
| 250 | + | ||
| 251 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 252 | + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); | ||
| 253 | + | ||
| 254 | + // need to increase by 1 since the blank token is not included in computing | ||
| 255 | + // vocab_size in NeMo. | ||
| 256 | + vocab_size_ += 1; | ||
| 257 | + | ||
| 258 | + SHERPA_ONNX_READ_META_DATA(window_size_, "window_size"); | ||
| 259 | + SHERPA_ONNX_READ_META_DATA(chunk_shift_, "chunk_shift"); | ||
| 260 | + SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor"); | ||
| 261 | + SHERPA_ONNX_READ_META_DATA_STR(normalize_type_, "normalize_type"); | ||
| 262 | + SHERPA_ONNX_READ_META_DATA(pred_rnn_layers_, "pred_rnn_layers"); | ||
| 263 | + SHERPA_ONNX_READ_META_DATA(pred_hidden_, "pred_hidden"); | ||
| 264 | + | ||
| 265 | + SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim1_, | ||
| 266 | + "cache_last_channel_dim1"); | ||
| 267 | + SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim2_, | ||
| 268 | + "cache_last_channel_dim2"); | ||
| 269 | + SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim3_, | ||
| 270 | + "cache_last_channel_dim3"); | ||
| 271 | + SHERPA_ONNX_READ_META_DATA(cache_last_time_dim1_, "cache_last_time_dim1"); | ||
| 272 | + SHERPA_ONNX_READ_META_DATA(cache_last_time_dim2_, "cache_last_time_dim2"); | ||
| 273 | + SHERPA_ONNX_READ_META_DATA(cache_last_time_dim3_, "cache_last_time_dim3"); | ||
| 274 | + | ||
| 275 | + if (normalize_type_ == "NA") { | ||
| 276 | + normalize_type_ = ""; | ||
| 277 | + } | ||
| 278 | + | ||
| 279 | + InitStates(); | ||
| 280 | + } | ||
| 281 | + | ||
| 282 | + void InitStates() { | ||
| 283 | + std::array<int64_t, 4> cache_last_channel_shape{1, cache_last_channel_dim1_, | ||
| 284 | + cache_last_channel_dim2_, | ||
| 285 | + cache_last_channel_dim3_}; | ||
| 286 | + | ||
| 287 | + cache_last_channel_ = Ort::Value::CreateTensor<float>( | ||
| 288 | + allocator_, cache_last_channel_shape.data(), | ||
| 289 | + cache_last_channel_shape.size()); | ||
| 290 | + | ||
| 291 | + Fill<float>(&cache_last_channel_, 0); | ||
| 292 | + | ||
| 293 | + std::array<int64_t, 4> cache_last_time_shape{ | ||
| 294 | + 1, cache_last_time_dim1_, cache_last_time_dim2_, cache_last_time_dim3_}; | ||
| 295 | + | ||
| 296 | + cache_last_time_ = Ort::Value::CreateTensor<float>( | ||
| 297 | + allocator_, cache_last_time_shape.data(), cache_last_time_shape.size()); | ||
| 298 | + | ||
| 299 | + Fill<float>(&cache_last_time_, 0); | ||
| 300 | + | ||
| 301 | + int64_t shape = 1; | ||
| 302 | + cache_last_channel_len_ = | ||
| 303 | + Ort::Value::CreateTensor<int64_t>(allocator_, &shape, 1); | ||
| 304 | + | ||
| 305 | + cache_last_channel_len_.GetTensorMutableData<int64_t>()[0] = 0; | ||
| 306 | + } | ||
| 307 | + | ||
| 308 | + void InitDecoder(void *model_data, size_t model_data_length) { | ||
| 309 | + decoder_sess_ = std::make_unique<Ort::Session>( | ||
| 310 | + env_, model_data, model_data_length, sess_opts_); | ||
| 311 | + | ||
| 312 | + GetInputNames(decoder_sess_.get(), &decoder_input_names_, | ||
| 313 | + &decoder_input_names_ptr_); | ||
| 314 | + | ||
| 315 | + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, | ||
| 316 | + &decoder_output_names_ptr_); | ||
| 317 | + } | ||
| 318 | + | ||
| 319 | + void InitJoiner(void *model_data, size_t model_data_length) { | ||
| 320 | + joiner_sess_ = std::make_unique<Ort::Session>( | ||
| 321 | + env_, model_data, model_data_length, sess_opts_); | ||
| 322 | + | ||
| 323 | + GetInputNames(joiner_sess_.get(), &joiner_input_names_, | ||
| 324 | + &joiner_input_names_ptr_); | ||
| 325 | + | ||
| 326 | + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, | ||
| 327 | + &joiner_output_names_ptr_); | ||
| 328 | + } | ||
| 329 | + | ||
| 330 | + private: | ||
| 331 | + OnlineModelConfig config_; | ||
| 332 | + Ort::Env env_; | ||
| 333 | + Ort::SessionOptions sess_opts_; | ||
| 334 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 335 | + | ||
| 336 | + std::unique_ptr<Ort::Session> encoder_sess_; | ||
| 337 | + std::unique_ptr<Ort::Session> decoder_sess_; | ||
| 338 | + std::unique_ptr<Ort::Session> joiner_sess_; | ||
| 339 | + | ||
| 340 | + std::vector<std::string> encoder_input_names_; | ||
| 341 | + std::vector<const char *> encoder_input_names_ptr_; | ||
| 342 | + | ||
| 343 | + std::vector<std::string> encoder_output_names_; | ||
| 344 | + std::vector<const char *> encoder_output_names_ptr_; | ||
| 345 | + | ||
| 346 | + std::vector<std::string> decoder_input_names_; | ||
| 347 | + std::vector<const char *> decoder_input_names_ptr_; | ||
| 348 | + | ||
| 349 | + std::vector<std::string> decoder_output_names_; | ||
| 350 | + std::vector<const char *> decoder_output_names_ptr_; | ||
| 351 | + | ||
| 352 | + std::vector<std::string> joiner_input_names_; | ||
| 353 | + std::vector<const char *> joiner_input_names_ptr_; | ||
| 354 | + | ||
| 355 | + std::vector<std::string> joiner_output_names_; | ||
| 356 | + std::vector<const char *> joiner_output_names_ptr_; | ||
| 357 | + | ||
| 358 | + int32_t window_size_; | ||
| 359 | + int32_t chunk_shift_; | ||
| 360 | + int32_t vocab_size_ = 0; | ||
| 361 | + int32_t subsampling_factor_ = 8; | ||
| 362 | + std::string normalize_type_; | ||
| 363 | + int32_t pred_rnn_layers_ = -1; | ||
| 364 | + int32_t pred_hidden_ = -1; | ||
| 365 | + | ||
| 366 | + int32_t cache_last_channel_dim1_; | ||
| 367 | + int32_t cache_last_channel_dim2_; | ||
| 368 | + int32_t cache_last_channel_dim3_; | ||
| 369 | + int32_t cache_last_time_dim1_; | ||
| 370 | + int32_t cache_last_time_dim2_; | ||
| 371 | + int32_t cache_last_time_dim3_; | ||
| 372 | + | ||
| 373 | + Ort::Value cache_last_channel_{nullptr}; | ||
| 374 | + Ort::Value cache_last_time_{nullptr}; | ||
| 375 | + Ort::Value cache_last_channel_len_{nullptr}; | ||
| 376 | +}; | ||
| 377 | + | ||
| 378 | +OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( | ||
| 379 | + const OnlineModelConfig &config) | ||
| 380 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 381 | + | ||
| 382 | +#if __ANDROID_API__ >= 9 | ||
| 383 | +OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( | ||
| 384 | + AAssetManager *mgr, const OnlineModelConfig &config) | ||
| 385 | + : impl_(std::make_unique<Impl>(mgr, config)) {} | ||
| 386 | +#endif | ||
| 387 | + | ||
| 388 | +OnlineTransducerNeMoModel::~OnlineTransducerNeMoModel() = default; | ||
| 389 | + | ||
| 390 | +std::vector<Ort::Value> | ||
| 391 | +OnlineTransducerNeMoModel::RunEncoder(Ort::Value features, | ||
| 392 | + std::vector<Ort::Value> states) const { | ||
| 393 | + return impl_->RunEncoder(std::move(features), std::move(states)); | ||
| 394 | +} | ||
| 395 | + | ||
| 396 | +std::pair<Ort::Value, std::vector<Ort::Value>> | ||
| 397 | +OnlineTransducerNeMoModel::RunDecoder(Ort::Value targets, | ||
| 398 | + std::vector<Ort::Value> states) const { | ||
| 399 | + return impl_->RunDecoder(std::move(targets), std::move(states)); | ||
| 400 | +} | ||
| 401 | + | ||
| 402 | +std::vector<Ort::Value> OnlineTransducerNeMoModel::GetDecoderInitStates( | ||
| 403 | + int32_t batch_size) const { | ||
| 404 | + return impl_->GetDecoderInitStates(batch_size); | ||
| 405 | +} | ||
| 406 | + | ||
| 407 | +Ort::Value OnlineTransducerNeMoModel::RunJoiner(Ort::Value encoder_out, | ||
| 408 | + Ort::Value decoder_out) const { | ||
| 409 | + return impl_->RunJoiner(std::move(encoder_out), std::move(decoder_out)); | ||
| 410 | +} | ||
| 411 | + | ||
| 412 | + | ||
| 413 | +int32_t OnlineTransducerNeMoModel::ChunkSize() const { | ||
| 414 | + return impl_->ChunkSize(); | ||
| 415 | + } | ||
| 416 | + | ||
| 417 | +int32_t OnlineTransducerNeMoModel::ChunkShift() const { | ||
| 418 | + return impl_->ChunkShift(); | ||
| 419 | + } | ||
| 420 | + | ||
| 421 | +int32_t OnlineTransducerNeMoModel::SubsamplingFactor() const { | ||
| 422 | + return impl_->SubsamplingFactor(); | ||
| 423 | +} | ||
| 424 | + | ||
| 425 | +int32_t OnlineTransducerNeMoModel::VocabSize() const { | ||
| 426 | + return impl_->VocabSize(); | ||
| 427 | +} | ||
| 428 | + | ||
| 429 | +OrtAllocator *OnlineTransducerNeMoModel::Allocator() const { | ||
| 430 | + return impl_->Allocator(); | ||
| 431 | +} | ||
| 432 | + | ||
| 433 | +std::string OnlineTransducerNeMoModel::FeatureNormalizationMethod() const { | ||
| 434 | + return impl_->FeatureNormalizationMethod(); | ||
| 435 | +} | ||
| 436 | + | ||
| 437 | +std::vector<Ort::Value> OnlineTransducerNeMoModel::GetInitStates() const { | ||
| 438 | + return impl_->GetInitStates(); | ||
| 439 | +} | ||
| 440 | + | ||
| 441 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/online-transducer-nemo-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +// Copyright (c) 2024 Sangeet Sagar | ||
| 5 | + | ||
| 6 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_NEMO_MODEL_H_ | ||
| 7 | +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_NEMO_MODEL_H_ | ||
| 8 | + | ||
| 9 | +#include <memory> | ||
| 10 | +#include <string> | ||
| 11 | +#include <utility> | ||
| 12 | +#include <vector> | ||
| 13 | + | ||
| 14 | +#if __ANDROID_API__ >= 9 | ||
| 15 | +#include "android/asset_manager.h" | ||
| 16 | +#include "android/asset_manager_jni.h" | ||
| 17 | +#endif | ||
| 18 | + | ||
| 19 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 20 | +#include "sherpa-onnx/csrc/online-model-config.h" | ||
| 21 | + | ||
| 22 | +namespace sherpa_onnx { | ||
| 23 | + | ||
| 24 | +// see | ||
| 25 | +// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py#L40 | ||
| 26 | +// Its decoder is stateful, not stateless. | ||
| 27 | +class OnlineTransducerNeMoModel { | ||
| 28 | + public: | ||
| 29 | + explicit OnlineTransducerNeMoModel(const OnlineModelConfig &config); | ||
| 30 | + | ||
| 31 | +#if __ANDROID_API__ >= 9 | ||
| 32 | + OnlineTransducerNeMoModel(AAssetManager *mgr, | ||
| 33 | + const OnlineModelConfig &config); | ||
| 34 | +#endif | ||
| 35 | + | ||
| 36 | + ~OnlineTransducerNeMoModel(); | ||
| 37 | + // A list of 3 tensors: | ||
| 38 | + // - cache_last_channel | ||
| 39 | + // - cache_last_time | ||
| 40 | + // - cache_last_channel_len | ||
| 41 | + std::vector<Ort::Value> GetInitStates() const; | ||
| 42 | + | ||
| 43 | + /** Run the encoder. | ||
| 44 | + * | ||
| 45 | + * @param features A tensor of shape (N, T, C). It is changed in-place. | ||
| 46 | + * @param states It is from GetInitStates() or returned from this method. | ||
| 47 | + * | ||
| 48 | + * @return Return a tuple containing: | ||
| 49 | + * - ans[0]: encoder_out, a tensor of shape (N, T', encoder_out_dim) | ||
| 50 | + * - ans[1:]: contains next states | ||
| 51 | + */ | ||
| 52 | + std::vector<Ort::Value> RunEncoder( | ||
| 53 | + Ort::Value features, std::vector<Ort::Value> states) const; // NOLINT | ||
| 54 | + | ||
| 55 | + /** Run the decoder network. | ||
| 56 | + * | ||
| 57 | + * @param targets A int32 tensor of shape (batch_size, 1) | ||
| 58 | + * @param states The states for the decoder model. | ||
| 59 | + * @return Return a vector: | ||
| 60 | + * - ans[0] is the decoder_out (a float tensor) | ||
| 61 | + * - ans[1:] is the next states | ||
| 62 | + */ | ||
| 63 | + std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder( | ||
| 64 | + Ort::Value targets, std::vector<Ort::Value> states) const; | ||
| 65 | + | ||
| 66 | + std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const; | ||
| 67 | + | ||
| 68 | + /** Run the joint network. | ||
| 69 | + * | ||
| 70 | + * @param encoder_out Output of the encoder network. | ||
| 71 | + * @param decoder_out Output of the decoder network. | ||
| 72 | + * @return Return a tensor of shape (N, 1, 1, vocab_size) containing logits. | ||
| 73 | + */ | ||
| 74 | + Ort::Value RunJoiner(Ort::Value encoder_out, | ||
| 75 | + Ort::Value decoder_out) const; | ||
| 76 | + | ||
| 77 | + | ||
| 78 | + /** We send this number of feature frames to the encoder at a time. */ | ||
| 79 | + int32_t ChunkSize() const; | ||
| 80 | + | ||
| 81 | + /** Number of input frames to discard after each call to RunEncoder. | ||
| 82 | + * | ||
| 83 | + * For instance, if we have 30 frames, chunk_size=8, chunk_shift=6. | ||
| 84 | + * | ||
| 85 | + * In the first call of RunEncoder, we use frames 0~7 since chunk_size is 8. | ||
| 86 | + * Then we discard frame 0~5 since chunk_shift is 6. | ||
| 87 | + * In the second call of RunEncoder, we use frames 6~13; and then we discard | ||
| 88 | + * frames 6~11. | ||
| 89 | + * In the third call of RunEncoder, we use frames 12~19; and then we discard | ||
| 90 | + * frames 12~16. | ||
| 91 | + * | ||
| 92 | + * Note: ChunkSize() - ChunkShift() == right context size | ||
| 93 | + */ | ||
| 94 | + int32_t ChunkShift() const; | ||
| 95 | + | ||
| 96 | + /** Return the subsampling factor of the model. | ||
| 97 | + */ | ||
| 98 | + int32_t SubsamplingFactor() const; | ||
| 99 | + | ||
| 100 | + int32_t VocabSize() const; | ||
| 101 | + | ||
| 102 | + /** Return an allocator for allocating memory | ||
| 103 | + */ | ||
| 104 | + OrtAllocator *Allocator() const; | ||
| 105 | + | ||
| 106 | + // Possible values: | ||
| 107 | + // - per_feature | ||
| 108 | + // - all_features (not implemented yet) | ||
| 109 | + // - fixed_mean (not implemented) | ||
| 110 | + // - fixed_std (not implemented) | ||
| 111 | + // - or just leave it to empty | ||
| 112 | + // See | ||
| 113 | + // https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59 | ||
| 114 | + // for details | ||
| 115 | + std::string FeatureNormalizationMethod() const; | ||
| 116 | + | ||
| 117 | + private: | ||
| 118 | + class Impl; | ||
| 119 | + std::unique_ptr<Impl> impl_; | ||
| 120 | + }; | ||
| 121 | + | ||
| 122 | +} // namespace sherpa_onnx | ||
| 123 | + | ||
| 124 | +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_NEMO_MODEL_H_ |
-
请 注册 或 登录 后发表评论