Committed by
GitHub
Fix nemo streaming transducer greedy search (#944)
正在显示
18 个修改的文件
包含
320 行增加
和
290 行删除
| @@ -16,6 +16,45 @@ echo "PATH: $PATH" | @@ -16,6 +16,45 @@ echo "PATH: $PATH" | ||
| 16 | which $EXE | 16 | which $EXE |
| 17 | 17 | ||
| 18 | log "------------------------------------------------------------" | 18 | log "------------------------------------------------------------" |
| 19 | +log "Run NeMo transducer (English)" | ||
| 20 | +log "------------------------------------------------------------" | ||
| 21 | +repo_url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms.tar.bz2 | ||
| 22 | +curl -SL -O $repo_url | ||
| 23 | +tar xvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms.tar.bz2 | ||
| 24 | +rm sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms.tar.bz2 | ||
| 25 | +repo=sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms | ||
| 26 | + | ||
| 27 | +log "Start testing ${repo_url}" | ||
| 28 | + | ||
| 29 | +waves=( | ||
| 30 | +$repo/test_wavs/0.wav | ||
| 31 | +$repo/test_wavs/1.wav | ||
| 32 | +$repo/test_wavs/8k.wav | ||
| 33 | +) | ||
| 34 | + | ||
| 35 | +for wave in ${waves[@]}; do | ||
| 36 | + time $EXE \ | ||
| 37 | + --tokens=$repo/tokens.txt \ | ||
| 38 | + --encoder=$repo/encoder.onnx \ | ||
| 39 | + --decoder=$repo/decoder.onnx \ | ||
| 40 | + --joiner=$repo/joiner.onnx \ | ||
| 41 | + --num-threads=2 \ | ||
| 42 | + $wave | ||
| 43 | +done | ||
| 44 | + | ||
| 45 | +time $EXE \ | ||
| 46 | + --tokens=$repo/tokens.txt \ | ||
| 47 | + --encoder=$repo/encoder.onnx \ | ||
| 48 | + --decoder=$repo/decoder.onnx \ | ||
| 49 | + --joiner=$repo/joiner.onnx \ | ||
| 50 | + --num-threads=2 \ | ||
| 51 | + $repo/test_wavs/0.wav \ | ||
| 52 | + $repo/test_wavs/1.wav \ | ||
| 53 | + $repo/test_wavs/8k.wav | ||
| 54 | + | ||
| 55 | +rm -rf $repo | ||
| 56 | + | ||
| 57 | +log "------------------------------------------------------------" | ||
| 19 | log "Run LSTM transducer (English)" | 58 | log "Run LSTM transducer (English)" |
| 20 | log "------------------------------------------------------------" | 59 | log "------------------------------------------------------------" |
| 21 | 60 |
| @@ -196,7 +196,6 @@ jobs: | @@ -196,7 +196,6 @@ jobs: | ||
| 196 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface | 196 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface |
| 197 | 197 | ||
| 198 | cd huggingface | 198 | cd huggingface |
| 199 | - git lfs pull | ||
| 200 | mkdir -p aarch64 | 199 | mkdir -p aarch64 |
| 201 | 200 | ||
| 202 | cp -v ../sherpa-onnx-*-shared.tar.bz2 ./aarch64 | 201 | cp -v ../sherpa-onnx-*-shared.tar.bz2 ./aarch64 |
| @@ -187,7 +187,6 @@ jobs: | @@ -187,7 +187,6 @@ jobs: | ||
| 187 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface | 187 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface |
| 188 | 188 | ||
| 189 | cd huggingface | 189 | cd huggingface |
| 190 | - git lfs pull | ||
| 191 | mkdir -p aarch64 | 190 | mkdir -p aarch64 |
| 192 | 191 | ||
| 193 | cp -v ../sherpa-onnx-*-static.tar.bz2 ./aarch64 | 192 | cp -v ../sherpa-onnx-*-static.tar.bz2 ./aarch64 |
| @@ -124,7 +124,6 @@ jobs: | @@ -124,7 +124,6 @@ jobs: | ||
| 124 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface | 124 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface |
| 125 | 125 | ||
| 126 | cd huggingface | 126 | cd huggingface |
| 127 | - git lfs pull | ||
| 128 | 127 | ||
| 129 | cp -v ../sherpa-onnx-*-android.tar.bz2 ./ | 128 | cp -v ../sherpa-onnx-*-android.tar.bz2 ./ |
| 130 | 129 |
| @@ -209,7 +209,6 @@ jobs: | @@ -209,7 +209,6 @@ jobs: | ||
| 209 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface | 209 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface |
| 210 | 210 | ||
| 211 | cd huggingface | 211 | cd huggingface |
| 212 | - git lfs pull | ||
| 213 | mkdir -p arm32 | 212 | mkdir -p arm32 |
| 214 | 213 | ||
| 215 | cp -v ../sherpa-onnx-*.tar.bz2 ./arm32 | 214 | cp -v ../sherpa-onnx-*.tar.bz2 ./arm32 |
| @@ -138,7 +138,6 @@ jobs: | @@ -138,7 +138,6 @@ jobs: | ||
| 138 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface | 138 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface |
| 139 | 139 | ||
| 140 | cd huggingface | 140 | cd huggingface |
| 141 | - git lfs pull | ||
| 142 | 141 | ||
| 143 | cp -v ../sherpa-onnx-*.tar.bz2 ./ | 142 | cp -v ../sherpa-onnx-*.tar.bz2 ./ |
| 144 | 143 |
| @@ -242,7 +242,6 @@ jobs: | @@ -242,7 +242,6 @@ jobs: | ||
| 242 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface | 242 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface |
| 243 | 243 | ||
| 244 | cd huggingface | 244 | cd huggingface |
| 245 | - git lfs pull | ||
| 246 | mkdir -p riscv64 | 245 | mkdir -p riscv64 |
| 247 | 246 | ||
| 248 | cp -v ../sherpa-onnx-*-shared.tar.bz2 ./riscv64 | 247 | cp -v ../sherpa-onnx-*-shared.tar.bz2 ./riscv64 |
| @@ -219,7 +219,6 @@ jobs: | @@ -219,7 +219,6 @@ jobs: | ||
| 219 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface | 219 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface |
| 220 | 220 | ||
| 221 | cd huggingface | 221 | cd huggingface |
| 222 | - git lfs pull | ||
| 223 | mkdir -p win64 | 222 | mkdir -p win64 |
| 224 | 223 | ||
| 225 | cp -v ../sherpa-onnx-*.tar.bz2 ./win64 | 224 | cp -v ../sherpa-onnx-*.tar.bz2 ./win64 |
| @@ -221,7 +221,6 @@ jobs: | @@ -221,7 +221,6 @@ jobs: | ||
| 221 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface | 221 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface |
| 222 | 222 | ||
| 223 | cd huggingface | 223 | cd huggingface |
| 224 | - git lfs pull | ||
| 225 | mkdir -p win32 | 224 | mkdir -p win32 |
| 226 | 225 | ||
| 227 | cp -v ../sherpa-onnx-*.tar.bz2 ./win32 | 226 | cp -v ../sherpa-onnx-*.tar.bz2 ./win32 |
| @@ -14,19 +14,18 @@ namespace sherpa_onnx { | @@ -14,19 +14,18 @@ namespace sherpa_onnx { | ||
| 14 | 14 | ||
| 15 | std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | 15 | std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( |
| 16 | const OnlineRecognizerConfig &config) { | 16 | const OnlineRecognizerConfig &config) { |
| 17 | - | ||
| 18 | if (!config.model_config.transducer.encoder.empty()) { | 17 | if (!config.model_config.transducer.encoder.empty()) { |
| 19 | Ort::Env env(ORT_LOGGING_LEVEL_WARNING); | 18 | Ort::Env env(ORT_LOGGING_LEVEL_WARNING); |
| 20 | - | 19 | + |
| 21 | auto decoder_model = ReadFile(config.model_config.transducer.decoder); | 20 | auto decoder_model = ReadFile(config.model_config.transducer.decoder); |
| 22 | - auto sess = std::make_unique<Ort::Session>(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{}); | ||
| 23 | - | 21 | + auto sess = std::make_unique<Ort::Session>( |
| 22 | + env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{}); | ||
| 23 | + | ||
| 24 | size_t node_count = sess->GetOutputCount(); | 24 | size_t node_count = sess->GetOutputCount(); |
| 25 | - | 25 | + |
| 26 | if (node_count == 1) { | 26 | if (node_count == 1) { |
| 27 | return std::make_unique<OnlineRecognizerTransducerImpl>(config); | 27 | return std::make_unique<OnlineRecognizerTransducerImpl>(config); |
| 28 | } else { | 28 | } else { |
| 29 | - SHERPA_ONNX_LOGE("Running streaming Nemo transducer model"); | ||
| 30 | return std::make_unique<OnlineRecognizerTransducerNeMoImpl>(config); | 29 | return std::make_unique<OnlineRecognizerTransducerNeMoImpl>(config); |
| 31 | } | 30 | } |
| 32 | } | 31 | } |
| @@ -50,12 +49,13 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | @@ -50,12 +49,13 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | ||
| 50 | AAssetManager *mgr, const OnlineRecognizerConfig &config) { | 49 | AAssetManager *mgr, const OnlineRecognizerConfig &config) { |
| 51 | if (!config.model_config.transducer.encoder.empty()) { | 50 | if (!config.model_config.transducer.encoder.empty()) { |
| 52 | Ort::Env env(ORT_LOGGING_LEVEL_WARNING); | 51 | Ort::Env env(ORT_LOGGING_LEVEL_WARNING); |
| 53 | - | 52 | + |
| 54 | auto decoder_model = ReadFile(mgr, config.model_config.transducer.decoder); | 53 | auto decoder_model = ReadFile(mgr, config.model_config.transducer.decoder); |
| 55 | - auto sess = std::make_unique<Ort::Session>(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{}); | ||
| 56 | - | 54 | + auto sess = std::make_unique<Ort::Session>( |
| 55 | + env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{}); | ||
| 56 | + | ||
| 57 | size_t node_count = sess->GetOutputCount(); | 57 | size_t node_count = sess->GetOutputCount(); |
| 58 | - | 58 | + |
| 59 | if (node_count == 1) { | 59 | if (node_count == 1) { |
| 60 | return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config); | 60 | return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config); |
| 61 | } else { | 61 | } else { |
| @@ -35,18 +35,15 @@ | @@ -35,18 +35,15 @@ | ||
| 35 | 35 | ||
| 36 | namespace sherpa_onnx { | 36 | namespace sherpa_onnx { |
| 37 | 37 | ||
| 38 | -static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, | ||
| 39 | - const SymbolTable &sym_table, | ||
| 40 | - float frame_shift_ms, | ||
| 41 | - int32_t subsampling_factor, | ||
| 42 | - int32_t segment, | ||
| 43 | - int32_t frames_since_start) { | 38 | +OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, |
| 39 | + const SymbolTable &sym_table, | ||
| 40 | + float frame_shift_ms, int32_t subsampling_factor, | ||
| 41 | + int32_t segment, int32_t frames_since_start) { | ||
| 44 | OnlineRecognizerResult r; | 42 | OnlineRecognizerResult r; |
| 45 | r.tokens.reserve(src.tokens.size()); | 43 | r.tokens.reserve(src.tokens.size()); |
| 46 | r.timestamps.reserve(src.tokens.size()); | 44 | r.timestamps.reserve(src.tokens.size()); |
| 47 | 45 | ||
| 48 | for (auto i : src.tokens) { | 46 | for (auto i : src.tokens) { |
| 49 | - if (i == -1) continue; | ||
| 50 | auto sym = sym_table[i]; | 47 | auto sym = sym_table[i]; |
| 51 | 48 | ||
| 52 | r.text.append(sym); | 49 | r.text.append(sym); |
| @@ -6,6 +6,7 @@ | @@ -6,6 +6,7 @@ | ||
| 6 | #ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ | 6 | #ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ |
| 7 | #define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ | 7 | #define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ |
| 8 | 8 | ||
| 9 | +#include <algorithm> | ||
| 9 | #include <fstream> | 10 | #include <fstream> |
| 10 | #include <ios> | 11 | #include <ios> |
| 11 | #include <memory> | 12 | #include <memory> |
| @@ -32,23 +33,20 @@ | @@ -32,23 +33,20 @@ | ||
| 32 | namespace sherpa_onnx { | 33 | namespace sherpa_onnx { |
| 33 | 34 | ||
| 34 | // defined in ./online-recognizer-transducer-impl.h | 35 | // defined in ./online-recognizer-transducer-impl.h |
| 35 | -// static may or may not be here? TODDOs | ||
| 36 | -static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, | ||
| 37 | - const SymbolTable &sym_table, | ||
| 38 | - float frame_shift_ms, | ||
| 39 | - int32_t subsampling_factor, | ||
| 40 | - int32_t segment, | ||
| 41 | - int32_t frames_since_start); | 36 | +OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, |
| 37 | + const SymbolTable &sym_table, | ||
| 38 | + float frame_shift_ms, int32_t subsampling_factor, | ||
| 39 | + int32_t segment, int32_t frames_since_start); | ||
| 42 | 40 | ||
| 43 | class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | 41 | class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { |
| 44 | - public: | 42 | + public: |
| 45 | explicit OnlineRecognizerTransducerNeMoImpl( | 43 | explicit OnlineRecognizerTransducerNeMoImpl( |
| 46 | const OnlineRecognizerConfig &config) | 44 | const OnlineRecognizerConfig &config) |
| 47 | : config_(config), | 45 | : config_(config), |
| 48 | symbol_table_(config.model_config.tokens), | 46 | symbol_table_(config.model_config.tokens), |
| 49 | endpoint_(config_.endpoint_config), | 47 | endpoint_(config_.endpoint_config), |
| 50 | - model_(std::make_unique<OnlineTransducerNeMoModel>( | ||
| 51 | - config.model_config)) { | 48 | + model_( |
| 49 | + std::make_unique<OnlineTransducerNeMoModel>(config.model_config)) { | ||
| 52 | if (config.decoding_method == "greedy_search") { | 50 | if (config.decoding_method == "greedy_search") { |
| 53 | decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>( | 51 | decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>( |
| 54 | model_.get(), config_.blank_penalty); | 52 | model_.get(), config_.blank_penalty); |
| @@ -73,7 +71,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | @@ -73,7 +71,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | ||
| 73 | model_.get(), config_.blank_penalty); | 71 | model_.get(), config_.blank_penalty); |
| 74 | } else { | 72 | } else { |
| 75 | SHERPA_ONNX_LOGE("Unsupported decoding method: %s", | 73 | SHERPA_ONNX_LOGE("Unsupported decoding method: %s", |
| 76 | - config.decoding_method.c_str()); | 74 | + config.decoding_method.c_str()); |
| 77 | exit(-1); | 75 | exit(-1); |
| 78 | } | 76 | } |
| 79 | 77 | ||
| @@ -83,7 +81,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | @@ -83,7 +81,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | ||
| 83 | 81 | ||
| 84 | std::unique_ptr<OnlineStream> CreateStream() const override { | 82 | std::unique_ptr<OnlineStream> CreateStream() const override { |
| 85 | auto stream = std::make_unique<OnlineStream>(config_.feat_config); | 83 | auto stream = std::make_unique<OnlineStream>(config_.feat_config); |
| 86 | - stream->SetStates(model_->GetInitStates()); | ||
| 87 | InitOnlineStream(stream.get()); | 84 | InitOnlineStream(stream.get()); |
| 88 | return stream; | 85 | return stream; |
| 89 | } | 86 | } |
| @@ -94,14 +91,12 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | @@ -94,14 +91,12 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | ||
| 94 | } | 91 | } |
| 95 | 92 | ||
| 96 | OnlineRecognizerResult GetResult(OnlineStream *s) const override { | 93 | OnlineRecognizerResult GetResult(OnlineStream *s) const override { |
| 97 | - OnlineTransducerDecoderResult decoder_result = s->GetResult(); | ||
| 98 | - decoder_->StripLeadingBlanks(&decoder_result); | ||
| 99 | - | ||
| 100 | // TODO(fangjun): Remember to change these constants if needed | 94 | // TODO(fangjun): Remember to change these constants if needed |
| 101 | int32_t frame_shift_ms = 10; | 95 | int32_t frame_shift_ms = 10; |
| 102 | - int32_t subsampling_factor = 8; | ||
| 103 | - return Convert(decoder_result, symbol_table_, frame_shift_ms, subsampling_factor, | ||
| 104 | - s->GetCurrentSegment(), s->GetNumFramesSinceStart()); | 96 | + int32_t subsampling_factor = model_->SubsamplingFactor(); |
| 97 | + return Convert(s->GetResult(), symbol_table_, frame_shift_ms, | ||
| 98 | + subsampling_factor, s->GetCurrentSegment(), | ||
| 99 | + s->GetNumFramesSinceStart()); | ||
| 105 | } | 100 | } |
| 106 | 101 | ||
| 107 | bool IsEndpoint(OnlineStream *s) const override { | 102 | bool IsEndpoint(OnlineStream *s) const override { |
| @@ -114,8 +109,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | @@ -114,8 +109,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | ||
| 114 | // frame shift is 10 milliseconds | 109 | // frame shift is 10 milliseconds |
| 115 | float frame_shift_in_seconds = 0.01; | 110 | float frame_shift_in_seconds = 0.01; |
| 116 | 111 | ||
| 117 | - // subsampling factor is 8 | ||
| 118 | - int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 8; | 112 | + int32_t trailing_silence_frames = |
| 113 | + s->GetResult().num_trailing_blanks * model_->SubsamplingFactor(); | ||
| 119 | 114 | ||
| 120 | return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames, | 115 | return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames, |
| 121 | frame_shift_in_seconds); | 116 | frame_shift_in_seconds); |
| @@ -126,19 +121,16 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | @@ -126,19 +121,16 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | ||
| 126 | // segment is incremented only when the last | 121 | // segment is incremented only when the last |
| 127 | // result is not empty | 122 | // result is not empty |
| 128 | const auto &r = s->GetResult(); | 123 | const auto &r = s->GetResult(); |
| 129 | - if (!r.tokens.empty() && r.tokens.back() != 0) { | 124 | + if (!r.tokens.empty()) { |
| 130 | s->GetCurrentSegment() += 1; | 125 | s->GetCurrentSegment() += 1; |
| 131 | } | 126 | } |
| 132 | } | 127 | } |
| 133 | 128 | ||
| 134 | - // we keep the decoder_out | ||
| 135 | - decoder_->UpdateDecoderOut(&s->GetResult()); | ||
| 136 | - Ort::Value decoder_out = std::move(s->GetResult().decoder_out); | 129 | + s->SetResult({}); |
| 130 | + | ||
| 131 | + s->SetStates(model_->GetEncoderInitStates()); | ||
| 137 | 132 | ||
| 138 | - auto r = decoder_->GetEmptyResult(); | ||
| 139 | - | ||
| 140 | - s->SetResult(r); | ||
| 141 | - s->GetResult().decoder_out = std::move(decoder_out); | 133 | + s->SetNeMoDecoderStates(model_->GetDecoderInitStates()); |
| 142 | 134 | ||
| 143 | // Note: We only update counters. The underlying audio samples | 135 | // Note: We only update counters. The underlying audio samples |
| 144 | // are not discarded. | 136 | // are not discarded. |
| @@ -151,10 +143,9 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | @@ -151,10 +143,9 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | ||
| 151 | 143 | ||
| 152 | int32_t feature_dim = ss[0]->FeatureDim(); | 144 | int32_t feature_dim = ss[0]->FeatureDim(); |
| 153 | 145 | ||
| 154 | - std::vector<OnlineTransducerDecoderResult> result(n); | ||
| 155 | std::vector<float> features_vec(n * chunk_size * feature_dim); | 146 | std::vector<float> features_vec(n * chunk_size * feature_dim); |
| 156 | std::vector<std::vector<Ort::Value>> encoder_states(n); | 147 | std::vector<std::vector<Ort::Value>> encoder_states(n); |
| 157 | - | 148 | + |
| 158 | for (int32_t i = 0; i != n; ++i) { | 149 | for (int32_t i = 0; i != n; ++i) { |
| 159 | const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); | 150 | const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); |
| 160 | std::vector<float> features = | 151 | std::vector<float> features = |
| @@ -166,9 +157,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | @@ -166,9 +157,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | ||
| 166 | std::copy(features.begin(), features.end(), | 157 | std::copy(features.begin(), features.end(), |
| 167 | features_vec.data() + i * chunk_size * feature_dim); | 158 | features_vec.data() + i * chunk_size * feature_dim); |
| 168 | 159 | ||
| 169 | - result[i] = std::move(ss[i]->GetResult()); | ||
| 170 | encoder_states[i] = std::move(ss[i]->GetStates()); | 160 | encoder_states[i] = std::move(ss[i]->GetStates()); |
| 171 | - | ||
| 172 | } | 161 | } |
| 173 | 162 | ||
| 174 | auto memory_info = | 163 | auto memory_info = |
| @@ -180,42 +169,35 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | @@ -180,42 +169,35 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | ||
| 180 | features_vec.size(), x_shape.data(), | 169 | features_vec.size(), x_shape.data(), |
| 181 | x_shape.size()); | 170 | x_shape.size()); |
| 182 | 171 | ||
| 183 | - // Batch size is 1 | ||
| 184 | - auto states = std::move(encoder_states[0]); | ||
| 185 | - int32_t num_states = states.size(); // num_states = 3 | 172 | + auto states = model_->StackStates(std::move(encoder_states)); |
| 173 | + int32_t num_states = states.size(); // num_states = 3 | ||
| 186 | auto t = model_->RunEncoder(std::move(x), std::move(states)); | 174 | auto t = model_->RunEncoder(std::move(x), std::move(states)); |
| 187 | // t[0] encoder_out, float tensor, (batch_size, dim, T) | 175 | // t[0] encoder_out, float tensor, (batch_size, dim, T) |
| 188 | // t[1] next states | 176 | // t[1] next states |
| 189 | - | 177 | + |
| 190 | std::vector<Ort::Value> out_states; | 178 | std::vector<Ort::Value> out_states; |
| 191 | out_states.reserve(num_states); | 179 | out_states.reserve(num_states); |
| 192 | - | 180 | + |
| 193 | for (int32_t k = 1; k != num_states + 1; ++k) { | 181 | for (int32_t k = 1; k != num_states + 1; ++k) { |
| 194 | out_states.push_back(std::move(t[k])); | 182 | out_states.push_back(std::move(t[k])); |
| 195 | } | 183 | } |
| 196 | 184 | ||
| 185 | + auto unstacked_states = model_->UnStackStates(std::move(out_states)); | ||
| 186 | + for (int32_t i = 0; i != n; ++i) { | ||
| 187 | + ss[i]->SetStates(std::move(unstacked_states[i])); | ||
| 188 | + } | ||
| 189 | + | ||
| 197 | Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]); | 190 | Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]); |
| 198 | - | ||
| 199 | - // defined in online-transducer-greedy-search-nemo-decoder.h | ||
| 200 | - // get intial states of decoder. | ||
| 201 | - std::vector<Ort::Value> &decoder_states = ss[0]->GetNeMoDecoderStates(); | ||
| 202 | - | ||
| 203 | - // Subsequent decoder states (for each chunks) are updated inside the Decode method. | ||
| 204 | - // This returns the decoder state from the LAST chunk. We probably dont need it. So we can discard it. | ||
| 205 | - decoder_states = decoder_->Decode(std::move(encoder_out), | ||
| 206 | - std::move(decoder_states), | ||
| 207 | - &result, ss, n); | ||
| 208 | - | ||
| 209 | - ss[0]->SetResult(result[0]); | ||
| 210 | - | ||
| 211 | - ss[0]->SetStates(std::move(out_states)); | 191 | + |
| 192 | + decoder_->Decode(std::move(encoder_out), ss, n); | ||
| 212 | } | 193 | } |
| 213 | 194 | ||
| 214 | void InitOnlineStream(OnlineStream *stream) const { | 195 | void InitOnlineStream(OnlineStream *stream) const { |
| 215 | - auto r = decoder_->GetEmptyResult(); | 196 | + // set encoder states |
| 197 | + stream->SetStates(model_->GetEncoderInitStates()); | ||
| 216 | 198 | ||
| 217 | - stream->SetResult(r); | ||
| 218 | - stream->SetNeMoDecoderStates(model_->GetDecoderInitStates(1)); | 199 | + // set decoder states |
| 200 | + stream->SetNeMoDecoderStates(model_->GetDecoderInitStates()); | ||
| 219 | } | 201 | } |
| 220 | 202 | ||
| 221 | private: | 203 | private: |
| @@ -250,7 +232,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | @@ -250,7 +232,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | ||
| 250 | symbol_table_.NumSymbols(), vocab_size); | 232 | symbol_table_.NumSymbols(), vocab_size); |
| 251 | exit(-1); | 233 | exit(-1); |
| 252 | } | 234 | } |
| 253 | - | ||
| 254 | } | 235 | } |
| 255 | 236 | ||
| 256 | private: | 237 | private: |
| @@ -259,9 +240,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | @@ -259,9 +240,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | ||
| 259 | std::unique_ptr<OnlineTransducerNeMoModel> model_; | 240 | std::unique_ptr<OnlineTransducerNeMoModel> model_; |
| 260 | std::unique_ptr<OnlineTransducerGreedySearchNeMoDecoder> decoder_; | 241 | std::unique_ptr<OnlineTransducerGreedySearchNeMoDecoder> decoder_; |
| 261 | Endpoint endpoint_; | 242 | Endpoint endpoint_; |
| 262 | - | ||
| 263 | }; | 243 | }; |
| 264 | 244 | ||
| 265 | } // namespace sherpa_onnx | 245 | } // namespace sherpa_onnx |
| 266 | 246 | ||
| 267 | -#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ | ||
| 247 | +#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ |
| @@ -225,7 +225,8 @@ std::vector<Ort::Value> &OnlineStream::GetStates() { | @@ -225,7 +225,8 @@ std::vector<Ort::Value> &OnlineStream::GetStates() { | ||
| 225 | return impl_->GetStates(); | 225 | return impl_->GetStates(); |
| 226 | } | 226 | } |
| 227 | 227 | ||
| 228 | -void OnlineStream::SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states) { | 228 | +void OnlineStream::SetNeMoDecoderStates( |
| 229 | + std::vector<Ort::Value> decoder_states) { | ||
| 229 | return impl_->SetNeMoDecoderStates(std::move(decoder_states)); | 230 | return impl_->SetNeMoDecoderStates(std::move(decoder_states)); |
| 230 | } | 231 | } |
| 231 | 232 |
| @@ -91,8 +91,8 @@ class OnlineStream { | @@ -91,8 +91,8 @@ class OnlineStream { | ||
| 91 | void SetStates(std::vector<Ort::Value> states); | 91 | void SetStates(std::vector<Ort::Value> states); |
| 92 | std::vector<Ort::Value> &GetStates(); | 92 | std::vector<Ort::Value> &GetStates(); |
| 93 | 93 | ||
| 94 | - void SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states); | ||
| 95 | - std::vector<Ort::Value> &GetNeMoDecoderStates(); | 94 | + void SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states); |
| 95 | + std::vector<Ort::Value> &GetNeMoDecoderStates(); | ||
| 96 | 96 | ||
| 97 | /** | 97 | /** |
| 98 | * Get the context graph corresponding to this stream. | 98 | * Get the context graph corresponding to this stream. |
| @@ -10,103 +10,64 @@ | @@ -10,103 +10,64 @@ | ||
| 10 | #include <utility> | 10 | #include <utility> |
| 11 | 11 | ||
| 12 | #include "sherpa-onnx/csrc/macros.h" | 12 | #include "sherpa-onnx/csrc/macros.h" |
| 13 | +#include "sherpa-onnx/csrc/online-stream.h" | ||
| 13 | #include "sherpa-onnx/csrc/onnx-utils.h" | 14 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 14 | 15 | ||
| 15 | namespace sherpa_onnx { | 16 | namespace sherpa_onnx { |
| 16 | 17 | ||
| 17 | -static std::pair<Ort::Value, Ort::Value> BuildDecoderInput( | ||
| 18 | - int32_t token, OrtAllocator *allocator) { | 18 | +static Ort::Value BuildDecoderInput(int32_t token, OrtAllocator *allocator) { |
| 19 | std::array<int64_t, 2> shape{1, 1}; | 19 | std::array<int64_t, 2> shape{1, 1}; |
| 20 | 20 | ||
| 21 | Ort::Value decoder_input = | 21 | Ort::Value decoder_input = |
| 22 | Ort::Value::CreateTensor<int32_t>(allocator, shape.data(), shape.size()); | 22 | Ort::Value::CreateTensor<int32_t>(allocator, shape.data(), shape.size()); |
| 23 | 23 | ||
| 24 | - std::array<int64_t, 1> length_shape{1}; | ||
| 25 | - Ort::Value decoder_input_length = Ort::Value::CreateTensor<int32_t>( | ||
| 26 | - allocator, length_shape.data(), length_shape.size()); | ||
| 27 | - | ||
| 28 | int32_t *p = decoder_input.GetTensorMutableData<int32_t>(); | 24 | int32_t *p = decoder_input.GetTensorMutableData<int32_t>(); |
| 29 | 25 | ||
| 30 | - int32_t *p_length = decoder_input_length.GetTensorMutableData<int32_t>(); | ||
| 31 | - | ||
| 32 | p[0] = token; | 26 | p[0] = token; |
| 33 | 27 | ||
| 34 | - p_length[0] = 1; | ||
| 35 | - | ||
| 36 | - return {std::move(decoder_input), std::move(decoder_input_length)}; | ||
| 37 | -} | ||
| 38 | - | ||
| 39 | - | ||
| 40 | -OnlineTransducerDecoderResult | ||
| 41 | -OnlineTransducerGreedySearchNeMoDecoder::GetEmptyResult() const { | ||
| 42 | - int32_t context_size = 8; | ||
| 43 | - int32_t blank_id = 0; // always 0 | ||
| 44 | - OnlineTransducerDecoderResult r; | ||
| 45 | - r.tokens.resize(context_size, -1); | ||
| 46 | - r.tokens.back() = blank_id; | ||
| 47 | - | ||
| 48 | - return r; | 28 | + return decoder_input; |
| 49 | } | 29 | } |
| 50 | 30 | ||
| 51 | -static void UpdateCachedDecoderOut( | ||
| 52 | - OrtAllocator *allocator, const Ort::Value *decoder_out, | ||
| 53 | - std::vector<OnlineTransducerDecoderResult> *result) { | ||
| 54 | - std::vector<int64_t> shape = | ||
| 55 | - decoder_out->GetTensorTypeAndShapeInfo().GetShape(); | 31 | +static void DecodeOne(const float *encoder_out, int32_t num_rows, |
| 32 | + int32_t num_cols, OnlineTransducerNeMoModel *model, | ||
| 33 | + float blank_penalty, OnlineStream *s) { | ||
| 56 | auto memory_info = | 34 | auto memory_info = |
| 57 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | 35 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); |
| 58 | - std::array<int64_t, 2> v_shape{1, shape[1]}; | ||
| 59 | 36 | ||
| 60 | - const float *src = decoder_out->GetTensorData<float>(); | ||
| 61 | - for (auto &r : *result) { | ||
| 62 | - if (!r.decoder_out) { | ||
| 63 | - r.decoder_out = Ort::Value::CreateTensor<float>(allocator, v_shape.data(), | ||
| 64 | - v_shape.size()); | ||
| 65 | - } | 37 | + int32_t vocab_size = model->VocabSize(); |
| 38 | + int32_t blank_id = vocab_size - 1; | ||
| 66 | 39 | ||
| 67 | - float *dst = r.decoder_out.GetTensorMutableData<float>(); | ||
| 68 | - std::copy(src, src + shape[1], dst); | ||
| 69 | - src += shape[1]; | ||
| 70 | - } | ||
| 71 | -} | 40 | + auto &r = s->GetResult(); |
| 72 | 41 | ||
| 73 | -std::vector<Ort::Value> DecodeOne( | ||
| 74 | - const float *encoder_out, int32_t num_rows, int32_t num_cols, | ||
| 75 | - OnlineTransducerNeMoModel *model, float blank_penalty, | ||
| 76 | - std::vector<Ort::Value>& decoder_states, | ||
| 77 | - std::vector<OnlineTransducerDecoderResult> *result) { | 42 | + Ort::Value decoder_out{nullptr}; |
| 78 | 43 | ||
| 79 | - auto memory_info = | ||
| 80 | - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | 44 | + auto decoder_input = BuildDecoderInput( |
| 45 | + r.tokens.empty() ? blank_id : r.tokens.back(), model->Allocator()); | ||
| 81 | 46 | ||
| 82 | - // OnlineTransducerDecoderResult result; | ||
| 83 | - int32_t vocab_size = model->VocabSize(); | ||
| 84 | - int32_t blank_id = vocab_size - 1; | ||
| 85 | - | ||
| 86 | - auto &r = (*result)[0]; | ||
| 87 | - Ort::Value decoder_out{nullptr}; | 47 | + std::vector<Ort::Value> &last_decoder_states = s->GetNeMoDecoderStates(); |
| 88 | 48 | ||
| 89 | - auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator()); | ||
| 90 | - // decoder_input_pair[0]: decoder_input | ||
| 91 | - // decoder_input_pair[1]: decoder_input_length (discarded) | 49 | + std::vector<Ort::Value> tmp_decoder_states; |
| 50 | + tmp_decoder_states.reserve(last_decoder_states.size()); | ||
| 51 | + for (auto &v : last_decoder_states) { | ||
| 52 | + tmp_decoder_states.push_back(View(&v)); | ||
| 53 | + } | ||
| 92 | 54 | ||
| 93 | // decoder_output_pair.second returns the next decoder state | 55 | // decoder_output_pair.second returns the next decoder state |
| 94 | std::pair<Ort::Value, std::vector<Ort::Value>> decoder_output_pair = | 56 | std::pair<Ort::Value, std::vector<Ort::Value>> decoder_output_pair = |
| 95 | - model->RunDecoder(std::move(decoder_input_pair.first), | ||
| 96 | - std::move(decoder_states)); // here decoder_states = {len=0, cap=0}. But decoder_output_pair= {first, second: {len=2, cap=2}} // ATTN | 57 | + model->RunDecoder(std::move(decoder_input), |
| 58 | + std::move(tmp_decoder_states)); | ||
| 97 | 59 | ||
| 98 | std::array<int64_t, 3> encoder_shape{1, num_cols, 1}; | 60 | std::array<int64_t, 3> encoder_shape{1, num_cols, 1}; |
| 99 | 61 | ||
| 100 | - decoder_states = std::move(decoder_output_pair.second); | 62 | + bool emitted = false; |
| 101 | 63 | ||
| 102 | - // TODO: Inside this loop, I need to framewise decoding. | ||
| 103 | for (int32_t t = 0; t != num_rows; ++t) { | 64 | for (int32_t t = 0; t != num_rows; ++t) { |
| 104 | Ort::Value cur_encoder_out = Ort::Value::CreateTensor( | 65 | Ort::Value cur_encoder_out = Ort::Value::CreateTensor( |
| 105 | memory_info, const_cast<float *>(encoder_out) + t * num_cols, num_cols, | 66 | memory_info, const_cast<float *>(encoder_out) + t * num_cols, num_cols, |
| 106 | encoder_shape.data(), encoder_shape.size()); | 67 | encoder_shape.data(), encoder_shape.size()); |
| 107 | 68 | ||
| 108 | Ort::Value logit = model->RunJoiner(std::move(cur_encoder_out), | 69 | Ort::Value logit = model->RunJoiner(std::move(cur_encoder_out), |
| 109 | - View(&decoder_output_pair.first)); | 70 | + View(&decoder_output_pair.first)); |
| 110 | 71 | ||
| 111 | float *p_logit = logit.GetTensorMutableData<float>(); | 72 | float *p_logit = logit.GetTensorMutableData<float>(); |
| 112 | if (blank_penalty > 0) { | 73 | if (blank_penalty > 0) { |
| @@ -117,82 +78,52 @@ std::vector<Ort::Value> DecodeOne( | @@ -117,82 +78,52 @@ std::vector<Ort::Value> DecodeOne( | ||
| 117 | static_cast<const float *>(p_logit), | 78 | static_cast<const float *>(p_logit), |
| 118 | std::max_element(static_cast<const float *>(p_logit), | 79 | std::max_element(static_cast<const float *>(p_logit), |
| 119 | static_cast<const float *>(p_logit) + vocab_size))); | 80 | static_cast<const float *>(p_logit) + vocab_size))); |
| 120 | - SHERPA_ONNX_LOGE("y=%d", y); | 81 | + |
| 121 | if (y != blank_id) { | 82 | if (y != blank_id) { |
| 83 | + emitted = true; | ||
| 122 | r.tokens.push_back(y); | 84 | r.tokens.push_back(y); |
| 123 | r.timestamps.push_back(t + r.frame_offset); | 85 | r.timestamps.push_back(t + r.frame_offset); |
| 86 | + r.num_trailing_blanks = 0; | ||
| 124 | 87 | ||
| 125 | - decoder_input_pair = BuildDecoderInput(y, model->Allocator()); | 88 | + decoder_input = BuildDecoderInput(y, model->Allocator()); |
| 126 | 89 | ||
| 127 | // last decoder state becomes the current state for the first chunk | 90 | // last decoder state becomes the current state for the first chunk |
| 128 | - decoder_output_pair = | ||
| 129 | - model->RunDecoder(std::move(decoder_input_pair.first), | ||
| 130 | - std::move(decoder_states)); | ||
| 131 | - | ||
| 132 | - // Update the decoder states for the next chunk | ||
| 133 | - decoder_states = std::move(decoder_output_pair.second); | 91 | + decoder_output_pair = model->RunDecoder( |
| 92 | + std::move(decoder_input), std::move(decoder_output_pair.second)); | ||
| 93 | + } else { | ||
| 94 | + ++r.num_trailing_blanks; | ||
| 134 | } | 95 | } |
| 135 | } | 96 | } |
| 136 | 97 | ||
| 137 | - decoder_out = std::move(decoder_output_pair.first); | ||
| 138 | -// UpdateCachedDecoderOut(model->Allocator(), &decoder_out, result); | ||
| 139 | - | ||
| 140 | - // Update frame_offset | ||
| 141 | - for (auto &r : *result) { | ||
| 142 | - r.frame_offset += num_rows; | 98 | + if (emitted) { |
| 99 | + s->SetNeMoDecoderStates(std::move(decoder_output_pair.second)); | ||
| 143 | } | 100 | } |
| 144 | 101 | ||
| 145 | - return std::move(decoder_states); | 102 | + r.frame_offset += num_rows; |
| 146 | } | 103 | } |
| 147 | 104 | ||
| 148 | - | ||
| 149 | -std::vector<Ort::Value> OnlineTransducerGreedySearchNeMoDecoder::Decode( | ||
| 150 | - Ort::Value encoder_out, | ||
| 151 | - std::vector<Ort::Value> decoder_states, | ||
| 152 | - std::vector<OnlineTransducerDecoderResult> *result, | ||
| 153 | - OnlineStream ** /*ss = nullptr*/, int32_t /*n= 0*/) { | ||
| 154 | - | 105 | +void OnlineTransducerGreedySearchNeMoDecoder::Decode(Ort::Value encoder_out, |
| 106 | + OnlineStream **ss, | ||
| 107 | + int32_t n) const { | ||
| 155 | auto shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape(); | 108 | auto shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape(); |
| 109 | + int32_t batch_size = static_cast<int32_t>(shape[0]); // bs = 1 | ||
| 156 | 110 | ||
| 157 | - if (shape[0] != result->size()) { | ||
| 158 | - SHERPA_ONNX_LOGE( | ||
| 159 | - "Size mismatch! encoder_out.size(0) %d, result.size(0): %d", | ||
| 160 | - static_cast<int32_t>(shape[0]), | ||
| 161 | - static_cast<int32_t>(result->size())); | 111 | + if (batch_size != n) { |
| 112 | + SHERPA_ONNX_LOGE("Size mismatch! encoder_out.size(0) %d, n: %d", | ||
| 113 | + static_cast<int32_t>(shape[0]), n); | ||
| 162 | exit(-1); | 114 | exit(-1); |
| 163 | } | 115 | } |
| 164 | 116 | ||
| 165 | - int32_t batch_size = static_cast<int32_t>(shape[0]); // bs = 1 | ||
| 166 | - int32_t dim1 = static_cast<int32_t>(shape[1]); // 2 | ||
| 167 | - int32_t dim2 = static_cast<int32_t>(shape[2]); // 512 | ||
| 168 | - | ||
| 169 | - // Define and initialize encoder_out_length | ||
| 170 | - Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); | ||
| 171 | - | ||
| 172 | - int64_t length_value = 1; | ||
| 173 | - std::vector<int64_t> length_shape = {1}; | ||
| 174 | - | ||
| 175 | - Ort::Value encoder_out_length = Ort::Value::CreateTensor<int64_t>( | ||
| 176 | - memory_info, &length_value, 1, length_shape.data(), length_shape.size() | ||
| 177 | - ); | ||
| 178 | - | ||
| 179 | - const int64_t *p_length = encoder_out_length.GetTensorData<int64_t>(); | ||
| 180 | - const float *p = encoder_out.GetTensorData<float>(); | 117 | + int32_t dim1 = static_cast<int32_t>(shape[1]); // T |
| 118 | + int32_t dim2 = static_cast<int32_t>(shape[2]); // encoder_out_dim | ||
| 181 | 119 | ||
| 182 | - // std::vector<OnlineTransducerDecoderResult> ans(batch_size); | 120 | + const float *p = encoder_out.GetTensorData<float>(); |
| 183 | 121 | ||
| 184 | for (int32_t i = 0; i != batch_size; ++i) { | 122 | for (int32_t i = 0; i != batch_size; ++i) { |
| 185 | const float *this_p = p + dim1 * dim2 * i; | 123 | const float *this_p = p + dim1 * dim2 * i; |
| 186 | - int32_t this_len = p_length[i]; | ||
| 187 | 124 | ||
| 188 | - // outputs the decoder state from last chunk. | ||
| 189 | - auto last_decoder_states = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_, decoder_states, result); | ||
| 190 | - // ans[i] = decode_result_pair.first; | ||
| 191 | - decoder_states = std::move(last_decoder_states); | 125 | + DecodeOne(this_p, dim1, dim2, model_, blank_penalty_, ss[i]); |
| 192 | } | 126 | } |
| 193 | - | ||
| 194 | - return decoder_states; | ||
| 195 | - | ||
| 196 | } | 127 | } |
| 197 | 128 | ||
| 198 | -} // namespace sherpa_onnx | ||
| 129 | +} // namespace sherpa_onnx |
| @@ -7,27 +7,22 @@ | @@ -7,27 +7,22 @@ | ||
| 7 | #define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ | 7 | #define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ |
| 8 | 8 | ||
| 9 | #include <vector> | 9 | #include <vector> |
| 10 | + | ||
| 10 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" | 11 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" |
| 11 | #include "sherpa-onnx/csrc/online-transducer-nemo-model.h" | 12 | #include "sherpa-onnx/csrc/online-transducer-nemo-model.h" |
| 12 | 13 | ||
| 13 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| 14 | 15 | ||
| 16 | +class OnlineStream; | ||
| 17 | + | ||
| 15 | class OnlineTransducerGreedySearchNeMoDecoder { | 18 | class OnlineTransducerGreedySearchNeMoDecoder { |
| 16 | public: | 19 | public: |
| 17 | OnlineTransducerGreedySearchNeMoDecoder(OnlineTransducerNeMoModel *model, | 20 | OnlineTransducerGreedySearchNeMoDecoder(OnlineTransducerNeMoModel *model, |
| 18 | float blank_penalty) | 21 | float blank_penalty) |
| 19 | - : model_(model), | ||
| 20 | - blank_penalty_(blank_penalty) {} | ||
| 21 | - | ||
| 22 | - OnlineTransducerDecoderResult GetEmptyResult() const; | ||
| 23 | - void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {} | ||
| 24 | - void StripLeadingBlanks(OnlineTransducerDecoderResult * /*r*/) const {} | ||
| 25 | - | ||
| 26 | - std::vector<Ort::Value> Decode( | ||
| 27 | - Ort::Value encoder_out, | ||
| 28 | - std::vector<Ort::Value> decoder_states, | ||
| 29 | - std::vector<OnlineTransducerDecoderResult> *result, | ||
| 30 | - OnlineStream **ss = nullptr, int32_t n = 0); | 22 | + : model_(model), blank_penalty_(blank_penalty) {} |
| 23 | + | ||
| 24 | + // @param n number of elements in ss | ||
| 25 | + void Decode(Ort::Value encoder_out, OnlineStream **ss, int32_t n) const; | ||
| 31 | 26 | ||
| 32 | private: | 27 | private: |
| 33 | OnlineTransducerNeMoModel *model_; // Not owned | 28 | OnlineTransducerNeMoModel *model_; // Not owned |
| @@ -37,4 +32,3 @@ class OnlineTransducerGreedySearchNeMoDecoder { | @@ -37,4 +32,3 @@ class OnlineTransducerGreedySearchNeMoDecoder { | ||
| 37 | } // namespace sherpa_onnx | 32 | } // namespace sherpa_onnx |
| 38 | 33 | ||
| 39 | #endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ | 34 | #endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ |
| 40 | - |
| @@ -54,7 +54,7 @@ class OnlineTransducerNeMoModel::Impl { | @@ -54,7 +54,7 @@ class OnlineTransducerNeMoModel::Impl { | ||
| 54 | InitJoiner(buf.data(), buf.size()); | 54 | InitJoiner(buf.data(), buf.size()); |
| 55 | } | 55 | } |
| 56 | } | 56 | } |
| 57 | - | 57 | + |
| 58 | #if __ANDROID_API__ >= 9 | 58 | #if __ANDROID_API__ >= 9 |
| 59 | Impl(AAssetManager *mgr, const OnlineModelConfig &config) | 59 | Impl(AAssetManager *mgr, const OnlineModelConfig &config) |
| 60 | : config_(config), | 60 | : config_(config), |
| @@ -79,7 +79,7 @@ class OnlineTransducerNeMoModel::Impl { | @@ -79,7 +79,7 @@ class OnlineTransducerNeMoModel::Impl { | ||
| 79 | #endif | 79 | #endif |
| 80 | 80 | ||
| 81 | std::vector<Ort::Value> RunEncoder(Ort::Value features, | 81 | std::vector<Ort::Value> RunEncoder(Ort::Value features, |
| 82 | - std::vector<Ort::Value> states) { | 82 | + std::vector<Ort::Value> states) { |
| 83 | Ort::Value &cache_last_channel = states[0]; | 83 | Ort::Value &cache_last_channel = states[0]; |
| 84 | Ort::Value &cache_last_time = states[1]; | 84 | Ort::Value &cache_last_time = states[1]; |
| 85 | Ort::Value &cache_last_channel_len = states[2]; | 85 | Ort::Value &cache_last_channel_len = states[2]; |
| @@ -102,9 +102,9 @@ class OnlineTransducerNeMoModel::Impl { | @@ -102,9 +102,9 @@ class OnlineTransducerNeMoModel::Impl { | ||
| 102 | std::move(features), View(&length), std::move(cache_last_channel), | 102 | std::move(features), View(&length), std::move(cache_last_channel), |
| 103 | std::move(cache_last_time), std::move(cache_last_channel_len)}; | 103 | std::move(cache_last_time), std::move(cache_last_channel_len)}; |
| 104 | 104 | ||
| 105 | - auto out = | ||
| 106 | - encoder_sess_->Run({}, encoder_input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 107 | - encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size()); | 105 | + auto out = encoder_sess_->Run( |
| 106 | + {}, encoder_input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 107 | + encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size()); | ||
| 108 | // out[0]: logit | 108 | // out[0]: logit |
| 109 | // out[1] logit_length | 109 | // out[1] logit_length |
| 110 | // out[2:] states_next | 110 | // out[2:] states_next |
| @@ -127,17 +127,19 @@ class OnlineTransducerNeMoModel::Impl { | @@ -127,17 +127,19 @@ class OnlineTransducerNeMoModel::Impl { | ||
| 127 | 127 | ||
| 128 | std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder( | 128 | std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder( |
| 129 | Ort::Value targets, std::vector<Ort::Value> states) { | 129 | Ort::Value targets, std::vector<Ort::Value> states) { |
| 130 | - | ||
| 131 | - Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); | 130 | + Ort::MemoryInfo memory_info = |
| 131 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); | ||
| 132 | + | ||
| 133 | + auto shape = targets.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 134 | + int32_t batch_size = static_cast<int32_t>(shape[0]); | ||
| 132 | 135 | ||
| 133 | - // Create the tensor with a single int32_t value of 1 | ||
| 134 | - int32_t length_value = 1; | ||
| 135 | - std::vector<int64_t> length_shape = {1}; | 136 | + std::vector<int64_t> length_shape = {batch_size}; |
| 137 | + std::vector<int32_t> length_value(batch_size, 1); | ||
| 136 | 138 | ||
| 137 | Ort::Value targets_length = Ort::Value::CreateTensor<int32_t>( | 139 | Ort::Value targets_length = Ort::Value::CreateTensor<int32_t>( |
| 138 | - memory_info, &length_value, 1, length_shape.data(), length_shape.size() | ||
| 139 | - ); | ||
| 140 | - | 140 | + memory_info, length_value.data(), batch_size, length_shape.data(), |
| 141 | + length_shape.size()); | ||
| 142 | + | ||
| 141 | std::vector<Ort::Value> decoder_inputs; | 143 | std::vector<Ort::Value> decoder_inputs; |
| 142 | decoder_inputs.reserve(2 + states.size()); | 144 | decoder_inputs.reserve(2 + states.size()); |
| 143 | 145 | ||
| @@ -171,35 +173,21 @@ class OnlineTransducerNeMoModel::Impl { | @@ -171,35 +173,21 @@ class OnlineTransducerNeMoModel::Impl { | ||
| 171 | Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) { | 173 | Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) { |
| 172 | std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out), | 174 | std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out), |
| 173 | std::move(decoder_out)}; | 175 | std::move(decoder_out)}; |
| 174 | - auto logit = | ||
| 175 | - joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(), | ||
| 176 | - joiner_input.size(), joiner_output_names_ptr_.data(), | ||
| 177 | - joiner_output_names_ptr_.size()); | 176 | + auto logit = joiner_sess_->Run({}, joiner_input_names_ptr_.data(), |
| 177 | + joiner_input.data(), joiner_input.size(), | ||
| 178 | + joiner_output_names_ptr_.data(), | ||
| 179 | + joiner_output_names_ptr_.size()); | ||
| 178 | 180 | ||
| 179 | return std::move(logit[0]); | 181 | return std::move(logit[0]); |
| 180 | -} | ||
| 181 | - | ||
| 182 | - std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const { | ||
| 183 | - std::array<int64_t, 3> s0_shape{pred_rnn_layers_, batch_size, pred_hidden_}; | ||
| 184 | - Ort::Value s0 = Ort::Value::CreateTensor<float>(allocator_, s0_shape.data(), | ||
| 185 | - s0_shape.size()); | ||
| 186 | - | ||
| 187 | - Fill<float>(&s0, 0); | ||
| 188 | - | ||
| 189 | - std::array<int64_t, 3> s1_shape{pred_rnn_layers_, batch_size, pred_hidden_}; | ||
| 190 | - | ||
| 191 | - Ort::Value s1 = Ort::Value::CreateTensor<float>(allocator_, s1_shape.data(), | ||
| 192 | - s1_shape.size()); | ||
| 193 | - | ||
| 194 | - Fill<float>(&s1, 0); | ||
| 195 | - | ||
| 196 | - std::vector<Ort::Value> states; | 182 | + } |
| 197 | 183 | ||
| 198 | - states.reserve(2); | ||
| 199 | - states.push_back(std::move(s0)); | ||
| 200 | - states.push_back(std::move(s1)); | 184 | + std::vector<Ort::Value> GetDecoderInitStates() { |
| 185 | + std::vector<Ort::Value> ans; | ||
| 186 | + ans.reserve(2); | ||
| 187 | + ans.push_back(View(&lstm0_)); | ||
| 188 | + ans.push_back(View(&lstm1_)); | ||
| 201 | 189 | ||
| 202 | - return states; | 190 | + return ans; |
| 203 | } | 191 | } |
| 204 | 192 | ||
| 205 | int32_t ChunkSize() const { return window_size_; } | 193 | int32_t ChunkSize() const { return window_size_; } |
| @@ -207,7 +195,7 @@ class OnlineTransducerNeMoModel::Impl { | @@ -207,7 +195,7 @@ class OnlineTransducerNeMoModel::Impl { | ||
| 207 | int32_t ChunkShift() const { return chunk_shift_; } | 195 | int32_t ChunkShift() const { return chunk_shift_; } |
| 208 | 196 | ||
| 209 | int32_t SubsamplingFactor() const { return subsampling_factor_; } | 197 | int32_t SubsamplingFactor() const { return subsampling_factor_; } |
| 210 | - | 198 | + |
| 211 | int32_t VocabSize() const { return vocab_size_; } | 199 | int32_t VocabSize() const { return vocab_size_; } |
| 212 | 200 | ||
| 213 | OrtAllocator *Allocator() const { return allocator_; } | 201 | OrtAllocator *Allocator() const { return allocator_; } |
| @@ -218,7 +206,7 @@ class OnlineTransducerNeMoModel::Impl { | @@ -218,7 +206,7 @@ class OnlineTransducerNeMoModel::Impl { | ||
| 218 | // - cache_last_channel | 206 | // - cache_last_channel |
| 219 | // - cache_last_time_ | 207 | // - cache_last_time_ |
| 220 | // - cache_last_channel_len | 208 | // - cache_last_channel_len |
| 221 | - std::vector<Ort::Value> GetInitStates() { | 209 | + std::vector<Ort::Value> GetEncoderInitStates() { |
| 222 | std::vector<Ort::Value> ans; | 210 | std::vector<Ort::Value> ans; |
| 223 | ans.reserve(3); | 211 | ans.reserve(3); |
| 224 | ans.push_back(View(&cache_last_channel_)); | 212 | ans.push_back(View(&cache_last_channel_)); |
| @@ -228,7 +216,75 @@ class OnlineTransducerNeMoModel::Impl { | @@ -228,7 +216,75 @@ class OnlineTransducerNeMoModel::Impl { | ||
| 228 | return ans; | 216 | return ans; |
| 229 | } | 217 | } |
| 230 | 218 | ||
| 231 | -private: | 219 | + std::vector<Ort::Value> StackStates( |
| 220 | + std::vector<std::vector<Ort::Value>> states) const { | ||
| 221 | + int32_t batch_size = static_cast<int32_t>(states.size()); | ||
| 222 | + if (batch_size == 1) { | ||
| 223 | + return std::move(states[0]); | ||
| 224 | + } | ||
| 225 | + | ||
| 226 | + std::vector<Ort::Value> ans; | ||
| 227 | + | ||
| 228 | + // stack cache_last_channel | ||
| 229 | + std::vector<const Ort::Value *> buf(batch_size); | ||
| 230 | + | ||
| 231 | + // there are 3 states to be stacked | ||
| 232 | + for (int32_t i = 0; i != 3; ++i) { | ||
| 233 | + buf.clear(); | ||
| 234 | + buf.reserve(batch_size); | ||
| 235 | + | ||
| 236 | + for (int32_t b = 0; b != batch_size; ++b) { | ||
| 237 | + assert(states[b].size() == 3); | ||
| 238 | + buf.push_back(&states[b][i]); | ||
| 239 | + } | ||
| 240 | + | ||
| 241 | + Ort::Value c{nullptr}; | ||
| 242 | + if (i == 2) { | ||
| 243 | + c = Cat<int64_t>(allocator_, buf, 0); | ||
| 244 | + } else { | ||
| 245 | + c = Cat(allocator_, buf, 0); | ||
| 246 | + } | ||
| 247 | + | ||
| 248 | + ans.push_back(std::move(c)); | ||
| 249 | + } | ||
| 250 | + | ||
| 251 | + return ans; | ||
| 252 | + } | ||
| 253 | + | ||
| 254 | + std::vector<std::vector<Ort::Value>> UnStackStates( | ||
| 255 | + std::vector<Ort::Value> states) const { | ||
| 256 | + assert(states.size() == 3); | ||
| 257 | + | ||
| 258 | + std::vector<std::vector<Ort::Value>> ans; | ||
| 259 | + | ||
| 260 | + auto shape = states[0].GetTensorTypeAndShapeInfo().GetShape(); | ||
| 261 | + int32_t batch_size = shape[0]; | ||
| 262 | + ans.resize(batch_size); | ||
| 263 | + | ||
| 264 | + if (batch_size == 1) { | ||
| 265 | + ans[0] = std::move(states); | ||
| 266 | + return ans; | ||
| 267 | + } | ||
| 268 | + | ||
| 269 | + for (int32_t i = 0; i != 3; ++i) { | ||
| 270 | + std::vector<Ort::Value> v; | ||
| 271 | + if (i == 2) { | ||
| 272 | + v = Unbind<int64_t>(allocator_, &states[i], 0); | ||
| 273 | + } else { | ||
| 274 | + v = Unbind(allocator_, &states[i], 0); | ||
| 275 | + } | ||
| 276 | + | ||
| 277 | + assert(v.size() == batch_size); | ||
| 278 | + | ||
| 279 | + for (int32_t b = 0; b != batch_size; ++b) { | ||
| 280 | + ans[b].push_back(std::move(v[b])); | ||
| 281 | + } | ||
| 282 | + } | ||
| 283 | + | ||
| 284 | + return ans; | ||
| 285 | + } | ||
| 286 | + | ||
| 287 | + private: | ||
| 232 | void InitEncoder(void *model_data, size_t model_data_length) { | 288 | void InitEncoder(void *model_data, size_t model_data_length) { |
| 233 | encoder_sess_ = std::make_unique<Ort::Session>( | 289 | encoder_sess_ = std::make_unique<Ort::Session>( |
| 234 | env_, model_data, model_data_length, sess_opts_); | 290 | env_, model_data, model_data_length, sess_opts_); |
| @@ -276,10 +332,10 @@ private: | @@ -276,10 +332,10 @@ private: | ||
| 276 | normalize_type_ = ""; | 332 | normalize_type_ = ""; |
| 277 | } | 333 | } |
| 278 | 334 | ||
| 279 | - InitStates(); | 335 | + InitEncoderStates(); |
| 280 | } | 336 | } |
| 281 | - | ||
| 282 | - void InitStates() { | 337 | + |
| 338 | + void InitEncoderStates() { | ||
| 283 | std::array<int64_t, 4> cache_last_channel_shape{1, cache_last_channel_dim1_, | 339 | std::array<int64_t, 4> cache_last_channel_shape{1, cache_last_channel_dim1_, |
| 284 | cache_last_channel_dim2_, | 340 | cache_last_channel_dim2_, |
| 285 | cache_last_channel_dim3_}; | 341 | cache_last_channel_dim3_}; |
| @@ -313,7 +369,25 @@ private: | @@ -313,7 +369,25 @@ private: | ||
| 313 | &decoder_input_names_ptr_); | 369 | &decoder_input_names_ptr_); |
| 314 | 370 | ||
| 315 | GetOutputNames(decoder_sess_.get(), &decoder_output_names_, | 371 | GetOutputNames(decoder_sess_.get(), &decoder_output_names_, |
| 316 | - &decoder_output_names_ptr_); | 372 | + &decoder_output_names_ptr_); |
| 373 | + | ||
| 374 | + InitDecoderStates(); | ||
| 375 | + } | ||
| 376 | + | ||
| 377 | + void InitDecoderStates() { | ||
| 378 | + int32_t batch_size = 1; | ||
| 379 | + std::array<int64_t, 3> s0_shape{pred_rnn_layers_, batch_size, pred_hidden_}; | ||
| 380 | + lstm0_ = Ort::Value::CreateTensor<float>(allocator_, s0_shape.data(), | ||
| 381 | + s0_shape.size()); | ||
| 382 | + | ||
| 383 | + Fill<float>(&lstm0_, 0); | ||
| 384 | + | ||
| 385 | + std::array<int64_t, 3> s1_shape{pred_rnn_layers_, batch_size, pred_hidden_}; | ||
| 386 | + | ||
| 387 | + lstm1_ = Ort::Value::CreateTensor<float>(allocator_, s1_shape.data(), | ||
| 388 | + s1_shape.size()); | ||
| 389 | + | ||
| 390 | + Fill<float>(&lstm1_, 0); | ||
| 317 | } | 391 | } |
| 318 | 392 | ||
| 319 | void InitJoiner(void *model_data, size_t model_data_length) { | 393 | void InitJoiner(void *model_data, size_t model_data_length) { |
| @@ -324,7 +398,7 @@ private: | @@ -324,7 +398,7 @@ private: | ||
| 324 | &joiner_input_names_ptr_); | 398 | &joiner_input_names_ptr_); |
| 325 | 399 | ||
| 326 | GetOutputNames(joiner_sess_.get(), &joiner_output_names_, | 400 | GetOutputNames(joiner_sess_.get(), &joiner_output_names_, |
| 327 | - &joiner_output_names_ptr_); | 401 | + &joiner_output_names_ptr_); |
| 328 | } | 402 | } |
| 329 | 403 | ||
| 330 | private: | 404 | private: |
| @@ -363,6 +437,7 @@ private: | @@ -363,6 +437,7 @@ private: | ||
| 363 | int32_t pred_rnn_layers_ = -1; | 437 | int32_t pred_rnn_layers_ = -1; |
| 364 | int32_t pred_hidden_ = -1; | 438 | int32_t pred_hidden_ = -1; |
| 365 | 439 | ||
| 440 | + // encoder states | ||
| 366 | int32_t cache_last_channel_dim1_; | 441 | int32_t cache_last_channel_dim1_; |
| 367 | int32_t cache_last_channel_dim2_; | 442 | int32_t cache_last_channel_dim2_; |
| 368 | int32_t cache_last_channel_dim3_; | 443 | int32_t cache_last_channel_dim3_; |
| @@ -370,9 +445,14 @@ private: | @@ -370,9 +445,14 @@ private: | ||
| 370 | int32_t cache_last_time_dim2_; | 445 | int32_t cache_last_time_dim2_; |
| 371 | int32_t cache_last_time_dim3_; | 446 | int32_t cache_last_time_dim3_; |
| 372 | 447 | ||
| 448 | + // init encoder states | ||
| 373 | Ort::Value cache_last_channel_{nullptr}; | 449 | Ort::Value cache_last_channel_{nullptr}; |
| 374 | Ort::Value cache_last_time_{nullptr}; | 450 | Ort::Value cache_last_time_{nullptr}; |
| 375 | Ort::Value cache_last_channel_len_{nullptr}; | 451 | Ort::Value cache_last_channel_len_{nullptr}; |
| 452 | + | ||
| 453 | + // init decoder states | ||
| 454 | + Ort::Value lstm0_{nullptr}; | ||
| 455 | + Ort::Value lstm1_{nullptr}; | ||
| 376 | }; | 456 | }; |
| 377 | 457 | ||
| 378 | OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( | 458 | OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( |
| @@ -387,10 +467,9 @@ OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( | @@ -387,10 +467,9 @@ OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( | ||
| 387 | 467 | ||
| 388 | OnlineTransducerNeMoModel::~OnlineTransducerNeMoModel() = default; | 468 | OnlineTransducerNeMoModel::~OnlineTransducerNeMoModel() = default; |
| 389 | 469 | ||
| 390 | -std::vector<Ort::Value> | ||
| 391 | -OnlineTransducerNeMoModel::RunEncoder(Ort::Value features, | ||
| 392 | - std::vector<Ort::Value> states) const { | ||
| 393 | - return impl_->RunEncoder(std::move(features), std::move(states)); | 470 | +std::vector<Ort::Value> OnlineTransducerNeMoModel::RunEncoder( |
| 471 | + Ort::Value features, std::vector<Ort::Value> states) const { | ||
| 472 | + return impl_->RunEncoder(std::move(features), std::move(states)); | ||
| 394 | } | 473 | } |
| 395 | 474 | ||
| 396 | std::pair<Ort::Value, std::vector<Ort::Value>> | 475 | std::pair<Ort::Value, std::vector<Ort::Value>> |
| @@ -399,9 +478,9 @@ OnlineTransducerNeMoModel::RunDecoder(Ort::Value targets, | @@ -399,9 +478,9 @@ OnlineTransducerNeMoModel::RunDecoder(Ort::Value targets, | ||
| 399 | return impl_->RunDecoder(std::move(targets), std::move(states)); | 478 | return impl_->RunDecoder(std::move(targets), std::move(states)); |
| 400 | } | 479 | } |
| 401 | 480 | ||
| 402 | -std::vector<Ort::Value> OnlineTransducerNeMoModel::GetDecoderInitStates( | ||
| 403 | - int32_t batch_size) const { | ||
| 404 | - return impl_->GetDecoderInitStates(batch_size); | 481 | +std::vector<Ort::Value> OnlineTransducerNeMoModel::GetDecoderInitStates() |
| 482 | + const { | ||
| 483 | + return impl_->GetDecoderInitStates(); | ||
| 405 | } | 484 | } |
| 406 | 485 | ||
| 407 | Ort::Value OnlineTransducerNeMoModel::RunJoiner(Ort::Value encoder_out, | 486 | Ort::Value OnlineTransducerNeMoModel::RunJoiner(Ort::Value encoder_out, |
| @@ -409,14 +488,13 @@ Ort::Value OnlineTransducerNeMoModel::RunJoiner(Ort::Value encoder_out, | @@ -409,14 +488,13 @@ Ort::Value OnlineTransducerNeMoModel::RunJoiner(Ort::Value encoder_out, | ||
| 409 | return impl_->RunJoiner(std::move(encoder_out), std::move(decoder_out)); | 488 | return impl_->RunJoiner(std::move(encoder_out), std::move(decoder_out)); |
| 410 | } | 489 | } |
| 411 | 490 | ||
| 491 | +int32_t OnlineTransducerNeMoModel::ChunkSize() const { | ||
| 492 | + return impl_->ChunkSize(); | ||
| 493 | +} | ||
| 412 | 494 | ||
| 413 | -int32_t OnlineTransducerNeMoModel::ChunkSize() const { | ||
| 414 | - return impl_->ChunkSize(); | ||
| 415 | - } | ||
| 416 | - | ||
| 417 | -int32_t OnlineTransducerNeMoModel::ChunkShift() const { | ||
| 418 | - return impl_->ChunkShift(); | ||
| 419 | - } | 495 | +int32_t OnlineTransducerNeMoModel::ChunkShift() const { |
| 496 | + return impl_->ChunkShift(); | ||
| 497 | +} | ||
| 420 | 498 | ||
| 421 | int32_t OnlineTransducerNeMoModel::SubsamplingFactor() const { | 499 | int32_t OnlineTransducerNeMoModel::SubsamplingFactor() const { |
| 422 | return impl_->SubsamplingFactor(); | 500 | return impl_->SubsamplingFactor(); |
| @@ -434,8 +512,19 @@ std::string OnlineTransducerNeMoModel::FeatureNormalizationMethod() const { | @@ -434,8 +512,19 @@ std::string OnlineTransducerNeMoModel::FeatureNormalizationMethod() const { | ||
| 434 | return impl_->FeatureNormalizationMethod(); | 512 | return impl_->FeatureNormalizationMethod(); |
| 435 | } | 513 | } |
| 436 | 514 | ||
| 437 | -std::vector<Ort::Value> OnlineTransducerNeMoModel::GetInitStates() const { | ||
| 438 | - return impl_->GetInitStates(); | 515 | +std::vector<Ort::Value> OnlineTransducerNeMoModel::GetEncoderInitStates() |
| 516 | + const { | ||
| 517 | + return impl_->GetEncoderInitStates(); | ||
| 518 | +} | ||
| 519 | + | ||
| 520 | +std::vector<Ort::Value> OnlineTransducerNeMoModel::StackStates( | ||
| 521 | + std::vector<std::vector<Ort::Value>> states) const { | ||
| 522 | + return impl_->StackStates(std::move(states)); | ||
| 523 | +} | ||
| 524 | + | ||
| 525 | +std::vector<std::vector<Ort::Value>> OnlineTransducerNeMoModel::UnStackStates( | ||
| 526 | + std::vector<Ort::Value> states) const { | ||
| 527 | + return impl_->UnStackStates(std::move(states)); | ||
| 439 | } | 528 | } |
| 440 | 529 | ||
| 441 | -} // namespace sherpa_onnx | ||
| 530 | +} // namespace sherpa_onnx |
| @@ -32,22 +32,31 @@ class OnlineTransducerNeMoModel { | @@ -32,22 +32,31 @@ class OnlineTransducerNeMoModel { | ||
| 32 | OnlineTransducerNeMoModel(AAssetManager *mgr, | 32 | OnlineTransducerNeMoModel(AAssetManager *mgr, |
| 33 | const OnlineModelConfig &config); | 33 | const OnlineModelConfig &config); |
| 34 | #endif | 34 | #endif |
| 35 | - | 35 | + |
| 36 | ~OnlineTransducerNeMoModel(); | 36 | ~OnlineTransducerNeMoModel(); |
| 37 | - // A list of 3 tensors: | 37 | + // A list of 3 tensors: |
| 38 | // - cache_last_channel | 38 | // - cache_last_channel |
| 39 | // - cache_last_time | 39 | // - cache_last_time |
| 40 | // - cache_last_channel_len | 40 | // - cache_last_channel_len |
| 41 | - std::vector<Ort::Value> GetInitStates() const; | 41 | + std::vector<Ort::Value> GetEncoderInitStates() const; |
| 42 | + | ||
| 43 | + // stack encoder states | ||
| 44 | + std::vector<Ort::Value> StackStates( | ||
| 45 | + std::vector<std::vector<Ort::Value>> states) const; | ||
| 46 | + | ||
| 47 | + // unstack encoder states | ||
| 48 | + std::vector<std::vector<Ort::Value>> UnStackStates( | ||
| 49 | + std::vector<Ort::Value> states) const; | ||
| 42 | 50 | ||
| 43 | /** Run the encoder. | 51 | /** Run the encoder. |
| 44 | * | 52 | * |
| 45 | * @param features A tensor of shape (N, T, C). It is changed in-place. | 53 | * @param features A tensor of shape (N, T, C). It is changed in-place. |
| 46 | - * @param states It is from GetInitStates() or returned from this method. | ||
| 47 | - * | 54 | + * @param states It is from GetEncoderInitStates() or returned from this |
| 55 | + * method. | ||
| 56 | + * | ||
| 48 | * @return Return a tuple containing: | 57 | * @return Return a tuple containing: |
| 49 | - * - ans[0]: encoder_out, a tensor of shape (N, T', encoder_out_dim) | ||
| 50 | - * - ans[1:]: contains next states | 58 | + * - ans[0]: encoder_out, a tensor of shape (N, encoder_out_dim, T') |
| 59 | + * - ans[1:]: contains next states | ||
| 51 | */ | 60 | */ |
| 52 | std::vector<Ort::Value> RunEncoder( | 61 | std::vector<Ort::Value> RunEncoder( |
| 53 | Ort::Value features, std::vector<Ort::Value> states) const; // NOLINT | 62 | Ort::Value features, std::vector<Ort::Value> states) const; // NOLINT |
| @@ -63,7 +72,7 @@ class OnlineTransducerNeMoModel { | @@ -63,7 +72,7 @@ class OnlineTransducerNeMoModel { | ||
| 63 | std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder( | 72 | std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder( |
| 64 | Ort::Value targets, std::vector<Ort::Value> states) const; | 73 | Ort::Value targets, std::vector<Ort::Value> states) const; |
| 65 | 74 | ||
| 66 | - std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const; | 75 | + std::vector<Ort::Value> GetDecoderInitStates() const; |
| 67 | 76 | ||
| 68 | /** Run the joint network. | 77 | /** Run the joint network. |
| 69 | * | 78 | * |
| @@ -71,9 +80,7 @@ class OnlineTransducerNeMoModel { | @@ -71,9 +80,7 @@ class OnlineTransducerNeMoModel { | ||
| 71 | * @param decoder_out Output of the decoder network. | 80 | * @param decoder_out Output of the decoder network. |
| 72 | * @return Return a tensor of shape (N, 1, 1, vocab_size) containing logits. | 81 | * @return Return a tensor of shape (N, 1, 1, vocab_size) containing logits. |
| 73 | */ | 82 | */ |
| 74 | - Ort::Value RunJoiner(Ort::Value encoder_out, | ||
| 75 | - Ort::Value decoder_out) const; | ||
| 76 | - | 83 | + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) const; |
| 77 | 84 | ||
| 78 | /** We send this number of feature frames to the encoder at a time. */ | 85 | /** We send this number of feature frames to the encoder at a time. */ |
| 79 | int32_t ChunkSize() const; | 86 | int32_t ChunkSize() const; |
| @@ -114,10 +121,10 @@ class OnlineTransducerNeMoModel { | @@ -114,10 +121,10 @@ class OnlineTransducerNeMoModel { | ||
| 114 | // for details | 121 | // for details |
| 115 | std::string FeatureNormalizationMethod() const; | 122 | std::string FeatureNormalizationMethod() const; |
| 116 | 123 | ||
| 117 | - private: | ||
| 118 | - class Impl; | ||
| 119 | - std::unique_ptr<Impl> impl_; | ||
| 120 | - }; | 124 | + private: |
| 125 | + class Impl; | ||
| 126 | + std::unique_ptr<Impl> impl_; | ||
| 127 | +}; | ||
| 121 | 128 | ||
| 122 | } // namespace sherpa_onnx | 129 | } // namespace sherpa_onnx |
| 123 | 130 |
-
请 注册 或 登录 后发表评论