Fangjun Kuang
Committed by GitHub

Add Streaming zipformer (#50)

@@ -33,8 +33,13 @@ jobs: @@ -33,8 +33,13 @@ jobs:
33 strategy: 33 strategy:
34 fail-fast: false 34 fail-fast: false
35 matrix: 35 matrix:
36 - os: [ubuntu-latest, macos-latest, windows-latest] 36 + os: [ubuntu-latest, macos-latest] # windows-latest]
37 python-version: ["3.7", "3.8", "3.9", "3.10"] 37 python-version: ["3.7", "3.8", "3.9", "3.10"]
  38 + exclude:
  39 + - os: macos-latest
  40 + python-version: "3.9"
  41 + - os: macos-latest
  42 + python-version: "3.10"
38 43
39 steps: 44 steps:
40 - uses: actions/checkout@v2 45 - uses: actions/checkout@v2
@@ -8,3 +8,4 @@ sherpa-onnx-* @@ -8,3 +8,4 @@ sherpa-onnx-*
8 __pycache__ 8 __pycache__
9 dist/ 9 dist/
10 sherpa_onnx.egg-info/ 10 sherpa_onnx.egg-info/
  11 +.DS_Store
@@ -62,6 +62,7 @@ endif() @@ -62,6 +62,7 @@ endif()
62 62
63 if(SHERPA_ONNX_ENABLE_TESTS) 63 if(SHERPA_ONNX_ENABLE_TESTS)
64 enable_testing() 64 enable_testing()
  65 + include(googletest)
65 endif() 66 endif()
66 67
67 add_subdirectory(sherpa-onnx) 68 add_subdirectory(sherpa-onnx)
1 -# Copyright 2020 Fangjun Kuang (csukuangfj@gmail.com)  
2 -# See ../LICENSE for clarification regarding multiple authors  
3 -#  
4 -# Licensed under the Apache License, Version 2.0 (the "License");  
5 -# you may not use this file except in compliance with the License.  
6 -# You may obtain a copy of the License at  
7 -#  
8 -# http://www.apache.org/licenses/LICENSE-2.0  
9 -#  
10 -# Unless required by applicable law or agreed to in writing, software  
11 -# distributed under the License is distributed on an "AS IS" BASIS,  
12 -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
13 -# See the License for the specific language governing permissions and  
14 -# limitations under the License.  
15 -  
16 function(download_googltest) 1 function(download_googltest)
17 include(FetchContent) 2 include(FetchContent)
18 3
@@ -26,6 +11,7 @@ function(download_googltest) @@ -26,6 +11,7 @@ function(download_googltest)
26 ${PROJECT_SOURCE_DIR}/googletest-1.13.0.tar.gz 11 ${PROJECT_SOURCE_DIR}/googletest-1.13.0.tar.gz
27 ${PROJECT_BINARY_DIR}/googletest-1.13.0.tar.gz 12 ${PROJECT_BINARY_DIR}/googletest-1.13.0.tar.gz
28 /tmp/googletest-1.13.0.tar.gz 13 /tmp/googletest-1.13.0.tar.gz
  14 + /star-fj/fangjun/download/github/googletest-1.13.0.tar.gz
29 ) 15 )
30 16
31 foreach(f IN LISTS possible_file_locations) 17 foreach(f IN LISTS possible_file_locations)
1 include_directories(${CMAKE_SOURCE_DIR}) 1 include_directories(${CMAKE_SOURCE_DIR})
2 2
3 add_library(sherpa-onnx-core 3 add_library(sherpa-onnx-core
  4 + cat.cc
4 features.cc 5 features.cc
5 online-lstm-transducer-model.cc 6 online-lstm-transducer-model.cc
6 online-recognizer.cc 7 online-recognizer.cc
@@ -8,8 +9,11 @@ add_library(sherpa-onnx-core @@ -8,8 +9,11 @@ add_library(sherpa-onnx-core
8 online-transducer-greedy-search-decoder.cc 9 online-transducer-greedy-search-decoder.cc
9 online-transducer-model-config.cc 10 online-transducer-model-config.cc
10 online-transducer-model.cc 11 online-transducer-model.cc
  12 + online-zipformer-transducer-model.cc
11 onnx-utils.cc 13 onnx-utils.cc
12 symbol-table.cc 14 symbol-table.cc
  15 + text-utils.cc
  16 + unbind.cc
13 wave-reader.cc 17 wave-reader.cc
14 ) 18 )
15 19
@@ -27,3 +31,32 @@ endif() @@ -27,3 +31,32 @@ endif()
27 31
28 install(TARGETS sherpa-onnx-core DESTINATION lib) 32 install(TARGETS sherpa-onnx-core DESTINATION lib)
29 install(TARGETS sherpa-onnx DESTINATION bin) 33 install(TARGETS sherpa-onnx DESTINATION bin)
  34 +
  35 +if(SHERPA_ONNX_ENABLE_TESTS)
  36 + set(sherpa_onnx_test_srcs
  37 + cat-test.cc
  38 + unbind-test.cc
  39 + )
  40 +
  41 + function(sherpa_onnx_add_test source)
  42 + get_filename_component(name ${source} NAME_WE)
  43 + set(target_name ${name})
  44 + add_executable(${target_name} "${source}")
  45 +
  46 + target_link_libraries(${target_name}
  47 + PRIVATE
  48 + gtest
  49 + gtest_main
  50 + sherpa-onnx-core
  51 + )
  52 +
  53 + add_test(NAME "${target_name}"
  54 + COMMAND
  55 + $<TARGET_FILE:${target_name}>
  56 + )
  57 + endfunction()
  58 +
  59 + foreach(source IN LISTS sherpa_onnx_test_srcs)
  60 + sherpa_onnx_add_test(${source})
  61 + endforeach()
  62 +endif()
  1 +// sherpa-onnx/csrc/cat-test.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/cat.h"
  6 +
  7 +#include "gtest/gtest.h"
  8 +#include "sherpa-onnx/csrc/onnx-utils.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +TEST(Cat, Test1DTensors) {
  13 + Ort::AllocatorWithDefaultOptions allocator;
  14 +
  15 + std::array<int64_t, 1> a_shape{3};
  16 + std::array<int64_t, 1> b_shape{6};
  17 +
  18 + Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
  19 + a_shape.size());
  20 +
  21 + Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
  22 + b_shape.size());
  23 + float *pa = a.GetTensorMutableData<float>();
  24 + float *pb = b.GetTensorMutableData<float>();
  25 + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) {
  26 + pa[i] = i;
  27 + }
  28 + for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0]); ++i) {
  29 + pb[i] = i + 10;
  30 + }
  31 +
  32 + Ort::Value ans = Cat(allocator, {&a, &b}, 0);
  33 +
  34 + const float *pans = ans.GetTensorData<float>();
  35 + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) {
  36 + EXPECT_EQ(pa[i], pans[i]);
  37 + }
  38 +
  39 + for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0]); ++i) {
  40 + EXPECT_EQ(pb[i], pans[i + a_shape[0]]);
  41 + }
  42 +
  43 + Print1D(&a);
  44 + Print1D(&b);
  45 + Print1D(&ans);
  46 +}
  47 +
  48 +TEST(Cat, Test2DTensorsDim0) {
  49 + Ort::AllocatorWithDefaultOptions allocator;
  50 +
  51 + std::array<int64_t, 2> a_shape{2, 3};
  52 + std::array<int64_t, 2> b_shape{4, 3};
  53 +
  54 + Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
  55 + a_shape.size());
  56 +
  57 + Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
  58 + b_shape.size());
  59 +
  60 + float *pa = a.GetTensorMutableData<float>();
  61 + float *pb = b.GetTensorMutableData<float>();
  62 + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) {
  63 + pa[i] = i;
  64 + }
  65 + for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0] * b_shape[1]); ++i) {
  66 + pb[i] = i + 10;
  67 + }
  68 +
  69 + Ort::Value ans = Cat(allocator, {&a, &b}, 0);
  70 +
  71 + const float *pans = ans.GetTensorData<float>();
  72 + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) {
  73 + EXPECT_EQ(pa[i], pans[i]);
  74 + }
  75 + for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0] * b_shape[1]); ++i) {
  76 + EXPECT_EQ(pb[i], pans[i + a_shape[0] * a_shape[1]]);
  77 + }
  78 +
  79 + Print2D(&a);
  80 + Print2D(&b);
  81 + Print2D(&ans);
  82 +}
  83 +
  84 +TEST(Cat, Test2DTensorsDim1) {
  85 + Ort::AllocatorWithDefaultOptions allocator;
  86 +
  87 + std::array<int64_t, 2> a_shape{4, 3};
  88 + std::array<int64_t, 2> b_shape{4, 2};
  89 +
  90 + Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
  91 + a_shape.size());
  92 +
  93 + Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
  94 + b_shape.size());
  95 +
  96 + float *pa = a.GetTensorMutableData<float>();
  97 + float *pb = b.GetTensorMutableData<float>();
  98 + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) {
  99 + pa[i] = i;
  100 + }
  101 + for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0] * b_shape[1]); ++i) {
  102 + pb[i] = i + 10;
  103 + }
  104 +
  105 + Ort::Value ans = Cat(allocator, {&a, &b}, 1);
  106 +
  107 + const float *pans = ans.GetTensorData<float>();
  108 +
  109 + for (int32_t r = 0; r != static_cast<int32_t>(a_shape[0]); ++r) {
  110 + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[1]);
  111 + ++i, ++pa, ++pans) {
  112 + EXPECT_EQ(*pa, *pans);
  113 + }
  114 +
  115 + for (int32_t i = 0; i != static_cast<int32_t>(b_shape[1]);
  116 + ++i, ++pb, ++pans) {
  117 + EXPECT_EQ(*pb, *pans);
  118 + }
  119 + }
  120 +
  121 + Print2D(&a);
  122 + Print2D(&b);
  123 + Print2D(&ans);
  124 +}
  125 +
  126 +TEST(Cat, Test3DTensorsDim0) {
  127 + Ort::AllocatorWithDefaultOptions allocator;
  128 +
  129 + std::array<int64_t, 3> a_shape{2, 3, 2};
  130 + std::array<int64_t, 3> b_shape{4, 3, 2};
  131 +
  132 + Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
  133 + a_shape.size());
  134 +
  135 + Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
  136 + b_shape.size());
  137 +
  138 + float *pa = a.GetTensorMutableData<float>();
  139 + float *pb = b.GetTensorMutableData<float>();
  140 + for (int32_t i = 0;
  141 + i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) {
  142 + pa[i] = i;
  143 + }
  144 + for (int32_t i = 0;
  145 + i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) {
  146 + pb[i] = i + 10;
  147 + }
  148 +
  149 + Ort::Value ans = Cat(allocator, {&a, &b}, 0);
  150 +
  151 + const float *pans = ans.GetTensorData<float>();
  152 + for (int32_t i = 0;
  153 + i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) {
  154 + EXPECT_EQ(pa[i], pans[i]);
  155 + }
  156 + for (int32_t i = 0;
  157 + i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) {
  158 + EXPECT_EQ(pb[i], pans[i + a_shape[0] * a_shape[1] * a_shape[2]]);
  159 + }
  160 +
  161 + Print3D(&a);
  162 + Print3D(&b);
  163 + Print3D(&ans);
  164 +}
  165 +
  166 +TEST(Cat, Test3DTensorsDim1) {
  167 + Ort::AllocatorWithDefaultOptions allocator;
  168 +
  169 + std::array<int64_t, 3> a_shape{2, 2, 3};
  170 + std::array<int64_t, 3> b_shape{2, 4, 3};
  171 +
  172 + Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
  173 + a_shape.size());
  174 +
  175 + Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
  176 + b_shape.size());
  177 +
  178 + float *pa = a.GetTensorMutableData<float>();
  179 + float *pb = b.GetTensorMutableData<float>();
  180 + for (int32_t i = 0;
  181 + i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) {
  182 + pa[i] = i;
  183 + }
  184 + for (int32_t i = 0;
  185 + i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) {
  186 + pb[i] = i + 10;
  187 + }
  188 +
  189 + Ort::Value ans = Cat(allocator, {&a, &b}, 1);
  190 +
  191 + const float *pans = ans.GetTensorData<float>();
  192 +
  193 + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) {
  194 + for (int32_t k = 0; k != static_cast<int32_t>(a_shape[1] * a_shape[2]);
  195 + ++k, ++pa, ++pans) {
  196 + EXPECT_EQ(*pa, *pans);
  197 + }
  198 +
  199 + for (int32_t k = 0; k != static_cast<int32_t>(b_shape[1] * b_shape[2]);
  200 + ++k, ++pb, ++pans) {
  201 + EXPECT_EQ(*pb, *pans);
  202 + }
  203 + }
  204 +
  205 + Print3D(&a);
  206 + Print3D(&b);
  207 + Print3D(&ans);
  208 +}
  209 +
  210 +TEST(Cat, Test3DTensorsDim2) {
  211 + Ort::AllocatorWithDefaultOptions allocator;
  212 +
  213 + std::array<int64_t, 3> a_shape{2, 3, 4};
  214 + std::array<int64_t, 3> b_shape{2, 3, 5};
  215 +
  216 + Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
  217 + a_shape.size());
  218 +
  219 + Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
  220 + b_shape.size());
  221 +
  222 + float *pa = a.GetTensorMutableData<float>();
  223 + float *pb = b.GetTensorMutableData<float>();
  224 + for (int32_t i = 0;
  225 + i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) {
  226 + pa[i] = i;
  227 + }
  228 + for (int32_t i = 0;
  229 + i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) {
  230 + pb[i] = i + 10;
  231 + }
  232 +
  233 + Ort::Value ans = Cat(allocator, {&a, &b}, 2);
  234 +
  235 + const float *pans = ans.GetTensorData<float>();
  236 +
  237 + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) {
  238 + for (int32_t k = 0; k != static_cast<int32_t>(a_shape[2]);
  239 + ++k, ++pa, ++pans) {
  240 + EXPECT_EQ(*pa, *pans);
  241 + }
  242 +
  243 + for (int32_t k = 0; k != static_cast<int32_t>(b_shape[2]);
  244 + ++k, ++pb, ++pans) {
  245 + EXPECT_EQ(*pb, *pans);
  246 + }
  247 + }
  248 +
  249 + Print3D(&a);
  250 + Print3D(&b);
  251 + Print3D(&ans);
  252 +}
  253 +
  254 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/cat.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/cat.h"
  6 +
  7 +#include <algorithm>
  8 +#include <functional>
  9 +#include <numeric>
  10 +#include <utility>
  11 +
  12 +#include "sherpa-onnx/csrc/onnx-utils.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +static bool Compare(const std::vector<int64_t> &a,
  17 + const std::vector<int64_t> &b, int32_t skip_dim) {
  18 + if (a.size() != b.size()) return false;
  19 +
  20 + for (int32_t i = 0; i != static_cast<int32_t>(a.size()); ++i) {
  21 + if (i == skip_dim) continue;
  22 +
  23 + if (a[i] != b[i]) return false;
  24 + }
  25 +
  26 + return true;
  27 +}
  28 +
  29 +static void PrintShape(const std::vector<int64_t> &a) {
  30 + for (auto i : a) {
  31 + fprintf(stderr, "%d ", static_cast<int32_t>(i));
  32 + }
  33 + fprintf(stderr, "\n");
  34 +}
  35 +
  36 +template <typename T /*=float*/>
  37 +Ort::Value Cat(OrtAllocator *allocator,
  38 + const std::vector<const Ort::Value *> &values, int32_t dim) {
  39 + if (values.size() == 1u) {
  40 + return Clone(values[0]);
  41 + }
  42 +
  43 + std::vector<int64_t> v0_shape =
  44 + values[0]->GetTensorTypeAndShapeInfo().GetShape();
  45 +
  46 + int64_t total_dim = v0_shape[dim];
  47 +
  48 + for (int32_t i = 1; i != static_cast<int32_t>(values.size()); ++i) {
  49 + auto s = values[i]->GetTensorTypeAndShapeInfo().GetShape();
  50 + total_dim += s[dim];
  51 +
  52 + bool ret = Compare(v0_shape, s, dim);
  53 + if (!ret) {
  54 + fprintf(stderr, "Incorrect shape in Cat !\n");
  55 +
  56 + fprintf(stderr, "Shape for tensor 0: ");
  57 + PrintShape(v0_shape);
  58 +
  59 + fprintf(stderr, "Shape for tensor %d: ", i);
  60 + PrintShape(s);
  61 +
  62 + exit(-1);
  63 + }
  64 + }
  65 +
  66 + std::vector<int64_t> ans_shape;
  67 + ans_shape.reserve(v0_shape.size());
  68 + ans_shape.insert(ans_shape.end(), v0_shape.data(), v0_shape.data() + dim);
  69 + ans_shape.push_back(total_dim);
  70 + ans_shape.insert(ans_shape.end(), v0_shape.data() + dim + 1,
  71 + v0_shape.data() + v0_shape.size());
  72 +
  73 + auto leading_size = static_cast<int32_t>(std::accumulate(
  74 + v0_shape.begin(), v0_shape.begin() + dim, 1, std::multiplies<int64_t>()));
  75 +
  76 + auto trailing_size = static_cast<int32_t>(
  77 + std::accumulate(v0_shape.begin() + dim + 1, v0_shape.end(), 1,
  78 + std::multiplies<int64_t>()));
  79 +
  80 + Ort::Value ans = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(),
  81 + ans_shape.size());
  82 + T *dst = ans.GetTensorMutableData<T>();
  83 +
  84 + for (int32_t i = 0; i != leading_size; ++i) {
  85 + for (int32_t n = 0; n != static_cast<int32_t>(values.size()); ++n) {
  86 + auto this_dim = values[n]->GetTensorTypeAndShapeInfo().GetShape()[dim];
  87 + const T *src = values[n]->GetTensorData<T>();
  88 + src += i * this_dim * trailing_size;
  89 +
  90 + std::copy(src, src + this_dim * trailing_size, dst);
  91 + dst += this_dim * trailing_size;
  92 + }
  93 + }
  94 +
  95 + return std::move(ans);
  96 +}
  97 +
  98 +template Ort::Value Cat<float>(OrtAllocator *allocator,
  99 + const std::vector<const Ort::Value *> &values,
  100 + int32_t dim);
  101 +
  102 +template Ort::Value Cat<int64_t>(OrtAllocator *allocator,
  103 + const std::vector<const Ort::Value *> &values,
  104 + int32_t dim);
  105 +
  106 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/cat.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_CAT_H_
  5 +#define SHERPA_ONNX_CSRC_CAT_H_
  6 +
  7 +#include <vector>
  8 +
  9 +#include "onnxruntime_cxx_api.h" // NOLINT
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +/** Cat a list of tensors along the given dim.
  14 + *
  15 + * @param allocator Allocator to allocate space for the returned tensor
  16 + * @param values Pointer to a list of tensors. The shape of the tensor must
  17 + * be the same except on the dim to be concatenated.
  18 + * @param dim The dim along which to concatenate the input tensors
  19 + *
  20 + * @return Return the concatenated tensor
  21 + */
  22 +template <typename T = float>
  23 +Ort::Value Cat(OrtAllocator *allocator,
  24 + const std::vector<const Ort::Value *> &values, int32_t dim);
  25 +
  26 +} // namespace sherpa_onnx
  27 +
  28 +#endif // SHERPA_ONNX_CSRC_CAT_H_
  1 +
  2 +// sherpa-onnx/csrc/macros.h
  3 +//
  4 +// Copyright 2023 Xiaomi Corporation
  5 +
  6 +#ifndef SHERPA_ONNX_CSRC_MACROS_H_
  7 +#define SHERPA_ONNX_CSRC_MACROS_H_
  8 +
  9 +#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \
  10 + do { \
  11 + auto value = \
  12 + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
  13 + if (!value) { \
  14 + fprintf(stderr, "%s does not exist in the metadata\n", src_key); \
  15 + exit(-1); \
  16 + } \
  17 + \
  18 + dst = atoi(value.get()); \
  19 + if (dst <= 0) { \
  20 + fprintf(stderr, "Invalid value %d for %s\n", dst, src_key); \
  21 + exit(-1); \
  22 + } \
  23 + } while (0)
  24 +
  25 +#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \
  26 + do { \
  27 + auto value = \
  28 + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
  29 + if (!value) { \
  30 + fprintf(stderr, "%s does not exist in the metadata\n", src_key); \
  31 + exit(-1); \
  32 + } \
  33 + \
  34 + bool ret = SplitStringToIntegers(value.get(), ",", true, &dst); \
  35 + if (!ret) { \
  36 + fprintf(stderr, "Invalid value %s for %s\n", value.get(), src_key); \
  37 + exit(-1); \
  38 + } \
  39 + } while (0)
  40 +
  41 +#endif // SHERPA_ONNX_CSRC_MACROS_H_
