Fangjun Kuang
Committed by GitHub

Add Streaming zipformer (#50)

... ... @@ -33,8 +33,13 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
os: [ubuntu-latest, macos-latest] # windows-latest]
python-version: ["3.7", "3.8", "3.9", "3.10"]
exclude:
- os: macos-latest
python-version: "3.9"
- os: macos-latest
python-version: "3.10"
steps:
- uses: actions/checkout@v2
... ...
... ... @@ -8,3 +8,4 @@ sherpa-onnx-*
__pycache__
dist/
sherpa_onnx.egg-info/
.DS_Store
... ...
... ... @@ -62,6 +62,7 @@ endif()
if(SHERPA_ONNX_ENABLE_TESTS)
enable_testing()
include(googletest)
endif()
add_subdirectory(sherpa-onnx)
... ...
# Copyright 2020 Fangjun Kuang (csukuangfj@gmail.com)
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
function(download_googltest)
include(FetchContent)
... ... @@ -26,6 +11,7 @@ function(download_googltest)
${PROJECT_SOURCE_DIR}/googletest-1.13.0.tar.gz
${PROJECT_BINARY_DIR}/googletest-1.13.0.tar.gz
/tmp/googletest-1.13.0.tar.gz
/star-fj/fangjun/download/github/googletest-1.13.0.tar.gz
)
foreach(f IN LISTS possible_file_locations)
... ...
include_directories(${CMAKE_SOURCE_DIR})
add_library(sherpa-onnx-core
cat.cc
features.cc
online-lstm-transducer-model.cc
online-recognizer.cc
... ... @@ -8,8 +9,11 @@ add_library(sherpa-onnx-core
online-transducer-greedy-search-decoder.cc
online-transducer-model-config.cc
online-transducer-model.cc
online-zipformer-transducer-model.cc
onnx-utils.cc
symbol-table.cc
text-utils.cc
unbind.cc
wave-reader.cc
)
... ... @@ -27,3 +31,32 @@ endif()
install(TARGETS sherpa-onnx-core DESTINATION lib)
install(TARGETS sherpa-onnx DESTINATION bin)
if(SHERPA_ONNX_ENABLE_TESTS)
set(sherpa_onnx_test_srcs
cat-test.cc
unbind-test.cc
)
function(sherpa_onnx_add_test source)
get_filename_component(name ${source} NAME_WE)
set(target_name ${name})
add_executable(${target_name} "${source}")
target_link_libraries(${target_name}
PRIVATE
gtest
gtest_main
sherpa-onnx-core
)
add_test(NAME "${target_name}"
COMMAND
$<TARGET_FILE:${target_name}>
)
endfunction()
foreach(source IN LISTS sherpa_onnx_test_srcs)
sherpa_onnx_add_test(${source})
endforeach()
endif()
... ...
// sherpa-onnx/csrc/cat-test.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/cat.h"
#include "gtest/gtest.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
TEST(Cat, Test1DTensors) {
Ort::AllocatorWithDefaultOptions allocator;
std::array<int64_t, 1> a_shape{3};
std::array<int64_t, 1> b_shape{6};
Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
a_shape.size());
Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
b_shape.size());
float *pa = a.GetTensorMutableData<float>();
float *pb = b.GetTensorMutableData<float>();
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) {
pa[i] = i;
}
for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0]); ++i) {
pb[i] = i + 10;
}
Ort::Value ans = Cat(allocator, {&a, &b}, 0);
const float *pans = ans.GetTensorData<float>();
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) {
EXPECT_EQ(pa[i], pans[i]);
}
for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0]); ++i) {
EXPECT_EQ(pb[i], pans[i + a_shape[0]]);
}
Print1D(&a);
Print1D(&b);
Print1D(&ans);
}
TEST(Cat, Test2DTensorsDim0) {
Ort::AllocatorWithDefaultOptions allocator;
std::array<int64_t, 2> a_shape{2, 3};
std::array<int64_t, 2> b_shape{4, 3};
Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
a_shape.size());
Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
b_shape.size());
float *pa = a.GetTensorMutableData<float>();
float *pb = b.GetTensorMutableData<float>();
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) {
pa[i] = i;
}
for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0] * b_shape[1]); ++i) {
pb[i] = i + 10;
}
Ort::Value ans = Cat(allocator, {&a, &b}, 0);
const float *pans = ans.GetTensorData<float>();
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) {
EXPECT_EQ(pa[i], pans[i]);
}
for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0] * b_shape[1]); ++i) {
EXPECT_EQ(pb[i], pans[i + a_shape[0] * a_shape[1]]);
}
Print2D(&a);
Print2D(&b);
Print2D(&ans);
}
TEST(Cat, Test2DTensorsDim1) {
Ort::AllocatorWithDefaultOptions allocator;
std::array<int64_t, 2> a_shape{4, 3};
std::array<int64_t, 2> b_shape{4, 2};
Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
a_shape.size());
Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
b_shape.size());
float *pa = a.GetTensorMutableData<float>();
float *pb = b.GetTensorMutableData<float>();
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) {
pa[i] = i;
}
for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0] * b_shape[1]); ++i) {
pb[i] = i + 10;
}
Ort::Value ans = Cat(allocator, {&a, &b}, 1);
const float *pans = ans.GetTensorData<float>();
for (int32_t r = 0; r != static_cast<int32_t>(a_shape[0]); ++r) {
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[1]);
++i, ++pa, ++pans) {
EXPECT_EQ(*pa, *pans);
}
for (int32_t i = 0; i != static_cast<int32_t>(b_shape[1]);
++i, ++pb, ++pans) {
EXPECT_EQ(*pb, *pans);
}
}
Print2D(&a);
Print2D(&b);
Print2D(&ans);
}
TEST(Cat, Test3DTensorsDim0) {
Ort::AllocatorWithDefaultOptions allocator;
std::array<int64_t, 3> a_shape{2, 3, 2};
std::array<int64_t, 3> b_shape{4, 3, 2};
Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
a_shape.size());
Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
b_shape.size());
float *pa = a.GetTensorMutableData<float>();
float *pb = b.GetTensorMutableData<float>();
for (int32_t i = 0;
i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) {
pa[i] = i;
}
for (int32_t i = 0;
i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) {
pb[i] = i + 10;
}
Ort::Value ans = Cat(allocator, {&a, &b}, 0);
const float *pans = ans.GetTensorData<float>();
for (int32_t i = 0;
i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) {
EXPECT_EQ(pa[i], pans[i]);
}
for (int32_t i = 0;
i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) {
EXPECT_EQ(pb[i], pans[i + a_shape[0] * a_shape[1] * a_shape[2]]);
}
Print3D(&a);
Print3D(&b);
Print3D(&ans);
}
TEST(Cat, Test3DTensorsDim1) {
Ort::AllocatorWithDefaultOptions allocator;
std::array<int64_t, 3> a_shape{2, 2, 3};
std::array<int64_t, 3> b_shape{2, 4, 3};
Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
a_shape.size());
Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
b_shape.size());
float *pa = a.GetTensorMutableData<float>();
float *pb = b.GetTensorMutableData<float>();
for (int32_t i = 0;
i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) {
pa[i] = i;
}
for (int32_t i = 0;
i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) {
pb[i] = i + 10;
}
Ort::Value ans = Cat(allocator, {&a, &b}, 1);
const float *pans = ans.GetTensorData<float>();
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) {
for (int32_t k = 0; k != static_cast<int32_t>(a_shape[1] * a_shape[2]);
++k, ++pa, ++pans) {
EXPECT_EQ(*pa, *pans);
}
for (int32_t k = 0; k != static_cast<int32_t>(b_shape[1] * b_shape[2]);
++k, ++pb, ++pans) {
EXPECT_EQ(*pb, *pans);
}
}
Print3D(&a);
Print3D(&b);
Print3D(&ans);
}
TEST(Cat, Test3DTensorsDim2) {
Ort::AllocatorWithDefaultOptions allocator;
std::array<int64_t, 3> a_shape{2, 3, 4};
std::array<int64_t, 3> b_shape{2, 3, 5};
Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
a_shape.size());
Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
b_shape.size());
float *pa = a.GetTensorMutableData<float>();
float *pb = b.GetTensorMutableData<float>();
for (int32_t i = 0;
i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) {
pa[i] = i;
}
for (int32_t i = 0;
i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) {
pb[i] = i + 10;
}
Ort::Value ans = Cat(allocator, {&a, &b}, 2);
const float *pans = ans.GetTensorData<float>();
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) {
for (int32_t k = 0; k != static_cast<int32_t>(a_shape[2]);
++k, ++pa, ++pans) {
EXPECT_EQ(*pa, *pans);
}
for (int32_t k = 0; k != static_cast<int32_t>(b_shape[2]);
++k, ++pb, ++pans) {
EXPECT_EQ(*pb, *pans);
}
}
Print3D(&a);
Print3D(&b);
Print3D(&ans);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/cat.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/cat.h"
#include <algorithm>
#include <functional>
#include <numeric>
#include <utility>
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
static bool Compare(const std::vector<int64_t> &a,
const std::vector<int64_t> &b, int32_t skip_dim) {
if (a.size() != b.size()) return false;
for (int32_t i = 0; i != static_cast<int32_t>(a.size()); ++i) {
if (i == skip_dim) continue;
if (a[i] != b[i]) return false;
}
return true;
}
static void PrintShape(const std::vector<int64_t> &a) {
for (auto i : a) {
fprintf(stderr, "%d ", static_cast<int32_t>(i));
}
fprintf(stderr, "\n");
}
template <typename T /*=float*/>
Ort::Value Cat(OrtAllocator *allocator,
const std::vector<const Ort::Value *> &values, int32_t dim) {
if (values.size() == 1u) {
return Clone(values[0]);
}
std::vector<int64_t> v0_shape =
values[0]->GetTensorTypeAndShapeInfo().GetShape();
int64_t total_dim = v0_shape[dim];
for (int32_t i = 1; i != static_cast<int32_t>(values.size()); ++i) {
auto s = values[i]->GetTensorTypeAndShapeInfo().GetShape();
total_dim += s[dim];
bool ret = Compare(v0_shape, s, dim);
if (!ret) {
fprintf(stderr, "Incorrect shape in Cat !\n");
fprintf(stderr, "Shape for tensor 0: ");
PrintShape(v0_shape);
fprintf(stderr, "Shape for tensor %d: ", i);
PrintShape(s);
exit(-1);
}
}
std::vector<int64_t> ans_shape;
ans_shape.reserve(v0_shape.size());
ans_shape.insert(ans_shape.end(), v0_shape.data(), v0_shape.data() + dim);
ans_shape.push_back(total_dim);
ans_shape.insert(ans_shape.end(), v0_shape.data() + dim + 1,
v0_shape.data() + v0_shape.size());
auto leading_size = static_cast<int32_t>(std::accumulate(
v0_shape.begin(), v0_shape.begin() + dim, 1, std::multiplies<int64_t>()));
auto trailing_size = static_cast<int32_t>(
std::accumulate(v0_shape.begin() + dim + 1, v0_shape.end(), 1,
std::multiplies<int64_t>()));
Ort::Value ans = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(),
ans_shape.size());
T *dst = ans.GetTensorMutableData<T>();
for (int32_t i = 0; i != leading_size; ++i) {
for (int32_t n = 0; n != static_cast<int32_t>(values.size()); ++n) {
auto this_dim = values[n]->GetTensorTypeAndShapeInfo().GetShape()[dim];
const T *src = values[n]->GetTensorData<T>();
src += i * this_dim * trailing_size;
std::copy(src, src + this_dim * trailing_size, dst);
dst += this_dim * trailing_size;
}
}
return std::move(ans);
}
template Ort::Value Cat<float>(OrtAllocator *allocator,
const std::vector<const Ort::Value *> &values,
int32_t dim);
template Ort::Value Cat<int64_t>(OrtAllocator *allocator,
const std::vector<const Ort::Value *> &values,
int32_t dim);
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/cat.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_CAT_H_
#define SHERPA_ONNX_CSRC_CAT_H_
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
namespace sherpa_onnx {
/** Cat a list of tensors along the given dim.
*
* @param allocator Allocator to allocate space for the returned tensor
* @param values Pointer to a list of tensors. The shape of the tensor must
* be the same except on the dim to be concatenated.
* @param dim The dim along which to concatenate the input tensors
*
* @return Return the concatenated tensor
*/
template <typename T = float>
Ort::Value Cat(OrtAllocator *allocator,
const std::vector<const Ort::Value *> &values, int32_t dim);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_CAT_H_
... ...
// sherpa-onnx/csrc/macros.h
//
// Copyright 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_MACROS_H_
#define SHERPA_ONNX_CSRC_MACROS_H_
#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \
do { \
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
fprintf(stderr, "%s does not exist in the metadata\n", src_key); \
exit(-1); \
} \
\
dst = atoi(value.get()); \
if (dst <= 0) { \
fprintf(stderr, "Invalid value %d for %s\n", dst, src_key); \
exit(-1); \
} \
} while (0)
#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \
do { \
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
fprintf(stderr, "%s does not exist in the metadata\n", src_key); \
exit(-1); \
} \
\
bool ret = SplitStringToIntegers(value.get(), ",", true, &dst); \
if (!ret) { \
fprintf(stderr, "Invalid value %s for %s\n", value.get(), src_key); \
exit(-1); \
} \
} while (0)
#endif // SHERPA_ONNX_CSRC_MACROS_H_
... ...
... ... @@ -3,6 +3,8 @@
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
#include <assert.h>
#include <algorithm>
#include <memory>
#include <sstream>
... ... @@ -11,23 +13,11 @@
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/cat.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \
do { \
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
fprintf(stderr, "%s does not exist in the metadata\n", src_key); \
exit(-1); \
} \
dst = atoi(value.get()); \
if (dst <= 0) { \
fprintf(stderr, "Invalud value %d for %s\n", dst, src_key); \
exit(-1); \
} \
} while (0)
#include "sherpa-onnx/csrc/unbind.h"
namespace sherpa_onnx {
... ... @@ -64,7 +54,7 @@ void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) {
fprintf(stderr, "%s\n", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator;
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(num_encoder_layers_, "num_encoder_layers");
SHERPA_ONNX_READ_META_DATA(T_, "T");
SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len");
... ... @@ -91,7 +81,7 @@ void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) {
fprintf(stderr, "%s\n", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator;
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
}
... ... @@ -120,37 +110,19 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::StackStates(
const std::vector<std::vector<Ort::Value>> &states) const {
int32_t batch_size = static_cast<int32_t>(states.size());
std::array<int64_t, 3> h_shape{num_encoder_layers_, batch_size, d_model_};
Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
h_shape.size());
std::array<int64_t, 3> c_shape{num_encoder_layers_, batch_size,
rnn_hidden_size_};
Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(),
c_shape.size());
float *dst_h = h.GetTensorMutableData<float>();
float *dst_c = c.GetTensorMutableData<float>();
std::vector<const Ort::Value *> h_buf(batch_size);
std::vector<const Ort::Value *> c_buf(batch_size);
for (int32_t layer = 0; layer != num_encoder_layers_; ++layer) {
for (int32_t i = 0; i != batch_size; ++i) {
const float *src_h =
states[i][0].GetTensorData<float>() + layer * d_model_;
const float *src_c =
states[i][1].GetTensorData<float>() + layer * rnn_hidden_size_;
std::copy(src_h, src_h + d_model_, dst_h);
std::copy(src_c, src_c + rnn_hidden_size_, dst_c);
dst_h += d_model_;
dst_c += rnn_hidden_size_;
}
for (int32_t i = 0; i != batch_size; ++i) {
assert(states[i].size() == 2);
h_buf[i] = &states[i][0];
c_buf[i] = &states[i][1];
}
std::vector<Ort::Value> ans;
Ort::Value h = Cat(allocator_, h_buf, 1);
Ort::Value c = Cat(allocator_, c_buf, 1);
std::vector<Ort::Value> ans;
ans.reserve(2);
ans.push_back(std::move(h));
ans.push_back(std::move(c));
... ... @@ -161,37 +133,19 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::StackStates(
std::vector<std::vector<Ort::Value>> OnlineLstmTransducerModel::UnStackStates(
const std::vector<Ort::Value> &states) const {
int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
assert(states.size() == 2);
std::vector<std::vector<Ort::Value>> ans(batch_size);
// allocate space
std::array<int64_t, 3> h_shape{num_encoder_layers_, 1, d_model_};
std::array<int64_t, 3> c_shape{num_encoder_layers_, 1, rnn_hidden_size_};
std::vector<Ort::Value> h_vec = Unbind(allocator_, &states[0], 1);
std::vector<Ort::Value> c_vec = Unbind(allocator_, &states[1], 1);
for (int32_t i = 0; i != batch_size; ++i) {
Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
h_shape.size());
Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(),
c_shape.size());
ans[i].push_back(std::move(h));
ans[i].push_back(std::move(c));
}
assert(h_vec.size() == batch_size);
assert(c_vec.size() == batch_size);
for (int32_t layer = 0; layer != num_encoder_layers_; ++layer) {
for (int32_t i = 0; i != batch_size; ++i) {
const float *src_h = states[0].GetTensorData<float>() +
layer * batch_size * d_model_ + i * d_model_;
const float *src_c = states[1].GetTensorData<float>() +
layer * batch_size * rnn_hidden_size_ +
i * rnn_hidden_size_;
float *dst_h = ans[i][0].GetTensorMutableData<float>() + layer * d_model_;
float *dst_c =
ans[i][1].GetTensorMutableData<float>() + layer * rnn_hidden_size_;
std::copy(src_h, src_h + d_model_, dst_h);
std::copy(src_c, src_c + rnn_hidden_size_, dst_c);
}
for (int32_t i = 0; i != batch_size; ++i) {
ans[i].push_back(std::move(h_vec[i]));
ans[i].push_back(std::move(c_vec[i]));
}
return ans;
... ... @@ -206,20 +160,15 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() {
Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
h_shape.size());
std::fill(h.GetTensorMutableData<float>(),
h.GetTensorMutableData<float>() +
num_encoder_layers_ * kBatchSize * d_model_,
0);
Fill<float>(&h, 0);
std::array<int64_t, 3> c_shape{num_encoder_layers_, kBatchSize,
rnn_hidden_size_};
Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(),
c_shape.size());
std::fill(c.GetTensorMutableData<float>(),
c.GetTensorMutableData<float>() +
num_encoder_layers_ * kBatchSize * rnn_hidden_size_,
0);
Fill<float>(&c, 0);
std::vector<Ort::Value> states;
... ...
... ... @@ -8,11 +8,13 @@
#include <string>
#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
#include "sherpa-onnx/csrc/online-zipformer-transducer-model.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
enum class ModelType {
kLstm,
kZipformer,
kUnkown,
};
... ... @@ -40,6 +42,8 @@ static ModelType GetModelType(const OnlineTransducerModelConfig &config) {
if (model_type.get() == std::string("lstm")) {
return ModelType::kLstm;
} else if (model_type.get() == std::string("zipformer")) {
return ModelType::kZipformer;
} else {
fprintf(stderr, "Unsupported model_type: %s\n", model_type.get());
return ModelType::kUnkown;
... ... @@ -53,6 +57,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
switch (model_type) {
case ModelType::kLstm:
return std::make_unique<OnlineLstmTransducerModel>(config);
case ModelType::kZipformer:
return std::make_unique<OnlineZipformerTransducerModel>(config);
case ModelType::kUnkown:
return nullptr;
}
... ...
// sherpa-onnx/csrc/online-zipformer-transducer-model.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-zipformer-transducer-model.h"
#include <assert.h>
#include <algorithm>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/cat.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/text-utils.h"
#include "sherpa-onnx/csrc/unbind.h"
namespace sherpa_onnx {
OnlineZipformerTransducerModel::OnlineZipformerTransducerModel(
const OnlineTransducerModelConfig &config)
: env_(ORT_LOGGING_LEVEL_WARNING),
config_(config),
sess_opts_{},
allocator_{} {
sess_opts_.SetIntraOpNumThreads(config.num_threads);
sess_opts_.SetInterOpNumThreads(config.num_threads);
InitEncoder(config.encoder_filename);
InitDecoder(config.decoder_filename);
InitJoiner(config.joiner_filename);
}
void OnlineZipformerTransducerModel::InitEncoder(const std::string &filename) {
encoder_sess_ = std::make_unique<Ort::Session>(
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
&encoder_input_names_ptr_);
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
&encoder_output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
os << "---encoder---\n";
PrintModelMetadata(os, meta_data);
fprintf(stderr, "%s\n", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA_VEC(encoder_dims_, "encoder_dims");
SHERPA_ONNX_READ_META_DATA_VEC(attention_dims_, "attention_dims");
SHERPA_ONNX_READ_META_DATA_VEC(num_encoder_layers_, "num_encoder_layers");
SHERPA_ONNX_READ_META_DATA_VEC(cnn_module_kernels_, "cnn_module_kernels");
SHERPA_ONNX_READ_META_DATA_VEC(left_context_len_, "left_context_len");
SHERPA_ONNX_READ_META_DATA(T_, "T");
SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len");
if (config_.debug) {
auto print = [](const std::vector<int32_t> &v, const char *name) {
fprintf(stderr, "%s: ", name);
for (auto i : v) {
fprintf(stderr, "%d ", i);
}
fprintf(stderr, "\n");
};
print(encoder_dims_, "encoder_dims");
print(attention_dims_, "attention_dims");
print(num_encoder_layers_, "num_encoder_layers");
print(cnn_module_kernels_, "cnn_module_kernels");
print(left_context_len_, "left_context_len");
fprintf(stderr, "T: %d\n", T_);
fprintf(stderr, "decode_chunk_len_: %d\n", decode_chunk_len_);
}
}
void OnlineZipformerTransducerModel::InitDecoder(const std::string &filename) {
decoder_sess_ = std::make_unique<Ort::Session>(
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
&decoder_input_names_ptr_);
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
&decoder_output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
os << "---decoder---\n";
PrintModelMetadata(os, meta_data);
fprintf(stderr, "%s\n", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
}
void OnlineZipformerTransducerModel::InitJoiner(const std::string &filename) {
joiner_sess_ = std::make_unique<Ort::Session>(
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
GetInputNames(joiner_sess_.get(), &joiner_input_names_,
&joiner_input_names_ptr_);
GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
&joiner_output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
os << "---joiner---\n";
PrintModelMetadata(os, meta_data);
fprintf(stderr, "%s\n", os.str().c_str());
}
}
std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
const std::vector<std::vector<Ort::Value>> &states) const {
int32_t batch_size = static_cast<int32_t>(states.size());
int32_t num_encoders = static_cast<int32_t>(num_encoder_layers_.size());
std::vector<const Ort::Value *> buf(batch_size);
std::vector<Ort::Value> ans;
ans.reserve(states[0].size());
// cached_len
for (int32_t i = 0; i != num_encoders; ++i) {
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][i];
}
auto v = Cat<int64_t>(allocator_, buf, 1); // (num_layers, 1)
ans.push_back(std::move(v));
}
// cached_avg
for (int32_t i = 0; i != num_encoders; ++i) {
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][num_encoders + i];
}
auto v = Cat(allocator_, buf, 1); // (num_layers, 1, encoder_dims)
ans.push_back(std::move(v));
}
// cached_key
for (int32_t i = 0; i != num_encoders; ++i) {
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][num_encoders * 2 + i];
}
// (num_layers, left_context_len, 1, attention_dims)
auto v = Cat(allocator_, buf, 2);
ans.push_back(std::move(v));
}
// cached_val
for (int32_t i = 0; i != num_encoders; ++i) {
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][num_encoders * 3 + i];
}
// (num_layers, left_context_len, 1, attention_dims/2)
auto v = Cat(allocator_, buf, 2);
ans.push_back(std::move(v));
}
// cached_val2
for (int32_t i = 0; i != num_encoders; ++i) {
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][num_encoders * 4 + i];
}
// (num_layers, left_context_len, 1, attention_dims/2)
auto v = Cat(allocator_, buf, 2);
ans.push_back(std::move(v));
}
// cached_conv1
for (int32_t i = 0; i != num_encoders; ++i) {
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][num_encoders * 5 + i];
}
// (num_layers, 1, encoder_dims, cnn_module_kernels-1)
auto v = Cat(allocator_, buf, 1);
ans.push_back(std::move(v));
}
// cached_conv2
for (int32_t i = 0; i != num_encoders; ++i) {
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][num_encoders * 6 + i];
}
// (num_layers, 1, encoder_dims, cnn_module_kernels-1)
auto v = Cat(allocator_, buf, 1);
ans.push_back(std::move(v));
}
return ans;
}
std::vector<std::vector<Ort::Value>>
OnlineZipformerTransducerModel::UnStackStates(
const std::vector<Ort::Value> &states) const {
assert(states.size() == num_encoder_layers_.size() * 7);
int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
int32_t num_encoders = num_encoder_layers_.size();
std::vector<std::vector<Ort::Value>> ans;
ans.resize(batch_size);
// cached_len
for (int32_t i = 0; i != num_encoders; ++i) {
auto v = Unbind<int64_t>(allocator_, &states[i], 1);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
ans[n].push_back(std::move(v[n]));
}
}
// cached_avg
for (int32_t i = num_encoders; i != 2 * num_encoders; ++i) {
auto v = Unbind(allocator_, &states[i], 1);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
ans[n].push_back(std::move(v[n]));
}
}
// cached_key
for (int32_t i = 2 * num_encoders; i != 3 * num_encoders; ++i) {
auto v = Unbind(allocator_, &states[i], 2);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
ans[n].push_back(std::move(v[n]));
}
}
// cached_val
for (int32_t i = 3 * num_encoders; i != 4 * num_encoders; ++i) {
auto v = Unbind(allocator_, &states[i], 2);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
ans[n].push_back(std::move(v[n]));
}
}
// cached_val2
for (int32_t i = 4 * num_encoders; i != 5 * num_encoders; ++i) {
auto v = Unbind(allocator_, &states[i], 2);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
ans[n].push_back(std::move(v[n]));
}
}
// cached_conv1
for (int32_t i = 5 * num_encoders; i != 6 * num_encoders; ++i) {
auto v = Unbind(allocator_, &states[i], 1);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
ans[n].push_back(std::move(v[n]));
}
}
// cached_conv2
for (int32_t i = 6 * num_encoders; i != 7 * num_encoders; ++i) {
auto v = Unbind(allocator_, &states[i], 1);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
ans[n].push_back(std::move(v[n]));
}
}
return ans;
}
std::vector<Ort::Value> OnlineZipformerTransducerModel::GetEncoderInitStates() {
// Please see
// https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py#L673
// for details
int32_t n = static_cast<int32_t>(encoder_dims_.size());
std::vector<Ort::Value> cached_len_vec;
std::vector<Ort::Value> cached_avg_vec;
std::vector<Ort::Value> cached_key_vec;
std::vector<Ort::Value> cached_val_vec;
std::vector<Ort::Value> cached_val2_vec;
std::vector<Ort::Value> cached_conv1_vec;
std::vector<Ort::Value> cached_conv2_vec;
cached_len_vec.reserve(n);
cached_avg_vec.reserve(n);
cached_key_vec.reserve(n);
cached_val_vec.reserve(n);
cached_val2_vec.reserve(n);
cached_conv1_vec.reserve(n);
cached_conv2_vec.reserve(n);
for (int32_t i = 0; i != n; ++i) {
{
std::array<int64_t, 2> s{num_encoder_layers_[i], 1};
auto v =
Ort::Value::CreateTensor<int64_t>(allocator_, s.data(), s.size());
Fill<int64_t>(&v, 0);
cached_len_vec.push_back(std::move(v));
}
{
std::array<int64_t, 3> s{num_encoder_layers_[i], 1, encoder_dims_[i]};
auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
Fill(&v, 0);
cached_avg_vec.push_back(std::move(v));
}
{
std::array<int64_t, 4> s{num_encoder_layers_[i], left_context_len_[i], 1,
attention_dims_[i]};
auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
Fill(&v, 0);
cached_key_vec.push_back(std::move(v));
}
{
std::array<int64_t, 4> s{num_encoder_layers_[i], left_context_len_[i], 1,
attention_dims_[i] / 2};
auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
Fill(&v, 0);
cached_val_vec.push_back(std::move(v));
}
{
std::array<int64_t, 4> s{num_encoder_layers_[i], left_context_len_[i], 1,
attention_dims_[i] / 2};
auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
Fill(&v, 0);
cached_val2_vec.push_back(std::move(v));
}
{
std::array<int64_t, 4> s{num_encoder_layers_[i], 1, encoder_dims_[i],
cnn_module_kernels_[i] - 1};
auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
Fill(&v, 0);
cached_conv1_vec.push_back(std::move(v));
}
{
std::array<int64_t, 4> s{num_encoder_layers_[i], 1, encoder_dims_[i],
cnn_module_kernels_[i] - 1};
auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
Fill(&v, 0);
cached_conv2_vec.push_back(std::move(v));
}
}
std::vector<Ort::Value> ans;
ans.reserve(n * 7);
for (auto &v : cached_len_vec) ans.push_back(std::move(v));
for (auto &v : cached_avg_vec) ans.push_back(std::move(v));
for (auto &v : cached_key_vec) ans.push_back(std::move(v));
for (auto &v : cached_val_vec) ans.push_back(std::move(v));
for (auto &v : cached_val2_vec) ans.push_back(std::move(v));
for (auto &v : cached_conv1_vec) ans.push_back(std::move(v));
for (auto &v : cached_conv2_vec) ans.push_back(std::move(v));
return ans;
}
std::pair<Ort::Value, std::vector<Ort::Value>>
OnlineZipformerTransducerModel::RunEncoder(Ort::Value features,
std::vector<Ort::Value> states) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::vector<Ort::Value> encoder_inputs;
encoder_inputs.reserve(1 + states.size());
encoder_inputs.push_back(std::move(features));
for (auto &v : states) {
encoder_inputs.push_back(std::move(v));
}
auto encoder_out = encoder_sess_->Run(
{}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
encoder_inputs.size(), encoder_output_names_ptr_.data(),
encoder_output_names_ptr_.size());
std::vector<Ort::Value> next_states;
next_states.reserve(states.size());
for (int32_t i = 1; i != static_cast<int32_t>(encoder_out.size()); ++i) {
next_states.push_back(std::move(encoder_out[i]));
}
return {std::move(encoder_out[0]), std::move(next_states)};
}
Ort::Value OnlineZipformerTransducerModel::BuildDecoderInput(
const std::vector<OnlineTransducerDecoderResult> &results) {
int32_t batch_size = static_cast<int32_t>(results.size());
std::array<int64_t, 2> shape{batch_size, context_size_};
Ort::Value decoder_input =
Ort::Value::CreateTensor<int64_t>(allocator_, shape.data(), shape.size());
int64_t *p = decoder_input.GetTensorMutableData<int64_t>();
for (const auto &r : results) {
const int64_t *begin = r.tokens.data() + r.tokens.size() - context_size_;
const int64_t *end = r.tokens.data() + r.tokens.size();
std::copy(begin, end, p);
p += context_size_;
}
return decoder_input;
}
Ort::Value OnlineZipformerTransducerModel::RunDecoder(
Ort::Value decoder_input) {
auto decoder_out = decoder_sess_->Run(
{}, decoder_input_names_ptr_.data(), &decoder_input, 1,
decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size());
return std::move(decoder_out[0]);
}
Ort::Value OnlineZipformerTransducerModel::RunJoiner(Ort::Value encoder_out,
Ort::Value decoder_out) {
std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
std::move(decoder_out)};
auto logit =
joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(),
joiner_input.size(), joiner_output_names_ptr_.data(),
joiner_output_names_ptr_.size());
return std::move(logit[0]);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/online-zipformer-transducer-model.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER_TRANSDUCER_MODEL_H_
#define SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER_TRANSDUCER_MODEL_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"
namespace sherpa_onnx {
class OnlineZipformerTransducerModel : public OnlineTransducerModel {
public:
explicit OnlineZipformerTransducerModel(
const OnlineTransducerModelConfig &config);
std::vector<Ort::Value> StackStates(
const std::vector<std::vector<Ort::Value>> &states) const override;
std::vector<std::vector<Ort::Value>> UnStackStates(
const std::vector<Ort::Value> &states) const override;
std::vector<Ort::Value> GetEncoderInitStates() override;
std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
Ort::Value features, std::vector<Ort::Value> states) override;
Ort::Value BuildDecoderInput(
const std::vector<OnlineTransducerDecoderResult> &results) override;
Ort::Value RunDecoder(Ort::Value decoder_input) override;
Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override;
int32_t ContextSize() const override { return context_size_; }
int32_t ChunkSize() const override { return T_; }
int32_t ChunkShift() const override { return decode_chunk_len_; }
int32_t VocabSize() const override { return vocab_size_; }
OrtAllocator *Allocator() override { return allocator_; }
private:
void InitEncoder(const std::string &encoder_filename);
void InitDecoder(const std::string &decoder_filename);
void InitJoiner(const std::string &joiner_filename);
private:
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> encoder_sess_;
std::unique_ptr<Ort::Session> decoder_sess_;
std::unique_ptr<Ort::Session> joiner_sess_;
std::vector<std::string> encoder_input_names_;
std::vector<const char *> encoder_input_names_ptr_;
std::vector<std::string> encoder_output_names_;
std::vector<const char *> encoder_output_names_ptr_;
std::vector<std::string> decoder_input_names_;
std::vector<const char *> decoder_input_names_ptr_;
std::vector<std::string> decoder_output_names_;
std::vector<const char *> decoder_output_names_ptr_;
std::vector<std::string> joiner_input_names_;
std::vector<const char *> joiner_input_names_ptr_;
std::vector<std::string> joiner_output_names_;
std::vector<const char *> joiner_output_names_ptr_;
OnlineTransducerModelConfig config_;
std::vector<int32_t> encoder_dims_;
std::vector<int32_t> attention_dims_;
std::vector<int32_t> num_encoder_layers_;
std::vector<int32_t> cnn_module_kernels_;
std::vector<int32_t> left_context_len_;
int32_t T_ = 0;
int32_t decode_chunk_len_ = 0;
int32_t context_size_ = 0;
int32_t vocab_size_ = 0;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER_TRANSDUCER_MODEL_H_
... ...
... ... @@ -46,16 +46,74 @@ void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) {
}
}
Ort::Value Clone(Ort::Value *v) {
Ort::Value Clone(const Ort::Value *v) {
auto type_and_shape = v->GetTensorTypeAndShapeInfo();
std::vector<int64_t> shape = type_and_shape.GetShape();
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
return Ort::Value::CreateTensor(memory_info, v->GetTensorMutableData<float>(),
type_and_shape.GetElementCount(),
shape.data(), shape.size());
switch (type_and_shape.GetElementType()) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
return Ort::Value::CreateTensor(
memory_info,
const_cast<Ort::Value *>(v)->GetTensorMutableData<int32_t>(),
type_and_shape.GetElementCount(), shape.data(), shape.size());
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
return Ort::Value::CreateTensor(
memory_info,
const_cast<Ort::Value *>(v)->GetTensorMutableData<int64_t>(),
type_and_shape.GetElementCount(), shape.data(), shape.size());
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
return Ort::Value::CreateTensor(
memory_info,
const_cast<Ort::Value *>(v)->GetTensorMutableData<float>(),
type_and_shape.GetElementCount(), shape.data(), shape.size());
default:
fprintf(stderr, "Unsupported type: %d\n",
static_cast<int32_t>(type_and_shape.GetElementType()));
exit(-1);
// unreachable code
return Ort::Value{nullptr};
}
}
void Print1D(Ort::Value *v) {
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
const float *d = v->GetTensorData<float>();
for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
fprintf(stderr, "%.3f ", d[i]);
}
fprintf(stderr, "\n");
}
void Print2D(Ort::Value *v) {
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
const float *d = v->GetTensorData<float>();
for (int32_t r = 0; r != static_cast<int32_t>(shape[0]); ++r) {
for (int32_t c = 0; c != static_cast<int32_t>(shape[1]); ++c, ++d) {
fprintf(stderr, "%.3f ", *d);
}
fprintf(stderr, "\n");
}
fprintf(stderr, "\n");
}
void Print3D(Ort::Value *v) {
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
const float *d = v->GetTensorData<float>();
for (int32_t p = 0; p != static_cast<int32_t>(shape[0]); ++p) {
fprintf(stderr, "---plane %d---\n", p);
for (int32_t r = 0; r != static_cast<int32_t>(shape[1]); ++r) {
for (int32_t c = 0; c != static_cast<int32_t>(shape[2]); ++c, ++d) {
fprintf(stderr, "%.3f ", *d);
}
fprintf(stderr, "\n");
}
}
fprintf(stderr, "\n");
}
} // namespace sherpa_onnx
... ...
... ... @@ -56,7 +56,23 @@ void PrintModelMetadata(std::ostream &os,
const Ort::ModelMetadata &meta_data); // NOLINT
// Return a shallow copy of v
Ort::Value Clone(Ort::Value *v);
Ort::Value Clone(const Ort::Value *v);
// Print a 1-D tensor to stderr
void Print1D(Ort::Value *v);
// Print a 2-D tensor to stderr
void Print2D(Ort::Value *v);
// Print a 3-D tensor to stderr
void Print3D(Ort::Value *v);
template <typename T = float>
void Fill(Ort::Value *tensor, T value) {
auto n = tensor->GetTypeInfo().GetTensorTypeAndShapeInfo().GetElementCount();
auto p = tensor->GetTensorMutableData<T>();
std::fill(p, p + n, value);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/text-utils.cc
//
// Copyright 2009-2011 Saarland University; Microsoft Corporation
// Copyright 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/text-utils.h"
#include <string>
#include <vector>
// This file is copied/modified from
// https://github.com/kaldi-asr/kaldi/blob/master/src/util/text-utils.cc
namespace sherpa_onnx {
void SplitStringToVector(const std::string &full, const char *delim,
bool omit_empty_strings,
std::vector<std::string> *out) {
size_t start = 0, found = 0, end = full.size();
out->clear();
while (found != std::string::npos) {
found = full.find_first_of(delim, start);
// start != end condition is for when the delimiter is at the end
if (!omit_empty_strings || (found != start && start != end))
out->push_back(full.substr(start, found - start));
start = found + 1;
}
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/text-utils.h
//
// Copyright 2009-2011 Saarland University; Microsoft Corporation
// Copyright 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_TEXT_UTILS_H_
#define SHERPA_ONNX_CSRC_TEXT_UTILS_H_
#include <stdlib.h>
#include <string>
#include <vector>
#ifdef _MSC_VER
#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) \
_strtoi64(cur_cstr, end_cstr, 10);
#else
#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) strtoll(cur_cstr, end_cstr, 10);
#endif
// This file is copied/modified from
// https://github.com/kaldi-asr/kaldi/blob/master/src/util/text-utils.h
namespace sherpa_onnx {
/// Split a string using any of the single character delimiters.
/// If omit_empty_strings == true, the output will contain any
/// nonempty strings after splitting on any of the
/// characters in the delimiter. If omit_empty_strings == false,
/// the output will contain n+1 strings if there are n characters
/// in the set "delim" within the input string. In this case
/// the empty string is split to a single empty string.
void SplitStringToVector(const std::string &full, const char *delim,
bool omit_empty_strings,
std::vector<std::string> *out);
/**
\brief Split a string (e.g. 1:2:3) into a vector of integers.
\param [in] delim String containing a list of characters, any of which
is allowed as a delimiter.
\param [in] omit_empty_strings If true, empty strings between delimiters are
allowed and will not produce an output integer; if false,
instances of characters in 'delim' that are consecutive or
at the start or end of the string would be an error.
You'll normally want this to be true if 'delim' consists
of spaces, and false otherwise.
\param [out] out The output list of integers.
*/
template <class I>
bool SplitStringToIntegers(const std::string &full, const char *delim,
bool omit_empty_strings, // typically false [but
// should probably be true
// if "delim" is spaces].
std::vector<I> *out) {
static_assert(std::is_integral<I>::value, "");
if (*(full.c_str()) == '\0') {
out->clear();
return true;
}
std::vector<std::string> split;
SplitStringToVector(full, delim, omit_empty_strings, &split);
out->resize(split.size());
for (size_t i = 0; i < split.size(); i++) {
const char *this_str = split[i].c_str();
char *end = NULL;
int64_t j = 0;
j = SHERPA_ONNX_STRTOLL(this_str, &end);
if (end == this_str || *end != '\0') {
out->clear();
return false;
} else {
I jI = static_cast<I>(j);
if (static_cast<int64_t>(jI) != j) {
// output type cannot fit this integer.
out->clear();
return false;
}
(*out)[i] = jI;
}
}
return true;
}
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_
... ...
// sherpa-onnx/csrc/unbind-test.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/unbind.h"
#include "gtest/gtest.h"
#include "sherpa-onnx/csrc/cat.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
TEST(Ubind, Test1DTensors) {
Ort::AllocatorWithDefaultOptions allocator;
std::array<int64_t, 1> shape{3};
Ort::Value v =
Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
float *p = v.GetTensorMutableData<float>();
for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
p[i] = i;
}
auto ans = Unbind(allocator, &v, 0);
EXPECT_EQ(ans.size(), shape[0]);
for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
EXPECT_EQ(ans[i].GetTensorData<float>()[0], p[i]);
}
Print1D(&v);
for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
Print1D(&ans[i]);
}
// For Cat
std::vector<const Ort::Value *> vec(ans.size());
for (int32_t i = 0; i != static_cast<int32_t>(vec.size()); ++i) {
vec[i] = &ans[i];
}
Ort::Value v2 = Cat(allocator, vec, 0);
const float *p2 = v2.GetTensorData<float>();
for (int32_t i = 0; i != shape[0]; ++i) {
EXPECT_EQ(p[i], p2[i]);
}
}
TEST(Ubind, Test2DTensorsDim0) {
Ort::AllocatorWithDefaultOptions allocator;
std::array<int64_t, 2> shape{3, 2};
Ort::Value v =
Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
float *p = v.GetTensorMutableData<float>();
for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1]); ++i) {
p[i] = i;
}
auto ans = Unbind(allocator, &v, 0);
Print2D(&v);
for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
Print2D(&ans[i]);
}
for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
const float *pans = ans[i].GetTensorData<float>();
for (int32_t k = 0; k != static_cast<int32_t>(shape[1]); ++k, ++p) {
EXPECT_EQ(*p, pans[k]);
}
}
// For Cat
std::vector<const Ort::Value *> vec(ans.size());
for (int32_t i = 0; i != static_cast<int32_t>(vec.size()); ++i) {
vec[i] = &ans[i];
}
Ort::Value v2 = Cat(allocator, vec, 0);
Print2D(&v2);
p = v.GetTensorMutableData<float>();
const float *p2 = v2.GetTensorData<float>();
for (int32_t i = 0; i != shape[0] * shape[1]; ++i) {
EXPECT_EQ(p[i], p2[i]);
}
}
TEST(Ubind, Test2DTensorsDim1) {
Ort::AllocatorWithDefaultOptions allocator;
std::array<int64_t, 2> shape{3, 2};
Ort::Value v =
Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
float *p = v.GetTensorMutableData<float>();
for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1]); ++i) {
p[i] = i;
}
auto ans = Unbind(allocator, &v, 1);
Print2D(&v);
for (int32_t i = 0; i != static_cast<int32_t>(shape[1]); ++i) {
Print2D(&ans[i]);
}
// For Cat
std::vector<const Ort::Value *> vec(ans.size());
for (int32_t i = 0; i != static_cast<int32_t>(vec.size()); ++i) {
vec[i] = &ans[i];
}
Ort::Value v2 = Cat(allocator, vec, 1);
Print2D(&v2);
p = v.GetTensorMutableData<float>();
const float *p2 = v2.GetTensorData<float>();
for (int32_t i = 0; i != shape[0] * shape[1]; ++i) {
EXPECT_EQ(p[i], p2[i]);
}
}
TEST(Ubind, Test3DTensorsDim0) {
Ort::AllocatorWithDefaultOptions allocator;
std::array<int64_t, 3> shape{3, 2, 5};
Ort::Value v =
Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
float *p = v.GetTensorMutableData<float>();
for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1] * shape[2]);
++i) {
p[i] = i;
}
auto ans = Unbind(allocator, &v, 0);
Print3D(&v);
for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
Print3D(&ans[i]);
}
for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
const float *pans = ans[i].GetTensorData<float>();
for (int32_t k = 0; k != static_cast<int32_t>(shape[1] * shape[2]);
++k, ++p) {
EXPECT_EQ(*p, pans[k]);
}
}
// For Cat
std::vector<const Ort::Value *> vec(ans.size());
for (int32_t i = 0; i != static_cast<int32_t>(vec.size()); ++i) {
vec[i] = &ans[i];
}
Ort::Value v2 = Cat(allocator, vec, 0);
Print3D(&v2);
p = v.GetTensorMutableData<float>();
const float *p2 = v2.GetTensorData<float>();
for (int32_t i = 0; i != shape[0] * shape[1] * shape[2]; ++i) {
EXPECT_EQ(p[i], p2[i]);
}
}
TEST(Ubind, Test3DTensorsDim1) {
Ort::AllocatorWithDefaultOptions allocator;
std::array<int64_t, 3> shape{3, 2, 5};
Ort::Value v =
Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
float *p = v.GetTensorMutableData<float>();
for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1] * shape[2]);
++i) {
p[i] = i;
}
auto ans = Unbind(allocator, &v, 1);
Print3D(&v);
for (int32_t i = 0; i != static_cast<int32_t>(shape[1]); ++i) {
Print3D(&ans[i]);
}
// For Cat
std::vector<const Ort::Value *> vec(ans.size());
for (int32_t i = 0; i != static_cast<int32_t>(vec.size()); ++i) {
vec[i] = &ans[i];
}
Ort::Value v2 = Cat(allocator, vec, 1);
Print3D(&v2);
p = v.GetTensorMutableData<float>();
const float *p2 = v2.GetTensorData<float>();
for (int32_t i = 0; i != shape[0] * shape[1] * shape[2]; ++i) {
EXPECT_EQ(p[i], p2[i]);
}
}
TEST(Ubind, Test3DTensorsDim2) {
Ort::AllocatorWithDefaultOptions allocator;
std::array<int64_t, 3> shape{3, 2, 5};
Ort::Value v =
Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
float *p = v.GetTensorMutableData<float>();
for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1] * shape[2]);
++i) {
p[i] = i;
}
auto ans = Unbind(allocator, &v, 2);
Print3D(&v);
for (int32_t i = 0; i != static_cast<int32_t>(shape[2]); ++i) {
Print3D(&ans[i]);
}
// For Cat
std::vector<const Ort::Value *> vec(ans.size());
for (int32_t i = 0; i != static_cast<int32_t>(vec.size()); ++i) {
vec[i] = &ans[i];
}
Ort::Value v2 = Cat(allocator, vec, 2);
Print3D(&v2);
p = v.GetTensorMutableData<float>();
const float *p2 = v2.GetTensorData<float>();
for (int32_t i = 0; i != shape[0] * shape[1] * shape[2]; ++i) {
EXPECT_EQ(p[i], p2[i]);
}
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/unbind.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/unbind.h"
#include <assert.h>
#include <algorithm>
#include <functional>
#include <numeric>
#include <utility>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
template <typename T /*= float*/>
std::vector<Ort::Value> Unbind(OrtAllocator *allocator, const Ort::Value *value,
int32_t dim) {
std::vector<int64_t> shape = value->GetTensorTypeAndShapeInfo().GetShape();
assert(dim >= 0);
assert(dim < static_cast<int32_t>(shape.size()));
int32_t n = static_cast<int32_t>(shape[dim]);
if (n == 1) {
std::vector<Ort::Value> ans;
ans.push_back(Clone(value));
return ans;
}
std::vector<int64_t> ans_shape = shape;
ans_shape[dim] = 1; // // Unlike torch, we keep the dim to 1
// allocator tensors
std::vector<Ort::Value> ans;
ans.reserve(n);
for (int32_t i = 0; i != n; ++i) {
Ort::Value t = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(),
ans_shape.size());
ans.push_back(std::move(t));
}
auto leading_size = static_cast<int32_t>(std::accumulate(
shape.begin(), shape.begin() + dim, 1, std::multiplies<int64_t>()));
auto trailing_size = static_cast<int32_t>(std::accumulate(
shape.begin() + dim + 1, shape.end(), 1, std::multiplies<int64_t>()));
const T *src = value->GetTensorData<T>();
for (int32_t i = 0; i != leading_size; ++i) {
for (int32_t k = 0; k != n; ++k) {
T *dst = ans[k].GetTensorMutableData<T>() + i * trailing_size;
std::copy(src, src + trailing_size, dst);
src += trailing_size;
}
}
return std::move(ans);
}
template std::vector<Ort::Value> Unbind<float>(OrtAllocator *allocator,
const Ort::Value *value,
int32_t dim);
template std::vector<Ort::Value> Unbind<int64_t>(OrtAllocator *allocator,
const Ort::Value *value,
int32_t dim);
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/unbind.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_UNBIND_H_
#define SHERPA_ONNX_CSRC_UNBIND_H_
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
namespace sherpa_onnx {
/** It is similar to torch.unbind() but we keep the unbind dim to 1 in
* the output
*
* @param allocator Allocator to allocate space for the returned tensor
* @param value The tensor to unbind
* @param dim The dim along which to unbind the tensor
*
* @return Return a list of tensors
*/
template <typename T = float>
std::vector<Ort::Value> Unbind(OrtAllocator *allocator, const Ort::Value *value,
int32_t dim);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_UNBIND_H_
... ...