Fangjun Kuang
Committed by GitHub

Fix nemo streaming transducer greedy search (#944)

@@ -16,6 +16,45 @@ echo "PATH: $PATH" @@ -16,6 +16,45 @@ echo "PATH: $PATH"
16 which $EXE 16 which $EXE
17 17
18 log "------------------------------------------------------------" 18 log "------------------------------------------------------------"
  19 +log "Run NeMo transducer (English)"
  20 +log "------------------------------------------------------------"
  21 +repo_url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms.tar.bz2
  22 +curl -SL -O $repo_url
  23 +tar xvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms.tar.bz2
  24 +rm sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms.tar.bz2
  25 +repo=sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms
  26 +
  27 +log "Start testing ${repo_url}"
  28 +
  29 +waves=(
  30 +$repo/test_wavs/0.wav
  31 +$repo/test_wavs/1.wav
  32 +$repo/test_wavs/8k.wav
  33 +)
  34 +
  35 +for wave in ${waves[@]}; do
  36 + time $EXE \
  37 + --tokens=$repo/tokens.txt \
  38 + --encoder=$repo/encoder.onnx \
  39 + --decoder=$repo/decoder.onnx \
  40 + --joiner=$repo/joiner.onnx \
  41 + --num-threads=2 \
  42 + $wave
  43 +done
  44 +
  45 +time $EXE \
  46 + --tokens=$repo/tokens.txt \
  47 + --encoder=$repo/encoder.onnx \
  48 + --decoder=$repo/decoder.onnx \
  49 + --joiner=$repo/joiner.onnx \
  50 + --num-threads=2 \
  51 + $repo/test_wavs/0.wav \
  52 + $repo/test_wavs/1.wav \
  53 + $repo/test_wavs/8k.wav
  54 +
  55 +rm -rf $repo
  56 +
  57 +log "------------------------------------------------------------"
19 log "Run LSTM transducer (English)" 58 log "Run LSTM transducer (English)"
20 log "------------------------------------------------------------" 59 log "------------------------------------------------------------"
21 60
@@ -196,7 +196,6 @@ jobs: @@ -196,7 +196,6 @@ jobs:
196 GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface 196 GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
197 197
198 cd huggingface 198 cd huggingface
199 - git lfs pull  
200 mkdir -p aarch64 199 mkdir -p aarch64
201 200
202 cp -v ../sherpa-onnx-*-shared.tar.bz2 ./aarch64 201 cp -v ../sherpa-onnx-*-shared.tar.bz2 ./aarch64
@@ -187,7 +187,6 @@ jobs: @@ -187,7 +187,6 @@ jobs:
187 GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface 187 GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
188 188
189 cd huggingface 189 cd huggingface
190 - git lfs pull  
191 mkdir -p aarch64 190 mkdir -p aarch64
192 191
193 cp -v ../sherpa-onnx-*-static.tar.bz2 ./aarch64 192 cp -v ../sherpa-onnx-*-static.tar.bz2 ./aarch64
@@ -124,7 +124,6 @@ jobs: @@ -124,7 +124,6 @@ jobs:
124 GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface 124 GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
125 125
126 cd huggingface 126 cd huggingface
127 - git lfs pull  
128 127
129 cp -v ../sherpa-onnx-*-android.tar.bz2 ./ 128 cp -v ../sherpa-onnx-*-android.tar.bz2 ./
130 129
@@ -209,7 +209,6 @@ jobs: @@ -209,7 +209,6 @@ jobs:
209 GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface 209 GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
210 210
211 cd huggingface 211 cd huggingface
212 - git lfs pull  
213 mkdir -p arm32 212 mkdir -p arm32
214 213
215 cp -v ../sherpa-onnx-*.tar.bz2 ./arm32 214 cp -v ../sherpa-onnx-*.tar.bz2 ./arm32
@@ -138,7 +138,6 @@ jobs: @@ -138,7 +138,6 @@ jobs:
138 GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface 138 GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
139 139
140 cd huggingface 140 cd huggingface
141 - git lfs pull  
142 141
143 cp -v ../sherpa-onnx-*.tar.bz2 ./ 142 cp -v ../sherpa-onnx-*.tar.bz2 ./
144 143
@@ -242,7 +242,6 @@ jobs: @@ -242,7 +242,6 @@ jobs:
242 GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface 242 GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
243 243
244 cd huggingface 244 cd huggingface
245 - git lfs pull  
246 mkdir -p riscv64 245 mkdir -p riscv64
247 246
248 cp -v ../sherpa-onnx-*-shared.tar.bz2 ./riscv64 247 cp -v ../sherpa-onnx-*-shared.tar.bz2 ./riscv64
@@ -219,7 +219,6 @@ jobs: @@ -219,7 +219,6 @@ jobs:
219 GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface 219 GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
220 220
221 cd huggingface 221 cd huggingface
222 - git lfs pull  
223 mkdir -p win64 222 mkdir -p win64
224 223
225 cp -v ../sherpa-onnx-*.tar.bz2 ./win64 224 cp -v ../sherpa-onnx-*.tar.bz2 ./win64
@@ -221,7 +221,6 @@ jobs: @@ -221,7 +221,6 @@ jobs:
221 GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface 221 GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
222 222
223 cd huggingface 223 cd huggingface
224 - git lfs pull  
225 mkdir -p win32 224 mkdir -p win32
226 225
227 cp -v ../sherpa-onnx-*.tar.bz2 ./win32 226 cp -v ../sherpa-onnx-*.tar.bz2 ./win32
@@ -14,19 +14,18 @@ namespace sherpa_onnx { @@ -14,19 +14,18 @@ namespace sherpa_onnx {
14 14
15 std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( 15 std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
16 const OnlineRecognizerConfig &config) { 16 const OnlineRecognizerConfig &config) {
17 -  
18 if (!config.model_config.transducer.encoder.empty()) { 17 if (!config.model_config.transducer.encoder.empty()) {
19 Ort::Env env(ORT_LOGGING_LEVEL_WARNING); 18 Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
20 - 19 +
21 auto decoder_model = ReadFile(config.model_config.transducer.decoder); 20 auto decoder_model = ReadFile(config.model_config.transducer.decoder);
22 - auto sess = std::make_unique<Ort::Session>(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});  
23 - 21 + auto sess = std::make_unique<Ort::Session>(
  22 + env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});
  23 +
24 size_t node_count = sess->GetOutputCount(); 24 size_t node_count = sess->GetOutputCount();
25 - 25 +
26 if (node_count == 1) { 26 if (node_count == 1) {
27 return std::make_unique<OnlineRecognizerTransducerImpl>(config); 27 return std::make_unique<OnlineRecognizerTransducerImpl>(config);
28 } else { 28 } else {
29 - SHERPA_ONNX_LOGE("Running streaming Nemo transducer model");  
30 return std::make_unique<OnlineRecognizerTransducerNeMoImpl>(config); 29 return std::make_unique<OnlineRecognizerTransducerNeMoImpl>(config);
31 } 30 }
32 } 31 }
@@ -50,12 +49,13 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( @@ -50,12 +49,13 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
50 AAssetManager *mgr, const OnlineRecognizerConfig &config) { 49 AAssetManager *mgr, const OnlineRecognizerConfig &config) {
51 if (!config.model_config.transducer.encoder.empty()) { 50 if (!config.model_config.transducer.encoder.empty()) {
52 Ort::Env env(ORT_LOGGING_LEVEL_WARNING); 51 Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
53 - 52 +
54 auto decoder_model = ReadFile(mgr, config.model_config.transducer.decoder); 53 auto decoder_model = ReadFile(mgr, config.model_config.transducer.decoder);
55 - auto sess = std::make_unique<Ort::Session>(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});  
56 - 54 + auto sess = std::make_unique<Ort::Session>(
  55 + env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});
  56 +
57 size_t node_count = sess->GetOutputCount(); 57 size_t node_count = sess->GetOutputCount();
58 - 58 +
59 if (node_count == 1) { 59 if (node_count == 1) {
60 return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config); 60 return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config);
61 } else { 61 } else {
@@ -35,18 +35,15 @@ @@ -35,18 +35,15 @@
35 35
36 namespace sherpa_onnx { 36 namespace sherpa_onnx {
37 37
38 -static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,  
39 - const SymbolTable &sym_table,  
40 - float frame_shift_ms,  
41 - int32_t subsampling_factor,  
42 - int32_t segment,  
43 - int32_t frames_since_start) { 38 +OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
  39 + const SymbolTable &sym_table,
  40 + float frame_shift_ms, int32_t subsampling_factor,
  41 + int32_t segment, int32_t frames_since_start) {
44 OnlineRecognizerResult r; 42 OnlineRecognizerResult r;
45 r.tokens.reserve(src.tokens.size()); 43 r.tokens.reserve(src.tokens.size());
46 r.timestamps.reserve(src.tokens.size()); 44 r.timestamps.reserve(src.tokens.size());
47 45
48 for (auto i : src.tokens) { 46 for (auto i : src.tokens) {
49 - if (i == -1) continue;  
50 auto sym = sym_table[i]; 47 auto sym = sym_table[i];
51 48
52 r.text.append(sym); 49 r.text.append(sym);
@@ -6,6 +6,7 @@ @@ -6,6 +6,7 @@
6 #ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ 6 #ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
7 #define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ 7 #define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
8 8
  9 +#include <algorithm>
9 #include <fstream> 10 #include <fstream>
10 #include <ios> 11 #include <ios>
11 #include <memory> 12 #include <memory>
@@ -32,23 +33,20 @@ @@ -32,23 +33,20 @@
32 namespace sherpa_onnx { 33 namespace sherpa_onnx {
33 34
34 // defined in ./online-recognizer-transducer-impl.h 35 // defined in ./online-recognizer-transducer-impl.h
35 -// static may or may not be here? TODDOs  
36 -static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,  
37 - const SymbolTable &sym_table,  
38 - float frame_shift_ms,  
39 - int32_t subsampling_factor,  
40 - int32_t segment,  
41 - int32_t frames_since_start); 36 +OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
  37 + const SymbolTable &sym_table,
  38 + float frame_shift_ms, int32_t subsampling_factor,
  39 + int32_t segment, int32_t frames_since_start);
42 40
43 class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { 41 class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
44 - public: 42 + public:
45 explicit OnlineRecognizerTransducerNeMoImpl( 43 explicit OnlineRecognizerTransducerNeMoImpl(
46 const OnlineRecognizerConfig &config) 44 const OnlineRecognizerConfig &config)
47 : config_(config), 45 : config_(config),
48 symbol_table_(config.model_config.tokens), 46 symbol_table_(config.model_config.tokens),
49 endpoint_(config_.endpoint_config), 47 endpoint_(config_.endpoint_config),
50 - model_(std::make_unique<OnlineTransducerNeMoModel>(  
51 - config.model_config)) { 48 + model_(
  49 + std::make_unique<OnlineTransducerNeMoModel>(config.model_config)) {
52 if (config.decoding_method == "greedy_search") { 50 if (config.decoding_method == "greedy_search") {
53 decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>( 51 decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>(
54 model_.get(), config_.blank_penalty); 52 model_.get(), config_.blank_penalty);
@@ -73,7 +71,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { @@ -73,7 +71,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
73 model_.get(), config_.blank_penalty); 71 model_.get(), config_.blank_penalty);
74 } else { 72 } else {
75 SHERPA_ONNX_LOGE("Unsupported decoding method: %s", 73 SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
76 - config.decoding_method.c_str()); 74 + config.decoding_method.c_str());
77 exit(-1); 75 exit(-1);
78 } 76 }
79 77
@@ -83,7 +81,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { @@ -83,7 +81,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
83 81
84 std::unique_ptr<OnlineStream> CreateStream() const override { 82 std::unique_ptr<OnlineStream> CreateStream() const override {
85 auto stream = std::make_unique<OnlineStream>(config_.feat_config); 83 auto stream = std::make_unique<OnlineStream>(config_.feat_config);
86 - stream->SetStates(model_->GetInitStates());  
87 InitOnlineStream(stream.get()); 84 InitOnlineStream(stream.get());
88 return stream; 85 return stream;
89 } 86 }
@@ -94,14 +91,12 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { @@ -94,14 +91,12 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
94 } 91 }
95 92
96 OnlineRecognizerResult GetResult(OnlineStream *s) const override { 93 OnlineRecognizerResult GetResult(OnlineStream *s) const override {
97 - OnlineTransducerDecoderResult decoder_result = s->GetResult();  
98 - decoder_->StripLeadingBlanks(&decoder_result);  
99 -  
100 // TODO(fangjun): Remember to change these constants if needed 94 // TODO(fangjun): Remember to change these constants if needed
101 int32_t frame_shift_ms = 10; 95 int32_t frame_shift_ms = 10;
102 - int32_t subsampling_factor = 8;  
103 - return Convert(decoder_result, symbol_table_, frame_shift_ms, subsampling_factor,  
104 - s->GetCurrentSegment(), s->GetNumFramesSinceStart()); 96 + int32_t subsampling_factor = model_->SubsamplingFactor();
  97 + return Convert(s->GetResult(), symbol_table_, frame_shift_ms,
  98 + subsampling_factor, s->GetCurrentSegment(),
  99 + s->GetNumFramesSinceStart());
105 } 100 }
106 101
107 bool IsEndpoint(OnlineStream *s) const override { 102 bool IsEndpoint(OnlineStream *s) const override {
@@ -114,8 +109,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { @@ -114,8 +109,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
114 // frame shift is 10 milliseconds 109 // frame shift is 10 milliseconds
115 float frame_shift_in_seconds = 0.01; 110 float frame_shift_in_seconds = 0.01;
116 111
117 - // subsampling factor is 8  
118 - int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 8; 112 + int32_t trailing_silence_frames =
  113 + s->GetResult().num_trailing_blanks * model_->SubsamplingFactor();
119 114
120 return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames, 115 return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames,
121 frame_shift_in_seconds); 116 frame_shift_in_seconds);
@@ -126,19 +121,16 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { @@ -126,19 +121,16 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
126 // segment is incremented only when the last 121 // segment is incremented only when the last
127 // result is not empty 122 // result is not empty
128 const auto &r = s->GetResult(); 123 const auto &r = s->GetResult();
129 - if (!r.tokens.empty() && r.tokens.back() != 0) { 124 + if (!r.tokens.empty()) {
130 s->GetCurrentSegment() += 1; 125 s->GetCurrentSegment() += 1;
131 } 126 }
132 } 127 }
133 128
134 - // we keep the decoder_out  
135 - decoder_->UpdateDecoderOut(&s->GetResult());  
136 - Ort::Value decoder_out = std::move(s->GetResult().decoder_out); 129 + s->SetResult({});
  130 +
  131 + s->SetStates(model_->GetEncoderInitStates());
137 132
138 - auto r = decoder_->GetEmptyResult();  
139 -  
140 - s->SetResult(r);  
141 - s->GetResult().decoder_out = std::move(decoder_out); 133 + s->SetNeMoDecoderStates(model_->GetDecoderInitStates());
142 134
143 // Note: We only update counters. The underlying audio samples 135 // Note: We only update counters. The underlying audio samples
144 // are not discarded. 136 // are not discarded.
@@ -151,10 +143,9 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { @@ -151,10 +143,9 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
151 143
152 int32_t feature_dim = ss[0]->FeatureDim(); 144 int32_t feature_dim = ss[0]->FeatureDim();
153 145
154 - std::vector<OnlineTransducerDecoderResult> result(n);  
155 std::vector<float> features_vec(n * chunk_size * feature_dim); 146 std::vector<float> features_vec(n * chunk_size * feature_dim);
156 std::vector<std::vector<Ort::Value>> encoder_states(n); 147 std::vector<std::vector<Ort::Value>> encoder_states(n);
157 - 148 +
158 for (int32_t i = 0; i != n; ++i) { 149 for (int32_t i = 0; i != n; ++i) {
159 const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); 150 const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
160 std::vector<float> features = 151 std::vector<float> features =
@@ -166,9 +157,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { @@ -166,9 +157,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
166 std::copy(features.begin(), features.end(), 157 std::copy(features.begin(), features.end(),
167 features_vec.data() + i * chunk_size * feature_dim); 158 features_vec.data() + i * chunk_size * feature_dim);
168 159
169 - result[i] = std::move(ss[i]->GetResult());  
170 encoder_states[i] = std::move(ss[i]->GetStates()); 160 encoder_states[i] = std::move(ss[i]->GetStates());
171 -  
172 } 161 }
173 162
174 auto memory_info = 163 auto memory_info =
@@ -180,42 +169,35 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { @@ -180,42 +169,35 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
180 features_vec.size(), x_shape.data(), 169 features_vec.size(), x_shape.data(),
181 x_shape.size()); 170 x_shape.size());
182 171
183 - // Batch size is 1  
184 - auto states = std::move(encoder_states[0]);  
185 - int32_t num_states = states.size(); // num_states = 3 172 + auto states = model_->StackStates(std::move(encoder_states));
  173 + int32_t num_states = states.size(); // num_states = 3
186 auto t = model_->RunEncoder(std::move(x), std::move(states)); 174 auto t = model_->RunEncoder(std::move(x), std::move(states));
187 // t[0] encoder_out, float tensor, (batch_size, dim, T) 175 // t[0] encoder_out, float tensor, (batch_size, dim, T)
188 // t[1] next states 176 // t[1] next states
189 - 177 +
190 std::vector<Ort::Value> out_states; 178 std::vector<Ort::Value> out_states;
191 out_states.reserve(num_states); 179 out_states.reserve(num_states);
192 - 180 +
193 for (int32_t k = 1; k != num_states + 1; ++k) { 181 for (int32_t k = 1; k != num_states + 1; ++k) {
194 out_states.push_back(std::move(t[k])); 182 out_states.push_back(std::move(t[k]));
195 } 183 }
196 184
  185 + auto unstacked_states = model_->UnStackStates(std::move(out_states));
  186 + for (int32_t i = 0; i != n; ++i) {
  187 + ss[i]->SetStates(std::move(unstacked_states[i]));
  188 + }
  189 +
197 Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]); 190 Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]);
198 -  
199 - // defined in online-transducer-greedy-search-nemo-decoder.h  
200 - // get intial states of decoder.  
201 - std::vector<Ort::Value> &decoder_states = ss[0]->GetNeMoDecoderStates();  
202 -  
203 - // Subsequent decoder states (for each chunks) are updated inside the Decode method.  
204 - // This returns the decoder state from the LAST chunk. We probably dont need it. So we can discard it.  
205 - decoder_states = decoder_->Decode(std::move(encoder_out),  
206 - std::move(decoder_states),  
207 - &result, ss, n);  
208 -  
209 - ss[0]->SetResult(result[0]);  
210 -  
211 - ss[0]->SetStates(std::move(out_states)); 191 +
  192 + decoder_->Decode(std::move(encoder_out), ss, n);
212 } 193 }
213 194
214 void InitOnlineStream(OnlineStream *stream) const { 195 void InitOnlineStream(OnlineStream *stream) const {
215 - auto r = decoder_->GetEmptyResult(); 196 + // set encoder states
  197 + stream->SetStates(model_->GetEncoderInitStates());
216 198
217 - stream->SetResult(r);  
218 - stream->SetNeMoDecoderStates(model_->GetDecoderInitStates(1)); 199 + // set decoder states
  200 + stream->SetNeMoDecoderStates(model_->GetDecoderInitStates());
219 } 201 }
220 202
221 private: 203 private:
@@ -250,7 +232,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { @@ -250,7 +232,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
250 symbol_table_.NumSymbols(), vocab_size); 232 symbol_table_.NumSymbols(), vocab_size);
251 exit(-1); 233 exit(-1);
252 } 234 }
253 -  
254 } 235 }
255 236
256 private: 237 private:
@@ -259,9 +240,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { @@ -259,9 +240,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
259 std::unique_ptr<OnlineTransducerNeMoModel> model_; 240 std::unique_ptr<OnlineTransducerNeMoModel> model_;
260 std::unique_ptr<OnlineTransducerGreedySearchNeMoDecoder> decoder_; 241 std::unique_ptr<OnlineTransducerGreedySearchNeMoDecoder> decoder_;
261 Endpoint endpoint_; 242 Endpoint endpoint_;
262 -  
263 }; 243 };
264 244
265 } // namespace sherpa_onnx 245 } // namespace sherpa_onnx
266 246
267 -#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_  
  247 +#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