@@ -3,6 +3,8 @@ @@ -3,6 +3,8 @@
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
4 #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" 4 #include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
5 5
  6 +#include <assert.h>
  7 +
6 #include <algorithm> 8 #include <algorithm>
7 #include <memory> 9 #include <memory>
8 #include <sstream> 10 #include <sstream>
@@ -11,23 +13,11 @@ @@ -11,23 +13,11 @@
11 #include <vector> 13 #include <vector>
12 14
13 #include "onnxruntime_cxx_api.h" // NOLINT 15 #include "onnxruntime_cxx_api.h" // NOLINT
  16 +#include "sherpa-onnx/csrc/cat.h"
  17 +#include "sherpa-onnx/csrc/macros.h"
14 #include "sherpa-onnx/csrc/online-transducer-decoder.h" 18 #include "sherpa-onnx/csrc/online-transducer-decoder.h"
15 #include "sherpa-onnx/csrc/onnx-utils.h" 19 #include "sherpa-onnx/csrc/onnx-utils.h"
16 -  
17 -#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \  
18 - do { \  
19 - auto value = \  
20 - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \  
21 - if (!value) { \  
22 - fprintf(stderr, "%s does not exist in the metadata\n", src_key); \  
23 - exit(-1); \  
24 - } \  
25 - dst = atoi(value.get()); \  
26 - if (dst <= 0) { \  
27 - fprintf(stderr, "Invalud value %d for %s\n", dst, src_key); \  
28 - exit(-1); \  
29 - } \  
30 - } while (0) 20 +#include "sherpa-onnx/csrc/unbind.h"
31 21
32 namespace sherpa_onnx { 22 namespace sherpa_onnx {
33 23
@@ -64,7 +54,7 @@ void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) { @@ -64,7 +54,7 @@ void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) {
64 fprintf(stderr, "%s\n", os.str().c_str()); 54 fprintf(stderr, "%s\n", os.str().c_str());
65 } 55 }
66 56
67 - Ort::AllocatorWithDefaultOptions allocator; 57 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
68 SHERPA_ONNX_READ_META_DATA(num_encoder_layers_, "num_encoder_layers"); 58 SHERPA_ONNX_READ_META_DATA(num_encoder_layers_, "num_encoder_layers");
69 SHERPA_ONNX_READ_META_DATA(T_, "T"); 59 SHERPA_ONNX_READ_META_DATA(T_, "T");
70 SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); 60 SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len");
@@ -91,7 +81,7 @@ void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) { @@ -91,7 +81,7 @@ void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) {
91 fprintf(stderr, "%s\n", os.str().c_str()); 81 fprintf(stderr, "%s\n", os.str().c_str());
92 } 82 }
93 83
94 - Ort::AllocatorWithDefaultOptions allocator; 84 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
95 SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); 85 SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
96 SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); 86 SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
97 } 87 }
@@ -120,37 +110,19 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::StackStates( @@ -120,37 +110,19 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::StackStates(
120 const std::vector<std::vector<Ort::Value>> &states) const { 110 const std::vector<std::vector<Ort::Value>> &states) const {
121 int32_t batch_size = static_cast<int32_t>(states.size()); 111 int32_t batch_size = static_cast<int32_t>(states.size());
122 112
123 - std::array<int64_t, 3> h_shape{num_encoder_layers_, batch_size, d_model_};  
124 - Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),  
125 - h_shape.size());  
126 -  
127 - std::array<int64_t, 3> c_shape{num_encoder_layers_, batch_size,  
128 - rnn_hidden_size_};  
129 -  
130 - Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(),  
131 - c_shape.size());  
132 -  
133 - float *dst_h = h.GetTensorMutableData<float>();  
134 - float *dst_c = c.GetTensorMutableData<float>(); 113 + std::vector<const Ort::Value *> h_buf(batch_size);
  114 + std::vector<const Ort::Value *> c_buf(batch_size);
