Committed by
GitHub
Support building GPU-capable sherpa-onnx on Linux aarch64. (#1500)
Thanks to @Peakyxh for providing pre-built onnxruntime libraries with CUDA support for Linux aarch64. Tested on Jetson nano b01
正在显示
41 个修改的文件
包含
546 行增加
和
300 行删除
| @@ -34,11 +34,12 @@ concurrency: | @@ -34,11 +34,12 @@ concurrency: | ||
| 34 | jobs: | 34 | jobs: |
| 35 | aarch64_linux_gnu_shared: | 35 | aarch64_linux_gnu_shared: |
| 36 | runs-on: ${{ matrix.os }} | 36 | runs-on: ${{ matrix.os }} |
| 37 | - name: aarch64 shared lib test | 37 | + name: aarch64 shared GPU ${{ matrix.gpu }} |
| 38 | strategy: | 38 | strategy: |
| 39 | fail-fast: false | 39 | fail-fast: false |
| 40 | matrix: | 40 | matrix: |
| 41 | os: [ubuntu-latest] | 41 | os: [ubuntu-latest] |
| 42 | + gpu: [ON, OFF] | ||
| 42 | 43 | ||
| 43 | steps: | 44 | steps: |
| 44 | - uses: actions/checkout@v4 | 45 | - uses: actions/checkout@v4 |
| @@ -79,15 +80,24 @@ jobs: | @@ -79,15 +80,24 @@ jobs: | ||
| 79 | make -j2 | 80 | make -j2 |
| 80 | make install | 81 | make install |
| 81 | 82 | ||
| 82 | - - name: cache-toolchain | ||
| 83 | - id: cache-toolchain | 83 | + - name: cache-toolchain (CPU) |
| 84 | + if: matrix.gpu == 'OFF' | ||
| 85 | + id: cache-toolchain-cpu | ||
| 84 | uses: actions/cache@v4 | 86 | uses: actions/cache@v4 |
| 85 | with: | 87 | with: |
| 86 | path: toolchain | 88 | path: toolchain |
| 87 | key: gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz | 89 | key: gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz |
| 88 | 90 | ||
| 89 | - - name: Download toolchain | ||
| 90 | - if: steps.cache-toolchain.outputs.cache-hit != 'true' | 91 | + - name: cache-toolchain (GPU) |
| 92 | + if: matrix.gpu == 'ON' | ||
| 93 | + id: cache-toolchain-gpu | ||
| 94 | + uses: actions/cache@v4 | ||
| 95 | + with: | ||
| 96 | + path: toolchain | ||
| 97 | + key: gcc-arm-10.3-2021.07-x86_64-aarch64-none-linux-gnu.tar.xz | ||
| 98 | + | ||
| 99 | + - name: Download toolchain (CPU, gcc 7.5) | ||
| 100 | + if: steps.cache-toolchain-cpu.outputs.cache-hit != 'true' && matrix.gpu == 'OFF' | ||
| 91 | shell: bash | 101 | shell: bash |
| 92 | run: | | 102 | run: | |
| 93 | wget -qq https://huggingface.co/csukuangfj/sherpa-ncnn-toolchains/resolve/main/gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz | 103 | wget -qq https://huggingface.co/csukuangfj/sherpa-ncnn-toolchains/resolve/main/gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz |
| @@ -95,6 +105,15 @@ jobs: | @@ -95,6 +105,15 @@ jobs: | ||
| 95 | mkdir $GITHUB_WORKSPACE/toolchain | 105 | mkdir $GITHUB_WORKSPACE/toolchain |
| 96 | tar xf ./gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz --strip-components 1 -C $GITHUB_WORKSPACE/toolchain | 106 | tar xf ./gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz --strip-components 1 -C $GITHUB_WORKSPACE/toolchain |
| 97 | 107 | ||
| 108 | + - name: Download toolchain (GPU, gcc 10.3) | ||
| 109 | + if: steps.cache-toolchain-gpu.outputs.cache-hit != 'true' && matrix.gpu == 'ON' | ||
| 110 | + shell: bash | ||
| 111 | + run: | | ||
| 112 | + wget -qq https://huggingface.co/csukuangfj/sherpa-ncnn-toolchains/resolve/main/gcc-arm-10.3-2021.07-x86_64-aarch64-none-linux-gnu.tar.xz | ||
| 113 | + | ||
| 114 | + mkdir $GITHUB_WORKSPACE/toolchain | ||
| 115 | + tar xf ./gcc-arm-10.3-2021.07-x86_64-aarch64-none-linux-gnu.tar.xz --strip-components 1 -C $GITHUB_WORKSPACE/toolchain | ||
| 116 | + | ||
| 98 | - name: Set environment variable | 117 | - name: Set environment variable |
| 99 | if: steps.cache-build-result.outputs.cache-hit != 'true' | 118 | if: steps.cache-build-result.outputs.cache-hit != 'true' |
| 100 | shell: bash | 119 | shell: bash |
| @@ -103,19 +122,31 @@ jobs: | @@ -103,19 +122,31 @@ jobs: | ||
| 103 | echo "$GITHUB_WORKSPACE/bin" >> "$GITHUB_PATH" | 122 | echo "$GITHUB_WORKSPACE/bin" >> "$GITHUB_PATH" |
| 104 | ls -lh "$GITHUB_WORKSPACE/toolchain/bin" | 123 | ls -lh "$GITHUB_WORKSPACE/toolchain/bin" |
| 105 | 124 | ||
| 106 | - echo "CC=aarch64-linux-gnu-gcc" >> "$GITHUB_ENV" | ||
| 107 | - echo "CXX=aarch64-linux-gnu-g++" >> "$GITHUB_ENV" | 125 | + if [[ ${{ matrix.gpu }} == OFF ]]; then |
| 126 | + echo "CC=aarch64-linux-gnu-gcc" >> "$GITHUB_ENV" | ||
| 127 | + echo "CXX=aarch64-linux-gnu-g++" >> "$GITHUB_ENV" | ||
| 128 | + else | ||
| 129 | + echo "CC=aarch64-none-linux-gnu-gcc" >> "$GITHUB_ENV" | ||
| 130 | + echo "CXX=aarch64-none-linux-gnu-g++" >> "$GITHUB_ENV" | ||
| 131 | + fi | ||
| 108 | 132 | ||
| 109 | - name: Display toolchain info | 133 | - name: Display toolchain info |
| 110 | shell: bash | 134 | shell: bash |
| 111 | run: | | 135 | run: | |
| 112 | - aarch64-linux-gnu-gcc --version | 136 | + if [[ ${{ matrix.gpu }} == OFF ]]; then |
| 137 | + which aarch64-linux-gnu-gcc | ||
| 138 | + aarch64-linux-gnu-gcc --version | ||
| 139 | + else | ||
| 140 | + which aarch64-none-linux-gnu-gcc | ||
| 141 | + aarch64-none-linux-gnu-gcc --version | ||
| 142 | + fi | ||
| 113 | 143 | ||
| 114 | - name: Display qemu-aarch64 -h | 144 | - name: Display qemu-aarch64 -h |
| 115 | shell: bash | 145 | shell: bash |
| 116 | run: | | 146 | run: | |
| 117 | export PATH=$GITHUB_WORKSPACE/qemu-install/bin:$PATH | 147 | export PATH=$GITHUB_WORKSPACE/qemu-install/bin:$PATH |
| 118 | export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-linux-gnu/libc | 148 | export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-linux-gnu/libc |
| 149 | + export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-none-linux-gnu/libc | ||
| 119 | qemu-aarch64 -h | 150 | qemu-aarch64 -h |
| 120 | 151 | ||
| 121 | - name: build aarch64-linux-gnu | 152 | - name: build aarch64-linux-gnu |
| @@ -127,6 +158,7 @@ jobs: | @@ -127,6 +158,7 @@ jobs: | ||
| 127 | cmake --version | 158 | cmake --version |
| 128 | 159 | ||
| 129 | export BUILD_SHARED_LIBS=ON | 160 | export BUILD_SHARED_LIBS=ON |
| 161 | + export SHERPA_ONNX_ENABLE_GPU=${{ matrix.gpu }} | ||
| 130 | 162 | ||
| 131 | ./build-aarch64-linux-gnu.sh | 163 | ./build-aarch64-linux-gnu.sh |
| 132 | 164 | ||
| @@ -140,7 +172,11 @@ jobs: | @@ -140,7 +172,11 @@ jobs: | ||
| 140 | run: | | 172 | run: | |
| 141 | export PATH=$GITHUB_WORKSPACE/toolchain/bin:$PATH | 173 | export PATH=$GITHUB_WORKSPACE/toolchain/bin:$PATH |
| 142 | export PATH=$GITHUB_WORKSPACE/qemu-install/bin:$PATH | 174 | export PATH=$GITHUB_WORKSPACE/qemu-install/bin:$PATH |
| 143 | - export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-linux-gnu/libc | 175 | + if [[ ${{ matrix.gpu }} == OFF ]]; then |
| 176 | + export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-linux-gnu/libc | ||
| 177 | + else | ||
| 178 | + export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-none-linux-gnu/libc | ||
| 179 | + fi | ||
| 144 | 180 | ||
| 145 | ls -lh ./build-aarch64-linux-gnu/bin | 181 | ls -lh ./build-aarch64-linux-gnu/bin |
| 146 | 182 | ||
| @@ -151,11 +187,20 @@ jobs: | @@ -151,11 +187,20 @@ jobs: | ||
| 151 | - name: Copy files | 187 | - name: Copy files |
| 152 | shell: bash | 188 | shell: bash |
| 153 | run: | | 189 | run: | |
| 154 | - aarch64-linux-gnu-strip --version | 190 | + if [[ ${{ matrix.gpu }} == OFF ]]; then |
| 191 | + aarch64-linux-gnu-strip --version | ||
| 192 | + else | ||
| 193 | + aarch64-none-linux-gnu-strip --version | ||
| 194 | + fi | ||
| 155 | 195 | ||
| 156 | SHERPA_ONNX_VERSION=v$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) | 196 | SHERPA_ONNX_VERSION=v$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) |
| 157 | 197 | ||
| 158 | dst=sherpa-onnx-${SHERPA_ONNX_VERSION}-linux-aarch64-shared | 198 | dst=sherpa-onnx-${SHERPA_ONNX_VERSION}-linux-aarch64-shared |
| 199 | + if [[ ${{ matrix.gpu }} == OFF ]]; then | ||
| 200 | + dst=${dst}-cpu | ||
| 201 | + else | ||
| 202 | + dst=${dst}-gpu | ||
| 203 | + fi | ||
| 159 | mkdir $dst | 204 | mkdir $dst |
| 160 | 205 | ||
| 161 | cp -a build-aarch64-linux-gnu/install/bin $dst/ | 206 | cp -a build-aarch64-linux-gnu/install/bin $dst/ |
| @@ -166,7 +211,11 @@ jobs: | @@ -166,7 +211,11 @@ jobs: | ||
| 166 | 211 | ||
| 167 | ls -lh $dst/bin/ | 212 | ls -lh $dst/bin/ |
| 168 | echo "strip" | 213 | echo "strip" |
| 169 | - aarch64-linux-gnu-strip $dst/bin/* | 214 | + if [[ ${{ matrix.gpu }} == OFF ]]; then |
| 215 | + aarch64-linux-gnu-strip $dst/bin/* | ||
| 216 | + else | ||
| 217 | + aarch64-none-linux-gnu-strip $dst/bin/* | ||
| 218 | + fi | ||
| 170 | 219 | ||
| 171 | tree $dst | 220 | tree $dst |
| 172 | 221 | ||
| @@ -174,8 +223,8 @@ jobs: | @@ -174,8 +223,8 @@ jobs: | ||
| 174 | 223 | ||
| 175 | - uses: actions/upload-artifact@v4 | 224 | - uses: actions/upload-artifact@v4 |
| 176 | with: | 225 | with: |
| 177 | - name: sherpa-onnx-linux-aarch64-shared | ||
| 178 | - path: sherpa-onnx-*linux-aarch64-shared.tar.bz2 | 226 | + name: sherpa-onnx-linux-aarch64-shared-gpu-${{ matrix.gpu }} |
| 227 | + path: sherpa-onnx-*linux-aarch64-shared*.tar.bz2 | ||
| 179 | 228 | ||
| 180 | # https://huggingface.co/docs/hub/spaces-github-actions | 229 | # https://huggingface.co/docs/hub/spaces-github-actions |
| 181 | - name: Publish to huggingface | 230 | - name: Publish to huggingface |
| @@ -198,7 +247,7 @@ jobs: | @@ -198,7 +247,7 @@ jobs: | ||
| 198 | cd huggingface | 247 | cd huggingface |
| 199 | mkdir -p aarch64 | 248 | mkdir -p aarch64 |
| 200 | 249 | ||
| 201 | - cp -v ../sherpa-onnx-*-shared.tar.bz2 ./aarch64 | 250 | + cp -v ../sherpa-onnx-*-shared*.tar.bz2 ./aarch64 |
| 202 | 251 | ||
| 203 | git status | 252 | git status |
| 204 | git lfs track "*.bz2" | 253 | git lfs track "*.bz2" |
| @@ -44,6 +44,21 @@ if [[ x"$BUILD_SHARED_LIBS" == x"" ]]; then | @@ -44,6 +44,21 @@ if [[ x"$BUILD_SHARED_LIBS" == x"" ]]; then | ||
| 44 | BUILD_SHARED_LIBS=OFF | 44 | BUILD_SHARED_LIBS=OFF |
| 45 | fi | 45 | fi |
| 46 | 46 | ||
| 47 | +if [[ x"$SHERPA_ONNX_ENABLE_GPU" == x"" ]]; then | ||
| 48 | + # By default, use CPU | ||
| 49 | + SHERPA_ONNX_ENABLE_GPU=OFF | ||
| 50 | + | ||
| 51 | + # If you use GPU, then please make sure you have NVIDIA GPUs on your board. | ||
| 52 | + # It uses onnxruntime 1.11.0. | ||
| 53 | + # | ||
| 54 | + # Tested on Jetson Nano B01 | ||
| 55 | +fi | ||
| 56 | + | ||
| 57 | +if [[ x"$SHERPA_ONNX_ENABLE_GPU" == x"ON" ]]; then | ||
| 58 | + # Build shared libs if building GPU is enabled. | ||
| 59 | + BUILD_SHARED_LIBS=ON | ||
| 60 | +fi | ||
| 61 | + | ||
| 47 | cmake \ | 62 | cmake \ |
| 48 | -DBUILD_PIPER_PHONMIZE_EXE=OFF \ | 63 | -DBUILD_PIPER_PHONMIZE_EXE=OFF \ |
| 49 | -DBUILD_PIPER_PHONMIZE_TESTS=OFF \ | 64 | -DBUILD_PIPER_PHONMIZE_TESTS=OFF \ |
| @@ -51,6 +66,7 @@ cmake \ | @@ -51,6 +66,7 @@ cmake \ | ||
| 51 | -DBUILD_ESPEAK_NG_TESTS=OFF \ | 66 | -DBUILD_ESPEAK_NG_TESTS=OFF \ |
| 52 | -DCMAKE_INSTALL_PREFIX=./install \ | 67 | -DCMAKE_INSTALL_PREFIX=./install \ |
| 53 | -DCMAKE_BUILD_TYPE=Release \ | 68 | -DCMAKE_BUILD_TYPE=Release \ |
| 69 | + -DSHERPA_ONNX_ENABLE_GPU=$SHERPA_ONNX_ENABLE_GPU \ | ||
| 54 | -DBUILD_SHARED_LIBS=$BUILD_SHARED_LIBS \ | 70 | -DBUILD_SHARED_LIBS=$BUILD_SHARED_LIBS \ |
| 55 | -DSHERPA_ONNX_ENABLE_TESTS=OFF \ | 71 | -DSHERPA_ONNX_ENABLE_TESTS=OFF \ |
| 56 | -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ | 72 | -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ |
cmake/onnxruntime-linux-aarch64-gpu.cmake
0 → 100644
| 1 | +# Copyright (c) 2022-2024 Xiaomi Corporation | ||
| 2 | +message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}") | ||
| 3 | +message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") | ||
| 4 | + | ||
| 5 | +if(NOT CMAKE_SYSTEM_NAME STREQUAL Linux) | ||
| 6 | + message(FATAL_ERROR "This file is for Linux only. Given: ${CMAKE_SYSTEM_NAME}") | ||
| 7 | +endif() | ||
| 8 | + | ||
| 9 | +if(NOT CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64) | ||
| 10 | + message(FATAL_ERROR "This file is for aarch64 only. Given: ${CMAKE_SYSTEM_PROCESSOR}") | ||
| 11 | +endif() | ||
| 12 | + | ||
| 13 | +if(NOT BUILD_SHARED_LIBS) | ||
| 14 | + message(FATAL_ERROR "This file is for building shared libraries. BUILD_SHARED_LIBS: ${BUILD_SHARED_LIBS}") | ||
| 15 | +endif() | ||
| 16 | + | ||
| 17 | +if(NOT SHERPA_ONNX_ENABLE_GPU) | ||
| 18 | + message(FATAL_ERROR "This file is for NVIDIA GPU only. Given SHERPA_ONNX_ENABLE_GPU: ${SHERPA_ONNX_ENABLE_GPU}") | ||
| 19 | +endif() | ||
| 20 | + | ||
| 21 | +set(onnxruntime_URL "https://github.com/csukuangfj/onnxruntime-libs/releases/download/v1.11.0/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2") | ||
| 22 | +set(onnxruntime_URL2 "https://hf-mirror.com/csukuangfj/onnxruntime-libs/resolve/main/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2") | ||
| 23 | +set(onnxruntime_HASH "SHA256=36eded935551e23aead09d4173bdf0bd1e7b01fdec15d77f97d6e34029aa60d7") | ||
| 24 | + | ||
| 25 | +# If you don't have access to the Internet, | ||
| 26 | +# please download onnxruntime to one of the following locations. | ||
| 27 | +# You can add more if you want. | ||
| 28 | +set(possible_file_locations | ||
| 29 | + $ENV{HOME}/Downloads/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2 | ||
| 30 | + ${CMAKE_SOURCE_DIR}/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2 | ||
| 31 | + ${CMAKE_BINARY_DIR}/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2 | ||
| 32 | + /tmp/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2 | ||
| 33 | + /star-fj/fangjun/download/github/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2 | ||
| 34 | +) | ||
| 35 | + | ||
| 36 | +foreach(f IN LISTS possible_file_locations) | ||
| 37 | + if(EXISTS ${f}) | ||
| 38 | + set(onnxruntime_URL "${f}") | ||
| 39 | + file(TO_CMAKE_PATH "${onnxruntime_URL}" onnxruntime_URL) | ||
| 40 | + message(STATUS "Found local downloaded onnxruntime: ${onnxruntime_URL}") | ||
| 41 | + set(onnxruntime_URL2) | ||
| 42 | + break() | ||
| 43 | + endif() | ||
| 44 | +endforeach() | ||
| 45 | + | ||
| 46 | +FetchContent_Declare(onnxruntime | ||
| 47 | + URL | ||
| 48 | + ${onnxruntime_URL} | ||
| 49 | + ${onnxruntime_URL2} | ||
| 50 | + URL_HASH ${onnxruntime_HASH} | ||
| 51 | +) | ||
| 52 | + | ||
| 53 | +FetchContent_GetProperties(onnxruntime) | ||
| 54 | +if(NOT onnxruntime_POPULATED) | ||
| 55 | + message(STATUS "Downloading onnxruntime from ${onnxruntime_URL}") | ||
| 56 | + FetchContent_Populate(onnxruntime) | ||
| 57 | +endif() | ||
| 58 | +message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}") | ||
| 59 | + | ||
| 60 | +find_library(location_onnxruntime onnxruntime | ||
| 61 | + PATHS | ||
| 62 | + "${onnxruntime_SOURCE_DIR}/lib" | ||
| 63 | + NO_CMAKE_SYSTEM_PATH | ||
| 64 | +) | ||
| 65 | + | ||
| 66 | +message(STATUS "location_onnxruntime: ${location_onnxruntime}") | ||
| 67 | + | ||
| 68 | +add_library(onnxruntime SHARED IMPORTED) | ||
| 69 | + | ||
| 70 | +set_target_properties(onnxruntime PROPERTIES | ||
| 71 | + IMPORTED_LOCATION ${location_onnxruntime} | ||
| 72 | + INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include" | ||
| 73 | +) | ||
| 74 | + | ||
| 75 | +find_library(location_onnxruntime_cuda_lib onnxruntime_providers_cuda | ||
| 76 | + PATHS | ||
| 77 | + "${onnxruntime_SOURCE_DIR}/lib" | ||
| 78 | + NO_CMAKE_SYSTEM_PATH | ||
| 79 | +) | ||
| 80 | + | ||
| 81 | +add_library(onnxruntime_providers_cuda SHARED IMPORTED) | ||
| 82 | +set_target_properties(onnxruntime_providers_cuda PROPERTIES | ||
| 83 | + IMPORTED_LOCATION ${location_onnxruntime_cuda_lib} | ||
| 84 | +) | ||
| 85 | +message(STATUS "location_onnxruntime_cuda_lib: ${location_onnxruntime_cuda_lib}") | ||
| 86 | + | ||
| 87 | +# for libonnxruntime_providers_shared.so | ||
| 88 | +find_library(location_onnxruntime_providers_shared_lib onnxruntime_providers_shared | ||
| 89 | + PATHS | ||
| 90 | + "${onnxruntime_SOURCE_DIR}/lib" | ||
| 91 | + NO_CMAKE_SYSTEM_PATH | ||
| 92 | +) | ||
| 93 | +add_library(onnxruntime_providers_shared SHARED IMPORTED) | ||
| 94 | +set_target_properties(onnxruntime_providers_shared PROPERTIES | ||
| 95 | + IMPORTED_LOCATION ${location_onnxruntime_providers_shared_lib} | ||
| 96 | +) | ||
| 97 | +message(STATUS "location_onnxruntime_providers_shared_lib: ${location_onnxruntime_providers_shared_lib}") | ||
| 98 | + | ||
| 99 | +file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/lib/libonnxruntime*") | ||
| 100 | +message(STATUS "onnxruntime lib files: ${onnxruntime_lib_files}") | ||
| 101 | +install(FILES ${onnxruntime_lib_files} DESTINATION lib) |
| @@ -13,7 +13,9 @@ function(download_onnxruntime) | @@ -13,7 +13,9 @@ function(download_onnxruntime) | ||
| 13 | include(onnxruntime-linux-riscv64-static) | 13 | include(onnxruntime-linux-riscv64-static) |
| 14 | endif() | 14 | endif() |
| 15 | elseif(CMAKE_SYSTEM_NAME STREQUAL Linux AND CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64) | 15 | elseif(CMAKE_SYSTEM_NAME STREQUAL Linux AND CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64) |
| 16 | - if(BUILD_SHARED_LIBS) | 16 | + if(SHERPA_ONNX_ENABLE_GPU) |
| 17 | + include(onnxruntime-linux-aarch64-gpu) | ||
| 18 | + elseif(BUILD_SHARED_LIBS) | ||
| 17 | include(onnxruntime-linux-aarch64) | 19 | include(onnxruntime-linux-aarch64) |
| 18 | else() | 20 | else() |
| 19 | include(onnxruntime-linux-aarch64-static) | 21 | include(onnxruntime-linux-aarch64-static) |
| 1 | function(download_piper_phonemize) | 1 | function(download_piper_phonemize) |
| 2 | include(FetchContent) | 2 | include(FetchContent) |
| 3 | 3 | ||
| 4 | - set(piper_phonemize_URL "https://github.com/csukuangfj/piper-phonemize/archive/dc6b5f4441bffe521047086930b0fc12686acd56.zip") | ||
| 5 | - set(piper_phonemize_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip") | ||
| 6 | - set(piper_phonemize_HASH "SHA256=b9faa04204b1756fa455a962abb1f037041c040133d55be58d11f11ab9b3ce14") | 4 | + set(piper_phonemize_URL "https://github.com/csukuangfj/piper-phonemize/archive/38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip") |
| 5 | + set(piper_phonemize_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip") | ||
| 6 | + set(piper_phonemize_HASH "SHA256=ab4d06ca76047e1585c63c482f39ffead5315785345055360703cc9382c5e74b") | ||
| 7 | 7 | ||
| 8 | # If you don't have access to the Internet, | 8 | # If you don't have access to the Internet, |
| 9 | # please pre-download kaldi-decoder | 9 | # please pre-download kaldi-decoder |
| 10 | set(possible_file_locations | 10 | set(possible_file_locations |
| 11 | - $ENV{HOME}/Downloads/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip | ||
| 12 | - ${CMAKE_SOURCE_DIR}/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip | ||
| 13 | - ${CMAKE_BINARY_DIR}/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip | ||
| 14 | - /tmp/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip | ||
| 15 | - /star-fj/fangjun/download/github/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip | 11 | + $ENV{HOME}/Downloads/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip |
| 12 | + ${CMAKE_SOURCE_DIR}/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip | ||
| 13 | + ${CMAKE_BINARY_DIR}/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip | ||
| 14 | + /tmp/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip | ||
| 15 | + /star-fj/fangjun/download/github/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip | ||
| 16 | ) | 16 | ) |
| 17 | 17 | ||
| 18 | foreach(f IN LISTS possible_file_locations) | 18 | foreach(f IN LISTS possible_file_locations) |
| @@ -7,6 +7,8 @@ | @@ -7,6 +7,8 @@ | ||
| 7 | #include <stdio.h> | 7 | #include <stdio.h> |
| 8 | #include <stdlib.h> | 8 | #include <stdlib.h> |
| 9 | 9 | ||
| 10 | +#include <utility> | ||
| 11 | + | ||
| 10 | #if __ANDROID_API__ >= 8 | 12 | #if __ANDROID_API__ >= 8 |
| 11 | #include "android/log.h" | 13 | #include "android/log.h" |
| 12 | #define SHERPA_ONNX_LOGE(...) \ | 14 | #define SHERPA_ONNX_LOGE(...) \ |
| @@ -36,30 +38,28 @@ | @@ -36,30 +38,28 @@ | ||
| 36 | #endif | 38 | #endif |
| 37 | 39 | ||
| 38 | // Read an integer | 40 | // Read an integer |
| 39 | -#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \ | ||
| 40 | - do { \ | ||
| 41 | - auto value = \ | ||
| 42 | - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | ||
| 43 | - if (!value) { \ | ||
| 44 | - SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ | ||
| 45 | - exit(-1); \ | ||
| 46 | - } \ | ||
| 47 | - \ | ||
| 48 | - dst = atoi(value.get()); \ | ||
| 49 | - if (dst < 0) { \ | ||
| 50 | - SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \ | ||
| 51 | - exit(-1); \ | ||
| 52 | - } \ | 41 | +#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \ |
| 42 | + do { \ | ||
| 43 | + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ | ||
| 44 | + if (value.empty()) { \ | ||
| 45 | + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ | ||
| 46 | + exit(-1); \ | ||
| 47 | + } \ | ||
| 48 | + \ | ||
| 49 | + dst = atoi(value.c_str()); \ | ||
| 50 | + if (dst < 0) { \ | ||
| 51 | + SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \ | ||
| 52 | + exit(-1); \ | ||
| 53 | + } \ | ||
| 53 | } while (0) | 54 | } while (0) |
| 54 | 55 | ||
| 55 | #define SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(dst, src_key, default_value) \ | 56 | #define SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(dst, src_key, default_value) \ |
| 56 | do { \ | 57 | do { \ |
| 57 | - auto value = \ | ||
| 58 | - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | ||
| 59 | - if (!value) { \ | 58 | + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ |
| 59 | + if (value.empty()) { \ | ||
| 60 | dst = default_value; \ | 60 | dst = default_value; \ |
| 61 | } else { \ | 61 | } else { \ |
| 62 | - dst = atoi(value.get()); \ | 62 | + dst = atoi(value.c_str()); \ |
| 63 | if (dst < 0) { \ | 63 | if (dst < 0) { \ |
| 64 | SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \ | 64 | SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \ |
| 65 | exit(-1); \ | 65 | exit(-1); \ |
| @@ -68,118 +68,111 @@ | @@ -68,118 +68,111 @@ | ||
| 68 | } while (0) | 68 | } while (0) |
| 69 | 69 | ||
| 70 | // read a vector of integers | 70 | // read a vector of integers |
| 71 | -#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \ | ||
| 72 | - do { \ | ||
| 73 | - auto value = \ | ||
| 74 | - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | ||
| 75 | - if (!value) { \ | ||
| 76 | - SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ | ||
| 77 | - exit(-1); \ | ||
| 78 | - } \ | ||
| 79 | - \ | ||
| 80 | - bool ret = SplitStringToIntegers(value.get(), ",", true, &dst); \ | ||
| 81 | - if (!ret) { \ | ||
| 82 | - SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.get(), src_key); \ | ||
| 83 | - exit(-1); \ | ||
| 84 | - } \ | 71 | +#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \ |
| 72 | + do { \ | ||
| 73 | + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ | ||
| 74 | + if (value.empty()) { \ | ||
| 75 | + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ | ||
| 76 | + exit(-1); \ | ||
| 77 | + } \ | ||
| 78 | + \ | ||
| 79 | + bool ret = SplitStringToIntegers(value.c_str(), ",", true, &dst); \ | ||
| 80 | + if (!ret) { \ | ||
| 81 | + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.c_str(), src_key); \ | ||
| 82 | + exit(-1); \ | ||
| 83 | + } \ | ||
| 85 | } while (0) | 84 | } while (0) |
| 86 | 85 | ||
| 87 | // read a vector of floats | 86 | // read a vector of floats |
| 88 | -#define SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(dst, src_key) \ | ||
| 89 | - do { \ | ||
| 90 | - auto value = \ | ||
| 91 | - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | ||
| 92 | - if (!value) { \ | ||
| 93 | - SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ | ||
| 94 | - exit(-1); \ | ||
| 95 | - } \ | ||
| 96 | - \ | ||
| 97 | - bool ret = SplitStringToFloats(value.get(), ",", true, &dst); \ | ||
| 98 | - if (!ret) { \ | ||
| 99 | - SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.get(), src_key); \ | ||
| 100 | - exit(-1); \ | ||
| 101 | - } \ | 87 | +#define SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(dst, src_key) \ |
| 88 | + do { \ | ||
| 89 | + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ | ||
| 90 | + if (value.empty()) { \ | ||
| 91 | + SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ | ||
| 92 | + exit(-1); \ | ||
| 93 | + } \ | ||
| 94 | + \ | ||
| 95 | + bool ret = SplitStringToFloats(value.c_str(), ",", true, &dst); \ | ||
| 96 | + if (!ret) { \ | ||
| 97 | + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.c_str(), src_key); \ | ||
| 98 | + exit(-1); \ | ||
| 99 | + } \ | ||
| 102 | } while (0) | 100 | } while (0) |
| 103 | 101 | ||
| 104 | // read a vector of strings | 102 | // read a vector of strings |
| 105 | -#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \ | ||
| 106 | - do { \ | ||
| 107 | - auto value = \ | ||
| 108 | - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | ||
| 109 | - if (!value) { \ | ||
| 110 | - SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ | ||
| 111 | - exit(-1); \ | ||
| 112 | - } \ | ||
| 113 | - SplitStringToVector(value.get(), ",", false, &dst); \ | ||
| 114 | - \ | ||
| 115 | - if (dst.empty()) { \ | ||
| 116 | - SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \ | ||
| 117 | - value.get(), src_key); \ | ||
| 118 | - exit(-1); \ | ||
| 119 | - } \ | 103 | +#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \ |
| 104 | + do { \ | ||
| 105 | + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ | ||
| 106 | + if (value.empty()) { \ | ||
| 107 | + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ | ||
| 108 | + exit(-1); \ | ||
| 109 | + } \ | ||
| 110 | + SplitStringToVector(value.c_str(), ",", false, &dst); \ | ||
| 111 | + \ | ||
| 112 | + if (dst.empty()) { \ | ||
| 113 | + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \ | ||
| 114 | + value.c_str(), src_key); \ | ||
| 115 | + exit(-1); \ | ||
| 116 | + } \ | ||
| 120 | } while (0) | 117 | } while (0) |
| 121 | 118 | ||
| 122 | // read a vector of strings separated by sep | 119 | // read a vector of strings separated by sep |
| 123 | -#define SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(dst, src_key, sep) \ | ||
| 124 | - do { \ | ||
| 125 | - auto value = \ | ||
| 126 | - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | ||
| 127 | - if (!value) { \ | ||
| 128 | - SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ | ||
| 129 | - exit(-1); \ | ||
| 130 | - } \ | ||
| 131 | - SplitStringToVector(value.get(), sep, false, &dst); \ | ||
| 132 | - \ | ||
| 133 | - if (dst.empty()) { \ | ||
| 134 | - SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \ | ||
| 135 | - value.get(), src_key); \ | ||
| 136 | - exit(-1); \ | ||
| 137 | - } \ | 120 | +#define SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(dst, src_key, sep) \ |
| 121 | + do { \ | ||
| 122 | + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ | ||
| 123 | + if (value.empty()) { \ | ||
| 124 | + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ | ||
| 125 | + exit(-1); \ | ||
| 126 | + } \ | ||
| 127 | + SplitStringToVector(value.c_str(), sep, false, &dst); \ | ||
| 128 | + \ | ||
| 129 | + if (dst.empty()) { \ | ||
| 130 | + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \ | ||
| 131 | + value.c_str(), src_key); \ | ||
| 132 | + exit(-1); \ | ||
| 133 | + } \ | ||
| 138 | } while (0) | 134 | } while (0) |
| 139 | 135 | ||
| 140 | // Read a string | 136 | // Read a string |
| 141 | -#define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \ | ||
| 142 | - do { \ | ||
| 143 | - auto value = \ | ||
| 144 | - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | ||
| 145 | - if (!value) { \ | ||
| 146 | - SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ | ||
| 147 | - exit(-1); \ | ||
| 148 | - } \ | ||
| 149 | - \ | ||
| 150 | - dst = value.get(); \ | ||
| 151 | - if (dst.empty()) { \ | ||
| 152 | - SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \ | ||
| 153 | - exit(-1); \ | ||
| 154 | - } \ | 137 | +#define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \ |
| 138 | + do { \ | ||
| 139 | + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ | ||
| 140 | + if (value.empty()) { \ | ||
| 141 | + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ | ||
| 142 | + exit(-1); \ | ||
| 143 | + } \ | ||
| 144 | + \ | ||
| 145 | + dst = std::move(value); \ | ||
| 146 | + if (dst.empty()) { \ | ||
| 147 | + SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \ | ||
| 148 | + exit(-1); \ | ||
| 149 | + } \ | ||
| 155 | } while (0) | 150 | } while (0) |
| 156 | 151 | ||
| 157 | -#define SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(dst, src_key) \ | ||
| 158 | - do { \ | ||
| 159 | - auto value = \ | ||
| 160 | - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | ||
| 161 | - if (!value) { \ | ||
| 162 | - SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ | ||
| 163 | - exit(-1); \ | ||
| 164 | - } \ | ||
| 165 | - \ | ||
| 166 | - dst = value.get(); \ | 152 | +#define SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(dst, src_key) \ |
| 153 | + do { \ | ||
| 154 | + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ | ||
| 155 | + if (value.empty()) { \ | ||
| 156 | + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ | ||
| 157 | + exit(-1); \ | ||
| 158 | + } \ | ||
| 159 | + \ | ||
| 160 | + dst = std::move(value); \ | ||
| 167 | } while (0) | 161 | } while (0) |
| 168 | 162 | ||
| 169 | -#define SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(dst, src_key, \ | ||
| 170 | - default_value) \ | ||
| 171 | - do { \ | ||
| 172 | - auto value = \ | ||
| 173 | - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | ||
| 174 | - if (!value) { \ | ||
| 175 | - dst = default_value; \ | ||
| 176 | - } else { \ | ||
| 177 | - dst = value.get(); \ | ||
| 178 | - if (dst.empty()) { \ | ||
| 179 | - SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \ | ||
| 180 | - exit(-1); \ | ||
| 181 | - } \ | ||
| 182 | - } \ | 163 | +#define SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(dst, src_key, \ |
| 164 | + default_value) \ | ||
| 165 | + do { \ | ||
| 166 | + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ | ||
| 167 | + if (value.empty()) { \ | ||
| 168 | + dst = default_value; \ | ||
| 169 | + } else { \ | ||
| 170 | + dst = std::move(value); \ | ||
| 171 | + if (dst.empty()) { \ | ||
| 172 | + SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \ | ||
| 173 | + exit(-1); \ | ||
| 174 | + } \ | ||
| 175 | + } \ | ||
| 183 | } while (0) | 176 | } while (0) |
| 184 | 177 | ||
| 185 | #define SHERPA_ONNX_EXIT(code) exit(code) | 178 | #define SHERPA_ONNX_EXIT(code) exit(code) |
| @@ -46,7 +46,7 @@ class OfflineCEDModel::Impl { | @@ -46,7 +46,7 @@ class OfflineCEDModel::Impl { | ||
| 46 | 46 | ||
| 47 | int32_t NumEventClasses() const { return num_event_classes_; } | 47 | int32_t NumEventClasses() const { return num_event_classes_; } |
| 48 | 48 | ||
| 49 | - OrtAllocator *Allocator() const { return allocator_; } | 49 | + OrtAllocator *Allocator() { return allocator_; } |
| 50 | 50 | ||
| 51 | private: | 51 | private: |
| 52 | void Init(void *model_data, size_t model_data_length) { | 52 | void Init(void *model_data, size_t model_data_length) { |
| @@ -44,7 +44,7 @@ class OfflineCtTransformerModel::Impl { | @@ -44,7 +44,7 @@ class OfflineCtTransformerModel::Impl { | ||
| 44 | return std::move(ans[0]); | 44 | return std::move(ans[0]); |
| 45 | } | 45 | } |
| 46 | 46 | ||
| 47 | - OrtAllocator *Allocator() const { return allocator_; } | 47 | + OrtAllocator *Allocator() { return allocator_; } |
| 48 | 48 | ||
| 49 | const OfflineCtTransformerModelMetaData &GetModelMetadata() const { | 49 | const OfflineCtTransformerModelMetaData &GetModelMetadata() const { |
| 50 | return meta_data_; | 50 | return meta_data_; |
| @@ -53,8 +53,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -53,8 +53,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 53 | 53 | ||
| 54 | Ort::AllocatorWithDefaultOptions allocator; | 54 | Ort::AllocatorWithDefaultOptions allocator; |
| 55 | auto model_type = | 55 | auto model_type = |
| 56 | - meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); | ||
| 57 | - if (!model_type) { | 56 | + LookupCustomModelMetaData(meta_data, "model_type", allocator); |
| 57 | + if (model_type.empty()) { | ||
| 58 | SHERPA_ONNX_LOGE( | 58 | SHERPA_ONNX_LOGE( |
| 59 | "No model_type in the metadata!\n" | 59 | "No model_type in the metadata!\n" |
| 60 | "If you are using models from NeMo, please refer to\n" | 60 | "If you are using models from NeMo, please refer to\n" |
| @@ -74,22 +74,22 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -74,22 +74,22 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 74 | return ModelType::kUnknown; | 74 | return ModelType::kUnknown; |
| 75 | } | 75 | } |
| 76 | 76 | ||
| 77 | - if (model_type.get() == std::string("EncDecCTCModelBPE")) { | 77 | + if (model_type == "EncDecCTCModelBPE") { |
| 78 | return ModelType::kEncDecCTCModelBPE; | 78 | return ModelType::kEncDecCTCModelBPE; |
| 79 | - } else if (model_type.get() == std::string("EncDecCTCModel")) { | 79 | + } else if (model_type == "EncDecCTCModel") { |
| 80 | return ModelType::kEncDecCTCModel; | 80 | return ModelType::kEncDecCTCModel; |
| 81 | - } else if (model_type.get() == std::string("EncDecHybridRNNTCTCBPEModel")) { | 81 | + } else if (model_type == "EncDecHybridRNNTCTCBPEModel") { |
| 82 | return ModelType::kEncDecHybridRNNTCTCBPEModel; | 82 | return ModelType::kEncDecHybridRNNTCTCBPEModel; |
| 83 | - } else if (model_type.get() == std::string("tdnn")) { | 83 | + } else if (model_type == "tdnn") { |
| 84 | return ModelType::kTdnn; | 84 | return ModelType::kTdnn; |
| 85 | - } else if (model_type.get() == std::string("zipformer2_ctc")) { | 85 | + } else if (model_type == "zipformer2_ctc") { |
| 86 | return ModelType::kZipformerCtc; | 86 | return ModelType::kZipformerCtc; |
| 87 | - } else if (model_type.get() == std::string("wenet_ctc")) { | 87 | + } else if (model_type == "wenet_ctc") { |
| 88 | return ModelType::kWenetCtc; | 88 | return ModelType::kWenetCtc; |
| 89 | - } else if (model_type.get() == std::string("telespeech_ctc")) { | 89 | + } else if (model_type == "telespeech_ctc") { |
| 90 | return ModelType::kTeleSpeechCtc; | 90 | return ModelType::kTeleSpeechCtc; |
| 91 | } else { | 91 | } else { |
| 92 | - SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); | 92 | + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str()); |
| 93 | return ModelType::kUnknown; | 93 | return ModelType::kUnknown; |
| 94 | } | 94 | } |
| 95 | } | 95 | } |
| @@ -155,7 +155,7 @@ class OfflineMoonshineModel::Impl { | @@ -155,7 +155,7 @@ class OfflineMoonshineModel::Impl { | ||
| 155 | return {std::move(cached_decoder_out[0]), std::move(next_states)}; | 155 | return {std::move(cached_decoder_out[0]), std::move(next_states)}; |
| 156 | } | 156 | } |
| 157 | 157 | ||
| 158 | - OrtAllocator *Allocator() const { return allocator_; } | 158 | + OrtAllocator *Allocator() { return allocator_; } |
| 159 | 159 | ||
| 160 | private: | 160 | private: |
| 161 | void InitPreprocessor(void *model_data, size_t model_data_length) { | 161 | void InitPreprocessor(void *model_data, size_t model_data_length) { |
| @@ -68,7 +68,7 @@ class OfflineNemoEncDecCtcModel::Impl { | @@ -68,7 +68,7 @@ class OfflineNemoEncDecCtcModel::Impl { | ||
| 68 | 68 | ||
| 69 | int32_t SubsamplingFactor() const { return subsampling_factor_; } | 69 | int32_t SubsamplingFactor() const { return subsampling_factor_; } |
| 70 | 70 | ||
| 71 | - OrtAllocator *Allocator() const { return allocator_; } | 71 | + OrtAllocator *Allocator() { return allocator_; } |
| 72 | 72 | ||
| 73 | std::string FeatureNormalizationMethod() const { return normalize_type_; } | 73 | std::string FeatureNormalizationMethod() const { return normalize_type_; } |
| 74 | 74 |
| @@ -56,7 +56,7 @@ class OfflineParaformerModel::Impl { | @@ -56,7 +56,7 @@ class OfflineParaformerModel::Impl { | ||
| 56 | 56 | ||
| 57 | const std::vector<float> &InverseStdDev() const { return inv_stddev_; } | 57 | const std::vector<float> &InverseStdDev() const { return inv_stddev_; } |
| 58 | 58 | ||
| 59 | - OrtAllocator *Allocator() const { return allocator_; } | 59 | + OrtAllocator *Allocator() { return allocator_; } |
| 60 | 60 | ||
| 61 | private: | 61 | private: |
| 62 | void Init(void *model_data, size_t model_data_length) { | 62 | void Init(void *model_data, size_t model_data_length) { |
| @@ -121,9 +121,9 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -121,9 +121,9 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 121 | 121 | ||
| 122 | Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | 122 | Ort::AllocatorWithDefaultOptions allocator; // used in the macro below |
| 123 | 123 | ||
| 124 | - auto model_type_ptr = | ||
| 125 | - meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); | ||
| 126 | - if (!model_type_ptr) { | 124 | + auto model_type = |
| 125 | + LookupCustomModelMetaData(meta_data, "model_type", allocator); | ||
| 126 | + if (!model_type.empty()) { | ||
| 127 | SHERPA_ONNX_LOGE( | 127 | SHERPA_ONNX_LOGE( |
| 128 | "No model_type in the metadata!\n\n" | 128 | "No model_type in the metadata!\n\n" |
| 129 | "Please refer to the following URLs to add metadata" | 129 | "Please refer to the following URLs to add metadata" |
| @@ -164,7 +164,6 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -164,7 +164,6 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 164 | "\n"); | 164 | "\n"); |
| 165 | exit(-1); | 165 | exit(-1); |
| 166 | } | 166 | } |
| 167 | - std::string model_type(model_type_ptr.get()); | ||
| 168 | 167 | ||
| 169 | if (model_type == "conformer" || model_type == "zipformer" || | 168 | if (model_type == "conformer" || model_type == "zipformer" || |
| 170 | model_type == "zipformer2") { | 169 | model_type == "zipformer2") { |
| @@ -301,9 +300,9 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -301,9 +300,9 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 301 | 300 | ||
| 302 | Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | 301 | Ort::AllocatorWithDefaultOptions allocator; // used in the macro below |
| 303 | 302 | ||
| 304 | - auto model_type_ptr = | ||
| 305 | - meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); | ||
| 306 | - if (!model_type_ptr) { | 303 | + auto model_type = |
| 304 | + LookupCustomModelMetaData(meta_data, "model_type", allocator); | ||
| 305 | + if (model_type.empty()) { | ||
| 307 | SHERPA_ONNX_LOGE( | 306 | SHERPA_ONNX_LOGE( |
| 308 | "No model_type in the metadata!\n\n" | 307 | "No model_type in the metadata!\n\n" |
| 309 | "Please refer to the following URLs to add metadata" | 308 | "Please refer to the following URLs to add metadata" |
| @@ -344,7 +343,6 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -344,7 +343,6 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 344 | "\n"); | 343 | "\n"); |
| 345 | exit(-1); | 344 | exit(-1); |
| 346 | } | 345 | } |
| 347 | - std::string model_type(model_type_ptr.get()); | ||
| 348 | 346 | ||
| 349 | if (model_type == "conformer" || model_type == "zipformer" || | 347 | if (model_type == "conformer" || model_type == "zipformer" || |
| 350 | model_type == "zipformer2") { | 348 | model_type == "zipformer2") { |
| @@ -56,7 +56,7 @@ class OfflineSenseVoiceModel::Impl { | @@ -56,7 +56,7 @@ class OfflineSenseVoiceModel::Impl { | ||
| 56 | return meta_data_; | 56 | return meta_data_; |
| 57 | } | 57 | } |
| 58 | 58 | ||
| 59 | - OrtAllocator *Allocator() const { return allocator_; } | 59 | + OrtAllocator *Allocator() { return allocator_; } |
| 60 | 60 | ||
| 61 | private: | 61 | private: |
| 62 | void Init(void *model_data, size_t model_data_length) { | 62 | void Init(void *model_data, size_t model_data_length) { |
| @@ -63,7 +63,7 @@ class OfflineTdnnCtcModel::Impl { | @@ -63,7 +63,7 @@ class OfflineTdnnCtcModel::Impl { | ||
| 63 | 63 | ||
| 64 | int32_t VocabSize() const { return vocab_size_; } | 64 | int32_t VocabSize() const { return vocab_size_; } |
| 65 | 65 | ||
| 66 | - OrtAllocator *Allocator() const { return allocator_; } | 66 | + OrtAllocator *Allocator() { return allocator_; } |
| 67 | 67 | ||
| 68 | private: | 68 | private: |
| 69 | void Init(void *model_data, size_t model_data_length) { | 69 | void Init(void *model_data, size_t model_data_length) { |
| @@ -69,7 +69,7 @@ class OfflineTeleSpeechCtcModel::Impl { | @@ -69,7 +69,7 @@ class OfflineTeleSpeechCtcModel::Impl { | ||
| 69 | 69 | ||
| 70 | int32_t SubsamplingFactor() const { return subsampling_factor_; } | 70 | int32_t SubsamplingFactor() const { return subsampling_factor_; } |
| 71 | 71 | ||
| 72 | - OrtAllocator *Allocator() const { return allocator_; } | 72 | + OrtAllocator *Allocator() { return allocator_; } |
| 73 | 73 | ||
| 74 | private: | 74 | private: |
| 75 | void Init(void *model_data, size_t model_data_length) { | 75 | void Init(void *model_data, size_t model_data_length) { |
| @@ -95,11 +95,11 @@ class OfflineTransducerModel::Impl { | @@ -95,11 +95,11 @@ class OfflineTransducerModel::Impl { | ||
| 95 | int32_t VocabSize() const { return vocab_size_; } | 95 | int32_t VocabSize() const { return vocab_size_; } |
| 96 | int32_t ContextSize() const { return context_size_; } | 96 | int32_t ContextSize() const { return context_size_; } |
| 97 | int32_t SubsamplingFactor() const { return 4; } | 97 | int32_t SubsamplingFactor() const { return 4; } |
| 98 | - OrtAllocator *Allocator() const { return allocator_; } | 98 | + OrtAllocator *Allocator() { return allocator_; } |
| 99 | 99 | ||
| 100 | Ort::Value BuildDecoderInput( | 100 | Ort::Value BuildDecoderInput( |
| 101 | const std::vector<OfflineTransducerDecoderResult> &results, | 101 | const std::vector<OfflineTransducerDecoderResult> &results, |
| 102 | - int32_t end_index) const { | 102 | + int32_t end_index) { |
| 103 | assert(end_index <= results.size()); | 103 | assert(end_index <= results.size()); |
| 104 | 104 | ||
| 105 | int32_t batch_size = end_index; | 105 | int32_t batch_size = end_index; |
| @@ -122,7 +122,7 @@ class OfflineTransducerModel::Impl { | @@ -122,7 +122,7 @@ class OfflineTransducerModel::Impl { | ||
| 122 | } | 122 | } |
| 123 | 123 | ||
| 124 | Ort::Value BuildDecoderInput(const std::vector<Hypothesis> &results, | 124 | Ort::Value BuildDecoderInput(const std::vector<Hypothesis> &results, |
| 125 | - int32_t end_index) const { | 125 | + int32_t end_index) { |
| 126 | assert(end_index <= results.size()); | 126 | assert(end_index <= results.size()); |
| 127 | 127 | ||
| 128 | int32_t batch_size = end_index; | 128 | int32_t batch_size = end_index; |
| @@ -123,7 +123,7 @@ class OfflineTransducerNeMoModel::Impl { | @@ -123,7 +123,7 @@ class OfflineTransducerNeMoModel::Impl { | ||
| 123 | return std::move(logit[0]); | 123 | return std::move(logit[0]); |
| 124 | } | 124 | } |
| 125 | 125 | ||
| 126 | - std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const { | 126 | + std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) { |
| 127 | std::array<int64_t, 3> s0_shape{pred_rnn_layers_, batch_size, pred_hidden_}; | 127 | std::array<int64_t, 3> s0_shape{pred_rnn_layers_, batch_size, pred_hidden_}; |
| 128 | Ort::Value s0 = Ort::Value::CreateTensor<float>(allocator_, s0_shape.data(), | 128 | Ort::Value s0 = Ort::Value::CreateTensor<float>(allocator_, s0_shape.data(), |
| 129 | s0_shape.size()); | 129 | s0_shape.size()); |
| @@ -149,7 +149,7 @@ class OfflineTransducerNeMoModel::Impl { | @@ -149,7 +149,7 @@ class OfflineTransducerNeMoModel::Impl { | ||
| 149 | int32_t SubsamplingFactor() const { return subsampling_factor_; } | 149 | int32_t SubsamplingFactor() const { return subsampling_factor_; } |
| 150 | int32_t VocabSize() const { return vocab_size_; } | 150 | int32_t VocabSize() const { return vocab_size_; } |
| 151 | 151 | ||
| 152 | - OrtAllocator *Allocator() const { return allocator_; } | 152 | + OrtAllocator *Allocator() { return allocator_; } |
| 153 | 153 | ||
| 154 | std::string FeatureNormalizationMethod() const { return normalize_type_; } | 154 | std::string FeatureNormalizationMethod() const { return normalize_type_; } |
| 155 | 155 |
| @@ -47,7 +47,7 @@ class OfflineWenetCtcModel::Impl { | @@ -47,7 +47,7 @@ class OfflineWenetCtcModel::Impl { | ||
| 47 | 47 | ||
| 48 | int32_t SubsamplingFactor() const { return subsampling_factor_; } | 48 | int32_t SubsamplingFactor() const { return subsampling_factor_; } |
| 49 | 49 | ||
| 50 | - OrtAllocator *Allocator() const { return allocator_; } | 50 | + OrtAllocator *Allocator() { return allocator_; } |
| 51 | 51 | ||
| 52 | private: | 52 | private: |
| 53 | void Init(void *model_data, size_t model_data_length) { | 53 | void Init(void *model_data, size_t model_data_length) { |
| @@ -188,7 +188,7 @@ class OfflineWhisperModel::Impl { | @@ -188,7 +188,7 @@ class OfflineWhisperModel::Impl { | ||
| 188 | return {std::move(n_layer_self_k_cache), std::move(n_layer_self_v_cache)}; | 188 | return {std::move(n_layer_self_k_cache), std::move(n_layer_self_v_cache)}; |
| 189 | } | 189 | } |
| 190 | 190 | ||
| 191 | - OrtAllocator *Allocator() const { return allocator_; } | 191 | + OrtAllocator *Allocator() { return allocator_; } |
| 192 | 192 | ||
| 193 | const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; } | 193 | const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; } |
| 194 | 194 |
| @@ -47,7 +47,7 @@ class OfflineZipformerAudioTaggingModel::Impl { | @@ -47,7 +47,7 @@ class OfflineZipformerAudioTaggingModel::Impl { | ||
| 47 | 47 | ||
| 48 | int32_t NumEventClasses() const { return num_event_classes_; } | 48 | int32_t NumEventClasses() const { return num_event_classes_; } |
| 49 | 49 | ||
| 50 | - OrtAllocator *Allocator() const { return allocator_; } | 50 | + OrtAllocator *Allocator() { return allocator_; } |
| 51 | 51 | ||
| 52 | private: | 52 | private: |
| 53 | void Init(void *model_data, size_t model_data_length) { | 53 | void Init(void *model_data, size_t model_data_length) { |
| @@ -48,7 +48,7 @@ class OfflineZipformerCtcModel::Impl { | @@ -48,7 +48,7 @@ class OfflineZipformerCtcModel::Impl { | ||
| 48 | int32_t VocabSize() const { return vocab_size_; } | 48 | int32_t VocabSize() const { return vocab_size_; } |
| 49 | int32_t SubsamplingFactor() const { return 4; } | 49 | int32_t SubsamplingFactor() const { return 4; } |
| 50 | 50 | ||
| 51 | - OrtAllocator *Allocator() const { return allocator_; } | 51 | + OrtAllocator *Allocator() { return allocator_; } |
| 52 | 52 | ||
| 53 | private: | 53 | private: |
| 54 | void Init(void *model_data, size_t model_data_length) { | 54 | void Init(void *model_data, size_t model_data_length) { |
| @@ -47,7 +47,7 @@ class OnlineCNNBiLSTMModel::Impl { | @@ -47,7 +47,7 @@ class OnlineCNNBiLSTMModel::Impl { | ||
| 47 | return {std::move(ans[0]), std::move(ans[1])}; | 47 | return {std::move(ans[0]), std::move(ans[1])}; |
| 48 | } | 48 | } |
| 49 | 49 | ||
| 50 | - OrtAllocator *Allocator() const { return allocator_; } | 50 | + OrtAllocator *Allocator() { return allocator_; } |
| 51 | 51 | ||
| 52 | const OnlineCNNBiLSTMModelMetaData &GetModelMetadata() const { | 52 | const OnlineCNNBiLSTMModelMetaData &GetModelMetadata() const { |
| 53 | return meta_data_; | 53 | return meta_data_; |
| @@ -163,8 +163,11 @@ std::vector<Ort::Value> OnlineConformerTransducerModel::StackStates( | @@ -163,8 +163,11 @@ std::vector<Ort::Value> OnlineConformerTransducerModel::StackStates( | ||
| 163 | conv_vec[i] = &states[i][1]; | 163 | conv_vec[i] = &states[i][1]; |
| 164 | } | 164 | } |
| 165 | 165 | ||
| 166 | - Ort::Value attn = Cat(allocator_, attn_vec, 2); | ||
| 167 | - Ort::Value conv = Cat(allocator_, conv_vec, 2); | 166 | + auto allocator = |
| 167 | + const_cast<OnlineConformerTransducerModel *>(this)->allocator_; | ||
| 168 | + | ||
| 169 | + Ort::Value attn = Cat(allocator, attn_vec, 2); | ||
| 170 | + Ort::Value conv = Cat(allocator, conv_vec, 2); | ||
| 168 | 171 | ||
| 169 | std::vector<Ort::Value> ans; | 172 | std::vector<Ort::Value> ans; |
| 170 | ans.reserve(2); | 173 | ans.reserve(2); |
| @@ -183,8 +186,11 @@ OnlineConformerTransducerModel::UnStackStates( | @@ -183,8 +186,11 @@ OnlineConformerTransducerModel::UnStackStates( | ||
| 183 | 186 | ||
| 184 | std::vector<std::vector<Ort::Value>> ans(batch_size); | 187 | std::vector<std::vector<Ort::Value>> ans(batch_size); |
| 185 | 188 | ||
| 186 | - std::vector<Ort::Value> attn_vec = Unbind(allocator_, &states[0], 2); | ||
| 187 | - std::vector<Ort::Value> conv_vec = Unbind(allocator_, &states[1], 2); | 189 | + auto allocator = |
| 190 | + const_cast<OnlineConformerTransducerModel *>(this)->allocator_; | ||
| 191 | + | ||
| 192 | + std::vector<Ort::Value> attn_vec = Unbind(allocator, &states[0], 2); | ||
| 193 | + std::vector<Ort::Value> conv_vec = Unbind(allocator, &states[1], 2); | ||
| 188 | 194 | ||
| 189 | assert(attn_vec.size() == batch_size); | 195 | assert(attn_vec.size() == batch_size); |
| 190 | assert(conv_vec.size() == batch_size); | 196 | assert(conv_vec.size() == batch_size); |
| @@ -158,9 +158,10 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::StackStates( | @@ -158,9 +158,10 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::StackStates( | ||
| 158 | h_buf[i] = &states[i][0]; | 158 | h_buf[i] = &states[i][0]; |
| 159 | c_buf[i] = &states[i][1]; | 159 | c_buf[i] = &states[i][1]; |
| 160 | } | 160 | } |
| 161 | + auto allocator = const_cast<OnlineLstmTransducerModel *>(this)->allocator_; | ||
| 161 | 162 | ||
| 162 | - Ort::Value h = Cat(allocator_, h_buf, 1); | ||
| 163 | - Ort::Value c = Cat(allocator_, c_buf, 1); | 163 | + Ort::Value h = Cat(allocator, h_buf, 1); |
| 164 | + Ort::Value c = Cat(allocator, c_buf, 1); | ||
| 164 | 165 | ||
| 165 | std::vector<Ort::Value> ans; | 166 | std::vector<Ort::Value> ans; |
| 166 | ans.reserve(2); | 167 | ans.reserve(2); |
| @@ -177,8 +178,10 @@ std::vector<std::vector<Ort::Value>> OnlineLstmTransducerModel::UnStackStates( | @@ -177,8 +178,10 @@ std::vector<std::vector<Ort::Value>> OnlineLstmTransducerModel::UnStackStates( | ||
| 177 | 178 | ||
| 178 | std::vector<std::vector<Ort::Value>> ans(batch_size); | 179 | std::vector<std::vector<Ort::Value>> ans(batch_size); |
| 179 | 180 | ||
| 180 | - std::vector<Ort::Value> h_vec = Unbind(allocator_, &states[0], 1); | ||
| 181 | - std::vector<Ort::Value> c_vec = Unbind(allocator_, &states[1], 1); | 181 | + auto allocator = const_cast<OnlineLstmTransducerModel *>(this)->allocator_; |
| 182 | + | ||
| 183 | + std::vector<Ort::Value> h_vec = Unbind(allocator, &states[0], 1); | ||
| 184 | + std::vector<Ort::Value> c_vec = Unbind(allocator, &states[1], 1); | ||
| 182 | 185 | ||
| 183 | assert(h_vec.size() == batch_size); | 186 | assert(h_vec.size() == batch_size); |
| 184 | assert(c_vec.size() == batch_size); | 187 | assert(c_vec.size() == batch_size); |
| @@ -102,7 +102,7 @@ class OnlineNeMoCtcModel::Impl { | @@ -102,7 +102,7 @@ class OnlineNeMoCtcModel::Impl { | ||
| 102 | 102 | ||
| 103 | int32_t ChunkShift() const { return chunk_shift_; } | 103 | int32_t ChunkShift() const { return chunk_shift_; } |
| 104 | 104 | ||
| 105 | - OrtAllocator *Allocator() const { return allocator_; } | 105 | + OrtAllocator *Allocator() { return allocator_; } |
| 106 | 106 | ||
| 107 | // Return a vector containing 3 tensors | 107 | // Return a vector containing 3 tensors |
| 108 | // - cache_last_channel | 108 | // - cache_last_channel |
| @@ -119,7 +119,7 @@ class OnlineNeMoCtcModel::Impl { | @@ -119,7 +119,7 @@ class OnlineNeMoCtcModel::Impl { | ||
| 119 | } | 119 | } |
| 120 | 120 | ||
| 121 | std::vector<Ort::Value> StackStates( | 121 | std::vector<Ort::Value> StackStates( |
| 122 | - std::vector<std::vector<Ort::Value>> states) const { | 122 | + std::vector<std::vector<Ort::Value>> states) { |
| 123 | int32_t batch_size = static_cast<int32_t>(states.size()); | 123 | int32_t batch_size = static_cast<int32_t>(states.size()); |
| 124 | if (batch_size == 1) { | 124 | if (batch_size == 1) { |
| 125 | return std::move(states[0]); | 125 | return std::move(states[0]); |
| @@ -157,6 +157,8 @@ class OnlineNeMoCtcModel::Impl { | @@ -157,6 +157,8 @@ class OnlineNeMoCtcModel::Impl { | ||
| 157 | std::vector<Ort::Value> states) const { | 157 | std::vector<Ort::Value> states) const { |
| 158 | assert(states.size() == 3); | 158 | assert(states.size() == 3); |
| 159 | 159 | ||
| 160 | + auto allocator = const_cast<Impl *>(this)->allocator_; | ||
| 161 | + | ||
| 160 | std::vector<std::vector<Ort::Value>> ans; | 162 | std::vector<std::vector<Ort::Value>> ans; |
| 161 | 163 | ||
| 162 | auto shape = states[0].GetTensorTypeAndShapeInfo().GetShape(); | 164 | auto shape = states[0].GetTensorTypeAndShapeInfo().GetShape(); |
| @@ -171,9 +173,9 @@ class OnlineNeMoCtcModel::Impl { | @@ -171,9 +173,9 @@ class OnlineNeMoCtcModel::Impl { | ||
| 171 | for (int32_t i = 0; i != 3; ++i) { | 173 | for (int32_t i = 0; i != 3; ++i) { |
| 172 | std::vector<Ort::Value> v; | 174 | std::vector<Ort::Value> v; |
| 173 | if (i == 2) { | 175 | if (i == 2) { |
| 174 | - v = Unbind<int64_t>(allocator_, &states[i], 0); | 176 | + v = Unbind<int64_t>(allocator, &states[i], 0); |
| 175 | } else { | 177 | } else { |
| 176 | - v = Unbind(allocator_, &states[i], 0); | 178 | + v = Unbind(allocator, &states[i], 0); |
| 177 | } | 179 | } |
| 178 | 180 | ||
| 179 | assert(v.size() == batch_size); | 181 | assert(v.size() == batch_size); |
| @@ -105,7 +105,7 @@ class OnlineParaformerModel::Impl { | @@ -105,7 +105,7 @@ class OnlineParaformerModel::Impl { | ||
| 105 | 105 | ||
| 106 | const std::vector<float> &InverseStdDev() const { return inv_stddev_; } | 106 | const std::vector<float> &InverseStdDev() const { return inv_stddev_; } |
| 107 | 107 | ||
| 108 | - OrtAllocator *Allocator() const { return allocator_; } | 108 | + OrtAllocator *Allocator() { return allocator_; } |
| 109 | 109 | ||
| 110 | private: | 110 | private: |
| 111 | void InitEncoder(void *model_data, size_t model_data_length) { | 111 | void InitEncoder(void *model_data, size_t model_data_length) { |
| @@ -5,10 +5,10 @@ | @@ -5,10 +5,10 @@ | ||
| 5 | 5 | ||
| 6 | #include "sherpa-onnx/csrc/online-rnn-lm.h" | 6 | #include "sherpa-onnx/csrc/online-rnn-lm.h" |
| 7 | 7 | ||
| 8 | +#include <algorithm> | ||
| 8 | #include <string> | 9 | #include <string> |
| 9 | #include <utility> | 10 | #include <utility> |
| 10 | #include <vector> | 11 | #include <vector> |
| 11 | -#include <algorithm> | ||
| 12 | 12 | ||
| 13 | #include "onnxruntime_cxx_api.h" // NOLINT | 13 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 14 | #include "sherpa-onnx/csrc/macros.h" | 14 | #include "sherpa-onnx/csrc/macros.h" |
| @@ -53,49 +53,49 @@ class OnlineRnnLM::Impl { | @@ -53,49 +53,49 @@ class OnlineRnnLM::Impl { | ||
| 53 | 53 | ||
| 54 | // classic rescore function | 54 | // classic rescore function |
| 55 | void ComputeLMScore(float scale, int32_t context_size, | 55 | void ComputeLMScore(float scale, int32_t context_size, |
| 56 | - std::vector<Hypotheses> *hyps) { | ||
| 57 | - Ort::AllocatorWithDefaultOptions allocator; | ||
| 58 | - | ||
| 59 | - for (auto &hyp : *hyps) { | ||
| 60 | - for (auto &h_m : hyp) { | ||
| 61 | - auto &h = h_m.second; | ||
| 62 | - auto &ys = h.ys; | ||
| 63 | - const int32_t token_num_in_chunk = | ||
| 64 | - ys.size() - context_size - h.cur_scored_pos - 1; | ||
| 65 | - | ||
| 66 | - if (token_num_in_chunk < 1) { | ||
| 67 | - continue; | ||
| 68 | - } | ||
| 69 | - | ||
| 70 | - if (h.nn_lm_states.empty()) { | ||
| 71 | - h.nn_lm_states = Convert(GetInitStates()); | ||
| 72 | - } | ||
| 73 | - | ||
| 74 | - if (token_num_in_chunk >= h.lm_rescore_min_chunk) { | ||
| 75 | - std::array<int64_t, 2> x_shape{1, token_num_in_chunk}; | ||
| 76 | - | ||
| 77 | - Ort::Value x = Ort::Value::CreateTensor<int64_t>( | ||
| 78 | - allocator, x_shape.data(), x_shape.size()); | ||
| 79 | - int64_t *p_x = x.GetTensorMutableData<int64_t>(); | ||
| 80 | - std::copy(ys.begin() + context_size + h.cur_scored_pos, | ||
| 81 | - ys.end() - 1, p_x); | ||
| 82 | - | ||
| 83 | - // streaming forward by NN LM | ||
| 84 | - auto out = ScoreToken(std::move(x), | ||
| 85 | - Convert(std::move(h.nn_lm_states))); | ||
| 86 | - | ||
| 87 | - // update NN LM score in hyp | ||
| 88 | - const float *p_nll = out.first.GetTensorData<float>(); | ||
| 89 | - h.lm_log_prob = -scale * (*p_nll); | ||
| 90 | - | ||
| 91 | - // update NN LM states in hyp | ||
| 92 | - h.nn_lm_states = Convert(std::move(out.second)); | ||
| 93 | - | ||
| 94 | - h.cur_scored_pos += token_num_in_chunk; | ||
| 95 | - } | 56 | + std::vector<Hypotheses> *hyps) { |
| 57 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 58 | + | ||
| 59 | + for (auto &hyp : *hyps) { | ||
| 60 | + for (auto &h_m : hyp) { | ||
| 61 | + auto &h = h_m.second; | ||
| 62 | + auto &ys = h.ys; | ||
| 63 | + const int32_t token_num_in_chunk = | ||
| 64 | + ys.size() - context_size - h.cur_scored_pos - 1; | ||
| 65 | + | ||
| 66 | + if (token_num_in_chunk < 1) { | ||
| 67 | + continue; | ||
| 68 | + } | ||
| 69 | + | ||
| 70 | + if (h.nn_lm_states.empty()) { | ||
| 71 | + h.nn_lm_states = Convert(GetInitStates()); | ||
| 72 | + } | ||
| 73 | + | ||
| 74 | + if (token_num_in_chunk >= h.lm_rescore_min_chunk) { | ||
| 75 | + std::array<int64_t, 2> x_shape{1, token_num_in_chunk}; | ||
| 76 | + | ||
| 77 | + Ort::Value x = Ort::Value::CreateTensor<int64_t>( | ||
| 78 | + allocator, x_shape.data(), x_shape.size()); | ||
| 79 | + int64_t *p_x = x.GetTensorMutableData<int64_t>(); | ||
| 80 | + std::copy(ys.begin() + context_size + h.cur_scored_pos, ys.end() - 1, | ||
| 81 | + p_x); | ||
| 82 | + | ||
| 83 | + // streaming forward by NN LM | ||
| 84 | + auto out = | ||
| 85 | + ScoreToken(std::move(x), Convert(std::move(h.nn_lm_states))); | ||
| 86 | + | ||
| 87 | + // update NN LM score in hyp | ||
| 88 | + const float *p_nll = out.first.GetTensorData<float>(); | ||
| 89 | + h.lm_log_prob = -scale * (*p_nll); | ||
| 90 | + | ||
| 91 | + // update NN LM states in hyp | ||
| 92 | + h.nn_lm_states = Convert(std::move(out.second)); | ||
| 93 | + | ||
| 94 | + h.cur_scored_pos += token_num_in_chunk; | ||
| 96 | } | 95 | } |
| 97 | } | 96 | } |
| 98 | } | 97 | } |
| 98 | + } | ||
| 99 | 99 | ||
| 100 | std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken( | 100 | std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken( |
| 101 | Ort::Value x, std::vector<Ort::Value> states) { | 101 | Ort::Value x, std::vector<Ort::Value> states) { |
| @@ -125,7 +125,7 @@ class OnlineRnnLM::Impl { | @@ -125,7 +125,7 @@ class OnlineRnnLM::Impl { | ||
| 125 | } | 125 | } |
| 126 | 126 | ||
| 127 | // get init states for classic rescore | 127 | // get init states for classic rescore |
| 128 | - std::vector<Ort::Value> GetInitStates() const { | 128 | + std::vector<Ort::Value> GetInitStates() { |
| 129 | std::vector<Ort::Value> ans; | 129 | std::vector<Ort::Value> ans; |
| 130 | ans.reserve(init_states_.size()); | 130 | ans.reserve(init_states_.size()); |
| 131 | 131 | ||
| @@ -226,7 +226,7 @@ std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::ScoreToken( | @@ -226,7 +226,7 @@ std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::ScoreToken( | ||
| 226 | 226 | ||
| 227 | // classic rescore scores | 227 | // classic rescore scores |
| 228 | void OnlineRnnLM::ComputeLMScore(float scale, int32_t context_size, | 228 | void OnlineRnnLM::ComputeLMScore(float scale, int32_t context_size, |
| 229 | - std::vector<Hypotheses> *hyps) { | 229 | + std::vector<Hypotheses> *hyps) { |
| 230 | return impl_->ComputeLMScore(scale, context_size, hyps); | 230 | return impl_->ComputeLMScore(scale, context_size, hyps); |
| 231 | } | 231 | } |
| 232 | 232 | ||
| @@ -235,5 +235,4 @@ void OnlineRnnLM::ComputeLMScoreSF(float scale, Hypothesis *hyp) { | @@ -235,5 +235,4 @@ void OnlineRnnLM::ComputeLMScoreSF(float scale, Hypothesis *hyp) { | ||
| 235 | return impl_->ComputeLMScoreSF(scale, hyp); | 235 | return impl_->ComputeLMScoreSF(scale, hyp); |
| 236 | } | 236 | } |
| 237 | 237 | ||
| 238 | - | ||
| 239 | } // namespace sherpa_onnx | 238 | } // namespace sherpa_onnx |
| @@ -54,8 +54,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -54,8 +54,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 54 | 54 | ||
| 55 | Ort::AllocatorWithDefaultOptions allocator; | 55 | Ort::AllocatorWithDefaultOptions allocator; |
| 56 | auto model_type = | 56 | auto model_type = |
| 57 | - meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); | ||
| 58 | - if (!model_type) { | 57 | + LookupCustomModelMetaData(meta_data, "model_type", allocator); |
| 58 | + if (model_type.empty()) { | ||
| 59 | SHERPA_ONNX_LOGE( | 59 | SHERPA_ONNX_LOGE( |
| 60 | "No model_type in the metadata!\n" | 60 | "No model_type in the metadata!\n" |
| 61 | "Please make sure you are using the latest export-onnx.py from icefall " | 61 | "Please make sure you are using the latest export-onnx.py from icefall " |
| @@ -63,16 +63,16 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -63,16 +63,16 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 63 | return ModelType::kUnknown; | 63 | return ModelType::kUnknown; |
| 64 | } | 64 | } |
| 65 | 65 | ||
| 66 | - if (model_type.get() == std::string("conformer")) { | 66 | + if (model_type == "conformer") { |
| 67 | return ModelType::kConformer; | 67 | return ModelType::kConformer; |
| 68 | - } else if (model_type.get() == std::string("lstm")) { | 68 | + } else if (model_type == "lstm") { |
| 69 | return ModelType::kLstm; | 69 | return ModelType::kLstm; |
| 70 | - } else if (model_type.get() == std::string("zipformer")) { | 70 | + } else if (model_type == "zipformer") { |
| 71 | return ModelType::kZipformer; | 71 | return ModelType::kZipformer; |
| 72 | - } else if (model_type.get() == std::string("zipformer2")) { | 72 | + } else if (model_type == "zipformer2") { |
| 73 | return ModelType::kZipformer2; | 73 | return ModelType::kZipformer2; |
| 74 | } else { | 74 | } else { |
| 75 | - SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); | 75 | + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str()); |
| 76 | return ModelType::kUnknown; | 76 | return ModelType::kUnknown; |
| 77 | } | 77 | } |
| 78 | } | 78 | } |
| @@ -197,7 +197,7 @@ class OnlineTransducerNeMoModel::Impl { | @@ -197,7 +197,7 @@ class OnlineTransducerNeMoModel::Impl { | ||
| 197 | 197 | ||
| 198 | int32_t VocabSize() const { return vocab_size_; } | 198 | int32_t VocabSize() const { return vocab_size_; } |
| 199 | 199 | ||
| 200 | - OrtAllocator *Allocator() const { return allocator_; } | 200 | + OrtAllocator *Allocator() { return allocator_; } |
| 201 | 201 | ||
| 202 | std::string FeatureNormalizationMethod() const { return normalize_type_; } | 202 | std::string FeatureNormalizationMethod() const { return normalize_type_; } |
| 203 | 203 | ||
| @@ -224,6 +224,8 @@ class OnlineTransducerNeMoModel::Impl { | @@ -224,6 +224,8 @@ class OnlineTransducerNeMoModel::Impl { | ||
| 224 | 224 | ||
| 225 | std::vector<Ort::Value> ans; | 225 | std::vector<Ort::Value> ans; |
| 226 | 226 | ||
| 227 | + auto allocator = const_cast<Impl *>(this)->allocator_; | ||
| 228 | + | ||
| 227 | // stack cache_last_channel | 229 | // stack cache_last_channel |
| 228 | std::vector<const Ort::Value *> buf(batch_size); | 230 | std::vector<const Ort::Value *> buf(batch_size); |
| 229 | 231 | ||
| @@ -239,9 +241,9 @@ class OnlineTransducerNeMoModel::Impl { | @@ -239,9 +241,9 @@ class OnlineTransducerNeMoModel::Impl { | ||
| 239 | 241 | ||
| 240 | Ort::Value c{nullptr}; | 242 | Ort::Value c{nullptr}; |
| 241 | if (i == 2) { | 243 | if (i == 2) { |
| 242 | - c = Cat<int64_t>(allocator_, buf, 0); | 244 | + c = Cat<int64_t>(allocator, buf, 0); |
| 243 | } else { | 245 | } else { |
| 244 | - c = Cat(allocator_, buf, 0); | 246 | + c = Cat(allocator, buf, 0); |
| 245 | } | 247 | } |
| 246 | 248 | ||
| 247 | ans.push_back(std::move(c)); | 249 | ans.push_back(std::move(c)); |
| @@ -251,7 +253,7 @@ class OnlineTransducerNeMoModel::Impl { | @@ -251,7 +253,7 @@ class OnlineTransducerNeMoModel::Impl { | ||
| 251 | } | 253 | } |
| 252 | 254 | ||
| 253 | std::vector<std::vector<Ort::Value>> UnStackStates( | 255 | std::vector<std::vector<Ort::Value>> UnStackStates( |
| 254 | - std::vector<Ort::Value> states) const { | 256 | + std::vector<Ort::Value> states) { |
| 255 | assert(states.size() == 3); | 257 | assert(states.size() == 3); |
| 256 | 258 | ||
| 257 | std::vector<std::vector<Ort::Value>> ans; | 259 | std::vector<std::vector<Ort::Value>> ans; |
| @@ -101,7 +101,7 @@ class OnlineWenetCtcModel::Impl { | @@ -101,7 +101,7 @@ class OnlineWenetCtcModel::Impl { | ||
| 101 | return config_.wenet_ctc.chunk_size * subsampling_factor_; | 101 | return config_.wenet_ctc.chunk_size * subsampling_factor_; |
| 102 | } | 102 | } |
| 103 | 103 | ||
| 104 | - OrtAllocator *Allocator() const { return allocator_; } | 104 | + OrtAllocator *Allocator() { return allocator_; } |
| 105 | 105 | ||
| 106 | // Return a vector containing 3 tensors | 106 | // Return a vector containing 3 tensors |
| 107 | // - attn_cache | 107 | // - attn_cache |
| @@ -179,12 +179,15 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates( | @@ -179,12 +179,15 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates( | ||
| 179 | std::vector<Ort::Value> ans; | 179 | std::vector<Ort::Value> ans; |
| 180 | ans.reserve(states[0].size()); | 180 | ans.reserve(states[0].size()); |
| 181 | 181 | ||
| 182 | + auto allocator = | ||
| 183 | + const_cast<OnlineZipformerTransducerModel *>(this)->allocator_; | ||
| 184 | + | ||
| 182 | // cached_len | 185 | // cached_len |
| 183 | for (int32_t i = 0; i != num_encoders; ++i) { | 186 | for (int32_t i = 0; i != num_encoders; ++i) { |
| 184 | for (int32_t n = 0; n != batch_size; ++n) { | 187 | for (int32_t n = 0; n != batch_size; ++n) { |
| 185 | buf[n] = &states[n][i]; | 188 | buf[n] = &states[n][i]; |
| 186 | } | 189 | } |
| 187 | - auto v = Cat<int64_t>(allocator_, buf, 1); // (num_layers, 1) | 190 | + auto v = Cat<int64_t>(allocator, buf, 1); // (num_layers, 1) |
| 188 | ans.push_back(std::move(v)); | 191 | ans.push_back(std::move(v)); |
| 189 | } | 192 | } |
| 190 | 193 | ||
| @@ -193,7 +196,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates( | @@ -193,7 +196,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates( | ||
| 193 | for (int32_t n = 0; n != batch_size; ++n) { | 196 | for (int32_t n = 0; n != batch_size; ++n) { |
| 194 | buf[n] = &states[n][num_encoders + i]; | 197 | buf[n] = &states[n][num_encoders + i]; |
| 195 | } | 198 | } |
| 196 | - auto v = Cat(allocator_, buf, 1); // (num_layers, 1, encoder_dims) | 199 | + auto v = Cat(allocator, buf, 1); // (num_layers, 1, encoder_dims) |
| 197 | ans.push_back(std::move(v)); | 200 | ans.push_back(std::move(v)); |
| 198 | } | 201 | } |
| 199 | 202 | ||
| @@ -203,7 +206,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates( | @@ -203,7 +206,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates( | ||
| 203 | buf[n] = &states[n][num_encoders * 2 + i]; | 206 | buf[n] = &states[n][num_encoders * 2 + i]; |
| 204 | } | 207 | } |
| 205 | // (num_layers, left_context_len, 1, attention_dims) | 208 | // (num_layers, left_context_len, 1, attention_dims) |
| 206 | - auto v = Cat(allocator_, buf, 2); | 209 | + auto v = Cat(allocator, buf, 2); |
| 207 | ans.push_back(std::move(v)); | 210 | ans.push_back(std::move(v)); |
| 208 | } | 211 | } |
| 209 | 212 | ||
| @@ -213,7 +216,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates( | @@ -213,7 +216,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates( | ||
| 213 | buf[n] = &states[n][num_encoders * 3 + i]; | 216 | buf[n] = &states[n][num_encoders * 3 + i]; |
| 214 | } | 217 | } |
| 215 | // (num_layers, left_context_len, 1, attention_dims/2) | 218 | // (num_layers, left_context_len, 1, attention_dims/2) |
| 216 | - auto v = Cat(allocator_, buf, 2); | 219 | + auto v = Cat(allocator, buf, 2); |
| 217 | ans.push_back(std::move(v)); | 220 | ans.push_back(std::move(v)); |
| 218 | } | 221 | } |
| 219 | 222 | ||
| @@ -223,7 +226,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates( | @@ -223,7 +226,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates( | ||
| 223 | buf[n] = &states[n][num_encoders * 4 + i]; | 226 | buf[n] = &states[n][num_encoders * 4 + i]; |
| 224 | } | 227 | } |
| 225 | // (num_layers, left_context_len, 1, attention_dims/2) | 228 | // (num_layers, left_context_len, 1, attention_dims/2) |
| 226 | - auto v = Cat(allocator_, buf, 2); | 229 | + auto v = Cat(allocator, buf, 2); |
| 227 | ans.push_back(std::move(v)); | 230 | ans.push_back(std::move(v)); |
| 228 | } | 231 | } |
| 229 | 232 | ||
| @@ -233,7 +236,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates( | @@ -233,7 +236,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates( | ||
| 233 | buf[n] = &states[n][num_encoders * 5 + i]; | 236 | buf[n] = &states[n][num_encoders * 5 + i]; |
| 234 | } | 237 | } |
| 235 | // (num_layers, 1, encoder_dims, cnn_module_kernels-1) | 238 | // (num_layers, 1, encoder_dims, cnn_module_kernels-1) |
| 236 | - auto v = Cat(allocator_, buf, 1); | 239 | + auto v = Cat(allocator, buf, 1); |
| 237 | ans.push_back(std::move(v)); | 240 | ans.push_back(std::move(v)); |
| 238 | } | 241 | } |
| 239 | 242 | ||
| @@ -243,7 +246,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates( | @@ -243,7 +246,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates( | ||
| 243 | buf[n] = &states[n][num_encoders * 6 + i]; | 246 | buf[n] = &states[n][num_encoders * 6 + i]; |
| 244 | } | 247 | } |
| 245 | // (num_layers, 1, encoder_dims, cnn_module_kernels-1) | 248 | // (num_layers, 1, encoder_dims, cnn_module_kernels-1) |
| 246 | - auto v = Cat(allocator_, buf, 1); | 249 | + auto v = Cat(allocator, buf, 1); |
| 247 | ans.push_back(std::move(v)); | 250 | ans.push_back(std::move(v)); |
| 248 | } | 251 | } |
| 249 | 252 | ||
| @@ -258,12 +261,15 @@ OnlineZipformerTransducerModel::UnStackStates( | @@ -258,12 +261,15 @@ OnlineZipformerTransducerModel::UnStackStates( | ||
| 258 | int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1]; | 261 | int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1]; |
| 259 | int32_t num_encoders = num_encoder_layers_.size(); | 262 | int32_t num_encoders = num_encoder_layers_.size(); |
| 260 | 263 | ||
| 264 | + auto allocator = | ||
| 265 | + const_cast<OnlineZipformerTransducerModel *>(this)->allocator_; | ||
| 266 | + | ||
| 261 | std::vector<std::vector<Ort::Value>> ans; | 267 | std::vector<std::vector<Ort::Value>> ans; |
| 262 | ans.resize(batch_size); | 268 | ans.resize(batch_size); |
| 263 | 269 | ||
| 264 | // cached_len | 270 | // cached_len |
| 265 | for (int32_t i = 0; i != num_encoders; ++i) { | 271 | for (int32_t i = 0; i != num_encoders; ++i) { |
| 266 | - auto v = Unbind<int64_t>(allocator_, &states[i], 1); | 272 | + auto v = Unbind<int64_t>(allocator, &states[i], 1); |
| 267 | assert(v.size() == batch_size); | 273 | assert(v.size() == batch_size); |
| 268 | 274 | ||
| 269 | for (int32_t n = 0; n != batch_size; ++n) { | 275 | for (int32_t n = 0; n != batch_size; ++n) { |
| @@ -273,7 +279,7 @@ OnlineZipformerTransducerModel::UnStackStates( | @@ -273,7 +279,7 @@ OnlineZipformerTransducerModel::UnStackStates( | ||
| 273 | 279 | ||
| 274 | // cached_avg | 280 | // cached_avg |
| 275 | for (int32_t i = num_encoders; i != 2 * num_encoders; ++i) { | 281 | for (int32_t i = num_encoders; i != 2 * num_encoders; ++i) { |
| 276 | - auto v = Unbind(allocator_, &states[i], 1); | 282 | + auto v = Unbind(allocator, &states[i], 1); |
| 277 | assert(v.size() == batch_size); | 283 | assert(v.size() == batch_size); |
| 278 | 284 | ||
| 279 | for (int32_t n = 0; n != batch_size; ++n) { | 285 | for (int32_t n = 0; n != batch_size; ++n) { |
| @@ -283,7 +289,7 @@ OnlineZipformerTransducerModel::UnStackStates( | @@ -283,7 +289,7 @@ OnlineZipformerTransducerModel::UnStackStates( | ||
| 283 | 289 | ||
| 284 | // cached_key | 290 | // cached_key |
| 285 | for (int32_t i = 2 * num_encoders; i != 3 * num_encoders; ++i) { | 291 | for (int32_t i = 2 * num_encoders; i != 3 * num_encoders; ++i) { |
| 286 | - auto v = Unbind(allocator_, &states[i], 2); | 292 | + auto v = Unbind(allocator, &states[i], 2); |
| 287 | assert(v.size() == batch_size); | 293 | assert(v.size() == batch_size); |
| 288 | 294 | ||
| 289 | for (int32_t n = 0; n != batch_size; ++n) { | 295 | for (int32_t n = 0; n != batch_size; ++n) { |
| @@ -293,7 +299,7 @@ OnlineZipformerTransducerModel::UnStackStates( | @@ -293,7 +299,7 @@ OnlineZipformerTransducerModel::UnStackStates( | ||
| 293 | 299 | ||
| 294 | // cached_val | 300 | // cached_val |
| 295 | for (int32_t i = 3 * num_encoders; i != 4 * num_encoders; ++i) { | 301 | for (int32_t i = 3 * num_encoders; i != 4 * num_encoders; ++i) { |
| 296 | - auto v = Unbind(allocator_, &states[i], 2); | 302 | + auto v = Unbind(allocator, &states[i], 2); |
| 297 | assert(v.size() == batch_size); | 303 | assert(v.size() == batch_size); |
| 298 | 304 | ||
| 299 | for (int32_t n = 0; n != batch_size; ++n) { | 305 | for (int32_t n = 0; n != batch_size; ++n) { |
| @@ -303,7 +309,7 @@ OnlineZipformerTransducerModel::UnStackStates( | @@ -303,7 +309,7 @@ OnlineZipformerTransducerModel::UnStackStates( | ||
| 303 | 309 | ||
| 304 | // cached_val2 | 310 | // cached_val2 |
| 305 | for (int32_t i = 4 * num_encoders; i != 5 * num_encoders; ++i) { | 311 | for (int32_t i = 4 * num_encoders; i != 5 * num_encoders; ++i) { |
| 306 | - auto v = Unbind(allocator_, &states[i], 2); | 312 | + auto v = Unbind(allocator, &states[i], 2); |
| 307 | assert(v.size() == batch_size); | 313 | assert(v.size() == batch_size); |
| 308 | 314 | ||
| 309 | for (int32_t n = 0; n != batch_size; ++n) { | 315 | for (int32_t n = 0; n != batch_size; ++n) { |
| @@ -313,7 +319,7 @@ OnlineZipformerTransducerModel::UnStackStates( | @@ -313,7 +319,7 @@ OnlineZipformerTransducerModel::UnStackStates( | ||
| 313 | 319 | ||
| 314 | // cached_conv1 | 320 | // cached_conv1 |
| 315 | for (int32_t i = 5 * num_encoders; i != 6 * num_encoders; ++i) { | 321 | for (int32_t i = 5 * num_encoders; i != 6 * num_encoders; ++i) { |
| 316 | - auto v = Unbind(allocator_, &states[i], 1); | 322 | + auto v = Unbind(allocator, &states[i], 1); |
| 317 | assert(v.size() == batch_size); | 323 | assert(v.size() == batch_size); |
| 318 | 324 | ||
| 319 | for (int32_t n = 0; n != batch_size; ++n) { | 325 | for (int32_t n = 0; n != batch_size; ++n) { |
| @@ -323,7 +329,7 @@ OnlineZipformerTransducerModel::UnStackStates( | @@ -323,7 +329,7 @@ OnlineZipformerTransducerModel::UnStackStates( | ||
| 323 | 329 | ||
| 324 | // cached_conv2 | 330 | // cached_conv2 |
| 325 | for (int32_t i = 6 * num_encoders; i != 7 * num_encoders; ++i) { | 331 | for (int32_t i = 6 * num_encoders; i != 7 * num_encoders; ++i) { |
| 326 | - auto v = Unbind(allocator_, &states[i], 1); | 332 | + auto v = Unbind(allocator, &states[i], 1); |
| 327 | assert(v.size() == batch_size); | 333 | assert(v.size() == batch_size); |
| 328 | 334 | ||
| 329 | for (int32_t n = 0; n != batch_size; ++n) { | 335 | for (int32_t n = 0; n != batch_size; ++n) { |
| @@ -70,7 +70,7 @@ class OnlineZipformer2CtcModel::Impl { | @@ -70,7 +70,7 @@ class OnlineZipformer2CtcModel::Impl { | ||
| 70 | 70 | ||
| 71 | int32_t ChunkShift() const { return decode_chunk_len_; } | 71 | int32_t ChunkShift() const { return decode_chunk_len_; } |
| 72 | 72 | ||
| 73 | - OrtAllocator *Allocator() const { return allocator_; } | 73 | + OrtAllocator *Allocator() { return allocator_; } |
| 74 | 74 | ||
| 75 | // Return a vector containing 3 tensors | 75 | // Return a vector containing 3 tensors |
| 76 | // - attn_cache | 76 | // - attn_cache |
| @@ -86,7 +86,7 @@ class OnlineZipformer2CtcModel::Impl { | @@ -86,7 +86,7 @@ class OnlineZipformer2CtcModel::Impl { | ||
| 86 | } | 86 | } |
| 87 | 87 | ||
| 88 | std::vector<Ort::Value> StackStates( | 88 | std::vector<Ort::Value> StackStates( |
| 89 | - std::vector<std::vector<Ort::Value>> states) const { | 89 | + std::vector<std::vector<Ort::Value>> states) { |
| 90 | int32_t batch_size = static_cast<int32_t>(states.size()); | 90 | int32_t batch_size = static_cast<int32_t>(states.size()); |
| 91 | 91 | ||
| 92 | std::vector<const Ort::Value *> buf(batch_size); | 92 | std::vector<const Ort::Value *> buf(batch_size); |
| @@ -159,7 +159,7 @@ class OnlineZipformer2CtcModel::Impl { | @@ -159,7 +159,7 @@ class OnlineZipformer2CtcModel::Impl { | ||
| 159 | } | 159 | } |
| 160 | 160 | ||
| 161 | std::vector<std::vector<Ort::Value>> UnStackStates( | 161 | std::vector<std::vector<Ort::Value>> UnStackStates( |
| 162 | - std::vector<Ort::Value> states) const { | 162 | + std::vector<Ort::Value> states) { |
| 163 | int32_t m = std::accumulate(num_encoder_layers_.begin(), | 163 | int32_t m = std::accumulate(num_encoder_layers_.begin(), |
| 164 | num_encoder_layers_.end(), 0); | 164 | num_encoder_layers_.end(), 0); |
| 165 | assert(states.size() == m * 6 + 2); | 165 | assert(states.size() == m * 6 + 2); |
| @@ -185,6 +185,9 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates( | @@ -185,6 +185,9 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates( | ||
| 185 | 185 | ||
| 186 | std::vector<const Ort::Value *> buf(batch_size); | 186 | std::vector<const Ort::Value *> buf(batch_size); |
| 187 | 187 | ||
| 188 | + auto allocator = | ||
| 189 | + const_cast<OnlineZipformer2TransducerModel *>(this)->allocator_; | ||
| 190 | + | ||
| 188 | std::vector<Ort::Value> ans; | 191 | std::vector<Ort::Value> ans; |
| 189 | int32_t num_states = static_cast<int32_t>(states[0].size()); | 192 | int32_t num_states = static_cast<int32_t>(states[0].size()); |
| 190 | ans.reserve(num_states); | 193 | ans.reserve(num_states); |
| @@ -194,42 +197,42 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates( | @@ -194,42 +197,42 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates( | ||
| 194 | for (int32_t n = 0; n != batch_size; ++n) { | 197 | for (int32_t n = 0; n != batch_size; ++n) { |
| 195 | buf[n] = &states[n][6 * i]; | 198 | buf[n] = &states[n][6 * i]; |
| 196 | } | 199 | } |
| 197 | - auto v = Cat(allocator_, buf, 1); | 200 | + auto v = Cat(allocator, buf, 1); |
| 198 | ans.push_back(std::move(v)); | 201 | ans.push_back(std::move(v)); |
| 199 | } | 202 | } |
| 200 | { | 203 | { |
| 201 | for (int32_t n = 0; n != batch_size; ++n) { | 204 | for (int32_t n = 0; n != batch_size; ++n) { |
| 202 | buf[n] = &states[n][6 * i + 1]; | 205 | buf[n] = &states[n][6 * i + 1]; |
| 203 | } | 206 | } |
| 204 | - auto v = Cat(allocator_, buf, 1); | 207 | + auto v = Cat(allocator, buf, 1); |
| 205 | ans.push_back(std::move(v)); | 208 | ans.push_back(std::move(v)); |
| 206 | } | 209 | } |
| 207 | { | 210 | { |
| 208 | for (int32_t n = 0; n != batch_size; ++n) { | 211 | for (int32_t n = 0; n != batch_size; ++n) { |
| 209 | buf[n] = &states[n][6 * i + 2]; | 212 | buf[n] = &states[n][6 * i + 2]; |
| 210 | } | 213 | } |
| 211 | - auto v = Cat(allocator_, buf, 1); | 214 | + auto v = Cat(allocator, buf, 1); |
| 212 | ans.push_back(std::move(v)); | 215 | ans.push_back(std::move(v)); |
| 213 | } | 216 | } |
| 214 | { | 217 | { |
| 215 | for (int32_t n = 0; n != batch_size; ++n) { | 218 | for (int32_t n = 0; n != batch_size; ++n) { |
| 216 | buf[n] = &states[n][6 * i + 3]; | 219 | buf[n] = &states[n][6 * i + 3]; |
| 217 | } | 220 | } |
| 218 | - auto v = Cat(allocator_, buf, 1); | 221 | + auto v = Cat(allocator, buf, 1); |
| 219 | ans.push_back(std::move(v)); | 222 | ans.push_back(std::move(v)); |
| 220 | } | 223 | } |
| 221 | { | 224 | { |
| 222 | for (int32_t n = 0; n != batch_size; ++n) { | 225 | for (int32_t n = 0; n != batch_size; ++n) { |
| 223 | buf[n] = &states[n][6 * i + 4]; | 226 | buf[n] = &states[n][6 * i + 4]; |
| 224 | } | 227 | } |
| 225 | - auto v = Cat(allocator_, buf, 0); | 228 | + auto v = Cat(allocator, buf, 0); |
| 226 | ans.push_back(std::move(v)); | 229 | ans.push_back(std::move(v)); |
| 227 | } | 230 | } |
| 228 | { | 231 | { |
| 229 | for (int32_t n = 0; n != batch_size; ++n) { | 232 | for (int32_t n = 0; n != batch_size; ++n) { |
| 230 | buf[n] = &states[n][6 * i + 5]; | 233 | buf[n] = &states[n][6 * i + 5]; |
| 231 | } | 234 | } |
| 232 | - auto v = Cat(allocator_, buf, 0); | 235 | + auto v = Cat(allocator, buf, 0); |
| 233 | ans.push_back(std::move(v)); | 236 | ans.push_back(std::move(v)); |
| 234 | } | 237 | } |
| 235 | } | 238 | } |
| @@ -238,7 +241,7 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates( | @@ -238,7 +241,7 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates( | ||
| 238 | for (int32_t n = 0; n != batch_size; ++n) { | 241 | for (int32_t n = 0; n != batch_size; ++n) { |
| 239 | buf[n] = &states[n][num_states - 2]; | 242 | buf[n] = &states[n][num_states - 2]; |
| 240 | } | 243 | } |
| 241 | - auto v = Cat(allocator_, buf, 0); | 244 | + auto v = Cat(allocator, buf, 0); |
| 242 | ans.push_back(std::move(v)); | 245 | ans.push_back(std::move(v)); |
| 243 | } | 246 | } |
| 244 | 247 | ||
| @@ -246,7 +249,7 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates( | @@ -246,7 +249,7 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates( | ||
| 246 | for (int32_t n = 0; n != batch_size; ++n) { | 249 | for (int32_t n = 0; n != batch_size; ++n) { |
| 247 | buf[n] = &states[n][num_states - 1]; | 250 | buf[n] = &states[n][num_states - 1]; |
| 248 | } | 251 | } |
| 249 | - auto v = Cat<int64_t>(allocator_, buf, 0); | 252 | + auto v = Cat<int64_t>(allocator, buf, 0); |
| 250 | ans.push_back(std::move(v)); | 253 | ans.push_back(std::move(v)); |
| 251 | } | 254 | } |
| 252 | return ans; | 255 | return ans; |
| @@ -261,12 +264,15 @@ OnlineZipformer2TransducerModel::UnStackStates( | @@ -261,12 +264,15 @@ OnlineZipformer2TransducerModel::UnStackStates( | ||
| 261 | 264 | ||
| 262 | int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1]; | 265 | int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1]; |
| 263 | 266 | ||
| 267 | + auto allocator = | ||
| 268 | + const_cast<OnlineZipformer2TransducerModel *>(this)->allocator_; | ||
| 269 | + | ||
| 264 | std::vector<std::vector<Ort::Value>> ans; | 270 | std::vector<std::vector<Ort::Value>> ans; |
| 265 | ans.resize(batch_size); | 271 | ans.resize(batch_size); |
| 266 | 272 | ||
| 267 | for (int32_t i = 0; i != m; ++i) { | 273 | for (int32_t i = 0; i != m; ++i) { |
| 268 | { | 274 | { |
| 269 | - auto v = Unbind(allocator_, &states[i * 6], 1); | 275 | + auto v = Unbind(allocator, &states[i * 6], 1); |
| 270 | assert(static_cast<int32_t>(v.size()) == batch_size); | 276 | assert(static_cast<int32_t>(v.size()) == batch_size); |
| 271 | 277 | ||
| 272 | for (int32_t n = 0; n != batch_size; ++n) { | 278 | for (int32_t n = 0; n != batch_size; ++n) { |
| @@ -274,7 +280,7 @@ OnlineZipformer2TransducerModel::UnStackStates( | @@ -274,7 +280,7 @@ OnlineZipformer2TransducerModel::UnStackStates( | ||
| 274 | } | 280 | } |
| 275 | } | 281 | } |
| 276 | { | 282 | { |
| 277 | - auto v = Unbind(allocator_, &states[i * 6 + 1], 1); | 283 | + auto v = Unbind(allocator, &states[i * 6 + 1], 1); |
| 278 | assert(static_cast<int32_t>(v.size()) == batch_size); | 284 | assert(static_cast<int32_t>(v.size()) == batch_size); |
| 279 | 285 | ||
| 280 | for (int32_t n = 0; n != batch_size; ++n) { | 286 | for (int32_t n = 0; n != batch_size; ++n) { |
| @@ -282,7 +288,7 @@ OnlineZipformer2TransducerModel::UnStackStates( | @@ -282,7 +288,7 @@ OnlineZipformer2TransducerModel::UnStackStates( | ||
| 282 | } | 288 | } |
| 283 | } | 289 | } |
| 284 | { | 290 | { |
| 285 | - auto v = Unbind(allocator_, &states[i * 6 + 2], 1); | 291 | + auto v = Unbind(allocator, &states[i * 6 + 2], 1); |
| 286 | assert(static_cast<int32_t>(v.size()) == batch_size); | 292 | assert(static_cast<int32_t>(v.size()) == batch_size); |
| 287 | 293 | ||
| 288 | for (int32_t n = 0; n != batch_size; ++n) { | 294 | for (int32_t n = 0; n != batch_size; ++n) { |
| @@ -290,7 +296,7 @@ OnlineZipformer2TransducerModel::UnStackStates( | @@ -290,7 +296,7 @@ OnlineZipformer2TransducerModel::UnStackStates( | ||
| 290 | } | 296 | } |
| 291 | } | 297 | } |
| 292 | { | 298 | { |
| 293 | - auto v = Unbind(allocator_, &states[i * 6 + 3], 1); | 299 | + auto v = Unbind(allocator, &states[i * 6 + 3], 1); |
| 294 | assert(static_cast<int32_t>(v.size()) == batch_size); | 300 | assert(static_cast<int32_t>(v.size()) == batch_size); |
| 295 | 301 | ||
| 296 | for (int32_t n = 0; n != batch_size; ++n) { | 302 | for (int32_t n = 0; n != batch_size; ++n) { |
| @@ -298,7 +304,7 @@ OnlineZipformer2TransducerModel::UnStackStates( | @@ -298,7 +304,7 @@ OnlineZipformer2TransducerModel::UnStackStates( | ||
| 298 | } | 304 | } |
| 299 | } | 305 | } |
| 300 | { | 306 | { |
| 301 | - auto v = Unbind(allocator_, &states[i * 6 + 4], 0); | 307 | + auto v = Unbind(allocator, &states[i * 6 + 4], 0); |
| 302 | assert(static_cast<int32_t>(v.size()) == batch_size); | 308 | assert(static_cast<int32_t>(v.size()) == batch_size); |
| 303 | 309 | ||
| 304 | for (int32_t n = 0; n != batch_size; ++n) { | 310 | for (int32_t n = 0; n != batch_size; ++n) { |
| @@ -306,7 +312,7 @@ OnlineZipformer2TransducerModel::UnStackStates( | @@ -306,7 +312,7 @@ OnlineZipformer2TransducerModel::UnStackStates( | ||
| 306 | } | 312 | } |
| 307 | } | 313 | } |
| 308 | { | 314 | { |
| 309 | - auto v = Unbind(allocator_, &states[i * 6 + 5], 0); | 315 | + auto v = Unbind(allocator, &states[i * 6 + 5], 0); |
| 310 | assert(static_cast<int32_t>(v.size()) == batch_size); | 316 | assert(static_cast<int32_t>(v.size()) == batch_size); |
| 311 | 317 | ||
| 312 | for (int32_t n = 0; n != batch_size; ++n) { | 318 | for (int32_t n = 0; n != batch_size; ++n) { |
| @@ -316,7 +322,7 @@ OnlineZipformer2TransducerModel::UnStackStates( | @@ -316,7 +322,7 @@ OnlineZipformer2TransducerModel::UnStackStates( | ||
| 316 | } | 322 | } |
| 317 | 323 | ||
| 318 | { | 324 | { |
| 319 | - auto v = Unbind(allocator_, &states[m * 6], 0); | 325 | + auto v = Unbind(allocator, &states[m * 6], 0); |
| 320 | assert(static_cast<int32_t>(v.size()) == batch_size); | 326 | assert(static_cast<int32_t>(v.size()) == batch_size); |
| 321 | 327 | ||
| 322 | for (int32_t n = 0; n != batch_size; ++n) { | 328 | for (int32_t n = 0; n != batch_size; ++n) { |
| @@ -324,7 +330,7 @@ OnlineZipformer2TransducerModel::UnStackStates( | @@ -324,7 +330,7 @@ OnlineZipformer2TransducerModel::UnStackStates( | ||
| 324 | } | 330 | } |
| 325 | } | 331 | } |
| 326 | { | 332 | { |
| 327 | - auto v = Unbind<int64_t>(allocator_, &states[m * 6 + 1], 0); | 333 | + auto v = Unbind<int64_t>(allocator, &states[m * 6 + 1], 0); |
| 328 | assert(static_cast<int32_t>(v.size()) == batch_size); | 334 | assert(static_cast<int32_t>(v.size()) == batch_size); |
| 329 | 335 | ||
| 330 | for (int32_t n = 0; n != batch_size; ++n) { | 336 | for (int32_t n = 0; n != batch_size; ++n) { |
| @@ -21,6 +21,36 @@ | @@ -21,6 +21,36 @@ | ||
| 21 | 21 | ||
| 22 | namespace sherpa_onnx { | 22 | namespace sherpa_onnx { |
| 23 | 23 | ||
| 24 | +static std::string GetInputName(Ort::Session *sess, size_t index, | ||
| 25 | + OrtAllocator *allocator) { | ||
| 26 | +// Note(fangjun): We only tested 1.17.1 and 1.11.0 | ||
| 27 | +// For other versions, we may need to change it | ||
| 28 | +#if ORT_API_VERSION >= 17 | ||
| 29 | + auto v = sess->GetInputNameAllocated(index, allocator); | ||
| 30 | + return v.get(); | ||
| 31 | +#else | ||
| 32 | + auto v = sess->GetInputName(index, allocator); | ||
| 33 | + std::string ans = v; | ||
| 34 | + allocator->Free(allocator, v); | ||
| 35 | + return ans; | ||
| 36 | +#endif | ||
| 37 | +} | ||
| 38 | + | ||
| 39 | +static std::string GetOutputName(Ort::Session *sess, size_t index, | ||
| 40 | + OrtAllocator *allocator) { | ||
| 41 | +// Note(fangjun): We only tested 1.17.1 and 1.11.0 | ||
| 42 | +// For other versions, we may need to change it | ||
| 43 | +#if ORT_API_VERSION >= 17 | ||
| 44 | + auto v = sess->GetOutputNameAllocated(index, allocator); | ||
| 45 | + return v.get(); | ||
| 46 | +#else | ||
| 47 | + auto v = sess->GetOutputName(index, allocator); | ||
| 48 | + std::string ans = v; | ||
| 49 | + allocator->Free(allocator, v); | ||
| 50 | + return ans; | ||
| 51 | +#endif | ||
| 52 | +} | ||
| 53 | + | ||
| 24 | void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names, | 54 | void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names, |
| 25 | std::vector<const char *> *input_names_ptr) { | 55 | std::vector<const char *> *input_names_ptr) { |
| 26 | Ort::AllocatorWithDefaultOptions allocator; | 56 | Ort::AllocatorWithDefaultOptions allocator; |
| @@ -28,8 +58,7 @@ void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names, | @@ -28,8 +58,7 @@ void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names, | ||
| 28 | input_names->resize(node_count); | 58 | input_names->resize(node_count); |
| 29 | input_names_ptr->resize(node_count); | 59 | input_names_ptr->resize(node_count); |
| 30 | for (size_t i = 0; i != node_count; ++i) { | 60 | for (size_t i = 0; i != node_count; ++i) { |
| 31 | - auto tmp = sess->GetInputNameAllocated(i, allocator); | ||
| 32 | - (*input_names)[i] = tmp.get(); | 61 | + (*input_names)[i] = GetInputName(sess, i, allocator); |
| 33 | (*input_names_ptr)[i] = (*input_names)[i].c_str(); | 62 | (*input_names_ptr)[i] = (*input_names)[i].c_str(); |
| 34 | } | 63 | } |
| 35 | } | 64 | } |
| @@ -41,8 +70,7 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names, | @@ -41,8 +70,7 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names, | ||
| 41 | output_names->resize(node_count); | 70 | output_names->resize(node_count); |
| 42 | output_names_ptr->resize(node_count); | 71 | output_names_ptr->resize(node_count); |
| 43 | for (size_t i = 0; i != node_count; ++i) { | 72 | for (size_t i = 0; i != node_count; ++i) { |
| 44 | - auto tmp = sess->GetOutputNameAllocated(i, allocator); | ||
| 45 | - (*output_names)[i] = tmp.get(); | 73 | + (*output_names)[i] = GetOutputName(sess, i, allocator); |
| 46 | (*output_names_ptr)[i] = (*output_names)[i].c_str(); | 74 | (*output_names_ptr)[i] = (*output_names)[i].c_str(); |
| 47 | } | 75 | } |
| 48 | } | 76 | } |
| @@ -78,12 +106,24 @@ Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out, | @@ -78,12 +106,24 @@ Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out, | ||
| 78 | 106 | ||
| 79 | void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) { | 107 | void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) { |
| 80 | Ort::AllocatorWithDefaultOptions allocator; | 108 | Ort::AllocatorWithDefaultOptions allocator; |
| 109 | +#if ORT_API_VERSION >= 17 | ||
| 81 | std::vector<Ort::AllocatedStringPtr> v = | 110 | std::vector<Ort::AllocatedStringPtr> v = |
| 82 | meta_data.GetCustomMetadataMapKeysAllocated(allocator); | 111 | meta_data.GetCustomMetadataMapKeysAllocated(allocator); |
| 83 | for (const auto &key : v) { | 112 | for (const auto &key : v) { |
| 84 | auto p = meta_data.LookupCustomMetadataMapAllocated(key.get(), allocator); | 113 | auto p = meta_data.LookupCustomMetadataMapAllocated(key.get(), allocator); |
| 85 | os << key.get() << "=" << p.get() << "\n"; | 114 | os << key.get() << "=" << p.get() << "\n"; |
| 86 | } | 115 | } |
| 116 | +#else | ||
| 117 | + int64_t num_keys = 0; | ||
| 118 | + char **keys = meta_data.GetCustomMetadataMapKeys(allocator, num_keys); | ||
| 119 | + for (int32_t i = 0; i < num_keys; ++i) { | ||
| 120 | + auto v = LookupCustomModelMetaData(meta_data, keys[i], allocator); | ||
| 121 | + os << keys[i] << "=" << v << "\n"; | ||
| 122 | + allocator.Free(keys[i]); | ||
| 123 | + } | ||
| 124 | + | ||
| 125 | + allocator.Free(keys); | ||
| 126 | +#endif | ||
| 87 | } | 127 | } |
| 88 | 128 | ||
| 89 | Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) { | 129 | Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) { |
| @@ -361,4 +401,20 @@ std::vector<Ort::Value> Convert(std::vector<CopyableOrtValue> values) { | @@ -361,4 +401,20 @@ std::vector<Ort::Value> Convert(std::vector<CopyableOrtValue> values) { | ||
| 361 | return ans; | 401 | return ans; |
| 362 | } | 402 | } |
| 363 | 403 | ||
| 404 | +std::string LookupCustomModelMetaData(const Ort::ModelMetadata &meta_data, | ||
| 405 | + const char *key, | ||
| 406 | + OrtAllocator *allocator) { | ||
| 407 | +// Note(fangjun): We only tested 1.17.1 and 1.11.0 | ||
| 408 | +// For other versions, we may need to change it | ||
| 409 | +#if ORT_API_VERSION >= 17 | ||
| 410 | + auto v = meta_data.LookupCustomMetadataMapAllocated(key, allocator); | ||
| 411 | + return v.get(); | ||
| 412 | +#else | ||
| 413 | + auto v = meta_data.LookupCustomMetadataMap(key, allocator); | ||
| 414 | + std::string ans = v; | ||
| 415 | + allocator->Free(allocator, v); | ||
| 416 | + return ans; | ||
| 417 | +#endif | ||
| 418 | +} | ||
| 419 | + | ||
| 364 | } // namespace sherpa_onnx | 420 | } // namespace sherpa_onnx |
| @@ -59,6 +59,9 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names, | @@ -59,6 +59,9 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names, | ||
| 59 | Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out, | 59 | Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out, |
| 60 | int32_t t); | 60 | int32_t t); |
| 61 | 61 | ||
| 62 | +std::string LookupCustomModelMetaData(const Ort::ModelMetadata &meta_data, | ||
| 63 | + const char *key, OrtAllocator *allocator); | ||
| 64 | + | ||
| 62 | void PrintModelMetadata(std::ostream &os, | 65 | void PrintModelMetadata(std::ostream &os, |
| 63 | const Ort::ModelMetadata &meta_data); // NOLINT | 66 | const Ort::ModelMetadata &meta_data); // NOLINT |
| 64 | 67 |
| @@ -60,6 +60,7 @@ Ort::SessionOptions GetSessionOptionsImpl( | @@ -60,6 +60,7 @@ Ort::SessionOptions GetSessionOptionsImpl( | ||
| 60 | case Provider::kCPU: | 60 | case Provider::kCPU: |
| 61 | break; // nothing to do for the CPU provider | 61 | break; // nothing to do for the CPU provider |
| 62 | case Provider::kXnnpack: { | 62 | case Provider::kXnnpack: { |
| 63 | +#if ORT_API_VERSION >= 17 | ||
| 63 | if (std::find(available_providers.begin(), available_providers.end(), | 64 | if (std::find(available_providers.begin(), available_providers.end(), |
| 64 | "XnnpackExecutionProvider") != available_providers.end()) { | 65 | "XnnpackExecutionProvider") != available_providers.end()) { |
| 65 | sess_opts.AppendExecutionProvider("XNNPACK"); | 66 | sess_opts.AppendExecutionProvider("XNNPACK"); |
| @@ -67,6 +68,11 @@ Ort::SessionOptions GetSessionOptionsImpl( | @@ -67,6 +68,11 @@ Ort::SessionOptions GetSessionOptionsImpl( | ||
| 67 | SHERPA_ONNX_LOGE("Available providers: %s. Fallback to cpu!", | 68 | SHERPA_ONNX_LOGE("Available providers: %s. Fallback to cpu!", |
| 68 | os.str().c_str()); | 69 | os.str().c_str()); |
| 69 | } | 70 | } |
| 71 | +#else | ||
| 72 | + SHERPA_ONNX_LOGE( | ||
| 73 | + "Does not support xnnpack for onnxruntime: %d. Fallback to cpu!", | ||
| 74 | + static_cast<int32_t>(ORT_API_VERSION)); | ||
| 75 | +#endif | ||
| 70 | break; | 76 | break; |
| 71 | } | 77 | } |
| 72 | case Provider::kTRT: { | 78 | case Provider::kTRT: { |
| @@ -40,8 +40,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -40,8 +40,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 40 | 40 | ||
| 41 | Ort::AllocatorWithDefaultOptions allocator; | 41 | Ort::AllocatorWithDefaultOptions allocator; |
| 42 | auto model_type = | 42 | auto model_type = |
| 43 | - meta_data.LookupCustomMetadataMapAllocated("framework", allocator); | ||
| 44 | - if (!model_type) { | 43 | + LookupCustomModelMetaData(meta_data, "framework", allocator); |
| 44 | + if (model_type.empty()) { | ||
| 45 | SHERPA_ONNX_LOGE( | 45 | SHERPA_ONNX_LOGE( |
| 46 | "No model_type in the metadata!\n" | 46 | "No model_type in the metadata!\n" |
| 47 | "Please make sure you have added metadata to the model.\n\n" | 47 | "Please make sure you have added metadata to the model.\n\n" |
| @@ -52,14 +52,14 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -52,14 +52,14 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 52 | return ModelType::kUnknown; | 52 | return ModelType::kUnknown; |
| 53 | } | 53 | } |
| 54 | 54 | ||
| 55 | - if (model_type.get() == std::string("wespeaker")) { | 55 | + if (model_type == "wespeaker") { |
| 56 | return ModelType::kWeSpeaker; | 56 | return ModelType::kWeSpeaker; |
| 57 | - } else if (model_type.get() == std::string("3d-speaker")) { | 57 | + } else if (model_type == "3d-speaker") { |
| 58 | return ModelType::k3dSpeaker; | 58 | return ModelType::k3dSpeaker; |
| 59 | - } else if (model_type.get() == std::string("nemo")) { | 59 | + } else if (model_type == "nemo") { |
| 60 | return ModelType::kNeMo; | 60 | return ModelType::kNeMo; |
| 61 | } else { | 61 | } else { |
| 62 | - SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); | 62 | + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str()); |
| 63 | return ModelType::kUnknown; | 63 | return ModelType::kUnknown; |
| 64 | } | 64 | } |
| 65 | } | 65 | } |
| @@ -53,7 +53,7 @@ class SpeakerEmbeddingExtractorNeMoModel::Impl { | @@ -53,7 +53,7 @@ class SpeakerEmbeddingExtractorNeMoModel::Impl { | ||
| 53 | return std::move(outputs[0]); | 53 | return std::move(outputs[0]); |
| 54 | } | 54 | } |
| 55 | 55 | ||
| 56 | - OrtAllocator *Allocator() const { return allocator_; } | 56 | + OrtAllocator *Allocator() { return allocator_; } |
| 57 | 57 | ||
| 58 | const SpeakerEmbeddingExtractorNeMoModelMetaData &GetMetaData() const { | 58 | const SpeakerEmbeddingExtractorNeMoModelMetaData &GetMetaData() const { |
| 59 | return meta_data_; | 59 | return meta_data_; |
| @@ -42,8 +42,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -42,8 +42,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 42 | 42 | ||
| 43 | Ort::AllocatorWithDefaultOptions allocator; | 43 | Ort::AllocatorWithDefaultOptions allocator; |
| 44 | auto model_type = | 44 | auto model_type = |
| 45 | - meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); | ||
| 46 | - if (!model_type) { | 45 | + LookupCustomModelMetaData(meta_data, "model_type", allocator); |
| 46 | + if (model_type.empty()) { | ||
| 47 | SHERPA_ONNX_LOGE( | 47 | SHERPA_ONNX_LOGE( |
| 48 | "No model_type in the metadata!\n" | 48 | "No model_type in the metadata!\n" |
| 49 | "Please make sure you have added metadata to the model.\n\n" | 49 | "Please make sure you have added metadata to the model.\n\n" |
| @@ -54,11 +54,10 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -54,11 +54,10 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 54 | return ModelType::kUnknown; | 54 | return ModelType::kUnknown; |
| 55 | } | 55 | } |
| 56 | 56 | ||
| 57 | - auto model_type_str = std::string(model_type.get()); | ||
| 58 | - if (model_type_str.find("whisper") == 0) { | 57 | + if (model_type.find("whisper") == 0) { |
| 59 | return ModelType::kWhisper; | 58 | return ModelType::kWhisper; |
| 60 | } else { | 59 | } else { |
| 61 | - SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); | 60 | + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str()); |
| 62 | return ModelType::kUnknown; | 61 | return ModelType::kUnknown; |
| 63 | } | 62 | } |
| 64 | } | 63 | } |
| @@ -29,20 +29,19 @@ namespace { | @@ -29,20 +29,19 @@ namespace { | ||
| 29 | const char *ws = " \t\n\r\f\v"; | 29 | const char *ws = " \t\n\r\f\v"; |
| 30 | 30 | ||
| 31 | // trim from end of string (right) | 31 | // trim from end of string (right) |
| 32 | -inline std::string &TrimRight(std::string &s, const char *t = ws) { | ||
| 33 | - s.erase(s.find_last_not_of(t) + 1); | ||
| 34 | - return s; | 32 | +inline void TrimRight(std::string *s, const char *t = ws) { |
| 33 | + s->erase(s->find_last_not_of(t) + 1); | ||
| 35 | } | 34 | } |
| 36 | 35 | ||
| 37 | // trim from beginning of string (left) | 36 | // trim from beginning of string (left) |
| 38 | -inline std::string &TrimLeft(std::string &s, const char *t = ws) { | ||
| 39 | - s.erase(0, s.find_first_not_of(t)); | ||
| 40 | - return s; | 37 | +inline void TrimLeft(std::string *s, const char *t = ws) { |
| 38 | + s->erase(0, s->find_first_not_of(t)); | ||
| 41 | } | 39 | } |
| 42 | 40 | ||
| 43 | // trim from both ends of string (right then left) | 41 | // trim from both ends of string (right then left) |
| 44 | -inline std::string &Trim(std::string &s, const char *t = ws) { | ||
| 45 | - return TrimLeft(TrimRight(s, t), t); | 42 | +inline void Trim(std::string *s, const char *t = ws) { |
| 43 | + TrimRight(s, t); | ||
| 44 | + TrimLeft(s, t); | ||
| 46 | } | 45 | } |
| 47 | } // namespace | 46 | } // namespace |
| 48 | 47 | ||
| @@ -56,7 +55,7 @@ std::unordered_map<std::string, int32_t> ReadTokens( | @@ -56,7 +55,7 @@ std::unordered_map<std::string, int32_t> ReadTokens( | ||
| 56 | std::string sym; | 55 | std::string sym; |
| 57 | int32_t id = -1; | 56 | int32_t id = -1; |
| 58 | while (std::getline(is, line)) { | 57 | while (std::getline(is, line)) { |
| 59 | - Trim(line); | 58 | + Trim(&line); |
| 60 | std::istringstream iss(line); | 59 | std::istringstream iss(line); |
| 61 | iss >> sym; | 60 | iss >> sym; |
| 62 | if (iss.eof()) { | 61 | if (iss.eof()) { |
-
请 注册 或 登录 后发表评论