Fangjun Kuang
Committed by GitHub

Support building GPU-capable sherpa-onnx on Linux aarch64. (#1500)

Thanks to @Peakyxh for providing pre-built onnxruntime libraries 
with CUDA support for Linux aarch64.

Tested on Jetson nano b01
正在显示 41 个修改的文件 包含 546 行增加300 行删除
@@ -34,11 +34,12 @@ concurrency: @@ -34,11 +34,12 @@ concurrency:
34 jobs: 34 jobs:
35 aarch64_linux_gnu_shared: 35 aarch64_linux_gnu_shared:
36 runs-on: ${{ matrix.os }} 36 runs-on: ${{ matrix.os }}
37 - name: aarch64 shared lib test 37 + name: aarch64 shared GPU ${{ matrix.gpu }}
38 strategy: 38 strategy:
39 fail-fast: false 39 fail-fast: false
40 matrix: 40 matrix:
41 os: [ubuntu-latest] 41 os: [ubuntu-latest]
  42 + gpu: [ON, OFF]
42 43
43 steps: 44 steps:
44 - uses: actions/checkout@v4 45 - uses: actions/checkout@v4
@@ -79,15 +80,24 @@ jobs: @@ -79,15 +80,24 @@ jobs:
79 make -j2 80 make -j2
80 make install 81 make install
81 82
82 - - name: cache-toolchain  
83 - id: cache-toolchain 83 + - name: cache-toolchain (CPU)
  84 + if: matrix.gpu == 'OFF'
  85 + id: cache-toolchain-cpu
84 uses: actions/cache@v4 86 uses: actions/cache@v4
85 with: 87 with:
86 path: toolchain 88 path: toolchain
87 key: gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz 89 key: gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz
88 90
89 - - name: Download toolchain  
90 - if: steps.cache-toolchain.outputs.cache-hit != 'true' 91 + - name: cache-toolchain (GPU)
  92 + if: matrix.gpu == 'ON'
  93 + id: cache-toolchain-gpu
  94 + uses: actions/cache@v4
  95 + with:
  96 + path: toolchain
  97 + key: gcc-arm-10.3-2021.07-x86_64-aarch64-none-linux-gnu.tar.xz
  98 +
  99 + - name: Download toolchain (CPU, gcc 7.5)
  100 + if: steps.cache-toolchain-cpu.outputs.cache-hit != 'true' && matrix.gpu == 'OFF'
91 shell: bash 101 shell: bash
92 run: | 102 run: |
93 wget -qq https://huggingface.co/csukuangfj/sherpa-ncnn-toolchains/resolve/main/gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz 103 wget -qq https://huggingface.co/csukuangfj/sherpa-ncnn-toolchains/resolve/main/gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz
@@ -95,6 +105,15 @@ jobs: @@ -95,6 +105,15 @@ jobs:
95 mkdir $GITHUB_WORKSPACE/toolchain 105 mkdir $GITHUB_WORKSPACE/toolchain
96 tar xf ./gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz --strip-components 1 -C $GITHUB_WORKSPACE/toolchain 106 tar xf ./gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz --strip-components 1 -C $GITHUB_WORKSPACE/toolchain
97 107
  108 + - name: Download toolchain (GPU, gcc 10.3)
  109 + if: steps.cache-toolchain-gpu.outputs.cache-hit != 'true' && matrix.gpu == 'ON'
  110 + shell: bash
  111 + run: |
  112 + wget -qq https://huggingface.co/csukuangfj/sherpa-ncnn-toolchains/resolve/main/gcc-arm-10.3-2021.07-x86_64-aarch64-none-linux-gnu.tar.xz
  113 +
  114 + mkdir $GITHUB_WORKSPACE/toolchain
  115 + tar xf ./gcc-arm-10.3-2021.07-x86_64-aarch64-none-linux-gnu.tar.xz --strip-components 1 -C $GITHUB_WORKSPACE/toolchain
  116 +
98 - name: Set environment variable 117 - name: Set environment variable
99 if: steps.cache-build-result.outputs.cache-hit != 'true' 118 if: steps.cache-build-result.outputs.cache-hit != 'true'
100 shell: bash 119 shell: bash
@@ -103,19 +122,31 @@ jobs: @@ -103,19 +122,31 @@ jobs:
103 echo "$GITHUB_WORKSPACE/bin" >> "$GITHUB_PATH" 122 echo "$GITHUB_WORKSPACE/bin" >> "$GITHUB_PATH"
104 ls -lh "$GITHUB_WORKSPACE/toolchain/bin" 123 ls -lh "$GITHUB_WORKSPACE/toolchain/bin"
105 124
106 - echo "CC=aarch64-linux-gnu-gcc" >> "$GITHUB_ENV"  
107 - echo "CXX=aarch64-linux-gnu-g++" >> "$GITHUB_ENV" 125 + if [[ ${{ matrix.gpu }} == OFF ]]; then
  126 + echo "CC=aarch64-linux-gnu-gcc" >> "$GITHUB_ENV"
  127 + echo "CXX=aarch64-linux-gnu-g++" >> "$GITHUB_ENV"
  128 + else
  129 + echo "CC=aarch64-none-linux-gnu-gcc" >> "$GITHUB_ENV"
  130 + echo "CXX=aarch64-none-linux-gnu-g++" >> "$GITHUB_ENV"
  131 + fi
108 132
109 - name: Display toolchain info 133 - name: Display toolchain info
110 shell: bash 134 shell: bash
111 run: | 135 run: |
112 - aarch64-linux-gnu-gcc --version 136 + if [[ ${{ matrix.gpu }} == OFF ]]; then
  137 + which aarch64-linux-gnu-gcc
  138 + aarch64-linux-gnu-gcc --version
  139 + else
  140 + which aarch64-none-linux-gnu-gcc
  141 + aarch64-none-linux-gnu-gcc --version
  142 + fi
113 143
114 - name: Display qemu-aarch64 -h 144 - name: Display qemu-aarch64 -h
115 shell: bash 145 shell: bash
116 run: | 146 run: |
117 export PATH=$GITHUB_WORKSPACE/qemu-install/bin:$PATH 147 export PATH=$GITHUB_WORKSPACE/qemu-install/bin:$PATH
118 export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-linux-gnu/libc 148 export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-linux-gnu/libc
  149 + export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-none-linux-gnu/libc
119 qemu-aarch64 -h 150 qemu-aarch64 -h
120 151
121 - name: build aarch64-linux-gnu 152 - name: build aarch64-linux-gnu
@@ -127,6 +158,7 @@ jobs: @@ -127,6 +158,7 @@ jobs:
127 cmake --version 158 cmake --version
128 159
129 export BUILD_SHARED_LIBS=ON 160 export BUILD_SHARED_LIBS=ON
  161 + export SHERPA_ONNX_ENABLE_GPU=${{ matrix.gpu }}
130 162
131 ./build-aarch64-linux-gnu.sh 163 ./build-aarch64-linux-gnu.sh
132 164
@@ -140,7 +172,11 @@ jobs: @@ -140,7 +172,11 @@ jobs:
140 run: | 172 run: |
141 export PATH=$GITHUB_WORKSPACE/toolchain/bin:$PATH 173 export PATH=$GITHUB_WORKSPACE/toolchain/bin:$PATH
142 export PATH=$GITHUB_WORKSPACE/qemu-install/bin:$PATH 174 export PATH=$GITHUB_WORKSPACE/qemu-install/bin:$PATH
143 - export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-linux-gnu/libc 175 + if [[ ${{ matrix.gpu }} == OFF ]]; then
  176 + export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-linux-gnu/libc
  177 + else
  178 + export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-none-linux-gnu/libc
  179 + fi
144 180
145 ls -lh ./build-aarch64-linux-gnu/bin 181 ls -lh ./build-aarch64-linux-gnu/bin
146 182
@@ -151,11 +187,20 @@ jobs: @@ -151,11 +187,20 @@ jobs:
151 - name: Copy files 187 - name: Copy files
152 shell: bash 188 shell: bash
153 run: | 189 run: |
154 - aarch64-linux-gnu-strip --version 190 + if [[ ${{ matrix.gpu }} == OFF ]]; then
  191 + aarch64-linux-gnu-strip --version
  192 + else
  193 + aarch64-none-linux-gnu-strip --version
  194 + fi
155 195
156 SHERPA_ONNX_VERSION=v$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) 196 SHERPA_ONNX_VERSION=v$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
157 197
158 dst=sherpa-onnx-${SHERPA_ONNX_VERSION}-linux-aarch64-shared 198 dst=sherpa-onnx-${SHERPA_ONNX_VERSION}-linux-aarch64-shared
  199 + if [[ ${{ matrix.gpu }} == OFF ]]; then
  200 + dst=${dst}-cpu
  201 + else
  202 + dst=${dst}-gpu
  203 + fi