135 115
136 - for (int32_t layer = 0; layer != num_encoder_layers_; ++layer) {  
137 for (int32_t i = 0; i != batch_size; ++i) { 116 for (int32_t i = 0; i != batch_size; ++i) {
138 - const float *src_h =  
139 - states[i][0].GetTensorData<float>() + layer * d_model_;  
140 -  
141 - const float *src_c =  
142 - states[i][1].GetTensorData<float>() + layer * rnn_hidden_size_;  
143 -  
144 - std::copy(src_h, src_h + d_model_, dst_h);  
145 - std::copy(src_c, src_c + rnn_hidden_size_, dst_c);  
146 -  
147 - dst_h += d_model_;  
148 - dst_c += rnn_hidden_size_;  
149 - } 117 + assert(states[i].size() == 2);
  118 + h_buf[i] = &states[i][0];
  119 + c_buf[i] = &states[i][1];
150 } 120 }
151 121
152 - std::vector<Ort::Value> ans; 122 + Ort::Value h = Cat(allocator_, h_buf, 1);
  123 + Ort::Value c = Cat(allocator_, c_buf, 1);
153 124
  125 + std::vector<Ort::Value> ans;
154 ans.reserve(2); 126 ans.reserve(2);
155 ans.push_back(std::move(h)); 127 ans.push_back(std::move(h));
156 ans.push_back(std::move(c)); 128 ans.push_back(std::move(c));
@@ -161,37 +133,19 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::StackStates( @@ -161,37 +133,19 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::StackStates(
161 std::vector<std::vector<Ort::Value>> OnlineLstmTransducerModel::UnStackStates( 133 std::vector<std::vector<Ort::Value>> OnlineLstmTransducerModel::UnStackStates(
162 const std::vector<Ort::Value> &states) const { 134 const std::vector<Ort::Value> &states) const {
163 int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1]; 135 int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
  136 + assert(states.size() == 2);
164 137
165 std::vector<std::vector<Ort::Value>> ans(batch_size); 138 std::vector<std::vector<Ort::Value>> ans(batch_size);
166 139
167 - // allocate space  
168 - std::array<int64_t, 3> h_shape{num_encoder_layers_, 1, d_model_};  
169 - std::array<int64_t, 3> c_shape{num_encoder_layers_, 1, rnn_hidden_size_}; 140 + std::vector<Ort::Value> h_vec = Unbind(allocator_, &states[0], 1);
  141 + std::vector<Ort::Value> c_vec = Unbind(allocator_, &states[1], 1);
170 142
171 - for (int32_t i = 0; i != batch_size; ++i) {  
172 - Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),  
173 - h_shape.size());  
174 - Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(),  
175 - c_shape.size());  
176 - ans[i].push_back(std::move(h));  
177 - ans[i].push_back(std::move(c));  
178 - } 143 + assert(h_vec.size() == batch_size);
  144 + assert(c_vec.size() == batch_size);
179 145
180 - for (int32_t layer = 0; layer != num_encoder_layers_; ++layer) {  
181 for (int32_t i = 0; i != batch_size; ++i) { 146 for (int32_t i = 0; i != batch_size; ++i) {
182 - const float *src_h = states[0].GetTensorData<float>() +  
183 - layer * batch_size * d_model_ + i * d_model_;  
184 - const float *src_c = states[1].GetTensorData<float>() +  
185 - layer * batch_size * rnn_hidden_size_ +  
186 - i * rnn_hidden_size_;  
187 -  
188 - float *dst_h = ans[i][0].GetTensorMutableData<float>() + layer * d_model_;  
189 - float *dst_c =  
190 - ans[i][1].GetTensorMutableData<float>() + layer * rnn_hidden_size_;  
191 -  
192 - std::copy(src_h, src_h + d_model_, dst_h);  
193 - std::copy(src_c, src_c + rnn_hidden_size_, dst_c);  
194 - } 147 + ans[i].push_back(std::move(h_vec[i]));
  148 + ans[i].push_back(std::move(c_vec[i]));
195 } 149 }
196 150
197 return ans; 151 return ans;
@@ -206,20 +160,15 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() { @@ -206,20 +160,15 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() {
206 Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(), 160 Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
207 h_shape.size()); 161 h_shape.size());
208 162
209 - std::fill(h.GetTensorMutableData<float>(),  
210 - h.GetTensorMutableData<float>() +  
211 - num_encoder_layers_ * kBatchSize * d_model_,  
212 - 0); 163 + Fill<float>(&h, 0);
213 164
214 std::array<int64_t, 3> c_shape{num_encoder_layers_, kBatchSize, 165 std::array<int64_t, 3> c_shape{num_encoder_layers_, kBatchSize,
215 rnn_hidden_size_}; 166 rnn_hidden_size_};
  167 +
