Committed by
GitHub
Stack and streaming conformer support (#141)
* added csrc/stack.cc * stack: added checks * added copyright info * passed cpp style checks * formatted code * added some support for streaming conformer model support (not verified) * code lint * made more progress with streaming conformer support (not working yet) * passed style check * changes as suggested by @csukuangfj * added some debug info * fixed style check * Use Cat to replace Stack * remove debug statements --------- Co-authored-by: Jingzhao Ou (jou2019) <jou2019@cisco.com> Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
正在显示
15 个修改的文件
包含
836 行增加
和
8 行删除
| @@ -34,6 +34,7 @@ set(sources | @@ -34,6 +34,7 @@ set(sources | ||
| 34 | offline-transducer-model-config.cc | 34 | offline-transducer-model-config.cc |
| 35 | offline-transducer-model.cc | 35 | offline-transducer-model.cc |
| 36 | offline-transducer-modified-beam-search-decoder.cc | 36 | offline-transducer-modified-beam-search-decoder.cc |
| 37 | + online-conformer-transducer-model.cc | ||
| 37 | online-lm.cc | 38 | online-lm.cc |
| 38 | online-lm-config.cc | 39 | online-lm-config.cc |
| 39 | online-lstm-transducer-model.cc | 40 | online-lstm-transducer-model.cc |
| @@ -52,6 +53,7 @@ set(sources | @@ -52,6 +53,7 @@ set(sources | ||
| 52 | parse-options.cc | 53 | parse-options.cc |
| 53 | resample.cc | 54 | resample.cc |
| 54 | slice.cc | 55 | slice.cc |
| 56 | + stack.cc | ||
| 55 | symbol-table.cc | 57 | symbol-table.cc |
| 56 | text-utils.cc | 58 | text-utils.cc |
| 57 | transpose.cc | 59 | transpose.cc |
| @@ -241,6 +243,7 @@ if(SHERPA_ONNX_ENABLE_TESTS) | @@ -241,6 +243,7 @@ if(SHERPA_ONNX_ENABLE_TESTS) | ||
| 241 | packed-sequence-test.cc | 243 | packed-sequence-test.cc |
| 242 | pad-sequence-test.cc | 244 | pad-sequence-test.cc |
| 243 | slice-test.cc | 245 | slice-test.cc |
| 246 | + stack-test.cc | ||
| 244 | transpose-test.cc | 247 | transpose-test.cc |
| 245 | unbind-test.cc | 248 | unbind-test.cc |
| 246 | ) | 249 | ) |
| 1 | +// sherpa-onnx/csrc/online-conformer-transducer-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com) | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/online-conformer-transducer-model.h" | ||
| 6 | + | ||
| 7 | +#include <assert.h> | ||
| 8 | + | ||
| 9 | +#include <algorithm> | ||
| 10 | +#include <memory> | ||
| 11 | +#include <sstream> | ||
| 12 | +#include <iostream> | ||
| 13 | +#include <string> | ||
| 14 | +#include <utility> | ||
| 15 | +#include <vector> | ||
| 16 | + | ||
| 17 | +#if __ANDROID_API__ >= 9 | ||
| 18 | +#include "android/asset_manager.h" | ||
| 19 | +#include "android/asset_manager_jni.h" | ||
| 20 | +#endif | ||
| 21 | + | ||
| 22 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 23 | +#include "sherpa-onnx/csrc/cat.h" | ||
| 24 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 25 | +#include "sherpa-onnx/csrc/online-transducer-decoder.h" | ||
| 26 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 27 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 28 | +#include "sherpa-onnx/csrc/unbind.h" | ||
| 29 | + | ||
| 30 | +namespace sherpa_onnx { | ||
| 31 | + | ||
| 32 | +OnlineConformerTransducerModel::OnlineConformerTransducerModel( | ||
| 33 | + const OnlineTransducerModelConfig &config) | ||
| 34 | + : env_(ORT_LOGGING_LEVEL_WARNING), | ||
| 35 | + config_(config), | ||
| 36 | + sess_opts_{}, | ||
| 37 | + allocator_{} { | ||
| 38 | + sess_opts_.SetIntraOpNumThreads(config.num_threads); | ||
| 39 | + sess_opts_.SetInterOpNumThreads(config.num_threads); | ||
| 40 | + | ||
| 41 | + { | ||
| 42 | + auto buf = ReadFile(config.encoder_filename); | ||
| 43 | + InitEncoder(buf.data(), buf.size()); | ||
| 44 | + } | ||
| 45 | + | ||
| 46 | + { | ||
| 47 | + auto buf = ReadFile(config.decoder_filename); | ||
| 48 | + InitDecoder(buf.data(), buf.size()); | ||
| 49 | + } | ||
| 50 | + | ||
| 51 | + { | ||
| 52 | + auto buf = ReadFile(config.joiner_filename); | ||
| 53 | + InitJoiner(buf.data(), buf.size()); | ||
| 54 | + } | ||
| 55 | +} | ||
| 56 | + | ||
| 57 | +#if __ANDROID_API__ >= 9 | ||
| 58 | +OnlineConformerTransducerModel::OnlineConformerTransducerModel( | ||
| 59 | + AAssetManager *mgr, const OnlineTransducerModelConfig &config) | ||
| 60 | + : env_(ORT_LOGGING_LEVEL_WARNING), | ||
| 61 | + config_(config), | ||
| 62 | + sess_opts_{}, | ||
| 63 | + allocator_{} { | ||
| 64 | + sess_opts_.SetIntraOpNumThreads(config.num_threads); | ||
| 65 | + sess_opts_.SetInterOpNumThreads(config.num_threads); | ||
| 66 | + | ||
| 67 | + { | ||
| 68 | + auto buf = ReadFile(mgr, config.encoder_filename); | ||
| 69 | + InitEncoder(buf.data(), buf.size()); | ||
| 70 | + } | ||
| 71 | + | ||
| 72 | + { | ||
| 73 | + auto buf = ReadFile(mgr, config.decoder_filename); | ||
| 74 | + InitDecoder(buf.data(), buf.size()); | ||
| 75 | + } | ||
| 76 | + | ||
| 77 | + { | ||
| 78 | + auto buf = ReadFile(mgr, config.joiner_filename); | ||
| 79 | + InitJoiner(buf.data(), buf.size()); | ||
| 80 | + } | ||
| 81 | +} | ||
| 82 | +#endif | ||
| 83 | + | ||
| 84 | +void OnlineConformerTransducerModel::InitEncoder(void *model_data, | ||
| 85 | + size_t model_data_length) { | ||
| 86 | + encoder_sess_ = std::make_unique<Ort::Session>(env_, model_data, | ||
| 87 | + model_data_length, sess_opts_); | ||
| 88 | + | ||
| 89 | + GetInputNames(encoder_sess_.get(), &encoder_input_names_, | ||
| 90 | + &encoder_input_names_ptr_); | ||
| 91 | + | ||
| 92 | + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, | ||
| 93 | + &encoder_output_names_ptr_); | ||
| 94 | + | ||
| 95 | + // get meta data | ||
| 96 | + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); | ||
| 97 | + if (config_.debug) { | ||
| 98 | + std::ostringstream os; | ||
| 99 | + os << "---encoder---\n"; | ||
| 100 | + PrintModelMetadata(os, meta_data); | ||
| 101 | + SHERPA_ONNX_LOGE("%s", os.str().c_str()); | ||
| 102 | + } | ||
| 103 | + | ||
| 104 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 105 | + SHERPA_ONNX_READ_META_DATA(num_encoder_layers_, "num_encoder_layers"); | ||
| 106 | + SHERPA_ONNX_READ_META_DATA(T_, "T"); | ||
| 107 | + SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); | ||
| 108 | + SHERPA_ONNX_READ_META_DATA(left_context_, "left_context"); | ||
| 109 | + SHERPA_ONNX_READ_META_DATA(encoder_dim_, "encoder_dim"); | ||
| 110 | + SHERPA_ONNX_READ_META_DATA(pad_length_, "pad_length"); | ||
| 111 | + SHERPA_ONNX_READ_META_DATA(cnn_module_kernel_, "cnn_module_kernel"); | ||
| 112 | +} | ||
| 113 | + | ||
| 114 | +void OnlineConformerTransducerModel::InitDecoder(void *model_data, | ||
| 115 | + size_t model_data_length) { | ||
| 116 | + decoder_sess_ = std::make_unique<Ort::Session>(env_, model_data, | ||
| 117 | + model_data_length, sess_opts_); | ||
| 118 | + | ||
| 119 | + GetInputNames(decoder_sess_.get(), &decoder_input_names_, | ||
| 120 | + &decoder_input_names_ptr_); | ||
| 121 | + | ||
| 122 | + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, | ||
| 123 | + &decoder_output_names_ptr_); | ||
| 124 | + | ||
| 125 | + // get meta data | ||
| 126 | + Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata(); | ||
| 127 | + if (config_.debug) { | ||
| 128 | + std::ostringstream os; | ||
| 129 | + os << "---decoder---\n"; | ||
| 130 | + PrintModelMetadata(os, meta_data); | ||
| 131 | + SHERPA_ONNX_LOGE("%s", os.str().c_str()); | ||
| 132 | + } | ||
| 133 | + | ||
| 134 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 135 | + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); | ||
| 136 | + SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); | ||
| 137 | +} | ||
| 138 | + | ||
| 139 | +void OnlineConformerTransducerModel::InitJoiner(void *model_data, | ||
| 140 | + size_t model_data_length) { | ||
| 141 | + joiner_sess_ = std::make_unique<Ort::Session>(env_, model_data, | ||
| 142 | + model_data_length, sess_opts_); | ||
| 143 | + | ||
| 144 | + GetInputNames(joiner_sess_.get(), &joiner_input_names_, | ||
| 145 | + &joiner_input_names_ptr_); | ||
| 146 | + | ||
| 147 | + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, | ||
| 148 | + &joiner_output_names_ptr_); | ||
| 149 | + | ||
| 150 | + // get meta data | ||
| 151 | + Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata(); | ||
| 152 | + if (config_.debug) { | ||
| 153 | + std::ostringstream os; | ||
| 154 | + os << "---joiner---\n"; | ||
| 155 | + PrintModelMetadata(os, meta_data); | ||
| 156 | + SHERPA_ONNX_LOGE("%s", os.str().c_str()); | ||
| 157 | + } | ||
| 158 | +} | ||
| 159 | + | ||
| 160 | +std::vector<Ort::Value> OnlineConformerTransducerModel::StackStates( | ||
| 161 | + const std::vector<std::vector<Ort::Value>> &states) const { | ||
| 162 | + int32_t batch_size = static_cast<int32_t>(states.size()); | ||
| 163 | + | ||
| 164 | + std::vector<const Ort::Value *> attn_vec(batch_size); | ||
| 165 | + std::vector<const Ort::Value *> conv_vec(batch_size); | ||
| 166 | + | ||
| 167 | + for (int32_t i = 0; i != batch_size; ++i) { | ||
| 168 | + assert(states[i].size() == 2); | ||
| 169 | + attn_vec[i] = &states[i][0]; | ||
| 170 | + conv_vec[i] = &states[i][1]; | ||
| 171 | + } | ||
| 172 | + | ||
| 173 | + Ort::Value attn = Cat(allocator_, attn_vec, 2); | ||
| 174 | + Ort::Value conv = Cat(allocator_, conv_vec, 2); | ||
| 175 | + | ||
| 176 | + std::vector<Ort::Value> ans; | ||
| 177 | + ans.reserve(2); | ||
| 178 | + ans.push_back(std::move(attn)); | ||
| 179 | + ans.push_back(std::move(conv)); | ||
| 180 | + | ||
| 181 | + return ans; | ||
| 182 | +} | ||
| 183 | + | ||
| 184 | +std::vector<std::vector<Ort::Value>> | ||
| 185 | +OnlineConformerTransducerModel::UnStackStates( | ||
| 186 | + const std::vector<Ort::Value> &states) const { | ||
| 187 | + const int32_t batch_size = | ||
| 188 | + states[0].GetTensorTypeAndShapeInfo().GetShape()[2]; | ||
| 189 | + assert(states.size() == 2); | ||
| 190 | + | ||
| 191 | + std::vector<std::vector<Ort::Value>> ans(batch_size); | ||
| 192 | + | ||
| 193 | + std::vector<Ort::Value> attn_vec = Unbind(allocator_, &states[0], 2); | ||
| 194 | + std::vector<Ort::Value> conv_vec = Unbind(allocator_, &states[1], 2); | ||
| 195 | + | ||
| 196 | + assert(attn_vec.size() == batch_size); | ||
| 197 | + assert(conv_vec.size() == batch_size); | ||
| 198 | + | ||
| 199 | + for (int32_t i = 0; i != batch_size; ++i) { | ||
| 200 | + ans[i].push_back(std::move(attn_vec[i])); | ||
| 201 | + ans[i].push_back(std::move(conv_vec[i])); | ||
| 202 | + } | ||
| 203 | + | ||
| 204 | + return ans; | ||
| 205 | +} | ||
| 206 | + | ||
| 207 | +std::vector<Ort::Value> OnlineConformerTransducerModel::GetEncoderInitStates() { | ||
| 208 | + // Please see | ||
| 209 | + // https://github.com/k2-fsa/icefall/blob/86b0db6eb9c84d9bc90a71d92774fe2a7f73e6ab/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py#L203 | ||
| 210 | + // for details | ||
| 211 | + constexpr int32_t kBatchSize = 1; | ||
| 212 | + std::array<int64_t, 4> h_shape{ | ||
| 213 | + num_encoder_layers_, left_context_, kBatchSize, encoder_dim_}; | ||
| 214 | + Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(), | ||
| 215 | + h_shape.size()); | ||
| 216 | + | ||
| 217 | + Fill<float>(&h, 0); | ||
| 218 | + | ||
| 219 | + std::array<int64_t, 4> c_shape{num_encoder_layers_, cnn_module_kernel_ - 1, | ||
| 220 | + kBatchSize, encoder_dim_}; | ||
| 221 | + | ||
| 222 | + Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(), | ||
| 223 | + c_shape.size()); | ||
| 224 | + | ||
| 225 | + Fill<float>(&c, 0); | ||
| 226 | + | ||
| 227 | + std::vector<Ort::Value> states; | ||
| 228 | + | ||
| 229 | + states.reserve(2); | ||
| 230 | + states.push_back(std::move(h)); | ||
| 231 | + states.push_back(std::move(c)); | ||
| 232 | + | ||
| 233 | + return states; | ||
| 234 | +} | ||
| 235 | + | ||
| 236 | +std::pair<Ort::Value, std::vector<Ort::Value>> | ||
| 237 | +OnlineConformerTransducerModel::RunEncoder(Ort::Value features, | ||
| 238 | + std::vector<Ort::Value> states, | ||
| 239 | + Ort::Value processed_frames) { | ||
| 240 | + std::array<Ort::Value, 4> encoder_inputs = { | ||
| 241 | + std::move(features), | ||
| 242 | + std::move(states[0]), | ||
| 243 | + std::move(states[1]), | ||
| 244 | + std::move(processed_frames)}; | ||
| 245 | + | ||
| 246 | + auto encoder_out = encoder_sess_->Run( | ||
| 247 | + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(), | ||
| 248 | + encoder_inputs.size(), encoder_output_names_ptr_.data(), | ||
| 249 | + encoder_output_names_ptr_.size()); | ||
| 250 | + | ||
| 251 | + std::vector<Ort::Value> next_states; | ||
| 252 | + next_states.reserve(2); | ||
| 253 | + next_states.push_back(std::move(encoder_out[1])); | ||
| 254 | + next_states.push_back(std::move(encoder_out[2])); | ||
| 255 | + | ||
| 256 | + return {std::move(encoder_out[0]), std::move(next_states)}; | ||
| 257 | +} | ||
| 258 | + | ||
| 259 | +Ort::Value OnlineConformerTransducerModel::RunDecoder( | ||
| 260 | + Ort::Value decoder_input) { | ||
| 261 | + auto decoder_out = decoder_sess_->Run( | ||
| 262 | + {}, decoder_input_names_ptr_.data(), &decoder_input, 1, | ||
| 263 | + decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size()); | ||
| 264 | + return std::move(decoder_out[0]); | ||
| 265 | +} | ||
| 266 | + | ||
| 267 | +Ort::Value OnlineConformerTransducerModel::RunJoiner(Ort::Value encoder_out, | ||
| 268 | + Ort::Value decoder_out) { | ||
| 269 | + std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out), | ||
| 270 | + std::move(decoder_out)}; | ||
| 271 | + auto logit = | ||
| 272 | + joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(), | ||
| 273 | + joiner_input.size(), joiner_output_names_ptr_.data(), | ||
| 274 | + joiner_output_names_ptr_.size()); | ||
| 275 | + | ||
| 276 | + return std::move(logit[0]); | ||
| 277 | +} | ||
| 278 | + | ||
| 279 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/online-conformer-transducer-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com) | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_CONFORMER_TRANSDUCER_MODEL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_ONLINE_CONFORMER_TRANSDUCER_MODEL_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | +#include <string> | ||
| 10 | +#include <utility> | ||
| 11 | +#include <vector> | ||
| 12 | + | ||
| 13 | +#if __ANDROID_API__ >= 9 | ||
| 14 | +#include "android/asset_manager.h" | ||
| 15 | +#include "android/asset_manager_jni.h" | ||
| 16 | +#endif | ||
| 17 | + | ||
| 18 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 19 | +#include "sherpa-onnx/csrc/online-transducer-model-config.h" | ||
| 20 | +#include "sherpa-onnx/csrc/online-transducer-model.h" | ||
| 21 | + | ||
| 22 | +namespace sherpa_onnx { | ||
| 23 | + | ||
| 24 | +class OnlineConformerTransducerModel : public OnlineTransducerModel { | ||
| 25 | + public: | ||
| 26 | + explicit OnlineConformerTransducerModel( | ||
| 27 | + const OnlineTransducerModelConfig &config); | ||
| 28 | + | ||
| 29 | +#if __ANDROID_API__ >= 9 | ||
| 30 | + OnlineConformerTransducerModel(AAssetManager *mgr, | ||
| 31 | + const OnlineTransducerModelConfig &config); | ||
| 32 | +#endif | ||
| 33 | + | ||
| 34 | + std::vector<Ort::Value> StackStates( | ||
| 35 | + const std::vector<std::vector<Ort::Value>> &states) const override; | ||
| 36 | + | ||
| 37 | + std::vector<std::vector<Ort::Value>> UnStackStates( | ||
| 38 | + const std::vector<Ort::Value> &states) const override; | ||
| 39 | + | ||
| 40 | + std::vector<Ort::Value> GetEncoderInitStates() override; | ||
| 41 | + | ||
| 42 | + std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( | ||
| 43 | + Ort::Value features, std::vector<Ort::Value> states, | ||
| 44 | + Ort::Value processed_frames) override; | ||
| 45 | + | ||
| 46 | + Ort::Value RunDecoder(Ort::Value decoder_input) override; | ||
| 47 | + | ||
| 48 | + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override; | ||
| 49 | + | ||
| 50 | + int32_t ContextSize() const override { return context_size_; } | ||
| 51 | + | ||
| 52 | + int32_t ChunkSize() const override { return T_; } | ||
| 53 | + | ||
| 54 | + int32_t ChunkShift() const override { return decode_chunk_len_; } | ||
| 55 | + | ||
| 56 | + int32_t VocabSize() const override { return vocab_size_; } | ||
| 57 | + OrtAllocator *Allocator() override { return allocator_; } | ||
| 58 | + | ||
| 59 | + private: | ||
| 60 | + void InitEncoder(void *model_data, size_t model_data_length); | ||
| 61 | + void InitDecoder(void *model_data, size_t model_data_length); | ||
| 62 | + void InitJoiner(void *model_data, size_t model_data_length); | ||
| 63 | + | ||
| 64 | + private: | ||
| 65 | + Ort::Env env_; | ||
| 66 | + Ort::SessionOptions sess_opts_; | ||
| 67 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 68 | + | ||
| 69 | + std::unique_ptr<Ort::Session> encoder_sess_; | ||
| 70 | + std::unique_ptr<Ort::Session> decoder_sess_; | ||
| 71 | + std::unique_ptr<Ort::Session> joiner_sess_; | ||
| 72 | + | ||
| 73 | + std::vector<std::string> encoder_input_names_; | ||
| 74 | + std::vector<const char *> encoder_input_names_ptr_; | ||
| 75 | + | ||
| 76 | + std::vector<std::string> encoder_output_names_; | ||
| 77 | + std::vector<const char *> encoder_output_names_ptr_; | ||
| 78 | + | ||
| 79 | + std::vector<std::string> decoder_input_names_; | ||
| 80 | + std::vector<const char *> decoder_input_names_ptr_; | ||
| 81 | + | ||
| 82 | + std::vector<std::string> decoder_output_names_; | ||
| 83 | + std::vector<const char *> decoder_output_names_ptr_; | ||
| 84 | + | ||
| 85 | + std::vector<std::string> joiner_input_names_; | ||
| 86 | + std::vector<const char *> joiner_input_names_ptr_; | ||
| 87 | + | ||
| 88 | + std::vector<std::string> joiner_output_names_; | ||
| 89 | + std::vector<const char *> joiner_output_names_ptr_; | ||
| 90 | + | ||
| 91 | + OnlineTransducerModelConfig config_; | ||
| 92 | + | ||
| 93 | + int32_t num_encoder_layers_ = 0; | ||
| 94 | + int32_t T_ = 0; | ||
| 95 | + int32_t decode_chunk_len_ = 0; | ||
| 96 | + int32_t cnn_module_kernel_ = 0; | ||
| 97 | + int32_t context_size_ = 0; | ||
| 98 | + int32_t left_context_ = 0; | ||
| 99 | + // TODO(jingzhaoou): to retrieve from model medadata | ||
| 100 | + int32_t right_context_ = 4; | ||
| 101 | + int32_t encoder_dim_ = 0; | ||
| 102 | + int32_t pad_length_ = 0; | ||
| 103 | + int32_t vocab_size_ = 0; | ||
| 104 | +}; | ||
| 105 | + | ||
| 106 | +} // namespace sherpa_onnx | ||
| 107 | + | ||
| 108 | +#endif // SHERPA_ONNX_CSRC_ONLINE_CONFORMER_TRANSDUCER_MODEL_H_ |
| @@ -227,7 +227,8 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() { | @@ -227,7 +227,8 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() { | ||
| 227 | 227 | ||
| 228 | std::pair<Ort::Value, std::vector<Ort::Value>> | 228 | std::pair<Ort::Value, std::vector<Ort::Value>> |
| 229 | OnlineLstmTransducerModel::RunEncoder(Ort::Value features, | 229 | OnlineLstmTransducerModel::RunEncoder(Ort::Value features, |
| 230 | - std::vector<Ort::Value> states) { | 230 | + std::vector<Ort::Value> states, |
| 231 | + Ort::Value /* processed_frames */) { | ||
| 231 | std::array<Ort::Value, 3> encoder_inputs = { | 232 | std::array<Ort::Value, 3> encoder_inputs = { |
| 232 | std::move(features), std::move(states[0]), std::move(states[1])}; | 233 | std::move(features), std::move(states[0]), std::move(states[1])}; |
| 233 | 234 |
| @@ -38,7 +38,8 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { | @@ -38,7 +38,8 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { | ||
| 38 | std::vector<Ort::Value> GetEncoderInitStates() override; | 38 | std::vector<Ort::Value> GetEncoderInitStates() override; |
| 39 | 39 | ||
| 40 | std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( | 40 | std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( |
| 41 | - Ort::Value features, std::vector<Ort::Value> states) override; | 41 | + Ort::Value features, std::vector<Ort::Value> states, |
| 42 | + Ort::Value processed_frames) override; | ||
| 42 | 43 | ||
| 43 | Ort::Value RunDecoder(Ort::Value decoder_input) override; | 44 | Ort::Value RunDecoder(Ort::Value decoder_input) override; |
| 44 | 45 |
| @@ -9,6 +9,7 @@ | @@ -9,6 +9,7 @@ | ||
| 9 | 9 | ||
| 10 | #include <algorithm> | 10 | #include <algorithm> |
| 11 | #include <iomanip> | 11 | #include <iomanip> |
| 12 | +#include <iostream> | ||
| 12 | #include <memory> | 13 | #include <memory> |
| 13 | #include <sstream> | 14 | #include <sstream> |
| 14 | #include <utility> | 15 | #include <utility> |
| @@ -187,11 +188,14 @@ class OnlineRecognizer::Impl { | @@ -187,11 +188,14 @@ class OnlineRecognizer::Impl { | ||
| 187 | std::vector<OnlineTransducerDecoderResult> results(n); | 188 | std::vector<OnlineTransducerDecoderResult> results(n); |
| 188 | std::vector<float> features_vec(n * chunk_size * feature_dim); | 189 | std::vector<float> features_vec(n * chunk_size * feature_dim); |
| 189 | std::vector<std::vector<Ort::Value>> states_vec(n); | 190 | std::vector<std::vector<Ort::Value>> states_vec(n); |
| 191 | + std::vector<int64_t> all_processed_frames(n); | ||
| 190 | 192 | ||
| 191 | for (int32_t i = 0; i != n; ++i) { | 193 | for (int32_t i = 0; i != n; ++i) { |
| 194 | + const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); | ||
| 192 | std::vector<float> features = | 195 | std::vector<float> features = |
| 193 | - ss[i]->GetFrames(ss[i]->GetNumProcessedFrames(), chunk_size); | 196 | + ss[i]->GetFrames(num_processed_frames, chunk_size); |
| 194 | 197 | ||
| 198 | + // Question: should num_processed_frames include chunk_shift? | ||
| 195 | ss[i]->GetNumProcessedFrames() += chunk_shift; | 199 | ss[i]->GetNumProcessedFrames() += chunk_shift; |
| 196 | 200 | ||
| 197 | std::copy(features.begin(), features.end(), | 201 | std::copy(features.begin(), features.end(), |
| @@ -199,6 +203,7 @@ class OnlineRecognizer::Impl { | @@ -199,6 +203,7 @@ class OnlineRecognizer::Impl { | ||
| 199 | 203 | ||
| 200 | results[i] = std::move(ss[i]->GetResult()); | 204 | results[i] = std::move(ss[i]->GetResult()); |
| 201 | states_vec[i] = std::move(ss[i]->GetStates()); | 205 | states_vec[i] = std::move(ss[i]->GetStates()); |
| 206 | + all_processed_frames[i] = num_processed_frames; | ||
| 202 | } | 207 | } |
| 203 | 208 | ||
| 204 | auto memory_info = | 209 | auto memory_info = |
| @@ -210,9 +215,20 @@ class OnlineRecognizer::Impl { | @@ -210,9 +215,20 @@ class OnlineRecognizer::Impl { | ||
| 210 | features_vec.size(), x_shape.data(), | 215 | features_vec.size(), x_shape.data(), |
| 211 | x_shape.size()); | 216 | x_shape.size()); |
| 212 | 217 | ||
| 218 | + std::array<int64_t, 1> processed_frames_shape{ | ||
| 219 | + static_cast<int64_t>(all_processed_frames.size())}; | ||
| 220 | + | ||
| 221 | + Ort::Value processed_frames = Ort::Value::CreateTensor( | ||
| 222 | + memory_info, | ||
| 223 | + all_processed_frames.data(), | ||
| 224 | + all_processed_frames.size(), | ||
| 225 | + processed_frames_shape.data(), | ||
| 226 | + processed_frames_shape.size()); | ||
| 227 | + | ||
| 213 | auto states = model_->StackStates(states_vec); | 228 | auto states = model_->StackStates(states_vec); |
| 214 | 229 | ||
| 215 | - auto pair = model_->RunEncoder(std::move(x), std::move(states)); | 230 | + auto pair = model_->RunEncoder( |
| 231 | + std::move(x), std::move(states), std::move(processed_frames)); | ||
| 216 | 232 | ||
| 217 | decoder_->Decode(std::move(pair.first), &results); | 233 | decoder_->Decode(std::move(pair.first), &results); |
| 218 | 234 |
| @@ -10,11 +10,13 @@ | @@ -10,11 +10,13 @@ | ||
| 10 | #endif | 10 | #endif |
| 11 | 11 | ||
| 12 | #include <algorithm> | 12 | #include <algorithm> |
| 13 | +#include <iostream> | ||
| 13 | #include <memory> | 14 | #include <memory> |
| 14 | #include <sstream> | 15 | #include <sstream> |
| 15 | #include <string> | 16 | #include <string> |
| 16 | 17 | ||
| 17 | #include "sherpa-onnx/csrc/macros.h" | 18 | #include "sherpa-onnx/csrc/macros.h" |
| 19 | +#include "sherpa-onnx/csrc/online-conformer-transducer-model.h" | ||
| 18 | #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" | 20 | #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" |
| 19 | #include "sherpa-onnx/csrc/online-zipformer-transducer-model.h" | 21 | #include "sherpa-onnx/csrc/online-zipformer-transducer-model.h" |
| 20 | #include "sherpa-onnx/csrc/onnx-utils.h" | 22 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| @@ -22,6 +24,7 @@ | @@ -22,6 +24,7 @@ | ||
| 22 | namespace { | 24 | namespace { |
| 23 | 25 | ||
| 24 | enum class ModelType { | 26 | enum class ModelType { |
| 27 | + kConformer, | ||
| 25 | kLstm, | 28 | kLstm, |
| 26 | kZipformer, | 29 | kZipformer, |
| 27 | kUnkown, | 30 | kUnkown, |
| @@ -57,7 +60,9 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -57,7 +60,9 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 57 | return ModelType::kUnkown; | 60 | return ModelType::kUnkown; |
| 58 | } | 61 | } |
| 59 | 62 | ||
| 60 | - if (model_type.get() == std::string("lstm")) { | 63 | + if (model_type.get() == std::string("conformer")) { |
| 64 | + return ModelType::kConformer; | ||
| 65 | + } else if (model_type.get() == std::string("lstm")) { | ||
| 61 | return ModelType::kLstm; | 66 | return ModelType::kLstm; |
| 62 | } else if (model_type.get() == std::string("zipformer")) { | 67 | } else if (model_type.get() == std::string("zipformer")) { |
| 63 | return ModelType::kZipformer; | 68 | return ModelType::kZipformer; |
| @@ -78,6 +83,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | @@ -78,6 +83,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | ||
| 78 | } | 83 | } |
| 79 | 84 | ||
| 80 | switch (model_type) { | 85 | switch (model_type) { |
| 86 | + case ModelType::kConformer: | ||
| 87 | + return std::make_unique<OnlineConformerTransducerModel>(config); | ||
| 81 | case ModelType::kLstm: | 88 | case ModelType::kLstm: |
| 82 | return std::make_unique<OnlineLstmTransducerModel>(config); | 89 | return std::make_unique<OnlineLstmTransducerModel>(config); |
| 83 | case ModelType::kZipformer: | 90 | case ModelType::kZipformer: |
| @@ -132,6 +139,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | @@ -132,6 +139,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | ||
| 132 | auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug); | 139 | auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug); |
| 133 | 140 | ||
| 134 | switch (model_type) { | 141 | switch (model_type) { |
| 142 | + case ModelType::kConformer: | ||
| 143 | + return std::make_unique<OnlineConformerTransducerModel>(mgr, config); | ||
| 135 | case ModelType::kLstm: | 144 | case ModelType::kLstm: |
| 136 | return std::make_unique<OnlineLstmTransducerModel>(mgr, config); | 145 | return std::make_unique<OnlineLstmTransducerModel>(mgr, config); |
| 137 | case ModelType::kZipformer: | 146 | case ModelType::kZipformer: |
| @@ -64,6 +64,7 @@ class OnlineTransducerModel { | @@ -64,6 +64,7 @@ class OnlineTransducerModel { | ||
| 64 | * | 64 | * |
| 65 | * @param features A tensor of shape (N, T, C). It is changed in-place. | 65 | * @param features A tensor of shape (N, T, C). It is changed in-place. |
| 66 | * @param states Encoder state of the previous chunk. It is changed in-place. | 66 | * @param states Encoder state of the previous chunk. It is changed in-place. |
| 67 | + * @param processed_frames Processed frames before subsampling. It is a 1-D tensor with data type int64_t. | ||
| 67 | * | 68 | * |
| 68 | * @return Return a tuple containing: | 69 | * @return Return a tuple containing: |
| 69 | * - encoder_out, a tensor of shape (N, T', encoder_out_dim) | 70 | * - encoder_out, a tensor of shape (N, T', encoder_out_dim) |
| @@ -71,7 +72,8 @@ class OnlineTransducerModel { | @@ -71,7 +72,8 @@ class OnlineTransducerModel { | ||
| 71 | */ | 72 | */ |
| 72 | virtual std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( | 73 | virtual std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( |
| 73 | Ort::Value features, | 74 | Ort::Value features, |
| 74 | - std::vector<Ort::Value> states) = 0; // NOLINT | 75 | + std::vector<Ort::Value> states, |
| 76 | + Ort::Value processed_frames) = 0; // NOLINT | ||
| 75 | 77 | ||
| 76 | /** Run the decoder network. | 78 | /** Run the decoder network. |
| 77 | * | 79 | * |
| @@ -434,7 +434,8 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::GetEncoderInitStates() { | @@ -434,7 +434,8 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::GetEncoderInitStates() { | ||
| 434 | 434 | ||
| 435 | std::pair<Ort::Value, std::vector<Ort::Value>> | 435 | std::pair<Ort::Value, std::vector<Ort::Value>> |
| 436 | OnlineZipformerTransducerModel::RunEncoder(Ort::Value features, | 436 | OnlineZipformerTransducerModel::RunEncoder(Ort::Value features, |
| 437 | - std::vector<Ort::Value> states) { | 437 | + std::vector<Ort::Value> states, |
| 438 | + Ort::Value /* processed_frames */) { | ||
| 438 | std::vector<Ort::Value> encoder_inputs; | 439 | std::vector<Ort::Value> encoder_inputs; |
| 439 | encoder_inputs.reserve(1 + states.size()); | 440 | encoder_inputs.reserve(1 + states.size()); |
| 440 | 441 |
| @@ -39,7 +39,8 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel { | @@ -39,7 +39,8 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel { | ||
| 39 | std::vector<Ort::Value> GetEncoderInitStates() override; | 39 | std::vector<Ort::Value> GetEncoderInitStates() override; |
| 40 | 40 | ||
| 41 | std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( | 41 | std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( |
| 42 | - Ort::Value features, std::vector<Ort::Value> states) override; | 42 | + Ort::Value features, std::vector<Ort::Value> states, |
| 43 | + Ort::Value processed_frames) override; | ||
| 43 | 44 | ||
| 44 | Ort::Value RunDecoder(Ort::Value decoder_input) override; | 45 | Ort::Value RunDecoder(Ort::Value decoder_input) override; |
| 45 | 46 |
| @@ -168,6 +168,26 @@ void Print3D(Ort::Value *v) { | @@ -168,6 +168,26 @@ void Print3D(Ort::Value *v) { | ||
| 168 | fprintf(stderr, "\n"); | 168 | fprintf(stderr, "\n"); |
| 169 | } | 169 | } |
| 170 | 170 | ||
| 171 | +void Print4D(Ort::Value *v) { | ||
| 172 | + std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 173 | + const float *d = v->GetTensorData<float>(); | ||
| 174 | + | ||
| 175 | + for (int32_t p = 0; p != static_cast<int32_t>(shape[0]); ++p) { | ||
| 176 | + fprintf(stderr, "---plane %d---\n", p); | ||
| 177 | + for (int32_t q = 0; q != static_cast<int32_t>(shape[1]); ++q) { | ||
| 178 | + fprintf(stderr, "---subplane %d---\n", q); | ||
| 179 | + for (int32_t r = 0; r != static_cast<int32_t>(shape[2]); ++r) { | ||
| 180 | + for (int32_t c = 0; c != static_cast<int32_t>(shape[3]); ++c, ++d) { | ||
| 181 | + fprintf(stderr, "%.3f ", *d); | ||
| 182 | + } | ||
| 183 | + fprintf(stderr, "\n"); | ||
| 184 | + } | ||
| 185 | + fprintf(stderr, "\n"); | ||
| 186 | + } | ||
| 187 | + } | ||
| 188 | + fprintf(stderr, "\n"); | ||
| 189 | +} | ||
| 190 | + | ||
| 171 | std::vector<char> ReadFile(const std::string &filename) { | 191 | std::vector<char> ReadFile(const std::string &filename) { |
| 172 | std::ifstream input(filename, std::ios::binary); | 192 | std::ifstream input(filename, std::ios::binary); |
| 173 | std::vector<char> buffer(std::istreambuf_iterator<char>(input), {}); | 193 | std::vector<char> buffer(std::istreambuf_iterator<char>(input), {}); |
| @@ -75,6 +75,9 @@ void Print2D(Ort::Value *v); | @@ -75,6 +75,9 @@ void Print2D(Ort::Value *v); | ||
| 75 | // Print a 3-D tensor to stderr | 75 | // Print a 3-D tensor to stderr |
| 76 | void Print3D(Ort::Value *v); | 76 | void Print3D(Ort::Value *v); |
| 77 | 77 | ||
| 78 | +// Print a 4-D tensor to stderr | ||
| 79 | +void Print4D(Ort::Value *v); | ||
| 80 | + | ||
| 78 | template <typename T = float> | 81 | template <typename T = float> |
| 79 | void Fill(Ort::Value *tensor, T value) { | 82 | void Fill(Ort::Value *tensor, T value) { |
| 80 | auto n = tensor->GetTypeInfo().GetTensorTypeAndShapeInfo().GetElementCount(); | 83 | auto n = tensor->GetTypeInfo().GetTensorTypeAndShapeInfo().GetElementCount(); |
sherpa-onnx/csrc/stack-test.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/stack-test.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com) | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/stack.h" | ||
| 6 | + | ||
| 7 | +#include "gtest/gtest.h" | ||
| 8 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +TEST(Stack, Test1DTensors) { | ||
| 13 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 14 | + | ||
| 15 | + std::array<int64_t, 1> a_shape{3}; | ||
| 16 | + std::array<int64_t, 1> b_shape{3}; | ||
| 17 | + | ||
| 18 | + Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(), | ||
| 19 | + a_shape.size()); | ||
| 20 | + | ||
| 21 | + Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(), | ||
| 22 | + b_shape.size()); | ||
| 23 | + float *pa = a.GetTensorMutableData<float>(); | ||
| 24 | + float *pb = b.GetTensorMutableData<float>(); | ||
| 25 | + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) { | ||
| 26 | + pa[i] = i; | ||
| 27 | + } | ||
| 28 | + for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0]); ++i) { | ||
| 29 | + pb[i] = i + 10; | ||
| 30 | + } | ||
| 31 | + | ||
| 32 | + Ort::Value ans = Stack(allocator, {&a, &b}, 0); | ||
| 33 | + | ||
| 34 | + Print1D(&a); | ||
| 35 | + Print1D(&b); | ||
| 36 | + Print2D(&ans); | ||
| 37 | + | ||
| 38 | + const float *pans = ans.GetTensorData<float>(); | ||
| 39 | + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) { | ||
| 40 | + EXPECT_EQ(pa[i], pans[i]); | ||
| 41 | + } | ||
| 42 | + | ||
| 43 | + for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0]); ++i) { | ||
| 44 | + EXPECT_EQ(pb[i], pans[i + a_shape[0]]); | ||
| 45 | + } | ||
| 46 | +} | ||
| 47 | + | ||
| 48 | +TEST(Stack, Test2DTensorsDim0) { | ||
| 49 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 50 | + | ||
| 51 | + std::array<int64_t, 2> a_shape{2, 3}; | ||
| 52 | + std::array<int64_t, 2> b_shape{2, 3}; | ||
| 53 | + | ||
| 54 | + Ort::Value a = Ort::Value::CreateTensor<float>( | ||
| 55 | + allocator, a_shape.data(), a_shape.size()); | ||
| 56 | + | ||
| 57 | + Ort::Value b = Ort::Value::CreateTensor<float>( | ||
| 58 | + allocator, b_shape.data(), b_shape.size()); | ||
| 59 | + | ||
| 60 | + float *pa = a.GetTensorMutableData<float>(); | ||
| 61 | + float *pb = b.GetTensorMutableData<float>(); | ||
| 62 | + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) { | ||
| 63 | + pa[i] = i; | ||
| 64 | + } | ||
| 65 | + for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0] * b_shape[1]); ++i) { | ||
| 66 | + pb[i] = i + 10; | ||
| 67 | + } | ||
| 68 | + | ||
| 69 | + Ort::Value ans = Stack(allocator, {&a, &b}, 0); | ||
| 70 | + | ||
| 71 | + Print2D(&a); | ||
| 72 | + Print2D(&b); | ||
| 73 | + Print3D(&ans); | ||
| 74 | + | ||
| 75 | + const float *pans = ans.GetTensorData<float>(); | ||
| 76 | + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) { | ||
| 77 | + EXPECT_EQ(pa[i], pans[i]); | ||
| 78 | + } | ||
| 79 | + for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0] * b_shape[1]); ++i) { | ||
| 80 | + EXPECT_EQ(pb[i], pans[i + a_shape[0] * a_shape[1]]); | ||
| 81 | + } | ||
| 82 | +} | ||
| 83 | + | ||
| 84 | +TEST(Stack, Test2DTensorsDim1) { | ||
| 85 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 86 | + | ||
| 87 | + std::array<int64_t, 2> a_shape{4, 3}; | ||
| 88 | + std::array<int64_t, 2> b_shape{4, 3}; | ||
| 89 | + | ||
| 90 | + Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(), | ||
| 91 | + a_shape.size()); | ||
| 92 | + | ||
| 93 | + Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(), | ||
| 94 | + b_shape.size()); | ||
| 95 | + | ||
| 96 | + float *pa = a.GetTensorMutableData<float>(); | ||
| 97 | + float *pb = b.GetTensorMutableData<float>(); | ||
| 98 | + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) { | ||
| 99 | + pa[i] = i; | ||
| 100 | + } | ||
| 101 | + for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0] * b_shape[1]); ++i) { | ||
| 102 | + pb[i] = i + 10; | ||
| 103 | + } | ||
| 104 | + | ||
| 105 | + Ort::Value ans = Stack(allocator, {&a, &b}, 1); | ||
| 106 | + | ||
| 107 | + Print2D(&a); | ||
| 108 | + Print2D(&b); | ||
| 109 | + Print3D(&ans); | ||
| 110 | + | ||
| 111 | + const float *pans = ans.GetTensorData<float>(); | ||
| 112 | + | ||
| 113 | + for (int32_t r = 0; r != static_cast<int32_t>(a_shape[0]); ++r) { | ||
| 114 | + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[1]); | ||
| 115 | + ++i, ++pa, ++pans) { | ||
| 116 | + EXPECT_EQ(*pa, *pans); | ||
| 117 | + } | ||
| 118 | + | ||
| 119 | + for (int32_t i = 0; i != static_cast<int32_t>(b_shape[1]); | ||
| 120 | + ++i, ++pb, ++pans) { | ||
| 121 | + EXPECT_EQ(*pb, *pans); | ||
| 122 | + } | ||
| 123 | + } | ||
| 124 | +} | ||
| 125 | + | ||
| 126 | +TEST(Stack, Test3DTensorsDim0) { | ||
| 127 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 128 | + | ||
| 129 | + std::array<int64_t, 3> a_shape{2, 3, 2}; | ||
| 130 | + std::array<int64_t, 3> b_shape{2, 3, 2}; | ||
| 131 | + | ||
| 132 | + Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(), | ||
| 133 | + a_shape.size()); | ||
| 134 | + | ||
| 135 | + Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(), | ||
| 136 | + b_shape.size()); | ||
| 137 | + | ||
| 138 | + float *pa = a.GetTensorMutableData<float>(); | ||
| 139 | + float *pb = b.GetTensorMutableData<float>(); | ||
| 140 | + for (int32_t i = 0; | ||
| 141 | + i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) { | ||
| 142 | + pa[i] = i; | ||
| 143 | + } | ||
| 144 | + for (int32_t i = 0; | ||
| 145 | + i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) { | ||
| 146 | + pb[i] = i + 10; | ||
| 147 | + } | ||
| 148 | + | ||
| 149 | + Ort::Value ans = Stack(allocator, {&a, &b}, 0); | ||
| 150 | + | ||
| 151 | + const float *pans = ans.GetTensorData<float>(); | ||
| 152 | + for (int32_t i = 0; | ||
| 153 | + i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) { | ||
| 154 | + EXPECT_EQ(pa[i], pans[i]); | ||
| 155 | + } | ||
| 156 | + for (int32_t i = 0; | ||
| 157 | + i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) { | ||
| 158 | + EXPECT_EQ(pb[i], pans[i + a_shape[0] * a_shape[1] * a_shape[2]]); | ||
| 159 | + } | ||
| 160 | + | ||
| 161 | + Print3D(&a); | ||
| 162 | + Print3D(&b); | ||
| 163 | + Print4D(&ans); | ||
| 164 | +} | ||
| 165 | + | ||
| 166 | +TEST(Stack, Test3DTensorsDim1) { | ||
| 167 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 168 | + | ||
| 169 | + std::array<int64_t, 3> a_shape{2, 2, 3}; | ||
| 170 | + std::array<int64_t, 3> b_shape{2, 2, 3}; | ||
| 171 | + | ||
| 172 | + Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(), | ||
| 173 | + a_shape.size()); | ||
| 174 | + | ||
| 175 | + Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(), | ||
| 176 | + b_shape.size()); | ||
| 177 | + | ||
| 178 | + float *pa = a.GetTensorMutableData<float>(); | ||
| 179 | + float *pb = b.GetTensorMutableData<float>(); | ||
| 180 | + for (int32_t i = 0; | ||
| 181 | + i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) { | ||
| 182 | + pa[i] = i; | ||
| 183 | + } | ||
| 184 | + for (int32_t i = 0; | ||
| 185 | + i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) { | ||
| 186 | + pb[i] = i + 10; | ||
| 187 | + } | ||
| 188 | + | ||
| 189 | + Ort::Value ans = Stack(allocator, {&a, &b}, 1); | ||
| 190 | + | ||
| 191 | + const float *pans = ans.GetTensorData<float>(); | ||
| 192 | + | ||
| 193 | + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) { | ||
| 194 | + for (int32_t k = 0; k != static_cast<int32_t>(a_shape[1] * a_shape[2]); | ||
| 195 | + ++k, ++pa, ++pans) { | ||
| 196 | + EXPECT_EQ(*pa, *pans); | ||
| 197 | + } | ||
| 198 | + | ||
| 199 | + for (int32_t k = 0; k != static_cast<int32_t>(b_shape[1] * b_shape[2]); | ||
| 200 | + ++k, ++pb, ++pans) { | ||
| 201 | + EXPECT_EQ(*pb, *pans); | ||
| 202 | + } | ||
| 203 | + } | ||
| 204 | + | ||
| 205 | + Print3D(&a); | ||
| 206 | + Print3D(&b); | ||
| 207 | + Print4D(&ans); | ||
| 208 | +} | ||
| 209 | + | ||
| 210 | +TEST(Stack, Test3DTensorsDim2) { | ||
| 211 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 212 | + | ||
| 213 | + std::array<int64_t, 3> a_shape{2, 3, 4}; | ||
| 214 | + std::array<int64_t, 3> b_shape{2, 3, 4}; | ||
| 215 | + | ||
| 216 | + Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(), | ||
| 217 | + a_shape.size()); | ||
| 218 | + | ||
| 219 | + Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(), | ||
| 220 | + b_shape.size()); | ||
| 221 | + | ||
| 222 | + float *pa = a.GetTensorMutableData<float>(); | ||
| 223 | + float *pb = b.GetTensorMutableData<float>(); | ||
| 224 | + for (int32_t i = 0; | ||
| 225 | + i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) { | ||
| 226 | + pa[i] = i; | ||
| 227 | + } | ||
| 228 | + for (int32_t i = 0; | ||
| 229 | + i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) { | ||
| 230 | + pb[i] = i + 10; | ||
| 231 | + } | ||
| 232 | + | ||
| 233 | + Ort::Value ans = Stack(allocator, {&a, &b}, 2); | ||
| 234 | + | ||
| 235 | + const float *pans = ans.GetTensorData<float>(); | ||
| 236 | + | ||
| 237 | + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) { | ||
| 238 | + for (int32_t k = 0; k != static_cast<int32_t>(a_shape[2]); | ||
| 239 | + ++k, ++pa, ++pans) { | ||
| 240 | + EXPECT_EQ(*pa, *pans); | ||
| 241 | + } | ||
| 242 | + | ||
| 243 | + for (int32_t k = 0; k != static_cast<int32_t>(b_shape[2]); | ||
| 244 | + ++k, ++pb, ++pans) { | ||
| 245 | + EXPECT_EQ(*pb, *pans); | ||
| 246 | + } | ||
| 247 | + } | ||
| 248 | + | ||
| 249 | + Print3D(&a); | ||
| 250 | + Print3D(&b); | ||
| 251 | + Print4D(&ans); | ||
| 252 | +} | ||
| 253 | + | ||
| 254 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/stack.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/stack.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com) | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/stack.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <functional> | ||
| 9 | +#include <iostream> | ||
| 10 | +#include <numeric> | ||
| 11 | +#include <utility> | ||
| 12 | + | ||
| 13 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | + | ||
| 17 | +static bool Compare(const std::vector<int64_t> &a, | ||
| 18 | + const std::vector<int64_t> &b) { | ||
| 19 | + if (a.size() != b.size()) return false; | ||
| 20 | + | ||
| 21 | + for (int32_t i = 0; i != static_cast<int32_t>(a.size()); ++i) { | ||
| 22 | + if (a[i] != b[i]) return false; | ||
| 23 | + } | ||
| 24 | + | ||
| 25 | + return true; | ||
| 26 | +} | ||
| 27 | + | ||
| 28 | +static void PrintShape(const std::vector<int64_t> &a) { | ||
| 29 | + for (auto i : a) { | ||
| 30 | + fprintf(stderr, "%d ", static_cast<int32_t>(i)); | ||
| 31 | + } | ||
| 32 | + fprintf(stderr, "\n"); | ||
| 33 | +} | ||
| 34 | + | ||
| 35 | +template <typename T /*=float*/> | ||
| 36 | +Ort::Value Stack(OrtAllocator *allocator, | ||
| 37 | + const std::vector<const Ort::Value *> &values, int32_t dim) { | ||
| 38 | + std::vector<int64_t> v0_shape = | ||
| 39 | + values[0]->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 40 | + | ||
| 41 | + for (int32_t i = 1; i != static_cast<int32_t>(values.size()); ++i) { | ||
| 42 | + auto s = values[i]->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 43 | + bool ret = Compare(v0_shape, s); | ||
| 44 | + if (!ret) { | ||
| 45 | + fprintf(stderr, "Incorrect shape in Stack !\n"); | ||
| 46 | + | ||
| 47 | + fprintf(stderr, "Shape for tensor 0: "); | ||
| 48 | + PrintShape(v0_shape); | ||
| 49 | + | ||
| 50 | + fprintf(stderr, "Shape for tensor %d: ", i); | ||
| 51 | + PrintShape(s); | ||
| 52 | + | ||
| 53 | + exit(-1); | ||
| 54 | + } | ||
| 55 | + } | ||
| 56 | + | ||
| 57 | + std::vector<int64_t> ans_shape; | ||
| 58 | + ans_shape.reserve(v0_shape.size() + 1); | ||
| 59 | + ans_shape.insert(ans_shape.end(), v0_shape.data(), v0_shape.data() + dim); | ||
| 60 | + ans_shape.push_back(values.size()); | ||
| 61 | + ans_shape.insert( | ||
| 62 | + ans_shape.end(), | ||
| 63 | + v0_shape.data() + dim, | ||
| 64 | + v0_shape.data() + v0_shape.size()); | ||
| 65 | + | ||
| 66 | + auto leading_size = static_cast<int32_t>(std::accumulate( | ||
| 67 | + v0_shape.begin(), v0_shape.begin() + dim, 1, std::multiplies<int64_t>())); | ||
| 68 | + | ||
| 69 | + auto trailing_size = static_cast<int32_t>( | ||
| 70 | + std::accumulate(v0_shape.begin() + dim, | ||
| 71 | + v0_shape.end(), 1, | ||
| 72 | + std::multiplies<int64_t>())); | ||
| 73 | + | ||
| 74 | + Ort::Value ans = Ort::Value::CreateTensor<T>( | ||
| 75 | + allocator, ans_shape.data(), ans_shape.size()); | ||
| 76 | + T *dst = ans.GetTensorMutableData<T>(); | ||
| 77 | + | ||
| 78 | + for (int32_t i = 0; i != leading_size; ++i) { | ||
| 79 | + for (int32_t n = 0; n != static_cast<int32_t>(values.size()); ++n) { | ||
| 80 | + const T *src = values[n]->GetTensorData<T>(); | ||
| 81 | + src += i * trailing_size; | ||
| 82 | + | ||
| 83 | + std::copy(src, src + trailing_size, dst); | ||
| 84 | + dst += trailing_size; | ||
| 85 | + } | ||
| 86 | + } | ||
| 87 | + | ||
| 88 | + return ans; | ||
| 89 | +} | ||
| 90 | + | ||
| 91 | +template Ort::Value Stack<float>( | ||
| 92 | + OrtAllocator *allocator, | ||
| 93 | + const std::vector<const Ort::Value *> &values, | ||
| 94 | + int32_t dim); | ||
| 95 | + | ||
| 96 | +template Ort::Value Stack<int64_t>( | ||
| 97 | + OrtAllocator *allocator, | ||
| 98 | + const std::vector<const Ort::Value *> &values, | ||
| 99 | + int32_t dim); | ||
| 100 | + | ||
| 101 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/stack.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/stack.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com) | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_STACK_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_STACK_H_ | ||
| 7 | + | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +/** Stack a list of tensors along the given dim. | ||
| 15 | + * | ||
| 16 | + * @param allocator Allocator to allocate space for the returned tensor | ||
| 17 | + * @param values Pointer to a list of tensors. The shape of the tensor must | ||
| 18 | + * be the same except on the dim to be stacked. | ||
| 19 | + * @param dim The dim along which to concatenate the input tensors | ||
| 20 | + * | ||
| 21 | + * @return Return the stacked tensor | ||
| 22 | + */ | ||
| 23 | +template <typename T = float> | ||
| 24 | +Ort::Value Stack(OrtAllocator *allocator, | ||
| 25 | + const std::vector<const Ort::Value *> &values, int32_t dim); | ||
| 26 | + | ||
| 27 | +} // namespace sherpa_onnx | ||
| 28 | + | ||
| 29 | +#endif // SHERPA_ONNX_CSRC_STACK_H_ |
-
请 注册 或 登录 后发表评论