@@ -225,7 +225,8 @@ std::vector<Ort::Value> &OnlineStream::GetStates() { @@ -225,7 +225,8 @@ std::vector<Ort::Value> &OnlineStream::GetStates() {
225 return impl_->GetStates(); 225 return impl_->GetStates();
226 } 226 }
227 227
228 -void OnlineStream::SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states) { 228 +void OnlineStream::SetNeMoDecoderStates(
  229 + std::vector<Ort::Value> decoder_states) {
229 return impl_->SetNeMoDecoderStates(std::move(decoder_states)); 230 return impl_->SetNeMoDecoderStates(std::move(decoder_states));
230 } 231 }
231 232
@@ -91,8 +91,8 @@ class OnlineStream { @@ -91,8 +91,8 @@ class OnlineStream {
91 void SetStates(std::vector<Ort::Value> states); 91 void SetStates(std::vector<Ort::Value> states);
92 std::vector<Ort::Value> &GetStates(); 92 std::vector<Ort::Value> &GetStates();
93 93
94 - void SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states);  
95 - std::vector<Ort::Value> &GetNeMoDecoderStates(); 94 + void SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states);
  95 + std::vector<Ort::Value> &GetNeMoDecoderStates();
96 96
97 /** 97 /**
98 * Get the context graph corresponding to this stream. 98 * Get the context graph corresponding to this stream.
@@ -10,103 +10,64 @@ @@ -10,103 +10,64 @@
10 #include <utility> 10 #include <utility>
11 11
12 #include "sherpa-onnx/csrc/macros.h" 12 #include "sherpa-onnx/csrc/macros.h"
  13 +#include "sherpa-onnx/csrc/online-stream.h"
13 #include "sherpa-onnx/csrc/onnx-utils.h" 14 #include "sherpa-onnx/csrc/onnx-utils.h"
14 15
15 namespace sherpa_onnx { 16 namespace sherpa_onnx {
16 17
17 -static std::pair<Ort::Value, Ort::Value> BuildDecoderInput(  
18 - int32_t token, OrtAllocator *allocator) { 18 +static Ort::Value BuildDecoderInput(int32_t token, OrtAllocator *allocator) {
19 std::array<int64_t, 2> shape{1, 1}; 19 std::array<int64_t, 2> shape{1, 1};
20 20
21 Ort::Value decoder_input = 21 Ort::Value decoder_input =
22 Ort::Value::CreateTensor<int32_t>(allocator, shape.data(), shape.size()); 22 Ort::Value::CreateTensor<int32_t>(allocator, shape.data(), shape.size());
23 23
24 - std::array<int64_t, 1> length_shape{1};  
25 - Ort::Value decoder_input_length = Ort::Value::CreateTensor<int32_t>(  
26 - allocator, length_shape.data(), length_shape.size());  
27 -  
28 int32_t *p = decoder_input.GetTensorMutableData<int32_t>(); 24 int32_t *p = decoder_input.GetTensorMutableData<int32_t>();
29 25
30 - int32_t *p_length = decoder_input_length.GetTensorMutableData<int32_t>();  
31 -  
32 p[0] = token; 26 p[0] = token;
33 27
34 - p_length[0] = 1;  
35 -  
36 - return {std::move(decoder_input), std::move(decoder_input_length)};  
37 -}  
38 -  
39 -  
40 -OnlineTransducerDecoderResult  
41 -OnlineTransducerGreedySearchNeMoDecoder::GetEmptyResult() const {  
42 - int32_t context_size = 8;  
43 - int32_t blank_id = 0; // always 0  
44 - OnlineTransducerDecoderResult r;  
45 - r.tokens.resize(context_size, -1);  
46 - r.tokens.back() = blank_id;  
47 -  
48 - return r; 28 + return decoder_input;
49 } 29 }
50 30
51 -static void UpdateCachedDecoderOut(  
52 - OrtAllocator *allocator, const Ort::Value *decoder_out,  
53 - std::vector<OnlineTransducerDecoderResult> *result) {  
54 - std::vector<int64_t> shape =  
55 - decoder_out->GetTensorTypeAndShapeInfo().GetShape(); 31 +static void DecodeOne(const float *encoder_out, int32_t num_rows,
  32 + int32_t num_cols, OnlineTransducerNeMoModel *model,
  33 + float blank_penalty, OnlineStream *s) {
56 auto memory_info = 34 auto memory_info =
57 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); 35 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
58 - std::array<int64_t, 2> v_shape{1, shape[1]};  
59 36
60 - const float *src = decoder_out->GetTensorData<float>();  
61 - for (auto &r : *result) {  
62 - if (!r.decoder_out) {  
63 - r.decoder_out = Ort::Value::CreateTensor<float>(allocator, v_shape.data(),  
64 - v_shape.size());  
65 - } 37 + int32_t vocab_size = model->VocabSize();
  38 + int32_t blank_id = vocab_size - 1;
66 39
67 - float *dst = r.decoder_out.GetTensorMutableData<float>();  
68 - std::copy(src, src + shape[1], dst);  
69 - src += shape[1];  
70 - }  
71 -} 40 + auto &r = s->GetResult();
72 41
73 -std::vector<Ort::Value> DecodeOne(  
74 - const float *encoder_out, int32_t num_rows, int32_t num_cols,  
75 - OnlineTransducerNeMoModel *model, float blank_penalty,  
76 - std::vector<Ort::Value>& decoder_states,  
77 - std::vector<OnlineTransducerDecoderResult> *result) { 42 + Ort::Value decoder_out{nullptr};
78 43
79 - auto memory_info =  
80 - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); 44 + auto decoder_input = BuildDecoderInput(
  45 + r.tokens.empty() ? blank_id : r.tokens.back(), model->Allocator());
81 46
82 - // OnlineTransducerDecoderResult result;  
83 - int32_t vocab_size = model->VocabSize();  
84 - int32_t blank_id = vocab_size - 1;  
85 -  
86 - auto &r = (*result)[0];  
87 - Ort::Value decoder_out{nullptr}; 47 + std::vector<Ort::Value> &last_decoder_states = s->GetNeMoDecoderStates();
88 48
89 - auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator());  
90 - // decoder_input_pair[0]: decoder_input  
91 - // decoder_input_pair[1]: decoder_input_length (discarded) 49 + std::vector<Ort::Value> tmp_decoder_states;
  50 + tmp_decoder_states.reserve(last_decoder_states.size());
  51 + for (auto &v : last_decoder_states) {
  52 + tmp_decoder_states.push_back(View(&v));
  53 + }
92 54
93 // decoder_output_pair.second returns the next decoder state 55 // decoder_output_pair.second returns the next decoder state
94 std::pair<Ort::Value, std::vector<Ort::Value>> decoder_output_pair = 56 std::pair<Ort::Value, std::vector<Ort::Value>> decoder_output_pair =
95 - model->RunDecoder(std::move(decoder_input_pair.first),  
96 - std::move(decoder_states)); // here decoder_states = {len=0, cap=0}. But decoder_output_pair= {first, second: {len=2, cap=2}} // ATTN 57 + model->RunDecoder(std::move(decoder_input),
  58 + std::move(tmp_decoder_states));
97 59
98 std::array<int64_t, 3> encoder_shape{1, num_cols, 1}; 60 std::array<int64_t, 3> encoder_shape{1, num_cols, 1};
99 61
100 - decoder_states = std::move(decoder_output_pair.second); 62 + bool emitted = false;
101 63
102 - // TODO: Inside this loop, I need to framewise decoding.  
103 for (int32_t t = 0; t != num_rows; ++t) { 64 for (int32_t t = 0; t != num_rows; ++t) {
104 Ort::Value cur_encoder_out = Ort::Value::CreateTensor( 65 Ort::Value cur_encoder_out = Ort::Value::CreateTensor(
105 memory_info, const_cast<float *>(encoder_out) + t * num_cols, num_cols, 66 memory_info, const_cast<float *>(encoder_out) + t * num_cols, num_cols,
106 encoder_shape.data(), encoder_shape.size()); 67 encoder_shape.data(), encoder_shape.size());
107 68
108 Ort::Value logit = model->RunJoiner(std::move(cur_encoder_out), 69 Ort::Value logit = model->RunJoiner(std::move(cur_encoder_out),
109 - View(&decoder_output_pair.first)); 70 + View(&decoder_output_pair.first));
110 71
111 float *p_logit = logit.GetTensorMutableData<float>(); 72 float *p_logit = logit.GetTensorMutableData<float>();
112 if (blank_penalty > 0) { 73 if (blank_penalty > 0) {
@@ -117,82 +78,52 @@ std::vector<Ort::Value> DecodeOne( @@ -117,82 +78,52 @@ std::vector<Ort::Value> DecodeOne(
117 static_cast<const float *>(p_logit), 78 static_cast<const float *>(p_logit),
118 std::max_element(static_cast<const float *>(p_logit), 79 std::max_element(static_cast<const float *>(p_logit),
119 static_cast<const float *>(p_logit) + vocab_size))); 80 static_cast<const float *>(p_logit) + vocab_size)));
120 - SHERPA_ONNX_LOGE("y=%d", y); 81 +
121 if (y != blank_id) { 82 if (y != blank_id) {
  83 + emitted = true;
122 r.tokens.push_back(y); 84 r.tokens.push_back(y);
123 r.timestamps.push_back(t + r.frame_offset); 85 r.timestamps.push_back(t + r.frame_offset);
  86 + r.num_trailing_blanks = 0;
124 87
125 - decoder_input_pair = BuildDecoderInput(y, model->Allocator()); 88 + decoder_input = BuildDecoderInput(y, model->Allocator());
126 89
127 // last decoder state becomes the current state for the first chunk 90 // last decoder state becomes the current state for the first chunk
128 - decoder_output_pair =  
129 - model->RunDecoder(std::move(decoder_input_pair.first),  
130 - std::move(decoder_states));  
131 -  
132 - // Update the decoder states for the next chunk  
133 - decoder_states = std::move(decoder_output_pair.second); 91 + decoder_output_pair = model->RunDecoder(
  92 + std::move(decoder_input), std::move(decoder_output_pair.second));
  93 + } else {
  94 + ++r.num_trailing_blanks;
134 } 95 }
135 } 96 }
136 97
137 - decoder_out = std::move(decoder_output_pair.first);  
138 -// UpdateCachedDecoderOut(model->Allocator(), &decoder_out, result);  
139 -  
140 - // Update frame_offset  
141 - for (auto &r : *result) {  
142 - r.frame_offset += num_rows; 98 + if (emitted) {
  99 + s->SetNeMoDecoderStates(std::move(decoder_output_pair.second));
143 } 100 }
144 101
145 - return std::move(decoder_states); 102 + r.frame_offset += num_rows;
146 } 103 }
147 104
148 -  
149 -std::vector<Ort::Value> OnlineTransducerGreedySearchNeMoDecoder::Decode(  
150 - Ort::Value encoder_out,  
151 - std::vector<Ort::Value> decoder_states,  
152 - std::vector<OnlineTransducerDecoderResult> *result,  
153 - OnlineStream ** /*ss = nullptr*/, int32_t /*n= 0*/) {  
154 - 105 +void OnlineTransducerGreedySearchNeMoDecoder::Decode(Ort::Value encoder_out,
  106 + OnlineStream **ss,
  107 + int32_t n) const {
155 auto shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape(); 108 auto shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape();
  109 + int32_t batch_size = static_cast<int32_t>(shape[0]); // bs = 1
156 110
157 - if (shape[0] != result->size()) {  
158 - SHERPA_ONNX_LOGE(  
159 - "Size mismatch! encoder_out.size(0) %d, result.size(0): %d",  
160 - static_cast<int32_t>(shape[0]),  
161 - static_cast<int32_t>(result->size())); 111 + if (batch_size != n) {
  112 + SHERPA_ONNX_LOGE("Size mismatch! encoder_out.size(0) %d, n: %d",
  113 + static_cast<int32_t>(shape[0]), n);
162 exit(-1); 114 exit(-1);
163 } 115 }
164 116
165 - int32_t batch_size = static_cast<int32_t>(shape[0]); // bs = 1  
166 - int32_t dim1 = static_cast<int32_t>(shape[1]); // 2  
167 - int32_t dim2 = static_cast<int32_t>(shape[2]); // 512  
168 -  
169 - // Define and initialize encoder_out_length  
170 - Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);  
171 -  
172 - int64_t length_value = 1;  
173 - std::vector<int64_t> length_shape = {1};  
174 -  
175 - Ort::Value encoder_out_length = Ort::Value::CreateTensor<int64_t>(  
176 - memory_info, &length_value, 1, length_shape.data(), length_shape.size()  
177 - );  
178 -  
179 - const int64_t *p_length = encoder_out_length.GetTensorData<int64_t>();  
180 - const float *p = encoder_out.GetTensorData<float>(); 117 + int32_t dim1 = static_cast<int32_t>(shape[1]); // T
  118 + int32_t dim2 = static_cast<int32_t>(shape[2]); // encoder_out_dim
181 119
182 - // std::vector<OnlineTransducerDecoderResult> ans(batch_size); 120 + const float *p = encoder_out.GetTensorData<float>();
183 121
184 for (int32_t i = 0; i != batch_size; ++i) { 122 for (int32_t i = 0; i != batch_size; ++i) {
185 const float *this_p = p + dim1 * dim2 * i; 123 const float *this_p = p + dim1 * dim2 * i;
186 - int32_t this_len = p_length[i];  
187 124
188 - // outputs the decoder state from last chunk.  
189 - auto last_decoder_states = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_, decoder_states, result);  
190 - // ans[i] = decode_result_pair.first;  
191 - decoder_states = std::move(last_decoder_states); 125 + DecodeOne(this_p, dim1, dim2, model_, blank_penalty_, ss[i]);
192 } 126 }
193 -  
194 - return decoder_states;  
195 -  
196 } 127 }
197 128
198 -} // namespace sherpa_onnx  
  129 +} // namespace sherpa_onnx