216 Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(), 168 Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(),
217 c_shape.size()); 169 c_shape.size());
218 170
219 - std::fill(c.GetTensorMutableData<float>(),  
220 - c.GetTensorMutableData<float>() +  
221 - num_encoder_layers_ * kBatchSize * rnn_hidden_size_,  
222 - 0); 171 + Fill<float>(&c, 0);
223 172
224 std::vector<Ort::Value> states; 173 std::vector<Ort::Value> states;
225 174
@@ -8,11 +8,13 @@ @@ -8,11 +8,13 @@
8 #include <string> 8 #include <string>
9 9
10 #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" 10 #include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
  11 +#include "sherpa-onnx/csrc/online-zipformer-transducer-model.h"
11 #include "sherpa-onnx/csrc/onnx-utils.h" 12 #include "sherpa-onnx/csrc/onnx-utils.h"
12 namespace sherpa_onnx { 13 namespace sherpa_onnx {
13 14
14 enum class ModelType { 15 enum class ModelType {
15 kLstm, 16 kLstm,
  17 + kZipformer,
16 kUnkown, 18 kUnkown,
17 }; 19 };
18 20
@@ -40,6 +42,8 @@ static ModelType GetModelType(const OnlineTransducerModelConfig &config) { @@ -40,6 +42,8 @@ static ModelType GetModelType(const OnlineTransducerModelConfig &config) {
40 42
41 if (model_type.get() == std::string("lstm")) { 43 if (model_type.get() == std::string("lstm")) {
42 return ModelType::kLstm; 44 return ModelType::kLstm;
  45 + } else if (model_type.get() == std::string("zipformer")) {
  46 + return ModelType::kZipformer;
43 } else { 47 } else {
44 fprintf(stderr, "Unsupported model_type: %s\n", model_type.get()); 48 fprintf(stderr, "Unsupported model_type: %s\n", model_type.get());
45 return ModelType::kUnkown; 49 return ModelType::kUnkown;
@@ -53,6 +57,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( @@ -53,6 +57,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
53 switch (model_type) { 57 switch (model_type) {
54 case ModelType::kLstm: 58 case ModelType::kLstm:
55 return std::make_unique<OnlineLstmTransducerModel>(config); 59 return std::make_unique<OnlineLstmTransducerModel>(config);
  60 + case ModelType::kZipformer:
  61 + return std::make_unique<OnlineZipformerTransducerModel>(config);
56 case ModelType::kUnkown: 62 case ModelType::kUnkown:
57 return nullptr; 63 return nullptr;
58 } 64 }
  1 +// sherpa-onnx/csrc/online-zipformer-transducer-model.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/online-zipformer-transducer-model.h"
  6 +
  7 +#include <assert.h>
  8 +
  9 +#include <algorithm>
  10 +#include <memory>
  11 +#include <sstream>
  12 +#include <string>
  13 +#include <utility>
  14 +#include <vector>
  15 +
  16 +#include "onnxruntime_cxx_api.h" // NOLINT
  17 +#include "sherpa-onnx/csrc/cat.h"
  18 +#include "sherpa-onnx/csrc/macros.h"
  19 +#include "sherpa-onnx/csrc/online-transducer-decoder.h"
  20 +#include "sherpa-onnx/csrc/onnx-utils.h"
  21 +#include "sherpa-onnx/csrc/text-utils.h"
  22 +#include "sherpa-onnx/csrc/unbind.h"
  23 +
  24 +namespace sherpa_onnx {
  25 +
  26 +OnlineZipformerTransducerModel::OnlineZipformerTransducerModel(
  27 + const OnlineTransducerModelConfig &config)
  28 + : env_(ORT_LOGGING_LEVEL_WARNING),
  29 + config_(config),
  30 + sess_opts_{},
  31 + allocator_{} {
  32 + sess_opts_.SetIntraOpNumThreads(config.num_threads);
  33 + sess_opts_.SetInterOpNumThreads(config.num_threads);
  34 +
  35 + InitEncoder(config.encoder_filename);
  36 + InitDecoder(config.decoder_filename);
  37 + InitJoiner(config.joiner_filename);
  38 +}
  39 +
  40 +void OnlineZipformerTransducerModel::InitEncoder(const std::string &filename) {
  41 + encoder_sess_ = std::make_unique<Ort::Session>(
  42 + env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
  43 +
  44 + GetInputNames(encoder_sess_.get(), &encoder_input_names_,
  45 + &encoder_input_names_ptr_);
  46 +
  47 + GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
  48 + &encoder_output_names_ptr_);
  49 +
  50 + // get meta data
  51 + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
  52 + if (config_.debug) {
  53 + std::ostringstream os;
  54 + os << "---encoder---\n";
  55 + PrintModelMetadata(os, meta_data);
  56 + fprintf(stderr, "%s\n", os.str().c_str());
  57 + }
  58 +
  59 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  60 + SHERPA_ONNX_READ_META_DATA_VEC(encoder_dims_, "encoder_dims");
  61 + SHERPA_ONNX_READ_META_DATA_VEC(attention_dims_, "attention_dims");
  62 + SHERPA_ONNX_READ_META_DATA_VEC(num_encoder_layers_, "num_encoder_layers");
  63 + SHERPA_ONNX_READ_META_DATA_VEC(cnn_module_kernels_, "cnn_module_kernels");
  64 + SHERPA_ONNX_READ_META_DATA_VEC(left_context_len_, "left_context_len");
  65 +
  66 + SHERPA_ONNX_READ_META_DATA(T_, "T");
  67 + SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len");
  68 +
  69 + if (config_.debug) {
  70 + auto print = [](const std::vector<int32_t> &v, const char *name) {
  71 + fprintf(stderr, "%s: ", name);
  72 + for (auto i : v) {
  73 + fprintf(stderr, "%d ", i);
  74 + }
  75 + fprintf(stderr, "\n");
  76 + };
  77 + print(encoder_dims_, "encoder_dims");
  78 + print(attention_dims_, "attention_dims");
  79 + print(num_encoder_layers_, "num_encoder_layers");
  80 + print(cnn_module_kernels_, "cnn_module_kernels");
  81 + print(left_context_len_, "left_context_len");
  82 + fprintf(stderr, "T: %d\n", T_);
  83 + fprintf(stderr, "decode_chunk_len_: %d\n", decode_chunk_len_);
  84 + }
  85 +}
  86 +
  87 +void OnlineZipformerTransducerModel::InitDecoder(const std::string &filename) {
  88 + decoder_sess_ = std::make_unique<Ort::Session>(
  89 + env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
  90 +
  91 + GetInputNames(decoder_sess_.get(), &decoder_input_names_,
  92 + &decoder_input_names_ptr_);
  93 +
  94 + GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
  95 + &decoder_output_names_ptr_);
  96 +
  97 + // get meta data
  98 + Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata();
  99 + if (config_.debug) {
  100 + std::ostringstream os;
  101 + os << "---decoder---\n";
  102 + PrintModelMetadata(os, meta_data);
  103 + fprintf(stderr, "%s\n", os.str().c_str());
  104 + }
  105 +
  106 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  107 + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
  108 + SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
  109 +}
  110 +
  111 +void OnlineZipformerTransducerModel::InitJoiner(const std::string &filename) {
  112 + joiner_sess_ = std::make_unique<Ort::Session>(
  113 + env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
  114 +
  115 + GetInputNames(joiner_sess_.get(), &joiner_input_names_,
  116 + &joiner_input_names_ptr_);
  117 +
  118 + GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
  119 + &joiner_output_names_ptr_);
  120 +
  121 + // get meta data
  122 + Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata();
  123 + if (config_.debug) {
  124 + std::ostringstream os;
  125 + os << "---joiner---\n";
  126 + PrintModelMetadata(os, meta_data);
  127 + fprintf(stderr, "%s\n", os.str().c_str());
  128 + }
  129 +}
  130 +
  131 +std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
  132 + const std::vector<std::vector<Ort::Value>> &states) const {
  133 + int32_t batch_size = static_cast<int32_t>(states.size());
  134 + int32_t num_encoders = static_cast<int32_t>(num_encoder_layers_.size());
  135 +
  136 + std::vector<const Ort::Value *> buf(batch_size);
  137 +
  138 + std::vector<Ort::Value> ans;
  139 + ans.reserve(states[0].size());
  140 +
  141 + // cached_len
  142 + for (int32_t i = 0; i != num_encoders; ++i) {
  143 + for (int32_t n = 0; n != batch_size; ++n) {
  144 + buf[n] = &states[n][i];
  145 + }
  146 + auto v = Cat<int64_t>(allocator_, buf, 1); // (num_layers, 1)
  147 + ans.push_back(std::move(v));
  148 + }
  149 +
  150 + // cached_avg
  151 + for (int32_t i = 0; i != num_encoders; ++i) {
  152 + for (int32_t n = 0; n != batch_size; ++n) {
  153 + buf[n] = &states[n][num_encoders + i];
  154 + }
  155 + auto v = Cat(allocator_, buf, 1); // (num_layers, 1, encoder_dims)
  156 + ans.push_back(std::move(v));
  157 + }
  158 +
  159 + // cached_key
  160 + for (int32_t i = 0; i != num_encoders; ++i) {
  161 + for (int32_t n = 0; n != batch_size; ++n) {
  162 + buf[n] = &states[n][num_encoders * 2 + i];
  163 + }
  164 + // (num_layers, left_context_len, 1, attention_dims)
  165 + auto v = Cat(allocator_, buf, 2);
  166 + ans.push_back(std::move(v));
  167 + }
  168 +
  169 + // cached_val
  170 + for (int32_t i = 0; i != num_encoders; ++i) {
  171 + for (int32_t n = 0; n != batch_size; ++n) {
  172 + buf[n] = &states[n][num_encoders * 3 + i];
  173 + }
  174 + // (num_layers, left_context_len, 1, attention_dims/2)
  175 + auto v = Cat(allocator_, buf, 2);
  176 + ans.push_back(std::move(v));
  177 + }
  178 +
  179 + // cached_val2
  180 + for (int32_t i = 0; i != num_encoders; ++i) {
  181 + for (int32_t n = 0; n != batch_size; ++n) {
  182 + buf[n] = &states[n][num_encoders * 4 + i];
  183 + }
  184 + // (num_layers, left_context_len, 1, attention_dims/2)
  185 + auto v = Cat(allocator_, buf, 2);
  186 + ans.push_back(std::move(v));
  187 + }
  188 +
  189 + // cached_conv1
  190 + for (int32_t i = 0; i != num_encoders; ++i) {
  191 + for (int32_t n = 0; n != batch_size; ++n) {
  192 + buf[n] = &states[n][num_encoders * 5 + i];
  193 + }
  194 + // (num_layers, 1, encoder_dims, cnn_module_kernels-1)
  195 + auto v = Cat(allocator_, buf, 1);
  196 + ans.push_back(std::move(v));
  197 + }
  198 +
  199 + // cached_conv2
  200 + for (int32_t i = 0; i != num_encoders; ++i) {
  201 + for (int32_t n = 0; n != batch_size; ++n) {
  202 + buf[n] = &states[n][num_encoders * 6 + i];
  203 + }
  204 + // (num_layers, 1, encoder_dims, cnn_module_kernels-1)
  205 + auto v = Cat(allocator_, buf, 1);
  206 + ans.push_back(std::move(v));
  207 + }
  208 +
  209 + return ans;
  210 +}
  211 +
  212 +std::vector<std::vector<Ort::Value>>
  213 +OnlineZipformerTransducerModel::UnStackStates(
  214 + const std::vector<Ort::Value> &states) const {
  215 + assert(states.size() == num_encoder_layers_.size() * 7);
  216 +
  217 + int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
  218 + int32_t num_encoders = num_encoder_layers_.size();
  219 +
  220 + std::vector<std::vector<Ort::Value>> ans;
  221 + ans.resize(batch_size);
  222 +
  223 + // cached_len
  224 + for (int32_t i = 0; i != num_encoders; ++i) {
  225 + auto v = Unbind<int64_t>(allocator_, &states[i], 1);
  226 + assert(v.size() == batch_size);
  227 +
  228 + for (int32_t n = 0; n != batch_size; ++n) {
  229 + ans[n].push_back(std::move(v[n]));
  230 + }
  231 + }
  232 +
  233 + // cached_avg
  234 + for (int32_t i = num_encoders; i != 2 * num_encoders; ++i) {
  235 + auto v = Unbind(allocator_, &states[i], 1);
  236 + assert(v.size() == batch_size);
  237 +
  238 + for (int32_t n = 0; n != batch_size; ++n) {
  239 + ans[n].push_back(std::move(v[n]));
  240 + }
  241 + }
  242 +
  243 + // cached_key
  244 + for (int32_t i = 2 * num_encoders; i != 3 * num_encoders; ++i) {
  245 + auto v = Unbind(allocator_, &states[i], 2);
  246 + assert(v.size() == batch_size);
  247 +
  248 + for (int32_t n = 0; n != batch_size; ++n) {
  249 + ans[n].push_back(std::move(v[n]));
  250 + }
  251 + }
  252 +
  253 + // cached_val
  254 + for (int32_t i = 3 * num_encoders; i != 4 * num_encoders; ++i) {
  255 + auto v = Unbind(allocator_, &states[i], 2);
  256 + assert(v.size() == batch_size);
  257 +
  258 + for (int32_t n = 0; n != batch_size; ++n) {
  259 + ans[n].push_back(std::move(v[n]));
  260 + }
  261 + }
  262 +
  263 + // cached_val2
  264 + for (int32_t i = 4 * num_encoders; i != 5 * num_encoders; ++i) {
  265 + auto v = Unbind(allocator_, &states[i], 2);
  266 + assert(v.size() == batch_size);
  267 +
  268 + for (int32_t n = 0; n != batch_size; ++n) {
  269 + ans[n].push_back(std::move(v[n]));
  270 + }
  271 + }
  272 +
  273 + // cached_conv1
  274 + for (int32_t i = 5 * num_encoders; i != 6 * num_encoders; ++i) {
  275 + auto v = Unbind(allocator_, &states[i], 1);
  276 + assert(v.size() == batch_size);
  277 +
  278 + for (int32_t n = 0; n != batch_size; ++n) {
  279 + ans[n].push_back(std::move(v[n]));
  280 + }
  281 + }
  282 +
  283 + // cached_conv2
  284 + for (int32_t i = 6 * num_encoders; i != 7 * num_encoders; ++i) {
  285 + auto v = Unbind(allocator_, &states[i], 1);
  286 + assert(v.size() == batch_size);
  287 +
  288 + for (int32_t n = 0; n != batch_size; ++n) {
  289 + ans[n].push_back(std::move(v[n]));
  290 + }
  291 + }
  292 +
  293 + return ans;
  294 +}
  295 +
  296 +std::vector<Ort::Value> OnlineZipformerTransducerModel::GetEncoderInitStates() {
  297 + // Please see
  298 + // https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py#L673
  299 + // for details
  300 +
  301 + int32_t n = static_cast<int32_t>(encoder_dims_.size());
  302 + std::vector<Ort::Value> cached_len_vec;
  303 + std::vector<Ort::Value> cached_avg_vec;
  304 + std::vector<Ort::Value> cached_key_vec;
  305 + std::vector<Ort::Value> cached_val_vec;
  306 + std::vector<Ort::Value> cached_val2_vec;
  307 + std::vector<Ort::Value> cached_conv1_vec;
  308 + std::vector<Ort::Value> cached_conv2_vec;
  309 +
  310 + cached_len_vec.reserve(n);
  311 + cached_avg_vec.reserve(n);
  312 + cached_key_vec.reserve(n);
  313 + cached_val_vec.reserve(n);
  314 + cached_val2_vec.reserve(n);
  315 + cached_conv1_vec.reserve(n);
  316 + cached_conv2_vec.reserve(n);
  317 +
  318 + for (int32_t i = 0; i != n; ++i) {
  319 + {
  320 + std::array<int64_t, 2> s{num_encoder_layers_[i], 1};
  321 + auto v =
  322 + Ort::Value::CreateTensor<int64_t>(allocator_, s.data(), s.size());
  323 + Fill<int64_t>(&v, 0);
  324 + cached_len_vec.push_back(std::move(v));
  325 + }
  326 +
  327 + {
  328 + std::array<int64_t, 3> s{num_encoder_layers_[i], 1, encoder_dims_[i]};
  329 + auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
  330 + Fill(&v, 0);
  331 + cached_avg_vec.push_back(std::move(v));
  332 + }
  333 +
  334 + {
  335 + std::array<int64_t, 4> s{num_encoder_layers_[i], left_context_len_[i], 1,
  336 + attention_dims_[i]};
  337 + auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
  338 + Fill(&v, 0);
  339 + cached_key_vec.push_back(std::move(v));
  340 + }
  341 +
  342 + {
  343 + std::array<int64_t, 4> s{num_encoder_layers_[i], left_context_len_[i], 1,
  344 + attention_dims_[i] / 2};
  345 + auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
  346 + Fill(&v, 0);
  347 + cached_val_vec.push_back(std::move(v));
  348 + }
  349 +
  350 + {
  351 + std::array<int64_t, 4> s{num_encoder_layers_[i], left_context_len_[i], 1,
  352 + attention_dims_[i] / 2};
  353 + auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
  354 + Fill(&v, 0);
  355 + cached_val2_vec.push_back(std::move(v));
  356 + }
  357 +
  358 + {
  359 + std::array<int64_t, 4> s{num_encoder_layers_[i], 1, encoder_dims_[i],
  360 + cnn_module_kernels_[i] - 1};
  361 + auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
  362 + Fill(&v, 0);
  363 + cached_conv1_vec.push_back(std::move(v));
  364 + }
  365 +
  366 + {
  367 + std::array<int64_t, 4> s{num_encoder_layers_[i], 1, encoder_dims_[i],
  368 + cnn_module_kernels_[i] - 1};
  369 + auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
  370 + Fill(&v, 0);
  371 + cached_conv2_vec.push_back(std::move(v));
  372 + }
  373 + }
  374 +
  375 + std::vector<Ort::Value> ans;
  376 + ans.reserve(n * 7);
  377 +
  378 + for (auto &v : cached_len_vec) ans.push_back(std::move(v));
  379 + for (auto &v : cached_avg_vec) ans.push_back(std::move(v));
  380 + for (auto &v : cached_key_vec) ans.push_back(std::move(v));
  381 + for (auto &v : cached_val_vec) ans.push_back(std::move(v));
  382 + for (auto &v : cached_val2_vec) ans.push_back(std::move(v));
  383 + for (auto &v : cached_conv1_vec) ans.push_back(std::move(v));
  384 + for (auto &v : cached_conv2_vec) ans.push_back(std::move(v));
  385 +
  386 + return ans;
  387 +}
  388 +
  389 +std::pair<Ort::Value, std::vector<Ort::Value>>
  390 +OnlineZipformerTransducerModel::RunEncoder(Ort::Value features,
  391 + std::vector<Ort::Value> states) {
  392 + auto memory_info =
  393 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  394 +
  395 + std::vector<Ort::Value> encoder_inputs;
  396 + encoder_inputs.reserve(1 + states.size());
  397 +
  398 + encoder_inputs.push_back(std::move(features));
  399 + for (auto &v : states) {
  400 + encoder_inputs.push_back(std::move(v));
  401 + }
  402 +
  403 + auto encoder_out = encoder_sess_->Run(
  404 + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
  405 + encoder_inputs.size(), encoder_output_names_ptr_.data(),
  406 + encoder_output_names_ptr_.size());
  407 +
  408 + std::vector<Ort::Value> next_states;
  409 + next_states.reserve(states.size());
  410 +
  411 + for (int32_t i = 1; i != static_cast<int32_t>(encoder_out.size()); ++i) {
  412 + next_states.push_back(std::move(encoder_out[i]));
  413 + }
  414 +
  415 + return {std::move(encoder_out[0]), std::move(next_states)};
  416 +}
  417 +
  418 +Ort::Value OnlineZipformerTransducerModel::BuildDecoderInput(
  419 + const std::vector<OnlineTransducerDecoderResult> &results) {
  420 + int32_t batch_size = static_cast<int32_t>(results.size());
  421 + std::array<int64_t, 2> shape{batch_size, context_size_};
  422 + Ort::Value decoder_input =
  423 + Ort::Value::CreateTensor<int64_t>(allocator_, shape.data(), shape.size());
  424 + int64_t *p = decoder_input.GetTensorMutableData<int64_t>();
  425 +
  426 + for (const auto &r : results) {
  427 + const int64_t *begin = r.tokens.data() + r.tokens.size() - context_size_;
  428 + const int64_t *end = r.tokens.data() + r.tokens.size();
  429 + std::copy(begin, end, p);
  430 + p += context_size_;
  431 + }
  432 +
  433 + return decoder_input;
  434 +}
  435 +
  436 +Ort::Value OnlineZipformerTransducerModel::RunDecoder(
  437 + Ort::Value decoder_input) {
  438 + auto decoder_out = decoder_sess_->Run(
  439 + {}, decoder_input_names_ptr_.data(), &decoder_input, 1,
  440 + decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size());
  441 + return std::move(decoder_out[0]);
  442 +}
  443 +
  444 +Ort::Value OnlineZipformerTransducerModel::RunJoiner(Ort::Value encoder_out,
  445 + Ort::Value decoder_out) {
  446 + std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
  447 + std::move(decoder_out)};
  448 + auto logit =
  449 + joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(),
  450 + joiner_input.size(), joiner_output_names_ptr_.data(),
  451 + joiner_output_names_ptr_.size());
  452 +
  453 + return std::move(logit[0]);
  454 +}
  455 +
  456 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-zipformer-transducer-model.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER_TRANSDUCER_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER_TRANSDUCER_MODEL_H_
  6 +
  7 +#include <memory>
  8 +#include <string>
  9 +#include <utility>
  10 +#include <vector>
  11 +
  12 +#include "onnxruntime_cxx_api.h" // NOLINT
  13 +#include "sherpa-onnx/csrc/online-transducer-model-config.h"
  14 +#include "sherpa-onnx/csrc/online-transducer-model.h"
  15 +
  16 +namespace sherpa_onnx {
  17 +
  18 +class OnlineZipformerTransducerModel : public OnlineTransducerModel {
  19 + public:
  20 + explicit OnlineZipformerTransducerModel(
  21 + const OnlineTransducerModelConfig &config);
  22 +
  23 + std::vector<Ort::Value> StackStates(
  24 + const std::vector<std::vector<Ort::Value>> &states) const override;
  25 +
  26 + std::vector<std::vector<Ort::Value>> UnStackStates(
  27 + const std::vector<Ort::Value> &states) const override;
  28 +
  29 + std::vector<Ort::Value> GetEncoderInitStates() override;
  30 +
  31 + std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
  32 + Ort::Value features, std::vector<Ort::Value> states) override;
  33 +
  34 + Ort::Value BuildDecoderInput(
  35 + const std::vector<OnlineTransducerDecoderResult> &results) override;
  36 +
  37 + Ort::Value RunDecoder(Ort::Value decoder_input) override;
  38 +
  39 + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override;
  40 +
  41 + int32_t ContextSize() const override { return context_size_; }
  42 +
  43 + int32_t ChunkSize() const override { return T_; }
  44 +
  45 + int32_t ChunkShift() const override { return decode_chunk_len_; }
  46 +
  47 + int32_t VocabSize() const override { return vocab_size_; }
  48 + OrtAllocator *Allocator() override { return allocator_; }
  49 +
  50 + private:
  51 + void InitEncoder(const std::string &encoder_filename);
  52 + void InitDecoder(const std::string &decoder_filename);
  53 + void InitJoiner(const std::string &joiner_filename);
  54 +
  55 + private:
  56 + Ort::Env env_;
  57 + Ort::SessionOptions sess_opts_;
  58 + Ort::AllocatorWithDefaultOptions allocator_;
  59 +
  60 + std::unique_ptr<Ort::Session> encoder_sess_;
  61 + std::unique_ptr<Ort::Session> decoder_sess_;
  62 + std::unique_ptr<Ort::Session> joiner_sess_;
  63 +
  64 + std::vector<std::string> encoder_input_names_;
  65 + std::vector<const char *> encoder_input_names_ptr_;
  66 +
  67 + std::vector<std::string> encoder_output_names_;
  68 + std::vector<const char *> encoder_output_names_ptr_;
  69 +
  70 + std::vector<std::string> decoder_input_names_;
  71 + std::vector<const char *> decoder_input_names_ptr_;
  72 +
  73 + std::vector<std::string> decoder_output_names_;
  74 + std::vector<const char *> decoder_output_names_ptr_;
  75 +
  76 + std::vector<std::string> joiner_input_names_;
  77 + std::vector<const char *> joiner_input_names_ptr_;
  78 +
  79 + std::vector<std::string> joiner_output_names_;
  80 + std::vector<const char *> joiner_output_names_ptr_;
  81 +
  82 + OnlineTransducerModelConfig config_;
  83 +
  84 + std::vector<int32_t> encoder_dims_;
  85 + std::vector<int32_t> attention_dims_;
  86 + std::vector<int32_t> num_encoder_layers_;
  87 + std::vector<int32_t> cnn_module_kernels_;
  88 + std::vector<int32_t> left_context_len_;
  89 +
  90 + int32_t T_ = 0;
  91 + int32_t decode_chunk_len_ = 0;
  92 +
  93 + int32_t context_size_ = 0;
  94 + int32_t vocab_size_ = 0;
  95 +};
  96 +
  97 +} // namespace sherpa_onnx
  98 +
  99 +#endif // SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER_TRANSDUCER_MODEL_H_
@@ -46,16 +46,74 @@ void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) { @@ -46,16 +46,74 @@ void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) {
46 } 46 }
47 } 47 }
48 48
49 -Ort::Value Clone(Ort::Value *v) { 49 +Ort::Value Clone(const Ort::Value *v) {
50 auto type_and_shape = v->GetTensorTypeAndShapeInfo(); 50 auto type_and_shape = v->GetTensorTypeAndShapeInfo();
51 std::vector<int64_t> shape = type_and_shape.GetShape(); 51 std::vector<int64_t> shape = type_and_shape.GetShape();
52 52
53 auto memory_info = 53 auto memory_info =
54 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); 54 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
55 55
56 - return Ort::Value::CreateTensor(memory_info, v->GetTensorMutableData<float>(),  
57 - type_and_shape.GetElementCount(),  
58 - shape.data(), shape.size()); 56 + switch (type_and_shape.GetElementType()) {
  57 + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
  58 + return Ort::Value::CreateTensor(
  59 + memory_info,
  60 + const_cast<Ort::Value *>(v)->GetTensorMutableData<int32_t>(),
  61 + type_and_shape.GetElementCount(), shape.data(), shape.size());
  62 + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
  63 + return Ort::Value::CreateTensor(
  64 + memory_info,
  65 + const_cast<Ort::Value *>(v)->GetTensorMutableData<int64_t>(),
  66 + type_and_shape.GetElementCount(), shape.data(), shape.size());
  67 + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
  68 + return Ort::Value::CreateTensor(
  69 + memory_info,
  70 + const_cast<Ort::Value *>(v)->GetTensorMutableData<float>(),
  71 + type_and_shape.GetElementCount(), shape.data(), shape.size());
  72 + default:
  73 + fprintf(stderr, "Unsupported type: %d\n",
  74 + static_cast<int32_t>(type_and_shape.GetElementType()));
  75 + exit(-1);
  76 + // unreachable code
  77 + return Ort::Value{nullptr};
  78 + }
  79 +}
  80 +
  81 +void Print1D(Ort::Value *v) {
  82 + std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
  83 + const float *d = v->GetTensorData<float>();
  84 + for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
  85 + fprintf(stderr, "%.3f ", d[i]);
  86 + }
  87 + fprintf(stderr, "\n");
  88 +}
  89 +
  90 +void Print2D(Ort::Value *v) {
  91 + std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
  92 + const float *d = v->GetTensorData<float>();
  93 +
  94 + for (int32_t r = 0; r != static_cast<int32_t>(shape[0]); ++r) {
  95 + for (int32_t c = 0; c != static_cast<int32_t>(shape[1]); ++c, ++d) {
  96 + fprintf(stderr, "%.3f ", *d);
  97 + }
  98 + fprintf(stderr, "\n");
  99 + }
  100 + fprintf(stderr, "\n");
  101 +}
  102 +
  103 +void Print3D(Ort::Value *v) {
  104 + std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
  105 + const float *d = v->GetTensorData<float>();
  106 +
  107 + for (int32_t p = 0; p != static_cast<int32_t>(shape[0]); ++p) {
  108 + fprintf(stderr, "---plane %d---\n", p);
  109 + for (int32_t r = 0; r != static_cast<int32_t>(shape[1]); ++r) {
  110 + for (int32_t c = 0; c != static_cast<int32_t>(shape[2]); ++c, ++d) {
  111 + fprintf(stderr, "%.3f ", *d);
  112 + }
  113 + fprintf(stderr, "\n");
  114 + }
  115 + }
  116 + fprintf(stderr, "\n");
59 } 117 }
60 118
61 } // namespace sherpa_onnx 119 } // namespace sherpa_onnx
@@ -56,7 +56,23 @@ void PrintModelMetadata(std::ostream &os, @@ -56,7 +56,23 @@ void PrintModelMetadata(std::ostream &os,
56 const Ort::ModelMetadata &meta_data); // NOLINT 56 const Ort::ModelMetadata &meta_data); // NOLINT
57 57
58 // Return a shallow copy of v 58 // Return a shallow copy of v
59 -Ort::Value Clone(Ort::Value *v); 59 +Ort::Value Clone(const Ort::Value *v);
  60 +
  61 +// Print a 1-D tensor to stderr
  62 +void Print1D(Ort::Value *v);
  63 +
  64 +// Print a 2-D tensor to stderr
  65 +void Print2D(Ort::Value *v);
  66 +
  67 +// Print a 3-D tensor to stderr
  68 +void Print3D(Ort::Value *v);
  69 +
  70 +template <typename T = float>
  71 +void Fill(Ort::Value *tensor, T value) {
  72 + auto n = tensor->GetTypeInfo().GetTensorTypeAndShapeInfo().GetElementCount();
  73 + auto p = tensor->GetTensorMutableData<T>();
  74 + std::fill(p, p + n, value);
  75 +}
60 76
61 } // namespace sherpa_onnx 77 } // namespace sherpa_onnx
62 78
  1 +// sherpa-onnx/csrc/text-utils.cc
  2 +//
  3 +// Copyright 2009-2011 Saarland University; Microsoft Corporation
  4 +// Copyright 2023 Xiaomi Corporation
  5 +
  6 +#include "sherpa-onnx/csrc/text-utils.h"
  7 +
  8 +#include <string>
  9 +#include <vector>
  10 +
  11 +// This file is copied/modified from
  12 +// https://github.com/kaldi-asr/kaldi/blob/master/src/util/text-utils.cc
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +void SplitStringToVector(const std::string &full, const char *delim,
  17 + bool omit_empty_strings,
  18 + std::vector<std::string> *out) {
  19 + size_t start = 0, found = 0, end = full.size();
  20 + out->clear();
  21 + while (found != std::string::npos) {
  22 + found = full.find_first_of(delim, start);
  23 + // start != end condition is for when the delimiter is at the end
  24 + if (!omit_empty_strings || (found != start && start != end))
  25 + out->push_back(full.substr(start, found - start));
  26 + start = found + 1;
  27 + }
  28 +}
  29 +
  30 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/text-utils.h
  2 +//
  3 +// Copyright 2009-2011 Saarland University; Microsoft Corporation
  4 +// Copyright 2023 Xiaomi Corporation
  5 +#ifndef SHERPA_ONNX_CSRC_TEXT_UTILS_H_
  6 +#define SHERPA_ONNX_CSRC_TEXT_UTILS_H_
  7 +#include <stdlib.h>
  8 +
  9 +#include <string>
  10 +#include <vector>
  11 +
  12 +#ifdef _MSC_VER
  13 +#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) \
  14 + _strtoi64(cur_cstr, end_cstr, 10);
  15 +#else
  16 +#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) strtoll(cur_cstr, end_cstr, 10);
  17 +#endif
  18 +
  19 +// This file is copied/modified from
  20 +// https://github.com/kaldi-asr/kaldi/blob/master/src/util/text-utils.h
  21 +
  22 +namespace sherpa_onnx {
  23 +
  24 +/// Split a string using any of the single character delimiters.
  25 +/// If omit_empty_strings == true, the output will contain any
  26 +/// nonempty strings after splitting on any of the
  27 +/// characters in the delimiter. If omit_empty_strings == false,
  28 +/// the output will contain n+1 strings if there are n characters
  29 +/// in the set "delim" within the input string. In this case
  30 +/// the empty string is split to a single empty string.
  31 +void SplitStringToVector(const std::string &full, const char *delim,
  32 + bool omit_empty_strings,
  33 + std::vector<std::string> *out);
  34 +
  35 +/**
  36 + \brief Split a string (e.g. 1:2:3) into a vector of integers.
  37 +
  38 + \param [in] delim String containing a list of characters, any of which
  39 + is allowed as a delimiter.
  40 + \param [in] omit_empty_strings If true, empty strings between delimiters are
  41 + allowed and will not produce an output integer; if false,
  42 + instances of characters in 'delim' that are consecutive or
  43 + at the start or end of the string would be an error.
  44 + You'll normally want this to be true if 'delim' consists
  45 + of spaces, and false otherwise.
  46 + \param [out] out The output list of integers.
  47 +*/
  48 +template <class I>
  49 +bool SplitStringToIntegers(const std::string &full, const char *delim,
  50 + bool omit_empty_strings, // typically false [but
  51 + // should probably be true
  52 + // if "delim" is spaces].
  53 + std::vector<I> *out) {
  54 + static_assert(std::is_integral<I>::value, "");
  55 + if (*(full.c_str()) == '\0') {
  56 + out->clear();
  57 + return true;
  58 + }
  59 + std::vector<std::string> split;
  60 + SplitStringToVector(full, delim, omit_empty_strings, &split);
  61 + out->resize(split.size());
  62 + for (size_t i = 0; i < split.size(); i++) {
  63 + const char *this_str = split[i].c_str();
  64 + char *end = NULL;
  65 + int64_t j = 0;
  66 + j = SHERPA_ONNX_STRTOLL(this_str, &end);
  67 + if (end == this_str || *end != '\0') {
  68 + out->clear();
  69 + return false;
  70 + } else {
  71 + I jI = static_cast<I>(j);
  72 + if (static_cast<int64_t>(jI) != j) {
  73 + // output type cannot fit this integer.
  74 + out->clear();
  75 + return false;
  76 + }
  77 + (*out)[i] = jI;
  78 + }
  79 + }
  80 + return true;
  81 +}
  82 +
  83 +} // namespace sherpa_onnx
  84 +
  85 +#endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_
  1 +// sherpa-onnx/csrc/unbind-test.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/unbind.h"
  6 +
  7 +#include "gtest/gtest.h"
  8 +#include "sherpa-onnx/csrc/cat.h"
  9 +#include "sherpa-onnx/csrc/onnx-utils.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +TEST(Ubind, Test1DTensors) {
  14 + Ort::AllocatorWithDefaultOptions allocator;
  15 + std::array<int64_t, 1> shape{3};
  16 + Ort::Value v =
  17 + Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
  18 + float *p = v.GetTensorMutableData<float>();
  19 +
  20 + for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
  21 + p[i] = i;
  22 + }
  23 + auto ans = Unbind(allocator, &v, 0);
  24 + EXPECT_EQ(ans.size(), shape[0]);
  25 + for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
  26 + EXPECT_EQ(ans[i].GetTensorData<float>()[0], p[i]);
  27 + }
  28 + Print1D(&v);
  29 + for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
  30 + Print1D(&ans[i]);
  31 + }
  32 +
  33 + // For Cat
  34 + std::vector<const Ort::Value *> vec(ans.size());
  35 + for (int32_t i = 0; i != static_cast<int32_t>(vec.size()); ++i) {
  36 + vec[i] = &ans[i];
  37 + }
  38 + Ort::Value v2 = Cat(allocator, vec, 0);
  39 + const float *p2 = v2.GetTensorData<float>();
  40 + for (int32_t i = 0; i != shape[0]; ++i) {
  41 + EXPECT_EQ(p[i], p2[i]);
  42 + }
  43 +}
  44 +
  45 +TEST(Ubind, Test2DTensorsDim0) {
  46 + Ort::AllocatorWithDefaultOptions allocator;
  47 + std::array<int64_t, 2> shape{3, 2};
  48 + Ort::Value v =
  49 + Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
  50 + float *p = v.GetTensorMutableData<float>();
  51 +
  52 + for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1]); ++i) {
  53 + p[i] = i;
  54 + }
  55 + auto ans = Unbind(allocator, &v, 0);
  56 +
  57 + Print2D(&v);
  58 + for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
  59 + Print2D(&ans[i]);
  60 + }
  61 +
  62 + for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
  63 + const float *pans = ans[i].GetTensorData<float>();
  64 + for (int32_t k = 0; k != static_cast<int32_t>(shape[1]); ++k, ++p) {
  65 + EXPECT_EQ(*p, pans[k]);
  66 + }
  67 + }
  68 +
  69 + // For Cat
  70 + std::vector<const Ort::Value *> vec(ans.size());
  71 + for (int32_t i = 0; i != static_cast<int32_t>(vec.size()); ++i) {
  72 + vec[i] = &ans[i];
  73 + }
  74 + Ort::Value v2 = Cat(allocator, vec, 0);
  75 + Print2D(&v2);
  76 +
  77 + p = v.GetTensorMutableData<float>();
  78 + const float *p2 = v2.GetTensorData<float>();
  79 + for (int32_t i = 0; i != shape[0] * shape[1]; ++i) {
  80 + EXPECT_EQ(p[i], p2[i]);
  81 + }
  82 +}
  83 +
  84 +TEST(Ubind, Test2DTensorsDim1) {
  85 + Ort::AllocatorWithDefaultOptions allocator;
  86 + std::array<int64_t, 2> shape{3, 2};
  87 + Ort::Value v =
  88 + Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
  89 + float *p = v.GetTensorMutableData<float>();
  90 +
  91 + for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1]); ++i) {
  92 + p[i] = i;
  93 + }
  94 + auto ans = Unbind(allocator, &v, 1);
  95 +
  96 + Print2D(&v);
  97 + for (int32_t i = 0; i != static_cast<int32_t>(shape[1]); ++i) {
  98 + Print2D(&ans[i]);
  99 + }
  100 +
  101 + // For Cat
  102 + std::vector<const Ort::Value *> vec(ans.size());
  103 + for (int32_t i = 0; i != static_cast<int32_t>(vec.size()); ++i) {
  104 + vec[i] = &ans[i];
  105 + }
  106 + Ort::Value v2 = Cat(allocator, vec, 1);
  107 + Print2D(&v2);
  108 +
  109 + p = v.GetTensorMutableData<float>();
  110 + const float *p2 = v2.GetTensorData<float>();
  111 + for (int32_t i = 0; i != shape[0] * shape[1]; ++i) {
  112 + EXPECT_EQ(p[i], p2[i]);
  113 + }
  114 +}
  115 +
  116 +TEST(Ubind, Test3DTensorsDim0) {
  117 + Ort::AllocatorWithDefaultOptions allocator;
  118 + std::array<int64_t, 3> shape{3, 2, 5};
  119 + Ort::Value v =
  120 + Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
  121 + float *p = v.GetTensorMutableData<float>();
  122 +
  123 + for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1] * shape[2]);
  124 + ++i) {
  125 + p[i] = i;
  126 + }
  127 + auto ans = Unbind(allocator, &v, 0);
  128 +
  129 + Print3D(&v);
  130 + for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
  131 + Print3D(&ans[i]);
  132 + }
  133 +
  134 + for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
  135 + const float *pans = ans[i].GetTensorData<float>();
  136 + for (int32_t k = 0; k != static_cast<int32_t>(shape[1] * shape[2]);
  137 + ++k, ++p) {
  138 + EXPECT_EQ(*p, pans[k]);
  139 + }
  140 + }
  141 +
  142 + // For Cat
  143 + std::vector<const Ort::Value *> vec(ans.size());
  144 + for (int32_t i = 0; i != static_cast<int32_t>(vec.size()); ++i) {
  145 + vec[i] = &ans[i];
  146 + }
  147 + Ort::Value v2 = Cat(allocator, vec, 0);
  148 + Print3D(&v2);
  149 +
  150 + p = v.GetTensorMutableData<float>();
  151 + const float *p2 = v2.GetTensorData<float>();
  152 + for (int32_t i = 0; i != shape[0] * shape[1] * shape[2]; ++i) {
  153 + EXPECT_EQ(p[i], p2[i]);
  154 + }
  155 +}
  156 +
  157 +TEST(Ubind, Test3DTensorsDim1) {
  158 + Ort::AllocatorWithDefaultOptions allocator;
  159 + std::array<int64_t, 3> shape{3, 2, 5};
  160 + Ort::Value v =
  161 + Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
  162 + float *p = v.GetTensorMutableData<float>();
  163 +
  164 + for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1] * shape[2]);
  165 + ++i) {
  166 + p[i] = i;
  167 + }
  168 + auto ans = Unbind(allocator, &v, 1);
  169 +
  170 + Print3D(&v);
  171 + for (int32_t i = 0; i != static_cast<int32_t>(shape[1]); ++i) {
  172 + Print3D(&ans[i]);
  173 + }
  174 +
  175 + // For Cat
  176 + std::vector<const Ort::Value *> vec(ans.size());
  177 + for (int32_t i = 0; i != static_cast<int32_t>(vec.size()); ++i) {
  178 + vec[i] = &ans[i];
  179 + }
  180 + Ort::Value v2 = Cat(allocator, vec, 1);
  181 + Print3D(&v2);
  182 +
  183 + p = v.GetTensorMutableData<float>();
  184 + const float *p2 = v2.GetTensorData<float>();
  185 + for (int32_t i = 0; i != shape[0] * shape[1] * shape[2]; ++i) {
  186 + EXPECT_EQ(p[i], p2[i]);
  187 + }
  188 +}
  189 +
  190 +TEST(Ubind, Test3DTensorsDim2) {
  191 + Ort::AllocatorWithDefaultOptions allocator;
  192 + std::array<int64_t, 3> shape{3, 2, 5};
  193 + Ort::Value v =
  194 + Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
  195 + float *p = v.GetTensorMutableData<float>();
  196 +
  197 + for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1] * shape[2]);
  198 + ++i) {
  199 + p[i] = i;
  200 + }
  201 + auto ans = Unbind(allocator, &v, 2);
  202 +
  203 + Print3D(&v);
  204 + for (int32_t i = 0; i != static_cast<int32_t>(shape[2]); ++i) {
  205 + Print3D(&ans[i]);
  206 + }
  207 +
  208 + // For Cat
  209 + std::vector<const Ort::Value *> vec(ans.size());
  210 + for (int32_t i = 0; i != static_cast<int32_t>(vec.size()); ++i) {
  211 + vec[i] = &ans[i];
  212 + }
  213 + Ort::Value v2 = Cat(allocator, vec, 2);
  214 + Print3D(&v2);
  215 +
  216 + p = v.GetTensorMutableData<float>();
  217 + const float *p2 = v2.GetTensorData<float>();
  218 + for (int32_t i = 0; i != shape[0] * shape[1] * shape[2]; ++i) {
  219 + EXPECT_EQ(p[i], p2[i]);
  220 + }
  221 +}
  222 +
  223 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/unbind.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/unbind.h"
  6 +
  7 +#include <assert.h>
  8 +
  9 +#include <algorithm>
  10 +#include <functional>
  11 +#include <numeric>
  12 +#include <utility>
  13 +#include <vector>
  14 +
  15 +#include "onnxruntime_cxx_api.h" // NOLINT
  16 +#include "sherpa-onnx/csrc/onnx-utils.h"
  17 +
  18 +namespace sherpa_onnx {
  19 +
  20 +template <typename T /*= float*/>
  21 +std::vector<Ort::Value> Unbind(OrtAllocator *allocator, const Ort::Value *value,
  22 + int32_t dim) {
  23 + std::vector<int64_t> shape = value->GetTensorTypeAndShapeInfo().GetShape();
  24 + assert(dim >= 0);
  25 + assert(dim < static_cast<int32_t>(shape.size()));
  26 + int32_t n = static_cast<int32_t>(shape[dim]);
  27 + if (n == 1) {
  28 + std::vector<Ort::Value> ans;
  29 + ans.push_back(Clone(value));
  30 + return ans;
  31 + }
  32 +
  33 + std::vector<int64_t> ans_shape = shape;
  34 + ans_shape[dim] = 1; // // Unlike torch, we keep the dim to 1
  35 +
  36 + // allocator tensors
  37 + std::vector<Ort::Value> ans;
  38 + ans.reserve(n);
  39 + for (int32_t i = 0; i != n; ++i) {
  40 + Ort::Value t = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(),
  41 + ans_shape.size());
  42 + ans.push_back(std::move(t));
  43 + }
  44 +
  45 + auto leading_size = static_cast<int32_t>(std::accumulate(
  46 + shape.begin(), shape.begin() + dim, 1, std::multiplies<int64_t>()));
  47 +
  48 + auto trailing_size = static_cast<int32_t>(std::accumulate(
  49 + shape.begin() + dim + 1, shape.end(), 1, std::multiplies<int64_t>()));
  50 +
  51 + const T *src = value->GetTensorData<T>();
  52 +
  53 + for (int32_t i = 0; i != leading_size; ++i) {
  54 + for (int32_t k = 0; k != n; ++k) {
  55 + T *dst = ans[k].GetTensorMutableData<T>() + i * trailing_size;
  56 + std::copy(src, src + trailing_size, dst);
  57 + src += trailing_size;
  58 + }
  59 + }
  60 +
  61 + return std::move(ans);
  62 +}
  63 +
  64 +template std::vector<Ort::Value> Unbind<float>(OrtAllocator *allocator,
  65 + const Ort::Value *value,
  66 + int32_t dim);
  67 +
  68 +template std::vector<Ort::Value> Unbind<int64_t>(OrtAllocator *allocator,
  69 + const Ort::Value *value,
  70 + int32_t dim);
  71 +
  72 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/unbind.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_UNBIND_H_
  5 +#define SHERPA_ONNX_CSRC_UNBIND_H_
  6 +
  7 +#include <vector>
  8 +
  9 +#include "onnxruntime_cxx_api.h" // NOLINT
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +/** It is similar to torch.unbind() but we keep the unbind dim to 1 in
  14 + * the output
  15 + *
  16 + * @param allocator Allocator to allocate space for the returned tensor
  17 + * @param value The tensor to unbind
  18 + * @param dim The dim along which to unbind the tensor
  19 + *
  20 + * @return Return a list of tensors
  21 + */
  22 +template <typename T = float>
  23 +std::vector<Ort::Value> Unbind(OrtAllocator *allocator, const Ort::Value *value,
  24 + int32_t dim);
  25 +
  26 +} // namespace sherpa_onnx
  27 +
  28 +#endif // SHERPA_ONNX_CSRC_UNBIND_H_