Fangjun Kuang
Committed by GitHub

Support building GPU-capable sherpa-onnx on Linux aarch64. (#1500)

Thanks to @Peakyxh for providing pre-built onnxruntime libraries 
with CUDA support for Linux aarch64.

Tested on Jetson nano b01
正在显示 41 个修改的文件 包含 546 行增加300 行删除
... ... @@ -34,11 +34,12 @@ concurrency:
jobs:
aarch64_linux_gnu_shared:
runs-on: ${{ matrix.os }}
name: aarch64 shared lib test
name: aarch64 shared GPU ${{ matrix.gpu }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
gpu: [ON, OFF]
steps:
- uses: actions/checkout@v4
... ... @@ -79,15 +80,24 @@ jobs:
make -j2
make install
- name: cache-toolchain
id: cache-toolchain
- name: cache-toolchain (CPU)
if: matrix.gpu == 'OFF'
id: cache-toolchain-cpu
uses: actions/cache@v4
with:
path: toolchain
key: gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz
- name: Download toolchain
if: steps.cache-toolchain.outputs.cache-hit != 'true'
- name: cache-toolchain (GPU)
if: matrix.gpu == 'ON'
id: cache-toolchain-gpu
uses: actions/cache@v4
with:
path: toolchain
key: gcc-arm-10.3-2021.07-x86_64-aarch64-none-linux-gnu.tar.xz
- name: Download toolchain (CPU, gcc 7.5)
if: steps.cache-toolchain-cpu.outputs.cache-hit != 'true' && matrix.gpu == 'OFF'
shell: bash
run: |
wget -qq https://huggingface.co/csukuangfj/sherpa-ncnn-toolchains/resolve/main/gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz
... ... @@ -95,6 +105,15 @@ jobs:
mkdir $GITHUB_WORKSPACE/toolchain
tar xf ./gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz --strip-components 1 -C $GITHUB_WORKSPACE/toolchain
- name: Download toolchain (GPU, gcc 10.3)
if: steps.cache-toolchain-gpu.outputs.cache-hit != 'true' && matrix.gpu == 'ON'
shell: bash
run: |
wget -qq https://huggingface.co/csukuangfj/sherpa-ncnn-toolchains/resolve/main/gcc-arm-10.3-2021.07-x86_64-aarch64-none-linux-gnu.tar.xz
mkdir $GITHUB_WORKSPACE/toolchain
tar xf ./gcc-arm-10.3-2021.07-x86_64-aarch64-none-linux-gnu.tar.xz --strip-components 1 -C $GITHUB_WORKSPACE/toolchain
- name: Set environment variable
if: steps.cache-build-result.outputs.cache-hit != 'true'
shell: bash
... ... @@ -103,19 +122,31 @@ jobs:
echo "$GITHUB_WORKSPACE/bin" >> "$GITHUB_PATH"
ls -lh "$GITHUB_WORKSPACE/toolchain/bin"
echo "CC=aarch64-linux-gnu-gcc" >> "$GITHUB_ENV"
echo "CXX=aarch64-linux-gnu-g++" >> "$GITHUB_ENV"
if [[ ${{ matrix.gpu }} == OFF ]]; then
echo "CC=aarch64-linux-gnu-gcc" >> "$GITHUB_ENV"
echo "CXX=aarch64-linux-gnu-g++" >> "$GITHUB_ENV"
else
echo "CC=aarch64-none-linux-gnu-gcc" >> "$GITHUB_ENV"
echo "CXX=aarch64-none-linux-gnu-g++" >> "$GITHUB_ENV"
fi
- name: Display toolchain info
shell: bash
run: |
aarch64-linux-gnu-gcc --version
if [[ ${{ matrix.gpu }} == OFF ]]; then
which aarch64-linux-gnu-gcc
aarch64-linux-gnu-gcc --version
else
which aarch64-none-linux-gnu-gcc
aarch64-none-linux-gnu-gcc --version
fi
- name: Display qemu-aarch64 -h
shell: bash
run: |
export PATH=$GITHUB_WORKSPACE/qemu-install/bin:$PATH
export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-linux-gnu/libc
export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-none-linux-gnu/libc
qemu-aarch64 -h
- name: build aarch64-linux-gnu
... ... @@ -127,6 +158,7 @@ jobs:
cmake --version
export BUILD_SHARED_LIBS=ON
export SHERPA_ONNX_ENABLE_GPU=${{ matrix.gpu }}
./build-aarch64-linux-gnu.sh
... ... @@ -140,7 +172,11 @@ jobs:
run: |
export PATH=$GITHUB_WORKSPACE/toolchain/bin:$PATH
export PATH=$GITHUB_WORKSPACE/qemu-install/bin:$PATH
export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-linux-gnu/libc
if [[ ${{ matrix.gpu }} == OFF ]]; then
export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-linux-gnu/libc
else
export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-none-linux-gnu/libc
fi
ls -lh ./build-aarch64-linux-gnu/bin
... ... @@ -151,11 +187,20 @@ jobs:
- name: Copy files
shell: bash
run: |
aarch64-linux-gnu-strip --version
if [[ ${{ matrix.gpu }} == OFF ]]; then
aarch64-linux-gnu-strip --version
else
aarch64-none-linux-gnu-strip --version
fi
SHERPA_ONNX_VERSION=v$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
dst=sherpa-onnx-${SHERPA_ONNX_VERSION}-linux-aarch64-shared
if [[ ${{ matrix.gpu }} == OFF ]]; then
dst=${dst}-cpu
else
dst=${dst}-gpu
fi
mkdir $dst
cp -a build-aarch64-linux-gnu/install/bin $dst/
... ... @@ -166,7 +211,11 @@ jobs:
ls -lh $dst/bin/
echo "strip"
aarch64-linux-gnu-strip $dst/bin/*
if [[ ${{ matrix.gpu }} == OFF ]]; then
aarch64-linux-gnu-strip $dst/bin/*
else
aarch64-none-linux-gnu-strip $dst/bin/*
fi
tree $dst
... ... @@ -174,8 +223,8 @@ jobs:
- uses: actions/upload-artifact@v4
with:
name: sherpa-onnx-linux-aarch64-shared
path: sherpa-onnx-*linux-aarch64-shared.tar.bz2
name: sherpa-onnx-linux-aarch64-shared-gpu-${{ matrix.gpu }}
path: sherpa-onnx-*linux-aarch64-shared*.tar.bz2
# https://huggingface.co/docs/hub/spaces-github-actions
- name: Publish to huggingface
... ... @@ -198,7 +247,7 @@ jobs:
cd huggingface
mkdir -p aarch64
cp -v ../sherpa-onnx-*-shared.tar.bz2 ./aarch64
cp -v ../sherpa-onnx-*-shared*.tar.bz2 ./aarch64
git status
git lfs track "*.bz2"
... ...
... ... @@ -44,6 +44,21 @@ if [[ x"$BUILD_SHARED_LIBS" == x"" ]]; then
BUILD_SHARED_LIBS=OFF
fi
if [[ x"$SHERPA_ONNX_ENABLE_GPU" == x"" ]]; then
# By default, use CPU
SHERPA_ONNX_ENABLE_GPU=OFF
# If you use GPU, then please make sure you have NVIDIA GPUs on your board.
# It uses onnxruntime 1.11.0.
#
# Tested on Jetson Nano B01
fi
if [[ x"$SHERPA_ONNX_ENABLE_GPU" == x"ON" ]]; then
# Build shared libs if building GPU is enabled.
BUILD_SHARED_LIBS=ON
fi
cmake \
-DBUILD_PIPER_PHONMIZE_EXE=OFF \
-DBUILD_PIPER_PHONMIZE_TESTS=OFF \
... ... @@ -51,6 +66,7 @@ cmake \
-DBUILD_ESPEAK_NG_TESTS=OFF \
-DCMAKE_INSTALL_PREFIX=./install \
-DCMAKE_BUILD_TYPE=Release \
-DSHERPA_ONNX_ENABLE_GPU=$SHERPA_ONNX_ENABLE_GPU \
-DBUILD_SHARED_LIBS=$BUILD_SHARED_LIBS \
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
... ...
# Copyright (c) 2022-2024 Xiaomi Corporation
message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}")
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
if(NOT CMAKE_SYSTEM_NAME STREQUAL Linux)
message(FATAL_ERROR "This file is for Linux only. Given: ${CMAKE_SYSTEM_NAME}")
endif()
if(NOT CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64)
message(FATAL_ERROR "This file is for aarch64 only. Given: ${CMAKE_SYSTEM_PROCESSOR}")
endif()
if(NOT BUILD_SHARED_LIBS)
message(FATAL_ERROR "This file is for building shared libraries. BUILD_SHARED_LIBS: ${BUILD_SHARED_LIBS}")
endif()
if(NOT SHERPA_ONNX_ENABLE_GPU)
message(FATAL_ERROR "This file is for NVIDIA GPU only. Given SHERPA_ONNX_ENABLE_GPU: ${SHERPA_ONNX_ENABLE_GPU}")
endif()
set(onnxruntime_URL "https://github.com/csukuangfj/onnxruntime-libs/releases/download/v1.11.0/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2")
set(onnxruntime_URL2 "https://hf-mirror.com/csukuangfj/onnxruntime-libs/resolve/main/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2")
set(onnxruntime_HASH "SHA256=36eded935551e23aead09d4173bdf0bd1e7b01fdec15d77f97d6e34029aa60d7")
# If you don't have access to the Internet,
# please download onnxruntime to one of the following locations.
# You can add more if you want.
set(possible_file_locations
$ENV{HOME}/Downloads/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2
${CMAKE_SOURCE_DIR}/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2
${CMAKE_BINARY_DIR}/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2
/tmp/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2
/star-fj/fangjun/download/github/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2
)
foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(onnxruntime_URL "${f}")
file(TO_CMAKE_PATH "${onnxruntime_URL}" onnxruntime_URL)
message(STATUS "Found local downloaded onnxruntime: ${onnxruntime_URL}")
set(onnxruntime_URL2)
break()
endif()
endforeach()
FetchContent_Declare(onnxruntime
URL
${onnxruntime_URL}
${onnxruntime_URL2}
URL_HASH ${onnxruntime_HASH}
)
FetchContent_GetProperties(onnxruntime)
if(NOT onnxruntime_POPULATED)
message(STATUS "Downloading onnxruntime from ${onnxruntime_URL}")
FetchContent_Populate(onnxruntime)
endif()
message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}")
find_library(location_onnxruntime onnxruntime
PATHS
"${onnxruntime_SOURCE_DIR}/lib"
NO_CMAKE_SYSTEM_PATH
)
message(STATUS "location_onnxruntime: ${location_onnxruntime}")
add_library(onnxruntime SHARED IMPORTED)
set_target_properties(onnxruntime PROPERTIES
IMPORTED_LOCATION ${location_onnxruntime}
INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include"
)
find_library(location_onnxruntime_cuda_lib onnxruntime_providers_cuda
PATHS
"${onnxruntime_SOURCE_DIR}/lib"
NO_CMAKE_SYSTEM_PATH
)
add_library(onnxruntime_providers_cuda SHARED IMPORTED)
set_target_properties(onnxruntime_providers_cuda PROPERTIES
IMPORTED_LOCATION ${location_onnxruntime_cuda_lib}
)
message(STATUS "location_onnxruntime_cuda_lib: ${location_onnxruntime_cuda_lib}")
# for libonnxruntime_providers_shared.so
find_library(location_onnxruntime_providers_shared_lib onnxruntime_providers_shared
PATHS
"${onnxruntime_SOURCE_DIR}/lib"
NO_CMAKE_SYSTEM_PATH
)
add_library(onnxruntime_providers_shared SHARED IMPORTED)
set_target_properties(onnxruntime_providers_shared PROPERTIES
IMPORTED_LOCATION ${location_onnxruntime_providers_shared_lib}
)
message(STATUS "location_onnxruntime_providers_shared_lib: ${location_onnxruntime_providers_shared_lib}")
file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/lib/libonnxruntime*")
message(STATUS "onnxruntime lib files: ${onnxruntime_lib_files}")
install(FILES ${onnxruntime_lib_files} DESTINATION lib)
... ...
... ... @@ -13,7 +13,9 @@ function(download_onnxruntime)
include(onnxruntime-linux-riscv64-static)
endif()
elseif(CMAKE_SYSTEM_NAME STREQUAL Linux AND CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64)
if(BUILD_SHARED_LIBS)
if(SHERPA_ONNX_ENABLE_GPU)
include(onnxruntime-linux-aarch64-gpu)
elseif(BUILD_SHARED_LIBS)
include(onnxruntime-linux-aarch64)
else()
include(onnxruntime-linux-aarch64-static)
... ...
function(download_piper_phonemize)
include(FetchContent)
set(piper_phonemize_URL "https://github.com/csukuangfj/piper-phonemize/archive/dc6b5f4441bffe521047086930b0fc12686acd56.zip")
set(piper_phonemize_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip")
set(piper_phonemize_HASH "SHA256=b9faa04204b1756fa455a962abb1f037041c040133d55be58d11f11ab9b3ce14")
set(piper_phonemize_URL "https://github.com/csukuangfj/piper-phonemize/archive/38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip")
set(piper_phonemize_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip")
set(piper_phonemize_HASH "SHA256=ab4d06ca76047e1585c63c482f39ffead5315785345055360703cc9382c5e74b")
# If you don't have access to the Internet,
# please pre-download kaldi-decoder
set(possible_file_locations
$ENV{HOME}/Downloads/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip
${CMAKE_SOURCE_DIR}/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip
${CMAKE_BINARY_DIR}/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip
/tmp/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip
/star-fj/fangjun/download/github/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip
$ENV{HOME}/Downloads/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip
${CMAKE_SOURCE_DIR}/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip
${CMAKE_BINARY_DIR}/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip
/tmp/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip
/star-fj/fangjun/download/github/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip
)
foreach(f IN LISTS possible_file_locations)
... ...
... ... @@ -7,6 +7,8 @@
#include <stdio.h>
#include <stdlib.h>
#include <utility>
#if __ANDROID_API__ >= 8
#include "android/log.h"
#define SHERPA_ONNX_LOGE(...) \
... ... @@ -36,30 +38,28 @@
#endif
// Read an integer
#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \
do { \
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
\
dst = atoi(value.get()); \
if (dst < 0) { \
SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \
exit(-1); \
} \
#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \
do { \
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
if (value.empty()) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
\
dst = atoi(value.c_str()); \
if (dst < 0) { \
SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \
exit(-1); \
} \
} while (0)
#define SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(dst, src_key, default_value) \
do { \
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
if (value.empty()) { \
dst = default_value; \
} else { \
dst = atoi(value.get()); \
dst = atoi(value.c_str()); \
if (dst < 0) { \
SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \
exit(-1); \
... ... @@ -68,118 +68,111 @@
} while (0)
// read a vector of integers
#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \
do { \
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
\
bool ret = SplitStringToIntegers(value.get(), ",", true, &dst); \
if (!ret) { \
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.get(), src_key); \
exit(-1); \
} \
#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \
do { \
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
if (value.empty()) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
\
bool ret = SplitStringToIntegers(value.c_str(), ",", true, &dst); \
if (!ret) { \
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.c_str(), src_key); \
exit(-1); \
} \
} while (0)
// read a vector of floats
#define SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(dst, src_key) \
do { \
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
exit(-1); \
} \
\
bool ret = SplitStringToFloats(value.get(), ",", true, &dst); \
if (!ret) { \
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.get(), src_key); \
exit(-1); \
} \
#define SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(dst, src_key) \
do { \
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
if (value.empty()) { \
SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
exit(-1); \
} \
\
bool ret = SplitStringToFloats(value.c_str(), ",", true, &dst); \
if (!ret) { \
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.c_str(), src_key); \
exit(-1); \
} \
} while (0)
// read a vector of strings
#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \
do { \
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
SplitStringToVector(value.get(), ",", false, &dst); \
\
if (dst.empty()) { \
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \
value.get(), src_key); \
exit(-1); \
} \
#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \
do { \
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
if (value.empty()) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
SplitStringToVector(value.c_str(), ",", false, &dst); \
\
if (dst.empty()) { \
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \
value.c_str(), src_key); \
exit(-1); \
} \
} while (0)
// read a vector of strings separated by sep
#define SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(dst, src_key, sep) \
do { \
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
SplitStringToVector(value.get(), sep, false, &dst); \
\
if (dst.empty()) { \
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \
value.get(), src_key); \
exit(-1); \
} \
#define SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(dst, src_key, sep) \
do { \
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
if (value.empty()) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
SplitStringToVector(value.c_str(), sep, false, &dst); \
\
if (dst.empty()) { \
SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \
value.c_str(), src_key); \
exit(-1); \
} \
} while (0)
// Read a string
#define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \
do { \
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
\
dst = value.get(); \
if (dst.empty()) { \
SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \
exit(-1); \
} \
#define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \
do { \
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
if (value.empty()) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
\
dst = std::move(value); \
if (dst.empty()) { \
SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \
exit(-1); \
} \
} while (0)
#define SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(dst, src_key) \
do { \
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
\
dst = value.get(); \
#define SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(dst, src_key) \
do { \
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
if (value.empty()) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
\
dst = std::move(value); \
} while (0)
#define SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(dst, src_key, \
default_value) \
do { \
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
dst = default_value; \
} else { \
dst = value.get(); \
if (dst.empty()) { \
SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \
exit(-1); \
} \
} \
#define SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(dst, src_key, \
default_value) \
do { \
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
if (value.empty()) { \
dst = default_value; \
} else { \
dst = std::move(value); \
if (dst.empty()) { \
SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \
exit(-1); \
} \
} \
} while (0)
#define SHERPA_ONNX_EXIT(code) exit(code)
... ...
... ... @@ -46,7 +46,7 @@ class OfflineCEDModel::Impl {
int32_t NumEventClasses() const { return num_event_classes_; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
private:
void Init(void *model_data, size_t model_data_length) {
... ...
... ... @@ -44,7 +44,7 @@ class OfflineCtTransformerModel::Impl {
return std::move(ans[0]);
}
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
const OfflineCtTransformerModelMetaData &GetModelMetadata() const {
return meta_data_;
... ...
... ... @@ -53,8 +53,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
Ort::AllocatorWithDefaultOptions allocator;
auto model_type =
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
if (!model_type) {
LookupCustomModelMetaData(meta_data, "model_type", allocator);
if (model_type.empty()) {
SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n"
"If you are using models from NeMo, please refer to\n"
... ... @@ -74,22 +74,22 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kUnknown;
}
if (model_type.get() == std::string("EncDecCTCModelBPE")) {
if (model_type == "EncDecCTCModelBPE") {
return ModelType::kEncDecCTCModelBPE;
} else if (model_type.get() == std::string("EncDecCTCModel")) {
} else if (model_type == "EncDecCTCModel") {
return ModelType::kEncDecCTCModel;
} else if (model_type.get() == std::string("EncDecHybridRNNTCTCBPEModel")) {
} else if (model_type == "EncDecHybridRNNTCTCBPEModel") {
return ModelType::kEncDecHybridRNNTCTCBPEModel;
} else if (model_type.get() == std::string("tdnn")) {
} else if (model_type == "tdnn") {
return ModelType::kTdnn;
} else if (model_type.get() == std::string("zipformer2_ctc")) {
} else if (model_type == "zipformer2_ctc") {
return ModelType::kZipformerCtc;
} else if (model_type.get() == std::string("wenet_ctc")) {
} else if (model_type == "wenet_ctc") {
return ModelType::kWenetCtc;
} else if (model_type.get() == std::string("telespeech_ctc")) {
} else if (model_type == "telespeech_ctc") {
return ModelType::kTeleSpeechCtc;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str());
return ModelType::kUnknown;
}
}
... ...
... ... @@ -155,7 +155,7 @@ class OfflineMoonshineModel::Impl {
return {std::move(cached_decoder_out[0]), std::move(next_states)};
}
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
private:
void InitPreprocessor(void *model_data, size_t model_data_length) {
... ...
... ... @@ -68,7 +68,7 @@ class OfflineNemoEncDecCtcModel::Impl {
int32_t SubsamplingFactor() const { return subsampling_factor_; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
std::string FeatureNormalizationMethod() const { return normalize_type_; }
... ...
... ... @@ -56,7 +56,7 @@ class OfflineParaformerModel::Impl {
const std::vector<float> &InverseStdDev() const { return inv_stddev_; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
private:
void Init(void *model_data, size_t model_data_length) {
... ...
... ... @@ -121,9 +121,9 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
auto model_type_ptr =
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
if (!model_type_ptr) {
auto model_type =
LookupCustomModelMetaData(meta_data, "model_type", allocator);
if (!model_type.empty()) {
SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n\n"
"Please refer to the following URLs to add metadata"
... ... @@ -164,7 +164,6 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
"\n");
exit(-1);
}
std::string model_type(model_type_ptr.get());
if (model_type == "conformer" || model_type == "zipformer" ||
model_type == "zipformer2") {
... ... @@ -301,9 +300,9 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
auto model_type_ptr =
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
if (!model_type_ptr) {
auto model_type =
LookupCustomModelMetaData(meta_data, "model_type", allocator);
if (model_type.empty()) {
SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n\n"
"Please refer to the following URLs to add metadata"
... ... @@ -344,7 +343,6 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
"\n");
exit(-1);
}
std::string model_type(model_type_ptr.get());
if (model_type == "conformer" || model_type == "zipformer" ||
model_type == "zipformer2") {
... ...
... ... @@ -56,7 +56,7 @@ class OfflineSenseVoiceModel::Impl {
return meta_data_;
}
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
private:
void Init(void *model_data, size_t model_data_length) {
... ...
... ... @@ -63,7 +63,7 @@ class OfflineTdnnCtcModel::Impl {
int32_t VocabSize() const { return vocab_size_; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
private:
void Init(void *model_data, size_t model_data_length) {
... ...
... ... @@ -69,7 +69,7 @@ class OfflineTeleSpeechCtcModel::Impl {
int32_t SubsamplingFactor() const { return subsampling_factor_; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
private:
void Init(void *model_data, size_t model_data_length) {
... ...
... ... @@ -95,11 +95,11 @@ class OfflineTransducerModel::Impl {
int32_t VocabSize() const { return vocab_size_; }
int32_t ContextSize() const { return context_size_; }
int32_t SubsamplingFactor() const { return 4; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
Ort::Value BuildDecoderInput(
const std::vector<OfflineTransducerDecoderResult> &results,
int32_t end_index) const {
int32_t end_index) {
assert(end_index <= results.size());
int32_t batch_size = end_index;
... ... @@ -122,7 +122,7 @@ class OfflineTransducerModel::Impl {
}
Ort::Value BuildDecoderInput(const std::vector<Hypothesis> &results,
int32_t end_index) const {
int32_t end_index) {
assert(end_index <= results.size());
int32_t batch_size = end_index;
... ...
... ... @@ -123,7 +123,7 @@ class OfflineTransducerNeMoModel::Impl {
return std::move(logit[0]);
}
std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) const {
std::vector<Ort::Value> GetDecoderInitStates(int32_t batch_size) {
std::array<int64_t, 3> s0_shape{pred_rnn_layers_, batch_size, pred_hidden_};
Ort::Value s0 = Ort::Value::CreateTensor<float>(allocator_, s0_shape.data(),
s0_shape.size());
... ... @@ -149,7 +149,7 @@ class OfflineTransducerNeMoModel::Impl {
int32_t SubsamplingFactor() const { return subsampling_factor_; }
int32_t VocabSize() const { return vocab_size_; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
std::string FeatureNormalizationMethod() const { return normalize_type_; }
... ...
... ... @@ -47,7 +47,7 @@ class OfflineWenetCtcModel::Impl {
int32_t SubsamplingFactor() const { return subsampling_factor_; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
private:
void Init(void *model_data, size_t model_data_length) {
... ...
... ... @@ -188,7 +188,7 @@ class OfflineWhisperModel::Impl {
return {std::move(n_layer_self_k_cache), std::move(n_layer_self_v_cache)};
}
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; }
... ...
... ... @@ -47,7 +47,7 @@ class OfflineZipformerAudioTaggingModel::Impl {
int32_t NumEventClasses() const { return num_event_classes_; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
private:
void Init(void *model_data, size_t model_data_length) {
... ...
... ... @@ -48,7 +48,7 @@ class OfflineZipformerCtcModel::Impl {
int32_t VocabSize() const { return vocab_size_; }
int32_t SubsamplingFactor() const { return 4; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
private:
void Init(void *model_data, size_t model_data_length) {
... ...
... ... @@ -47,7 +47,7 @@ class OnlineCNNBiLSTMModel::Impl {
return {std::move(ans[0]), std::move(ans[1])};
}
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
const OnlineCNNBiLSTMModelMetaData &GetModelMetadata() const {
return meta_data_;
... ...
... ... @@ -163,8 +163,11 @@ std::vector<Ort::Value> OnlineConformerTransducerModel::StackStates(
conv_vec[i] = &states[i][1];
}
Ort::Value attn = Cat(allocator_, attn_vec, 2);
Ort::Value conv = Cat(allocator_, conv_vec, 2);
auto allocator =
const_cast<OnlineConformerTransducerModel *>(this)->allocator_;
Ort::Value attn = Cat(allocator, attn_vec, 2);
Ort::Value conv = Cat(allocator, conv_vec, 2);
std::vector<Ort::Value> ans;
ans.reserve(2);
... ... @@ -183,8 +186,11 @@ OnlineConformerTransducerModel::UnStackStates(
std::vector<std::vector<Ort::Value>> ans(batch_size);
std::vector<Ort::Value> attn_vec = Unbind(allocator_, &states[0], 2);
std::vector<Ort::Value> conv_vec = Unbind(allocator_, &states[1], 2);
auto allocator =
const_cast<OnlineConformerTransducerModel *>(this)->allocator_;
std::vector<Ort::Value> attn_vec = Unbind(allocator, &states[0], 2);
std::vector<Ort::Value> conv_vec = Unbind(allocator, &states[1], 2);
assert(attn_vec.size() == batch_size);
assert(conv_vec.size() == batch_size);
... ...
... ... @@ -158,9 +158,10 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::StackStates(
h_buf[i] = &states[i][0];
c_buf[i] = &states[i][1];
}
auto allocator = const_cast<OnlineLstmTransducerModel *>(this)->allocator_;
Ort::Value h = Cat(allocator_, h_buf, 1);
Ort::Value c = Cat(allocator_, c_buf, 1);
Ort::Value h = Cat(allocator, h_buf, 1);
Ort::Value c = Cat(allocator, c_buf, 1);
std::vector<Ort::Value> ans;
ans.reserve(2);
... ... @@ -177,8 +178,10 @@ std::vector<std::vector<Ort::Value>> OnlineLstmTransducerModel::UnStackStates(
std::vector<std::vector<Ort::Value>> ans(batch_size);
std::vector<Ort::Value> h_vec = Unbind(allocator_, &states[0], 1);
std::vector<Ort::Value> c_vec = Unbind(allocator_, &states[1], 1);
auto allocator = const_cast<OnlineLstmTransducerModel *>(this)->allocator_;
std::vector<Ort::Value> h_vec = Unbind(allocator, &states[0], 1);
std::vector<Ort::Value> c_vec = Unbind(allocator, &states[1], 1);
assert(h_vec.size() == batch_size);
assert(c_vec.size() == batch_size);
... ...
... ... @@ -102,7 +102,7 @@ class OnlineNeMoCtcModel::Impl {
int32_t ChunkShift() const { return chunk_shift_; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
// Return a vector containing 3 tensors
// - cache_last_channel
... ... @@ -119,7 +119,7 @@ class OnlineNeMoCtcModel::Impl {
}
std::vector<Ort::Value> StackStates(
std::vector<std::vector<Ort::Value>> states) const {
std::vector<std::vector<Ort::Value>> states) {
int32_t batch_size = static_cast<int32_t>(states.size());
if (batch_size == 1) {
return std::move(states[0]);
... ... @@ -157,6 +157,8 @@ class OnlineNeMoCtcModel::Impl {
std::vector<Ort::Value> states) const {
assert(states.size() == 3);
auto allocator = const_cast<Impl *>(this)->allocator_;
std::vector<std::vector<Ort::Value>> ans;
auto shape = states[0].GetTensorTypeAndShapeInfo().GetShape();
... ... @@ -171,9 +173,9 @@ class OnlineNeMoCtcModel::Impl {
for (int32_t i = 0; i != 3; ++i) {
std::vector<Ort::Value> v;
if (i == 2) {
v = Unbind<int64_t>(allocator_, &states[i], 0);
v = Unbind<int64_t>(allocator, &states[i], 0);
} else {
v = Unbind(allocator_, &states[i], 0);
v = Unbind(allocator, &states[i], 0);
}
assert(v.size() == batch_size);
... ...
... ... @@ -105,7 +105,7 @@ class OnlineParaformerModel::Impl {
const std::vector<float> &InverseStdDev() const { return inv_stddev_; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
private:
void InitEncoder(void *model_data, size_t model_data_length) {
... ...
... ... @@ -5,10 +5,10 @@
#include "sherpa-onnx/csrc/online-rnn-lm.h"
#include <algorithm>
#include <string>
#include <utility>
#include <vector>
#include <algorithm>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/macros.h"
... ... @@ -53,49 +53,49 @@ class OnlineRnnLM::Impl {
// classic rescore function
void ComputeLMScore(float scale, int32_t context_size,
std::vector<Hypotheses> *hyps) {
Ort::AllocatorWithDefaultOptions allocator;
for (auto &hyp : *hyps) {
for (auto &h_m : hyp) {
auto &h = h_m.second;
auto &ys = h.ys;
const int32_t token_num_in_chunk =
ys.size() - context_size - h.cur_scored_pos - 1;
if (token_num_in_chunk < 1) {
continue;
}
if (h.nn_lm_states.empty()) {
h.nn_lm_states = Convert(GetInitStates());
}
if (token_num_in_chunk >= h.lm_rescore_min_chunk) {
std::array<int64_t, 2> x_shape{1, token_num_in_chunk};
Ort::Value x = Ort::Value::CreateTensor<int64_t>(
allocator, x_shape.data(), x_shape.size());
int64_t *p_x = x.GetTensorMutableData<int64_t>();
std::copy(ys.begin() + context_size + h.cur_scored_pos,
ys.end() - 1, p_x);
// streaming forward by NN LM
auto out = ScoreToken(std::move(x),
Convert(std::move(h.nn_lm_states)));
// update NN LM score in hyp
const float *p_nll = out.first.GetTensorData<float>();
h.lm_log_prob = -scale * (*p_nll);
// update NN LM states in hyp
h.nn_lm_states = Convert(std::move(out.second));
h.cur_scored_pos += token_num_in_chunk;
}
std::vector<Hypotheses> *hyps) {
Ort::AllocatorWithDefaultOptions allocator;
for (auto &hyp : *hyps) {
for (auto &h_m : hyp) {
auto &h = h_m.second;
auto &ys = h.ys;
const int32_t token_num_in_chunk =
ys.size() - context_size - h.cur_scored_pos - 1;
if (token_num_in_chunk < 1) {
continue;
}
if (h.nn_lm_states.empty()) {
h.nn_lm_states = Convert(GetInitStates());
}
if (token_num_in_chunk >= h.lm_rescore_min_chunk) {
std::array<int64_t, 2> x_shape{1, token_num_in_chunk};
Ort::Value x = Ort::Value::CreateTensor<int64_t>(
allocator, x_shape.data(), x_shape.size());
int64_t *p_x = x.GetTensorMutableData<int64_t>();
std::copy(ys.begin() + context_size + h.cur_scored_pos, ys.end() - 1,
p_x);
// streaming forward by NN LM
auto out =
ScoreToken(std::move(x), Convert(std::move(h.nn_lm_states)));
// update NN LM score in hyp
const float *p_nll = out.first.GetTensorData<float>();
h.lm_log_prob = -scale * (*p_nll);
// update NN LM states in hyp
h.nn_lm_states = Convert(std::move(out.second));
h.cur_scored_pos += token_num_in_chunk;
}
}
}
}
std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
Ort::Value x, std::vector<Ort::Value> states) {
... ... @@ -125,7 +125,7 @@ class OnlineRnnLM::Impl {
}
// get init states for classic rescore
std::vector<Ort::Value> GetInitStates() const {
std::vector<Ort::Value> GetInitStates() {
std::vector<Ort::Value> ans;
ans.reserve(init_states_.size());
... ... @@ -226,7 +226,7 @@ std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::ScoreToken(
// classic rescore scores
void OnlineRnnLM::ComputeLMScore(float scale, int32_t context_size,
std::vector<Hypotheses> *hyps) {
std::vector<Hypotheses> *hyps) {
return impl_->ComputeLMScore(scale, context_size, hyps);
}
... ... @@ -235,5 +235,4 @@ void OnlineRnnLM::ComputeLMScoreSF(float scale, Hypothesis *hyp) {
return impl_->ComputeLMScoreSF(scale, hyp);
}
} // namespace sherpa_onnx
... ...
... ... @@ -54,8 +54,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
Ort::AllocatorWithDefaultOptions allocator;
auto model_type =
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
if (!model_type) {
LookupCustomModelMetaData(meta_data, "model_type", allocator);
if (model_type.empty()) {
SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n"
"Please make sure you are using the latest export-onnx.py from icefall "
... ... @@ -63,16 +63,16 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kUnknown;
}
if (model_type.get() == std::string("conformer")) {
if (model_type == "conformer") {
return ModelType::kConformer;
} else if (model_type.get() == std::string("lstm")) {
} else if (model_type == "lstm") {
return ModelType::kLstm;
} else if (model_type.get() == std::string("zipformer")) {
} else if (model_type == "zipformer") {
return ModelType::kZipformer;
} else if (model_type.get() == std::string("zipformer2")) {
} else if (model_type == "zipformer2") {
return ModelType::kZipformer2;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str());
return ModelType::kUnknown;
}
}
... ...
... ... @@ -197,7 +197,7 @@ class OnlineTransducerNeMoModel::Impl {
int32_t VocabSize() const { return vocab_size_; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
std::string FeatureNormalizationMethod() const { return normalize_type_; }
... ... @@ -224,6 +224,8 @@ class OnlineTransducerNeMoModel::Impl {
std::vector<Ort::Value> ans;
auto allocator = const_cast<Impl *>(this)->allocator_;
// stack cache_last_channel
std::vector<const Ort::Value *> buf(batch_size);
... ... @@ -239,9 +241,9 @@ class OnlineTransducerNeMoModel::Impl {
Ort::Value c{nullptr};
if (i == 2) {
c = Cat<int64_t>(allocator_, buf, 0);
c = Cat<int64_t>(allocator, buf, 0);
} else {
c = Cat(allocator_, buf, 0);
c = Cat(allocator, buf, 0);
}
ans.push_back(std::move(c));
... ... @@ -251,7 +253,7 @@ class OnlineTransducerNeMoModel::Impl {
}
std::vector<std::vector<Ort::Value>> UnStackStates(
std::vector<Ort::Value> states) const {
std::vector<Ort::Value> states) {
assert(states.size() == 3);
std::vector<std::vector<Ort::Value>> ans;
... ...
... ... @@ -101,7 +101,7 @@ class OnlineWenetCtcModel::Impl {
return config_.wenet_ctc.chunk_size * subsampling_factor_;
}
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
// Return a vector containing 3 tensors
// - attn_cache
... ...
... ... @@ -179,12 +179,15 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
std::vector<Ort::Value> ans;
ans.reserve(states[0].size());
auto allocator =
const_cast<OnlineZipformerTransducerModel *>(this)->allocator_;
// cached_len
for (int32_t i = 0; i != num_encoders; ++i) {
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][i];
}
auto v = Cat<int64_t>(allocator_, buf, 1); // (num_layers, 1)
auto v = Cat<int64_t>(allocator, buf, 1); // (num_layers, 1)
ans.push_back(std::move(v));
}
... ... @@ -193,7 +196,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][num_encoders + i];
}
auto v = Cat(allocator_, buf, 1); // (num_layers, 1, encoder_dims)
auto v = Cat(allocator, buf, 1); // (num_layers, 1, encoder_dims)
ans.push_back(std::move(v));
}
... ... @@ -203,7 +206,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
buf[n] = &states[n][num_encoders * 2 + i];
}
// (num_layers, left_context_len, 1, attention_dims)
auto v = Cat(allocator_, buf, 2);
auto v = Cat(allocator, buf, 2);
ans.push_back(std::move(v));
}
... ... @@ -213,7 +216,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
buf[n] = &states[n][num_encoders * 3 + i];
}
// (num_layers, left_context_len, 1, attention_dims/2)
auto v = Cat(allocator_, buf, 2);
auto v = Cat(allocator, buf, 2);
ans.push_back(std::move(v));
}
... ... @@ -223,7 +226,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
buf[n] = &states[n][num_encoders * 4 + i];
}
// (num_layers, left_context_len, 1, attention_dims/2)
auto v = Cat(allocator_, buf, 2);
auto v = Cat(allocator, buf, 2);
ans.push_back(std::move(v));
}
... ... @@ -233,7 +236,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
buf[n] = &states[n][num_encoders * 5 + i];
}
// (num_layers, 1, encoder_dims, cnn_module_kernels-1)
auto v = Cat(allocator_, buf, 1);
auto v = Cat(allocator, buf, 1);
ans.push_back(std::move(v));
}
... ... @@ -243,7 +246,7 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::StackStates(
buf[n] = &states[n][num_encoders * 6 + i];
}
// (num_layers, 1, encoder_dims, cnn_module_kernels-1)
auto v = Cat(allocator_, buf, 1);
auto v = Cat(allocator, buf, 1);
ans.push_back(std::move(v));
}
... ... @@ -258,12 +261,15 @@ OnlineZipformerTransducerModel::UnStackStates(
int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
int32_t num_encoders = num_encoder_layers_.size();
auto allocator =
const_cast<OnlineZipformerTransducerModel *>(this)->allocator_;
std::vector<std::vector<Ort::Value>> ans;
ans.resize(batch_size);
// cached_len
for (int32_t i = 0; i != num_encoders; ++i) {
auto v = Unbind<int64_t>(allocator_, &states[i], 1);
auto v = Unbind<int64_t>(allocator, &states[i], 1);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
... ... @@ -273,7 +279,7 @@ OnlineZipformerTransducerModel::UnStackStates(
// cached_avg
for (int32_t i = num_encoders; i != 2 * num_encoders; ++i) {
auto v = Unbind(allocator_, &states[i], 1);
auto v = Unbind(allocator, &states[i], 1);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
... ... @@ -283,7 +289,7 @@ OnlineZipformerTransducerModel::UnStackStates(
// cached_key
for (int32_t i = 2 * num_encoders; i != 3 * num_encoders; ++i) {
auto v = Unbind(allocator_, &states[i], 2);
auto v = Unbind(allocator, &states[i], 2);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
... ... @@ -293,7 +299,7 @@ OnlineZipformerTransducerModel::UnStackStates(
// cached_val
for (int32_t i = 3 * num_encoders; i != 4 * num_encoders; ++i) {
auto v = Unbind(allocator_, &states[i], 2);
auto v = Unbind(allocator, &states[i], 2);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
... ... @@ -303,7 +309,7 @@ OnlineZipformerTransducerModel::UnStackStates(
// cached_val2
for (int32_t i = 4 * num_encoders; i != 5 * num_encoders; ++i) {
auto v = Unbind(allocator_, &states[i], 2);
auto v = Unbind(allocator, &states[i], 2);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
... ... @@ -313,7 +319,7 @@ OnlineZipformerTransducerModel::UnStackStates(
// cached_conv1
for (int32_t i = 5 * num_encoders; i != 6 * num_encoders; ++i) {
auto v = Unbind(allocator_, &states[i], 1);
auto v = Unbind(allocator, &states[i], 1);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
... ... @@ -323,7 +329,7 @@ OnlineZipformerTransducerModel::UnStackStates(
// cached_conv2
for (int32_t i = 6 * num_encoders; i != 7 * num_encoders; ++i) {
auto v = Unbind(allocator_, &states[i], 1);
auto v = Unbind(allocator, &states[i], 1);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
... ...
... ... @@ -70,7 +70,7 @@ class OnlineZipformer2CtcModel::Impl {
int32_t ChunkShift() const { return decode_chunk_len_; }
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
// Return a vector containing 3 tensors
// - attn_cache
... ... @@ -86,7 +86,7 @@ class OnlineZipformer2CtcModel::Impl {
}
std::vector<Ort::Value> StackStates(
std::vector<std::vector<Ort::Value>> states) const {
std::vector<std::vector<Ort::Value>> states) {
int32_t batch_size = static_cast<int32_t>(states.size());
std::vector<const Ort::Value *> buf(batch_size);
... ... @@ -159,7 +159,7 @@ class OnlineZipformer2CtcModel::Impl {
}
std::vector<std::vector<Ort::Value>> UnStackStates(
std::vector<Ort::Value> states) const {
std::vector<Ort::Value> states) {
int32_t m = std::accumulate(num_encoder_layers_.begin(),
num_encoder_layers_.end(), 0);
assert(states.size() == m * 6 + 2);
... ...
... ... @@ -185,6 +185,9 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates(
std::vector<const Ort::Value *> buf(batch_size);
auto allocator =
const_cast<OnlineZipformer2TransducerModel *>(this)->allocator_;
std::vector<Ort::Value> ans;
int32_t num_states = static_cast<int32_t>(states[0].size());
ans.reserve(num_states);
... ... @@ -194,42 +197,42 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates(
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i];
}
auto v = Cat(allocator_, buf, 1);
auto v = Cat(allocator, buf, 1);
ans.push_back(std::move(v));
}
{
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i + 1];
}
auto v = Cat(allocator_, buf, 1);
auto v = Cat(allocator, buf, 1);
ans.push_back(std::move(v));
}
{
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i + 2];
}
auto v = Cat(allocator_, buf, 1);
auto v = Cat(allocator, buf, 1);
ans.push_back(std::move(v));
}
{
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i + 3];
}
auto v = Cat(allocator_, buf, 1);
auto v = Cat(allocator, buf, 1);
ans.push_back(std::move(v));
}
{
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i + 4];
}
auto v = Cat(allocator_, buf, 0);
auto v = Cat(allocator, buf, 0);
ans.push_back(std::move(v));
}
{
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i + 5];
}
auto v = Cat(allocator_, buf, 0);
auto v = Cat(allocator, buf, 0);
ans.push_back(std::move(v));
}
}
... ... @@ -238,7 +241,7 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates(
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][num_states - 2];
}
auto v = Cat(allocator_, buf, 0);
auto v = Cat(allocator, buf, 0);
ans.push_back(std::move(v));
}
... ... @@ -246,7 +249,7 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates(
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][num_states - 1];
}
auto v = Cat<int64_t>(allocator_, buf, 0);
auto v = Cat<int64_t>(allocator, buf, 0);
ans.push_back(std::move(v));
}
return ans;
... ... @@ -261,12 +264,15 @@ OnlineZipformer2TransducerModel::UnStackStates(
int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
auto allocator =
const_cast<OnlineZipformer2TransducerModel *>(this)->allocator_;
std::vector<std::vector<Ort::Value>> ans;
ans.resize(batch_size);
for (int32_t i = 0; i != m; ++i) {
{
auto v = Unbind(allocator_, &states[i * 6], 1);
auto v = Unbind(allocator, &states[i * 6], 1);
assert(static_cast<int32_t>(v.size()) == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
... ... @@ -274,7 +280,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
}
}
{
auto v = Unbind(allocator_, &states[i * 6 + 1], 1);
auto v = Unbind(allocator, &states[i * 6 + 1], 1);
assert(static_cast<int32_t>(v.size()) == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
... ... @@ -282,7 +288,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
}
}
{
auto v = Unbind(allocator_, &states[i * 6 + 2], 1);
auto v = Unbind(allocator, &states[i * 6 + 2], 1);
assert(static_cast<int32_t>(v.size()) == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
... ... @@ -290,7 +296,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
}
}
{
auto v = Unbind(allocator_, &states[i * 6 + 3], 1);
auto v = Unbind(allocator, &states[i * 6 + 3], 1);
assert(static_cast<int32_t>(v.size()) == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
... ... @@ -298,7 +304,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
}
}
{
auto v = Unbind(allocator_, &states[i * 6 + 4], 0);
auto v = Unbind(allocator, &states[i * 6 + 4], 0);
assert(static_cast<int32_t>(v.size()) == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
... ... @@ -306,7 +312,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
}
}
{
auto v = Unbind(allocator_, &states[i * 6 + 5], 0);
auto v = Unbind(allocator, &states[i * 6 + 5], 0);
assert(static_cast<int32_t>(v.size()) == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
... ... @@ -316,7 +322,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
}
{
auto v = Unbind(allocator_, &states[m * 6], 0);
auto v = Unbind(allocator, &states[m * 6], 0);
assert(static_cast<int32_t>(v.size()) == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
... ... @@ -324,7 +330,7 @@ OnlineZipformer2TransducerModel::UnStackStates(
}
}
{
auto v = Unbind<int64_t>(allocator_, &states[m * 6 + 1], 0);
auto v = Unbind<int64_t>(allocator, &states[m * 6 + 1], 0);
assert(static_cast<int32_t>(v.size()) == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
... ...
... ... @@ -21,6 +21,36 @@
namespace sherpa_onnx {
static std::string GetInputName(Ort::Session *sess, size_t index,
OrtAllocator *allocator) {
// Note(fangjun): We only tested 1.17.1 and 1.11.0
// For other versions, we may need to change it
#if ORT_API_VERSION >= 17
auto v = sess->GetInputNameAllocated(index, allocator);
return v.get();
#else
auto v = sess->GetInputName(index, allocator);
std::string ans = v;
allocator->Free(allocator, v);
return ans;
#endif
}
static std::string GetOutputName(Ort::Session *sess, size_t index,
OrtAllocator *allocator) {
// Note(fangjun): We only tested 1.17.1 and 1.11.0
// For other versions, we may need to change it
#if ORT_API_VERSION >= 17
auto v = sess->GetOutputNameAllocated(index, allocator);
return v.get();
#else
auto v = sess->GetOutputName(index, allocator);
std::string ans = v;
allocator->Free(allocator, v);
return ans;
#endif
}
void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names,
std::vector<const char *> *input_names_ptr) {
Ort::AllocatorWithDefaultOptions allocator;
... ... @@ -28,8 +58,7 @@ void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names,
input_names->resize(node_count);
input_names_ptr->resize(node_count);
for (size_t i = 0; i != node_count; ++i) {
auto tmp = sess->GetInputNameAllocated(i, allocator);
(*input_names)[i] = tmp.get();
(*input_names)[i] = GetInputName(sess, i, allocator);
(*input_names_ptr)[i] = (*input_names)[i].c_str();
}
}
... ... @@ -41,8 +70,7 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
output_names->resize(node_count);
output_names_ptr->resize(node_count);
for (size_t i = 0; i != node_count; ++i) {
auto tmp = sess->GetOutputNameAllocated(i, allocator);
(*output_names)[i] = tmp.get();
(*output_names)[i] = GetOutputName(sess, i, allocator);
(*output_names_ptr)[i] = (*output_names)[i].c_str();
}
}
... ... @@ -78,12 +106,24 @@ Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out,
void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) {
Ort::AllocatorWithDefaultOptions allocator;
#if ORT_API_VERSION >= 17
std::vector<Ort::AllocatedStringPtr> v =
meta_data.GetCustomMetadataMapKeysAllocated(allocator);
for (const auto &key : v) {
auto p = meta_data.LookupCustomMetadataMapAllocated(key.get(), allocator);
os << key.get() << "=" << p.get() << "\n";
}
#else
int64_t num_keys = 0;
char **keys = meta_data.GetCustomMetadataMapKeys(allocator, num_keys);
for (int32_t i = 0; i < num_keys; ++i) {
auto v = LookupCustomModelMetaData(meta_data, keys[i], allocator);
os << keys[i] << "=" << v << "\n";
allocator.Free(keys[i]);
}
allocator.Free(keys);
#endif
}
Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) {
... ... @@ -361,4 +401,20 @@ std::vector<Ort::Value> Convert(std::vector<CopyableOrtValue> values) {
return ans;
}
std::string LookupCustomModelMetaData(const Ort::ModelMetadata &meta_data,
const char *key,
OrtAllocator *allocator) {
// Note(fangjun): We only tested 1.17.1 and 1.11.0
// For other versions, we may need to change it
#if ORT_API_VERSION >= 17
auto v = meta_data.LookupCustomMetadataMapAllocated(key, allocator);
return v.get();
#else
auto v = meta_data.LookupCustomMetadataMap(key, allocator);
std::string ans = v;
allocator->Free(allocator, v);
return ans;
#endif
}
} // namespace sherpa_onnx
... ...
... ... @@ -59,6 +59,9 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out,
int32_t t);
std::string LookupCustomModelMetaData(const Ort::ModelMetadata &meta_data,
const char *key, OrtAllocator *allocator);
void PrintModelMetadata(std::ostream &os,
const Ort::ModelMetadata &meta_data); // NOLINT
... ...
... ... @@ -60,6 +60,7 @@ Ort::SessionOptions GetSessionOptionsImpl(
case Provider::kCPU:
break; // nothing to do for the CPU provider
case Provider::kXnnpack: {
#if ORT_API_VERSION >= 17
if (std::find(available_providers.begin(), available_providers.end(),
"XnnpackExecutionProvider") != available_providers.end()) {
sess_opts.AppendExecutionProvider("XNNPACK");
... ... @@ -67,6 +68,11 @@ Ort::SessionOptions GetSessionOptionsImpl(
SHERPA_ONNX_LOGE("Available providers: %s. Fallback to cpu!",
os.str().c_str());
}
#else
SHERPA_ONNX_LOGE(
"Does not support xnnpack for onnxruntime: %d. Fallback to cpu!",
static_cast<int32_t>(ORT_API_VERSION));
#endif
break;
}
case Provider::kTRT: {
... ...
... ... @@ -40,8 +40,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
Ort::AllocatorWithDefaultOptions allocator;
auto model_type =
meta_data.LookupCustomMetadataMapAllocated("framework", allocator);
if (!model_type) {
LookupCustomModelMetaData(meta_data, "framework", allocator);
if (model_type.empty()) {
SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n"
"Please make sure you have added metadata to the model.\n\n"
... ... @@ -52,14 +52,14 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kUnknown;
}
if (model_type.get() == std::string("wespeaker")) {
if (model_type == "wespeaker") {
return ModelType::kWeSpeaker;
} else if (model_type.get() == std::string("3d-speaker")) {
} else if (model_type == "3d-speaker") {
return ModelType::k3dSpeaker;
} else if (model_type.get() == std::string("nemo")) {
} else if (model_type == "nemo") {
return ModelType::kNeMo;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str());
return ModelType::kUnknown;
}
}
... ...
... ... @@ -53,7 +53,7 @@ class SpeakerEmbeddingExtractorNeMoModel::Impl {
return std::move(outputs[0]);
}
OrtAllocator *Allocator() const { return allocator_; }
OrtAllocator *Allocator() { return allocator_; }
const SpeakerEmbeddingExtractorNeMoModelMetaData &GetMetaData() const {
return meta_data_;
... ...
... ... @@ -42,8 +42,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
Ort::AllocatorWithDefaultOptions allocator;
auto model_type =
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
if (!model_type) {
LookupCustomModelMetaData(meta_data, "model_type", allocator);
if (model_type.empty()) {
SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n"
"Please make sure you have added metadata to the model.\n\n"
... ... @@ -54,11 +54,10 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kUnknown;
}
auto model_type_str = std::string(model_type.get());
if (model_type_str.find("whisper") == 0) {
if (model_type.find("whisper") == 0) {
return ModelType::kWhisper;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str());
return ModelType::kUnknown;
}
}
... ...
... ... @@ -29,20 +29,19 @@ namespace {
const char *ws = " \t\n\r\f\v";
// trim from end of string (right)
inline std::string &TrimRight(std::string &s, const char *t = ws) {
s.erase(s.find_last_not_of(t) + 1);
return s;
inline void TrimRight(std::string *s, const char *t = ws) {
s->erase(s->find_last_not_of(t) + 1);
}
// trim from beginning of string (left)
inline std::string &TrimLeft(std::string &s, const char *t = ws) {
s.erase(0, s.find_first_not_of(t));
return s;
inline void TrimLeft(std::string *s, const char *t = ws) {
s->erase(0, s->find_first_not_of(t));
}
// trim from both ends of string (right then left)
inline std::string &Trim(std::string &s, const char *t = ws) {
return TrimLeft(TrimRight(s, t), t);
inline void Trim(std::string *s, const char *t = ws) {
TrimRight(s, t);
TrimLeft(s, t);
}
} // namespace
... ... @@ -56,7 +55,7 @@ std::unordered_map<std::string, int32_t> ReadTokens(
std::string sym;
int32_t id = -1;
while (std::getline(is, line)) {
Trim(line);
Trim(&line);
std::istringstream iss(line);
iss >> sym;
if (iss.eof()) {
... ...