159 mkdir $dst 204 mkdir $dst
160 205
161 cp -a build-aarch64-linux-gnu/install/bin $dst/ 206 cp -a build-aarch64-linux-gnu/install/bin $dst/
@@ -166,7 +211,11 @@ jobs: @@ -166,7 +211,11 @@ jobs:
166 211
167 ls -lh $dst/bin/ 212 ls -lh $dst/bin/
168 echo "strip" 213 echo "strip"
169 - aarch64-linux-gnu-strip $dst/bin/* 214 + if [[ ${{ matrix.gpu }} == OFF ]]; then
  215 + aarch64-linux-gnu-strip $dst/bin/*
  216 + else
  217 + aarch64-none-linux-gnu-strip $dst/bin/*
  218 + fi
170 219
171 tree $dst 220 tree $dst
172 221
@@ -174,8 +223,8 @@ jobs: @@ -174,8 +223,8 @@ jobs:
174 223
175 - uses: actions/upload-artifact@v4 224 - uses: actions/upload-artifact@v4
176 with: 225 with:
177 - name: sherpa-onnx-linux-aarch64-shared  
178 - path: sherpa-onnx-*linux-aarch64-shared.tar.bz2 226 + name: sherpa-onnx-linux-aarch64-shared-gpu-${{ matrix.gpu }}
  227 + path: sherpa-onnx-*linux-aarch64-shared*.tar.bz2
179 228
180 # https://huggingface.co/docs/hub/spaces-github-actions 229 # https://huggingface.co/docs/hub/spaces-github-actions
181 - name: Publish to huggingface 230 - name: Publish to huggingface
@@ -198,7 +247,7 @@ jobs: @@ -198,7 +247,7 @@ jobs:
198 cd huggingface 247 cd huggingface
199 mkdir -p aarch64 248 mkdir -p aarch64
200 249
201 - cp -v ../sherpa-onnx-*-shared.tar.bz2 ./aarch64 250 + cp -v ../sherpa-onnx-*-shared*.tar.bz2 ./aarch64
202 251
203 git status 252 git status
204 git lfs track "*.bz2" 253 git lfs track "*.bz2"
@@ -44,6 +44,21 @@ if [[ x"$BUILD_SHARED_LIBS" == x"" ]]; then @@ -44,6 +44,21 @@ if [[ x"$BUILD_SHARED_LIBS" == x"" ]]; then
44 BUILD_SHARED_LIBS=OFF 44 BUILD_SHARED_LIBS=OFF
45 fi 45 fi
46 46
  47 +if [[ x"$SHERPA_ONNX_ENABLE_GPU" == x"" ]]; then
  48 + # By default, use CPU
  49 + SHERPA_ONNX_ENABLE_GPU=OFF
  50 +
  51 + # If you use GPU, then please make sure you have NVIDIA GPUs on your board.
  52 + # It uses onnxruntime 1.11.0.
  53 + #
  54 + # Tested on Jetson Nano B01
  55 +fi
  56 +
  57 +if [[ x"$SHERPA_ONNX_ENABLE_GPU" == x"ON" ]]; then
  58 + # Build shared libs if building GPU is enabled.
  59 + BUILD_SHARED_LIBS=ON
  60 +fi
  61 +
47 cmake \ 62 cmake \
48 -DBUILD_PIPER_PHONMIZE_EXE=OFF \ 63 -DBUILD_PIPER_PHONMIZE_EXE=OFF \
49 -DBUILD_PIPER_PHONMIZE_TESTS=OFF \ 64 -DBUILD_PIPER_PHONMIZE_TESTS=OFF \
@@ -51,6 +66,7 @@ cmake \ @@ -51,6 +66,7 @@ cmake \
51 -DBUILD_ESPEAK_NG_TESTS=OFF \ 66 -DBUILD_ESPEAK_NG_TESTS=OFF \
52 -DCMAKE_INSTALL_PREFIX=./install \ 67 -DCMAKE_INSTALL_PREFIX=./install \
53 -DCMAKE_BUILD_TYPE=Release \ 68 -DCMAKE_BUILD_TYPE=Release \
  69 + -DSHERPA_ONNX_ENABLE_GPU=$SHERPA_ONNX_ENABLE_GPU \
54 -DBUILD_SHARED_LIBS=$BUILD_SHARED_LIBS \ 70 -DBUILD_SHARED_LIBS=$BUILD_SHARED_LIBS \
55 -DSHERPA_ONNX_ENABLE_TESTS=OFF \ 71 -DSHERPA_ONNX_ENABLE_TESTS=OFF \
56 -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ 72 -DSHERPA_ONNX_ENABLE_PYTHON=OFF \
  1 +# Copyright (c) 2022-2024 Xiaomi Corporation
  2 +message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}")
  3 +message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
  4 +
  5 +if(NOT CMAKE_SYSTEM_NAME STREQUAL Linux)
  6 + message(FATAL_ERROR "This file is for Linux only. Given: ${CMAKE_SYSTEM_NAME}")
  7 +endif()
  8 +
  9 +if(NOT CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64)
  10 + message(FATAL_ERROR "This file is for aarch64 only. Given: ${CMAKE_SYSTEM_PROCESSOR}")
  11 +endif()
  12 +
  13 +if(NOT BUILD_SHARED_LIBS)
  14 + message(FATAL_ERROR "This file is for building shared libraries. BUILD_SHARED_LIBS: ${BUILD_SHARED_LIBS}")
  15 +endif()
  16 +
  17 +if(NOT SHERPA_ONNX_ENABLE_GPU)
  18 + message(FATAL_ERROR "This file is for NVIDIA GPU only. Given SHERPA_ONNX_ENABLE_GPU: ${SHERPA_ONNX_ENABLE_GPU}")
  19 +endif()
  20 +
  21 +set(onnxruntime_URL "https://github.com/csukuangfj/onnxruntime-libs/releases/download/v1.11.0/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2")
  22 +set(onnxruntime_URL2 "https://hf-mirror.com/csukuangfj/onnxruntime-libs/resolve/main/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2")
  23 +set(onnxruntime_HASH "SHA256=36eded935551e23aead09d4173bdf0bd1e7b01fdec15d77f97d6e34029aa60d7")
  24 +
  25 +# If you don't have access to the Internet,
  26 +# please download onnxruntime to one of the following locations.
  27 +# You can add more if you want.
  28 +set(possible_file_locations
  29 + $ENV{HOME}/Downloads/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2
  30 + ${CMAKE_SOURCE_DIR}/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2
  31 + ${CMAKE_BINARY_DIR}/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2
  32 + /tmp/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2
  33 + /star-fj/fangjun/download/github/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2
  34 +)
  35 +
  36 +foreach(f IN LISTS possible_file_locations)
  37 + if(EXISTS ${f})
  38 + set(onnxruntime_URL "${f}")
  39 + file(TO_CMAKE_PATH "${onnxruntime_URL}" onnxruntime_URL)
  40 + message(STATUS "Found local downloaded onnxruntime: ${onnxruntime_URL}")
  41 + set(onnxruntime_URL2)
  42 + break()
  43 + endif()
  44 +endforeach()
  45 +
  46 +FetchContent_Declare(onnxruntime
  47 + URL
  48 + ${onnxruntime_URL}
  49 + ${onnxruntime_URL2}
  50 + URL_HASH ${onnxruntime_HASH}
  51 +)
  52 +
  53 +FetchContent_GetProperties(onnxruntime)
  54 +if(NOT onnxruntime_POPULATED)
  55 + message(STATUS "Downloading onnxruntime from ${onnxruntime_URL}")
  56 + FetchContent_Populate(onnxruntime)
  57 +endif()
  58 +message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}")
  59 +
  60 +find_library(location_onnxruntime onnxruntime
  61 + PATHS
  62 + "${onnxruntime_SOURCE_DIR}/lib"
  63 + NO_CMAKE_SYSTEM_PATH
  64 +)
  65 +
  66 +message(STATUS "location_onnxruntime: ${location_onnxruntime}")
  67 +
  68 +add_library(onnxruntime SHARED IMPORTED)
  69 +
  70 +set_target_properties(onnxruntime PROPERTIES
  71 + IMPORTED_LOCATION ${location_onnxruntime}
  72 + INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include"
  73 +)
  74 +
  75 +find_library(location_onnxruntime_cuda_lib onnxruntime_providers_cuda
  76 + PATHS
  77 + "${onnxruntime_SOURCE_DIR}/lib"
  78 + NO_CMAKE_SYSTEM_PATH
  79 +)
  80 +
  81 +add_library(onnxruntime_providers_cuda SHARED IMPORTED)
  82 +set_target_properties(onnxruntime_providers_cuda PROPERTIES
  83 + IMPORTED_LOCATION ${location_onnxruntime_cuda_lib}
  84 +)
  85 +message(STATUS "location_onnxruntime_cuda_lib: ${location_onnxruntime_cuda_lib}")
  86 +
  87 +# for libonnxruntime_providers_shared.so
  88 +find_library(location_onnxruntime_providers_shared_lib onnxruntime_providers_shared
  89 + PATHS
  90 + "${onnxruntime_SOURCE_DIR}/lib"
  91 + NO_CMAKE_SYSTEM_PATH
  92 +)
  93 +add_library(onnxruntime_providers_shared SHARED IMPORTED)
  94 +set_target_properties(onnxruntime_providers_shared PROPERTIES
  95 + IMPORTED_LOCATION ${location_onnxruntime_providers_shared_lib}
  96 +)
  97 +message(STATUS "location_onnxruntime_providers_shared_lib: ${location_onnxruntime_providers_shared_lib}")
  98 +
  99 +file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/lib/libonnxruntime*")
  100 +message(STATUS "onnxruntime lib files: ${onnxruntime_lib_files}")
  101 +install(FILES ${onnxruntime_lib_files} DESTINATION lib)
@@ -13,7 +13,9 @@ function(download_onnxruntime) @@ -13,7 +13,9 @@ function(download_onnxruntime)
13 include(onnxruntime-linux-riscv64-static) 13 include(onnxruntime-linux-riscv64-static)
14 endif() 14 endif()
15 elseif(CMAKE_SYSTEM_NAME STREQUAL Linux AND CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64) 15 elseif(CMAKE_SYSTEM_NAME STREQUAL Linux AND CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64)
16 - if(BUILD_SHARED_LIBS) 16 + if(SHERPA_ONNX_ENABLE_GPU)
  17 + include(onnxruntime-linux-aarch64-gpu)
  18 + elseif(BUILD_SHARED_LIBS)
17 include(onnxruntime-linux-aarch64) 19 include(onnxruntime-linux-aarch64)
18 else() 20 else()
19 include(onnxruntime-linux-aarch64-static) 21 include(onnxruntime-linux-aarch64-static)
1 function(download_piper_phonemize) 1 function(download_piper_phonemize)
2 include(FetchContent) 2 include(FetchContent)
3 3
4 - set(piper_phonemize_URL "https://github.com/csukuangfj/piper-phonemize/archive/dc6b5f4441bffe521047086930b0fc12686acd56.zip")  
5 - set(piper_phonemize_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip")  
6 - set(piper_phonemize_HASH "SHA256=b9faa04204b1756fa455a962abb1f037041c040133d55be58d11f11ab9b3ce14") 4 + set(piper_phonemize_URL "https://github.com/csukuangfj/piper-phonemize/archive/38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip")
  5 + set(piper_phonemize_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip")
  6 + set(piper_phonemize_HASH "SHA256=ab4d06ca76047e1585c63c482f39ffead5315785345055360703cc9382c5e74b")
7 7
8 # If you don't have access to the Internet, 8 # If you don't have access to the Internet,
9 # please pre-download kaldi-decoder 9 # please pre-download kaldi-decoder
10 set(possible_file_locations 10 set(possible_file_locations
11 - $ENV{HOME}/Downloads/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip  
12 - ${CMAKE_SOURCE_DIR}/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip  
13 - ${CMAKE_BINARY_DIR}/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip  
14 - /tmp/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip  
15 - /star-fj/fangjun/download/github/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip 11 + $ENV{HOME}/Downloads/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip
  12 + ${CMAKE_SOURCE_DIR}/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip
  13 + ${CMAKE_BINARY_DIR}/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip
  14 + /tmp/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip
  15 + /star-fj/fangjun/download/github/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip
16 ) 16 )
17 17
18 foreach(f IN LISTS possible_file_locations) 18 foreach(f IN LISTS possible_file_locations)
@@ -7,6 +7,8 @@ @@ -7,6 +7,8 @@
7 #include <stdio.h> 7 #include <stdio.h>
8 #include <stdlib.h> 8 #include <stdlib.h>
9 9
  10 +#include <utility>
  11 +
10 #if __ANDROID_API__ >= 8 12 #if __ANDROID_API__ >= 8
11 #include "android/log.h" 13 #include "android/log.h"
12 #define SHERPA_ONNX_LOGE(...) \ 14 #define SHERPA_ONNX_LOGE(...) \
@@ -36,30 +38,28 @@ @@ -36,30 +38,28 @@
36 #endif 38 #endif
37 39
38 // Read an integer 40 // Read an integer
39 -#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \  
40 - do { \  
41 - auto value = \  
42 - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \  
43 - if (!value) { \  
44 - SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \  
45 - exit(-1); \  
46 - } \  
47 - \  
48 - dst = atoi(value.get()); \  
49 - if (dst < 0) { \  
50 - SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \  
51 - exit(-1); \  
52 - } \ 41 +#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \
  42 + do { \
  43 + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
  44 + if (value.empty()) { \
  45 + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
  46 + exit(-1); \
  47 + } \
  48 + \
  49 + dst = atoi(value.c_str()); \
  50 + if (dst < 0) { \
  51 + SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \
  52 + exit(-1); \
  53 + } \
53 } while (0) 54 } while (0)
54 55
55 #define SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(dst, src_key, default_value) \ 56 #define SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(dst, src_key, default_value) \
56 do { \ 57 do { \
57 - auto value = \  
58 - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \  
59 - if (!value) { \ 58 + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
  59 + if (value.empty()) { \
60 dst = default_value; \ 60 dst = default_value; \
61 } else { \ 61 } else { \
62 - dst = atoi(value.get()); \ 62 + dst = atoi(value.c_str()); \
63 if (dst < 0) { \ 63 if (dst < 0) { \
64 SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \ 64 SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \
65 exit(-1); \ 65 exit(-1); \
@@ -68,118 +68,111 @@ @@ -68,118 +68,111 @@
68 } while (0) 68 } while (0)
69 69
70 // read a vector of integers 70 // read a vector of integers
71 -#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \  
72 - do { \  
73 - auto value = \  
74 - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \  
75 - if (!value) { \  
76 - SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \  
77 - exit(-1); \  
78 - } \  
79 - \  
80 - bool ret = SplitStringToIntegers(value.get(), ",", true, &dst); \  
81 - if (!ret) { \  
82 - SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.get(), src_key); \  
83 - exit(-1); \  
84 - } \ 71 +#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \
  72 + do { \
  73 + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
  74 + if (value.empty()) { \
  75 + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
  76 + exit(-1); \
  77 + } \
  78 + \
  79 + bool ret = SplitStringToIntegers(value.c_str(), ",", true, &dst); \
  80 + if (!ret) { \
  81 + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.c_str(), src_key); \
  82 + exit(-1); \
  83 + } \
85 } while (0) 84 } while (0)
86 85
87 // read a vector of floats 86 // read a vector of floats
88 -#define SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(dst, src_key) \  
89 - do { \  
90 - auto value = \  
91 - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \  
92 - if (!value) { \  
93 - SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \  
94 - exit(-1); \  
95 - } \  
96 - \  
97 - bool ret = SplitStringToFloats(value.get(), ",", true, &dst); \  
98 - if (!ret) { \  
99 - SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.get(), src_key); \  
100 - exit(-1); \  
101 - } \ 87 +#define SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(dst, src_key) \
  88 + do { \
  89 + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
  90 + if (value.empty()) { \
  91 + SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
  92 + exit(-1); \
  93 + } \
  94 + \
  95 + bool ret = SplitStringToFloats(value.c_str(), ",", true, &dst); \
  96 + if (!ret) { \
  97 + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.c_str(), src_key); \
  98 + exit(-1); \
  99 + } \
102 } while (0) 100 } while (0)
103 101
104 // read a vector of strings 102 // read a vector of strings
105 -#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \  
106 - do { \  
107 - auto value = \  
108 - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \  
109 - if (!value) { \  
110 - SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \  
111 - exit(-1); \  
112 - } \  
113 - SplitStringToVector(value.get(), ",", false, &dst); \  
114 - \  
115 - if (dst.empty()) { \  
116 - SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \  
117 - value.get(), src_key); \  
118 - exit(-1); \  
119 - } \ 103 +#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \
  104 + do { \
  105 + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
  106 + if (value.empty()) { \
  107 + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
  108 + exit(-1); \
  109 + } \
  110 + SplitStringToVector(value.c_str(), ",", false, &dst); \
  111 + \
  112 + if (dst.empty()) { \
  113 + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \
  114 + value.c_str(), src_key); \
  115 + exit(-1); \
  116 + } \
120 } while (0) 117 } while (0)
121 118
122 // read a vector of strings separated by sep 119 // read a vector of strings separated by sep
123 -#define SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(dst, src_key, sep) \  
124 - do { \  
125 - auto value = \  
126 - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \  
127 - if (!value) { \  
128 - SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \  
129 - exit(-1); \  
130 - } \  
131 - SplitStringToVector(value.get(), sep, false, &dst); \  
132 - \  
133 - if (dst.empty()) { \  
134 - SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \  
135 - value.get(), src_key); \  
136 - exit(-1); \  
137 - } \ 120 +#define SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(dst, src_key, sep) \
  121 + do { \
  122 + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
  123 + if (value.empty()) { \
  124 + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
  125 + exit(-1); \
  126 + } \
  127 + SplitStringToVector(value.c_str(), sep, false, &dst); \
  128 + \
  129 + if (dst.empty()) { \
  130 + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \
  131 + value.c_str(), src_key); \
  132 + exit(-1); \
  133 + } \
138 } while (0) 134 } while (0)
139 135
140 // Read a string 136 // Read a string
141 -#define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \  
142 - do { \  
143 - auto value = \  
144 - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \  
145 - if (!value) { \  
146 - SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \  
147 - exit(-1); \  
148 - } \  
149 - \  
150 - dst = value.get(); \  
151 - if (dst.empty()) { \  
152 - SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \  
153 - exit(-1); \  
154 - } \ 137 +#define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \
  138 + do { \
  139 + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
  140 + if (value.empty()) { \
  141 + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
  142 + exit(-1); \
  143 + } \
  144 + \
  145 + dst = std::move(value); \
  146 + if (dst.empty()) { \
  147 + SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \
  148 + exit(-1); \
  149 + } \
155 } while (0) 150 } while (0)
156 151
157 -#define SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(dst, src_key) \  
158 - do { \  
159 - auto value = \  
160 - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \  
161 - if (!value) { \  
162 - SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \  
163 - exit(-1); \  
164 - } \  
165 - \  
166 - dst = value.get(); \ 152 +#define SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(dst, src_key) \
  153 + do { \
  154 + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
  155 + if (value.empty()) { \
  156 + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
  157 + exit(-1); \
  158 + } \
  159 + \
  160 + dst = std::move(value); \
167 } while (0) 161 } while (0)
168 162
169 -#define SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(dst, src_key, \  
170 - default_value) \  
171 - do { \  
172 - auto value = \  
173 - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \  
174 - if (!value) { \  
175 - dst = default_value; \  
176 - } else { \  
177 - dst = value.get(); \  
178 - if (dst.empty()) { \  
179 - SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \  
180 - exit(-1); \  
181 - } \  
182 - } \ 163 +#define SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(dst, src_key, \
  164 + default_value) \
  165 + do { \
  166 + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
  167 + if (value.empty()) { \
  168 + dst = default_value; \
  169 + } else { \
  170 + dst = std::move(value); \
  171 + if (dst.empty()) { \
  172 + SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \
  173 + exit(-1); \
  174 + } \
  175 + } \
183 } while (0) 176 } while (0)
184 177
185 #define SHERPA_ONNX_EXIT(code) exit(code) 178 #define SHERPA_ONNX_EXIT(code) exit(code)
@@ -46,7 +46,7 @@ class OfflineCEDModel::Impl { @@ -46,7 +46,7 @@ class OfflineCEDModel::Impl {
46 46
47 int32_t NumEventClasses() const { return num_event_classes_; } 47 int32_t NumEventClasses() const { return num_event_classes_; }
48 48
49 - OrtAllocator *Allocator() const { return allocator_; } 49 + OrtAllocator *Allocator() { return allocator_; }
50 50
51 private: 51 private:
52 void Init(void *model_data, size_t model_data_length) { 52 void Init(void *model_data, size_t model_data_length) {
@@ -44,7 +44,7 @@ class OfflineCtTransformerModel::Impl { @@ -44,7 +44,7 @@ class OfflineCtTransformerModel::Impl {
44 return std::move(ans[0]); 44 return std::move(ans[0]);
45 } 45 }
46 46
47 - OrtAllocator *Allocator() const { return allocator_; } 47 + OrtAllocator *Allocator() { return allocator_; }
48 48
49 const OfflineCtTransformerModelMetaData &GetModelMetadata() const { 49 const OfflineCtTransformerModelMetaData &GetModelMetadata() const {
50 return meta_data_; 50 return meta_data_;
@@ -53,8 +53,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -53,8 +53,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
53 53
54 Ort::AllocatorWithDefaultOptions allocator; 54 Ort::AllocatorWithDefaultOptions allocator;
55 auto model_type = 55 auto model_type =
56 - meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);  
57 - if (!model_type) { 56 + LookupCustomModelMetaData(meta_data, "model_type", allocator);
  57 + if (model_type.empty()) {
58 SHERPA_ONNX_LOGE( 58 SHERPA_ONNX_LOGE(
59 "No model_type in the metadata!\n" 59 "No model_type in the metadata!\n"
60 "If you are using models from NeMo, please refer to\n" 60 "If you are using models from NeMo, please refer to\n"
@@ -74,22 +74,22 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -74,22 +74,22 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
74 return ModelType::kUnknown; 74 return ModelType::kUnknown;
75 } 75 }
76 76
77 - if (model_type.get() == std::string("EncDecCTCModelBPE")) { 77 + if (model_type == "EncDecCTCModelBPE") {
78 return ModelType::kEncDecCTCModelBPE; 78 return ModelType::kEncDecCTCModelBPE;
79 - } else if (model_type.get() == std::string("EncDecCTCModel")) { 79 + } else if (model_type == "EncDecCTCModel") {
80 return ModelType::kEncDecCTCModel; 80 return ModelType::kEncDecCTCModel;
81 - } else if (model_type.get() == std::string("EncDecHybridRNNTCTCBPEModel")) { 81 + } else if (model_type == "EncDecHybridRNNTCTCBPEModel") {
82 return ModelType::kEncDecHybridRNNTCTCBPEModel; 82 return ModelType::kEncDecHybridRNNTCTCBPEModel;
83 - } else if (model_type.get() == std::string("tdnn")) { 83 + } else if (model_type == "tdnn") {
84 return ModelType::kTdnn; 84 return ModelType::kTdnn;
85 - } else if (model_type.get() == std::string("zipformer2_ctc")) { 85 + } else if (model_type == "zipformer2_ctc") {
86 return ModelType::kZipformerCtc; 86 return ModelType::kZipformerCtc;
87 - } else if (model_type.get() == std::string("wenet_ctc")) { 87 + } else if (model_type == "wenet_ctc") {
88 return ModelType::kWenetCtc; 88 return ModelType::kWenetCtc;
89 - } else if (model_type.get() == std::string("telespeech_ctc")) { 89 + } else if (model_type == "telespeech_ctc") {
90 return ModelType::kTeleSpeechCtc; 90 return ModelType::kTeleSpeechCtc;
91 } else { 91 } else {
92 - SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); 92 + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str());
93 return ModelType::kUnknown; 93 return ModelType::kUnknown;
94 } 94 }
95 } 95 }
@@ -155,7 +155,7 @@ class OfflineMoonshineModel::Impl { @@ -155,7 +155,7 @@ class OfflineMoonshineModel::Impl {
155 return {std::move(cached_decoder_out[0]), std::move(next_states)}; 155 return {std::move(cached_decoder_out[0]), std::move(next_states)};
156 } 156 }
157 157
158 - OrtAllocator *Allocator() const { return allocator_; } 158 + OrtAllocator *Allocator() { return allocator_; }
159 159
160 private: 160 private:
161 void InitPreprocessor(void *model_data, size_t model_data_length) { 161 void InitPreprocessor(void *model_data, size_t model_data_length) {
@@ -68,7 +68,7 @@ class OfflineNemoEncDecCtcModel::Impl { @@ -68,7 +68,7 @@ class OfflineNemoEncDecCtcModel::Impl {
68 68
69 int32_t SubsamplingFactor() const { return subsampling_factor_; } 69 int32_t SubsamplingFactor() const { return subsampling_factor_; }
70 70
71 - OrtAllocator *Allocator() const { return allocator_; } 71 + OrtAllocator *Allocator() { return allocator_; }
72 72
73 std::string FeatureNormalizationMethod() const { return normalize_type_; } 73 std::string FeatureNormalizationMethod() const { return normalize_type_; }
74 74
@@ -56,7 +56,7 @@ class OfflineParaformerModel::Impl { @@ -56,7 +56,7 @@ class OfflineParaformerModel::Impl {
56 56
57 const std::vector<float> &InverseStdDev() const { return inv_stddev_; } 57 const std::vector<float> &InverseStdDev() const { return inv_stddev_; }
58 58
59 - OrtAllocator *Allocator() const { return allocator_; } 59 + OrtAllocator *Allocator() { return allocator_; }
60 60
61 private: 61 private:
62 void Init(void *model_data, size_t model_data_length) { 62 void Init(void *model_data, size_t model_data_length) {
@@ -121,9 +121,9 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -121,9 +121,9 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
121 121
122 Ort::AllocatorWithDefaultOptions allocator; // used in the macro below 122 Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
123 123
124 - auto model_type_ptr =  
125 - meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);  
126 - if (!model_type_ptr) { 124 + auto model_type =
  125 + LookupCustomModelMetaData(meta_data, "model_type", allocator);
  126 + if (!model_type.empty()) {
127 SHERPA_ONNX_LOGE( 127 SHERPA_ONNX_LOGE(
128 "No model_type in the metadata!\n\n" 128 "No model_type in the metadata!\n\n"
129 "Please refer to the following URLs to add metadata" 129 "Please refer to the following URLs to add metadata"
@@ -164,7 +164,6 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -164,7 +164,6 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
164 "\n"); 164 "\n");
165 exit(-1); 165 exit(-1);
166 } 166 }
167 - std::string model_type(model_type_ptr.get());  
168 167
169 if (model_type == "conformer" || model_type == "zipformer" || 168 if (model_type == "conformer" || model_type == "zipformer" ||
170 model_type == "zipformer2") { 169 model_type == "zipformer2") {
@@ -301,9 +300,9 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -301,9 +300,9 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
301 300
302 Ort::AllocatorWithDefaultOptions allocator; // used in the macro below 301 Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
303 302
304 - auto model_type_ptr =  
305 - meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);  
306 - if (!model_type_ptr) { 303 + auto model_type =
  304 + LookupCustomModelMetaData(meta_data, "model_type", allocator);
  305 + if (model_type.empty()) {
307 SHERPA_ONNX_LOGE( 306 SHERPA_ONNX_LOGE(
308 "No model_type in the metadata!\n\n" 307 "No model_type in the metadata!\n\n"
309 "Please refer to the following URLs to add metadata" 308 "Please refer to the following URLs to add metadata"
@@ -344,7 +343,6 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -344,7 +343,6 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
344 "\n"); 343 "\n");
345 exit(-1); 344 exit(-1);
346 } 345 }
347 - std::string model_type(model_type_ptr.get());  
348 346
349 if (model_type == "conformer" || model_type == "zipformer" || 347 if (model_type == "conformer" || model_type == "zipformer" ||
350 model_type == "zipformer2") { 348 model_type == "zipformer2") {
@@ -56,7 +56,7 @@ class OfflineSenseVoiceModel::Impl { @@ -56,7 +56,7 @@ class OfflineSenseVoiceModel::Impl {
56 return meta_data_; 56 return meta_data_;
57 } 57 }
58 58
59 - OrtAllocator *Allocator() const { return allocator_; } 59 + OrtAllocator *Allocator() { return allocator_; }
60 60
61 private: 61 private:
62 void Init(void *model_data, size_t model_data_length) { 62 void Init(void *model_data, size_t model_data_length) {
@@ -63,7 +63,7 @@ class OfflineTdnnCtcModel::Impl { @@ -63,7 +63,7 @@ class OfflineTdnnCtcModel::Impl {
63 63
64 int32_t VocabSize() const { return vocab_size_; } 64 int32_t VocabSize() const { return vocab_size_; }
65 65
66 - OrtAllocator *Allocator() const { return allocator_; } 66 + OrtAllocator *Allocator() { return allocator_; }
67 67
68 private: 68 private:
69 void Init(void *model_data, size_t model_data_length) { 69 void Init(void *model_data, size_t model_data_length) {
@@ -69,7 +69,7 @@ class OfflineTeleSpeechCtcModel::Impl { @@ -69,7 +69,7 @@ class OfflineTeleSpeechCtcModel::Impl {
69 69
70 int32_t SubsamplingFactor() const { return subsampling_factor_; } 70 int32_t SubsamplingFactor() const { return subsampling_factor_; }
71 71
72 - OrtAllocator *Allocator() const { return allocator_; } 72 + OrtAllocator *Allocator() { return allocator_; }
73 73
74 private: 74 private:
75 void Init(void *model_data, size_t model_data_length) { 75 void Init(void *model_data, size_t model_data_length) {
@@ -95,11 +95,11 @@ class OfflineTransducerModel::Impl { @@ -95,11 +95,11 @@ class OfflineTransducerModel::Impl {
95 int32_t VocabSize() const { return vocab_size_; } 95 int32_t VocabSize() const { return vocab_size_; }
96 int32_t ContextSize() const { return context_size_; } 96 int32_t ContextSize() const { return context_size_; }
97 int32_t SubsamplingFactor() const { return 4; } 97 int32_t SubsamplingFactor() const { return 4; }
98 - OrtAllocator *Allocator() const { return allocator_; } 98 + OrtAllocator *Allocator() { return allocator_; }
99 99
100 Ort::Value BuildDecoderInput( 100 Ort::Value BuildDecoderInput(
101 const std::vector<OfflineTransducerDecoderResult> &results, 101 const std::vector<OfflineTransducerDecoderResult> &results,
102 - int32_t end_index) const { 102 + int32_t end_index) {
103 assert(end_index <= results.size()); 103 assert(end_index <= results.size());
104 104
105 int32_t batch_size = end_index; 105 int32_t batch_size = end_index;
@@ -122,7 +122,7 @@ class OfflineTransducerModel::Impl { @@ -122,7 +122,7 @@ class OfflineTransducerModel::Impl {
122 } 122 }
123 123
124 Ort::Value BuildDecoderInput(const std::vector<Hypothesis> &results, 124 Ort::Value BuildDecoderInput(const std::vector<Hypothesis> &results,
125 - int32_t end_index) const { 125 + int32_t end_index) {
126 assert(end_index <= results.size()); 126 assert(end_index <= results.size());
127 127
128 int32_t batch_size = end_index; 128 int32_t batch_size = end_index;
@@ -123,7 +123,7 @@ class OfflineTransducerNeMoModel::Impl { @@ -123,7 +123,7 @@ class OfflineTransducerNeMoModel::Impl {
123 return std::move(logit[0]); 123 return std::move(logit[0]);
124 } 124 }
125 125
126 - std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const { 126 + std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) {
127 std::array<int64_t, 3> s0_shape{pred_rnn_layers_, batch_size, pred_hidden_}; 127 std::array<int64_t, 3> s0_shape{pred_rnn_layers_, batch_size, pred_hidden_};
128 Ort::Value s0 = Ort::Value::CreateTensor<float>(allocator_, s0_shape.data(), 128 Ort::Value s0 = Ort::Value::CreateTensor<float>(allocator_, s0_shape.data(),
129 s0_shape.size()); 129 s0_shape.size());
@@ -149,7 +149,7 @@ class OfflineTransducerNeMoModel::Impl { @@ -149,7 +149,7 @@ class OfflineTransducerNeMoModel::Impl {
149 int32_t SubsamplingFactor() const { return subsampling_factor_; } 149 int32_t SubsamplingFactor() const { return subsampling_factor_; }
150 int32_t VocabSize() const { return vocab_size_; } 150 int32_t VocabSize() const { return vocab_size_; }
151 151
152 - OrtAllocator *Allocator() const { return allocator_; } 152 + OrtAllocator *Allocator() { return allocator_; }
153 153
154 std::string FeatureNormalizationMethod() const { return normalize_type_; } 154 std::string FeatureNormalizationMethod() const { return normalize_type_; }
155 155
@@ -47,7 +47,7 @@ class OfflineWenetCtcModel::Impl { @@ -47,7 +47,7 @@ class OfflineWenetCtcModel::Impl {
47 47
48 int32_t SubsamplingFactor() const { return subsampling_factor_; } 48 int32_t SubsamplingFactor() const { return subsampling_factor_; }
49 49
50 - OrtAllocator *Allocator() const { return allocator_; } 50 + OrtAllocator *Allocator() { return allocator_; }
51 51
52 private: 52 private:
53 void Init(void *model_data, size_t model_data_length) { 53 void Init(void *model_data, size_t model_data_length) {
@@ -188,7 +188,7 @@ class OfflineWhisperModel::Impl { @@ -188,7 +188,7 @@ class OfflineWhisperModel::Impl {
188 return {std::move(n_layer_self_k_cache), std::move(n_layer_self_v_cache)}; 188 return {std::move(n_layer_self_k_cache), std::move(n_layer_self_v_cache)};
189 } 189 }
190 190
191 - OrtAllocator *Allocator() const { return allocator_; } 191 + OrtAllocator *Allocator() { return allocator_; }
192 192
193 const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; } 193 const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; }
194 194
@@ -47,7 +47,7 @@ class OfflineZipformerAudioTaggingModel::Impl { @@ -47,7 +47,7 @@ class OfflineZipformerAudioTaggingModel::Impl {
47 47
48 int32_t NumEventClasses() const { return num_event_classes_; } 48 int32_t NumEventClasses() const { return num_event_classes_; }
49 49
50 - OrtAllocator *Allocator() const { return allocator_; } 50 + OrtAllocator *Allocator() { return allocator_; }
51 51
52 private: 52 private:
53 void Init(void *model_data, size_t model_data_length) { 53 void Init(void *model_data, size_t model_data_length) {
@@ -48,7 +48,7 @@ class OfflineZipformerCtcModel::Impl { @@ -48,7 +48,7 @@ class OfflineZipformerCtcModel::Impl {
48 int32_t VocabSize() const { return vocab_size_; } 48 int32_t VocabSize() const { return vocab_size_; }
49 int32_t SubsamplingFactor() const { return 4; } 49 int32_t SubsamplingFactor() const { return 4; }
50 50
51 - OrtAllocator *Allocator() const { return allocator_; } 51 + OrtAllocator *Allocator() { return allocator_; }
52 52
53 private: 53 private:
54 void Init(void *model_data, size_t model_data_length) { 54 void Init(void *model_data, size_t model_data_length) {
@@ -47,7 +47,7 @@ class OnlineCNNBiLSTMModel::Impl { @@ -47,7 +47,7 @@ class OnlineCNNBiLSTMModel::Impl {
47 return {std::move(ans[0]), std::move(ans[1])}; 47 return {std::move(ans[0]), std::move(ans[1])};
48 } 48 }
49 49
50 - OrtAllocator *Allocator() const { return allocator_; } 50 + OrtAllocator *Allocator() { return allocator_; }
51 51
52 const OnlineCNNBiLSTMModelMetaData &GetModelMetadata() const { 52 const OnlineCNNBiLSTMModelMetaData &GetModelMetadata() const {
53 return meta_data_; 53 return meta_data_;
@@ -163,8 +163,11 @@ std::vector<Ort::Value> OnlineConformerTransducerModel::StackStates( @@ -163,8 +163,11 @@ std::vector<Ort::Value> OnlineConformerTransducerModel::StackStates(
163 conv_vec[i] = &states[i][1]; 163 conv_vec[i] = &states[i][1];
164 } 164 }
165 165
166 - Ort::Value attn = Cat(allocator_, attn_vec, 2);  
167 - Ort::Value conv = Cat(allocator_, conv_vec, 2); 166 + auto allocator =
  167 + const_cast<OnlineConformerTransducerModel *>(this)->allocator_;
  168 +
  169 + Ort::Value attn = Cat(allocator, attn_vec, 2);
  170 + Ort::Value conv = Cat(allocator, conv_vec, 2);
168 171
169 std::vector<Ort::Value> ans; 172 std::vector<Ort::Value> ans;
170 ans.reserve(2); 173 ans.reserve(2);
@@ -183,8 +186,11 @@ OnlineConformerTransducerModel::UnStackStates( @@ -183,8 +186,11 @@ OnlineConformerTransducerModel::UnStackStates(
183 186
184 std::vector<std::vector<Ort::Value>> ans(batch_size); 187 std::vector<std::vector<Ort::Value>> ans(batch_size);
185 188
186 - std::vector<Ort::Value> attn_vec = Unbind(allocator_, &states[0], 2);  
187 - std::vector<Ort::Value> conv_vec = Unbind(allocator_, &states[1], 2); 189 + auto allocator =
  190 + const_cast<OnlineConformerTransducerModel *>(this)->allocator_;
  191 +
  192 + std::vector<Ort::Value> attn_vec = Unbind(allocator, &states[0], 2);
  193 + std::vector<Ort::Value> conv_vec = Unbind(allocator, &states[1], 2);
188 194
189 assert(attn_vec.size() == batch_size); 195 assert(attn_vec.size() == batch_size);
190 assert(conv_vec.size() == batch_size); 196 assert(conv_vec.size() == batch_size);
@@ -158,9 +158,10 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::StackStates( @@ -158,9 +158,10 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::StackStates(
158 h_buf[i] = &states[i][0]; 158 h_buf[i] = &states[i][0];
159 c_buf[i] = &states[i][1]; 159 c_buf[i] = &states[i][1];
160 } 160 }
  161 + auto allocator = const_cast<OnlineLstmTransducerModel *>(this)->allocator_;
161 162
162 - Ort::Value h = Cat(allocator_, h_buf, 1);  
163 - Ort::Value c = Cat(allocator_, c_buf, 1); 163 + Ort::Value h = Cat(allocator, h_buf, 1);
  164 + Ort::Value c = Cat(allocator, c_buf, 1);
164 165
165 std::vector<Ort::Value> ans; 166 std::vector<Ort::Value> ans;
166 ans.reserve(2); 167 ans.reserve(2);
@@ -177,8 +178,10 @@ std::vector<std::vector<Ort::Value>> OnlineLstmTransducerModel::UnStackStates( @@ -177,8 +178,10 @@ std::vector<std::vector<Ort::Value>> OnlineLstmTransducerModel::UnStackStates(
177 178
178 std::vector<std::vector<Ort::Value>> ans(batch_size); 179 std::vector<std::vector<Ort::Value>> ans(batch_size);
179 180
180 - std::vector<Ort::Value> h_vec = Unbind(allocator_, &states[0], 1);  
181 - std::vector<Ort::Value> c_vec = Unbind(allocator_, &states[1], 1); 181 + auto allocator = const_cast<OnlineLstmTransducerModel *>(this)->allocator_;
  182 +
  183 + std::vector<Ort::Value> h_vec = Unbind(allocator, &states[0], 1);
  184 + std::vector<Ort::Value> c_vec = Unbind(allocator, &states[1], 1);
182 185
183 assert(h_vec.size() == batch_size); 186 assert(h_vec.size() == batch_size);
184 assert(c_vec.size() == batch_size); 187 assert(c_vec.size() == batch_size);
@@ -102,7 +102,7 @@ class OnlineNeMoCtcModel::Impl { @@ -102,7 +102,7 @@ class OnlineNeMoCtcModel::Impl {
102 102
103 int32_t ChunkShift() const { return chunk_shift_; } 103 int32_t ChunkShift() const { return chunk_shift_; }
104 104
105 - OrtAllocator *Allocator() const { return allocator_; } 105 + OrtAllocator *Allocator() { return allocator_; }
106 106
107 // Return a vector containing 3 tensors 107 // Return a vector containing 3 tensors
108 // - cache_last_channel 108 // - cache_last_channel
@@ -119,7 +119,7 @@ class OnlineNeMoCtcModel::Impl { @@ -119,7 +119,7 @@ class OnlineNeMoCtcModel::Impl {
119 } 119 }
120 120
121 std::vector<Ort::Value> StackStates( 121 std::vector<Ort::Value> StackStates(
122 - std::vector<std::vector<Ort::Value>> states) const { 122 + std::vector<std::vector<Ort::Value>> states) {
123 int32_t batch_size = static_cast<int32_t>(states.size()); 123 int32_t batch_size = static_cast<int32_t>(states.size());
124 if (batch_size == 1) { 124 if (batch_size == 1) {
125 return std::move(states[0]); 125 return std::move(states[0]);
@@ -157,6 +157,8 @@ class OnlineNeMoCtcModel::Impl { @@ -157,6 +157,8 @@ class OnlineNeMoCtcModel::Impl {
157 std::vector<Ort::Value> states) const { 157 std::vector<Ort::Value> states) const {
158 assert(states.size() == 3); 158 assert(states.size() == 3);
159 159
  160 + auto allocator = const_cast<Impl *>(this)->allocator_;
  161 +
160 std::vector<std::vector<Ort::Value>> ans; 162 std::vector<std::vector<Ort::Value>> ans;
161 163
162 auto shape = states[0].GetTensorTypeAndShapeInfo().GetShape(); 164 auto shape = states[0].GetTensorTypeAndShapeInfo().GetShape();
@@ -171,9 +173,9 @@ class OnlineNeMoCtcModel::Impl { @@ -171,9 +173,9 @@ class OnlineNeMoCtcModel::Impl {
171 for (int32_t i = 0; i != 3; ++i) { 173 for (int32_t i = 0; i != 3; ++i) {
172 std::vector<Ort::Value> v; 174 std::vector<Ort::Value> v;
173 if (i == 2) { 175 if (i == 2) {
174 - v = Unbind<int64_t>(allocator_, &states[i], 0); 176 + v = Unbind<int64_t>(allocator, &states[i], 0);
175 } else { 177 } else {
176 - v = Unbind(allocator_, &states[i], 0); 178 + v = Unbind(allocator, &states[i], 0);
177 } 179 }
178 180
179 assert(v.size() == batch_size); 181 assert(v.size() == batch_size);
@@ -105,7 +105,7 @@ class OnlineParaformerModel::Impl { @@ -105,7 +105,7 @@ class OnlineParaformerModel::Impl {
105 105
106 const std::vector<float> &InverseStdDev() const { return inv_stddev_; } 106 const std::vector<float> &InverseStdDev() const { return inv_stddev_; }
107 107
108 - OrtAllocator *Allocator() const { return allocator_; } 108 + OrtAllocator *Allocator() { return allocator_; }
109 109
110 private: 110 private:
111 void InitEncoder(void *model_data, size_t model_data_length) { 111 void InitEncoder(void *model_data, size_t model_data_length) {
@@ -5,10 +5,10 @@ @@ -5,10 +5,10 @@
5 5
6 #include "sherpa-onnx/csrc/online-rnn-lm.h" 6 #include "sherpa-onnx/csrc/online-rnn-lm.h"
7 7
  8 +#include <algorithm>
8 #include <string> 9 #include <string>
9 #include <utility> 10 #include <utility>
10 #include <vector> 11 #include <vector>
11 -#include <algorithm>  
12 12
13 #include "onnxruntime_cxx_api.h" // NOLINT 13 #include "onnxruntime_cxx_api.h" // NOLINT
14 #include "sherpa-onnx/csrc/macros.h" 14 #include "sherpa-onnx/csrc/macros.h"
@@ -53,49 +53,49 @@ class OnlineRnnLM::Impl { @@ -53,49 +53,49 @@ class OnlineRnnLM::Impl {
53 53
54 // classic rescore function 54 // classic rescore function
55 void ComputeLMScore(float scale, int32_t context_size, 55 void ComputeLMScore(float scale, int32_t context_size,
56 - std::vector<Hypotheses> *hyps) {  
57 - Ort::AllocatorWithDefaultOptions allocator;  
58 -  
59 - for (auto &hyp : *hyps) {  
60 - for (auto &h_m : hyp) {  
61 - auto &h = h_m.second;  
62 - auto &ys = h.ys;  
63 - const int32_t token_num_in_chunk =  
64 - ys.size() - context_size - h.cur_scored_pos - 1;  
65 -  
66 - if (token_num_in_chunk < 1) {  
67 - continue;  
68 - }  
69 -  
70 - if (h.nn_lm_states.empty()) {  
71 - h.nn_lm_states = Convert(GetInitStates());  
72 - }  
73 -  
74 - if (token_num_in_chunk >= h.lm_rescore_min_chunk) {  
75 - std::array<int64_t, 2> x_shape{1, token_num_in_chunk};  
76 -  
77 - Ort::Value x = Ort::Value::CreateTensor<int64_t>(  
78 - allocator, x_shape.data(), x_shape.size());  
79 - int64_t *p_x = x.GetTensorMutableData<int64_t>();  
80 - std::copy(ys.begin() + context_size + h.cur_scored_pos,  
81 - ys.end() - 1, p_x);  
82 -  
83 - // streaming forward by NN LM  
84 - auto out = ScoreToken(std::move(x),  
85 - Convert(std::move(h.nn_lm_states)));  
86 -  
87 - // update NN LM score in hyp  
88 - const float *p_nll = out.first.GetTensorData<float>();  
89 - h.lm_log_prob = -scale * (*p_nll);  
90 -  
91 - // update NN LM states in hyp  
92 - h.nn_lm_states = Convert(std::move(out.second));  
93 -  
94 - h.cur_scored_pos += token_num_in_chunk;  
95 - } 56 + std::vector<Hypotheses> *hyps) {
  57 + Ort::AllocatorWithDefaultOptions allocator;
  58 +
  59 + for (auto &hyp : *hyps) {
  60 + for (auto &h_m : hyp) {
  61 + auto &h = h_m.second;
  62 + auto &ys = h.ys;
  63 + const int32_t token_num_in_chunk =
  64 + ys.size() - context_size - h.cur_scored_pos - 1;
  65 +
  66 + if (token_num_in_chunk < 1) {
  67 + continue;
  68 + }
  69 +
  70 + if (h.nn_lm_states.empty()) {
  71 + h.nn_lm_states = Convert(GetInitStates());
  72 + }
  73 +
  74 + if (token_num_in_chunk >= h.lm_rescore_min_chunk) {
  75 + std::array<int64_t, 2> x_shape{1, token_num_in_chunk};
  76 +
  77 + Ort::Value x = Ort::Value::CreateTensor<int64_t>(
  78 + allocator, x_shape.data(), x_shape.size());
  79 + int64_t *p_x = x.GetTensorMutableData<int64_t>();
  80 + std::copy(ys.begin() + context_size + h.cur_scored_pos, ys.end() - 1,
  81 + p_x);
  82 +
  83 + // streaming forward by NN LM
  84 + auto out =
  85 + ScoreToken(std::move(x), Convert(std::move(h.nn_lm_states)));
  86 +
  87 + // update NN LM score in hyp
  88 + const float *p_nll = out.first.GetTensorData<float>();
  89 + h.lm_log_prob = -scale * (*p_nll);
  90 +
  91 + // update NN LM states in hyp
  92 + h.nn_lm_states = Convert(std::move(out.second));
  93 +
  94 + h.cur_scored_pos += token_num_in_chunk;
96 } 95 }
97 } 96 }
98 } 97 }
  98 + }
99 99
100 std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken( 100 std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
101 Ort::Value x, std::vector<Ort::Value> states) { 101 Ort::Value x, std::vector<Ort::Value> states) {
@@ -125,7 +125,7 @@ class OnlineRnnLM::Impl { @@ -125,7 +125,7 @@ class OnlineRnnLM::Impl {
125 } 125 }
126 126
127 // get init states for classic rescore 127 // get init states for classic rescore
128 - std::vector<Ort::Value> GetInitStates() const { 128 + std::vector<Ort::Value> GetInitStates() {
129 std::vector<Ort::Value> ans; 129 std::vector<Ort::Value> ans;
130 ans.reserve(init_states_.size()); 130 ans.reserve(init_states_.size());
131 131
@@ -226,7 +226,7 @@ std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::ScoreToken( @@ -226,7 +226,7 @@ std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::ScoreToken(
226 226
227 // classic rescore scores 227 // classic rescore scores
228 void OnlineRnnLM::ComputeLMScore(float scale, int32_t context_size, 228 void OnlineRnnLM::ComputeLMScore(float scale, int32_t context_size,
229 - std::vector<Hypotheses> *hyps) { 229 + std::vector<Hypotheses> *hyps) {
230 return impl_->ComputeLMScore(scale, context_size, hyps); 230 return impl_->ComputeLMScore(scale, context_size, hyps);
231 } 231 }
232 232
@@ -235,5 +235,4 @@ void OnlineRnnLM::ComputeLMScoreSF(float scale, Hypothesis *hyp) { @@ -235,5 +235,4 @@ void OnlineRnnLM::ComputeLMScoreSF(float scale, Hypothesis *hyp) {
235 return impl_->ComputeLMScoreSF(scale, hyp); 235 return impl_->ComputeLMScoreSF(scale, hyp);
236 } 236 }
237 237
238 -  
239 } // namespace sherpa_onnx 238 } // namespace sherpa_onnx
@@ -54,8 +54,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -54,8 +54,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
54 54
55 Ort::AllocatorWithDefaultOptions allocator; 55 Ort::AllocatorWithDefaultOptions allocator;
56 auto model_type = 56 auto model_type =
57 - meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);  
58 - if (!model_type) { 57 + LookupCustomModelMetaData(meta_data, "model_type", allocator);
  58 + if (model_type.empty()) {
59 SHERPA_ONNX_LOGE( 59 SHERPA_ONNX_LOGE(
60 "No model_type in the metadata!\n" 60 "No model_type in the metadata!\n"
61 "Please make sure you are using the latest export-onnx.py from icefall " 61 "Please make sure you are using the latest export-onnx.py from icefall "
@@ -63,16 +63,16 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -63,16 +63,16 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
63 return ModelType::kUnknown; 63 return ModelType::kUnknown;
64 } 64 }
65 65
66 - if (model_type.get() == std::string("conformer")) { 66 + if (model_type == "conformer") {
67 return ModelType::kConformer; 67 return ModelType::kConformer;
68 - } else if (model_type.get() == std::string("lstm")) { 68 + } else if (model_type == "lstm") {
69 return ModelType::kLstm; 69 return ModelType::kLstm;
70 - } else if (model_type.get() == std::string("zipformer")) { 70 + } else if (model_type == "zipformer") {
71 return ModelType::kZipformer; 71 return ModelType::kZipformer;
72 - } else if (model_type.get() == std::string("zipformer2")) { 72 + } else if (model_type == "zipformer2") {
73 return ModelType::kZipformer2; 73 return ModelType::kZipformer2;
74 } else { 74 } else {
75 - SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); 75 + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str());
76 return ModelType::kUnknown; 76 return ModelType::kUnknown;
77 } 77 }
78 } 78 }
@@ -197,7 +197,7 @@ class OnlineTransducerNeMoModel::Impl { @@ -197,7 +197,7 @@ class OnlineTransducerNeMoModel::Impl {
197 197
198 int32_t VocabSize() const { return vocab_size_; } 198 int32_t VocabSize() const { return vocab_size_; }
199 199
200 - OrtAllocator *Allocator() const { return allocator_; } 200 + OrtAllocator *Allocator() { return allocator_; }
201 201
202 std::string FeatureNormalizationMethod() const { return normalize_type_; } 202 std::string FeatureNormalizationMethod() const { return normalize_type_; }
203 203
@@ -224,6 +224,8 @@ class OnlineTransducerNeMoModel::Impl { @@ -224,6 +224,8 @@ class OnlineTransducerNeMoModel::Impl {
224 224
225 std::vector<Ort::Value> ans; 225 std::vector<Ort::Value> ans;
226 226
  227 + auto allocator = const_cast<Impl *>(this)->allocator_;
  228 +
227 // stack cache_last_channel 229 // stack cache_last_channel
228 std::vector<const Ort::Value *> buf(batch_size); 230 std::vector<const Ort::Value *> buf(batch_size);
229 231
@@ -239,9 +241,9 @@ class OnlineTransducerNeMoModel::Impl { @@ -239,9 +241,9 @@ class OnlineTransducerNeMoModel::Impl {
239 241
240 Ort::Value c{nullptr}; 242 Ort::Value c{nullptr};
241 if (i == 2) { 243 if (i == 2) {
242 - c = Cat<int64_t>(allocator_, buf, 0); 244 + c = Cat<int64_t>(allocator, buf, 0);
243 } else { 245 } else {
244 - c = Cat(allocator_, buf, 0); 246 + c = Cat(allocator, buf, 0);
245 } 247 }
246 248
247 ans.push_back(std::move(c)); 249 ans.push_back(std::move(c));
@@ -251,7 +253,7 @@ class OnlineTransducerNeMoModel::Impl { @@ -251,7 +253,7 @@ class OnlineTransducerNeMoModel::Impl {
251 } 253 }
252 254
253 std::vector<std::vector<Ort::Value>> UnStackStates( 255 std::vector<std::vector<Ort::Value>> UnStackStates(
254 - std::vector<Ort::Value> states) const { 256 + std::vector<Ort::Value> states) {
255 assert(states.size() == 3); 257 assert(states.size() == 3);
256 258
257 std::vector<std::vector<Ort::Value>> ans; 259 std::vector<std::vector<Ort::Value>> ans;
@@ -101,7 +101,7 @@ class OnlineWenetCtcModel::Impl { @@ -101,7 +101,7 @@ class OnlineWenetCtcModel::Impl {
101 return config_.wenet_ctc.chunk_size * subsampling_factor_; 101 return config_.wenet_ctc.chunk_size * subsampling_factor_;
102 } 102 }
103 103
104 - OrtAllocator *Allocator() const { return allocator_; } 104 + OrtAllocator *Allocator() { return allocator_; }
105 105
106 // Return a vector containing 3 tensors 106 // Return a vector containing 3 tensors
107 // - attn_cache 107 // - attn_cache
@@ -179,12 +179,15 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates( @@ -179,12 +179,15 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
179 std::vector<Ort::Value> ans; 179 std::vector<Ort::Value> ans;
180 ans.reserve(states[0].size()); 180 ans.reserve(states[0].size());
181 181
  182 + auto allocator =
  183 + const_cast<OnlineZipformerTransducerModel *>(this)->allocator_;
  184 +
182 // cached_len 185 // cached_len
183 for (int32_t i = 0; i != num_encoders; ++i) { 186 for (int32_t i = 0; i != num_encoders; ++i) {
184 for (int32_t n = 0; n != batch_size; ++n) { 187 for (int32_t n = 0; n != batch_size; ++n) {
185 buf[n] = &states[n][i]; 188 buf[n] = &states[n][i];
186 } 189 }
187 - auto v = Cat<int64_t>(allocator_, buf, 1); // (num_layers, 1) 190 + auto v = Cat<int64_t>(allocator, buf, 1); // (num_layers, 1)
188 ans.push_back(std::move(v)); 191 ans.push_back(std::move(v));
189 } 192 }
190 193
@@ -193,7 +196,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates( @@ -193,7 +196,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
193 for (int32_t n = 0; n != batch_size; ++n) { 196 for (int32_t n = 0; n != batch_size; ++n) {
194 buf[n] = &states[n][num_encoders + i]; 197 buf[n] = &states[n][num_encoders + i];
195 } 198 }
196 - auto v = Cat(allocator_, buf, 1); // (num_layers, 1, encoder_dims) 199 + auto v = Cat(allocator, buf, 1); // (num_layers, 1, encoder_dims)
197 ans.push_back(std::move(v)); 200 ans.push_back(std::move(v));
198 } 201 }
199 202
@@ -203,7 +206,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates( @@ -203,7 +206,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
203 buf[n] = &states[n][num_encoders * 2 + i]; 206 buf[n] = &states[n][num_encoders * 2 + i];
204 } 207 }
205 // (num_layers, left_context_len, 1, attention_dims) 208 // (num_layers, left_context_len, 1, attention_dims)
206 - auto v = Cat(allocator_, buf, 2); 209 + auto v = Cat(allocator, buf, 2);
207 ans.push_back(std::move(v)); 210 ans.push_back(std::move(v));
208 } 211 }
209 212
@@ -213,7 +216,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates( @@ -213,7 +216,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
213 buf[n] = &states[n][num_encoders * 3 + i]; 216 buf[n] = &states[n][num_encoders * 3 + i];
214 } 217 }
215 // (num_layers, left_context_len, 1, attention_dims/2) 218 // (num_layers, left_context_len, 1, attention_dims/2)
216 - auto v = Cat(allocator_, buf, 2); 219 + auto v = Cat(allocator, buf, 2);
217 ans.push_back(std::move(v)); 220 ans.push_back(std::move(v));
218 } 221 }
219 222
@@ -223,7 +226,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates( @@ -223,7 +226,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
223 buf[n] = &states[n][num_encoders * 4 + i]; 226 buf[n] = &states[n][num_encoders * 4 + i];
224 } 227 }
225 // (num_layers, left_context_len, 1, attention_dims/2) 228 // (num_layers, left_context_len, 1, attention_dims/2)
226 - auto v = Cat(allocator_, buf, 2); 229 + auto v = Cat(allocator, buf, 2);
227 ans.push_back(std::move(v)); 230 ans.push_back(std::move(v));
228 } 231 }
229 232
@@ -233,7 +236,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates( @@ -233,7 +236,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
233 buf[n] = &states[n][num_encoders * 5 + i]; 236 buf[n] = &states[n][num_encoders * 5 + i];
234 } 237 }
235 // (num_layers, 1, encoder_dims, cnn_module_kernels-1) 238 // (num_layers, 1, encoder_dims, cnn_module_kernels-1)
236 - auto v = Cat(allocator_, buf, 1); 239 + auto v = Cat(allocator, buf, 1);
237 ans.push_back(std::move(v)); 240 ans.push_back(std::move(v));
238 } 241 }
239 242
@@ -243,7 +246,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates( @@ -243,7 +246,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
243 buf[n] = &states[n][num_encoders * 6 + i]; 246 buf[n] = &states[n][num_encoders * 6 + i];
244 } 247 }
245 // (num_layers, 1, encoder_dims, cnn_module_kernels-1) 248 // (num_layers, 1, encoder_dims, cnn_module_kernels-1)
246 - auto v = Cat(allocator_, buf, 1); 249 + auto v = Cat(allocator, buf, 1);
247 ans.push_back(std::move(v)); 250 ans.push_back(std::move(v));
248 } 251 }
249 252
@@ -258,12 +261,15 @@ OnlineZipformerTransducerModel::UnStackStates( @@ -258,12 +261,15 @@ OnlineZipformerTransducerModel::UnStackStates(
258 int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1]; 261 int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
259 int32_t num_encoders = num_encoder_layers_.size(); 262 int32_t num_encoders = num_encoder_layers_.size();
260 263
  264 + auto allocator =
  265 + const_cast<OnlineZipformerTransducerModel *>(this)->allocator_;
  266 +
261 std::vector<std::vector<Ort::Value>> ans; 267 std::vector<std::vector<Ort::Value>> ans;
262 ans.resize(batch_size); 268 ans.resize(batch_size);
263 269
264 // cached_len 270 // cached_len
265 for (int32_t i = 0; i != num_encoders; ++i) { 271 for (int32_t i = 0; i != num_encoders; ++i) {
266 - auto v = Unbind<int64_t>(allocator_, &states[i], 1); 272 + auto v = Unbind<int64_t>(allocator, &states[i], 1);
267 assert(v.size() == batch_size); 273 assert(v.size() == batch_size);
268 274
269 for (int32_t n = 0; n != batch_size; ++n) { 275 for (int32_t n = 0; n != batch_size; ++n) {
@@ -273,7 +279,7 @@ OnlineZipformerTransducerModel::UnStackStates( @@ -273,7 +279,7 @@ OnlineZipformerTransducerModel::UnStackStates(
273 279
274 // cached_avg 280 // cached_avg
275 for (int32_t i = num_encoders; i != 2 * num_encoders; ++i) { 281 for (int32_t i = num_encoders; i != 2 * num_encoders; ++i) {
276 - auto v = Unbind(allocator_, &states[i], 1); 282 + auto v = Unbind(allocator, &states[i], 1);
277 assert(v.size() == batch_size); 283 assert(v.size() == batch_size);
278 284
279 for (int32_t n = 0; n != batch_size; ++n) { 285 for (int32_t n = 0; n != batch_size; ++n) {
@@ -283,7 +289,7 @@ OnlineZipformerTransducerModel::UnStackStates( @@ -283,7 +289,7 @@ OnlineZipformerTransducerModel::UnStackStates(
283 289
284 // cached_key 290 // cached_key
285 for (int32_t i = 2 * num_encoders; i != 3 * num_encoders; ++i) { 291 for (int32_t i = 2 * num_encoders; i != 3 * num_encoders; ++i) {
286 - auto v = Unbind(allocator_, &states[i], 2); 292 + auto v = Unbind(allocator, &states[i], 2);
287 assert(v.size() == batch_size); 293 assert(v.size() == batch_size);
288 294
289 for (int32_t n = 0; n != batch_size; ++n) { 295 for (int32_t n = 0; n != batch_size; ++n) {
@@ -293,7 +299,7 @@ OnlineZipformerTransducerModel::UnStackStates( @@ -293,7 +299,7 @@ OnlineZipformerTransducerModel::UnStackStates(
293 299
294 // cached_val 300 // cached_val
295 for (int32_t i = 3 * num_encoders; i != 4 * num_encoders; ++i) { 301 for (int32_t i = 3 * num_encoders; i != 4 * num_encoders; ++i) {
296 - auto v = Unbind(allocator_, &states[i], 2); 302 + auto v = Unbind(allocator, &states[i], 2);
297 assert(v.size() == batch_size); 303 assert(v.size() == batch_size);
298 304
299 for (int32_t n = 0; n != batch_size; ++n) { 305 for (int32_t n = 0; n != batch_size; ++n) {
@@ -303,7 +309,7 @@ OnlineZipformerTransducerModel::UnStackStates( @@ -303,7 +309,7 @@ OnlineZipformerTransducerModel::UnStackStates(
303 309
304 // cached_val2 310 // cached_val2
305 for (int32_t i = 4 * num_encoders; i != 5 * num_encoders; ++i) { 311 for (int32_t i = 4 * num_encoders; i != 5 * num_encoders; ++i) {
306 - auto v = Unbind(allocator_, &states[i], 2); 312 + auto v = Unbind(allocator, &states[i], 2);
307 assert(v.size() == batch_size); 313 assert(v.size() == batch_size);
308 314
309 for (int32_t n = 0; n != batch_size; ++n) { 315 for (int32_t n = 0; n != batch_size; ++n) {
@@ -313,7 +319,7 @@ OnlineZipformerTransducerModel::UnStackStates( @@ -313,7 +319,7 @@ OnlineZipformerTransducerModel::UnStackStates(
313 319
314 // cached_conv1 320 // cached_conv1
315 for (int32_t i = 5 * num_encoders; i != 6 * num_encoders; ++i) { 321 for (int32_t i = 5 * num_encoders; i != 6 * num_encoders; ++i) {
316 - auto v = Unbind(allocator_, &states[i], 1); 322 + auto v = Unbind(allocator, &states[i], 1);
317 assert(v.size() == batch_size); 323 assert(v.size() == batch_size);
318 324
319 for (int32_t n = 0; n != batch_size; ++n) { 325 for (int32_t n = 0; n != batch_size; ++n) {
@@ -323,7 +329,7 @@ OnlineZipformerTransducerModel::UnStackStates( @@ -323,7 +329,7 @@ OnlineZipformerTransducerModel::UnStackStates(
323 329
324 // cached_conv2 330 // cached_conv2
325 for (int32_t i = 6 * num_encoders; i != 7 * num_encoders; ++i) { 331 for (int32_t i = 6 * num_encoders; i != 7 * num_encoders; ++i) {
326 - auto v = Unbind(allocator_, &states[i], 1); 332 + auto v = Unbind(allocator, &states[i], 1);
327 assert(v.size() == batch_size); 333 assert(v.size() == batch_size);
328 334
329 for (int32_t n = 0; n != batch_size; ++n) { 335 for (int32_t n = 0; n != batch_size; ++n) {
@@ -70,7 +70,7 @@ class OnlineZipformer2CtcModel::Impl { @@ -70,7 +70,7 @@ class OnlineZipformer2CtcModel::Impl {
70 70
71 int32_t ChunkShift() const { return decode_chunk_len_; } 71 int32_t ChunkShift() const { return decode_chunk_len_; }
72 72
73 - OrtAllocator *Allocator() const { return allocator_; } 73 + OrtAllocator *Allocator() { return allocator_; }
74 74
75 // Return a vector containing 3 tensors 75 // Return a vector containing 3 tensors
76 // - attn_cache 76 // - attn_cache
@@ -86,7 +86,7 @@ class OnlineZipformer2CtcModel::Impl { @@ -86,7 +86,7 @@ class OnlineZipformer2CtcModel::Impl {
86 } 86 }
87 87
88 std::vector<Ort::Value> StackStates( 88 std::vector<Ort::Value> StackStates(
89 - std::vector<std::vector<Ort::Value>> states) const { 89 + std::vector<std::vector<Ort::Value>> states) {
90 int32_t batch_size = static_cast<int32_t>(states.size()); 90 int32_t batch_size = static_cast<int32_t>(states.size());
91 91
92 std::vector<const Ort::Value *> buf(batch_size); 92 std::vector<const Ort::Value *> buf(batch_size);
@@ -159,7 +159,7 @@ class OnlineZipformer2CtcModel::Impl { @@ -159,7 +159,7 @@ class OnlineZipformer2CtcModel::Impl {
159 } 159 }
160 160
161 std::vector<std::vector<Ort::Value>> UnStackStates( 161 std::vector<std::vector<Ort::Value>> UnStackStates(
162 - std::vector<Ort::Value> states) const { 162 + std::vector<Ort::Value> states) {
163 int32_t m = std::accumulate(num_encoder_layers_.begin(), 163 int32_t m = std::accumulate(num_encoder_layers_.begin(),
164 num_encoder_layers_.end(), 0); 164 num_encoder_layers_.end(), 0);
165 assert(states.size() == m * 6 + 2); 165 assert(states.size() == m * 6 + 2);
@@ -185,6 +185,9 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates( @@ -185,6 +185,9 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates(
185 185
186 std::vector<const Ort::Value *> buf(batch_size); 186 std::vector<const Ort::Value *> buf(batch_size);
187 187
  188 + auto allocator =
  189 + const_cast<OnlineZipformer2TransducerModel *>(this)->allocator_;
  190 +
188 std::vector<Ort::Value> ans; 191 std::vector<Ort::Value> ans;
189 int32_t num_states = static_cast<int32_t>(states[0].size()); 192 int32_t num_states = static_cast<int32_t>(states[0].size());
190 ans.reserve(num_states); 193 ans.reserve(num_states);
@@ -194,42 +197,42 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates( @@ -194,42 +197,42 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates(
194 for (int32_t n = 0; n != batch_size; ++n) { 197 for (int32_t n = 0; n != batch_size; ++n) {
195 buf[n] = &states[n][6 * i]; 198 buf[n] = &states[n][6 * i];
196 } 199 }
197 - auto v = Cat(allocator_, buf, 1); 200 + auto v = Cat(allocator, buf, 1);
198 ans.push_back(std::move(v)); 201 ans.push_back(std::move(v));
199 } 202 }
200 { 203 {
201 for (int32_t n = 0; n != batch_size; ++n) { 204 for (int32_t n = 0; n != batch_size; ++n) {
202 buf[n] = &states[n][6 * i + 1]; 205 buf[n] = &states[n][6 * i + 1];
203 } 206 }
204 - auto v = Cat(allocator_, buf, 1); 207 + auto v = Cat(allocator, buf, 1);
205 ans.push_back(std::move(v)); 208 ans.push_back(std::move(v));
206 } 209 }
207 { 210 {
208 for (int32_t n = 0; n != batch_size; ++n) { 211 for (int32_t n = 0; n != batch_size; ++n) {
209 buf[n] = &states[n][6 * i + 2]; 212 buf[n] = &states[n][6 * i + 2];
210 } 213 }
211 - auto v = Cat(allocator_, buf, 1); 214 + auto v = Cat(allocator, buf, 1);
212 ans.push_back(std::move(v)); 215 ans.push_back(std::move(v));
213 } 216 }
214 { 217 {
215 for (int32_t n = 0; n != batch_size; ++n) { 218 for (int32_t n = 0; n != batch_size; ++n) {
216 buf[n] = &states[n][6 * i + 3]; 219 buf[n] = &states[n][6 * i + 3];
217 } 220 }
218 - auto v = Cat(allocator_, buf, 1); 221 + auto v = Cat(allocator, buf, 1);
219 ans.push_back(std::move(v)); 222 ans.push_back(std::move(v));
220 } 223 }
221 { 224 {
222 for (int32_t n = 0; n != batch_size; ++n) { 225 for (int32_t n = 0; n != batch_size; ++n) {
223 buf[n] = &states[n][6 * i + 4]; 226 buf[n] = &states[n][6 * i + 4];
224 } 227 }
225 - auto v = Cat(allocator_, buf, 0); 228 + auto v = Cat(allocator, buf, 0);
226 ans.push_back(std::move(v)); 229 ans.push_back(std::move(v));
227 } 230 }
228 { 231 {
229 for (int32_t n = 0; n != batch_size; ++n) { 232 for (int32_t n = 0; n != batch_size; ++n) {
230 buf[n] = &states[n][6 * i + 5]; 233 buf[n] = &states[n][6 * i + 5];
231 } 234 }
232 - auto v = Cat(allocator_, buf, 0); 235 + auto v = Cat(allocator, buf, 0);
233 ans.push_back(std::move(v)); 236 ans.push_back(std::move(v));
234 } 237 }
235 } 238 }
@@ -238,7 +241,7 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates( @@ -238,7 +241,7 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates(
238 for (int32_t n = 0; n != batch_size; ++n) { 241 for (int32_t n = 0; n != batch_size; ++n) {
239 buf[n] = &states[n][num_states - 2]; 242 buf[n] = &states[n][num_states - 2];
240 } 243 }
241 - auto v = Cat(allocator_, buf, 0); 244 + auto v = Cat(allocator, buf, 0);
242 ans.push_back(std::move(v)); 245 ans.push_back(std::move(v));
243 } 246 }
244 247
@@ -246,7 +249,7 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates( @@ -246,7 +249,7 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates(
246 for (int32_t n = 0; n != batch_size; ++n) { 249 for (int32_t n = 0; n != batch_size; ++n) {
247 buf[n] = &states[n][num_states - 1]; 250 buf[n] = &states[n][num_states - 1];
248 } 251 }
249 - auto v = Cat<int64_t>(allocator_, buf, 0); 252 + auto v = Cat<int64_t>(allocator, buf, 0);
250 ans.push_back(std::move(v)); 253 ans.push_back(std::move(v));
251 } 254 }
252 return ans; 255 return ans;
@@ -261,12 +264,15 @@ OnlineZipformer2TransducerModel::UnStackStates( @@ -261,12 +264,15 @@ OnlineZipformer2TransducerModel::UnStackStates(
261 264
262 int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1]; 265 int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
263 266
  267 + auto allocator =
  268 + const_cast<OnlineZipformer2TransducerModel *>(this)->allocator_;
  269 +
264 std::vector<std::vector<Ort::Value>> ans; 270 std::vector<std::vector<Ort::Value>> ans;
265 ans.resize(batch_size); 271 ans.resize(batch_size);
266 272
267 for (int32_t i = 0; i != m; ++i) { 273 for (int32_t i = 0; i != m; ++i) {
268 { 274 {
269 - auto v = Unbind(allocator_, &states[i * 6], 1); 275 + auto v = Unbind(allocator, &states[i * 6], 1);
270 assert(static_cast<int32_t>(v.size()) == batch_size); 276 assert(static_cast<int32_t>(v.size()) == batch_size);
271 277
272 for (int32_t n = 0; n != batch_size; ++n) { 278 for (int32_t n = 0; n != batch_size; ++n) {
@@ -274,7 +280,7 @@ OnlineZipformer2TransducerModel::UnStackStates( @@ -274,7 +280,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
274 } 280 }
275 } 281 }
276 { 282 {
277 - auto v = Unbind(allocator_, &states[i * 6 + 1], 1); 283 + auto v = Unbind(allocator, &states[i * 6 + 1], 1);
278 assert(static_cast<int32_t>(v.size()) == batch_size); 284 assert(static_cast<int32_t>(v.size()) == batch_size);
279 285
280 for (int32_t n = 0; n != batch_size; ++n) { 286 for (int32_t n = 0; n != batch_size; ++n) {
@@ -282,7 +288,7 @@ OnlineZipformer2TransducerModel::UnStackStates( @@ -282,7 +288,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
282 } 288 }
283 } 289 }
284 { 290 {
285 - auto v = Unbind(allocator_, &states[i * 6 + 2], 1); 291 + auto v = Unbind(allocator, &states[i * 6 + 2], 1);
286 assert(static_cast<int32_t>(v.size()) == batch_size); 292 assert(static_cast<int32_t>(v.size()) == batch_size);
287 293
288 for (int32_t n = 0; n != batch_size; ++n) { 294 for (int32_t n = 0; n != batch_size; ++n) {
@@ -290,7 +296,7 @@ OnlineZipformer2TransducerModel::UnStackStates( @@ -290,7 +296,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
290 } 296 }
291 } 297 }
292 { 298 {
293 - auto v = Unbind(allocator_, &states[i * 6 + 3], 1); 299 + auto v = Unbind(allocator, &states[i * 6 + 3], 1);
294 assert(static_cast<int32_t>(v.size()) == batch_size); 300 assert(static_cast<int32_t>(v.size()) == batch_size);
295 301
296 for (int32_t n = 0; n != batch_size; ++n) { 302 for (int32_t n = 0; n != batch_size; ++n) {
@@ -298,7 +304,7 @@ OnlineZipformer2TransducerModel::UnStackStates( @@ -298,7 +304,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
298 } 304 }
299 } 305 }
300 { 306 {
301 - auto v = Unbind(allocator_, &states[i * 6 + 4], 0); 307 + auto v = Unbind(allocator, &states[i * 6 + 4], 0);
302 assert(static_cast<int32_t>(v.size()) == batch_size); 308 assert(static_cast<int32_t>(v.size()) == batch_size);
303 309
304 for (int32_t n = 0; n != batch_size; ++n) { 310 for (int32_t n = 0; n != batch_size; ++n) {
@@ -306,7 +312,7 @@ OnlineZipformer2TransducerModel::UnStackStates( @@ -306,7 +312,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
306 } 312 }
307 } 313 }
308 { 314 {
309 - auto v = Unbind(allocator_, &states[i * 6 + 5], 0); 315 + auto v = Unbind(allocator, &states[i * 6 + 5], 0);
310 assert(static_cast<int32_t>(v.size()) == batch_size); 316 assert(static_cast<int32_t>(v.size()) == batch_size);
311 317
312 for (int32_t n = 0; n != batch_size; ++n) { 318 for (int32_t n = 0; n != batch_size; ++n) {
@@ -316,7 +322,7 @@ OnlineZipformer2TransducerModel::UnStackStates( @@ -316,7 +322,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
316 } 322 }
317 323
318 { 324 {
319 - auto v = Unbind(allocator_, &states[m * 6], 0); 325 + auto v = Unbind(allocator, &states[m * 6], 0);
320 assert(static_cast<int32_t>(v.size()) == batch_size); 326 assert(static_cast<int32_t>(v.size()) == batch_size);
321 327
322 for (int32_t n = 0; n != batch_size; ++n) { 328 for (int32_t n = 0; n != batch_size; ++n) {
@@ -324,7 +330,7 @@ OnlineZipformer2TransducerModel::UnStackStates( @@ -324,7 +330,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
324 } 330 }
325 } 331 }
326 { 332 {
327 - auto v = Unbind<int64_t>(allocator_, &states[m * 6 + 1], 0); 333 + auto v = Unbind<int64_t>(allocator, &states[m * 6 + 1], 0);
328 assert(static_cast<int32_t>(v.size()) == batch_size); 334 assert(static_cast<int32_t>(v.size()) == batch_size);
329 335
330 for (int32_t n = 0; n != batch_size; ++n) { 336 for (int32_t n = 0; n != batch_size; ++n) {
@@ -21,6 +21,36 @@ @@ -21,6 +21,36 @@
21 21
22 namespace sherpa_onnx { 22 namespace sherpa_onnx {
23 23
  24 +static std::string GetInputName(Ort::Session *sess, size_t index,
  25 + OrtAllocator *allocator) {
  26 +// Note(fangjun): We only tested 1.17.1 and 1.11.0
  27 +// For other versions, we may need to change it
  28 +#if ORT_API_VERSION >= 17
  29 + auto v = sess->GetInputNameAllocated(index, allocator);
  30 + return v.get();
  31 +#else
  32 + auto v = sess->GetInputName(index, allocator);
  33 + std::string ans = v;
  34 + allocator->Free(allocator, v);
  35 + return ans;
  36 +#endif
  37 +}
  38 +
  39 +static std::string GetOutputName(Ort::Session *sess, size_t index,
  40 + OrtAllocator *allocator) {
  41 +// Note(fangjun): We only tested 1.17.1 and 1.11.0
  42 +// For other versions, we may need to change it
  43 +#if ORT_API_VERSION >= 17
  44 + auto v = sess->GetOutputNameAllocated(index, allocator);
  45 + return v.get();
  46 +#else
  47 + auto v = sess->GetOutputName(index, allocator);
  48 + std::string ans = v;
  49 + allocator->Free(allocator, v);
  50 + return ans;
  51 +#endif
  52 +}
  53 +
24 void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names, 54 void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names,
25 std::vector<const char *> *input_names_ptr) { 55 std::vector<const char *> *input_names_ptr) {
26 Ort::AllocatorWithDefaultOptions allocator; 56 Ort::AllocatorWithDefaultOptions allocator;
@@ -28,8 +58,7 @@ void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names, @@ -28,8 +58,7 @@ void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names,
28 input_names->resize(node_count); 58 input_names->resize(node_count);
29 input_names_ptr->resize(node_count); 59 input_names_ptr->resize(node_count);
30 for (size_t i = 0; i != node_count; ++i) { 60 for (size_t i = 0; i != node_count; ++i) {
31 - auto tmp = sess->GetInputNameAllocated(i, allocator);  
32 - (*input_names)[i] = tmp.get(); 61 + (*input_names)[i] = GetInputName(sess, i, allocator);
33 (*input_names_ptr)[i] = (*input_names)[i].c_str(); 62 (*input_names_ptr)[i] = (*input_names)[i].c_str();
34 } 63 }
35 } 64 }
@@ -41,8 +70,7 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names, @@ -41,8 +70,7 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
41 output_names->resize(node_count); 70 output_names->resize(node_count);
42 output_names_ptr->resize(node_count); 71 output_names_ptr->resize(node_count);
43 for (size_t i = 0; i != node_count; ++i) { 72 for (size_t i = 0; i != node_count; ++i) {
44 - auto tmp = sess->GetOutputNameAllocated(i, allocator);  
45 - (*output_names)[i] = tmp.get(); 73 + (*output_names)[i] = GetOutputName(sess, i, allocator);
46 (*output_names_ptr)[i] = (*output_names)[i].c_str(); 74 (*output_names_ptr)[i] = (*output_names)[i].c_str();
47 } 75 }
48 } 76 }
@@ -78,12 +106,24 @@ Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out, @@ -78,12 +106,24 @@ Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out,
78 106
79 void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) { 107 void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) {
80 Ort::AllocatorWithDefaultOptions allocator; 108 Ort::AllocatorWithDefaultOptions allocator;
  109 +#if ORT_API_VERSION >= 17
81 std::vector<Ort::AllocatedStringPtr> v = 110 std::vector<Ort::AllocatedStringPtr> v =
82 meta_data.GetCustomMetadataMapKeysAllocated(allocator); 111 meta_data.GetCustomMetadataMapKeysAllocated(allocator);
83 for (const auto &key : v) { 112 for (const auto &key : v) {
84 auto p = meta_data.LookupCustomMetadataMapAllocated(key.get(), allocator); 113 auto p = meta_data.LookupCustomMetadataMapAllocated(key.get(), allocator);
85 os << key.get() << "=" << p.get() << "\n"; 114 os << key.get() << "=" << p.get() << "\n";
86 } 115 }
  116 +#else
  117 + int64_t num_keys = 0;
  118 + char **keys = meta_data.GetCustomMetadataMapKeys(allocator, num_keys);
  119 + for (int32_t i = 0; i < num_keys; ++i) {
  120 + auto v = LookupCustomModelMetaData(meta_data, keys[i], allocator);
  121 + os << keys[i] << "=" << v << "\n";
  122 + allocator.Free(keys[i]);
  123 + }
  124 +
  125 + allocator.Free(keys);
  126 +#endif
87 } 127 }
88 128
89 Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) { 129 Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) {
@@ -361,4 +401,20 @@ std::vector<Ort::Value> Convert(std::vector<CopyableOrtValue> values) { @@ -361,4 +401,20 @@ std::vector<Ort::Value> Convert(std::vector<CopyableOrtValue> values) {
361 return ans; 401 return ans;
362 } 402 }
363 403
  404 +std::string LookupCustomModelMetaData(const Ort::ModelMetadata &meta_data,
  405 + const char *key,
  406 + OrtAllocator *allocator) {
  407 +// Note(fangjun): We only tested 1.17.1 and 1.11.0
  408 +// For other versions, we may need to change it
  409 +#if ORT_API_VERSION >= 17
  410 + auto v = meta_data.LookupCustomMetadataMapAllocated(key, allocator);
  411 + return v.get();
  412 +#else
  413 + auto v = meta_data.LookupCustomMetadataMap(key, allocator);
  414 + std::string ans = v;
  415 + allocator->Free(allocator, v);
  416 + return ans;
  417 +#endif
  418 +}
  419 +
364 } // namespace sherpa_onnx 420 } // namespace sherpa_onnx
@@ -59,6 +59,9 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names, @@ -59,6 +59,9 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
59 Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out, 59 Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out,
60 int32_t t); 60 int32_t t);
61 61
  62 +std::string LookupCustomModelMetaData(const Ort::ModelMetadata &meta_data,
  63 + const char *key, OrtAllocator *allocator);
  64 +
62 void PrintModelMetadata(std::ostream &os, 65 void PrintModelMetadata(std::ostream &os,
63 const Ort::ModelMetadata &meta_data); // NOLINT 66 const Ort::ModelMetadata &meta_data); // NOLINT
64 67
@@ -60,6 +60,7 @@ Ort::SessionOptions GetSessionOptionsImpl( @@ -60,6 +60,7 @@ Ort::SessionOptions GetSessionOptionsImpl(
60 case Provider::kCPU: 60 case Provider::kCPU:
61 break; // nothing to do for the CPU provider 61 break; // nothing to do for the CPU provider
62 case Provider::kXnnpack: { 62 case Provider::kXnnpack: {
  63 +#if ORT_API_VERSION >= 17
63 if (std::find(available_providers.begin(), available_providers.end(), 64 if (std::find(available_providers.begin(), available_providers.end(),
64 "XnnpackExecutionProvider") != available_providers.end()) { 65 "XnnpackExecutionProvider") != available_providers.end()) {
65 sess_opts.AppendExecutionProvider("XNNPACK"); 66 sess_opts.AppendExecutionProvider("XNNPACK");
@@ -67,6 +68,11 @@ Ort::SessionOptions GetSessionOptionsImpl( @@ -67,6 +68,11 @@ Ort::SessionOptions GetSessionOptionsImpl(
67 SHERPA_ONNX_LOGE("Available providers: %s. Fallback to cpu!", 68 SHERPA_ONNX_LOGE("Available providers: %s. Fallback to cpu!",
68 os.str().c_str()); 69 os.str().c_str());
69 } 70 }
  71 +#else
  72 + SHERPA_ONNX_LOGE(
  73 + "Does not support xnnpack for onnxruntime: %d. Fallback to cpu!",
  74 + static_cast<int32_t>(ORT_API_VERSION));
  75 +#endif
70 break; 76 break;
71 } 77 }
72 case Provider::kTRT: { 78 case Provider::kTRT: {
@@ -40,8 +40,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -40,8 +40,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
40 40
41 Ort::AllocatorWithDefaultOptions allocator; 41 Ort::AllocatorWithDefaultOptions allocator;
42 auto model_type = 42 auto model_type =
43 - meta_data.LookupCustomMetadataMapAllocated("framework", allocator);  
44 - if (!model_type) { 43 + LookupCustomModelMetaData(meta_data, "framework", allocator);
  44 + if (model_type.empty()) {
45 SHERPA_ONNX_LOGE( 45 SHERPA_ONNX_LOGE(
46 "No model_type in the metadata!\n" 46 "No model_type in the metadata!\n"
47 "Please make sure you have added metadata to the model.\n\n" 47 "Please make sure you have added metadata to the model.\n\n"
@@ -52,14 +52,14 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -52,14 +52,14 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
52 return ModelType::kUnknown; 52 return ModelType::kUnknown;
53 } 53 }
54 54
55 - if (model_type.get() == std::string("wespeaker")) { 55 + if (model_type == "wespeaker") {
56 return ModelType::kWeSpeaker; 56 return ModelType::kWeSpeaker;
57 - } else if (model_type.get() == std::string("3d-speaker")) { 57 + } else if (model_type == "3d-speaker") {
58 return ModelType::k3dSpeaker; 58 return ModelType::k3dSpeaker;
59 - } else if (model_type.get() == std::string("nemo")) { 59 + } else if (model_type == "nemo") {
60 return ModelType::kNeMo; 60 return ModelType::kNeMo;
61 } else { 61 } else {
62 - SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); 62 + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str());
63 return ModelType::kUnknown; 63 return ModelType::kUnknown;
64 } 64 }
65 } 65 }
@@ -53,7 +53,7 @@ class SpeakerEmbeddingExtractorNeMoModel::Impl { @@ -53,7 +53,7 @@ class SpeakerEmbeddingExtractorNeMoModel::Impl {
53 return std::move(outputs[0]); 53 return std::move(outputs[0]);
54 } 54 }
55 55
56 - OrtAllocator *Allocator() const { return allocator_; } 56 + OrtAllocator *Allocator() { return allocator_; }
57 57
58 const SpeakerEmbeddingExtractorNeMoModelMetaData &GetMetaData() const { 58 const SpeakerEmbeddingExtractorNeMoModelMetaData &GetMetaData() const {
59 return meta_data_; 59 return meta_data_;
@@ -42,8 +42,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -42,8 +42,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
42 42
43 Ort::AllocatorWithDefaultOptions allocator; 43 Ort::AllocatorWithDefaultOptions allocator;
44 auto model_type = 44 auto model_type =
45 - meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);  
46 - if (!model_type) { 45 + LookupCustomModelMetaData(meta_data, "model_type", allocator);
  46 + if (model_type.empty()) {
47 SHERPA_ONNX_LOGE( 47 SHERPA_ONNX_LOGE(
48 "No model_type in the metadata!\n" 48 "No model_type in the metadata!\n"
49 "Please make sure you have added metadata to the model.\n\n" 49 "Please make sure you have added metadata to the model.\n\n"
@@ -54,11 +54,10 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -54,11 +54,10 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
54 return ModelType::kUnknown; 54 return ModelType::kUnknown;
55 } 55 }
56 56
57 - auto model_type_str = std::string(model_type.get());  
58 - if (model_type_str.find("whisper") == 0) { 57 + if (model_type.find("whisper") == 0) {
59 return ModelType::kWhisper; 58 return ModelType::kWhisper;
60 } else { 59 } else {
61 - SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); 60 + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str());
62 return ModelType::kUnknown; 61 return ModelType::kUnknown;
63 } 62 }
64 } 63 }
@@ -29,20 +29,19 @@ namespace { @@ -29,20 +29,19 @@ namespace {
29 const char *ws = " \t\n\r\f\v"; 29 const char *ws = " \t\n\r\f\v";
30 30
31 // trim from end of string (right) 31 // trim from end of string (right)
32 -inline std::string &TrimRight(std::string &s, const char *t = ws) {  
33 - s.erase(s.find_last_not_of(t) + 1);  
34 - return s; 32 +inline void TrimRight(std::string *s, const char *t = ws) {
  33 + s->erase(s->find_last_not_of(t) + 1);
35 } 34 }
36 35
37 // trim from beginning of string (left) 36 // trim from beginning of string (left)
38 -inline std::string &TrimLeft(std::string &s, const char *t = ws) {  
39 - s.erase(0, s.find_first_not_of(t));  
40 - return s; 37 +inline void TrimLeft(std::string *s, const char *t = ws) {
  38 + s->erase(0, s->find_first_not_of(t));
41 } 39 }
42 40
43 // trim from both ends of string (right then left) 41 // trim from both ends of string (right then left)
44 -inline std::string &Trim(std::string &s, const char *t = ws) {  
45 - return TrimLeft(TrimRight(s, t), t); 42 +inline void Trim(std::string *s, const char *t = ws) {
  43 + TrimRight(s, t);
  44 + TrimLeft(s, t);
46 } 45 }
47 } // namespace 46 } // namespace
48 47
@@ -56,7 +55,7 @@ std::unordered_map<std::string, int32_t> ReadTokens( @@ -56,7 +55,7 @@ std::unordered_map<std::string, int32_t> ReadTokens(
56 std::string sym; 55 std::string sym;
57 int32_t id = -1; 56 int32_t id = -1;
58 while (std::getline(is, line)) { 57 while (std::getline(is, line)) {
59 - Trim(line); 58 + Trim(&line);
60 std::istringstream iss(line); 59 std::istringstream iss(line);
61 iss >> sym; 60 iss >> sym;
62 if (iss.eof()) { 61 if (iss.eof()) {