正在显示
20 个修改的文件
包含
1576 行增加
和
99 行删除
| @@ -33,8 +33,13 @@ jobs: | @@ -33,8 +33,13 @@ jobs: | ||
| 33 | strategy: | 33 | strategy: |
| 34 | fail-fast: false | 34 | fail-fast: false |
| 35 | matrix: | 35 | matrix: |
| 36 | - os: [ubuntu-latest, macos-latest, windows-latest] | 36 | + os: [ubuntu-latest, macos-latest] # windows-latest] |
| 37 | python-version: ["3.7", "3.8", "3.9", "3.10"] | 37 | python-version: ["3.7", "3.8", "3.9", "3.10"] |
| 38 | + exclude: | ||
| 39 | + - os: macos-latest | ||
| 40 | + python-version: "3.9" | ||
| 41 | + - os: macos-latest | ||
| 42 | + python-version: "3.10" | ||
| 38 | 43 | ||
| 39 | steps: | 44 | steps: |
| 40 | - uses: actions/checkout@v2 | 45 | - uses: actions/checkout@v2 |
| 1 | -# Copyright 2020 Fangjun Kuang (csukuangfj@gmail.com) | ||
| 2 | -# See ../LICENSE for clarification regarding multiple authors | ||
| 3 | -# | ||
| 4 | -# Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 5 | -# you may not use this file except in compliance with the License. | ||
| 6 | -# You may obtain a copy of the License at | ||
| 7 | -# | ||
| 8 | -# http://www.apache.org/licenses/LICENSE-2.0 | ||
| 9 | -# | ||
| 10 | -# Unless required by applicable law or agreed to in writing, software | ||
| 11 | -# distributed under the License is distributed on an "AS IS" BASIS, | ||
| 12 | -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 13 | -# See the License for the specific language governing permissions and | ||
| 14 | -# limitations under the License. | ||
| 15 | - | ||
| 16 | function(download_googltest) | 1 | function(download_googltest) |
| 17 | include(FetchContent) | 2 | include(FetchContent) |
| 18 | 3 | ||
| @@ -26,6 +11,7 @@ function(download_googltest) | @@ -26,6 +11,7 @@ function(download_googltest) | ||
| 26 | ${PROJECT_SOURCE_DIR}/googletest-1.13.0.tar.gz | 11 | ${PROJECT_SOURCE_DIR}/googletest-1.13.0.tar.gz |
| 27 | ${PROJECT_BINARY_DIR}/googletest-1.13.0.tar.gz | 12 | ${PROJECT_BINARY_DIR}/googletest-1.13.0.tar.gz |
| 28 | /tmp/googletest-1.13.0.tar.gz | 13 | /tmp/googletest-1.13.0.tar.gz |
| 14 | + /star-fj/fangjun/download/github/googletest-1.13.0.tar.gz | ||
| 29 | ) | 15 | ) |
| 30 | 16 | ||
| 31 | foreach(f IN LISTS possible_file_locations) | 17 | foreach(f IN LISTS possible_file_locations) |
| 1 | include_directories(${CMAKE_SOURCE_DIR}) | 1 | include_directories(${CMAKE_SOURCE_DIR}) |
| 2 | 2 | ||
| 3 | add_library(sherpa-onnx-core | 3 | add_library(sherpa-onnx-core |
| 4 | + cat.cc | ||
| 4 | features.cc | 5 | features.cc |
| 5 | online-lstm-transducer-model.cc | 6 | online-lstm-transducer-model.cc |
| 6 | online-recognizer.cc | 7 | online-recognizer.cc |
| @@ -8,8 +9,11 @@ add_library(sherpa-onnx-core | @@ -8,8 +9,11 @@ add_library(sherpa-onnx-core | ||
| 8 | online-transducer-greedy-search-decoder.cc | 9 | online-transducer-greedy-search-decoder.cc |
| 9 | online-transducer-model-config.cc | 10 | online-transducer-model-config.cc |
| 10 | online-transducer-model.cc | 11 | online-transducer-model.cc |
| 12 | + online-zipformer-transducer-model.cc | ||
| 11 | onnx-utils.cc | 13 | onnx-utils.cc |
| 12 | symbol-table.cc | 14 | symbol-table.cc |
| 15 | + text-utils.cc | ||
| 16 | + unbind.cc | ||
| 13 | wave-reader.cc | 17 | wave-reader.cc |
| 14 | ) | 18 | ) |
| 15 | 19 | ||
| @@ -27,3 +31,32 @@ endif() | @@ -27,3 +31,32 @@ endif() | ||
| 27 | 31 | ||
| 28 | install(TARGETS sherpa-onnx-core DESTINATION lib) | 32 | install(TARGETS sherpa-onnx-core DESTINATION lib) |
| 29 | install(TARGETS sherpa-onnx DESTINATION bin) | 33 | install(TARGETS sherpa-onnx DESTINATION bin) |
| 34 | + | ||
| 35 | +if(SHERPA_ONNX_ENABLE_TESTS) | ||
| 36 | + set(sherpa_onnx_test_srcs | ||
| 37 | + cat-test.cc | ||
| 38 | + unbind-test.cc | ||
| 39 | + ) | ||
| 40 | + | ||
| 41 | + function(sherpa_onnx_add_test source) | ||
| 42 | + get_filename_component(name ${source} NAME_WE) | ||
| 43 | + set(target_name ${name}) | ||
| 44 | + add_executable(${target_name} "${source}") | ||
| 45 | + | ||
| 46 | + target_link_libraries(${target_name} | ||
| 47 | + PRIVATE | ||
| 48 | + gtest | ||
| 49 | + gtest_main | ||
| 50 | + sherpa-onnx-core | ||
| 51 | + ) | ||
| 52 | + | ||
| 53 | + add_test(NAME "${target_name}" | ||
| 54 | + COMMAND | ||
| 55 | + $<TARGET_FILE:${target_name}> | ||
| 56 | + ) | ||
| 57 | + endfunction() | ||
| 58 | + | ||
| 59 | + foreach(source IN LISTS sherpa_onnx_test_srcs) | ||
| 60 | + sherpa_onnx_add_test(${source}) | ||
| 61 | + endforeach() | ||
| 62 | +endif() |
sherpa-onnx/csrc/cat-test.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/cat-test.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/cat.h" | ||
| 6 | + | ||
| 7 | +#include "gtest/gtest.h" | ||
| 8 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +TEST(Cat, Test1DTensors) { | ||
| 13 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 14 | + | ||
| 15 | + std::array<int64_t, 1> a_shape{3}; | ||
| 16 | + std::array<int64_t, 1> b_shape{6}; | ||
| 17 | + | ||
| 18 | + Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(), | ||
| 19 | + a_shape.size()); | ||
| 20 | + | ||
| 21 | + Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(), | ||
| 22 | + b_shape.size()); | ||
| 23 | + float *pa = a.GetTensorMutableData<float>(); | ||
| 24 | + float *pb = b.GetTensorMutableData<float>(); | ||
| 25 | + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) { | ||
| 26 | + pa[i] = i; | ||
| 27 | + } | ||
| 28 | + for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0]); ++i) { | ||
| 29 | + pb[i] = i + 10; | ||
| 30 | + } | ||
| 31 | + | ||
| 32 | + Ort::Value ans = Cat(allocator, {&a, &b}, 0); | ||
| 33 | + | ||
| 34 | + const float *pans = ans.GetTensorData<float>(); | ||
| 35 | + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) { | ||
| 36 | + EXPECT_EQ(pa[i], pans[i]); | ||
| 37 | + } | ||
| 38 | + | ||
| 39 | + for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0]); ++i) { | ||
| 40 | + EXPECT_EQ(pb[i], pans[i + a_shape[0]]); | ||
| 41 | + } | ||
| 42 | + | ||
| 43 | + Print1D(&a); | ||
| 44 | + Print1D(&b); | ||
| 45 | + Print1D(&ans); | ||
| 46 | +} | ||
| 47 | + | ||
| 48 | +TEST(Cat, Test2DTensorsDim0) { | ||
| 49 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 50 | + | ||
| 51 | + std::array<int64_t, 2> a_shape{2, 3}; | ||
| 52 | + std::array<int64_t, 2> b_shape{4, 3}; | ||
| 53 | + | ||
| 54 | + Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(), | ||
| 55 | + a_shape.size()); | ||
| 56 | + | ||
| 57 | + Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(), | ||
| 58 | + b_shape.size()); | ||
| 59 | + | ||
| 60 | + float *pa = a.GetTensorMutableData<float>(); | ||
| 61 | + float *pb = b.GetTensorMutableData<float>(); | ||
| 62 | + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) { | ||
| 63 | + pa[i] = i; | ||
| 64 | + } | ||
| 65 | + for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0] * b_shape[1]); ++i) { | ||
| 66 | + pb[i] = i + 10; | ||
| 67 | + } | ||
| 68 | + | ||
| 69 | + Ort::Value ans = Cat(allocator, {&a, &b}, 0); | ||
| 70 | + | ||
| 71 | + const float *pans = ans.GetTensorData<float>(); | ||
| 72 | + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) { | ||
| 73 | + EXPECT_EQ(pa[i], pans[i]); | ||
| 74 | + } | ||
| 75 | + for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0] * b_shape[1]); ++i) { | ||
| 76 | + EXPECT_EQ(pb[i], pans[i + a_shape[0] * a_shape[1]]); | ||
| 77 | + } | ||
| 78 | + | ||
| 79 | + Print2D(&a); | ||
| 80 | + Print2D(&b); | ||
| 81 | + Print2D(&ans); | ||
| 82 | +} | ||
| 83 | + | ||
| 84 | +TEST(Cat, Test2DTensorsDim1) { | ||
| 85 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 86 | + | ||
| 87 | + std::array<int64_t, 2> a_shape{4, 3}; | ||
| 88 | + std::array<int64_t, 2> b_shape{4, 2}; | ||
| 89 | + | ||
| 90 | + Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(), | ||
| 91 | + a_shape.size()); | ||
| 92 | + | ||
| 93 | + Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(), | ||
| 94 | + b_shape.size()); | ||
| 95 | + | ||
| 96 | + float *pa = a.GetTensorMutableData<float>(); | ||
| 97 | + float *pb = b.GetTensorMutableData<float>(); | ||
| 98 | + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) { | ||
| 99 | + pa[i] = i; | ||
| 100 | + } | ||
| 101 | + for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0] * b_shape[1]); ++i) { | ||
| 102 | + pb[i] = i + 10; | ||
| 103 | + } | ||
| 104 | + | ||
| 105 | + Ort::Value ans = Cat(allocator, {&a, &b}, 1); | ||
| 106 | + | ||
| 107 | + const float *pans = ans.GetTensorData<float>(); | ||
| 108 | + | ||
| 109 | + for (int32_t r = 0; r != static_cast<int32_t>(a_shape[0]); ++r) { | ||
| 110 | + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[1]); | ||
| 111 | + ++i, ++pa, ++pans) { | ||
| 112 | + EXPECT_EQ(*pa, *pans); | ||
| 113 | + } | ||
| 114 | + | ||
| 115 | + for (int32_t i = 0; i != static_cast<int32_t>(b_shape[1]); | ||
| 116 | + ++i, ++pb, ++pans) { | ||
| 117 | + EXPECT_EQ(*pb, *pans); | ||
| 118 | + } | ||
| 119 | + } | ||
| 120 | + | ||
| 121 | + Print2D(&a); | ||
| 122 | + Print2D(&b); | ||
| 123 | + Print2D(&ans); | ||
| 124 | +} | ||
| 125 | + | ||
| 126 | +TEST(Cat, Test3DTensorsDim0) { | ||
| 127 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 128 | + | ||
| 129 | + std::array<int64_t, 3> a_shape{2, 3, 2}; | ||
| 130 | + std::array<int64_t, 3> b_shape{4, 3, 2}; | ||
| 131 | + | ||
| 132 | + Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(), | ||
| 133 | + a_shape.size()); | ||
| 134 | + | ||
| 135 | + Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(), | ||
| 136 | + b_shape.size()); | ||
| 137 | + | ||
| 138 | + float *pa = a.GetTensorMutableData<float>(); | ||
| 139 | + float *pb = b.GetTensorMutableData<float>(); | ||
| 140 | + for (int32_t i = 0; | ||
| 141 | + i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) { | ||
| 142 | + pa[i] = i; | ||
| 143 | + } | ||
| 144 | + for (int32_t i = 0; | ||
| 145 | + i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) { | ||
| 146 | + pb[i] = i + 10; | ||
| 147 | + } | ||
| 148 | + | ||
| 149 | + Ort::Value ans = Cat(allocator, {&a, &b}, 0); | ||
| 150 | + | ||
| 151 | + const float *pans = ans.GetTensorData<float>(); | ||
| 152 | + for (int32_t i = 0; | ||
| 153 | + i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) { | ||
| 154 | + EXPECT_EQ(pa[i], pans[i]); | ||
| 155 | + } | ||
| 156 | + for (int32_t i = 0; | ||
| 157 | + i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) { | ||
| 158 | + EXPECT_EQ(pb[i], pans[i + a_shape[0] * a_shape[1] * a_shape[2]]); | ||
| 159 | + } | ||
| 160 | + | ||
| 161 | + Print3D(&a); | ||
| 162 | + Print3D(&b); | ||
| 163 | + Print3D(&ans); | ||
| 164 | +} | ||
| 165 | + | ||
| 166 | +TEST(Cat, Test3DTensorsDim1) { | ||
| 167 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 168 | + | ||
| 169 | + std::array<int64_t, 3> a_shape{2, 2, 3}; | ||
| 170 | + std::array<int64_t, 3> b_shape{2, 4, 3}; | ||
| 171 | + | ||
| 172 | + Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(), | ||
| 173 | + a_shape.size()); | ||
| 174 | + | ||
| 175 | + Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(), | ||
| 176 | + b_shape.size()); | ||
| 177 | + | ||
| 178 | + float *pa = a.GetTensorMutableData<float>(); | ||
| 179 | + float *pb = b.GetTensorMutableData<float>(); | ||
| 180 | + for (int32_t i = 0; | ||
| 181 | + i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) { | ||
| 182 | + pa[i] = i; | ||
| 183 | + } | ||
| 184 | + for (int32_t i = 0; | ||
| 185 | + i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) { | ||
| 186 | + pb[i] = i + 10; | ||
| 187 | + } | ||
| 188 | + | ||
| 189 | + Ort::Value ans = Cat(allocator, {&a, &b}, 1); | ||
| 190 | + | ||
| 191 | + const float *pans = ans.GetTensorData<float>(); | ||
| 192 | + | ||
| 193 | + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) { | ||
| 194 | + for (int32_t k = 0; k != static_cast<int32_t>(a_shape[1] * a_shape[2]); | ||
| 195 | + ++k, ++pa, ++pans) { | ||
| 196 | + EXPECT_EQ(*pa, *pans); | ||
| 197 | + } | ||
| 198 | + | ||
| 199 | + for (int32_t k = 0; k != static_cast<int32_t>(b_shape[1] * b_shape[2]); | ||
| 200 | + ++k, ++pb, ++pans) { | ||
| 201 | + EXPECT_EQ(*pb, *pans); | ||
| 202 | + } | ||
| 203 | + } | ||
| 204 | + | ||
| 205 | + Print3D(&a); | ||
| 206 | + Print3D(&b); | ||
| 207 | + Print3D(&ans); | ||
| 208 | +} | ||
| 209 | + | ||
| 210 | +TEST(Cat, Test3DTensorsDim2) { | ||
| 211 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 212 | + | ||
| 213 | + std::array<int64_t, 3> a_shape{2, 3, 4}; | ||
| 214 | + std::array<int64_t, 3> b_shape{2, 3, 5}; | ||
| 215 | + | ||
| 216 | + Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(), | ||
| 217 | + a_shape.size()); | ||
| 218 | + | ||
| 219 | + Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(), | ||
| 220 | + b_shape.size()); | ||
| 221 | + | ||
| 222 | + float *pa = a.GetTensorMutableData<float>(); | ||
| 223 | + float *pb = b.GetTensorMutableData<float>(); | ||
| 224 | + for (int32_t i = 0; | ||
| 225 | + i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) { | ||
| 226 | + pa[i] = i; | ||
| 227 | + } | ||
| 228 | + for (int32_t i = 0; | ||
| 229 | + i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) { | ||
| 230 | + pb[i] = i + 10; | ||
| 231 | + } | ||
| 232 | + | ||
| 233 | + Ort::Value ans = Cat(allocator, {&a, &b}, 2); | ||
| 234 | + | ||
| 235 | + const float *pans = ans.GetTensorData<float>(); | ||
| 236 | + | ||
| 237 | + for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) { | ||
| 238 | + for (int32_t k = 0; k != static_cast<int32_t>(a_shape[2]); | ||
| 239 | + ++k, ++pa, ++pans) { | ||
| 240 | + EXPECT_EQ(*pa, *pans); | ||
| 241 | + } | ||
| 242 | + | ||
| 243 | + for (int32_t k = 0; k != static_cast<int32_t>(b_shape[2]); | ||
| 244 | + ++k, ++pb, ++pans) { | ||
| 245 | + EXPECT_EQ(*pb, *pans); | ||
| 246 | + } | ||
| 247 | + } | ||
| 248 | + | ||
| 249 | + Print3D(&a); | ||
| 250 | + Print3D(&b); | ||
| 251 | + Print3D(&ans); | ||
| 252 | +} | ||
| 253 | + | ||
| 254 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/cat.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/cat.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/cat.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <functional> | ||
| 9 | +#include <numeric> | ||
| 10 | +#include <utility> | ||
| 11 | + | ||
| 12 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +static bool Compare(const std::vector<int64_t> &a, | ||
| 17 | + const std::vector<int64_t> &b, int32_t skip_dim) { | ||
| 18 | + if (a.size() != b.size()) return false; | ||
| 19 | + | ||
| 20 | + for (int32_t i = 0; i != static_cast<int32_t>(a.size()); ++i) { | ||
| 21 | + if (i == skip_dim) continue; | ||
| 22 | + | ||
| 23 | + if (a[i] != b[i]) return false; | ||
| 24 | + } | ||
| 25 | + | ||
| 26 | + return true; | ||
| 27 | +} | ||
| 28 | + | ||
| 29 | +static void PrintShape(const std::vector<int64_t> &a) { | ||
| 30 | + for (auto i : a) { | ||
| 31 | + fprintf(stderr, "%d ", static_cast<int32_t>(i)); | ||
| 32 | + } | ||
| 33 | + fprintf(stderr, "\n"); | ||
| 34 | +} | ||
| 35 | + | ||
| 36 | +template <typename T /*=float*/> | ||
| 37 | +Ort::Value Cat(OrtAllocator *allocator, | ||
| 38 | + const std::vector<const Ort::Value *> &values, int32_t dim) { | ||
| 39 | + if (values.size() == 1u) { | ||
| 40 | + return Clone(values[0]); | ||
| 41 | + } | ||
| 42 | + | ||
| 43 | + std::vector<int64_t> v0_shape = | ||
| 44 | + values[0]->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 45 | + | ||
| 46 | + int64_t total_dim = v0_shape[dim]; | ||
| 47 | + | ||
| 48 | + for (int32_t i = 1; i != static_cast<int32_t>(values.size()); ++i) { | ||
| 49 | + auto s = values[i]->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 50 | + total_dim += s[dim]; | ||
| 51 | + | ||
| 52 | + bool ret = Compare(v0_shape, s, dim); | ||
| 53 | + if (!ret) { | ||
| 54 | + fprintf(stderr, "Incorrect shape in Cat !\n"); | ||
| 55 | + | ||
| 56 | + fprintf(stderr, "Shape for tensor 0: "); | ||
| 57 | + PrintShape(v0_shape); | ||
| 58 | + | ||
| 59 | + fprintf(stderr, "Shape for tensor %d: ", i); | ||
| 60 | + PrintShape(s); | ||
| 61 | + | ||
| 62 | + exit(-1); | ||
| 63 | + } | ||
| 64 | + } | ||
| 65 | + | ||
| 66 | + std::vector<int64_t> ans_shape; | ||
| 67 | + ans_shape.reserve(v0_shape.size()); | ||
| 68 | + ans_shape.insert(ans_shape.end(), v0_shape.data(), v0_shape.data() + dim); | ||
| 69 | + ans_shape.push_back(total_dim); | ||
| 70 | + ans_shape.insert(ans_shape.end(), v0_shape.data() + dim + 1, | ||
| 71 | + v0_shape.data() + v0_shape.size()); | ||
| 72 | + | ||
| 73 | + auto leading_size = static_cast<int32_t>(std::accumulate( | ||
| 74 | + v0_shape.begin(), v0_shape.begin() + dim, 1, std::multiplies<int64_t>())); | ||
| 75 | + | ||
| 76 | + auto trailing_size = static_cast<int32_t>( | ||
| 77 | + std::accumulate(v0_shape.begin() + dim + 1, v0_shape.end(), 1, | ||
| 78 | + std::multiplies<int64_t>())); | ||
| 79 | + | ||
| 80 | + Ort::Value ans = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(), | ||
| 81 | + ans_shape.size()); | ||
| 82 | + T *dst = ans.GetTensorMutableData<T>(); | ||
| 83 | + | ||
| 84 | + for (int32_t i = 0; i != leading_size; ++i) { | ||
| 85 | + for (int32_t n = 0; n != static_cast<int32_t>(values.size()); ++n) { | ||
| 86 | + auto this_dim = values[n]->GetTensorTypeAndShapeInfo().GetShape()[dim]; | ||
| 87 | + const T *src = values[n]->GetTensorData<T>(); | ||
| 88 | + src += i * this_dim * trailing_size; | ||
| 89 | + | ||
| 90 | + std::copy(src, src + this_dim * trailing_size, dst); | ||
| 91 | + dst += this_dim * trailing_size; | ||
| 92 | + } | ||
| 93 | + } | ||
| 94 | + | ||
| 95 | + return std::move(ans); | ||
| 96 | +} | ||
| 97 | + | ||
| 98 | +template Ort::Value Cat<float>(OrtAllocator *allocator, | ||
| 99 | + const std::vector<const Ort::Value *> &values, | ||
| 100 | + int32_t dim); | ||
| 101 | + | ||
| 102 | +template Ort::Value Cat<int64_t>(OrtAllocator *allocator, | ||
| 103 | + const std::vector<const Ort::Value *> &values, | ||
| 104 | + int32_t dim); | ||
| 105 | + | ||
| 106 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/cat.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/cat.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_CAT_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_CAT_H_ | ||
| 6 | + | ||
| 7 | +#include <vector> | ||
| 8 | + | ||
| 9 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +/** Cat a list of tensors along the given dim. | ||
| 14 | + * | ||
| 15 | + * @param allocator Allocator to allocate space for the returned tensor | ||
| 16 | + * @param values Pointer to a list of tensors. The shape of the tensor must | ||
| 17 | + * be the same except on the dim to be concatenated. | ||
| 18 | + * @param dim The dim along which to concatenate the input tensors | ||
| 19 | + * | ||
| 20 | + * @return Return the concatenated tensor | ||
| 21 | + */ | ||
| 22 | +template <typename T = float> | ||
| 23 | +Ort::Value Cat(OrtAllocator *allocator, | ||
| 24 | + const std::vector<const Ort::Value *> &values, int32_t dim); | ||
| 25 | + | ||
| 26 | +} // namespace sherpa_onnx | ||
| 27 | + | ||
| 28 | +#endif // SHERPA_ONNX_CSRC_CAT_H_ |
sherpa-onnx/csrc/macros.h
0 → 100644
| 1 | + | ||
| 2 | +// sherpa-onnx/csrc/macros.h | ||
| 3 | +// | ||
| 4 | +// Copyright 2023 Xiaomi Corporation | ||
| 5 | + | ||
| 6 | +#ifndef SHERPA_ONNX_CSRC_MACROS_H_ | ||
| 7 | +#define SHERPA_ONNX_CSRC_MACROS_H_ | ||
| 8 | + | ||
| 9 | +#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \ | ||
| 10 | + do { \ | ||
| 11 | + auto value = \ | ||
| 12 | + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | ||
| 13 | + if (!value) { \ | ||
| 14 | + fprintf(stderr, "%s does not exist in the metadata\n", src_key); \ | ||
| 15 | + exit(-1); \ | ||
| 16 | + } \ | ||
| 17 | + \ | ||
| 18 | + dst = atoi(value.get()); \ | ||
| 19 | + if (dst <= 0) { \ | ||
| 20 | + fprintf(stderr, "Invalid value %d for %s\n", dst, src_key); \ | ||
| 21 | + exit(-1); \ | ||
| 22 | + } \ | ||
| 23 | + } while (0) | ||
| 24 | + | ||
| 25 | +#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \ | ||
| 26 | + do { \ | ||
| 27 | + auto value = \ | ||
| 28 | + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | ||
| 29 | + if (!value) { \ | ||
| 30 | + fprintf(stderr, "%s does not exist in the metadata\n", src_key); \ | ||
| 31 | + exit(-1); \ | ||
| 32 | + } \ | ||
| 33 | + \ | ||
| 34 | + bool ret = SplitStringToIntegers(value.get(), ",", true, &dst); \ | ||
| 35 | + if (!ret) { \ | ||
| 36 | + fprintf(stderr, "Invalid value %s for %s\n", value.get(), src_key); \ | ||
| 37 | + exit(-1); \ | ||
| 38 | + } \ | ||
| 39 | + } while (0) | ||
| 40 | + | ||
| 41 | +#endif // SHERPA_ONNX_CSRC_MACROS_H_ |
| @@ -3,6 +3,8 @@ | @@ -3,6 +3,8 @@ | ||
| 3 | // Copyright (c) 2023 Xiaomi Corporation | 3 | // Copyright (c) 2023 Xiaomi Corporation |
| 4 | #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" | 4 | #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" |
| 5 | 5 | ||
| 6 | +#include <assert.h> | ||
| 7 | + | ||
| 6 | #include <algorithm> | 8 | #include <algorithm> |
| 7 | #include <memory> | 9 | #include <memory> |
| 8 | #include <sstream> | 10 | #include <sstream> |
| @@ -11,23 +13,11 @@ | @@ -11,23 +13,11 @@ | ||
| 11 | #include <vector> | 13 | #include <vector> |
| 12 | 14 | ||
| 13 | #include "onnxruntime_cxx_api.h" // NOLINT | 15 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 16 | +#include "sherpa-onnx/csrc/cat.h" | ||
| 17 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 14 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" | 18 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" |
| 15 | #include "sherpa-onnx/csrc/onnx-utils.h" | 19 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 16 | - | ||
| 17 | -#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \ | ||
| 18 | - do { \ | ||
| 19 | - auto value = \ | ||
| 20 | - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | ||
| 21 | - if (!value) { \ | ||
| 22 | - fprintf(stderr, "%s does not exist in the metadata\n", src_key); \ | ||
| 23 | - exit(-1); \ | ||
| 24 | - } \ | ||
| 25 | - dst = atoi(value.get()); \ | ||
| 26 | - if (dst <= 0) { \ | ||
| 27 | - fprintf(stderr, "Invalud value %d for %s\n", dst, src_key); \ | ||
| 28 | - exit(-1); \ | ||
| 29 | - } \ | ||
| 30 | - } while (0) | 20 | +#include "sherpa-onnx/csrc/unbind.h" |
| 31 | 21 | ||
| 32 | namespace sherpa_onnx { | 22 | namespace sherpa_onnx { |
| 33 | 23 | ||
| @@ -64,7 +54,7 @@ void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) { | @@ -64,7 +54,7 @@ void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) { | ||
| 64 | fprintf(stderr, "%s\n", os.str().c_str()); | 54 | fprintf(stderr, "%s\n", os.str().c_str()); |
| 65 | } | 55 | } |
| 66 | 56 | ||
| 67 | - Ort::AllocatorWithDefaultOptions allocator; | 57 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below |
| 68 | SHERPA_ONNX_READ_META_DATA(num_encoder_layers_, "num_encoder_layers"); | 58 | SHERPA_ONNX_READ_META_DATA(num_encoder_layers_, "num_encoder_layers"); |
| 69 | SHERPA_ONNX_READ_META_DATA(T_, "T"); | 59 | SHERPA_ONNX_READ_META_DATA(T_, "T"); |
| 70 | SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); | 60 | SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); |
| @@ -91,7 +81,7 @@ void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) { | @@ -91,7 +81,7 @@ void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) { | ||
| 91 | fprintf(stderr, "%s\n", os.str().c_str()); | 81 | fprintf(stderr, "%s\n", os.str().c_str()); |
| 92 | } | 82 | } |
| 93 | 83 | ||
| 94 | - Ort::AllocatorWithDefaultOptions allocator; | 84 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below |
| 95 | SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); | 85 | SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); |
| 96 | SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); | 86 | SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); |
| 97 | } | 87 | } |
| @@ -120,37 +110,19 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::StackStates( | @@ -120,37 +110,19 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::StackStates( | ||
| 120 | const std::vector<std::vector<Ort::Value>> &states) const { | 110 | const std::vector<std::vector<Ort::Value>> &states) const { |
| 121 | int32_t batch_size = static_cast<int32_t>(states.size()); | 111 | int32_t batch_size = static_cast<int32_t>(states.size()); |
| 122 | 112 | ||
| 123 | - std::array<int64_t, 3> h_shape{num_encoder_layers_, batch_size, d_model_}; | ||
| 124 | - Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(), | ||
| 125 | - h_shape.size()); | ||
| 126 | - | ||
| 127 | - std::array<int64_t, 3> c_shape{num_encoder_layers_, batch_size, | ||
| 128 | - rnn_hidden_size_}; | ||
| 129 | - | ||
| 130 | - Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(), | ||
| 131 | - c_shape.size()); | ||
| 132 | - | ||
| 133 | - float *dst_h = h.GetTensorMutableData<float>(); | ||
| 134 | - float *dst_c = c.GetTensorMutableData<float>(); | 113 | + std::vector<const Ort::Value *> h_buf(batch_size); |
| 114 | + std::vector<const Ort::Value *> c_buf(batch_size); | ||
| 135 | 115 | ||
| 136 | - for (int32_t layer = 0; layer != num_encoder_layers_; ++layer) { | ||
| 137 | - for (int32_t i = 0; i != batch_size; ++i) { | ||
| 138 | - const float *src_h = | ||
| 139 | - states[i][0].GetTensorData<float>() + layer * d_model_; | ||
| 140 | - | ||
| 141 | - const float *src_c = | ||
| 142 | - states[i][1].GetTensorData<float>() + layer * rnn_hidden_size_; | ||
| 143 | - | ||
| 144 | - std::copy(src_h, src_h + d_model_, dst_h); | ||
| 145 | - std::copy(src_c, src_c + rnn_hidden_size_, dst_c); | ||
| 146 | - | ||
| 147 | - dst_h += d_model_; | ||
| 148 | - dst_c += rnn_hidden_size_; | ||
| 149 | - } | 116 | + for (int32_t i = 0; i != batch_size; ++i) { |
| 117 | + assert(states[i].size() == 2); | ||
| 118 | + h_buf[i] = &states[i][0]; | ||
| 119 | + c_buf[i] = &states[i][1]; | ||
| 150 | } | 120 | } |
| 151 | 121 | ||
| 152 | - std::vector<Ort::Value> ans; | 122 | + Ort::Value h = Cat(allocator_, h_buf, 1); |
| 123 | + Ort::Value c = Cat(allocator_, c_buf, 1); | ||
| 153 | 124 | ||
| 125 | + std::vector<Ort::Value> ans; | ||
| 154 | ans.reserve(2); | 126 | ans.reserve(2); |
| 155 | ans.push_back(std::move(h)); | 127 | ans.push_back(std::move(h)); |
| 156 | ans.push_back(std::move(c)); | 128 | ans.push_back(std::move(c)); |
| @@ -161,37 +133,19 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::StackStates( | @@ -161,37 +133,19 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::StackStates( | ||
| 161 | std::vector<std::vector<Ort::Value>> OnlineLstmTransducerModel::UnStackStates( | 133 | std::vector<std::vector<Ort::Value>> OnlineLstmTransducerModel::UnStackStates( |
| 162 | const std::vector<Ort::Value> &states) const { | 134 | const std::vector<Ort::Value> &states) const { |
| 163 | int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1]; | 135 | int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1]; |
| 136 | + assert(states.size() == 2); | ||
| 164 | 137 | ||
| 165 | std::vector<std::vector<Ort::Value>> ans(batch_size); | 138 | std::vector<std::vector<Ort::Value>> ans(batch_size); |
| 166 | 139 | ||
| 167 | - // allocate space | ||
| 168 | - std::array<int64_t, 3> h_shape{num_encoder_layers_, 1, d_model_}; | ||
| 169 | - std::array<int64_t, 3> c_shape{num_encoder_layers_, 1, rnn_hidden_size_}; | 140 | + std::vector<Ort::Value> h_vec = Unbind(allocator_, &states[0], 1); |
| 141 | + std::vector<Ort::Value> c_vec = Unbind(allocator_, &states[1], 1); | ||
| 170 | 142 | ||
| 171 | - for (int32_t i = 0; i != batch_size; ++i) { | ||
| 172 | - Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(), | ||
| 173 | - h_shape.size()); | ||
| 174 | - Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(), | ||
| 175 | - c_shape.size()); | ||
| 176 | - ans[i].push_back(std::move(h)); | ||
| 177 | - ans[i].push_back(std::move(c)); | ||
| 178 | - } | 143 | + assert(h_vec.size() == batch_size); |
| 144 | + assert(c_vec.size() == batch_size); | ||
| 179 | 145 | ||
| 180 | - for (int32_t layer = 0; layer != num_encoder_layers_; ++layer) { | ||
| 181 | - for (int32_t i = 0; i != batch_size; ++i) { | ||
| 182 | - const float *src_h = states[0].GetTensorData<float>() + | ||
| 183 | - layer * batch_size * d_model_ + i * d_model_; | ||
| 184 | - const float *src_c = states[1].GetTensorData<float>() + | ||
| 185 | - layer * batch_size * rnn_hidden_size_ + | ||
| 186 | - i * rnn_hidden_size_; | ||
| 187 | - | ||
| 188 | - float *dst_h = ans[i][0].GetTensorMutableData<float>() + layer * d_model_; | ||
| 189 | - float *dst_c = | ||
| 190 | - ans[i][1].GetTensorMutableData<float>() + layer * rnn_hidden_size_; | ||
| 191 | - | ||
| 192 | - std::copy(src_h, src_h + d_model_, dst_h); | ||
| 193 | - std::copy(src_c, src_c + rnn_hidden_size_, dst_c); | ||
| 194 | - } | 146 | + for (int32_t i = 0; i != batch_size; ++i) { |
| 147 | + ans[i].push_back(std::move(h_vec[i])); | ||
| 148 | + ans[i].push_back(std::move(c_vec[i])); | ||
| 195 | } | 149 | } |
| 196 | 150 | ||
| 197 | return ans; | 151 | return ans; |
| @@ -206,20 +160,15 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() { | @@ -206,20 +160,15 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() { | ||
| 206 | Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(), | 160 | Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(), |
| 207 | h_shape.size()); | 161 | h_shape.size()); |
| 208 | 162 | ||
| 209 | - std::fill(h.GetTensorMutableData<float>(), | ||
| 210 | - h.GetTensorMutableData<float>() + | ||
| 211 | - num_encoder_layers_ * kBatchSize * d_model_, | ||
| 212 | - 0); | 163 | + Fill<float>(&h, 0); |
| 213 | 164 | ||
| 214 | std::array<int64_t, 3> c_shape{num_encoder_layers_, kBatchSize, | 165 | std::array<int64_t, 3> c_shape{num_encoder_layers_, kBatchSize, |
| 215 | rnn_hidden_size_}; | 166 | rnn_hidden_size_}; |
| 167 | + | ||
| 216 | Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(), | 168 | Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(), |
| 217 | c_shape.size()); | 169 | c_shape.size()); |
| 218 | 170 | ||
| 219 | - std::fill(c.GetTensorMutableData<float>(), | ||
| 220 | - c.GetTensorMutableData<float>() + | ||
| 221 | - num_encoder_layers_ * kBatchSize * rnn_hidden_size_, | ||
| 222 | - 0); | 171 | + Fill<float>(&c, 0); |
| 223 | 172 | ||
| 224 | std::vector<Ort::Value> states; | 173 | std::vector<Ort::Value> states; |
| 225 | 174 |
| @@ -8,11 +8,13 @@ | @@ -8,11 +8,13 @@ | ||
| 8 | #include <string> | 8 | #include <string> |
| 9 | 9 | ||
| 10 | #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" | 10 | #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" |
| 11 | +#include "sherpa-onnx/csrc/online-zipformer-transducer-model.h" | ||
| 11 | #include "sherpa-onnx/csrc/onnx-utils.h" | 12 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 12 | namespace sherpa_onnx { | 13 | namespace sherpa_onnx { |
| 13 | 14 | ||
| 14 | enum class ModelType { | 15 | enum class ModelType { |
| 15 | kLstm, | 16 | kLstm, |
| 17 | + kZipformer, | ||
| 16 | kUnkown, | 18 | kUnkown, |
| 17 | }; | 19 | }; |
| 18 | 20 | ||
| @@ -40,6 +42,8 @@ static ModelType GetModelType(const OnlineTransducerModelConfig &config) { | @@ -40,6 +42,8 @@ static ModelType GetModelType(const OnlineTransducerModelConfig &config) { | ||
| 40 | 42 | ||
| 41 | if (model_type.get() == std::string("lstm")) { | 43 | if (model_type.get() == std::string("lstm")) { |
| 42 | return ModelType::kLstm; | 44 | return ModelType::kLstm; |
| 45 | + } else if (model_type.get() == std::string("zipformer")) { | ||
| 46 | + return ModelType::kZipformer; | ||
| 43 | } else { | 47 | } else { |
| 44 | fprintf(stderr, "Unsupported model_type: %s\n", model_type.get()); | 48 | fprintf(stderr, "Unsupported model_type: %s\n", model_type.get()); |
| 45 | return ModelType::kUnkown; | 49 | return ModelType::kUnkown; |
| @@ -53,6 +57,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | @@ -53,6 +57,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | ||
| 53 | switch (model_type) { | 57 | switch (model_type) { |
| 54 | case ModelType::kLstm: | 58 | case ModelType::kLstm: |
| 55 | return std::make_unique<OnlineLstmTransducerModel>(config); | 59 | return std::make_unique<OnlineLstmTransducerModel>(config); |
| 60 | + case ModelType::kZipformer: | ||
| 61 | + return std::make_unique<OnlineZipformerTransducerModel>(config); | ||
| 56 | case ModelType::kUnkown: | 62 | case ModelType::kUnkown: |
| 57 | return nullptr; | 63 | return nullptr; |
| 58 | } | 64 | } |
| 1 | +// sherpa-onnx/csrc/online-zipformer-transducer-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/online-zipformer-transducer-model.h" | ||
| 6 | + | ||
| 7 | +#include <assert.h> | ||
| 8 | + | ||
| 9 | +#include <algorithm> | ||
| 10 | +#include <memory> | ||
| 11 | +#include <sstream> | ||
| 12 | +#include <string> | ||
| 13 | +#include <utility> | ||
| 14 | +#include <vector> | ||
| 15 | + | ||
| 16 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 17 | +#include "sherpa-onnx/csrc/cat.h" | ||
| 18 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 19 | +#include "sherpa-onnx/csrc/online-transducer-decoder.h" | ||
| 20 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 21 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 22 | +#include "sherpa-onnx/csrc/unbind.h" | ||
| 23 | + | ||
| 24 | +namespace sherpa_onnx { | ||
| 25 | + | ||
| 26 | +OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( | ||
| 27 | + const OnlineTransducerModelConfig &config) | ||
| 28 | + : env_(ORT_LOGGING_LEVEL_WARNING), | ||
| 29 | + config_(config), | ||
| 30 | + sess_opts_{}, | ||
| 31 | + allocator_{} { | ||
| 32 | + sess_opts_.SetIntraOpNumThreads(config.num_threads); | ||
| 33 | + sess_opts_.SetInterOpNumThreads(config.num_threads); | ||
| 34 | + | ||
| 35 | + InitEncoder(config.encoder_filename); | ||
| 36 | + InitDecoder(config.decoder_filename); | ||
| 37 | + InitJoiner(config.joiner_filename); | ||
| 38 | +} | ||
| 39 | + | ||
| 40 | +void OnlineZipformerTransducerModel::InitEncoder(const std::string &filename) { | ||
| 41 | + encoder_sess_ = std::make_unique<Ort::Session>( | ||
| 42 | + env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); | ||
| 43 | + | ||
| 44 | + GetInputNames(encoder_sess_.get(), &encoder_input_names_, | ||
| 45 | + &encoder_input_names_ptr_); | ||
| 46 | + | ||
| 47 | + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, | ||
| 48 | + &encoder_output_names_ptr_); | ||
| 49 | + | ||
| 50 | + // get meta data | ||
| 51 | + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); | ||
| 52 | + if (config_.debug) { | ||
| 53 | + std::ostringstream os; | ||
| 54 | + os << "---encoder---\n"; | ||
| 55 | + PrintModelMetadata(os, meta_data); | ||
| 56 | + fprintf(stderr, "%s\n", os.str().c_str()); | ||
| 57 | + } | ||
| 58 | + | ||
| 59 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 60 | + SHERPA_ONNX_READ_META_DATA_VEC(encoder_dims_, "encoder_dims"); | ||
| 61 | + SHERPA_ONNX_READ_META_DATA_VEC(attention_dims_, "attention_dims"); | ||
| 62 | + SHERPA_ONNX_READ_META_DATA_VEC(num_encoder_layers_, "num_encoder_layers"); | ||
| 63 | + SHERPA_ONNX_READ_META_DATA_VEC(cnn_module_kernels_, "cnn_module_kernels"); | ||
| 64 | + SHERPA_ONNX_READ_META_DATA_VEC(left_context_len_, "left_context_len"); | ||
| 65 | + | ||
| 66 | + SHERPA_ONNX_READ_META_DATA(T_, "T"); | ||
| 67 | + SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); | ||
| 68 | + | ||
| 69 | + if (config_.debug) { | ||
| 70 | + auto print = [](const std::vector<int32_t> &v, const char *name) { | ||
| 71 | + fprintf(stderr, "%s: ", name); | ||
| 72 | + for (auto i : v) { | ||
| 73 | + fprintf(stderr, "%d ", i); | ||
| 74 | + } | ||
| 75 | + fprintf(stderr, "\n"); | ||
| 76 | + }; | ||
| 77 | + print(encoder_dims_, "encoder_dims"); | ||
| 78 | + print(attention_dims_, "attention_dims"); | ||
| 79 | + print(num_encoder_layers_, "num_encoder_layers"); | ||
| 80 | + print(cnn_module_kernels_, "cnn_module_kernels"); | ||
| 81 | + print(left_context_len_, "left_context_len"); | ||
| 82 | + fprintf(stderr, "T: %d\n", T_); | ||
| 83 | + fprintf(stderr, "decode_chunk_len_: %d\n", decode_chunk_len_); | ||
| 84 | + } | ||
| 85 | +} | ||
| 86 | + | ||
| 87 | +void OnlineZipformerTransducerModel::InitDecoder(const std::string &filename) { | ||
| 88 | + decoder_sess_ = std::make_unique<Ort::Session>( | ||
| 89 | + env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); | ||
| 90 | + | ||
| 91 | + GetInputNames(decoder_sess_.get(), &decoder_input_names_, | ||
| 92 | + &decoder_input_names_ptr_); | ||
| 93 | + | ||
| 94 | + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, | ||
| 95 | + &decoder_output_names_ptr_); | ||
| 96 | + | ||
| 97 | + // get meta data | ||
| 98 | + Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata(); | ||
| 99 | + if (config_.debug) { | ||
| 100 | + std::ostringstream os; | ||
| 101 | + os << "---decoder---\n"; | ||
| 102 | + PrintModelMetadata(os, meta_data); | ||
| 103 | + fprintf(stderr, "%s\n", os.str().c_str()); | ||
| 104 | + } | ||
| 105 | + | ||
| 106 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 107 | + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); | ||
| 108 | + SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); | ||
| 109 | +} | ||
| 110 | + | ||
| 111 | +void OnlineZipformerTransducerModel::InitJoiner(const std::string &filename) { | ||
| 112 | + joiner_sess_ = std::make_unique<Ort::Session>( | ||
| 113 | + env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); | ||
| 114 | + | ||
| 115 | + GetInputNames(joiner_sess_.get(), &joiner_input_names_, | ||
| 116 | + &joiner_input_names_ptr_); | ||
| 117 | + | ||
| 118 | + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, | ||
| 119 | + &joiner_output_names_ptr_); | ||
| 120 | + | ||
| 121 | + // get meta data | ||
| 122 | + Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata(); | ||
| 123 | + if (config_.debug) { | ||
| 124 | + std::ostringstream os; | ||
| 125 | + os << "---joiner---\n"; | ||
| 126 | + PrintModelMetadata(os, meta_data); | ||
| 127 | + fprintf(stderr, "%s\n", os.str().c_str()); | ||
| 128 | + } | ||
| 129 | +} | ||
| 130 | + | ||
| 131 | +std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates( | ||
| 132 | + const std::vector<std::vector<Ort::Value>> &states) const { | ||
| 133 | + int32_t batch_size = static_cast<int32_t>(states.size()); | ||
| 134 | + int32_t num_encoders = static_cast<int32_t>(num_encoder_layers_.size()); | ||
| 135 | + | ||
| 136 | + std::vector<const Ort::Value *> buf(batch_size); | ||
| 137 | + | ||
| 138 | + std::vector<Ort::Value> ans; | ||
| 139 | + ans.reserve(states[0].size()); | ||
| 140 | + | ||
| 141 | + // cached_len | ||
| 142 | + for (int32_t i = 0; i != num_encoders; ++i) { | ||
| 143 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 144 | + buf[n] = &states[n][i]; | ||
| 145 | + } | ||
| 146 | + auto v = Cat<int64_t>(allocator_, buf, 1); // (num_layers, 1) | ||
| 147 | + ans.push_back(std::move(v)); | ||
| 148 | + } | ||
| 149 | + | ||
| 150 | + // cached_avg | ||
| 151 | + for (int32_t i = 0; i != num_encoders; ++i) { | ||
| 152 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 153 | + buf[n] = &states[n][num_encoders + i]; | ||
| 154 | + } | ||
| 155 | + auto v = Cat(allocator_, buf, 1); // (num_layers, 1, encoder_dims) | ||
| 156 | + ans.push_back(std::move(v)); | ||
| 157 | + } | ||
| 158 | + | ||
| 159 | + // cached_key | ||
| 160 | + for (int32_t i = 0; i != num_encoders; ++i) { | ||
| 161 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 162 | + buf[n] = &states[n][num_encoders * 2 + i]; | ||
| 163 | + } | ||
| 164 | + // (num_layers, left_context_len, 1, attention_dims) | ||
| 165 | + auto v = Cat(allocator_, buf, 2); | ||
| 166 | + ans.push_back(std::move(v)); | ||
| 167 | + } | ||
| 168 | + | ||
| 169 | + // cached_val | ||
| 170 | + for (int32_t i = 0; i != num_encoders; ++i) { | ||
| 171 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 172 | + buf[n] = &states[n][num_encoders * 3 + i]; | ||
| 173 | + } | ||
| 174 | + // (num_layers, left_context_len, 1, attention_dims/2) | ||
| 175 | + auto v = Cat(allocator_, buf, 2); | ||
| 176 | + ans.push_back(std::move(v)); | ||
| 177 | + } | ||
| 178 | + | ||
| 179 | + // cached_val2 | ||
| 180 | + for (int32_t i = 0; i != num_encoders; ++i) { | ||
| 181 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 182 | + buf[n] = &states[n][num_encoders * 4 + i]; | ||
| 183 | + } | ||
| 184 | + // (num_layers, left_context_len, 1, attention_dims/2) | ||
| 185 | + auto v = Cat(allocator_, buf, 2); | ||
| 186 | + ans.push_back(std::move(v)); | ||
| 187 | + } | ||
| 188 | + | ||
| 189 | + // cached_conv1 | ||
| 190 | + for (int32_t i = 0; i != num_encoders; ++i) { | ||
| 191 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 192 | + buf[n] = &states[n][num_encoders * 5 + i]; | ||
| 193 | + } | ||
| 194 | + // (num_layers, 1, encoder_dims, cnn_module_kernels-1) | ||
| 195 | + auto v = Cat(allocator_, buf, 1); | ||
| 196 | + ans.push_back(std::move(v)); | ||
| 197 | + } | ||
| 198 | + | ||
| 199 | + // cached_conv2 | ||
| 200 | + for (int32_t i = 0; i != num_encoders; ++i) { | ||
| 201 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 202 | + buf[n] = &states[n][num_encoders * 6 + i]; | ||
| 203 | + } | ||
| 204 | + // (num_layers, 1, encoder_dims, cnn_module_kernels-1) | ||
| 205 | + auto v = Cat(allocator_, buf, 1); | ||
| 206 | + ans.push_back(std::move(v)); | ||
| 207 | + } | ||
| 208 | + | ||
| 209 | + return ans; | ||
| 210 | +} | ||
| 211 | + | ||
| 212 | +std::vector<std::vector<Ort::Value>> | ||
| 213 | +OnlineZipformerTransducerModel::UnStackStates( | ||
| 214 | + const std::vector<Ort::Value> &states) const { | ||
| 215 | + assert(states.size() == num_encoder_layers_.size() * 7); | ||
| 216 | + | ||
| 217 | + int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1]; | ||
| 218 | + int32_t num_encoders = num_encoder_layers_.size(); | ||
| 219 | + | ||
| 220 | + std::vector<std::vector<Ort::Value>> ans; | ||
| 221 | + ans.resize(batch_size); | ||
| 222 | + | ||
| 223 | + // cached_len | ||
| 224 | + for (int32_t i = 0; i != num_encoders; ++i) { | ||
| 225 | + auto v = Unbind<int64_t>(allocator_, &states[i], 1); | ||
| 226 | + assert(v.size() == batch_size); | ||
| 227 | + | ||
| 228 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 229 | + ans[n].push_back(std::move(v[n])); | ||
| 230 | + } | ||
| 231 | + } | ||
| 232 | + | ||
| 233 | + // cached_avg | ||
| 234 | + for (int32_t i = num_encoders; i != 2 * num_encoders; ++i) { | ||
| 235 | + auto v = Unbind(allocator_, &states[i], 1); | ||
| 236 | + assert(v.size() == batch_size); | ||
| 237 | + | ||
| 238 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 239 | + ans[n].push_back(std::move(v[n])); | ||
| 240 | + } | ||
| 241 | + } | ||
| 242 | + | ||
| 243 | + // cached_key | ||
| 244 | + for (int32_t i = 2 * num_encoders; i != 3 * num_encoders; ++i) { | ||
| 245 | + auto v = Unbind(allocator_, &states[i], 2); | ||
| 246 | + assert(v.size() == batch_size); | ||
| 247 | + | ||
| 248 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 249 | + ans[n].push_back(std::move(v[n])); | ||
| 250 | + } | ||
| 251 | + } | ||
| 252 | + | ||
| 253 | + // cached_val | ||
| 254 | + for (int32_t i = 3 * num_encoders; i != 4 * num_encoders; ++i) { | ||
| 255 | + auto v = Unbind(allocator_, &states[i], 2); | ||
| 256 | + assert(v.size() == batch_size); | ||
| 257 | + | ||
| 258 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 259 | + ans[n].push_back(std::move(v[n])); | ||
| 260 | + } | ||
| 261 | + } | ||
| 262 | + | ||
| 263 | + // cached_val2 | ||
| 264 | + for (int32_t i = 4 * num_encoders; i != 5 * num_encoders; ++i) { | ||
| 265 | + auto v = Unbind(allocator_, &states[i], 2); | ||
| 266 | + assert(v.size() == batch_size); | ||
| 267 | + | ||
| 268 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 269 | + ans[n].push_back(std::move(v[n])); | ||
| 270 | + } | ||
| 271 | + } | ||
| 272 | + | ||
| 273 | + // cached_conv1 | ||
| 274 | + for (int32_t i = 5 * num_encoders; i != 6 * num_encoders; ++i) { | ||
| 275 | + auto v = Unbind(allocator_, &states[i], 1); | ||
| 276 | + assert(v.size() == batch_size); | ||
| 277 | + | ||
| 278 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 279 | + ans[n].push_back(std::move(v[n])); | ||
| 280 | + } | ||
| 281 | + } | ||
| 282 | + | ||
| 283 | + // cached_conv2 | ||
| 284 | + for (int32_t i = 6 * num_encoders; i != 7 * num_encoders; ++i) { | ||
| 285 | + auto v = Unbind(allocator_, &states[i], 1); | ||
| 286 | + assert(v.size() == batch_size); | ||
| 287 | + | ||
| 288 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 289 | + ans[n].push_back(std::move(v[n])); | ||
| 290 | + } | ||
| 291 | + } | ||
| 292 | + | ||
| 293 | + return ans; | ||
| 294 | +} | ||
| 295 | + | ||
| 296 | +std::vector<Ort::Value> OnlineZipformerTransducerModel::GetEncoderInitStates() { | ||
| 297 | + // Please see | ||
| 298 | + // https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py#L673 | ||
| 299 | + // for details | ||
| 300 | + | ||
| 301 | + int32_t n = static_cast<int32_t>(encoder_dims_.size()); | ||
| 302 | + std::vector<Ort::Value> cached_len_vec; | ||
| 303 | + std::vector<Ort::Value> cached_avg_vec; | ||
| 304 | + std::vector<Ort::Value> cached_key_vec; | ||
| 305 | + std::vector<Ort::Value> cached_val_vec; | ||
| 306 | + std::vector<Ort::Value> cached_val2_vec; | ||
| 307 | + std::vector<Ort::Value> cached_conv1_vec; | ||
| 308 | + std::vector<Ort::Value> cached_conv2_vec; | ||
| 309 | + | ||
| 310 | + cached_len_vec.reserve(n); | ||
| 311 | + cached_avg_vec.reserve(n); | ||
| 312 | + cached_key_vec.reserve(n); | ||
| 313 | + cached_val_vec.reserve(n); | ||
| 314 | + cached_val2_vec.reserve(n); | ||
| 315 | + cached_conv1_vec.reserve(n); | ||
| 316 | + cached_conv2_vec.reserve(n); | ||
| 317 | + | ||
| 318 | + for (int32_t i = 0; i != n; ++i) { | ||
| 319 | + { | ||
| 320 | + std::array<int64_t, 2> s{num_encoder_layers_[i], 1}; | ||
| 321 | + auto v = | ||
| 322 | + Ort::Value::CreateTensor<int64_t>(allocator_, s.data(), s.size()); | ||
| 323 | + Fill<int64_t>(&v, 0); | ||
| 324 | + cached_len_vec.push_back(std::move(v)); | ||
| 325 | + } | ||
| 326 | + | ||
| 327 | + { | ||
| 328 | + std::array<int64_t, 3> s{num_encoder_layers_[i], 1, encoder_dims_[i]}; | ||
| 329 | + auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size()); | ||
| 330 | + Fill(&v, 0); | ||
| 331 | + cached_avg_vec.push_back(std::move(v)); | ||
| 332 | + } | ||
| 333 | + | ||
| 334 | + { | ||
| 335 | + std::array<int64_t, 4> s{num_encoder_layers_[i], left_context_len_[i], 1, | ||
| 336 | + attention_dims_[i]}; | ||
| 337 | + auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size()); | ||
| 338 | + Fill(&v, 0); | ||
| 339 | + cached_key_vec.push_back(std::move(v)); | ||
| 340 | + } | ||
| 341 | + | ||
| 342 | + { | ||
| 343 | + std::array<int64_t, 4> s{num_encoder_layers_[i], left_context_len_[i], 1, | ||
| 344 | + attention_dims_[i] / 2}; | ||
| 345 | + auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size()); | ||
| 346 | + Fill(&v, 0); | ||
| 347 | + cached_val_vec.push_back(std::move(v)); | ||
| 348 | + } | ||
| 349 | + | ||
| 350 | + { | ||
| 351 | + std::array<int64_t, 4> s{num_encoder_layers_[i], left_context_len_[i], 1, | ||
| 352 | + attention_dims_[i] / 2}; | ||
| 353 | + auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size()); | ||
| 354 | + Fill(&v, 0); | ||
| 355 | + cached_val2_vec.push_back(std::move(v)); | ||
| 356 | + } | ||
| 357 | + | ||
| 358 | + { | ||
| 359 | + std::array<int64_t, 4> s{num_encoder_layers_[i], 1, encoder_dims_[i], | ||
| 360 | + cnn_module_kernels_[i] - 1}; | ||
| 361 | + auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size()); | ||
| 362 | + Fill(&v, 0); | ||
| 363 | + cached_conv1_vec.push_back(std::move(v)); | ||
| 364 | + } | ||
| 365 | + | ||
| 366 | + { | ||
| 367 | + std::array<int64_t, 4> s{num_encoder_layers_[i], 1, encoder_dims_[i], | ||
| 368 | + cnn_module_kernels_[i] - 1}; | ||
| 369 | + auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size()); | ||
| 370 | + Fill(&v, 0); | ||
| 371 | + cached_conv2_vec.push_back(std::move(v)); | ||
| 372 | + } | ||
| 373 | + } | ||
| 374 | + | ||
| 375 | + std::vector<Ort::Value> ans; | ||
| 376 | + ans.reserve(n * 7); | ||
| 377 | + | ||
| 378 | + for (auto &v : cached_len_vec) ans.push_back(std::move(v)); | ||
| 379 | + for (auto &v : cached_avg_vec) ans.push_back(std::move(v)); | ||
| 380 | + for (auto &v : cached_key_vec) ans.push_back(std::move(v)); | ||
| 381 | + for (auto &v : cached_val_vec) ans.push_back(std::move(v)); | ||
| 382 | + for (auto &v : cached_val2_vec) ans.push_back(std::move(v)); | ||
| 383 | + for (auto &v : cached_conv1_vec) ans.push_back(std::move(v)); | ||
| 384 | + for (auto &v : cached_conv2_vec) ans.push_back(std::move(v)); | ||
| 385 | + | ||
| 386 | + return ans; | ||
| 387 | +} | ||
| 388 | + | ||
| 389 | +std::pair<Ort::Value, std::vector<Ort::Value>> | ||
| 390 | +OnlineZipformerTransducerModel::RunEncoder(Ort::Value features, | ||
| 391 | + std::vector<Ort::Value> states) { | ||
| 392 | + auto memory_info = | ||
| 393 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 394 | + | ||
| 395 | + std::vector<Ort::Value> encoder_inputs; | ||
| 396 | + encoder_inputs.reserve(1 + states.size()); | ||
| 397 | + | ||
| 398 | + encoder_inputs.push_back(std::move(features)); | ||
| 399 | + for (auto &v : states) { | ||
| 400 | + encoder_inputs.push_back(std::move(v)); | ||
| 401 | + } | ||
| 402 | + | ||
| 403 | + auto encoder_out = encoder_sess_->Run( | ||
| 404 | + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(), | ||
| 405 | + encoder_inputs.size(), encoder_output_names_ptr_.data(), | ||
| 406 | + encoder_output_names_ptr_.size()); | ||
| 407 | + | ||
| 408 | + std::vector<Ort::Value> next_states; | ||
| 409 | + next_states.reserve(states.size()); | ||
| 410 | + | ||
| 411 | + for (int32_t i = 1; i != static_cast<int32_t>(encoder_out.size()); ++i) { | ||
| 412 | + next_states.push_back(std::move(encoder_out[i])); | ||
| 413 | + } | ||
| 414 | + | ||
| 415 | + return {std::move(encoder_out[0]), std::move(next_states)}; | ||
| 416 | +} | ||
| 417 | + | ||
| 418 | +Ort::Value OnlineZipformerTransducerModel::BuildDecoderInput( | ||
| 419 | + const std::vector<OnlineTransducerDecoderResult> &results) { | ||
| 420 | + int32_t batch_size = static_cast<int32_t>(results.size()); | ||
| 421 | + std::array<int64_t, 2> shape{batch_size, context_size_}; | ||
| 422 | + Ort::Value decoder_input = | ||
| 423 | + Ort::Value::CreateTensor<int64_t>(allocator_, shape.data(), shape.size()); | ||
| 424 | + int64_t *p = decoder_input.GetTensorMutableData<int64_t>(); | ||
| 425 | + | ||
| 426 | + for (const auto &r : results) { | ||
| 427 | + const int64_t *begin = r.tokens.data() + r.tokens.size() - context_size_; | ||
| 428 | + const int64_t *end = r.tokens.data() + r.tokens.size(); | ||
| 429 | + std::copy(begin, end, p); | ||
| 430 | + p += context_size_; | ||
| 431 | + } | ||
| 432 | + | ||
| 433 | + return decoder_input; | ||
| 434 | +} | ||
| 435 | + | ||
| 436 | +Ort::Value OnlineZipformerTransducerModel::RunDecoder( | ||
| 437 | + Ort::Value decoder_input) { | ||
| 438 | + auto decoder_out = decoder_sess_->Run( | ||
| 439 | + {}, decoder_input_names_ptr_.data(), &decoder_input, 1, | ||
| 440 | + decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size()); | ||
| 441 | + return std::move(decoder_out[0]); | ||
| 442 | +} | ||
| 443 | + | ||
| 444 | +Ort::Value OnlineZipformerTransducerModel::RunJoiner(Ort::Value encoder_out, | ||
| 445 | + Ort::Value decoder_out) { | ||
| 446 | + std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out), | ||
| 447 | + std::move(decoder_out)}; | ||
| 448 | + auto logit = | ||
| 449 | + joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(), | ||
| 450 | + joiner_input.size(), joiner_output_names_ptr_.data(), | ||
| 451 | + joiner_output_names_ptr_.size()); | ||
| 452 | + | ||
| 453 | + return std::move(logit[0]); | ||
| 454 | +} | ||
| 455 | + | ||
| 456 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/online-zipformer-transducer-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER_TRANSDUCER_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER_TRANSDUCER_MODEL_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <string> | ||
| 9 | +#include <utility> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 13 | +#include "sherpa-onnx/csrc/online-transducer-model-config.h" | ||
| 14 | +#include "sherpa-onnx/csrc/online-transducer-model.h" | ||
| 15 | + | ||
| 16 | +namespace sherpa_onnx { | ||
| 17 | + | ||
| 18 | +class OnlineZipformerTransducerModel : public OnlineTransducerModel { | ||
| 19 | + public: | ||
| 20 | + explicit OnlineZipformerTransducerModel( | ||
| 21 | + const OnlineTransducerModelConfig &config); | ||
| 22 | + | ||
| 23 | + std::vector<Ort::Value> StackStates( | ||
| 24 | + const std::vector<std::vector<Ort::Value>> &states) const override; | ||
| 25 | + | ||
| 26 | + std::vector<std::vector<Ort::Value>> UnStackStates( | ||
| 27 | + const std::vector<Ort::Value> &states) const override; | ||
| 28 | + | ||
| 29 | + std::vector<Ort::Value> GetEncoderInitStates() override; | ||
| 30 | + | ||
| 31 | + std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( | ||
| 32 | + Ort::Value features, std::vector<Ort::Value> states) override; | ||
| 33 | + | ||
| 34 | + Ort::Value BuildDecoderInput( | ||
| 35 | + const std::vector<OnlineTransducerDecoderResult> &results) override; | ||
| 36 | + | ||
| 37 | + Ort::Value RunDecoder(Ort::Value decoder_input) override; | ||
| 38 | + | ||
| 39 | + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override; | ||
| 40 | + | ||
| 41 | + int32_t ContextSize() const override { return context_size_; } | ||
| 42 | + | ||
| 43 | + int32_t ChunkSize() const override { return T_; } | ||
| 44 | + | ||
| 45 | + int32_t ChunkShift() const override { return decode_chunk_len_; } | ||
| 46 | + | ||
| 47 | + int32_t VocabSize() const override { return vocab_size_; } | ||
| 48 | + OrtAllocator *Allocator() override { return allocator_; } | ||
| 49 | + | ||
| 50 | + private: | ||
| 51 | + void InitEncoder(const std::string &encoder_filename); | ||
| 52 | + void InitDecoder(const std::string &decoder_filename); | ||
| 53 | + void InitJoiner(const std::string &joiner_filename); | ||
| 54 | + | ||
| 55 | + private: | ||
| 56 | + Ort::Env env_; | ||
| 57 | + Ort::SessionOptions sess_opts_; | ||
| 58 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 59 | + | ||
| 60 | + std::unique_ptr<Ort::Session> encoder_sess_; | ||
| 61 | + std::unique_ptr<Ort::Session> decoder_sess_; | ||
| 62 | + std::unique_ptr<Ort::Session> joiner_sess_; | ||
| 63 | + | ||
| 64 | + std::vector<std::string> encoder_input_names_; | ||
| 65 | + std::vector<const char *> encoder_input_names_ptr_; | ||
| 66 | + | ||
| 67 | + std::vector<std::string> encoder_output_names_; | ||
| 68 | + std::vector<const char *> encoder_output_names_ptr_; | ||
| 69 | + | ||
| 70 | + std::vector<std::string> decoder_input_names_; | ||
| 71 | + std::vector<const char *> decoder_input_names_ptr_; | ||
| 72 | + | ||
| 73 | + std::vector<std::string> decoder_output_names_; | ||
| 74 | + std::vector<const char *> decoder_output_names_ptr_; | ||
| 75 | + | ||
| 76 | + std::vector<std::string> joiner_input_names_; | ||
| 77 | + std::vector<const char *> joiner_input_names_ptr_; | ||
| 78 | + | ||
| 79 | + std::vector<std::string> joiner_output_names_; | ||
| 80 | + std::vector<const char *> joiner_output_names_ptr_; | ||
| 81 | + | ||
| 82 | + OnlineTransducerModelConfig config_; | ||
| 83 | + | ||
| 84 | + std::vector<int32_t> encoder_dims_; | ||
| 85 | + std::vector<int32_t> attention_dims_; | ||
| 86 | + std::vector<int32_t> num_encoder_layers_; | ||
| 87 | + std::vector<int32_t> cnn_module_kernels_; | ||
| 88 | + std::vector<int32_t> left_context_len_; | ||
| 89 | + | ||
| 90 | + int32_t T_ = 0; | ||
| 91 | + int32_t decode_chunk_len_ = 0; | ||
| 92 | + | ||
| 93 | + int32_t context_size_ = 0; | ||
| 94 | + int32_t vocab_size_ = 0; | ||
| 95 | +}; | ||
| 96 | + | ||
| 97 | +} // namespace sherpa_onnx | ||
| 98 | + | ||
| 99 | +#endif // SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER_TRANSDUCER_MODEL_H_ |
| @@ -46,16 +46,74 @@ void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) { | @@ -46,16 +46,74 @@ void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) { | ||
| 46 | } | 46 | } |
| 47 | } | 47 | } |
| 48 | 48 | ||
| 49 | -Ort::Value Clone(Ort::Value *v) { | 49 | +Ort::Value Clone(const Ort::Value *v) { |
| 50 | auto type_and_shape = v->GetTensorTypeAndShapeInfo(); | 50 | auto type_and_shape = v->GetTensorTypeAndShapeInfo(); |
| 51 | std::vector<int64_t> shape = type_and_shape.GetShape(); | 51 | std::vector<int64_t> shape = type_and_shape.GetShape(); |
| 52 | 52 | ||
| 53 | auto memory_info = | 53 | auto memory_info = |
| 54 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | 54 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); |
| 55 | 55 | ||
| 56 | - return Ort::Value::CreateTensor(memory_info, v->GetTensorMutableData<float>(), | ||
| 57 | - type_and_shape.GetElementCount(), | ||
| 58 | - shape.data(), shape.size()); | 56 | + switch (type_and_shape.GetElementType()) { |
| 57 | + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: | ||
| 58 | + return Ort::Value::CreateTensor( | ||
| 59 | + memory_info, | ||
| 60 | + const_cast<Ort::Value *>(v)->GetTensorMutableData<int32_t>(), | ||
| 61 | + type_and_shape.GetElementCount(), shape.data(), shape.size()); | ||
| 62 | + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: | ||
| 63 | + return Ort::Value::CreateTensor( | ||
| 64 | + memory_info, | ||
| 65 | + const_cast<Ort::Value *>(v)->GetTensorMutableData<int64_t>(), | ||
| 66 | + type_and_shape.GetElementCount(), shape.data(), shape.size()); | ||
| 67 | + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: | ||
| 68 | + return Ort::Value::CreateTensor( | ||
| 69 | + memory_info, | ||
| 70 | + const_cast<Ort::Value *>(v)->GetTensorMutableData<float>(), | ||
| 71 | + type_and_shape.GetElementCount(), shape.data(), shape.size()); | ||
| 72 | + default: | ||
| 73 | + fprintf(stderr, "Unsupported type: %d\n", | ||
| 74 | + static_cast<int32_t>(type_and_shape.GetElementType())); | ||
| 75 | + exit(-1); | ||
| 76 | + // unreachable code | ||
| 77 | + return Ort::Value{nullptr}; | ||
| 78 | + } | ||
| 79 | +} | ||
| 80 | + | ||
| 81 | +void Print1D(Ort::Value *v) { | ||
| 82 | + std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 83 | + const float *d = v->GetTensorData<float>(); | ||
| 84 | + for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) { | ||
| 85 | + fprintf(stderr, "%.3f ", d[i]); | ||
| 86 | + } | ||
| 87 | + fprintf(stderr, "\n"); | ||
| 88 | +} | ||
| 89 | + | ||
| 90 | +void Print2D(Ort::Value *v) { | ||
| 91 | + std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 92 | + const float *d = v->GetTensorData<float>(); | ||
| 93 | + | ||
| 94 | + for (int32_t r = 0; r != static_cast<int32_t>(shape[0]); ++r) { | ||
| 95 | + for (int32_t c = 0; c != static_cast<int32_t>(shape[1]); ++c, ++d) { | ||
| 96 | + fprintf(stderr, "%.3f ", *d); | ||
| 97 | + } | ||
| 98 | + fprintf(stderr, "\n"); | ||
| 99 | + } | ||
| 100 | + fprintf(stderr, "\n"); | ||
| 101 | +} | ||
| 102 | + | ||
| 103 | +void Print3D(Ort::Value *v) { | ||
| 104 | + std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 105 | + const float *d = v->GetTensorData<float>(); | ||
| 106 | + | ||
| 107 | + for (int32_t p = 0; p != static_cast<int32_t>(shape[0]); ++p) { | ||
| 108 | + fprintf(stderr, "---plane %d---\n", p); | ||
| 109 | + for (int32_t r = 0; r != static_cast<int32_t>(shape[1]); ++r) { | ||
| 110 | + for (int32_t c = 0; c != static_cast<int32_t>(shape[2]); ++c, ++d) { | ||
| 111 | + fprintf(stderr, "%.3f ", *d); | ||
| 112 | + } | ||
| 113 | + fprintf(stderr, "\n"); | ||
| 114 | + } | ||
| 115 | + } | ||
| 116 | + fprintf(stderr, "\n"); | ||
| 59 | } | 117 | } |
| 60 | 118 | ||
| 61 | } // namespace sherpa_onnx | 119 | } // namespace sherpa_onnx |
| @@ -56,7 +56,23 @@ void PrintModelMetadata(std::ostream &os, | @@ -56,7 +56,23 @@ void PrintModelMetadata(std::ostream &os, | ||
| 56 | const Ort::ModelMetadata &meta_data); // NOLINT | 56 | const Ort::ModelMetadata &meta_data); // NOLINT |
| 57 | 57 | ||
| 58 | // Return a shallow copy of v | 58 | // Return a shallow copy of v |
| 59 | -Ort::Value Clone(Ort::Value *v); | 59 | +Ort::Value Clone(const Ort::Value *v); |
| 60 | + | ||
| 61 | +// Print a 1-D tensor to stderr | ||
| 62 | +void Print1D(Ort::Value *v); | ||
| 63 | + | ||
| 64 | +// Print a 2-D tensor to stderr | ||
| 65 | +void Print2D(Ort::Value *v); | ||
| 66 | + | ||
| 67 | +// Print a 3-D tensor to stderr | ||
| 68 | +void Print3D(Ort::Value *v); | ||
| 69 | + | ||
| 70 | +template <typename T = float> | ||
| 71 | +void Fill(Ort::Value *tensor, T value) { | ||
| 72 | + auto n = tensor->GetTypeInfo().GetTensorTypeAndShapeInfo().GetElementCount(); | ||
| 73 | + auto p = tensor->GetTensorMutableData<T>(); | ||
| 74 | + std::fill(p, p + n, value); | ||
| 75 | +} | ||
| 60 | 76 | ||
| 61 | } // namespace sherpa_onnx | 77 | } // namespace sherpa_onnx |
| 62 | 78 |
sherpa-onnx/csrc/text-utils.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/text-utils.cc | ||
| 2 | +// | ||
| 3 | +// Copyright 2009-2011 Saarland University; Microsoft Corporation | ||
| 4 | +// Copyright 2023 Xiaomi Corporation | ||
| 5 | + | ||
| 6 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 7 | + | ||
| 8 | +#include <string> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +// This file is copied/modified from | ||
| 12 | +// https://github.com/kaldi-asr/kaldi/blob/master/src/util/text-utils.cc | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +void SplitStringToVector(const std::string &full, const char *delim, | ||
| 17 | + bool omit_empty_strings, | ||
| 18 | + std::vector<std::string> *out) { | ||
| 19 | + size_t start = 0, found = 0, end = full.size(); | ||
| 20 | + out->clear(); | ||
| 21 | + while (found != std::string::npos) { | ||
| 22 | + found = full.find_first_of(delim, start); | ||
| 23 | + // start != end condition is for when the delimiter is at the end | ||
| 24 | + if (!omit_empty_strings || (found != start && start != end)) | ||
| 25 | + out->push_back(full.substr(start, found - start)); | ||
| 26 | + start = found + 1; | ||
| 27 | + } | ||
| 28 | +} | ||
| 29 | + | ||
| 30 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/text-utils.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/text-utils.h | ||
| 2 | +// | ||
| 3 | +// Copyright 2009-2011 Saarland University; Microsoft Corporation | ||
| 4 | +// Copyright 2023 Xiaomi Corporation | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_TEXT_UTILS_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_TEXT_UTILS_H_ | ||
| 7 | +#include <stdlib.h> | ||
| 8 | + | ||
| 9 | +#include <string> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#ifdef _MSC_VER | ||
| 13 | +#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) \ | ||
| 14 | + _strtoi64(cur_cstr, end_cstr, 10); | ||
| 15 | +#else | ||
| 16 | +#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) strtoll(cur_cstr, end_cstr, 10); | ||
| 17 | +#endif | ||
| 18 | + | ||
| 19 | +// This file is copied/modified from | ||
| 20 | +// https://github.com/kaldi-asr/kaldi/blob/master/src/util/text-utils.h | ||
| 21 | + | ||
| 22 | +namespace sherpa_onnx { | ||
| 23 | + | ||
| 24 | +/// Split a string using any of the single character delimiters. | ||
| 25 | +/// If omit_empty_strings == true, the output will contain any | ||
| 26 | +/// nonempty strings after splitting on any of the | ||
| 27 | +/// characters in the delimiter. If omit_empty_strings == false, | ||
| 28 | +/// the output will contain n+1 strings if there are n characters | ||
| 29 | +/// in the set "delim" within the input string. In this case | ||
| 30 | +/// the empty string is split to a single empty string. | ||
| 31 | +void SplitStringToVector(const std::string &full, const char *delim, | ||
| 32 | + bool omit_empty_strings, | ||
| 33 | + std::vector<std::string> *out); | ||
| 34 | + | ||
| 35 | +/** | ||
| 36 | + \brief Split a string (e.g. 1:2:3) into a vector of integers. | ||
| 37 | + | ||
| 38 | + \param [in] delim String containing a list of characters, any of which | ||
| 39 | + is allowed as a delimiter. | ||
| 40 | + \param [in] omit_empty_strings If true, empty strings between delimiters are | ||
| 41 | + allowed and will not produce an output integer; if false, | ||
| 42 | + instances of characters in 'delim' that are consecutive or | ||
| 43 | + at the start or end of the string would be an error. | ||
| 44 | + You'll normally want this to be true if 'delim' consists | ||
| 45 | + of spaces, and false otherwise. | ||
| 46 | + \param [out] out The output list of integers. | ||
| 47 | +*/ | ||
| 48 | +template <class I> | ||
| 49 | +bool SplitStringToIntegers(const std::string &full, const char *delim, | ||
| 50 | + bool omit_empty_strings, // typically false [but | ||
| 51 | + // should probably be true | ||
| 52 | + // if "delim" is spaces]. | ||
| 53 | + std::vector<I> *out) { | ||
| 54 | + static_assert(std::is_integral<I>::value, ""); | ||
| 55 | + if (*(full.c_str()) == '\0') { | ||
| 56 | + out->clear(); | ||
| 57 | + return true; | ||
| 58 | + } | ||
| 59 | + std::vector<std::string> split; | ||
| 60 | + SplitStringToVector(full, delim, omit_empty_strings, &split); | ||
| 61 | + out->resize(split.size()); | ||
| 62 | + for (size_t i = 0; i < split.size(); i++) { | ||
| 63 | + const char *this_str = split[i].c_str(); | ||
| 64 | + char *end = NULL; | ||
| 65 | + int64_t j = 0; | ||
| 66 | + j = SHERPA_ONNX_STRTOLL(this_str, &end); | ||
| 67 | + if (end == this_str || *end != '\0') { | ||
| 68 | + out->clear(); | ||
| 69 | + return false; | ||
| 70 | + } else { | ||
| 71 | + I jI = static_cast<I>(j); | ||
| 72 | + if (static_cast<int64_t>(jI) != j) { | ||
| 73 | + // output type cannot fit this integer. | ||
| 74 | + out->clear(); | ||
| 75 | + return false; | ||
| 76 | + } | ||
| 77 | + (*out)[i] = jI; | ||
| 78 | + } | ||
| 79 | + } | ||
| 80 | + return true; | ||
| 81 | +} | ||
| 82 | + | ||
| 83 | +} // namespace sherpa_onnx | ||
| 84 | + | ||
| 85 | +#endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_ |
sherpa-onnx/csrc/unbind-test.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/unbind-test.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/unbind.h" | ||
| 6 | + | ||
| 7 | +#include "gtest/gtest.h" | ||
| 8 | +#include "sherpa-onnx/csrc/cat.h" | ||
| 9 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +TEST(Ubind, Test1DTensors) { | ||
| 14 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 15 | + std::array<int64_t, 1> shape{3}; | ||
| 16 | + Ort::Value v = | ||
| 17 | + Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size()); | ||
| 18 | + float *p = v.GetTensorMutableData<float>(); | ||
| 19 | + | ||
| 20 | + for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) { | ||
| 21 | + p[i] = i; | ||
| 22 | + } | ||
| 23 | + auto ans = Unbind(allocator, &v, 0); | ||
| 24 | + EXPECT_EQ(ans.size(), shape[0]); | ||
| 25 | + for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) { | ||
| 26 | + EXPECT_EQ(ans[i].GetTensorData<float>()[0], p[i]); | ||
| 27 | + } | ||
| 28 | + Print1D(&v); | ||
| 29 | + for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) { | ||
| 30 | + Print1D(&ans[i]); | ||
| 31 | + } | ||
| 32 | + | ||
| 33 | + // For Cat | ||
| 34 | + std::vector<const Ort::Value *> vec(ans.size()); | ||
| 35 | + for (int32_t i = 0; i != static_cast<int32_t>(vec.size()); ++i) { | ||
| 36 | + vec[i] = &ans[i]; | ||
| 37 | + } | ||
| 38 | + Ort::Value v2 = Cat(allocator, vec, 0); | ||
| 39 | + const float *p2 = v2.GetTensorData<float>(); | ||
| 40 | + for (int32_t i = 0; i != shape[0]; ++i) { | ||
| 41 | + EXPECT_EQ(p[i], p2[i]); | ||
| 42 | + } | ||
| 43 | +} | ||
| 44 | + | ||
| 45 | +TEST(Ubind, Test2DTensorsDim0) { | ||
| 46 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 47 | + std::array<int64_t, 2> shape{3, 2}; | ||
| 48 | + Ort::Value v = | ||
| 49 | + Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size()); | ||
| 50 | + float *p = v.GetTensorMutableData<float>(); | ||
| 51 | + | ||
| 52 | + for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1]); ++i) { | ||
| 53 | + p[i] = i; | ||
| 54 | + } | ||
| 55 | + auto ans = Unbind(allocator, &v, 0); | ||
| 56 | + | ||
| 57 | + Print2D(&v); | ||
| 58 | + for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) { | ||
| 59 | + Print2D(&ans[i]); | ||
| 60 | + } | ||
| 61 | + | ||
| 62 | + for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) { | ||
| 63 | + const float *pans = ans[i].GetTensorData<float>(); | ||
| 64 | + for (int32_t k = 0; k != static_cast<int32_t>(shape[1]); ++k, ++p) { | ||
| 65 | + EXPECT_EQ(*p, pans[k]); | ||
| 66 | + } | ||
| 67 | + } | ||
| 68 | + | ||
| 69 | + // For Cat | ||
| 70 | + std::vector<const Ort::Value *> vec(ans.size()); | ||
| 71 | + for (int32_t i = 0; i != static_cast<int32_t>(vec.size()); ++i) { | ||
| 72 | + vec[i] = &ans[i]; | ||
| 73 | + } | ||
| 74 | + Ort::Value v2 = Cat(allocator, vec, 0); | ||
| 75 | + Print2D(&v2); | ||
| 76 | + | ||
| 77 | + p = v.GetTensorMutableData<float>(); | ||
| 78 | + const float *p2 = v2.GetTensorData<float>(); | ||
| 79 | + for (int32_t i = 0; i != shape[0] * shape[1]; ++i) { | ||
| 80 | + EXPECT_EQ(p[i], p2[i]); | ||
| 81 | + } | ||
| 82 | +} | ||
| 83 | + | ||
| 84 | +TEST(Ubind, Test2DTensorsDim1) { | ||
| 85 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 86 | + std::array<int64_t, 2> shape{3, 2}; | ||
| 87 | + Ort::Value v = | ||
| 88 | + Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size()); | ||
| 89 | + float *p = v.GetTensorMutableData<float>(); | ||
| 90 | + | ||
| 91 | + for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1]); ++i) { | ||
| 92 | + p[i] = i; | ||
| 93 | + } | ||
| 94 | + auto ans = Unbind(allocator, &v, 1); | ||
| 95 | + | ||
| 96 | + Print2D(&v); | ||
| 97 | + for (int32_t i = 0; i != static_cast<int32_t>(shape[1]); ++i) { | ||
| 98 | + Print2D(&ans[i]); | ||
| 99 | + } | ||
| 100 | + | ||
| 101 | + // For Cat | ||
| 102 | + std::vector<const Ort::Value *> vec(ans.size()); | ||
| 103 | + for (int32_t i = 0; i != static_cast<int32_t>(vec.size()); ++i) { | ||
| 104 | + vec[i] = &ans[i]; | ||
| 105 | + } | ||
| 106 | + Ort::Value v2 = Cat(allocator, vec, 1); | ||
| 107 | + Print2D(&v2); | ||
| 108 | + | ||
| 109 | + p = v.GetTensorMutableData<float>(); | ||
| 110 | + const float *p2 = v2.GetTensorData<float>(); | ||
| 111 | + for (int32_t i = 0; i != shape[0] * shape[1]; ++i) { | ||
| 112 | + EXPECT_EQ(p[i], p2[i]); | ||
| 113 | + } | ||
| 114 | +} | ||
| 115 | + | ||
| 116 | +TEST(Ubind, Test3DTensorsDim0) { | ||
| 117 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 118 | + std::array<int64_t, 3> shape{3, 2, 5}; | ||
| 119 | + Ort::Value v = | ||
| 120 | + Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size()); | ||
| 121 | + float *p = v.GetTensorMutableData<float>(); | ||
| 122 | + | ||
| 123 | + for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1] * shape[2]); | ||
| 124 | + ++i) { | ||
| 125 | + p[i] = i; | ||
| 126 | + } | ||
| 127 | + auto ans = Unbind(allocator, &v, 0); | ||
| 128 | + | ||
| 129 | + Print3D(&v); | ||
| 130 | + for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) { | ||
| 131 | + Print3D(&ans[i]); | ||
| 132 | + } | ||
| 133 | + | ||
| 134 | + for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) { | ||
| 135 | + const float *pans = ans[i].GetTensorData<float>(); | ||
| 136 | + for (int32_t k = 0; k != static_cast<int32_t>(shape[1] * shape[2]); | ||
| 137 | + ++k, ++p) { | ||
| 138 | + EXPECT_EQ(*p, pans[k]); | ||
| 139 | + } | ||
| 140 | + } | ||
| 141 | + | ||
| 142 | + // For Cat | ||
| 143 | + std::vector<const Ort::Value *> vec(ans.size()); | ||
| 144 | + for (int32_t i = 0; i != static_cast<int32_t>(vec.size()); ++i) { | ||
| 145 | + vec[i] = &ans[i]; | ||
| 146 | + } | ||
| 147 | + Ort::Value v2 = Cat(allocator, vec, 0); | ||
| 148 | + Print3D(&v2); | ||
| 149 | + | ||
| 150 | + p = v.GetTensorMutableData<float>(); | ||
| 151 | + const float *p2 = v2.GetTensorData<float>(); | ||
| 152 | + for (int32_t i = 0; i != shape[0] * shape[1] * shape[2]; ++i) { | ||
| 153 | + EXPECT_EQ(p[i], p2[i]); | ||
| 154 | + } | ||
| 155 | +} | ||
| 156 | + | ||
| 157 | +TEST(Ubind, Test3DTensorsDim1) { | ||
| 158 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 159 | + std::array<int64_t, 3> shape{3, 2, 5}; | ||
| 160 | + Ort::Value v = | ||
| 161 | + Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size()); | ||
| 162 | + float *p = v.GetTensorMutableData<float>(); | ||
| 163 | + | ||
| 164 | + for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1] * shape[2]); | ||
| 165 | + ++i) { | ||
| 166 | + p[i] = i; | ||
| 167 | + } | ||
| 168 | + auto ans = Unbind(allocator, &v, 1); | ||
| 169 | + | ||
| 170 | + Print3D(&v); | ||
| 171 | + for (int32_t i = 0; i != static_cast<int32_t>(shape[1]); ++i) { | ||
| 172 | + Print3D(&ans[i]); | ||
| 173 | + } | ||
| 174 | + | ||
| 175 | + // For Cat | ||
| 176 | + std::vector<const Ort::Value *> vec(ans.size()); | ||
| 177 | + for (int32_t i = 0; i != static_cast<int32_t>(vec.size()); ++i) { | ||
| 178 | + vec[i] = &ans[i]; | ||
| 179 | + } | ||
| 180 | + Ort::Value v2 = Cat(allocator, vec, 1); | ||
| 181 | + Print3D(&v2); | ||
| 182 | + | ||
| 183 | + p = v.GetTensorMutableData<float>(); | ||
| 184 | + const float *p2 = v2.GetTensorData<float>(); | ||
| 185 | + for (int32_t i = 0; i != shape[0] * shape[1] * shape[2]; ++i) { | ||
| 186 | + EXPECT_EQ(p[i], p2[i]); | ||
| 187 | + } | ||
| 188 | +} | ||
| 189 | + | ||
| 190 | +TEST(Ubind, Test3DTensorsDim2) { | ||
| 191 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 192 | + std::array<int64_t, 3> shape{3, 2, 5}; | ||
| 193 | + Ort::Value v = | ||
| 194 | + Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size()); | ||
| 195 | + float *p = v.GetTensorMutableData<float>(); | ||
| 196 | + | ||
| 197 | + for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1] * shape[2]); | ||
| 198 | + ++i) { | ||
| 199 | + p[i] = i; | ||
| 200 | + } | ||
| 201 | + auto ans = Unbind(allocator, &v, 2); | ||
| 202 | + | ||
| 203 | + Print3D(&v); | ||
| 204 | + for (int32_t i = 0; i != static_cast<int32_t>(shape[2]); ++i) { | ||
| 205 | + Print3D(&ans[i]); | ||
| 206 | + } | ||
| 207 | + | ||
| 208 | + // For Cat | ||
| 209 | + std::vector<const Ort::Value *> vec(ans.size()); | ||
| 210 | + for (int32_t i = 0; i != static_cast<int32_t>(vec.size()); ++i) { | ||
| 211 | + vec[i] = &ans[i]; | ||
| 212 | + } | ||
| 213 | + Ort::Value v2 = Cat(allocator, vec, 2); | ||
| 214 | + Print3D(&v2); | ||
| 215 | + | ||
| 216 | + p = v.GetTensorMutableData<float>(); | ||
| 217 | + const float *p2 = v2.GetTensorData<float>(); | ||
| 218 | + for (int32_t i = 0; i != shape[0] * shape[1] * shape[2]; ++i) { | ||
| 219 | + EXPECT_EQ(p[i], p2[i]); | ||
| 220 | + } | ||
| 221 | +} | ||
| 222 | + | ||
| 223 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/unbind.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/unbind.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/unbind.h" | ||
| 6 | + | ||
| 7 | +#include <assert.h> | ||
| 8 | + | ||
| 9 | +#include <algorithm> | ||
| 10 | +#include <functional> | ||
| 11 | +#include <numeric> | ||
| 12 | +#include <utility> | ||
| 13 | +#include <vector> | ||
| 14 | + | ||
| 15 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 16 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 17 | + | ||
| 18 | +namespace sherpa_onnx { | ||
| 19 | + | ||
| 20 | +template <typename T /*= float*/> | ||
| 21 | +std::vector<Ort::Value> Unbind(OrtAllocator *allocator, const Ort::Value *value, | ||
| 22 | + int32_t dim) { | ||
| 23 | + std::vector<int64_t> shape = value->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 24 | + assert(dim >= 0); | ||
| 25 | + assert(dim < static_cast<int32_t>(shape.size())); | ||
| 26 | + int32_t n = static_cast<int32_t>(shape[dim]); | ||
| 27 | + if (n == 1) { | ||
| 28 | + std::vector<Ort::Value> ans; | ||
| 29 | + ans.push_back(Clone(value)); | ||
| 30 | + return ans; | ||
| 31 | + } | ||
| 32 | + | ||
| 33 | + std::vector<int64_t> ans_shape = shape; | ||
| 34 | + ans_shape[dim] = 1; // // Unlike torch, we keep the dim to 1 | ||
| 35 | + | ||
| 36 | + // allocator tensors | ||
| 37 | + std::vector<Ort::Value> ans; | ||
| 38 | + ans.reserve(n); | ||
| 39 | + for (int32_t i = 0; i != n; ++i) { | ||
| 40 | + Ort::Value t = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(), | ||
| 41 | + ans_shape.size()); | ||
| 42 | + ans.push_back(std::move(t)); | ||
| 43 | + } | ||
| 44 | + | ||
| 45 | + auto leading_size = static_cast<int32_t>(std::accumulate( | ||
| 46 | + shape.begin(), shape.begin() + dim, 1, std::multiplies<int64_t>())); | ||
| 47 | + | ||
| 48 | + auto trailing_size = static_cast<int32_t>(std::accumulate( | ||
| 49 | + shape.begin() + dim + 1, shape.end(), 1, std::multiplies<int64_t>())); | ||
| 50 | + | ||
| 51 | + const T *src = value->GetTensorData<T>(); | ||
| 52 | + | ||
| 53 | + for (int32_t i = 0; i != leading_size; ++i) { | ||
| 54 | + for (int32_t k = 0; k != n; ++k) { | ||
| 55 | + T *dst = ans[k].GetTensorMutableData<T>() + i * trailing_size; | ||
| 56 | + std::copy(src, src + trailing_size, dst); | ||
| 57 | + src += trailing_size; | ||
| 58 | + } | ||
| 59 | + } | ||
| 60 | + | ||
| 61 | + return std::move(ans); | ||
| 62 | +} | ||
| 63 | + | ||
| 64 | +template std::vector<Ort::Value> Unbind<float>(OrtAllocator *allocator, | ||
| 65 | + const Ort::Value *value, | ||
| 66 | + int32_t dim); | ||
| 67 | + | ||
| 68 | +template std::vector<Ort::Value> Unbind<int64_t>(OrtAllocator *allocator, | ||
| 69 | + const Ort::Value *value, | ||
| 70 | + int32_t dim); | ||
| 71 | + | ||
| 72 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/unbind.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/unbind.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_UNBIND_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_UNBIND_H_ | ||
| 6 | + | ||
| 7 | +#include <vector> | ||
| 8 | + | ||
| 9 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +/** It is similar to torch.unbind() but we keep the unbind dim to 1 in | ||
| 14 | + * the output | ||
| 15 | + * | ||
| 16 | + * @param allocator Allocator to allocate space for the returned tensor | ||
| 17 | + * @param value The tensor to unbind | ||
| 18 | + * @param dim The dim along which to unbind the tensor | ||
| 19 | + * | ||
| 20 | + * @return Return a list of tensors | ||
| 21 | + */ | ||
| 22 | +template <typename T = float> | ||
| 23 | +std::vector<Ort::Value> Unbind(OrtAllocator *allocator, const Ort::Value *value, | ||
| 24 | + int32_t dim); | ||
| 25 | + | ||
| 26 | +} // namespace sherpa_onnx | ||
| 27 | + | ||
| 28 | +#endif // SHERPA_ONNX_CSRC_UNBIND_H_ |
-
请 注册 或 登录 后发表评论