Jingzhao Ou
Committed by GitHub

Stack and streaming conformer support (#141)

* added csrc/stack.cc

* stack: added checks

* added copyright info

* passed cpp style checks

* formatted code

* added some support for streaming conformer model support (not verified)

* code lint

* made more progress with streaming conformer support (not working yet)

* passed style check

* changes as suggested by @csukuangfj

* added some debug info

* fixed style check

* Use Cat to replace Stack

* remove debug statements

---------

Co-authored-by: Jingzhao Ou (jou2019) <jou2019@cisco.com>
Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
@@ -34,6 +34,7 @@ set(sources @@ -34,6 +34,7 @@ set(sources
34 offline-transducer-model-config.cc 34 offline-transducer-model-config.cc
35 offline-transducer-model.cc 35 offline-transducer-model.cc
36 offline-transducer-modified-beam-search-decoder.cc 36 offline-transducer-modified-beam-search-decoder.cc
  37 + online-conformer-transducer-model.cc
37 online-lm.cc 38 online-lm.cc
38 online-lm-config.cc 39 online-lm-config.cc
39 online-lstm-transducer-model.cc 40 online-lstm-transducer-model.cc
@@ -52,6 +53,7 @@ set(sources @@ -52,6 +53,7 @@ set(sources
52 parse-options.cc 53 parse-options.cc
53 resample.cc 54 resample.cc
54 slice.cc 55 slice.cc
  56 + stack.cc
55 symbol-table.cc 57 symbol-table.cc
56 text-utils.cc 58 text-utils.cc
57 transpose.cc 59 transpose.cc
@@ -241,6 +243,7 @@ if(SHERPA_ONNX_ENABLE_TESTS) @@ -241,6 +243,7 @@ if(SHERPA_ONNX_ENABLE_TESTS)
241 packed-sequence-test.cc 243 packed-sequence-test.cc
242 pad-sequence-test.cc 244 pad-sequence-test.cc
243 slice-test.cc 245 slice-test.cc
  246 + stack-test.cc
244 transpose-test.cc 247 transpose-test.cc
245 unbind-test.cc 248 unbind-test.cc
246 ) 249 )
  1 +// sherpa-onnx/csrc/online-conformer-transducer-model.cc
  2 +//
  3 +// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com)
  4 +
  5 +#include "sherpa-onnx/csrc/online-conformer-transducer-model.h"
  6 +
  7 +#include <assert.h>
  8 +
  9 +#include <algorithm>
  10 +#include <memory>
  11 +#include <sstream>
  12 +#include <iostream>
  13 +#include <string>
  14 +#include <utility>
  15 +#include <vector>
  16 +
  17 +#if __ANDROID_API__ >= 9
  18 +#include "android/asset_manager.h"
  19 +#include "android/asset_manager_jni.h"
  20 +#endif
  21 +
  22 +#include "onnxruntime_cxx_api.h" // NOLINT
  23 +#include "sherpa-onnx/csrc/cat.h"
  24 +#include "sherpa-onnx/csrc/macros.h"
  25 +#include "sherpa-onnx/csrc/online-transducer-decoder.h"
  26 +#include "sherpa-onnx/csrc/onnx-utils.h"
  27 +#include "sherpa-onnx/csrc/text-utils.h"
  28 +#include "sherpa-onnx/csrc/unbind.h"
  29 +
  30 +namespace sherpa_onnx {
  31 +
  32 +OnlineConformerTransducerModel::OnlineConformerTransducerModel(
  33 + const OnlineTransducerModelConfig &config)
  34 + : env_(ORT_LOGGING_LEVEL_WARNING),
  35 + config_(config),
  36 + sess_opts_{},
  37 + allocator_{} {
  38 + sess_opts_.SetIntraOpNumThreads(config.num_threads);
  39 + sess_opts_.SetInterOpNumThreads(config.num_threads);
  40 +
  41 + {
  42 + auto buf = ReadFile(config.encoder_filename);
  43 + InitEncoder(buf.data(), buf.size());
  44 + }
  45 +
  46 + {
  47 + auto buf = ReadFile(config.decoder_filename);
  48 + InitDecoder(buf.data(), buf.size());
  49 + }
  50 +
  51 + {
  52 + auto buf = ReadFile(config.joiner_filename);
  53 + InitJoiner(buf.data(), buf.size());
  54 + }
  55 +}
  56 +
  57 +#if __ANDROID_API__ >= 9
  58 +OnlineConformerTransducerModel::OnlineConformerTransducerModel(
  59 + AAssetManager *mgr, const OnlineTransducerModelConfig &config)
  60 + : env_(ORT_LOGGING_LEVEL_WARNING),
  61 + config_(config),
  62 + sess_opts_{},
  63 + allocator_{} {
  64 + sess_opts_.SetIntraOpNumThreads(config.num_threads);
  65 + sess_opts_.SetInterOpNumThreads(config.num_threads);
  66 +
  67 + {
  68 + auto buf = ReadFile(mgr, config.encoder_filename);
  69 + InitEncoder(buf.data(), buf.size());
  70 + }
  71 +
  72 + {
  73 + auto buf = ReadFile(mgr, config.decoder_filename);
  74 + InitDecoder(buf.data(), buf.size());
  75 + }
  76 +
  77 + {
  78 + auto buf = ReadFile(mgr, config.joiner_filename);
  79 + InitJoiner(buf.data(), buf.size());
  80 + }
  81 +}
  82 +#endif
  83 +
  84 +void OnlineConformerTransducerModel::InitEncoder(void *model_data,
  85 + size_t model_data_length) {
  86 + encoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
  87 + model_data_length, sess_opts_);
  88 +
  89 + GetInputNames(encoder_sess_.get(), &encoder_input_names_,
  90 + &encoder_input_names_ptr_);
  91 +
  92 + GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
  93 + &encoder_output_names_ptr_);
  94 +
  95 + // get meta data
  96 + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
  97 + if (config_.debug) {
  98 + std::ostringstream os;
  99 + os << "---encoder---\n";
  100 + PrintModelMetadata(os, meta_data);
  101 + SHERPA_ONNX_LOGE("%s", os.str().c_str());
  102 + }
  103 +
  104 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  105 + SHERPA_ONNX_READ_META_DATA(num_encoder_layers_, "num_encoder_layers");
  106 + SHERPA_ONNX_READ_META_DATA(T_, "T");
  107 + SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len");
  108 + SHERPA_ONNX_READ_META_DATA(left_context_, "left_context");
  109 + SHERPA_ONNX_READ_META_DATA(encoder_dim_, "encoder_dim");
  110 + SHERPA_ONNX_READ_META_DATA(pad_length_, "pad_length");
  111 + SHERPA_ONNX_READ_META_DATA(cnn_module_kernel_, "cnn_module_kernel");
  112 +}
  113 +
  114 +void OnlineConformerTransducerModel::InitDecoder(void *model_data,
  115 + size_t model_data_length) {
  116 + decoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
  117 + model_data_length, sess_opts_);
  118 +
  119 + GetInputNames(decoder_sess_.get(), &decoder_input_names_,
  120 + &decoder_input_names_ptr_);
  121 +
  122 + GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
  123 + &decoder_output_names_ptr_);
  124 +
  125 + // get meta data
  126 + Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata();
  127 + if (config_.debug) {
  128 + std::ostringstream os;
  129 + os << "---decoder---\n";
  130 + PrintModelMetadata(os, meta_data);
  131 + SHERPA_ONNX_LOGE("%s", os.str().c_str());
  132 + }
  133 +
  134 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  135 + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
  136 + SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
  137 +}
  138 +
  139 +void OnlineConformerTransducerModel::InitJoiner(void *model_data,
  140 + size_t model_data_length) {
  141 + joiner_sess_ = std::make_unique<Ort::Session>(env_, model_data,
  142 + model_data_length, sess_opts_);
  143 +
  144 + GetInputNames(joiner_sess_.get(), &joiner_input_names_,
  145 + &joiner_input_names_ptr_);
  146 +
  147 + GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
  148 + &joiner_output_names_ptr_);
  149 +
  150 + // get meta data
  151 + Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata();
  152 + if (config_.debug) {
  153 + std::ostringstream os;
  154 + os << "---joiner---\n";
  155 + PrintModelMetadata(os, meta_data);
  156 + SHERPA_ONNX_LOGE("%s", os.str().c_str());
  157 + }
  158 +}
  159 +
  160 +std::vector<Ort::Value> OnlineConformerTransducerModel::StackStates(
  161 + const std::vector<std::vector<Ort::Value>> &states) const {
  162 + int32_t batch_size = static_cast<int32_t>(states.size());
  163 +
  164 + std::vector<const Ort::Value *> attn_vec(batch_size);
  165 + std::vector<const Ort::Value *> conv_vec(batch_size);
  166 +
  167 + for (int32_t i = 0; i != batch_size; ++i) {
  168 + assert(states[i].size() == 2);
  169 + attn_vec[i] = &states[i][0];
  170 + conv_vec[i] = &states[i][1];
  171 + }
  172 +
  173 + Ort::Value attn = Cat(allocator_, attn_vec, 2);
  174 + Ort::Value conv = Cat(allocator_, conv_vec, 2);
  175 +
  176 + std::vector<Ort::Value> ans;
  177 + ans.reserve(2);
  178 + ans.push_back(std::move(attn));
  179 + ans.push_back(std::move(conv));
  180 +
  181 + return ans;
  182 +}
  183 +
  184 +std::vector<std::vector<Ort::Value>>
  185 +OnlineConformerTransducerModel::UnStackStates(
  186 + const std::vector<Ort::Value> &states) const {
  187 + const int32_t batch_size =
  188 + states[0].GetTensorTypeAndShapeInfo().GetShape()[2];
  189 + assert(states.size() == 2);
  190 +
  191 + std::vector<std::vector<Ort::Value>> ans(batch_size);
  192 +
  193 + std::vector<Ort::Value> attn_vec = Unbind(allocator_, &states[0], 2);
  194 + std::vector<Ort::Value> conv_vec = Unbind(allocator_, &states[1], 2);
  195 +
  196 + assert(attn_vec.size() == batch_size);
  197 + assert(conv_vec.size() == batch_size);
  198 +
  199 + for (int32_t i = 0; i != batch_size; ++i) {
  200 + ans[i].push_back(std::move(attn_vec[i]));
  201 + ans[i].push_back(std::move(conv_vec[i]));
  202 + }
  203 +
  204 + return ans;
  205 +}
  206 +
  207 +std::vector<Ort::Value> OnlineConformerTransducerModel::GetEncoderInitStates() {
  208 + // Please see
  209 + // https://github.com/k2-fsa/icefall/blob/86b0db6eb9c84d9bc90a71d92774fe2a7f73e6ab/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py#L203
  210 + // for details
  211 + constexpr int32_t kBatchSize = 1;
  212 + std::array<int64_t, 4> h_shape{
  213 + num_encoder_layers_, left_context_, kBatchSize, encoder_dim_};
  214 + Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
  215 + h_shape.size());
  216 +
  217 + Fill<float>(&h, 0);
  218 +
  219 + std::array<int64_t, 4> c_shape{num_encoder_layers_, cnn_module_kernel_ - 1,
  220 + kBatchSize, encoder_dim_};
  221 +
  222 + Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(),
  223 + c_shape.size());
  224 +
  225 + Fill<float>(&c, 0);
  226 +
  227 + std::vector<Ort::Value> states;
  228 +
  229 + states.reserve(2);
  230 + states.push_back(std::move(h));
  231 + states.push_back(std::move(c));
  232 +
  233 + return states;
  234 +}
  235 +
  236 +std::pair<Ort::Value, std::vector<Ort::Value>>
  237 +OnlineConformerTransducerModel::RunEncoder(Ort::Value features,
  238 + std::vector<Ort::Value> states,
  239 + Ort::Value processed_frames) {
  240 + std::array<Ort::Value, 4> encoder_inputs = {
  241 + std::move(features),
  242 + std::move(states[0]),
  243 + std::move(states[1]),
  244 + std::move(processed_frames)};
  245 +
  246 + auto encoder_out = encoder_sess_->Run(
  247 + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
  248 + encoder_inputs.size(), encoder_output_names_ptr_.data(),
  249 + encoder_output_names_ptr_.size());
  250 +
  251 + std::vector<Ort::Value> next_states;
  252 + next_states.reserve(2);
  253 + next_states.push_back(std::move(encoder_out[1]));
  254 + next_states.push_back(std::move(encoder_out[2]));
  255 +
  256 + return {std::move(encoder_out[0]), std::move(next_states)};
  257 +}
  258 +
  259 +Ort::Value OnlineConformerTransducerModel::RunDecoder(
  260 + Ort::Value decoder_input) {
  261 + auto decoder_out = decoder_sess_->Run(
  262 + {}, decoder_input_names_ptr_.data(), &decoder_input, 1,
  263 + decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size());
  264 + return std::move(decoder_out[0]);
  265 +}
  266 +
  267 +Ort::Value OnlineConformerTransducerModel::RunJoiner(Ort::Value encoder_out,
  268 + Ort::Value decoder_out) {
  269 + std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
  270 + std::move(decoder_out)};
  271 + auto logit =
  272 + joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(),
  273 + joiner_input.size(), joiner_output_names_ptr_.data(),
  274 + joiner_output_names_ptr_.size());
  275 +
  276 + return std::move(logit[0]);
  277 +}
  278 +
  279 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-conformer-transducer-model.h
  2 +//
  3 +// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com)
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_ONLINE_CONFORMER_TRANSDUCER_MODEL_H_
  6 +#define SHERPA_ONNX_CSRC_ONLINE_CONFORMER_TRANSDUCER_MODEL_H_
  7 +
  8 +#include <memory>
  9 +#include <string>
  10 +#include <utility>
  11 +#include <vector>
  12 +
  13 +#if __ANDROID_API__ >= 9
  14 +#include "android/asset_manager.h"
  15 +#include "android/asset_manager_jni.h"
  16 +#endif
  17 +
  18 +#include "onnxruntime_cxx_api.h" // NOLINT
  19 +#include "sherpa-onnx/csrc/online-transducer-model-config.h"
  20 +#include "sherpa-onnx/csrc/online-transducer-model.h"
  21 +
  22 +namespace sherpa_onnx {
  23 +
  24 +class OnlineConformerTransducerModel : public OnlineTransducerModel {
  25 + public:
  26 + explicit OnlineConformerTransducerModel(
  27 + const OnlineTransducerModelConfig &config);
  28 +
  29 +#if __ANDROID_API__ >= 9
  30 + OnlineConformerTransducerModel(AAssetManager *mgr,
  31 + const OnlineTransducerModelConfig &config);
  32 +#endif
  33 +
  34 + std::vector<Ort::Value> StackStates(
  35 + const std::vector<std::vector<Ort::Value>> &states) const override;
  36 +
  37 + std::vector<std::vector<Ort::Value>> UnStackStates(
  38 + const std::vector<Ort::Value> &states) const override;
  39 +
  40 + std::vector<Ort::Value> GetEncoderInitStates() override;
  41 +
  42 + std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
  43 + Ort::Value features, std::vector<Ort::Value> states,
  44 + Ort::Value processed_frames) override;
  45 +
  46 + Ort::Value RunDecoder(Ort::Value decoder_input) override;
  47 +
  48 + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override;
  49 +
  50 + int32_t ContextSize() const override { return context_size_; }
  51 +
  52 + int32_t ChunkSize() const override { return T_; }
  53 +
  54 + int32_t ChunkShift() const override { return decode_chunk_len_; }
  55 +
  56 + int32_t VocabSize() const override { return vocab_size_; }
  57 + OrtAllocator *Allocator() override { return allocator_; }
  58 +
  59 + private:
  60 + void InitEncoder(void *model_data, size_t model_data_length);
  61 + void InitDecoder(void *model_data, size_t model_data_length);
  62 + void InitJoiner(void *model_data, size_t model_data_length);
  63 +
  64 + private:
  65 + Ort::Env env_;
  66 + Ort::SessionOptions sess_opts_;
  67 + Ort::AllocatorWithDefaultOptions allocator_;
  68 +
  69 + std::unique_ptr<Ort::Session> encoder_sess_;
  70 + std::unique_ptr<Ort::Session> decoder_sess_;
  71 + std::unique_ptr<Ort::Session> joiner_sess_;
  72 +
  73 + std::vector<std::string> encoder_input_names_;
  74 + std::vector<const char *> encoder_input_names_ptr_;
  75 +
  76 + std::vector<std::string> encoder_output_names_;
  77 + std::vector<const char *> encoder_output_names_ptr_;
  78 +
  79 + std::vector<std::string> decoder_input_names_;
  80 + std::vector<const char *> decoder_input_names_ptr_;
  81 +
  82 + std::vector<std::string> decoder_output_names_;
  83 + std::vector<const char *> decoder_output_names_ptr_;
  84 +
  85 + std::vector<std::string> joiner_input_names_;
  86 + std::vector<const char *> joiner_input_names_ptr_;
  87 +
  88 + std::vector<std::string> joiner_output_names_;
  89 + std::vector<const char *> joiner_output_names_ptr_;
  90 +
  91 + OnlineTransducerModelConfig config_;
  92 +
  93 + int32_t num_encoder_layers_ = 0;
  94 + int32_t T_ = 0;
  95 + int32_t decode_chunk_len_ = 0;
  96 + int32_t cnn_module_kernel_ = 0;
  97 + int32_t context_size_ = 0;
  98 + int32_t left_context_ = 0;
  99 + // TODO(jingzhaoou): to retrieve from model medadata
  100 + int32_t right_context_ = 4;
  101 + int32_t encoder_dim_ = 0;
  102 + int32_t pad_length_ = 0;
  103 + int32_t vocab_size_ = 0;
  104 +};
  105 +
  106 +} // namespace sherpa_onnx
  107 +
  108 +#endif // SHERPA_ONNX_CSRC_ONLINE_CONFORMER_TRANSDUCER_MODEL_H_