@@ -7,27 +7,22 @@ @@ -7,27 +7,22 @@
7 #define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ 7 #define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
8 8
9 #include <vector> 9 #include <vector>
  10 +
10 #include "sherpa-onnx/csrc/online-transducer-decoder.h" 11 #include "sherpa-onnx/csrc/online-transducer-decoder.h"
11 #include "sherpa-onnx/csrc/online-transducer-nemo-model.h" 12 #include "sherpa-onnx/csrc/online-transducer-nemo-model.h"
12 13
13 namespace sherpa_onnx { 14 namespace sherpa_onnx {
14 15
  16 +class OnlineStream;
  17 +
15 class OnlineTransducerGreedySearchNeMoDecoder { 18 class OnlineTransducerGreedySearchNeMoDecoder {
16 public: 19 public:
17 OnlineTransducerGreedySearchNeMoDecoder(OnlineTransducerNeMoModel *model, 20 OnlineTransducerGreedySearchNeMoDecoder(OnlineTransducerNeMoModel *model,
18 float blank_penalty) 21 float blank_penalty)
19 - : model_(model),  
20 - blank_penalty_(blank_penalty) {}  
21 -  
22 - OnlineTransducerDecoderResult GetEmptyResult() const;  
23 - void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {}  
24 - void StripLeadingBlanks(OnlineTransducerDecoderResult * /*r*/) const {}  
25 -  
26 - std::vector<Ort::Value> Decode(  
27 - Ort::Value encoder_out,  
28 - std::vector<Ort::Value> decoder_states,  
29 - std::vector<OnlineTransducerDecoderResult> *result,  
30 - OnlineStream **ss = nullptr, int32_t n = 0); 22 + : model_(model), blank_penalty_(blank_penalty) {}
  23 +
  24 + // @param n number of elements in ss
  25 + void Decode(Ort::Value encoder_out, OnlineStream **ss, int32_t n) const;
31 26
32 private: 27 private:
33 OnlineTransducerNeMoModel *model_; // Not owned 28 OnlineTransducerNeMoModel *model_; // Not owned
@@ -37,4 +32,3 @@ class OnlineTransducerGreedySearchNeMoDecoder { @@ -37,4 +32,3 @@ class OnlineTransducerGreedySearchNeMoDecoder {
37 } // namespace sherpa_onnx 32 } // namespace sherpa_onnx
38 33
39 #endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ 34 #endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
40 -  
@@ -54,7 +54,7 @@ class OnlineTransducerNeMoModel::Impl { @@ -54,7 +54,7 @@ class OnlineTransducerNeMoModel::Impl {
54 InitJoiner(buf.data(), buf.size()); 54 InitJoiner(buf.data(), buf.size());
55 } 55 }
56 } 56 }
57 - 57 +
58 #if __ANDROID_API__ >= 9 58 #if __ANDROID_API__ >= 9
59 Impl(AAssetManager *mgr, const OnlineModelConfig &config) 59 Impl(AAssetManager *mgr, const OnlineModelConfig &config)
60 : config_(config), 60 : config_(config),
@@ -79,7 +79,7 @@ class OnlineTransducerNeMoModel::Impl { @@ -79,7 +79,7 @@ class OnlineTransducerNeMoModel::Impl {
79 #endif 79 #endif
80 80
81 std::vector<Ort::Value> RunEncoder(Ort::Value features, 81 std::vector<Ort::Value> RunEncoder(Ort::Value features,
82 - std::vector<Ort::Value> states) { 82 + std::vector<Ort::Value> states) {
83 Ort::Value &cache_last_channel = states[0]; 83 Ort::Value &cache_last_channel = states[0];
84 Ort::Value &cache_last_time = states[1]; 84 Ort::Value &cache_last_time = states[1];
85 Ort::Value &cache_last_channel_len = states[2]; 85 Ort::Value &cache_last_channel_len = states[2];
@@ -102,9 +102,9 @@ class OnlineTransducerNeMoModel::Impl { @@ -102,9 +102,9 @@ class OnlineTransducerNeMoModel::Impl {
102 std::move(features), View(&length), std::move(cache_last_channel), 102 std::move(features), View(&length), std::move(cache_last_channel),
103 std::move(cache_last_time), std::move(cache_last_channel_len)}; 103 std::move(cache_last_time), std::move(cache_last_channel_len)};
104 104
105 - auto out =  
106 - encoder_sess_->Run({}, encoder_input_names_ptr_.data(), inputs.data(), inputs.size(),  
107 - encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size()); 105 + auto out = encoder_sess_->Run(
  106 + {}, encoder_input_names_ptr_.data(), inputs.data(), inputs.size(),
  107 + encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size());
108 // out[0]: logit 108 // out[0]: logit
109 // out[1] logit_length 109 // out[1] logit_length
110 // out[2:] states_next 110 // out[2:] states_next
@@ -127,17 +127,19 @@ class OnlineTransducerNeMoModel::Impl { @@ -127,17 +127,19 @@ class OnlineTransducerNeMoModel::Impl {
127 127
128 std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder( 128 std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder(
129 Ort::Value targets, std::vector<Ort::Value> states) { 129 Ort::Value targets, std::vector<Ort::Value> states) {
130 -  
131 - Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); 130 + Ort::MemoryInfo memory_info =
  131 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
  132 +
  133 + auto shape = targets.GetTensorTypeAndShapeInfo().GetShape();
  134 + int32_t batch_size = static_cast<int32_t>(shape[0]);
132 135
133 - // Create the tensor with a single int32_t value of 1  
134 - int32_t length_value = 1;  
135 - std::vector<int64_t> length_shape = {1}; 136 + std::vector<int64_t> length_shape = {batch_size};
  137 + std::vector<int32_t> length_value(batch_size, 1);
136 138
137 Ort::Value targets_length = Ort::Value::CreateTensor<int32_t>( 139 Ort::Value targets_length = Ort::Value::CreateTensor<int32_t>(
138 - memory_info, &length_value, 1, length_shape.data(), length_shape.size()  
139 - );  
140 - 140 + memory_info, length_value.data(), batch_size, length_shape.data(),
  141 + length_shape.size());
  142 +
141 std::vector<Ort::Value> decoder_inputs; 143 std::vector<Ort::Value> decoder_inputs;
142 decoder_inputs.reserve(2 + states.size()); 144 decoder_inputs.reserve(2 + states.size());
143 145
@@ -171,35 +173,21 @@ class OnlineTransducerNeMoModel::Impl { @@ -171,35 +173,21 @@ class OnlineTransducerNeMoModel::Impl {
171 Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) { 173 Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) {
172 std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out), 174 std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
173 std::move(decoder_out)}; 175 std::move(decoder_out)};
174 - auto logit =  
175 - joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(),  
176 - joiner_input.size(), joiner_output_names_ptr_.data(),  
177 - joiner_output_names_ptr_.size()); 176 + auto logit = joiner_sess_->Run({}, joiner_input_names_ptr_.data(),
  177 + joiner_input.data(), joiner_input.size(),
  178 + joiner_output_names_ptr_.data(),
  179 + joiner_output_names_ptr_.size());
178 180
179 return std::move(logit[0]); 181 return std::move(logit[0]);
180 -}  
181 -  
182 - std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const {  
183 - std::array<int64_t, 3> s0_shape{pred_rnn_layers_, batch_size, pred_hidden_};  
184 - Ort::Value s0 = Ort::Value::CreateTensor<float>(allocator_, s0_shape.data(),  
185 - s0_shape.size());  
186 -  
187 - Fill<float>(&s0, 0);  
188 -  
189 - std::array<int64_t, 3> s1_shape{pred_rnn_layers_, batch_size, pred_hidden_};  
190 -  
191 - Ort::Value s1 = Ort::Value::CreateTensor<float>(allocator_, s1_shape.data(),  
192 - s1_shape.size());  
193 -  
194 - Fill<float>(&s1, 0);  
195 -  
196 - std::vector<Ort::Value> states; 182 + }
197 183
198 - states.reserve(2);  
199 - states.push_back(std::move(s0));  
200 - states.push_back(std::move(s1)); 184 + std::vector<Ort::Value> GetDecoderInitStates() {
  185 + std::vector<Ort::Value> ans;
  186 + ans.reserve(2);
  187 + ans.push_back(View(&lstm0_));
  188 + ans.push_back(View(&lstm1_));
201 189
202 - return states; 190 + return ans;
203 } 191 }
204 192
205 int32_t ChunkSize() const { return window_size_; } 193 int32_t ChunkSize() const { return window_size_; }
@@ -207,7 +195,7 @@ class OnlineTransducerNeMoModel::Impl { @@ -207,7 +195,7 @@ class OnlineTransducerNeMoModel::Impl {
207 int32_t ChunkShift() const { return chunk_shift_; } 195 int32_t ChunkShift() const { return chunk_shift_; }
208 196
209 int32_t SubsamplingFactor() const { return subsampling_factor_; } 197 int32_t SubsamplingFactor() const { return subsampling_factor_; }
210 - 198 +
211 int32_t VocabSize() const { return vocab_size_; } 199 int32_t VocabSize() const { return vocab_size_; }
212 200
213 OrtAllocator *Allocator() const { return allocator_; } 201 OrtAllocator *Allocator() const { return allocator_; }
@@ -218,7 +206,7 @@ class OnlineTransducerNeMoModel::Impl { @@ -218,7 +206,7 @@ class OnlineTransducerNeMoModel::Impl {
218 // - cache_last_channel 206 // - cache_last_channel
219 // - cache_last_time_ 207 // - cache_last_time_
220 // - cache_last_channel_len 208 // - cache_last_channel_len
221 - std::vector<Ort::Value> GetInitStates() { 209 + std::vector<Ort::Value> GetEncoderInitStates() {
222 std::vector<Ort::Value> ans; 210 std::vector<Ort::Value> ans;
223 ans.reserve(3); 211 ans.reserve(3);
224 ans.push_back(View(&cache_last_channel_)); 212 ans.push_back(View(&cache_last_channel_));
@@ -228,7 +216,75 @@ class OnlineTransducerNeMoModel::Impl { @@ -228,7 +216,75 @@ class OnlineTransducerNeMoModel::Impl {
228 return ans; 216 return ans;
229 } 217 }
230 218
231 -private: 219 + std::vector<Ort::Value> StackStates(
  220 + std::vector<std::vector<Ort::Value>> states) const {
  221 + int32_t batch_size = static_cast<int32_t>(states.size());
  222 + if (batch_size == 1) {
  223 + return std::move(states[0]);
  224 + }
  225 +
  226 + std::vector<Ort::Value> ans;
  227 +
  228 + // stack cache_last_channel
  229 + std::vector<const Ort::Value *> buf(batch_size);
  230 +
  231 + // there are 3 states to be stacked
  232 + for (int32_t i = 0; i != 3; ++i) {
  233 + buf.clear();
  234 + buf.reserve(batch_size);
  235 +
  236 + for (int32_t b = 0; b != batch_size; ++b) {
  237 + assert(states[b].size() == 3);
  238 + buf.push_back(&states[b][i]);
  239 + }
  240 +
  241 + Ort::Value c{nullptr};
  242 + if (i == 2) {
  243 + c = Cat<int64_t>(allocator_, buf, 0);
  244 + } else {
  245 + c = Cat(allocator_, buf, 0);
  246 + }
  247 +
  248 + ans.push_back(std::move(c));
  249 + }
  250 +
  251 + return ans;
  252 + }
  253 +
  254 + std::vector<std::vector<Ort::Value>> UnStackStates(
  255 + std::vector<Ort::Value> states) const {
  256 + assert(states.size() == 3);
  257 +
  258 + std::vector<std::vector<Ort::Value>> ans;
  259 +
  260 + auto shape = states[0].GetTensorTypeAndShapeInfo().GetShape();
  261 + int32_t batch_size = shape[0];
  262 + ans.resize(batch_size);
  263 +
  264 + if (batch_size == 1) {
  265 + ans[0] = std::move(states);
  266 + return ans;
  267 + }
  268 +
  269 + for (int32_t i = 0; i != 3; ++i) {
  270 + std::vector<Ort::Value> v;
  271 + if (i == 2) {
  272 + v = Unbind<int64_t>(allocator_, &states[i], 0);
  273 + } else {
  274 + v = Unbind(allocator_, &states[i], 0);
  275 + }
  276 +
  277 + assert(v.size() == batch_size);
  278 +
  279 + for (int32_t b = 0; b != batch_size; ++b) {
  280 + ans[b].push_back(std::move(v[b]));
  281 + }
  282 + }
  283 +
  284 + return ans;
  285 + }
  286 +
  287 + private:
232 void InitEncoder(void *model_data, size_t model_data_length) { 288 void InitEncoder(void *model_data, size_t model_data_length) {
233 encoder_sess_ = std::make_unique<Ort::Session>( 289 encoder_sess_ = std::make_unique<Ort::Session>(
234 env_, model_data, model_data_length, sess_opts_); 290 env_, model_data, model_data_length, sess_opts_);
@@ -276,10 +332,10 @@ private: @@ -276,10 +332,10 @@ private:
276 normalize_type_ = ""; 332 normalize_type_ = "";
277 } 333 }
278 334
279 - InitStates(); 335 + InitEncoderStates();
280 } 336 }
281 -  
282 - void InitStates() { 337 +
  338 + void InitEncoderStates() {
283 std::array<int64_t, 4> cache_last_channel_shape{1, cache_last_channel_dim1_, 339 std::array<int64_t, 4> cache_last_channel_shape{1, cache_last_channel_dim1_,
284 cache_last_channel_dim2_, 340 cache_last_channel_dim2_,
285 cache_last_channel_dim3_}; 341 cache_last_channel_dim3_};
@@ -313,7 +369,25 @@ private: @@ -313,7 +369,25 @@ private:
313 &decoder_input_names_ptr_); 369 &decoder_input_names_ptr_);
314 370
315 GetOutputNames(decoder_sess_.get(), &decoder_output_names_, 371 GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
316 - &decoder_output_names_ptr_); 372 + &decoder_output_names_ptr_);
  373 +
  374 + InitDecoderStates();
  375 + }
  376 +
  377 + void InitDecoderStates() {
  378 + int32_t batch_size = 1;
  379 + std::array<int64_t, 3> s0_shape{pred_rnn_layers_, batch_size, pred_hidden_};
  380 + lstm0_ = Ort::Value::CreateTensor<float>(allocator_, s0_shape.data(),
  381 + s0_shape.size());
  382 +
  383 + Fill<float>(&lstm0_, 0);
  384 +
  385 + std::array<int64_t, 3> s1_shape{pred_rnn_layers_, batch_size, pred_hidden_};
  386 +
  387 + lstm1_ = Ort::Value::CreateTensor<float>(allocator_, s1_shape.data(),
  388 + s1_shape.size());
  389 +
  390 + Fill<float>(&lstm1_, 0);
317 } 391 }
318 392
319 void InitJoiner(void *model_data, size_t model_data_length) { 393 void InitJoiner(void *model_data, size_t model_data_length) {
@@ -324,7 +398,7 @@ private: @@ -324,7 +398,7 @@ private:
324 &joiner_input_names_ptr_); 398 &joiner_input_names_ptr_);
325 399
326 GetOutputNames(joiner_sess_.get(), &joiner_output_names_, 400 GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
327 - &joiner_output_names_ptr_); 401 + &joiner_output_names_ptr_);
328 } 402 }
329 403
330 private: 404 private:
@@ -363,6 +437,7 @@ private: @@ -363,6 +437,7 @@ private:
363 int32_t pred_rnn_layers_ = -1; 437 int32_t pred_rnn_layers_ = -1;
364 int32_t pred_hidden_ = -1; 438 int32_t pred_hidden_ = -1;
365 439
  440 + // encoder states
366 int32_t cache_last_channel_dim1_; 441 int32_t cache_last_channel_dim1_;
367 int32_t cache_last_channel_dim2_; 442 int32_t cache_last_channel_dim2_;
368 int32_t cache_last_channel_dim3_; 443 int32_t cache_last_channel_dim3_;
@@ -370,9 +445,14 @@ private: @@ -370,9 +445,14 @@ private:
370 int32_t cache_last_time_dim2_; 445 int32_t cache_last_time_dim2_;
371 int32_t cache_last_time_dim3_; 446 int32_t cache_last_time_dim3_;
372 447
  448 + // init encoder states
373 Ort::Value cache_last_channel_{nullptr}; 449 Ort::Value cache_last_channel_{nullptr};
374 Ort::Value cache_last_time_{nullptr}; 450 Ort::Value cache_last_time_{nullptr};
375 Ort::Value cache_last_channel_len_{nullptr}; 451 Ort::Value cache_last_channel_len_{nullptr};
  452 +
  453 + // init decoder states
  454 + Ort::Value lstm0_{nullptr};
  455 + Ort::Value lstm1_{nullptr};
376 }; 456 };
377 457
378 OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( 458 OnlineTransducerNeMoModel::OnlineTransducerNeMoModel(
@@ -387,10 +467,9 @@ OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( @@ -387,10 +467,9 @@ OnlineTransducerNeMoModel::OnlineTransducerNeMoModel(
387 467
388 OnlineTransducerNeMoModel::~OnlineTransducerNeMoModel() = default; 468 OnlineTransducerNeMoModel::~OnlineTransducerNeMoModel() = default;
389 469
390 -std::vector<Ort::Value>  
391 -OnlineTransducerNeMoModel::RunEncoder(Ort::Value features,  
392 - std::vector<Ort::Value> states) const {  
393 - return impl_->RunEncoder(std::move(features), std::move(states)); 470 +std::vector<Ort::Value> OnlineTransducerNeMoModel::RunEncoder(
  471 + Ort::Value features, std::vector<Ort::Value> states) const {
  472 + return impl_->RunEncoder(std::move(features), std::move(states));
394 } 473 }
395 474
396 std::pair<Ort::Value, std::vector<Ort::Value>> 475 std::pair<Ort::Value, std::vector<Ort::Value>>
@@ -399,9 +478,9 @@ OnlineTransducerNeMoModel::RunDecoder(Ort::Value targets, @@ -399,9 +478,9 @@ OnlineTransducerNeMoModel::RunDecoder(Ort::Value targets,
399 return impl_->RunDecoder(std::move(targets), std::move(states)); 478 return impl_->RunDecoder(std::move(targets), std::move(states));
400 } 479 }
401 480
402 -std::vector<Ort::Value> OnlineTransducerNeMoModel::GetDecoderInitStates(  
403 - int32_t batch_size) const {  
404 - return impl_->GetDecoderInitStates(batch_size); 481 +std::vector<Ort::Value> OnlineTransducerNeMoModel::GetDecoderInitStates()
  482 + const {
  483 + return impl_->GetDecoderInitStates();
405 } 484 }
406 485
407 Ort::Value OnlineTransducerNeMoModel::RunJoiner(Ort::Value encoder_out, 486 Ort::Value OnlineTransducerNeMoModel::RunJoiner(Ort::Value encoder_out,
@@ -409,14 +488,13 @@ Ort::Value OnlineTransducerNeMoModel::RunJoiner(Ort::Value encoder_out, @@ -409,14 +488,13 @@ Ort::Value OnlineTransducerNeMoModel::RunJoiner(Ort::Value encoder_out,
409 return impl_->RunJoiner(std::move(encoder_out), std::move(decoder_out)); 488 return impl_->RunJoiner(std::move(encoder_out), std::move(decoder_out));
410 } 489 }
411 490
  491 +int32_t OnlineTransducerNeMoModel::ChunkSize() const {
  492 + return impl_->ChunkSize();
  493 +}
412 494
413 -int32_t OnlineTransducerNeMoModel::ChunkSize() const {  
414 - return impl_->ChunkSize();  
415 - }  
416 -  
417 -int32_t OnlineTransducerNeMoModel::ChunkShift() const {  
418 - return impl_->ChunkShift();  
419 - } 495 +int32_t OnlineTransducerNeMoModel::ChunkShift() const {
  496 + return impl_->ChunkShift();
  497 +}
420 498
421 int32_t OnlineTransducerNeMoModel::SubsamplingFactor() const { 499 int32_t OnlineTransducerNeMoModel::SubsamplingFactor() const {
422 return impl_->SubsamplingFactor(); 500 return impl_->SubsamplingFactor();
@@ -434,8 +512,19 @@ std::string OnlineTransducerNeMoModel::FeatureNormalizationMethod() const { @@ -434,8 +512,19 @@ std::string OnlineTransducerNeMoModel::FeatureNormalizationMethod() const {
434 return impl_->FeatureNormalizationMethod(); 512 return impl_->FeatureNormalizationMethod();
435 } 513 }
436 514
437 -std::vector<Ort::Value> OnlineTransducerNeMoModel::GetInitStates() const {  
438 - return impl_->GetInitStates(); 515 +std::vector<Ort::Value> OnlineTransducerNeMoModel::GetEncoderInitStates()
  516 + const {
  517 + return impl_->GetEncoderInitStates();
  518 +}
  519 +
  520 +std::vector<Ort::Value> OnlineTransducerNeMoModel::StackStates(
  521 + std::vector<std::vector<Ort::Value>> states) const {
  522 + return impl_->StackStates(std::move(states));
  523 +}
  524 +
  525 +std::vector<std::vector<Ort::Value>> OnlineTransducerNeMoModel::UnStackStates(
  526 + std::vector<Ort::Value> states) const {
  527 + return impl_->UnStackStates(std::move(states));
439 } 528 }
440 529
441 -} // namespace sherpa_onnx  
  530 +} // namespace sherpa_onnx
@@ -32,22 +32,31 @@ class OnlineTransducerNeMoModel { @@ -32,22 +32,31 @@ class OnlineTransducerNeMoModel {
32 OnlineTransducerNeMoModel(AAssetManager *mgr, 32 OnlineTransducerNeMoModel(AAssetManager *mgr,
33 const OnlineModelConfig &config); 33 const OnlineModelConfig &config);
34 #endif 34 #endif
35 - 35 +
36 ~OnlineTransducerNeMoModel(); 36 ~OnlineTransducerNeMoModel();
37 - // A list of 3 tensors: 37 + // A list of 3 tensors:
38 // - cache_last_channel 38 // - cache_last_channel
39 // - cache_last_time 39 // - cache_last_time
40 // - cache_last_channel_len 40 // - cache_last_channel_len
41 - std::vector<Ort::Value> GetInitStates() const; 41 + std::vector<Ort::Value> GetEncoderInitStates() const;
  42 +
  43 + // stack encoder states
  44 + std::vector<Ort::Value> StackStates(
  45 + std::vector<std::vector<Ort::Value>> states) const;
  46 +
  47 + // unstack encoder states
  48 + std::vector<std::vector<Ort::Value>> UnStackStates(
  49 + std::vector<Ort::Value> states) const;
42 50
43 /** Run the encoder. 51 /** Run the encoder.
44 * 52 *
45 * @param features A tensor of shape (N, T, C). It is changed in-place. 53 * @param features A tensor of shape (N, T, C). It is changed in-place.
46 - * @param states It is from GetInitStates() or returned from this method.  
47 - * 54 + * @param states It is from GetEncoderInitStates() or returned from this
  55 + * method.
  56 + *
48 * @return Return a tuple containing: 57 * @return Return a tuple containing:
49 - * - ans[0]: encoder_out, a tensor of shape (N, T', encoder_out_dim)  
50 - * - ans[1:]: contains next states 58 + * - ans[0]: encoder_out, a tensor of shape (N, encoder_out_dim, T')
  59 + * - ans[1:]: contains next states
51 */ 60 */
52 std::vector<Ort::Value> RunEncoder( 61 std::vector<Ort::Value> RunEncoder(
53 Ort::Value features, std::vector<Ort::Value> states) const; // NOLINT 62 Ort::Value features, std::vector<Ort::Value> states) const; // NOLINT
@@ -63,7 +72,7 @@ class OnlineTransducerNeMoModel { @@ -63,7 +72,7 @@ class OnlineTransducerNeMoModel {
63 std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder( 72 std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder(
64 Ort::Value targets, std::vector<Ort::Value> states) const; 73 Ort::Value targets, std::vector<Ort::Value> states) const;
65 74
66 - std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const; 75 + std::vector<Ort::Value> GetDecoderInitStates() const;
67 76
68 /** Run the joint network. 77 /** Run the joint network.
69 * 78 *
@@ -71,9 +80,7 @@ class OnlineTransducerNeMoModel { @@ -71,9 +80,7 @@ class OnlineTransducerNeMoModel {
71 * @param decoder_out Output of the decoder network. 80 * @param decoder_out Output of the decoder network.
72 * @return Return a tensor of shape (N, 1, 1, vocab_size) containing logits. 81 * @return Return a tensor of shape (N, 1, 1, vocab_size) containing logits.
73 */ 82 */
74 - Ort::Value RunJoiner(Ort::Value encoder_out,  
75 - Ort::Value decoder_out) const;  
76 - 83 + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) const;
77 84
78 /** We send this number of feature frames to the encoder at a time. */ 85 /** We send this number of feature frames to the encoder at a time. */
79 int32_t ChunkSize() const; 86 int32_t ChunkSize() const;
@@ -114,10 +121,10 @@ class OnlineTransducerNeMoModel { @@ -114,10 +121,10 @@ class OnlineTransducerNeMoModel {
114 // for details 121 // for details
115 std::string FeatureNormalizationMethod() const; 122 std::string FeatureNormalizationMethod() const;
116 123
117 - private:  
118 - class Impl;  
119 - std::unique_ptr<Impl> impl_;  
120 - }; 124 + private:
  125 + class Impl;
  126 + std::unique_ptr<Impl> impl_;
  127 +};
121 128
122 } // namespace sherpa_onnx 129 } // namespace sherpa_onnx
123 130