Jingzhao Ou
Committed by GitHub

Stack and streaming conformer support (#141)

* added csrc/stack.cc

* stack: added checks

* added copyright info

* passed cpp style checks

* formatted code

* added some support for streaming conformer model support (not verified)

* code lint

* made more progress with streaming conformer support (not working yet)

* passed style check

* changes as suggested by @csukuangfj

* added some debug info

* fixed style check

* Use Cat to replace Stack

* remove debug statements

---------

Co-authored-by: Jingzhao Ou (jou2019) <jou2019@cisco.com>
Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
... ... @@ -34,6 +34,7 @@ set(sources
offline-transducer-model-config.cc
offline-transducer-model.cc
offline-transducer-modified-beam-search-decoder.cc
online-conformer-transducer-model.cc
online-lm.cc
online-lm-config.cc
online-lstm-transducer-model.cc
... ... @@ -52,6 +53,7 @@ set(sources
parse-options.cc
resample.cc
slice.cc
stack.cc
symbol-table.cc
text-utils.cc
transpose.cc
... ... @@ -241,6 +243,7 @@ if(SHERPA_ONNX_ENABLE_TESTS)
packed-sequence-test.cc
pad-sequence-test.cc
slice-test.cc
stack-test.cc
transpose-test.cc
unbind-test.cc
)
... ...
// sherpa-onnx/csrc/online-conformer-transducer-model.cc
//
// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com)
#include "sherpa-onnx/csrc/online-conformer-transducer-model.h"
#include <assert.h>
#include <algorithm>
#include <memory>
#include <sstream>
#include <iostream>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/cat.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/text-utils.h"
#include "sherpa-onnx/csrc/unbind.h"
namespace sherpa_onnx {
OnlineConformerTransducerModel::OnlineConformerTransducerModel(
const OnlineTransducerModelConfig &config)
: env_(ORT_LOGGING_LEVEL_WARNING),
config_(config),
sess_opts_{},
allocator_{} {
sess_opts_.SetIntraOpNumThreads(config.num_threads);
sess_opts_.SetInterOpNumThreads(config.num_threads);
{
auto buf = ReadFile(config.encoder_filename);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(config.decoder_filename);
InitDecoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(config.joiner_filename);
InitJoiner(buf.data(), buf.size());
}
}
#if __ANDROID_API__ >= 9
OnlineConformerTransducerModel::OnlineConformerTransducerModel(
AAssetManager *mgr, const OnlineTransducerModelConfig &config)
: env_(ORT_LOGGING_LEVEL_WARNING),
config_(config),
sess_opts_{},
allocator_{} {
sess_opts_.SetIntraOpNumThreads(config.num_threads);
sess_opts_.SetInterOpNumThreads(config.num_threads);
{
auto buf = ReadFile(mgr, config.encoder_filename);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.decoder_filename);
InitDecoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.joiner_filename);
InitJoiner(buf.data(), buf.size());
}
}
#endif
void OnlineConformerTransducerModel::InitEncoder(void *model_data,
size_t model_data_length) {
encoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
model_data_length, sess_opts_);
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
&encoder_input_names_ptr_);
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
&encoder_output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
os << "---encoder---\n";
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(num_encoder_layers_, "num_encoder_layers");
SHERPA_ONNX_READ_META_DATA(T_, "T");
SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len");
SHERPA_ONNX_READ_META_DATA(left_context_, "left_context");
SHERPA_ONNX_READ_META_DATA(encoder_dim_, "encoder_dim");
SHERPA_ONNX_READ_META_DATA(pad_length_, "pad_length");
SHERPA_ONNX_READ_META_DATA(cnn_module_kernel_, "cnn_module_kernel");
}
void OnlineConformerTransducerModel::InitDecoder(void *model_data,
size_t model_data_length) {
decoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
model_data_length, sess_opts_);
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
&decoder_input_names_ptr_);
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
&decoder_output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
os << "---decoder---\n";
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
}
void OnlineConformerTransducerModel::InitJoiner(void *model_data,
size_t model_data_length) {
joiner_sess_ = std::make_unique<Ort::Session>(env_, model_data,
model_data_length, sess_opts_);
GetInputNames(joiner_sess_.get(), &joiner_input_names_,
&joiner_input_names_ptr_);
GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
&joiner_output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
os << "---joiner---\n";
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}
}
std::vector<Ort::Value> OnlineConformerTransducerModel::StackStates(
const std::vector<std::vector<Ort::Value>> &states) const {
int32_t batch_size = static_cast<int32_t>(states.size());
std::vector<const Ort::Value *> attn_vec(batch_size);
std::vector<const Ort::Value *> conv_vec(batch_size);
for (int32_t i = 0; i != batch_size; ++i) {
assert(states[i].size() == 2);
attn_vec[i] = &states[i][0];
conv_vec[i] = &states[i][1];
}
Ort::Value attn = Cat(allocator_, attn_vec, 2);
Ort::Value conv = Cat(allocator_, conv_vec, 2);
std::vector<Ort::Value> ans;
ans.reserve(2);
ans.push_back(std::move(attn));
ans.push_back(std::move(conv));
return ans;
}
std::vector<std::vector<Ort::Value>>
OnlineConformerTransducerModel::UnStackStates(
const std::vector<Ort::Value> &states) const {
const int32_t batch_size =
states[0].GetTensorTypeAndShapeInfo().GetShape()[2];
assert(states.size() == 2);
std::vector<std::vector<Ort::Value>> ans(batch_size);
std::vector<Ort::Value> attn_vec = Unbind(allocator_, &states[0], 2);
std::vector<Ort::Value> conv_vec = Unbind(allocator_, &states[1], 2);
assert(attn_vec.size() == batch_size);
assert(conv_vec.size() == batch_size);
for (int32_t i = 0; i != batch_size; ++i) {
ans[i].push_back(std::move(attn_vec[i]));
ans[i].push_back(std::move(conv_vec[i]));
}
return ans;
}
std::vector<Ort::Value> OnlineConformerTransducerModel::GetEncoderInitStates() {
// Please see
// https://github.com/k2-fsa/icefall/blob/86b0db6eb9c84d9bc90a71d92774fe2a7f73e6ab/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py#L203
// for details
constexpr int32_t kBatchSize = 1;
std::array<int64_t, 4> h_shape{
num_encoder_layers_, left_context_, kBatchSize, encoder_dim_};
Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
h_shape.size());
Fill<float>(&h, 0);
std::array<int64_t, 4> c_shape{num_encoder_layers_, cnn_module_kernel_ - 1,
kBatchSize, encoder_dim_};
Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(),
c_shape.size());
Fill<float>(&c, 0);
std::vector<Ort::Value> states;
states.reserve(2);
states.push_back(std::move(h));
states.push_back(std::move(c));
return states;
}
std::pair<Ort::Value, std::vector<Ort::Value>>
OnlineConformerTransducerModel::RunEncoder(Ort::Value features,
std::vector<Ort::Value> states,
Ort::Value processed_frames) {
std::array<Ort::Value, 4> encoder_inputs = {
std::move(features),
std::move(states[0]),
std::move(states[1]),
std::move(processed_frames)};
auto encoder_out = encoder_sess_->Run(
{}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
encoder_inputs.size(), encoder_output_names_ptr_.data(),
encoder_output_names_ptr_.size());
std::vector<Ort::Value> next_states;
next_states.reserve(2);
next_states.push_back(std::move(encoder_out[1]));
next_states.push_back(std::move(encoder_out[2]));
return {std::move(encoder_out[0]), std::move(next_states)};
}
Ort::Value OnlineConformerTransducerModel::RunDecoder(
Ort::Value decoder_input) {
auto decoder_out = decoder_sess_->Run(
{}, decoder_input_names_ptr_.data(), &decoder_input, 1,
decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size());
return std::move(decoder_out[0]);
}
Ort::Value OnlineConformerTransducerModel::RunJoiner(Ort::Value encoder_out,
Ort::Value decoder_out) {
std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
std::move(decoder_out)};
auto logit =
joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(),
joiner_input.size(), joiner_output_names_ptr_.data(),
joiner_output_names_ptr_.size());
return std::move(logit[0]);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/online-conformer-transducer-model.h
//
// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com)
#ifndef SHERPA_ONNX_CSRC_ONLINE_CONFORMER_TRANSDUCER_MODEL_H_
#define SHERPA_ONNX_CSRC_ONLINE_CONFORMER_TRANSDUCER_MODEL_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"
namespace sherpa_onnx {
class OnlineConformerTransducerModel : public OnlineTransducerModel {
public:
explicit OnlineConformerTransducerModel(
const OnlineTransducerModelConfig &config);
#if __ANDROID_API__ >= 9
OnlineConformerTransducerModel(AAssetManager *mgr,
const OnlineTransducerModelConfig &config);
#endif
std::vector<Ort::Value> StackStates(
const std::vector<std::vector<Ort::Value>> &states) const override;
std::vector<std::vector<Ort::Value>> UnStackStates(
const std::vector<Ort::Value> &states) const override;
std::vector<Ort::Value> GetEncoderInitStates() override;
std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
Ort::Value features, std::vector<Ort::Value> states,
Ort::Value processed_frames) override;
Ort::Value RunDecoder(Ort::Value decoder_input) override;
Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override;
int32_t ContextSize() const override { return context_size_; }
int32_t ChunkSize() const override { return T_; }
int32_t ChunkShift() const override { return decode_chunk_len_; }
int32_t VocabSize() const override { return vocab_size_; }
OrtAllocator *Allocator() override { return allocator_; }
private:
void InitEncoder(void *model_data, size_t model_data_length);
void InitDecoder(void *model_data, size_t model_data_length);
void InitJoiner(void *model_data, size_t model_data_length);
private:
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> encoder_sess_;
std::unique_ptr<Ort::Session> decoder_sess_;
std::unique_ptr<Ort::Session> joiner_sess_;
std::vector<std::string> encoder_input_names_;
std::vector<const char *> encoder_input_names_ptr_;
std::vector<std::string> encoder_output_names_;
std::vector<const char *> encoder_output_names_ptr_;
std::vector<std::string> decoder_input_names_;
std::vector<const char *> decoder_input_names_ptr_;
std::vector<std::string> decoder_output_names_;
std::vector<const char *> decoder_output_names_ptr_;
std::vector<std::string> joiner_input_names_;
std::vector<const char *> joiner_input_names_ptr_;
std::vector<std::string> joiner_output_names_;
std::vector<const char *> joiner_output_names_ptr_;
OnlineTransducerModelConfig config_;
int32_t num_encoder_layers_ = 0;
int32_t T_ = 0;
int32_t decode_chunk_len_ = 0;
int32_t cnn_module_kernel_ = 0;
int32_t context_size_ = 0;
int32_t left_context_ = 0;
// TODO(jingzhaoou): to retrieve from model medadata
int32_t right_context_ = 4;
int32_t encoder_dim_ = 0;
int32_t pad_length_ = 0;
int32_t vocab_size_ = 0;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_CONFORMER_TRANSDUCER_MODEL_H_
... ...
... ... @@ -227,7 +227,8 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() {
std::pair<Ort::Value, std::vector<Ort::Value>>
OnlineLstmTransducerModel::RunEncoder(Ort::Value features,
std::vector<Ort::Value> states) {
std::vector<Ort::Value> states,
Ort::Value /* processed_frames */) {
std::array<Ort::Value, 3> encoder_inputs = {
std::move(features), std::move(states[0]), std::move(states[1])};
... ...
... ... @@ -38,7 +38,8 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
std::vector<Ort::Value> GetEncoderInitStates() override;
std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
Ort::Value features, std::vector<Ort::Value> states) override;
Ort::Value features, std::vector<Ort::Value> states,
Ort::Value processed_frames) override;
Ort::Value RunDecoder(Ort::Value decoder_input) override;
... ...
... ... @@ -9,6 +9,7 @@
#include <algorithm>
#include <iomanip>
#include <iostream>
#include <memory>
#include <sstream>
#include <utility>
... ... @@ -187,11 +188,14 @@ class OnlineRecognizer::Impl {
std::vector<OnlineTransducerDecoderResult> results(n);
std::vector<float> features_vec(n * chunk_size * feature_dim);
std::vector<std::vector<Ort::Value>> states_vec(n);
std::vector<int64_t> all_processed_frames(n);
for (int32_t i = 0; i != n; ++i) {
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
std::vector<float> features =
ss[i]->GetFrames(ss[i]->GetNumProcessedFrames(), chunk_size);
ss[i]->GetFrames(num_processed_frames, chunk_size);
// Question: should num_processed_frames include chunk_shift?
ss[i]->GetNumProcessedFrames() += chunk_shift;
std::copy(features.begin(), features.end(),
... ... @@ -199,6 +203,7 @@ class OnlineRecognizer::Impl {
results[i] = std::move(ss[i]->GetResult());
states_vec[i] = std::move(ss[i]->GetStates());
all_processed_frames[i] = num_processed_frames;
}
auto memory_info =
... ... @@ -210,9 +215,20 @@ class OnlineRecognizer::Impl {
features_vec.size(), x_shape.data(),
x_shape.size());
std::array<int64_t, 1> processed_frames_shape{
static_cast<int64_t>(all_processed_frames.size())};
Ort::Value processed_frames = Ort::Value::CreateTensor(
memory_info,
all_processed_frames.data(),
all_processed_frames.size(),
processed_frames_shape.data(),
processed_frames_shape.size());
auto states = model_->StackStates(states_vec);
auto pair = model_->RunEncoder(std::move(x), std::move(states));
auto pair = model_->RunEncoder(
std::move(x), std::move(states), std::move(processed_frames));
decoder_->Decode(std::move(pair.first), &results);
... ...
... ... @@ -10,11 +10,13 @@
#endif
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-conformer-transducer-model.h"
#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
#include "sherpa-onnx/csrc/online-zipformer-transducer-model.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
... ... @@ -22,6 +24,7 @@
namespace {
enum class ModelType {
kConformer,
kLstm,
kZipformer,
kUnkown,
... ... @@ -57,7 +60,9 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kUnkown;
}
if (model_type.get() == std::string("lstm")) {
if (model_type.get() == std::string("conformer")) {
return ModelType::kConformer;
} else if (model_type.get() == std::string("lstm")) {
return ModelType::kLstm;
} else if (model_type.get() == std::string("zipformer")) {
return ModelType::kZipformer;
... ... @@ -78,6 +83,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
}
switch (model_type) {
case ModelType::kConformer:
return std::make_unique<OnlineConformerTransducerModel>(config);
case ModelType::kLstm:
return std::make_unique<OnlineLstmTransducerModel>(config);
case ModelType::kZipformer:
... ... @@ -132,6 +139,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
switch (model_type) {
case ModelType::kConformer:
return std::make_unique<OnlineConformerTransducerModel>(mgr, config);
case ModelType::kLstm:
return std::make_unique<OnlineLstmTransducerModel>(mgr, config);
case ModelType::kZipformer:
... ...
... ... @@ -64,6 +64,7 @@ class OnlineTransducerModel {
*
* @param features A tensor of shape (N, T, C). It is changed in-place.
* @param states Encoder state of the previous chunk. It is changed in-place.
* @param processed_frames Processed frames before subsampling. It is a 1-D tensor with data type int64_t.
*
* @return Return a tuple containing:
* - encoder_out, a tensor of shape (N, T', encoder_out_dim)
... ... @@ -71,7 +72,8 @@ class OnlineTransducerModel {
*/
virtual std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
Ort::Value features,
std::vector<Ort::Value> states) = 0; // NOLINT
std::vector<Ort::Value> states,
Ort::Value processed_frames) = 0; // NOLINT
/** Run the decoder network.
*
... ...
... ... @@ -434,7 +434,8 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::GetEncoderInitStates() {
std::pair<Ort::Value, std::vector<Ort::Value>>
OnlineZipformerTransducerModel::RunEncoder(Ort::Value features,
std::vector<Ort::Value> states) {
std::vector<Ort::Value> states,
Ort::Value /* processed_frames */) {
std::vector<Ort::Value> encoder_inputs;
encoder_inputs.reserve(1 + states.size());
... ...
... ... @@ -39,7 +39,8 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel {
std::vector<Ort::Value> GetEncoderInitStates() override;
std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
Ort::Value features, std::vector<Ort::Value> states) override;
Ort::Value features, std::vector<Ort::Value> states,
Ort::Value processed_frames) override;
Ort::Value RunDecoder(Ort::Value decoder_input) override;
... ...
... ... @@ -168,6 +168,26 @@ void Print3D(Ort::Value *v) {
fprintf(stderr, "\n");
}
void Print4D(Ort::Value *v) {
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
const float *d = v->GetTensorData<float>();
for (int32_t p = 0; p != static_cast<int32_t>(shape[0]); ++p) {
fprintf(stderr, "---plane %d---\n", p);
for (int32_t q = 0; q != static_cast<int32_t>(shape[1]); ++q) {
fprintf(stderr, "---subplane %d---\n", q);
for (int32_t r = 0; r != static_cast<int32_t>(shape[2]); ++r) {
for (int32_t c = 0; c != static_cast<int32_t>(shape[3]); ++c, ++d) {
fprintf(stderr, "%.3f ", *d);
}
fprintf(stderr, "\n");
}
fprintf(stderr, "\n");
}
}
fprintf(stderr, "\n");
}
std::vector<char> ReadFile(const std::string &filename) {
std::ifstream input(filename, std::ios::binary);
std::vector<char> buffer(std::istreambuf_iterator<char>(input), {});
... ...
... ... @@ -75,6 +75,9 @@ void Print2D(Ort::Value *v);
// Print a 3-D tensor to stderr
void Print3D(Ort::Value *v);
// Print a 4-D tensor to stderr
void Print4D(Ort::Value *v);
template <typename T = float>
void Fill(Ort::Value *tensor, T value) {
auto n = tensor->GetTypeInfo().GetTensorTypeAndShapeInfo().GetElementCount();
... ...
// sherpa-onnx/csrc/stack-test.cc
//
// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com)
#include "sherpa-onnx/csrc/stack.h"
#include "gtest/gtest.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
TEST(Stack, Test1DTensors) {
Ort::AllocatorWithDefaultOptions allocator;
std::array<int64_t, 1> a_shape{3};
std::array<int64_t, 1> b_shape{3};
Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
a_shape.size());
Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
b_shape.size());
float *pa = a.GetTensorMutableData<float>();
float *pb = b.GetTensorMutableData<float>();
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) {
pa[i] = i;
}
for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0]); ++i) {
pb[i] = i + 10;
}
Ort::Value ans = Stack(allocator, {&a, &b}, 0);
Print1D(&a);
Print1D(&b);
Print2D(&ans);
const float *pans = ans.GetTensorData<float>();
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) {
EXPECT_EQ(pa[i], pans[i]);
}
for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0]); ++i) {
EXPECT_EQ(pb[i], pans[i + a_shape[0]]);
}
}
TEST(Stack, Test2DTensorsDim0) {
Ort::AllocatorWithDefaultOptions allocator;
std::array<int64_t, 2> a_shape{2, 3};
std::array<int64_t, 2> b_shape{2, 3};
Ort::Value a = Ort::Value::CreateTensor<float>(
allocator, a_shape.data(), a_shape.size());
Ort::Value b = Ort::Value::CreateTensor<float>(
allocator, b_shape.data(), b_shape.size());
float *pa = a.GetTensorMutableData<float>();
float *pb = b.GetTensorMutableData<float>();
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) {
pa[i] = i;
}
for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0] * b_shape[1]); ++i) {
pb[i] = i + 10;
}
Ort::Value ans = Stack(allocator, {&a, &b}, 0);
Print2D(&a);
Print2D(&b);
Print3D(&ans);
const float *pans = ans.GetTensorData<float>();
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) {
EXPECT_EQ(pa[i], pans[i]);
}
for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0] * b_shape[1]); ++i) {
EXPECT_EQ(pb[i], pans[i + a_shape[0] * a_shape[1]]);
}
}
TEST(Stack, Test2DTensorsDim1) {
Ort::AllocatorWithDefaultOptions allocator;
std::array<int64_t, 2> a_shape{4, 3};
std::array<int64_t, 2> b_shape{4, 3};
Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
a_shape.size());
Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
b_shape.size());
float *pa = a.GetTensorMutableData<float>();
float *pb = b.GetTensorMutableData<float>();
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) {
pa[i] = i;
}
for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0] * b_shape[1]); ++i) {
pb[i] = i + 10;
}
Ort::Value ans = Stack(allocator, {&a, &b}, 1);
Print2D(&a);
Print2D(&b);
Print3D(&ans);
const float *pans = ans.GetTensorData<float>();
for (int32_t r = 0; r != static_cast<int32_t>(a_shape[0]); ++r) {
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[1]);
++i, ++pa, ++pans) {
EXPECT_EQ(*pa, *pans);
}
for (int32_t i = 0; i != static_cast<int32_t>(b_shape[1]);
++i, ++pb, ++pans) {
EXPECT_EQ(*pb, *pans);
}
}
}
TEST(Stack, Test3DTensorsDim0) {
Ort::AllocatorWithDefaultOptions allocator;
std::array<int64_t, 3> a_shape{2, 3, 2};
std::array<int64_t, 3> b_shape{2, 3, 2};
Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
a_shape.size());
Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
b_shape.size());
float *pa = a.GetTensorMutableData<float>();
float *pb = b.GetTensorMutableData<float>();
for (int32_t i = 0;
i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) {
pa[i] = i;
}
for (int32_t i = 0;
i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) {
pb[i] = i + 10;
}
Ort::Value ans = Stack(allocator, {&a, &b}, 0);
const float *pans = ans.GetTensorData<float>();
for (int32_t i = 0;
i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) {
EXPECT_EQ(pa[i], pans[i]);
}
for (int32_t i = 0;
i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) {
EXPECT_EQ(pb[i], pans[i + a_shape[0] * a_shape[1] * a_shape[2]]);
}
Print3D(&a);
Print3D(&b);
Print4D(&ans);
}
TEST(Stack, Test3DTensorsDim1) {
Ort::AllocatorWithDefaultOptions allocator;
std::array<int64_t, 3> a_shape{2, 2, 3};
std::array<int64_t, 3> b_shape{2, 2, 3};
Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
a_shape.size());
Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
b_shape.size());
float *pa = a.GetTensorMutableData<float>();
float *pb = b.GetTensorMutableData<float>();
for (int32_t i = 0;
i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) {
pa[i] = i;
}
for (int32_t i = 0;
i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) {
pb[i] = i + 10;
}
Ort::Value ans = Stack(allocator, {&a, &b}, 1);
const float *pans = ans.GetTensorData<float>();
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) {
for (int32_t k = 0; k != static_cast<int32_t>(a_shape[1] * a_shape[2]);
++k, ++pa, ++pans) {
EXPECT_EQ(*pa, *pans);
}
for (int32_t k = 0; k != static_cast<int32_t>(b_shape[1] * b_shape[2]);
++k, ++pb, ++pans) {
EXPECT_EQ(*pb, *pans);
}
}
Print3D(&a);
Print3D(&b);
Print4D(&ans);
}
TEST(Stack, Test3DTensorsDim2) {
Ort::AllocatorWithDefaultOptions allocator;
std::array<int64_t, 3> a_shape{2, 3, 4};
std::array<int64_t, 3> b_shape{2, 3, 4};
Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
a_shape.size());
Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
b_shape.size());
float *pa = a.GetTensorMutableData<float>();
float *pb = b.GetTensorMutableData<float>();
for (int32_t i = 0;
i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) {
pa[i] = i;
}
for (int32_t i = 0;
i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) {
pb[i] = i + 10;
}
Ort::Value ans = Stack(allocator, {&a, &b}, 2);
const float *pans = ans.GetTensorData<float>();
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) {
for (int32_t k = 0; k != static_cast<int32_t>(a_shape[2]);
++k, ++pa, ++pans) {
EXPECT_EQ(*pa, *pans);
}
for (int32_t k = 0; k != static_cast<int32_t>(b_shape[2]);
++k, ++pb, ++pans) {
EXPECT_EQ(*pb, *pans);
}
}
Print3D(&a);
Print3D(&b);
Print4D(&ans);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/stack.cc
//
// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com)
#include "sherpa-onnx/csrc/stack.h"
#include <algorithm>
#include <functional>
#include <iostream>
#include <numeric>
#include <utility>
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
static bool Compare(const std::vector<int64_t> &a,
const std::vector<int64_t> &b) {
if (a.size() != b.size()) return false;
for (int32_t i = 0; i != static_cast<int32_t>(a.size()); ++i) {
if (a[i] != b[i]) return false;
}
return true;
}
static void PrintShape(const std::vector<int64_t> &a) {
for (auto i : a) {
fprintf(stderr, "%d ", static_cast<int32_t>(i));
}
fprintf(stderr, "\n");
}
template <typename T /*=float*/>
Ort::Value Stack(OrtAllocator *allocator,
const std::vector<const Ort::Value *> &values, int32_t dim) {
std::vector<int64_t> v0_shape =
values[0]->GetTensorTypeAndShapeInfo().GetShape();
for (int32_t i = 1; i != static_cast<int32_t>(values.size()); ++i) {
auto s = values[i]->GetTensorTypeAndShapeInfo().GetShape();
bool ret = Compare(v0_shape, s);
if (!ret) {
fprintf(stderr, "Incorrect shape in Stack !\n");
fprintf(stderr, "Shape for tensor 0: ");
PrintShape(v0_shape);
fprintf(stderr, "Shape for tensor %d: ", i);
PrintShape(s);
exit(-1);
}
}
std::vector<int64_t> ans_shape;
ans_shape.reserve(v0_shape.size() + 1);
ans_shape.insert(ans_shape.end(), v0_shape.data(), v0_shape.data() + dim);
ans_shape.push_back(values.size());
ans_shape.insert(
ans_shape.end(),
v0_shape.data() + dim,
v0_shape.data() + v0_shape.size());
auto leading_size = static_cast<int32_t>(std::accumulate(
v0_shape.begin(), v0_shape.begin() + dim, 1, std::multiplies<int64_t>()));
auto trailing_size = static_cast<int32_t>(
std::accumulate(v0_shape.begin() + dim,
v0_shape.end(), 1,
std::multiplies<int64_t>()));
Ort::Value ans = Ort::Value::CreateTensor<T>(
allocator, ans_shape.data(), ans_shape.size());
T *dst = ans.GetTensorMutableData<T>();
for (int32_t i = 0; i != leading_size; ++i) {
for (int32_t n = 0; n != static_cast<int32_t>(values.size()); ++n) {
const T *src = values[n]->GetTensorData<T>();
src += i * trailing_size;
std::copy(src, src + trailing_size, dst);
dst += trailing_size;
}
}
return ans;
}
template Ort::Value Stack<float>(
OrtAllocator *allocator,
const std::vector<const Ort::Value *> &values,
int32_t dim);
template Ort::Value Stack<int64_t>(
OrtAllocator *allocator,
const std::vector<const Ort::Value *> &values,
int32_t dim);
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/stack.h
//
// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com)
#ifndef SHERPA_ONNX_CSRC_STACK_H_
#define SHERPA_ONNX_CSRC_STACK_H_
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
namespace sherpa_onnx {
/** Stack a list of tensors along the given dim.
*
* @param allocator Allocator to allocate space for the returned tensor
* @param values Pointer to a list of tensors. The shape of the tensor must
* be the same except on the dim to be stacked.
* @param dim The dim along which to concatenate the input tensors
*
* @return Return the stacked tensor
*/
template <typename T = float>
Ort::Value Stack(OrtAllocator *allocator,
const std::vector<const Ort::Value *> &values, int32_t dim);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_STACK_H_
... ...