@@ -227,7 +227,8 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() { @@ -227,7 +227,8 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() {
227 227
228 std::pair<Ort::Value, std::vector<Ort::Value>> 228 std::pair<Ort::Value, std::vector<Ort::Value>>
229 OnlineLstmTransducerModel::RunEncoder(Ort::Value features, 229 OnlineLstmTransducerModel::RunEncoder(Ort::Value features,
230 - std::vector<Ort::Value> states) { 230 + std::vector<Ort::Value> states,
  231 + Ort::Value /* processed_frames */) {
231 std::array<Ort::Value, 3> encoder_inputs = { 232 std::array<Ort::Value, 3> encoder_inputs = {
232 std::move(features), std::move(states[0]), std::move(states[1])}; 233 std::move(features), std::move(states[0]), std::move(states[1])};
233 234
@@ -38,7 +38,8 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { @@ -38,7 +38,8 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
38 std::vector<Ort::Value> GetEncoderInitStates() override; 38 std::vector<Ort::Value> GetEncoderInitStates() override;
39 39
40 std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( 40 std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
41 - Ort::Value features, std::vector<Ort::Value> states) override; 41 + Ort::Value features, std::vector<Ort::Value> states,
  42 + Ort::Value processed_frames) override;
42 43
43 Ort::Value RunDecoder(Ort::Value decoder_input) override; 44 Ort::Value RunDecoder(Ort::Value decoder_input) override;
44 45
@@ -9,6 +9,7 @@ @@ -9,6 +9,7 @@
9 9
10 #include <algorithm> 10 #include <algorithm>
11 #include <iomanip> 11 #include <iomanip>
  12 +#include <iostream>
12 #include <memory> 13 #include <memory>
13 #include <sstream> 14 #include <sstream>
14 #include <utility> 15 #include <utility>
@@ -187,11 +188,14 @@ class OnlineRecognizer::Impl { @@ -187,11 +188,14 @@ class OnlineRecognizer::Impl {
187 std::vector<OnlineTransducerDecoderResult> results(n); 188 std::vector<OnlineTransducerDecoderResult> results(n);
188 std::vector<float> features_vec(n * chunk_size * feature_dim); 189 std::vector<float> features_vec(n * chunk_size * feature_dim);
189 std::vector<std::vector<Ort::Value>> states_vec(n); 190 std::vector<std::vector<Ort::Value>> states_vec(n);
  191 + std::vector<int64_t> all_processed_frames(n);
190 192
191 for (int32_t i = 0; i != n; ++i) { 193 for (int32_t i = 0; i != n; ++i) {
  194 + const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
192 std::vector<float> features = 195 std::vector<float> features =
193 - ss[i]->GetFrames(ss[i]->GetNumProcessedFrames(), chunk_size); 196 + ss[i]->GetFrames(num_processed_frames, chunk_size);
194 197
  198 + // Question: should num_processed_frames include chunk_shift?
195 ss[i]->GetNumProcessedFrames() += chunk_shift; 199 ss[i]->GetNumProcessedFrames() += chunk_shift;
196 200
197 std::copy(features.begin(), features.end(), 201 std::copy(features.begin(), features.end(),
@@ -199,6 +203,7 @@ class OnlineRecognizer::Impl { @@ -199,6 +203,7 @@ class OnlineRecognizer::Impl {
199 203
200 results[i] = std::move(ss[i]->GetResult()); 204 results[i] = std::move(ss[i]->GetResult());
201 states_vec[i] = std::move(ss[i]->GetStates()); 205 states_vec[i] = std::move(ss[i]->GetStates());
  206 + all_processed_frames[i] = num_processed_frames;
202 } 207 }
203 208
204 auto memory_info = 209 auto memory_info =
@@ -210,9 +215,20 @@ class OnlineRecognizer::Impl { @@ -210,9 +215,20 @@ class OnlineRecognizer::Impl {
210 features_vec.size(), x_shape.data(), 215 features_vec.size(), x_shape.data(),
211 x_shape.size()); 216 x_shape.size());
212 217
  218 + std::array<int64_t, 1> processed_frames_shape{
  219 + static_cast<int64_t>(all_processed_frames.size())};
  220 +
  221 + Ort::Value processed_frames = Ort::Value::CreateTensor(
  222 + memory_info,
  223 + all_processed_frames.data(),
  224 + all_processed_frames.size(),
  225 + processed_frames_shape.data(),
  226 + processed_frames_shape.size());
  227 +
213 auto states = model_->StackStates(states_vec); 228 auto states = model_->StackStates(states_vec);
214 229
215 - auto pair = model_->RunEncoder(std::move(x), std::move(states)); 230 + auto pair = model_->RunEncoder(
  231 + std::move(x), std::move(states), std::move(processed_frames));
216 232
217 decoder_->Decode(std::move(pair.first), &results); 233 decoder_->Decode(std::move(pair.first), &results);
218 234
@@ -10,11 +10,13 @@ @@ -10,11 +10,13 @@
10 #endif 10 #endif
11 11
12 #include <algorithm> 12 #include <algorithm>
  13 +#include <iostream>
13 #include <memory> 14 #include <memory>
14 #include <sstream> 15 #include <sstream>
15 #include <string> 16 #include <string>
16 17
17 #include "sherpa-onnx/csrc/macros.h" 18 #include "sherpa-onnx/csrc/macros.h"
  19 +#include "sherpa-onnx/csrc/online-conformer-transducer-model.h"
18 #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" 20 #include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
19 #include "sherpa-onnx/csrc/online-zipformer-transducer-model.h" 21 #include "sherpa-onnx/csrc/online-zipformer-transducer-model.h"
20 #include "sherpa-onnx/csrc/onnx-utils.h" 22 #include "sherpa-onnx/csrc/onnx-utils.h"
@@ -22,6 +24,7 @@ @@ -22,6 +24,7 @@
22 namespace { 24 namespace {
23 25
24 enum class ModelType { 26 enum class ModelType {
  27 + kConformer,
25 kLstm, 28 kLstm,
26 kZipformer, 29 kZipformer,
27 kUnkown, 30 kUnkown,
@@ -57,7 +60,9 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -57,7 +60,9 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
57 return ModelType::kUnkown; 60 return ModelType::kUnkown;
58 } 61 }
59 62
60 - if (model_type.get() == std::string("lstm")) { 63 + if (model_type.get() == std::string("conformer")) {
  64 + return ModelType::kConformer;
  65 + } else if (model_type.get() == std::string("lstm")) {
61 return ModelType::kLstm; 66 return ModelType::kLstm;
62 } else if (model_type.get() == std::string("zipformer")) { 67 } else if (model_type.get() == std::string("zipformer")) {
63 return ModelType::kZipformer; 68 return ModelType::kZipformer;
@@ -78,6 +83,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( @@ -78,6 +83,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
78 } 83 }
79 84
80 switch (model_type) { 85 switch (model_type) {
  86 + case ModelType::kConformer:
  87 + return std::make_unique<OnlineConformerTransducerModel>(config);
81 case ModelType::kLstm: 88 case ModelType::kLstm:
82 return std::make_unique<OnlineLstmTransducerModel>(config); 89 return std::make_unique<OnlineLstmTransducerModel>(config);
83 case ModelType::kZipformer: 90 case ModelType::kZipformer:
@@ -132,6 +139,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( @@ -132,6 +139,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
132 auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug); 139 auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
133 140
134 switch (model_type) { 141 switch (model_type) {
  142 + case ModelType::kConformer:
  143 + return std::make_unique<OnlineConformerTransducerModel>(mgr, config);
135 case ModelType::kLstm: 144 case ModelType::kLstm:
136 return std::make_unique<OnlineLstmTransducerModel>(mgr, config); 145 return std::make_unique<OnlineLstmTransducerModel>(mgr, config);
137 case ModelType::kZipformer: 146 case ModelType::kZipformer:
@@ -64,6 +64,7 @@ class OnlineTransducerModel { @@ -64,6 +64,7 @@ class OnlineTransducerModel {
64 * 64 *
65 * @param features A tensor of shape (N, T, C). It is changed in-place. 65 * @param features A tensor of shape (N, T, C). It is changed in-place.
66 * @param states Encoder state of the previous chunk. It is changed in-place. 66 * @param states Encoder state of the previous chunk. It is changed in-place.
  67 + * @param processed_frames Processed frames before subsampling. It is a 1-D tensor with data type int64_t.
67 * 68 *
68 * @return Return a tuple containing: 69 * @return Return a tuple containing:
69 * - encoder_out, a tensor of shape (N, T', encoder_out_dim) 70 * - encoder_out, a tensor of shape (N, T', encoder_out_dim)
@@ -71,7 +72,8 @@ class OnlineTransducerModel { @@ -71,7 +72,8 @@ class OnlineTransducerModel {
71 */ 72 */
72 virtual std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( 73 virtual std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
73 Ort::Value features, 74 Ort::Value features,
74 - std::vector<Ort::Value> states) = 0; // NOLINT 75 + std::vector<Ort::Value> states,
  76 + Ort::Value processed_frames) = 0; // NOLINT
75 77
76 /** Run the decoder network. 78 /** Run the decoder network.
77 * 79 *
@@ -434,7 +434,8 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::GetEncoderInitStates() { @@ -434,7 +434,8 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::GetEncoderInitStates() {
434 434
435 std::pair<Ort::Value, std::vector<Ort::Value>> 435 std::pair<Ort::Value, std::vector<Ort::Value>>
436 OnlineZipformerTransducerModel::RunEncoder(Ort::Value features, 436 OnlineZipformerTransducerModel::RunEncoder(Ort::Value features,
437 - std::vector<Ort::Value> states) { 437 + std::vector<Ort::Value> states,
  438 + Ort::Value /* processed_frames */) {
438 std::vector<Ort::Value> encoder_inputs; 439 std::vector<Ort::Value> encoder_inputs;
439 encoder_inputs.reserve(1 + states.size()); 440 encoder_inputs.reserve(1 + states.size());
440 441
@@ -39,7 +39,8 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel { @@ -39,7 +39,8 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel {
39 std::vector<Ort::Value> GetEncoderInitStates() override; 39 std::vector<Ort::Value> GetEncoderInitStates() override;
40 40
41 std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( 41 std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
42 - Ort::Value features, std::vector<Ort::Value> states) override; 42 + Ort::Value features, std::vector<Ort::Value> states,
  43 + Ort::Value processed_frames) override;
43 44
44 Ort::Value RunDecoder(Ort::Value decoder_input) override; 45 Ort::Value RunDecoder(Ort::Value decoder_input) override;
45 46
@@ -168,6 +168,26 @@ void Print3D(Ort::Value *v) { @@ -168,6 +168,26 @@ void Print3D(Ort::Value *v) {
168 fprintf(stderr, "\n"); 168 fprintf(stderr, "\n");
169 } 169 }
170 170
  171 +void Print4D(Ort::Value *v) {
  172 + std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
  173 + const float *d = v->GetTensorData<float>();
  174 +
  175 + for (int32_t p = 0; p != static_cast<int32_t>(shape[0]); ++p) {
  176 + fprintf(stderr, "---plane %d---\n", p);
  177 + for (int32_t q = 0; q != static_cast<int32_t>(shape[1]); ++q) {
  178 + fprintf(stderr, "---subplane %d---\n", q);
  179 + for (int32_t r = 0; r != static_cast<int32_t>(shape[2]); ++r) {
  180 + for (int32_t c = 0; c != static_cast<int32_t>(shape[3]); ++c, ++d) {
  181 + fprintf(stderr, "%.3f ", *d);
  182 + }
  183 + fprintf(stderr, "\n");
  184 + }
  185 + fprintf(stderr, "\n");
  186 + }
  187 + }
  188 + fprintf(stderr, "\n");
  189 +}
  190 +
171 std::vector<char> ReadFile(const std::string &filename) { 191 std::vector<char> ReadFile(const std::string &filename) {
172 std::ifstream input(filename, std::ios::binary); 192 std::ifstream input(filename, std::ios::binary);
173 std::vector<char> buffer(std::istreambuf_iterator<char>(input), {}); 193 std::vector<char> buffer(std::istreambuf_iterator<char>(input), {});
@@ -75,6 +75,9 @@ void Print2D(Ort::Value *v); @@ -75,6 +75,9 @@ void Print2D(Ort::Value *v);
75 // Print a 3-D tensor to stderr 75 // Print a 3-D tensor to stderr
76 void Print3D(Ort::Value *v); 76 void Print3D(Ort::Value *v);
77 77
  78 +// Print a 4-D tensor to stderr
  79 +void Print4D(Ort::Value *v);
  80 +
78 template <typename T = float> 81 template <typename T = float>
79 void Fill(Ort::Value *tensor, T value) { 82 void Fill(Ort::Value *tensor, T value) {
80 auto n = tensor->GetTypeInfo().GetTensorTypeAndShapeInfo().GetElementCount(); 83 auto n = tensor->GetTypeInfo().GetTensorTypeAndShapeInfo().GetElementCount();
  1 +// sherpa-onnx/csrc/stack-test.cc
  2 +//
  3 +// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com)
  4 +
  5 +#include "sherpa-onnx/csrc/stack.h"
  6 +
  7 +#include "gtest/gtest.h"
  8 +#include "sherpa-onnx/csrc/onnx-utils.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +TEST(Stack, Test1DTensors) {
  13 + Ort::AllocatorWithDefaultOptions allocator;
  14 +
  15 + std::array<int64_t, 1> a_shape{3};
  16 + std::array<int64_t, 1> b_shape{3};
  17 +
  18 + Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
  19 + a_shape.size());
  20 +
  21 + Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
  22 + b_shape.size());
  23 + float *pa = a.GetTensorMutableData<float>();
  24 + float *pb = b.GetTensorMutableData<float>();
  25 + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) {
  26 + pa[i] = i;
  27 + }
  28 + for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0]); ++i) {
  29 + pb[i] = i + 10;
  30 + }
  31 +
  32 + Ort::Value ans = Stack(allocator, {&a, &b}, 0);
  33 +
  34 + Print1D(&a);
  35 + Print1D(&b);
  36 + Print2D(&ans);
  37 +
  38 + const float *pans = ans.GetTensorData<float>();
  39 + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) {
  40 + EXPECT_EQ(pa[i], pans[i]);
  41 + }
  42 +
  43 + for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0]); ++i) {
  44 + EXPECT_EQ(pb[i], pans[i + a_shape[0]]);
  45 + }
  46 +}
  47 +
  48 +TEST(Stack, Test2DTensorsDim0) {
  49 + Ort::AllocatorWithDefaultOptions allocator;
  50 +
  51 + std::array<int64_t, 2> a_shape{2, 3};
  52 + std::array<int64_t, 2> b_shape{2, 3};
  53 +
  54 + Ort::Value a = Ort::Value::CreateTensor<float>(
  55 + allocator, a_shape.data(), a_shape.size());
  56 +
  57 + Ort::Value b = Ort::Value::CreateTensor<float>(
  58 + allocator, b_shape.data(), b_shape.size());
  59 +
  60 + float *pa = a.GetTensorMutableData<float>();
  61 + float *pb = b.GetTensorMutableData<float>();
  62 + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) {
  63 + pa[i] = i;
  64 + }
  65 + for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0] * b_shape[1]); ++i) {
  66 + pb[i] = i + 10;
  67 + }
  68 +
  69 + Ort::Value ans = Stack(allocator, {&a, &b}, 0);
  70 +
  71 + Print2D(&a);
  72 + Print2D(&b);
  73 + Print3D(&ans);
  74 +
  75 + const float *pans = ans.GetTensorData<float>();
  76 + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) {
  77 + EXPECT_EQ(pa[i], pans[i]);
  78 + }
  79 + for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0] * b_shape[1]); ++i) {
  80 + EXPECT_EQ(pb[i], pans[i + a_shape[0] * a_shape[1]]);
  81 + }
  82 +}
  83 +
  84 +TEST(Stack, Test2DTensorsDim1) {
  85 + Ort::AllocatorWithDefaultOptions allocator;
  86 +
  87 + std::array<int64_t, 2> a_shape{4, 3};
  88 + std::array<int64_t, 2> b_shape{4, 3};
  89 +
  90 + Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
  91 + a_shape.size());
  92 +
  93 + Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
  94 + b_shape.size());
  95 +
  96 + float *pa = a.GetTensorMutableData<float>();
  97 + float *pb = b.GetTensorMutableData<float>();
  98 + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) {
  99 + pa[i] = i;
  100 + }
  101 + for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0] * b_shape[1]); ++i) {
  102 + pb[i] = i + 10;
  103 + }
  104 +
  105 + Ort::Value ans = Stack(allocator, {&a, &b}, 1);
  106 +
  107 + Print2D(&a);
  108 + Print2D(&b);
  109 + Print3D(&ans);
  110 +
  111 + const float *pans = ans.GetTensorData<float>();
  112 +
  113 + for (int32_t r = 0; r != static_cast<int32_t>(a_shape[0]); ++r) {
  114 + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[1]);
  115 + ++i, ++pa, ++pans) {
  116 + EXPECT_EQ(*pa, *pans);
  117 + }
  118 +
  119 + for (int32_t i = 0; i != static_cast<int32_t>(b_shape[1]);
  120 + ++i, ++pb, ++pans) {
  121 + EXPECT_EQ(*pb, *pans);
  122 + }
  123 + }
  124 +}
  125 +
  126 +TEST(Stack, Test3DTensorsDim0) {
  127 + Ort::AllocatorWithDefaultOptions allocator;
  128 +
  129 + std::array<int64_t, 3> a_shape{2, 3, 2};
  130 + std::array<int64_t, 3> b_shape{2, 3, 2};
  131 +
  132 + Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
  133 + a_shape.size());
  134 +
  135 + Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
  136 + b_shape.size());
  137 +
  138 + float *pa = a.GetTensorMutableData<float>();
  139 + float *pb = b.GetTensorMutableData<float>();
  140 + for (int32_t i = 0;
  141 + i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) {
  142 + pa[i] = i;
  143 + }
  144 + for (int32_t i = 0;
  145 + i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) {
  146 + pb[i] = i + 10;
  147 + }
  148 +
  149 + Ort::Value ans = Stack(allocator, {&a, &b}, 0);
  150 +
  151 + const float *pans = ans.GetTensorData<float>();
  152 + for (int32_t i = 0;
  153 + i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) {
  154 + EXPECT_EQ(pa[i], pans[i]);
  155 + }
  156 + for (int32_t i = 0;
  157 + i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) {
  158 + EXPECT_EQ(pb[i], pans[i + a_shape[0] * a_shape[1] * a_shape[2]]);
  159 + }
  160 +
  161 + Print3D(&a);
  162 + Print3D(&b);
  163 + Print4D(&ans);
  164 +}
  165 +
  166 +TEST(Stack, Test3DTensorsDim1) {
  167 + Ort::AllocatorWithDefaultOptions allocator;
  168 +
  169 + std::array<int64_t, 3> a_shape{2, 2, 3};
  170 + std::array<int64_t, 3> b_shape{2, 2, 3};
  171 +
  172 + Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
  173 + a_shape.size());
  174 +
  175 + Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
  176 + b_shape.size());
  177 +
  178 + float *pa = a.GetTensorMutableData<float>();
  179 + float *pb = b.GetTensorMutableData<float>();
  180 + for (int32_t i = 0;
  181 + i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) {
  182 + pa[i] = i;
  183 + }
  184 + for (int32_t i = 0;
  185 + i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) {
  186 + pb[i] = i + 10;
  187 + }
  188 +
  189 + Ort::Value ans = Stack(allocator, {&a, &b}, 1);
  190 +
  191 + const float *pans = ans.GetTensorData<float>();
  192 +
  193 + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) {
  194 + for (int32_t k = 0; k != static_cast<int32_t>(a_shape[1] * a_shape[2]);
  195 + ++k, ++pa, ++pans) {
  196 + EXPECT_EQ(*pa, *pans);
  197 + }
  198 +
  199 + for (int32_t k = 0; k != static_cast<int32_t>(b_shape[1] * b_shape[2]);
  200 + ++k, ++pb, ++pans) {
  201 + EXPECT_EQ(*pb, *pans);
  202 + }
  203 + }
  204 +
  205 + Print3D(&a);
  206 + Print3D(&b);
  207 + Print4D(&ans);
  208 +}
  209 +
  210 +TEST(Stack, Test3DTensorsDim2) {
  211 + Ort::AllocatorWithDefaultOptions allocator;
  212 +
  213 + std::array<int64_t, 3> a_shape{2, 3, 4};
  214 + std::array<int64_t, 3> b_shape{2, 3, 4};
  215 +
  216 + Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
  217 + a_shape.size());
  218 +
  219 + Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
  220 + b_shape.size());
  221 +
  222 + float *pa = a.GetTensorMutableData<float>();
  223 + float *pb = b.GetTensorMutableData<float>();
  224 + for (int32_t i = 0;
  225 + i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) {
  226 + pa[i] = i;
  227 + }
  228 + for (int32_t i = 0;
  229 + i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) {
  230 + pb[i] = i + 10;
  231 + }
  232 +
  233 + Ort::Value ans = Stack(allocator, {&a, &b}, 2);
  234 +
  235 + const float *pans = ans.GetTensorData<float>();
  236 +
  237 + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) {
  238 + for (int32_t k = 0; k != static_cast<int32_t>(a_shape[2]);
  239 + ++k, ++pa, ++pans) {
  240 + EXPECT_EQ(*pa, *pans);
  241 + }
  242 +
  243 + for (int32_t k = 0; k != static_cast<int32_t>(b_shape[2]);
  244 + ++k, ++pb, ++pans) {
  245 + EXPECT_EQ(*pb, *pans);
  246 + }
  247 + }
  248 +
  249 + Print3D(&a);
  250 + Print3D(&b);
  251 + Print4D(&ans);
  252 +}
  253 +
  254 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/stack.cc
  2 +//
  3 +// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com)
  4 +
  5 +#include "sherpa-onnx/csrc/stack.h"
  6 +
  7 +#include <algorithm>
  8 +#include <functional>
  9 +#include <iostream>
  10 +#include <numeric>
  11 +#include <utility>
  12 +
  13 +#include "sherpa-onnx/csrc/onnx-utils.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +static bool Compare(const std::vector<int64_t> &a,
  18 + const std::vector<int64_t> &b) {
  19 + if (a.size() != b.size()) return false;
  20 +
  21 + for (int32_t i = 0; i != static_cast<int32_t>(a.size()); ++i) {
  22 + if (a[i] != b[i]) return false;
  23 + }
  24 +
  25 + return true;
  26 +}
  27 +
  28 +static void PrintShape(const std::vector<int64_t> &a) {
  29 + for (auto i : a) {
  30 + fprintf(stderr, "%d ", static_cast<int32_t>(i));
  31 + }
  32 + fprintf(stderr, "\n");
  33 +}
  34 +
  35 +template <typename T /*=float*/>
  36 +Ort::Value Stack(OrtAllocator *allocator,
  37 + const std::vector<const Ort::Value *> &values, int32_t dim) {
  38 + std::vector<int64_t> v0_shape =
  39 + values[0]->GetTensorTypeAndShapeInfo().GetShape();
  40 +
  41 + for (int32_t i = 1; i != static_cast<int32_t>(values.size()); ++i) {
  42 + auto s = values[i]->GetTensorTypeAndShapeInfo().GetShape();
  43 + bool ret = Compare(v0_shape, s);
  44 + if (!ret) {
  45 + fprintf(stderr, "Incorrect shape in Stack !\n");
  46 +
  47 + fprintf(stderr, "Shape for tensor 0: ");
  48 + PrintShape(v0_shape);
  49 +
  50 + fprintf(stderr, "Shape for tensor %d: ", i);
  51 + PrintShape(s);
  52 +
  53 + exit(-1);
  54 + }
  55 + }
  56 +
  57 + std::vector<int64_t> ans_shape;
  58 + ans_shape.reserve(v0_shape.size() + 1);
  59 + ans_shape.insert(ans_shape.end(), v0_shape.data(), v0_shape.data() + dim);
  60 + ans_shape.push_back(values.size());
  61 + ans_shape.insert(
  62 + ans_shape.end(),
  63 + v0_shape.data() + dim,
  64 + v0_shape.data() + v0_shape.size());
  65 +
  66 + auto leading_size = static_cast<int32_t>(std::accumulate(
  67 + v0_shape.begin(), v0_shape.begin() + dim, 1, std::multiplies<int64_t>()));
  68 +
  69 + auto trailing_size = static_cast<int32_t>(
  70 + std::accumulate(v0_shape.begin() + dim,
  71 + v0_shape.end(), 1,
  72 + std::multiplies<int64_t>()));
  73 +
  74 + Ort::Value ans = Ort::Value::CreateTensor<T>(
  75 + allocator, ans_shape.data(), ans_shape.size());
  76 + T *dst = ans.GetTensorMutableData<T>();
  77 +
  78 + for (int32_t i = 0; i != leading_size; ++i) {
  79 + for (int32_t n = 0; n != static_cast<int32_t>(values.size()); ++n) {
  80 + const T *src = values[n]->GetTensorData<T>();
  81 + src += i * trailing_size;
  82 +
  83 + std::copy(src, src + trailing_size, dst);
  84 + dst += trailing_size;
  85 + }
  86 + }
  87 +
  88 + return ans;
  89 +}
  90 +
  91 +template Ort::Value Stack<float>(
  92 + OrtAllocator *allocator,
  93 + const std::vector<const Ort::Value *> &values,
  94 + int32_t dim);
  95 +
  96 +template Ort::Value Stack<int64_t>(
  97 + OrtAllocator *allocator,
  98 + const std::vector<const Ort::Value *> &values,
  99 + int32_t dim);
  100 +
  101 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/stack.h
  2 +//
  3 +// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com)
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_STACK_H_
  6 +#define SHERPA_ONNX_CSRC_STACK_H_
  7 +
  8 +#include <vector>
  9 +
  10 +#include "onnxruntime_cxx_api.h" // NOLINT
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +/** Stack a list of tensors along the given dim.
  15 + *
  16 + * @param allocator Allocator to allocate space for the returned tensor
  17 + * @param values Pointer to a list of tensors. The shape of the tensor must
  18 + * be the same except on the dim to be stacked.
  19 + * @param dim The dim along which to concatenate the input tensors
  20 + *
  21 + * @return Return the stacked tensor
  22 + */
  23 +template <typename T = float>
  24 +Ort::Value Stack(OrtAllocator *allocator,
  25 + const std::vector<const Ort::Value *> &values, int32_t dim);
  26 +
  27 +} // namespace sherpa_onnx
  28 +
  29 +#endif // SHERPA_ONNX_CSRC_STACK_H_