Committed by
GitHub
Refactor the code (#15)
* code refactoring * Remove reference files * Update README and CI * small fixes * fix style issues * add style check for CI * fix style issues * remove kaldi-native-io
正在显示
26 个修改的文件
包含
1179 行增加
和
718 行删除
.clang-format
0 → 100644
.github/workflows/style_check.yaml
0 → 100644
| 1 | +# Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) | ||
| 2 | +# | ||
| 3 | +# See ../../LICENSE for clarification regarding multiple authors | ||
| 4 | +# | ||
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 6 | +# you may not use this file except in compliance with the License. | ||
| 7 | +# You may obtain a copy of the License at | ||
| 8 | +# | ||
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
| 10 | +# | ||
| 11 | +# Unless required by applicable law or agreed to in writing, software | ||
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 14 | +# See the License for the specific language governing permissions and | ||
| 15 | +# limitations under the License. | ||
| 16 | + | ||
| 17 | +name: style_check | ||
| 18 | + | ||
| 19 | +on: | ||
| 20 | + push: | ||
| 21 | + branches: | ||
| 22 | + - master | ||
| 23 | + paths: | ||
| 24 | + - '.github/workflows/style_check.yaml' | ||
| 25 | + - 'sherpa-onnx/**' | ||
| 26 | + pull_request: | ||
| 27 | + branches: | ||
| 28 | + - master | ||
| 29 | + paths: | ||
| 30 | + - '.github/workflows/style_check.yaml' | ||
| 31 | + - 'sherpa-onnx/**' | ||
| 32 | + | ||
| 33 | +concurrency: | ||
| 34 | + group: style_check-${{ github.ref }} | ||
| 35 | + cancel-in-progress: true | ||
| 36 | + | ||
| 37 | +jobs: | ||
| 38 | + style_check: | ||
| 39 | + runs-on: ubuntu-latest | ||
| 40 | + strategy: | ||
| 41 | + matrix: | ||
| 42 | + python-version: [3.8] | ||
| 43 | + fail-fast: false | ||
| 44 | + | ||
| 45 | + steps: | ||
| 46 | + - uses: actions/checkout@v2 | ||
| 47 | + with: | ||
| 48 | + fetch-depth: 0 | ||
| 49 | + | ||
| 50 | + - name: Setup Python ${{ matrix.python-version }} | ||
| 51 | + uses: actions/setup-python@v1 | ||
| 52 | + with: | ||
| 53 | + python-version: ${{ matrix.python-version }} | ||
| 54 | + | ||
| 55 | + - name: Check style with cpplint | ||
| 56 | + shell: bash | ||
| 57 | + working-directory: ${{github.workspace}} | ||
| 58 | + run: ./scripts/check_style_cpplint.sh |
| @@ -59,31 +59,28 @@ jobs: | @@ -59,31 +59,28 @@ jobs: | ||
| 59 | - name: Run tests for ubuntu/macos (English) | 59 | - name: Run tests for ubuntu/macos (English) |
| 60 | run: | | 60 | run: | |
| 61 | time ./build/bin/sherpa-onnx \ | 61 | time ./build/bin/sherpa-onnx \ |
| 62 | + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ | ||
| 62 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ | 63 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ |
| 63 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ | 64 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ |
| 64 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ | 65 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ |
| 65 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ | 66 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ |
| 66 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ | 67 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ |
| 67 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ | ||
| 68 | - greedy \ | ||
| 69 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav | 68 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav |
| 70 | 69 | ||
| 71 | time ./build/bin/sherpa-onnx \ | 70 | time ./build/bin/sherpa-onnx \ |
| 71 | + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ | ||
| 72 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ | 72 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ |
| 73 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ | 73 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ |
| 74 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ | 74 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ |
| 75 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ | 75 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ |
| 76 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ | 76 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ |
| 77 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ | ||
| 78 | - greedy \ | ||
| 79 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav | 77 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav |
| 80 | 78 | ||
| 81 | time ./build/bin/sherpa-onnx \ | 79 | time ./build/bin/sherpa-onnx \ |
| 80 | + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ | ||
| 82 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ | 81 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ |
| 83 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ | 82 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ |
| 84 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ | 83 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ |
| 85 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ | 84 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ |
| 86 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ | 85 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ |
| 87 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ | ||
| 88 | - greedy \ | ||
| 89 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav | 86 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav |
| @@ -38,7 +38,6 @@ set(CMAKE_CXX_EXTENSIONS OFF) | @@ -38,7 +38,6 @@ set(CMAKE_CXX_EXTENSIONS OFF) | ||
| 38 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) | 38 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) |
| 39 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) | 39 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) |
| 40 | 40 | ||
| 41 | -include(kaldi_native_io) | ||
| 42 | include(kaldi-native-fbank) | 41 | include(kaldi-native-fbank) |
| 43 | include(onnxruntime) | 42 | include(onnxruntime) |
| 44 | 43 |
| @@ -14,6 +14,9 @@ the following links: | @@ -14,6 +14,9 @@ the following links: | ||
| 14 | **NOTE**: We provide only non-streaming models at present. | 14 | **NOTE**: We provide only non-streaming models at present. |
| 15 | 15 | ||
| 16 | 16 | ||
| 17 | +**HINT**: The script for exporting the English model can be found at | ||
| 18 | +<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless3/export.py> | ||
| 19 | + | ||
| 17 | # Usage | 20 | # Usage |
| 18 | 21 | ||
| 19 | ```bash | 22 | ```bash |
| @@ -34,13 +37,14 @@ cd .. | @@ -34,13 +37,14 @@ cd .. | ||
| 34 | git lfs install | 37 | git lfs install |
| 35 | git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 | 38 | git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 |
| 36 | 39 | ||
| 40 | +./build/bin/sherpa-onnx --help | ||
| 41 | + | ||
| 37 | ./build/bin/sherpa-onnx \ | 42 | ./build/bin/sherpa-onnx \ |
| 43 | + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ | ||
| 38 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ | 44 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ |
| 39 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ | 45 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ |
| 40 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ | 46 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ |
| 41 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ | 47 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ |
| 42 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ | 48 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ |
| 43 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ | ||
| 44 | - greedy \ | ||
| 45 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav | 49 | ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav |
| 46 | ``` | 50 | ``` |
cmake/kaldi_native_io.cmake
已删除
100644 → 0
| 1 | -function(download_kaldi_native_io) | ||
| 2 | - if(CMAKE_VERSION VERSION_LESS 3.11) | ||
| 3 | - # FetchContent is available since 3.11, | ||
| 4 | - # we've copied it to ${CMAKE_SOURCE_DIR}/cmake/Modules | ||
| 5 | - # so that it can be used in lower CMake versions. | ||
| 6 | - message(STATUS "Use FetchContent provided by sherpa-onnx") | ||
| 7 | - list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) | ||
| 8 | - endif() | ||
| 9 | - | ||
| 10 | - include(FetchContent) | ||
| 11 | - | ||
| 12 | - set(kaldi_native_io_URL "https://github.com/csukuangfj/kaldi_native_io/archive/refs/tags/v1.15.1.tar.gz") | ||
| 13 | - set(kaldi_native_io_HASH "SHA256=97377e1d61e99d8fc1d6037a418d3037522dfa46337e06162e24b1d97f3d70a6") | ||
| 14 | - | ||
| 15 | - set(KALDI_NATIVE_IO_BUILD_TESTS OFF CACHE BOOL "" FORCE) | ||
| 16 | - set(KALDI_NATIVE_IO_BUILD_PYTHON OFF CACHE BOOL "" FORCE) | ||
| 17 | - | ||
| 18 | - FetchContent_Declare(kaldi_native_io | ||
| 19 | - URL ${kaldi_native_io_URL} | ||
| 20 | - URL_HASH ${kaldi_native_io_HASH} | ||
| 21 | - ) | ||
| 22 | - | ||
| 23 | - FetchContent_GetProperties(kaldi_native_io) | ||
| 24 | - if(NOT kaldi_native_io_POPULATED) | ||
| 25 | - message(STATUS "Downloading kaldi_native_io ${kaldi_native_io_URL}") | ||
| 26 | - FetchContent_Populate(kaldi_native_io) | ||
| 27 | - endif() | ||
| 28 | - message(STATUS "kaldi_native_io is downloaded to ${kaldi_native_io_SOURCE_DIR}") | ||
| 29 | - message(STATUS "kaldi_native_io's binary dir is ${kaldi_native_io_BINARY_DIR}") | ||
| 30 | - | ||
| 31 | - add_subdirectory(${kaldi_native_io_SOURCE_DIR} ${kaldi_native_io_BINARY_DIR} EXCLUDE_FROM_ALL) | ||
| 32 | - | ||
| 33 | - target_include_directories(kaldi_native_io_core | ||
| 34 | - PUBLIC | ||
| 35 | - ${kaldi_native_io_SOURCE_DIR}/ | ||
| 36 | - ) | ||
| 37 | -endfunction() | ||
| 38 | - | ||
| 39 | -download_kaldi_native_io() |
| @@ -10,7 +10,7 @@ function(download_onnxruntime) | @@ -10,7 +10,7 @@ function(download_onnxruntime) | ||
| 10 | include(FetchContent) | 10 | include(FetchContent) |
| 11 | 11 | ||
| 12 | if(UNIX AND NOT APPLE) | 12 | if(UNIX AND NOT APPLE) |
| 13 | - set(onnxruntime_URL "http://github.com/microsoft/onnxruntime/releases/download/v1.12.1/onnxruntime-linux-x64-1.12.1.tgz") | 13 | + set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.12.1/onnxruntime-linux-x64-1.12.1.tgz") |
| 14 | 14 | ||
| 15 | # If you don't have access to the internet, you can first download onnxruntime to some directory, and the use | 15 | # If you don't have access to the internet, you can first download onnxruntime to some directory, and the use |
| 16 | # set(onnxruntime_URL "file:///ceph-fj/fangjun/open-source/sherpa-onnx/onnxruntime-linux-x64-1.12.1.tgz") | 16 | # set(onnxruntime_URL "file:///ceph-fj/fangjun/open-source/sherpa-onnx/onnxruntime-linux-x64-1.12.1.tgz") |
scripts/check_style_cpplint.sh
0 → 100755
| 1 | +#!/bin/bash | ||
| 2 | +# | ||
| 3 | +# Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang) | ||
| 4 | +# | ||
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 6 | +# you may not use this file except in compliance with the License. | ||
| 7 | +# You may obtain a copy of the License at | ||
| 8 | +# | ||
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
| 10 | +# | ||
| 11 | +# Unless required by applicable law or agreed to in writing, software | ||
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 14 | +# See the License for the specific language governing permissions and | ||
| 15 | +# limitations under the License. | ||
| 16 | +# | ||
| 17 | +# Usage: | ||
| 18 | +# | ||
| 19 | +# (1) To check files of the last commit | ||
| 20 | +# ./scripts/check_style_cpplint.sh | ||
| 21 | +# | ||
| 22 | +# (2) To check changed files not committed yet | ||
| 23 | +# ./scripts/check_style_cpplint.sh 1 | ||
| 24 | +# | ||
| 25 | +# (3) To check all files in the project | ||
| 26 | +# ./scripts/check_style_cpplint.sh 2 | ||
| 27 | + | ||
| 28 | + | ||
| 29 | +cpplint_version="1.5.4" | ||
| 30 | +cur_dir=$(cd $(dirname $BASH_SOURCE) && pwd) | ||
| 31 | +sherpa_onnx_dir=$(cd $cur_dir/.. && pwd) | ||
| 32 | + | ||
| 33 | +build_dir=$sherpa_onnx_dir/build | ||
| 34 | +mkdir -p $build_dir | ||
| 35 | + | ||
| 36 | +cpplint_src=$build_dir/cpplint-${cpplint_version}/cpplint.py | ||
| 37 | + | ||
| 38 | +if [ ! -d "$build_dir/cpplint-${cpplint_version}" ]; then | ||
| 39 | + pushd $build_dir | ||
| 40 | + if command -v wget &> /dev/null; then | ||
| 41 | + wget https://github.com/cpplint/cpplint/archive/${cpplint_version}.tar.gz | ||
| 42 | + elif command -v curl &> /dev/null; then | ||
| 43 | + curl -O -SL https://github.com/cpplint/cpplint/archive/${cpplint_version}.tar.gz | ||
| 44 | + else | ||
| 45 | + echo "Please install wget or curl to download cpplint" | ||
| 46 | + exit 1 | ||
| 47 | + fi | ||
| 48 | + tar xf ${cpplint_version}.tar.gz | ||
| 49 | + rm ${cpplint_version}.tar.gz | ||
| 50 | + | ||
| 51 | + # cpplint will report the following error for: __host__ __device__ ( | ||
| 52 | + # | ||
| 53 | + # Extra space before ( in function call [whitespace/parens] [4] | ||
| 54 | + # | ||
| 55 | + # the following patch disables the above error | ||
| 56 | + sed -i "3490i\ not Search(r'__host__ __device__\\\s+\\\(', fncall) and" $cpplint_src | ||
| 57 | + popd | ||
| 58 | +fi | ||
| 59 | + | ||
| 60 | +source $sherpa_onnx_dir/scripts/utils.sh | ||
| 61 | + | ||
| 62 | +# return true if the given file is a c++ source file | ||
| 63 | +# return false otherwise | ||
| 64 | +function is_source_code_file() { | ||
| 65 | + case "$1" in | ||
| 66 | + *.cc|*.h|*.cu) | ||
| 67 | + echo true;; | ||
| 68 | + *) | ||
| 69 | + echo false;; | ||
| 70 | + esac | ||
| 71 | +} | ||
| 72 | + | ||
| 73 | +function check_style() { | ||
| 74 | + python3 $cpplint_src $1 || abort $1 | ||
| 75 | +} | ||
| 76 | + | ||
| 77 | +function check_last_commit() { | ||
| 78 | + files=$(git diff HEAD^1 --name-only --diff-filter=ACDMRUXB) | ||
| 79 | + echo $files | ||
| 80 | +} | ||
| 81 | + | ||
| 82 | +function check_current_dir() { | ||
| 83 | + files=$(git status -s -uno --porcelain | awk '{ | ||
| 84 | + if (NF == 4) { | ||
| 85 | + # a file has been renamed | ||
| 86 | + print $NF | ||
| 87 | + } else { | ||
| 88 | + print $2 | ||
| 89 | + }}') | ||
| 90 | + | ||
| 91 | + echo $files | ||
| 92 | +} | ||
| 93 | + | ||
| 94 | +function do_check() { | ||
| 95 | + case "$1" in | ||
| 96 | + 1) | ||
| 97 | + echo "Check changed files" | ||
| 98 | + files=$(check_current_dir) | ||
| 99 | + ;; | ||
| 100 | + 2) | ||
| 101 | + echo "Check all files" | ||
| 102 | + files=$(find $sherpa_onnx_dir/sherpa-onnx -name "*.h" -o -name "*.cc") | ||
| 103 | + ;; | ||
| 104 | + *) | ||
| 105 | + echo "Check last commit" | ||
| 106 | + files=$(check_last_commit) | ||
| 107 | + ;; | ||
| 108 | + esac | ||
| 109 | + | ||
| 110 | + for f in $files; do | ||
| 111 | + need_check=$(is_source_code_file $f) | ||
| 112 | + if $need_check; then | ||
| 113 | + [[ -f $f ]] && check_style $f | ||
| 114 | + fi | ||
| 115 | + done | ||
| 116 | +} | ||
| 117 | + | ||
| 118 | +function main() { | ||
| 119 | + do_check $1 | ||
| 120 | + | ||
| 121 | + ok "Great! Style check passed!" | ||
| 122 | +} | ||
| 123 | + | ||
| 124 | +cd $sherpa_onnx_dir | ||
| 125 | + | ||
| 126 | +main $1 |
scripts/utils.sh
0 → 100644
| 1 | +#!/bin/bash | ||
| 2 | + | ||
| 3 | +default='\033[0m' | ||
| 4 | +bold='\033[1m' | ||
| 5 | +red='\033[31m' | ||
| 6 | +green='\033[32m' | ||
| 7 | + | ||
| 8 | +function ok() { | ||
| 9 | + printf "${bold}${green}[OK]${default} $1\n" | ||
| 10 | +} | ||
| 11 | + | ||
| 12 | +function error() { | ||
| 13 | + printf "${bold}${red}[FAILED]${default} $1\n" | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +function abort() { | ||
| 17 | + printf "${bold}${red}[FAILED]${default} $1\n" | ||
| 18 | + exit 1 | ||
| 19 | +} |
| 1 | include_directories(${CMAKE_SOURCE_DIR}) | 1 | include_directories(${CMAKE_SOURCE_DIR}) |
| 2 | -add_executable(sherpa-onnx main.cpp) | 2 | + |
| 3 | +add_executable(sherpa-onnx | ||
| 4 | + decode.cc | ||
| 5 | + rnnt-model.cc | ||
| 6 | + sherpa-onnx.cc | ||
| 7 | + symbol-table.cc | ||
| 8 | + wave-reader.cc | ||
| 9 | +) | ||
| 3 | 10 | ||
| 4 | target_link_libraries(sherpa-onnx | 11 | target_link_libraries(sherpa-onnx |
| 5 | onnxruntime | 12 | onnxruntime |
| 6 | kaldi-native-fbank-core | 13 | kaldi-native-fbank-core |
| 7 | - kaldi_native_io_core | ||
| 8 | ) | 14 | ) |
| 15 | + | ||
| 16 | +add_executable(sherpa-show-onnx-info show-onnx-info.cc) | ||
| 17 | +target_link_libraries(sherpa-show-onnx-info onnxruntime) |
sherpa-onnx/csrc/decode.cc
0 → 100644
| 1 | +/** | ||
| 2 | + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) | ||
| 3 | + * | ||
| 4 | + * See LICENSE for clarification regarding multiple authors | ||
| 5 | + * | ||
| 6 | + * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 7 | + * you may not use this file except in compliance with the License. | ||
| 8 | + * You may obtain a copy of the License at | ||
| 9 | + * | ||
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 | ||
| 11 | + * | ||
| 12 | + * Unless required by applicable law or agreed to in writing, software | ||
| 13 | + * distributed under the License is distributed on an "AS IS" BASIS, | ||
| 14 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 15 | + * See the License for the specific language governing permissions and | ||
| 16 | + * limitations under the License. | ||
| 17 | + */ | ||
| 18 | + | ||
| 19 | +#include "sherpa-onnx/csrc/decode.h" | ||
| 20 | + | ||
| 21 | +#include <assert.h> | ||
| 22 | + | ||
| 23 | +#include <algorithm> | ||
| 24 | +#include <vector> | ||
| 25 | + | ||
| 26 | +namespace sherpa_onnx { | ||
| 27 | + | ||
| 28 | +std::vector<int32_t> GreedySearch(RnntModel &model, // NOLINT | ||
| 29 | + const Ort::Value &encoder_out) { | ||
| 30 | + std::vector<int64_t> encoder_out_shape = | ||
| 31 | + encoder_out.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 32 | + assert(encoder_out_shape[0] == 1 && "Only batch_size=1 is implemented"); | ||
| 33 | + Ort::Value projected_encoder_out = | ||
| 34 | + model.RunJoinerEncoderProj(encoder_out.GetTensorData<float>(), | ||
| 35 | + encoder_out_shape[1], encoder_out_shape[2]); | ||
| 36 | + | ||
| 37 | + const float *p_projected_encoder_out = | ||
| 38 | + projected_encoder_out.GetTensorData<float>(); | ||
| 39 | + | ||
| 40 | + int32_t context_size = 2; // hard-code it to 2 | ||
| 41 | + int32_t blank_id = 0; // hard-code it to 0 | ||
| 42 | + std::vector<int32_t> hyp(context_size, blank_id); | ||
| 43 | + std::array<int64_t, 2> decoder_input{blank_id, blank_id}; | ||
| 44 | + | ||
| 45 | + Ort::Value decoder_out = model.RunDecoder(decoder_input.data(), context_size); | ||
| 46 | + | ||
| 47 | + std::vector<int64_t> decoder_out_shape = | ||
| 48 | + decoder_out.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 49 | + | ||
| 50 | + Ort::Value projected_decoder_out = model.RunJoinerDecoderProj( | ||
| 51 | + decoder_out.GetTensorData<float>(), decoder_out_shape[2]); | ||
| 52 | + | ||
| 53 | + int32_t joiner_dim = | ||
| 54 | + projected_decoder_out.GetTensorTypeAndShapeInfo().GetShape()[1]; | ||
| 55 | + | ||
| 56 | + int32_t T = encoder_out_shape[1]; | ||
| 57 | + for (int32_t t = 0; t != T; ++t) { | ||
| 58 | + Ort::Value logit = model.RunJoiner( | ||
| 59 | + p_projected_encoder_out + t * joiner_dim, | ||
| 60 | + projected_decoder_out.GetTensorData<float>(), joiner_dim); | ||
| 61 | + | ||
| 62 | + int32_t vocab_size = logit.GetTensorTypeAndShapeInfo().GetShape()[1]; | ||
| 63 | + | ||
| 64 | + const float *p_logit = logit.GetTensorData<float>(); | ||
| 65 | + | ||
| 66 | + auto y = static_cast<int32_t>(std::distance( | ||
| 67 | + static_cast<const float *>(p_logit), | ||
| 68 | + std::max_element(static_cast<const float *>(p_logit), | ||
| 69 | + static_cast<const float *>(p_logit) + vocab_size))); | ||
| 70 | + | ||
| 71 | + if (y != blank_id) { | ||
| 72 | + decoder_input[0] = hyp.back(); | ||
| 73 | + decoder_input[1] = y; | ||
| 74 | + hyp.push_back(y); | ||
| 75 | + decoder_out = model.RunDecoder(decoder_input.data(), context_size); | ||
| 76 | + projected_decoder_out = model.RunJoinerDecoderProj( | ||
| 77 | + decoder_out.GetTensorData<float>(), decoder_out_shape[2]); | ||
| 78 | + } | ||
| 79 | + } | ||
| 80 | + | ||
| 81 | + return {hyp.begin() + context_size, hyp.end()}; | ||
| 82 | +} | ||
| 83 | + | ||
| 84 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/decode.h
0 → 100644
| 1 | +/** | ||
| 2 | + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) | ||
| 3 | + * | ||
| 4 | + * See LICENSE for clarification regarding multiple authors | ||
| 5 | + * | ||
| 6 | + * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 7 | + * you may not use this file except in compliance with the License. | ||
| 8 | + * You may obtain a copy of the License at | ||
| 9 | + * | ||
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 | ||
| 11 | + * | ||
| 12 | + * Unless required by applicable law or agreed to in writing, software | ||
| 13 | + * distributed under the License is distributed on an "AS IS" BASIS, | ||
| 14 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 15 | + * See the License for the specific language governing permissions and | ||
| 16 | + * limitations under the License. | ||
| 17 | + */ | ||
| 18 | + | ||
| 19 | +#ifndef SHERPA_ONNX_CSRC_DECODE_H_ | ||
| 20 | +#define SHERPA_ONNX_CSRC_DECODE_H_ | ||
| 21 | + | ||
| 22 | +#include <vector> | ||
| 23 | + | ||
| 24 | +#include "sherpa-onnx/csrc/rnnt-model.h" | ||
| 25 | + | ||
| 26 | +namespace sherpa_onnx { | ||
| 27 | + | ||
| 28 | +/** Greedy search for non-streaming ASR. | ||
| 29 | + * | ||
| 30 | + * @TODO(fangjun) Support batch size > 1 | ||
| 31 | + * | ||
| 32 | + * @param model The RnntModel | ||
| 33 | + * @param encoder_out Its shape is (1, num_frames, encoder_out_dim). | ||
| 34 | + */ | ||
| 35 | +std::vector<int32_t> GreedySearch(RnntModel &model, // NOLINT | ||
| 36 | + const Ort::Value &encoder_out); | ||
| 37 | + | ||
| 38 | +} // namespace sherpa_onnx | ||
| 39 | + | ||
| 40 | +#endif // SHERPA_ONNX_CSRC_DECODE_H_ |
sherpa-onnx/csrc/fbank_features.h
已删除
100644 → 0
| 1 | -#include <iostream> | ||
| 2 | - | ||
| 3 | -#include "kaldi_native_io/csrc/kaldi-io.h" | ||
| 4 | -#include "kaldi_native_io/csrc/wave-reader.h" | ||
| 5 | -#include "kaldi-native-fbank/csrc/online-feature.h" | ||
| 6 | - | ||
| 7 | - | ||
| 8 | -kaldiio::Matrix<float> readWav(std::string filename, bool log = false){ | ||
| 9 | - if (log) | ||
| 10 | - std::cout << "reading " << filename << std::endl; | ||
| 11 | - | ||
| 12 | - bool binary = true; | ||
| 13 | - kaldiio::Input ki(filename, &binary); | ||
| 14 | - kaldiio::WaveHolder wh; | ||
| 15 | - | ||
| 16 | - if (!wh.Read(ki.Stream())) { | ||
| 17 | - std::cerr << "Failed to read " << filename; | ||
| 18 | - exit(EXIT_FAILURE); | ||
| 19 | - } | ||
| 20 | - | ||
| 21 | - auto &wave_data = wh.Value(); | ||
| 22 | - auto &d = wave_data.Data(); | ||
| 23 | - | ||
| 24 | - if (log) | ||
| 25 | - std::cout << "wav shape: " << "(" << d.NumRows() << "," << d.NumCols() << ")" << std::endl; | ||
| 26 | - | ||
| 27 | - return d; | ||
| 28 | -} | ||
| 29 | - | ||
| 30 | - | ||
| 31 | -std::vector<float> ComputeFeatures(knf::OnlineFbank &fbank, knf::FbankOptions opts, kaldiio::Matrix<float> samples, bool log = false){ | ||
| 32 | - int numSamples = samples.NumCols(); | ||
| 33 | - | ||
| 34 | - for (int i = 0; i < numSamples; i++) | ||
| 35 | - { | ||
| 36 | - float currentSample = samples.Row(0).Data()[i] / 32768; | ||
| 37 | - fbank.AcceptWaveform(opts.frame_opts.samp_freq, ¤tSample, 1); | ||
| 38 | - } | ||
| 39 | - | ||
| 40 | - std::vector<float> features; | ||
| 41 | - int32_t num_frames = fbank.NumFramesReady(); | ||
| 42 | - for (int32_t i = 0; i != num_frames; ++i) { | ||
| 43 | - const float *frame = fbank.GetFrame(i); | ||
| 44 | - for (int32_t k = 0; k != opts.mel_opts.num_bins; ++k) { | ||
| 45 | - features.push_back(frame[k]); | ||
| 46 | - } | ||
| 47 | - } | ||
| 48 | - if (log){ | ||
| 49 | - std::cout << "done feature extraction" << std::endl; | ||
| 50 | - std::cout << "extracted fbank shape " << "(" << num_frames << "," << opts.mel_opts.num_bins << ")" << std::endl; | ||
| 51 | - | ||
| 52 | - for (int i=0; i< 20; i++) | ||
| 53 | - std::cout << features.at(i) << std::endl; | ||
| 54 | - } | ||
| 55 | - | ||
| 56 | - return features; | ||
| 57 | -} |
sherpa-onnx/csrc/main.cpp
已删除
100644 → 0
| 1 | -#include <algorithm> | ||
| 2 | -#include <fstream> | ||
| 3 | -#include <iostream> | ||
| 4 | -#include <math.h> | ||
| 5 | -#include <time.h> | ||
| 6 | -#include <vector> | ||
| 7 | - | ||
| 8 | -#include "sherpa-onnx/csrc/fbank_features.h" | ||
| 9 | -#include "sherpa-onnx/csrc/rnnt_beam_search.h" | ||
| 10 | - | ||
| 11 | -#include "kaldi-native-fbank/csrc/online-feature.h" | ||
| 12 | - | ||
| 13 | -int main(int argc, char *argv[]) { | ||
| 14 | - char *encoder_path = argv[1]; | ||
| 15 | - char *decoder_path = argv[2]; | ||
| 16 | - char *joiner_path = argv[3]; | ||
| 17 | - char *joiner_encoder_proj_path = argv[4]; | ||
| 18 | - char *joiner_decoder_proj_path = argv[5]; | ||
| 19 | - char *token_path = argv[6]; | ||
| 20 | - std::string search_method = argv[7]; | ||
| 21 | - char *filename = argv[8]; | ||
| 22 | - | ||
| 23 | - // General parameters | ||
| 24 | - int numberOfThreads = 16; | ||
| 25 | - | ||
| 26 | - // Initialize fbanks | ||
| 27 | - knf::FbankOptions opts; | ||
| 28 | - opts.frame_opts.dither = 0; | ||
| 29 | - opts.frame_opts.samp_freq = 16000; | ||
| 30 | - opts.frame_opts.frame_shift_ms = 10.0f; | ||
| 31 | - opts.frame_opts.frame_length_ms = 25.0f; | ||
| 32 | - opts.mel_opts.num_bins = 80; | ||
| 33 | - opts.frame_opts.window_type = "povey"; | ||
| 34 | - opts.frame_opts.snip_edges = false; | ||
| 35 | - knf::OnlineFbank fbank(opts); | ||
| 36 | - | ||
| 37 | - // set session opts | ||
| 38 | - // https://onnxruntime.ai/docs/performance/tune-performance.html | ||
| 39 | - session_options.SetIntraOpNumThreads(numberOfThreads); | ||
| 40 | - session_options.SetInterOpNumThreads(numberOfThreads); | ||
| 41 | - session_options.SetGraphOptimizationLevel( | ||
| 42 | - GraphOptimizationLevel::ORT_ENABLE_EXTENDED); | ||
| 43 | - session_options.SetLogSeverityLevel(4); | ||
| 44 | - session_options.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL); | ||
| 45 | - | ||
| 46 | - api.CreateTensorRTProviderOptions(&tensorrt_options); | ||
| 47 | - std::unique_ptr<OrtTensorRTProviderOptionsV2, | ||
| 48 | - decltype(api.ReleaseTensorRTProviderOptions)> | ||
| 49 | - rel_trt_options(tensorrt_options, api.ReleaseTensorRTProviderOptions); | ||
| 50 | - api.SessionOptionsAppendExecutionProvider_TensorRT_V2( | ||
| 51 | - static_cast<OrtSessionOptions *>(session_options), rel_trt_options.get()); | ||
| 52 | - | ||
| 53 | - // Define model | ||
| 54 | - auto model = | ||
| 55 | - get_model(encoder_path, decoder_path, joiner_path, | ||
| 56 | - joiner_encoder_proj_path, joiner_decoder_proj_path, token_path); | ||
| 57 | - | ||
| 58 | - std::vector<std::string> filename_list{filename}; | ||
| 59 | - | ||
| 60 | - for (auto filename : filename_list) { | ||
| 61 | - std::cout << filename << std::endl; | ||
| 62 | - auto samples = readWav(filename, true); | ||
| 63 | - int numSamples = samples.NumCols(); | ||
| 64 | - | ||
| 65 | - auto features = ComputeFeatures(fbank, opts, samples); | ||
| 66 | - | ||
| 67 | - auto tic = std::chrono::high_resolution_clock::now(); | ||
| 68 | - | ||
| 69 | - // # === Encoder Out === # | ||
| 70 | - int num_frames = features.size() / opts.mel_opts.num_bins; | ||
| 71 | - auto encoder_out = | ||
| 72 | - model.encoder_forward(features, std::vector<int64_t>{num_frames}, | ||
| 73 | - std::vector<int64_t>{1, num_frames, 80}, | ||
| 74 | - std::vector<int64_t>{1}, memory_info); | ||
| 75 | - | ||
| 76 | - // # === Search === # | ||
| 77 | - std::vector<std::vector<int32_t>> hyps; | ||
| 78 | - if (search_method == "greedy") | ||
| 79 | - hyps = GreedySearch(&model, &encoder_out); | ||
| 80 | - else { | ||
| 81 | - std::cout << "wrong search method!" << std::endl; | ||
| 82 | - exit(0); | ||
| 83 | - } | ||
| 84 | - auto results = hyps2result(model.tokens_map, hyps); | ||
| 85 | - | ||
| 86 | - // # === Print Elapsed Time === # | ||
| 87 | - auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>( | ||
| 88 | - std::chrono::high_resolution_clock::now() - tic); | ||
| 89 | - std::cout << "Elapsed: " << float(elapsed.count()) / 1000 << " seconds" | ||
| 90 | - << std::endl; | ||
| 91 | - std::cout << "rtf: " << float(elapsed.count()) / 1000 / (numSamples / 16000) | ||
| 92 | - << std::endl; | ||
| 93 | - | ||
| 94 | - print_hyps(hyps); | ||
| 95 | - std::cout << results[0] << std::endl; | ||
| 96 | - } | ||
| 97 | - | ||
| 98 | - return 0; | ||
| 99 | -} |
sherpa-onnx/csrc/models.h
已删除
100644 → 0
| 1 | -#include <map> | ||
| 2 | -#include <vector> | ||
| 3 | -#include <iostream> | ||
| 4 | -#include <algorithm> | ||
| 5 | -#include <sys/stat.h> | ||
| 6 | - | ||
| 7 | -#include "utils_onnx.h" | ||
| 8 | - | ||
| 9 | - | ||
| 10 | -struct Model | ||
| 11 | -{ | ||
| 12 | - public: | ||
| 13 | - const char* encoder_path; | ||
| 14 | - const char* decoder_path; | ||
| 15 | - const char* joiner_path; | ||
| 16 | - const char* joiner_encoder_proj_path; | ||
| 17 | - const char* joiner_decoder_proj_path; | ||
| 18 | - const char* tokens_path; | ||
| 19 | - | ||
| 20 | - Ort::Session encoder = load_model(encoder_path); | ||
| 21 | - Ort::Session decoder = load_model(decoder_path); | ||
| 22 | - Ort::Session joiner = load_model(joiner_path); | ||
| 23 | - Ort::Session joiner_encoder_proj = load_model(joiner_encoder_proj_path); | ||
| 24 | - Ort::Session joiner_decoder_proj = load_model(joiner_decoder_proj_path); | ||
| 25 | - std::map<int, std::string> tokens_map = get_token_map(tokens_path); | ||
| 26 | - | ||
| 27 | - int32_t blank_id; | ||
| 28 | - int32_t unk_id; | ||
| 29 | - int32_t context_size; | ||
| 30 | - | ||
| 31 | - std::vector<Ort::Value> encoder_forward(std::vector<float> in_vector, | ||
| 32 | - std::vector<int64_t> in_vector_length, | ||
| 33 | - std::vector<int64_t> feature_dims, | ||
| 34 | - std::vector<int64_t> feature_length_dims, | ||
| 35 | - Ort::MemoryInfo &memory_info){ | ||
| 36 | - std::vector<Ort::Value> encoder_inputTensors; | ||
| 37 | - encoder_inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, in_vector.data(), in_vector.size(), feature_dims.data(), feature_dims.size())); | ||
| 38 | - encoder_inputTensors.push_back(Ort::Value::CreateTensor<int64_t>(memory_info, in_vector_length.data(), in_vector_length.size(), feature_length_dims.data(), feature_length_dims.size())); | ||
| 39 | - | ||
| 40 | - std::vector<const char*> encoder_inputNames = {encoder.GetInputName(0, allocator), encoder.GetInputName(1, allocator)}; | ||
| 41 | - std::vector<const char*> encoder_outputNames = {encoder.GetOutputName(0, allocator)}; | ||
| 42 | - | ||
| 43 | - auto out = encoder.Run(Ort::RunOptions{nullptr}, | ||
| 44 | - encoder_inputNames.data(), | ||
| 45 | - encoder_inputTensors.data(), | ||
| 46 | - encoder_inputTensors.size(), | ||
| 47 | - encoder_outputNames.data(), | ||
| 48 | - encoder_outputNames.size()); | ||
| 49 | - return out; | ||
| 50 | - } | ||
| 51 | - | ||
| 52 | - std::vector<Ort::Value> decoder_forward(std::vector<int64_t> in_vector, | ||
| 53 | - std::vector<int64_t> dims, | ||
| 54 | - Ort::MemoryInfo &memory_info){ | ||
| 55 | - std::vector<Ort::Value> inputTensors; | ||
| 56 | - inputTensors.push_back(Ort::Value::CreateTensor<int64_t>(memory_info, in_vector.data(), in_vector.size(), dims.data(), dims.size())); | ||
| 57 | - | ||
| 58 | - std::vector<const char*> inputNames {decoder.GetInputName(0, allocator)}; | ||
| 59 | - std::vector<const char*> outputNames {decoder.GetOutputName(0, allocator)}; | ||
| 60 | - | ||
| 61 | - auto out = decoder.Run(Ort::RunOptions{nullptr}, | ||
| 62 | - inputNames.data(), | ||
| 63 | - inputTensors.data(), | ||
| 64 | - inputTensors.size(), | ||
| 65 | - outputNames.data(), | ||
| 66 | - outputNames.size()); | ||
| 67 | - | ||
| 68 | - return out; | ||
| 69 | - } | ||
| 70 | - | ||
| 71 | - std::vector<Ort::Value> joiner_forward(std::vector<float> projected_encoder_out, | ||
| 72 | - std::vector<float> decoder_out, | ||
| 73 | - std::vector<int64_t> projected_encoder_out_dims, | ||
| 74 | - std::vector<int64_t> decoder_out_dims, | ||
| 75 | - Ort::MemoryInfo &memory_info){ | ||
| 76 | - std::vector<Ort::Value> inputTensors; | ||
| 77 | - inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, projected_encoder_out.data(), projected_encoder_out.size(), projected_encoder_out_dims.data(), projected_encoder_out_dims.size())); | ||
| 78 | - inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, decoder_out.data(), decoder_out.size(), decoder_out_dims.data(), decoder_out_dims.size())); | ||
| 79 | - std::vector<const char*> inputNames = {joiner.GetInputName(0, allocator), joiner.GetInputName(1, allocator)}; | ||
| 80 | - std::vector<const char*> outputNames = {joiner.GetOutputName(0, allocator)}; | ||
| 81 | - | ||
| 82 | - auto out = joiner.Run(Ort::RunOptions{nullptr}, | ||
| 83 | - inputNames.data(), | ||
| 84 | - inputTensors.data(), | ||
| 85 | - inputTensors.size(), | ||
| 86 | - outputNames.data(), | ||
| 87 | - outputNames.size()); | ||
| 88 | - | ||
| 89 | - return out; | ||
| 90 | - } | ||
| 91 | - | ||
| 92 | - std::vector<Ort::Value> joiner_encoder_proj_forward(std::vector<float> in_vector, | ||
| 93 | - std::vector<int64_t> dims, | ||
| 94 | - Ort::MemoryInfo &memory_info){ | ||
| 95 | - std::vector<Ort::Value> inputTensors; | ||
| 96 | - inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, in_vector.data(), in_vector.size(), dims.data(), dims.size())); | ||
| 97 | - | ||
| 98 | - std::vector<const char*> inputNames {joiner_encoder_proj.GetInputName(0, allocator)}; | ||
| 99 | - std::vector<const char*> outputNames {joiner_encoder_proj.GetOutputName(0, allocator)}; | ||
| 100 | - | ||
| 101 | - auto out = joiner_encoder_proj.Run(Ort::RunOptions{nullptr}, | ||
| 102 | - inputNames.data(), | ||
| 103 | - inputTensors.data(), | ||
| 104 | - inputTensors.size(), | ||
| 105 | - outputNames.data(), | ||
| 106 | - outputNames.size()); | ||
| 107 | - | ||
| 108 | - return out; | ||
| 109 | - } | ||
| 110 | - | ||
| 111 | - std::vector<Ort::Value> joiner_decoder_proj_forward(std::vector<float> in_vector, | ||
| 112 | - std::vector<int64_t> dims, | ||
| 113 | - Ort::MemoryInfo &memory_info){ | ||
| 114 | - std::vector<Ort::Value> inputTensors; | ||
| 115 | - inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, in_vector.data(), in_vector.size(), dims.data(), dims.size())); | ||
| 116 | - | ||
| 117 | - std::vector<const char*> inputNames {joiner_decoder_proj.GetInputName(0, allocator)}; | ||
| 118 | - std::vector<const char*> outputNames {joiner_decoder_proj.GetOutputName(0, allocator)}; | ||
| 119 | - | ||
| 120 | - auto out = joiner_decoder_proj.Run(Ort::RunOptions{nullptr}, | ||
| 121 | - inputNames.data(), | ||
| 122 | - inputTensors.data(), | ||
| 123 | - inputTensors.size(), | ||
| 124 | - outputNames.data(), | ||
| 125 | - outputNames.size()); | ||
| 126 | - | ||
| 127 | - return out; | ||
| 128 | - } | ||
| 129 | - | ||
| 130 | - Ort::Session load_model(const char* path){ | ||
| 131 | - struct stat buffer; | ||
| 132 | - if (stat(path, &buffer) != 0){ | ||
| 133 | - std::cout << "File does not exist!: " << path << std::endl; | ||
| 134 | - exit(0); | ||
| 135 | - } | ||
| 136 | - std::cout << "loading " << path << std::endl; | ||
| 137 | - Ort::Session onnx_model(env, path, session_options); | ||
| 138 | - return onnx_model; | ||
| 139 | - } | ||
| 140 | - | ||
| 141 | - void extract_constant_lm_parameters(){ | ||
| 142 | - /* | ||
| 143 | - all_in_one contains these params. We should trace all_in_one and find 'constants_lm' nodes to extract these params | ||
| 144 | - For now, these params are set staticaly. | ||
| 145 | - in: Ort::Session &all_in_one | ||
| 146 | - out: {blank_id, unk_id, context_size} | ||
| 147 | - should return std::vector<int32_t> | ||
| 148 | - */ | ||
| 149 | - blank_id = 0; | ||
| 150 | - unk_id = 0; | ||
| 151 | - context_size = 2; | ||
| 152 | - } | ||
| 153 | - | ||
| 154 | - std::map<int, std::string> get_token_map(const char* token_path){ | ||
| 155 | - std::ifstream inFile; | ||
| 156 | - inFile.open(token_path); | ||
| 157 | - if (inFile.fail()) | ||
| 158 | - std::cerr << "Could not find token file" << std::endl; | ||
| 159 | - | ||
| 160 | - std::map<int, std::string> token_map; | ||
| 161 | - | ||
| 162 | - std::string line; | ||
| 163 | - while (std::getline(inFile, line)) | ||
| 164 | - { | ||
| 165 | - int id; | ||
| 166 | - std::string token; | ||
| 167 | - | ||
| 168 | - std::istringstream iss(line); | ||
| 169 | - iss >> token; | ||
| 170 | - iss >> id; | ||
| 171 | - | ||
| 172 | - token_map[id] = token; | ||
| 173 | - } | ||
| 174 | - | ||
| 175 | - return token_map; | ||
| 176 | - } | ||
| 177 | - | ||
| 178 | -}; | ||
| 179 | - | ||
| 180 | - | ||
| 181 | -Model get_model(std::string exp_path, char* tokens_path){ | ||
| 182 | - Model model{ | ||
| 183 | - (exp_path + "/encoder_simp.onnx").c_str(), | ||
| 184 | - (exp_path + "/decoder_simp.onnx").c_str(), | ||
| 185 | - (exp_path + "/joiner_simp.onnx").c_str(), | ||
| 186 | - (exp_path + "/joiner_encoder_proj_simp.onnx").c_str(), | ||
| 187 | - (exp_path + "/joiner_decoder_proj_simp.onnx").c_str(), | ||
| 188 | - tokens_path, | ||
| 189 | - }; | ||
| 190 | - model.extract_constant_lm_parameters(); | ||
| 191 | - | ||
| 192 | - return model; | ||
| 193 | -} | ||
| 194 | - | ||
| 195 | -Model get_model(char* encoder_path, | ||
| 196 | - char* decoder_path, | ||
| 197 | - char* joiner_path, | ||
| 198 | - char* joiner_encoder_proj_path, | ||
| 199 | - char* joiner_decoder_proj_path, | ||
| 200 | - char* tokens_path){ | ||
| 201 | - Model model{ | ||
| 202 | - encoder_path, | ||
| 203 | - decoder_path, | ||
| 204 | - joiner_path, | ||
| 205 | - joiner_encoder_proj_path, | ||
| 206 | - joiner_decoder_proj_path, | ||
| 207 | - tokens_path, | ||
| 208 | - }; | ||
| 209 | - model.extract_constant_lm_parameters(); | ||
| 210 | - | ||
| 211 | - return model; | ||
| 212 | -} | ||
| 213 | - | ||
| 214 | - | ||
| 215 | -void doWarmup(Model *model, int numWarmup = 5){ | ||
| 216 | - std::cout << "Warmup is started" << std::endl; | ||
| 217 | - | ||
| 218 | - std::vector<float> encoder_warmup_sample (500 * 80, 1.0); | ||
| 219 | - for (int i=0; i<numWarmup; i++) | ||
| 220 | - auto encoder_out = model->encoder_forward(encoder_warmup_sample, | ||
| 221 | - std::vector<int64_t> {500}, | ||
| 222 | - std::vector<int64_t> {1, 500, 80}, | ||
| 223 | - std::vector<int64_t> {1}, | ||
| 224 | - memory_info); | ||
| 225 | - | ||
| 226 | - std::vector<int64_t> decoder_warmup_sample {1, 1}; | ||
| 227 | - for (int i=0; i<numWarmup; i++) | ||
| 228 | - auto decoder_out = model->decoder_forward(decoder_warmup_sample, | ||
| 229 | - std::vector<int64_t> {1, 2}, | ||
| 230 | - memory_info); | ||
| 231 | - | ||
| 232 | - std::vector<float> joiner_warmup_sample1 (512, 1.0); | ||
| 233 | - std::vector<float> joiner_warmup_sample2 (512, 1.0); | ||
| 234 | - for (int i=0; i<numWarmup; i++) | ||
| 235 | - auto logits = model->joiner_forward(joiner_warmup_sample1, | ||
| 236 | - joiner_warmup_sample2, | ||
| 237 | - std::vector<int64_t> {1, 1, 1, 512}, | ||
| 238 | - std::vector<int64_t> {1, 1, 1, 512}, | ||
| 239 | - memory_info); | ||
| 240 | - | ||
| 241 | - std::vector<float> joiner_encoder_proj_warmup_sample (100 * 512, 1.0); | ||
| 242 | - for (int i=0; i<numWarmup; i++) | ||
| 243 | - auto projected_encoder_out = model->joiner_encoder_proj_forward(joiner_encoder_proj_warmup_sample, | ||
| 244 | - std::vector<int64_t> {100, 512}, | ||
| 245 | - memory_info); | ||
| 246 | - | ||
| 247 | - std::vector<float> joiner_decoder_proj_warmup_sample (512, 1.0); | ||
| 248 | - for (int i=0; i<numWarmup; i++) | ||
| 249 | - auto projected_decoder_out = model->joiner_decoder_proj_forward(joiner_decoder_proj_warmup_sample, | ||
| 250 | - std::vector<int64_t> {1, 512}, | ||
| 251 | - memory_info); | ||
| 252 | - std::cout << "Warmup is done" << std::endl; | ||
| 253 | -} |
sherpa-onnx/csrc/rnnt-model.cc
0 → 100644
| 1 | +/** | ||
| 2 | + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) | ||
| 3 | + * | ||
| 4 | + * See LICENSE for clarification regarding multiple authors | ||
| 5 | + * | ||
| 6 | + * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 7 | + * you may not use this file except in compliance with the License. | ||
| 8 | + * You may obtain a copy of the License at | ||
| 9 | + * | ||
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 | ||
| 11 | + * | ||
| 12 | + * Unless required by applicable law or agreed to in writing, software | ||
| 13 | + * distributed under the License is distributed on an "AS IS" BASIS, | ||
| 14 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 15 | + * See the License for the specific language governing permissions and | ||
| 16 | + * limitations under the License. | ||
| 17 | + */ | ||
| 18 | +#include "sherpa-onnx/csrc/rnnt-model.h" | ||
| 19 | + | ||
| 20 | +#include <array> | ||
| 21 | +#include <utility> | ||
| 22 | +#include <vector> | ||
| 23 | + | ||
| 24 | +namespace sherpa_onnx { | ||
| 25 | + | ||
| 26 | +/** | ||
| 27 | + * Get the input names of a model. | ||
| 28 | + * | ||
| 29 | + * @param sess An onnxruntime session. | ||
| 30 | + * @param input_names. On return, it contains the input names of the model. | ||
| 31 | + * @param input_names_ptr. On return, input_names_ptr[i] contains | ||
| 32 | + * input_names[i].c_str() | ||
| 33 | + */ | ||
| 34 | +static void GetInputNames(Ort::Session *sess, | ||
| 35 | + std::vector<std::string> *input_names, | ||
| 36 | + std::vector<const char *> *input_names_ptr) { | ||
| 37 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 38 | + size_t node_count = sess->GetInputCount(); | ||
| 39 | + input_names->resize(node_count); | ||
| 40 | + input_names_ptr->resize(node_count); | ||
| 41 | + for (size_t i = 0; i != node_count; ++i) { | ||
| 42 | + auto tmp = sess->GetInputNameAllocated(i, allocator); | ||
| 43 | + (*input_names)[i] = tmp.get(); | ||
| 44 | + (*input_names_ptr)[i] = (*input_names)[i].c_str(); | ||
| 45 | + } | ||
| 46 | +} | ||
| 47 | + | ||
| 48 | +/** | ||
| 49 | + * Get the output names of a model. | ||
| 50 | + * | ||
| 51 | + * @param sess An onnxruntime session. | ||
| 52 | + * @param output_names. On return, it contains the output names of the model. | ||
| 53 | + * @param output_names_ptr. On return, output_names_ptr[i] contains | ||
| 54 | + * output_names[i].c_str() | ||
| 55 | + */ | ||
| 56 | +static void GetOutputNames(Ort::Session *sess, | ||
| 57 | + std::vector<std::string> *output_names, | ||
| 58 | + std::vector<const char *> *output_names_ptr) { | ||
| 59 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 60 | + size_t node_count = sess->GetOutputCount(); | ||
| 61 | + output_names->resize(node_count); | ||
| 62 | + output_names_ptr->resize(node_count); | ||
| 63 | + for (size_t i = 0; i != node_count; ++i) { | ||
| 64 | + auto tmp = sess->GetOutputNameAllocated(i, allocator); | ||
| 65 | + (*output_names)[i] = tmp.get(); | ||
| 66 | + (*output_names_ptr)[i] = (*output_names)[i].c_str(); | ||
| 67 | + } | ||
| 68 | +} | ||
| 69 | + | ||
| 70 | +RnntModel::RnntModel(const std::string &encoder_filename, | ||
| 71 | + const std::string &decoder_filename, | ||
| 72 | + const std::string &joiner_filename, | ||
| 73 | + const std::string &joiner_encoder_proj_filename, | ||
| 74 | + const std::string &joiner_decoder_proj_filename, | ||
| 75 | + int32_t num_threads) | ||
| 76 | + : env_(ORT_LOGGING_LEVEL_WARNING) { | ||
| 77 | + sess_opts_.SetIntraOpNumThreads(num_threads); | ||
| 78 | + sess_opts_.SetInterOpNumThreads(num_threads); | ||
| 79 | + | ||
| 80 | + InitEncoder(encoder_filename); | ||
| 81 | + InitDecoder(decoder_filename); | ||
| 82 | + InitJoiner(joiner_filename); | ||
| 83 | + InitJoinerEncoderProj(joiner_encoder_proj_filename); | ||
| 84 | + InitJoinerDecoderProj(joiner_decoder_proj_filename); | ||
| 85 | +} | ||
| 86 | + | ||
| 87 | +void RnntModel::InitEncoder(const std::string &filename) { | ||
| 88 | + encoder_sess_ = | ||
| 89 | + std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_); | ||
| 90 | + GetInputNames(encoder_sess_.get(), &encoder_input_names_, | ||
| 91 | + &encoder_input_names_ptr_); | ||
| 92 | + | ||
| 93 | + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, | ||
| 94 | + &encoder_output_names_ptr_); | ||
| 95 | +} | ||
| 96 | + | ||
| 97 | +void RnntModel::InitDecoder(const std::string &filename) { | ||
| 98 | + decoder_sess_ = | ||
| 99 | + std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_); | ||
| 100 | + | ||
| 101 | + GetInputNames(decoder_sess_.get(), &decoder_input_names_, | ||
| 102 | + &decoder_input_names_ptr_); | ||
| 103 | + | ||
| 104 | + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, | ||
| 105 | + &decoder_output_names_ptr_); | ||
| 106 | +} | ||
| 107 | + | ||
| 108 | +void RnntModel::InitJoiner(const std::string &filename) { | ||
| 109 | + joiner_sess_ = | ||
| 110 | + std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_); | ||
| 111 | + | ||
| 112 | + GetInputNames(joiner_sess_.get(), &joiner_input_names_, | ||
| 113 | + &joiner_input_names_ptr_); | ||
| 114 | + | ||
| 115 | + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, | ||
| 116 | + &joiner_output_names_ptr_); | ||
| 117 | +} | ||
| 118 | + | ||
| 119 | +void RnntModel::InitJoinerEncoderProj(const std::string &filename) { | ||
| 120 | + joiner_encoder_proj_sess_ = | ||
| 121 | + std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_); | ||
| 122 | + | ||
| 123 | + GetInputNames(joiner_encoder_proj_sess_.get(), | ||
| 124 | + &joiner_encoder_proj_input_names_, | ||
| 125 | + &joiner_encoder_proj_input_names_ptr_); | ||
| 126 | + | ||
| 127 | + GetOutputNames(joiner_encoder_proj_sess_.get(), | ||
| 128 | + &joiner_encoder_proj_output_names_, | ||
| 129 | + &joiner_encoder_proj_output_names_ptr_); | ||
| 130 | +} | ||
| 131 | + | ||
| 132 | +void RnntModel::InitJoinerDecoderProj(const std::string &filename) { | ||
| 133 | + joiner_decoder_proj_sess_ = | ||
| 134 | + std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_); | ||
| 135 | + | ||
| 136 | + GetInputNames(joiner_decoder_proj_sess_.get(), | ||
| 137 | + &joiner_decoder_proj_input_names_, | ||
| 138 | + &joiner_decoder_proj_input_names_ptr_); | ||
| 139 | + | ||
| 140 | + GetOutputNames(joiner_decoder_proj_sess_.get(), | ||
| 141 | + &joiner_decoder_proj_output_names_, | ||
| 142 | + &joiner_decoder_proj_output_names_ptr_); | ||
| 143 | +} | ||
| 144 | + | ||
| 145 | +Ort::Value RnntModel::RunEncoder(const float *features, int32_t T, | ||
| 146 | + int32_t feature_dim) { | ||
| 147 | + auto memory_info = | ||
| 148 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 149 | + std::array<int64_t, 3> x_shape{1, T, feature_dim}; | ||
| 150 | + Ort::Value x = | ||
| 151 | + Ort::Value::CreateTensor(memory_info, const_cast<float *>(features), | ||
| 152 | + T * feature_dim, x_shape.data(), x_shape.size()); | ||
| 153 | + | ||
| 154 | + std::array<int64_t, 1> x_lens_shape{1}; | ||
| 155 | + int64_t x_lens_tmp = T; | ||
| 156 | + | ||
| 157 | + Ort::Value x_lens = Ort::Value::CreateTensor( | ||
| 158 | + memory_info, &x_lens_tmp, 1, x_lens_shape.data(), x_lens_shape.size()); | ||
| 159 | + | ||
| 160 | + std::array<Ort::Value, 2> encoder_inputs{std::move(x), std::move(x_lens)}; | ||
| 161 | + | ||
| 162 | + // Note: We discard encoder_out_lens since we only implement | ||
| 163 | + // batch==1. | ||
| 164 | + auto encoder_out = encoder_sess_->Run( | ||
| 165 | + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(), | ||
| 166 | + encoder_inputs.size(), encoder_output_names_ptr_.data(), | ||
| 167 | + encoder_output_names_ptr_.size()); | ||
| 168 | + return std::move(encoder_out[0]); | ||
| 169 | +} | ||
| 170 | +Ort::Value RnntModel::RunJoinerEncoderProj(const float *encoder_out, int32_t T, | ||
| 171 | + int32_t encoder_out_dim) { | ||
| 172 | + auto memory_info = | ||
| 173 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 174 | + | ||
| 175 | + std::array<int64_t, 2> in_shape{T, encoder_out_dim}; | ||
| 176 | + Ort::Value in = Ort::Value::CreateTensor( | ||
| 177 | + memory_info, const_cast<float *>(encoder_out), T * encoder_out_dim, | ||
| 178 | + in_shape.data(), in_shape.size()); | ||
| 179 | + | ||
| 180 | + auto encoder_proj_out = joiner_encoder_proj_sess_->Run( | ||
| 181 | + {}, joiner_encoder_proj_input_names_ptr_.data(), &in, 1, | ||
| 182 | + joiner_encoder_proj_output_names_ptr_.data(), | ||
| 183 | + joiner_encoder_proj_output_names_ptr_.size()); | ||
| 184 | + return std::move(encoder_proj_out[0]); | ||
| 185 | +} | ||
| 186 | + | ||
| 187 | +Ort::Value RnntModel::RunDecoder(const int64_t *decoder_input, | ||
| 188 | + int32_t context_size) { | ||
| 189 | + auto memory_info = | ||
| 190 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 191 | + | ||
| 192 | + int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1 | ||
| 193 | + std::array<int64_t, 2> shape{batch_size, context_size}; | ||
| 194 | + Ort::Value in = Ort::Value::CreateTensor( | ||
| 195 | + memory_info, const_cast<int64_t *>(decoder_input), | ||
| 196 | + batch_size * context_size, shape.data(), shape.size()); | ||
| 197 | + | ||
| 198 | + auto decoder_out = decoder_sess_->Run( | ||
| 199 | + {}, decoder_input_names_ptr_.data(), &in, 1, | ||
| 200 | + decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size()); | ||
| 201 | + return std::move(decoder_out[0]); | ||
| 202 | +} | ||
| 203 | + | ||
| 204 | +Ort::Value RnntModel::RunJoinerDecoderProj(const float *decoder_out, | ||
| 205 | + int32_t decoder_out_dim) { | ||
| 206 | + auto memory_info = | ||
| 207 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 208 | + | ||
| 209 | + int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1 | ||
| 210 | + std::array<int64_t, 2> shape{batch_size, decoder_out_dim}; | ||
| 211 | + Ort::Value in = Ort::Value::CreateTensor( | ||
| 212 | + memory_info, const_cast<float *>(decoder_out), | ||
| 213 | + batch_size * decoder_out_dim, shape.data(), shape.size()); | ||
| 214 | + | ||
| 215 | + auto decoder_proj_out = joiner_decoder_proj_sess_->Run( | ||
| 216 | + {}, joiner_decoder_proj_input_names_ptr_.data(), &in, 1, | ||
| 217 | + joiner_decoder_proj_output_names_ptr_.data(), | ||
| 218 | + joiner_decoder_proj_output_names_ptr_.size()); | ||
| 219 | + return std::move(decoder_proj_out[0]); | ||
| 220 | +} | ||
| 221 | + | ||
| 222 | +Ort::Value RnntModel::RunJoiner(const float *projected_encoder_out, | ||
| 223 | + const float *projected_decoder_out, | ||
| 224 | + int32_t joiner_dim) { | ||
| 225 | + auto memory_info = | ||
| 226 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 227 | + int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1 | ||
| 228 | + std::array<int64_t, 2> shape{batch_size, joiner_dim}; | ||
| 229 | + | ||
| 230 | + Ort::Value enc = Ort::Value::CreateTensor( | ||
| 231 | + memory_info, const_cast<float *>(projected_encoder_out), | ||
| 232 | + batch_size * joiner_dim, shape.data(), shape.size()); | ||
| 233 | + | ||
| 234 | + Ort::Value dec = Ort::Value::CreateTensor( | ||
| 235 | + memory_info, const_cast<float *>(projected_decoder_out), | ||
| 236 | + batch_size * joiner_dim, shape.data(), shape.size()); | ||
| 237 | + | ||
| 238 | + std::array<Ort::Value, 2> inputs{std::move(enc), std::move(dec)}; | ||
| 239 | + | ||
| 240 | + auto logit = joiner_sess_->Run( | ||
| 241 | + {}, joiner_input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 242 | + joiner_output_names_ptr_.data(), joiner_output_names_ptr_.size()); | ||
| 243 | + | ||
| 244 | + return std::move(logit[0]); | ||
| 245 | +} | ||
| 246 | + | ||
| 247 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/rnnt-model.h
0 → 100644
| 1 | +/** | ||
| 2 | + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) | ||
| 3 | + * | ||
| 4 | + * See LICENSE for clarification regarding multiple authors | ||
| 5 | + * | ||
| 6 | + * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 7 | + * you may not use this file except in compliance with the License. | ||
| 8 | + * You may obtain a copy of the License at | ||
| 9 | + * | ||
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 | ||
| 11 | + * | ||
| 12 | + * Unless required by applicable law or agreed to in writing, software | ||
| 13 | + * distributed under the License is distributed on an "AS IS" BASIS, | ||
| 14 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 15 | + * See the License for the specific language governing permissions and | ||
| 16 | + * limitations under the License. | ||
| 17 | + */ | ||
| 18 | + | ||
| 19 | +#ifndef SHERPA_ONNX_CSRC_RNNT_MODEL_H_ | ||
| 20 | +#define SHERPA_ONNX_CSRC_RNNT_MODEL_H_ | ||
| 21 | + | ||
| 22 | +#include <memory> | ||
| 23 | +#include <string> | ||
| 24 | +#include <vector> | ||
| 25 | + | ||
| 26 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 27 | + | ||
| 28 | +namespace sherpa_onnx { | ||
| 29 | + | ||
| 30 | +class RnntModel { | ||
| 31 | + public: | ||
| 32 | + /** | ||
| 33 | + * @param encoder_filename Path to the encoder model | ||
| 34 | + * @param decoder_filename Path to the decoder model | ||
| 35 | + * @param joiner_filename Path to the joiner model | ||
| 36 | + * @param joiner_encoder_proj_filename Path to the joiner encoder_proj model | ||
| 37 | + * @param joiner_decoder_proj_filename Path to the joiner decoder_proj model | ||
| 38 | + * @param num_threads Number of threads to use to run the models | ||
| 39 | + */ | ||
| 40 | + RnntModel(const std::string &encoder_filename, | ||
| 41 | + const std::string &decoder_filename, | ||
| 42 | + const std::string &joiner_filename, | ||
| 43 | + const std::string &joiner_encoder_proj_filename, | ||
| 44 | + const std::string &joiner_decoder_proj_filename, | ||
| 45 | + int32_t num_threads); | ||
| 46 | + | ||
| 47 | + /** Run the encoder model. | ||
| 48 | + * | ||
| 49 | + * @TODO(fangjun): Support batch_size > 1 | ||
| 50 | + * | ||
| 51 | + * @param features A tensor of shape (batch_size, T, feature_dim) | ||
| 52 | + * @param T Number of feature frames | ||
| 53 | + * @param feature_dim Dimension of the feature. | ||
| 54 | + * | ||
| 55 | + * @return Return a tensor of shape (batch_size, T', encoder_out_dim) | ||
| 56 | + */ | ||
| 57 | + Ort::Value RunEncoder(const float *features, int32_t T, int32_t feature_dim); | ||
| 58 | + | ||
| 59 | + /** Run the joiner encoder_proj model. | ||
| 60 | + * | ||
| 61 | + * @param encoder_out A tensor of shape (T, encoder_out_dim) | ||
| 62 | + * @param T Number of frames in encoder_out. | ||
| 63 | + * @param encoder_out_dim Dimension of encoder_out. | ||
| 64 | + * | ||
| 65 | + * @return Return a tensor of shape (T, joiner_dim) | ||
| 66 | + * | ||
| 67 | + */ | ||
| 68 | + Ort::Value RunJoinerEncoderProj(const float *encoder_out, int32_t T, | ||
| 69 | + int32_t encoder_out_dim); | ||
| 70 | + | ||
| 71 | + /** Run the decoder model. | ||
| 72 | + * | ||
| 73 | + * @TODO(fangjun): Support batch_size > 1 | ||
| 74 | + * | ||
| 75 | + * @param decoder_input A tensor of shape (batch_size, context_size). | ||
| 76 | + * @return Return a tensor of shape (batch_size, 1, decoder_out_dim) | ||
| 77 | + */ | ||
| 78 | + Ort::Value RunDecoder(const int64_t *decoder_input, int32_t context_size); | ||
| 79 | + | ||
| 80 | + /** Run joiner decoder_proj model. | ||
| 81 | + * | ||
| 82 | + * @TODO(fangjun): Support batch_size > 1 | ||
| 83 | + * | ||
| 84 | + * @param decoder_out A tensor of shape (batch_size, decoder_out_dim) | ||
| 85 | + * @param decoder_out_dim Output dimension of the decoder_out. | ||
| 86 | + * | ||
| 87 | + * @return Return a tensor of shape (batch_size, joiner_dim); | ||
| 88 | + */ | ||
| 89 | + Ort::Value RunJoinerDecoderProj(const float *decoder_out, | ||
| 90 | + int32_t decoder_out_dim); | ||
| 91 | + | ||
| 92 | + /** Run the joiner model. | ||
| 93 | + * | ||
| 94 | + * @TODO(fangjun): Support batch_size > 1 | ||
| 95 | + * | ||
| 96 | + * @param projected_encoder_out A tensor of shape (batch_size, joiner_dim). | ||
| 97 | + * @param projected_decoder_out A tensor of shape (batch_size, joiner_dim). | ||
| 98 | + * | ||
| 99 | + * @return Return a tensor of shape (batch_size, vocab_size) | ||
| 100 | + */ | ||
| 101 | + Ort::Value RunJoiner(const float *projected_encoder_out, | ||
| 102 | + const float *projected_decoder_out, int32_t joiner_dim); | ||
| 103 | + | ||
| 104 | + private: | ||
| 105 | + void InitEncoder(const std::string &encoder_filename); | ||
| 106 | + void InitDecoder(const std::string &decoder_filename); | ||
| 107 | + void InitJoiner(const std::string &joiner_filename); | ||
| 108 | + void InitJoinerEncoderProj(const std::string &joiner_encoder_proj_filename); | ||
| 109 | + void InitJoinerDecoderProj(const std::string &joiner_decoder_proj_filename); | ||
| 110 | + | ||
| 111 | + private: | ||
| 112 | + Ort::Env env_; | ||
| 113 | + Ort::SessionOptions sess_opts_; | ||
| 114 | + std::unique_ptr<Ort::Session> encoder_sess_; | ||
| 115 | + std::unique_ptr<Ort::Session> decoder_sess_; | ||
| 116 | + std::unique_ptr<Ort::Session> joiner_sess_; | ||
| 117 | + std::unique_ptr<Ort::Session> joiner_encoder_proj_sess_; | ||
| 118 | + std::unique_ptr<Ort::Session> joiner_decoder_proj_sess_; | ||
| 119 | + | ||
| 120 | + std::vector<std::string> encoder_input_names_; | ||
| 121 | + std::vector<const char *> encoder_input_names_ptr_; | ||
| 122 | + std::vector<std::string> encoder_output_names_; | ||
| 123 | + std::vector<const char *> encoder_output_names_ptr_; | ||
| 124 | + | ||
| 125 | + std::vector<std::string> decoder_input_names_; | ||
| 126 | + std::vector<const char *> decoder_input_names_ptr_; | ||
| 127 | + std::vector<std::string> decoder_output_names_; | ||
| 128 | + std::vector<const char *> decoder_output_names_ptr_; | ||
| 129 | + | ||
| 130 | + std::vector<std::string> joiner_input_names_; | ||
| 131 | + std::vector<const char *> joiner_input_names_ptr_; | ||
| 132 | + std::vector<std::string> joiner_output_names_; | ||
| 133 | + std::vector<const char *> joiner_output_names_ptr_; | ||
| 134 | + | ||
| 135 | + std::vector<std::string> joiner_encoder_proj_input_names_; | ||
| 136 | + std::vector<const char *> joiner_encoder_proj_input_names_ptr_; | ||
| 137 | + std::vector<std::string> joiner_encoder_proj_output_names_; | ||
| 138 | + std::vector<const char *> joiner_encoder_proj_output_names_ptr_; | ||
| 139 | + | ||
| 140 | + std::vector<std::string> joiner_decoder_proj_input_names_; | ||
| 141 | + std::vector<const char *> joiner_decoder_proj_input_names_ptr_; | ||
| 142 | + std::vector<std::string> joiner_decoder_proj_output_names_; | ||
| 143 | + std::vector<const char *> joiner_decoder_proj_output_names_ptr_; | ||
| 144 | +}; | ||
| 145 | + | ||
| 146 | +} // namespace sherpa_onnx | ||
| 147 | + | ||
| 148 | +#endif // SHERPA_ONNX_CSRC_RNNT_MODEL_H_ |
sherpa-onnx/csrc/rnnt_beam_search.h
已删除
100644 → 0
| 1 | -#include <vector> | ||
| 2 | -#include <iostream> | ||
| 3 | -#include <algorithm> | ||
| 4 | -#include <time.h> | ||
| 5 | - | ||
| 6 | -#include "models.h" | ||
| 7 | -#include "utils.h" | ||
| 8 | - | ||
| 9 | - | ||
| 10 | -std::vector<float> getEncoderCol(Ort::Value &tensor, int start, int length){ | ||
| 11 | - float* floatarr = tensor.GetTensorMutableData<float>(); | ||
| 12 | - std::vector<float> vector {floatarr + start, floatarr + length}; | ||
| 13 | - return vector; | ||
| 14 | -} | ||
| 15 | - | ||
| 16 | - | ||
| 17 | -/** | ||
| 18 | - * Assume batch size = 1 | ||
| 19 | - */ | ||
| 20 | -std::vector<int64_t> BuildDecoderInput(const std::vector<std::vector<int32_t>> &hyps, | ||
| 21 | - std::vector<int64_t> &decoder_input) { | ||
| 22 | - | ||
| 23 | - int32_t context_size = decoder_input.size(); | ||
| 24 | - int32_t hyps_length = hyps[0].size(); | ||
| 25 | - for (int i=0; i < context_size; i++) | ||
| 26 | - decoder_input[i] = hyps[0][hyps_length-context_size+i]; | ||
| 27 | - | ||
| 28 | - return decoder_input; | ||
| 29 | -} | ||
| 30 | - | ||
| 31 | - | ||
| 32 | -std::vector<std::vector<int32_t>> GreedySearch( | ||
| 33 | - Model *model, // NOLINT | ||
| 34 | - std::vector<Ort::Value> *encoder_out){ | ||
| 35 | - Ort::Value &encoder_out_tensor = encoder_out->at(0); | ||
| 36 | - int encoder_out_dim1 = encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1]; | ||
| 37 | - int encoder_out_dim2 = encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[2]; | ||
| 38 | - auto encoder_out_vector = ortVal2Vector(encoder_out_tensor, encoder_out_dim1 * encoder_out_dim2); | ||
| 39 | - | ||
| 40 | - // # === Greedy Search === # | ||
| 41 | - int32_t batch_size = 1; | ||
| 42 | - std::vector<int32_t> blanks(model->context_size, model->blank_id); | ||
| 43 | - std::vector<std::vector<int32_t>> hyps(batch_size, blanks); | ||
| 44 | - std::vector<int64_t> decoder_input(model->context_size, model->blank_id); | ||
| 45 | - | ||
| 46 | - auto decoder_out = model->decoder_forward(decoder_input, | ||
| 47 | - std::vector<int64_t> {batch_size, model->context_size}, | ||
| 48 | - memory_info); | ||
| 49 | - | ||
| 50 | - Ort::Value &decoder_out_tensor = decoder_out[0]; | ||
| 51 | - int decoder_out_dim = decoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[2]; | ||
| 52 | - auto decoder_out_vector = ortVal2Vector(decoder_out_tensor, decoder_out_dim); | ||
| 53 | - | ||
| 54 | - decoder_out = model->joiner_decoder_proj_forward(decoder_out_vector, | ||
| 55 | - std::vector<int64_t> {1, decoder_out_dim}, | ||
| 56 | - memory_info); | ||
| 57 | - Ort::Value &projected_decoder_out_tensor = decoder_out[0]; | ||
| 58 | - auto projected_decoder_out_dim = projected_decoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1]; | ||
| 59 | - auto projected_decoder_out_vector = ortVal2Vector(projected_decoder_out_tensor, projected_decoder_out_dim); | ||
| 60 | - | ||
| 61 | - auto projected_encoder_out = model->joiner_encoder_proj_forward(encoder_out_vector, | ||
| 62 | - std::vector<int64_t> {encoder_out_dim1, encoder_out_dim2}, | ||
| 63 | - memory_info); | ||
| 64 | - Ort::Value &projected_encoder_out_tensor = projected_encoder_out[0]; | ||
| 65 | - int projected_encoder_out_dim1 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[0]; | ||
| 66 | - int projected_encoder_out_dim2 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1]; | ||
| 67 | - auto projected_encoder_out_vector = ortVal2Vector(projected_encoder_out_tensor, projected_encoder_out_dim1 * projected_encoder_out_dim2); | ||
| 68 | - | ||
| 69 | - int32_t offset = 0; | ||
| 70 | - for (int i=0; i< projected_encoder_out_dim1; i++){ | ||
| 71 | - int32_t cur_batch_size = 1; | ||
| 72 | - int32_t start = offset; | ||
| 73 | - int32_t end = start + cur_batch_size; | ||
| 74 | - offset = end; | ||
| 75 | - | ||
| 76 | - auto cur_encoder_out = getEncoderCol(projected_encoder_out_tensor, start * projected_encoder_out_dim2, end * projected_encoder_out_dim2); | ||
| 77 | - | ||
| 78 | - auto logits = model->joiner_forward(cur_encoder_out, | ||
| 79 | - projected_decoder_out_vector, | ||
| 80 | - std::vector<int64_t> {1, projected_encoder_out_dim2}, | ||
| 81 | - std::vector<int64_t> {1, projected_decoder_out_dim}, | ||
| 82 | - memory_info); | ||
| 83 | - | ||
| 84 | - Ort::Value &logits_tensor = logits[0]; | ||
| 85 | - int logits_dim = logits_tensor.GetTensorTypeAndShapeInfo().GetShape()[1]; | ||
| 86 | - auto logits_vector = ortVal2Vector(logits_tensor, logits_dim); | ||
| 87 | - | ||
| 88 | - int max_indices = static_cast<int>(std::distance(logits_vector.begin(), std::max_element(logits_vector.begin(), logits_vector.end()))); | ||
| 89 | - bool emitted = false; | ||
| 90 | - | ||
| 91 | - for (int32_t k = 0; k != cur_batch_size; ++k) { | ||
| 92 | - auto index = max_indices; | ||
| 93 | - if (index != model->blank_id && index != model->unk_id) { | ||
| 94 | - emitted = true; | ||
| 95 | - hyps[k].push_back(index); | ||
| 96 | - } | ||
| 97 | - } | ||
| 98 | - | ||
| 99 | - if (emitted) { | ||
| 100 | - decoder_input = BuildDecoderInput(hyps, decoder_input); | ||
| 101 | - | ||
| 102 | - decoder_out = model->decoder_forward(decoder_input, | ||
| 103 | - std::vector<int64_t> {batch_size, model->context_size}, | ||
| 104 | - memory_info); | ||
| 105 | - | ||
| 106 | - decoder_out_dim = decoder_out[0].GetTensorTypeAndShapeInfo().GetShape()[2]; | ||
| 107 | - decoder_out_vector = ortVal2Vector(decoder_out[0], decoder_out_dim); | ||
| 108 | - | ||
| 109 | - decoder_out = model->joiner_decoder_proj_forward(decoder_out_vector, | ||
| 110 | - std::vector<int64_t> {1, decoder_out_dim}, | ||
| 111 | - memory_info); | ||
| 112 | - | ||
| 113 | - projected_decoder_out_dim = decoder_out[0].GetTensorTypeAndShapeInfo().GetShape()[1]; | ||
| 114 | - projected_decoder_out_vector = ortVal2Vector(decoder_out[0], projected_decoder_out_dim); | ||
| 115 | - } | ||
| 116 | - } | ||
| 117 | - | ||
| 118 | - return hyps; | ||
| 119 | -} | ||
| 120 | - |
sherpa-onnx/csrc/sherpa-onnx.cc
0 → 100644
| 1 | +/** | ||
| 2 | + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) | ||
| 3 | + * | ||
| 4 | + * See LICENSE for clarification regarding multiple authors | ||
| 5 | + * | ||
| 6 | + * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 7 | + * you may not use this file except in compliance with the License. | ||
| 8 | + * You may obtain a copy of the License at | ||
| 9 | + * | ||
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 | ||
| 11 | + * | ||
| 12 | + * Unless required by applicable law or agreed to in writing, software | ||
| 13 | + * distributed under the License is distributed on an "AS IS" BASIS, | ||
| 14 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 15 | + * See the License for the specific language governing permissions and | ||
| 16 | + * limitations under the License. | ||
| 17 | + */ | ||
| 18 | + | ||
| 19 | +#include <iostream> | ||
| 20 | +#include <string> | ||
| 21 | +#include <vector> | ||
| 22 | + | ||
| 23 | +#include "kaldi-native-fbank/csrc/online-feature.h" | ||
| 24 | +#include "sherpa-onnx/csrc/decode.h" | ||
| 25 | +#include "sherpa-onnx/csrc/rnnt-model.h" | ||
| 26 | +#include "sherpa-onnx/csrc/symbol-table.h" | ||
| 27 | +#include "sherpa-onnx/csrc/wave-reader.h" | ||
| 28 | + | ||
| 29 | +/** Compute fbank features of the input wave filename. | ||
| 30 | + * | ||
| 31 | + * @param wav_filename. Path to a mono wave file. | ||
| 32 | + * @param expected_sampling_rate Expected sampling rate of the input wave file. | ||
| 33 | + * @param num_frames On return, it contains the number of feature frames. | ||
| 34 | + * @return Return the computed feature of shape (num_frames, feature_dim) | ||
| 35 | + * stored in row-major. | ||
| 36 | + */ | ||
| 37 | +static std::vector<float> ComputeFeatures(const std::string &wav_filename, | ||
| 38 | + float expected_sampling_rate, | ||
| 39 | + int32_t *num_frames) { | ||
| 40 | + std::vector<float> samples = | ||
| 41 | + sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate); | ||
| 42 | + | ||
| 43 | + float duration = samples.size() / expected_sampling_rate; | ||
| 44 | + | ||
| 45 | + std::cout << "wav filename: " << wav_filename << "\n"; | ||
| 46 | + std::cout << "wav duration (s): " << duration << "\n"; | ||
| 47 | + | ||
| 48 | + knf::FbankOptions opts; | ||
| 49 | + opts.frame_opts.dither = 0; | ||
| 50 | + opts.frame_opts.snip_edges = false; | ||
| 51 | + opts.frame_opts.samp_freq = expected_sampling_rate; | ||
| 52 | + | ||
| 53 | + int32_t feature_dim = 80; | ||
| 54 | + | ||
| 55 | + opts.mel_opts.num_bins = feature_dim; | ||
| 56 | + | ||
| 57 | + knf::OnlineFbank fbank(opts); | ||
| 58 | + fbank.AcceptWaveform(expected_sampling_rate, samples.data(), samples.size()); | ||
| 59 | + fbank.InputFinished(); | ||
| 60 | + | ||
| 61 | + *num_frames = fbank.NumFramesReady(); | ||
| 62 | + | ||
| 63 | + std::vector<float> features(*num_frames * feature_dim); | ||
| 64 | + float *p = features.data(); | ||
| 65 | + | ||
| 66 | + for (int32_t i = 0; i != fbank.NumFramesReady(); ++i, p += feature_dim) { | ||
| 67 | + const float *f = fbank.GetFrame(i); | ||
| 68 | + std::copy(f, f + feature_dim, p); | ||
| 69 | + } | ||
| 70 | + | ||
| 71 | + return features; | ||
| 72 | +} | ||
| 73 | + | ||
| 74 | +int main(int32_t argc, char *argv[]) { | ||
| 75 | + if (argc < 8 || argc > 9) { | ||
| 76 | + const char *usage = R"usage( | ||
| 77 | +Usage: | ||
| 78 | + ./bin/sherpa-onnx \ | ||
| 79 | + /path/to/tokens.txt \ | ||
| 80 | + /path/to/encoder.onnx \ | ||
| 81 | + /path/to/decoder.onnx \ | ||
| 82 | + /path/to/joiner.onnx \ | ||
| 83 | + /path/to/joiner_encoder_proj.ncnn.param \ | ||
| 84 | + /path/to/joiner_decoder_proj.ncnn.param \ | ||
| 85 | + /path/to/foo.wav [num_threads] | ||
| 86 | + | ||
| 87 | +You can download pre-trained models from the following repository: | ||
| 88 | +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 | ||
| 89 | +)usage"; | ||
| 90 | + std::cerr << usage << "\n"; | ||
| 91 | + | ||
| 92 | + return 0; | ||
| 93 | + } | ||
| 94 | + | ||
| 95 | + std::string tokens = argv[1]; | ||
| 96 | + std::string encoder = argv[2]; | ||
| 97 | + std::string decoder = argv[3]; | ||
| 98 | + std::string joiner = argv[4]; | ||
| 99 | + std::string joiner_encoder_proj = argv[5]; | ||
| 100 | + std::string joiner_decoder_proj = argv[6]; | ||
| 101 | + std::string wav_filename = argv[7]; | ||
| 102 | + int32_t num_threads = 4; | ||
| 103 | + if (argc == 9) { | ||
| 104 | + num_threads = atoi(argv[8]); | ||
| 105 | + } | ||
| 106 | + | ||
| 107 | + sherpa_onnx::SymbolTable sym(tokens); | ||
| 108 | + | ||
| 109 | + int32_t num_frames; | ||
| 110 | + auto features = ComputeFeatures(wav_filename, 16000, &num_frames); | ||
| 111 | + int32_t feature_dim = features.size() / num_frames; | ||
| 112 | + | ||
| 113 | + sherpa_onnx::RnntModel model(encoder, decoder, joiner, joiner_encoder_proj, | ||
| 114 | + joiner_decoder_proj, num_threads); | ||
| 115 | + Ort::Value encoder_out = | ||
| 116 | + model.RunEncoder(features.data(), num_frames, feature_dim); | ||
| 117 | + | ||
| 118 | + auto hyp = sherpa_onnx::GreedySearch(model, encoder_out); | ||
| 119 | + | ||
| 120 | + std::string text; | ||
| 121 | + for (auto i : hyp) { | ||
| 122 | + text += sym[i]; | ||
| 123 | + } | ||
| 124 | + | ||
| 125 | + std::cout << "Recognition result for " << wav_filename << "\n" | ||
| 126 | + << text << "\n"; | ||
| 127 | + | ||
| 128 | + return 0; | ||
| 129 | +} |
| @@ -15,34 +15,21 @@ | @@ -15,34 +15,21 @@ | ||
| 15 | * See the License for the specific language governing permissions and | 15 | * See the License for the specific language governing permissions and |
| 16 | * limitations under the License. | 16 | * limitations under the License. |
| 17 | */ | 17 | */ |
| 18 | - | ||
| 19 | #include <iostream> | 18 | #include <iostream> |
| 19 | +#include <sstream> | ||
| 20 | 20 | ||
| 21 | -#include "kaldi-native-fbank/csrc/online-feature.h" | 21 | +#include "onnxruntime_cxx_api.h" // NOLINT |
| 22 | 22 | ||
| 23 | int main() { | 23 | int main() { |
| 24 | - knf::FbankOptions opts; | ||
| 25 | - opts.frame_opts.dither = 0; | ||
| 26 | - opts.mel_opts.num_bins = 10; | ||
| 27 | - | ||
| 28 | - knf::OnlineFbank fbank(opts); | ||
| 29 | - for (int32_t i = 0; i < 1600; ++i) { | ||
| 30 | - float s = (i * i - i / 2) / 32767.; | ||
| 31 | - fbank.AcceptWaveform(16000, &s, 1); | ||
| 32 | - } | ||
| 33 | - | 24 | + std::cout << "ORT_API_VERSION: " << ORT_API_VERSION << "\n"; |
| 25 | + std::vector<std::string> providers = Ort::GetAvailableProviders(); | ||
| 34 | std::ostringstream os; | 26 | std::ostringstream os; |
| 35 | - | ||
| 36 | - int32_t n = fbank.NumFramesReady(); | ||
| 37 | - for (int32_t i = 0; i != n; ++i) { | ||
| 38 | - const float *frame = fbank.GetFrame(i); | ||
| 39 | - for (int32_t k = 0; k != opts.mel_opts.num_bins; ++k) { | ||
| 40 | - os << frame[k] << ", "; | ||
| 41 | - } | ||
| 42 | - os << "\n"; | 27 | + os << "Available providers: "; |
| 28 | + std::string sep = ""; | ||
| 29 | + for (const auto &p : providers) { | ||
| 30 | + os << sep << p; | ||
| 31 | + sep = ", "; | ||
| 43 | } | 32 | } |
| 44 | - | ||
| 45 | std::cout << os.str() << "\n"; | 33 | std::cout << os.str() << "\n"; |
| 46 | - | ||
| 47 | return 0; | 34 | return 0; |
| 48 | } | 35 | } |
sherpa-onnx/csrc/symbol-table.cc
0 → 100644
| 1 | +/** | ||
| 2 | + * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang) | ||
| 3 | + * | ||
| 4 | + * See LICENSE for clarification regarding multiple authors | ||
| 5 | + * | ||
| 6 | + * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 7 | + * you may not use this file except in compliance with the License. | ||
| 8 | + * You may obtain a copy of the License at | ||
| 9 | + * | ||
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 | ||
| 11 | + * | ||
| 12 | + * Unless required by applicable law or agreed to in writing, software | ||
| 13 | + * distributed under the License is distributed on an "AS IS" BASIS, | ||
| 14 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 15 | + * See the License for the specific language governing permissions and | ||
| 16 | + * limitations under the License. | ||
| 17 | + */ | ||
| 18 | + | ||
| 19 | +#include "sherpa-onnx/csrc/symbol-table.h" | ||
| 20 | + | ||
| 21 | +#include <cassert> | ||
| 22 | +#include <fstream> | ||
| 23 | +#include <sstream> | ||
| 24 | + | ||
| 25 | +namespace sherpa_onnx { | ||
| 26 | + | ||
| 27 | +SymbolTable::SymbolTable(const std::string &filename) { | ||
| 28 | + std::ifstream is(filename); | ||
| 29 | + std::string sym; | ||
| 30 | + int32_t id; | ||
| 31 | + while (is >> sym >> id) { | ||
| 32 | + if (sym.size() >= 3) { | ||
| 33 | + // For BPE-based models, we replace ▁ with a space | ||
| 34 | + // Unicode 9601, hex 0x2581, utf8 0xe29681 | ||
| 35 | + const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str()); | ||
| 36 | + if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) { | ||
| 37 | + sym = sym.replace(0, 3, " "); | ||
| 38 | + } | ||
| 39 | + } | ||
| 40 | + | ||
| 41 | + assert(!sym.empty()); | ||
| 42 | + assert(sym2id_.count(sym) == 0); | ||
| 43 | + assert(id2sym_.count(id) == 0); | ||
| 44 | + | ||
| 45 | + sym2id_.insert({sym, id}); | ||
| 46 | + id2sym_.insert({id, sym}); | ||
| 47 | + } | ||
| 48 | + assert(is.eof()); | ||
| 49 | +} | ||
| 50 | + | ||
| 51 | +std::string SymbolTable::ToString() const { | ||
| 52 | + std::ostringstream os; | ||
| 53 | + char sep = ' '; | ||
| 54 | + for (const auto &p : sym2id_) { | ||
| 55 | + os << p.first << sep << p.second << "\n"; | ||
| 56 | + } | ||
| 57 | + return os.str(); | ||
| 58 | +} | ||
| 59 | + | ||
| 60 | +const std::string &SymbolTable::operator[](int32_t id) const { | ||
| 61 | + return id2sym_.at(id); | ||
| 62 | +} | ||
| 63 | + | ||
| 64 | +int32_t SymbolTable::operator[](const std::string &sym) const { | ||
| 65 | + return sym2id_.at(sym); | ||
| 66 | +} | ||
| 67 | + | ||
| 68 | +bool SymbolTable::contains(int32_t id) const { return id2sym_.count(id) != 0; } | ||
| 69 | + | ||
| 70 | +bool SymbolTable::contains(const std::string &sym) const { | ||
| 71 | + return sym2id_.count(sym) != 0; | ||
| 72 | +} | ||
| 73 | + | ||
| 74 | +std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table) { | ||
| 75 | + return os << symbol_table.ToString(); | ||
| 76 | +} | ||
| 77 | + | ||
| 78 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/symbol-table.h
0 → 100644
| 1 | +/** | ||
| 2 | + * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang) | ||
| 3 | + * | ||
| 4 | + * See LICENSE for clarification regarding multiple authors | ||
| 5 | + * | ||
| 6 | + * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 7 | + * you may not use this file except in compliance with the License. | ||
| 8 | + * You may obtain a copy of the License at | ||
| 9 | + * | ||
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 | ||
| 11 | + * | ||
| 12 | + * Unless required by applicable law or agreed to in writing, software | ||
| 13 | + * distributed under the License is distributed on an "AS IS" BASIS, | ||
| 14 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 15 | + * See the License for the specific language governing permissions and | ||
| 16 | + * limitations under the License. | ||
| 17 | + */ | ||
| 18 | + | ||
| 19 | +#ifndef SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_ | ||
| 20 | +#define SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_ | ||
| 21 | + | ||
| 22 | +#include <string> | ||
| 23 | +#include <unordered_map> | ||
| 24 | + | ||
| 25 | +namespace sherpa_onnx { | ||
| 26 | + | ||
| 27 | +/// It manages mapping between symbols and integer IDs. | ||
| 28 | +class SymbolTable { | ||
| 29 | + public: | ||
| 30 | + SymbolTable() = default; | ||
| 31 | + /// Construct a symbol table from a file. | ||
| 32 | + /// Each line in the file contains two fields: | ||
| 33 | + /// | ||
| 34 | + /// sym ID | ||
| 35 | + /// | ||
| 36 | + /// Fields are separated by space(s). | ||
| 37 | + explicit SymbolTable(const std::string &filename); | ||
| 38 | + | ||
| 39 | + /// Return a string representation of this symbol table | ||
| 40 | + std::string ToString() const; | ||
| 41 | + | ||
| 42 | + /// Return the symbol corresponding to the given ID. | ||
| 43 | + const std::string &operator[](int32_t id) const; | ||
| 44 | + /// Return the ID corresponding to the given symbol. | ||
| 45 | + int32_t operator[](const std::string &sym) const; | ||
| 46 | + | ||
| 47 | + /// Return true if there is a symbol with the given ID. | ||
| 48 | + bool contains(int32_t id) const; | ||
| 49 | + | ||
| 50 | + /// Return true if there is a given symbol in the symbol table. | ||
| 51 | + bool contains(const std::string &sym) const; | ||
| 52 | + | ||
| 53 | + private: | ||
| 54 | + std::unordered_map<std::string, int32_t> sym2id_; | ||
| 55 | + std::unordered_map<int32_t, std::string> id2sym_; | ||
| 56 | +}; | ||
| 57 | + | ||
| 58 | +std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table); | ||
| 59 | + | ||
| 60 | +} // namespace sherpa_onnx | ||
| 61 | + | ||
| 62 | +#endif // SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_ |
sherpa-onnx/csrc/utils.h
已删除
100644 → 0
| 1 | -#include <iostream> | ||
| 2 | -#include <fstream> | ||
| 3 | - | ||
| 4 | - | ||
| 5 | -void vector2file(std::vector<float> vector, std::string saveFileName){ | ||
| 6 | - std::ofstream f(saveFileName); | ||
| 7 | - for(std::vector<float>::const_iterator i = vector.begin(); i != vector.end(); ++i) { | ||
| 8 | - f << *i << '\n'; | ||
| 9 | - } | ||
| 10 | -} | ||
| 11 | - | ||
| 12 | - | ||
| 13 | -std::vector<std::string> hyps2result(std::map<int, std::string> token_map, std::vector<std::vector<int32_t>> hyps, int context_size = 2){ | ||
| 14 | - std::vector<std::string> results; | ||
| 15 | - | ||
| 16 | - for (int k=0; k < hyps.size(); k++){ | ||
| 17 | - std::string result = token_map[hyps[k][context_size]]; | ||
| 18 | - | ||
| 19 | - for (int i=context_size+1; i < hyps[k].size(); i++){ | ||
| 20 | - std::string token = token_map[hyps[k][i]]; | ||
| 21 | - | ||
| 22 | - // TODO: recognising '_' is not working | ||
| 23 | - if (token.at(0) == '_') | ||
| 24 | - result += " " + token; | ||
| 25 | - else | ||
| 26 | - result += token; | ||
| 27 | - } | ||
| 28 | - results.push_back(result); | ||
| 29 | - } | ||
| 30 | - return results; | ||
| 31 | -} | ||
| 32 | - | ||
| 33 | - | ||
| 34 | -void print_hyps(std::vector<std::vector<int32_t>> hyps, int context_size = 2){ | ||
| 35 | - std::cout << "Hyps:" << std::endl; | ||
| 36 | - for (int i=context_size; i<hyps[0].size(); i++) | ||
| 37 | - std::cout << hyps[0][i] << "-"; | ||
| 38 | - std::cout << "|" << std::endl; | ||
| 39 | -} |
sherpa-onnx/csrc/utils_onnx.h
已删除
100644 → 0
| 1 | -#include <iostream> | ||
| 2 | -#include "onnxruntime_cxx_api.h" | ||
| 3 | - | ||
| 4 | -Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test"); | ||
| 5 | -const auto& api = Ort::GetApi(); | ||
| 6 | -OrtTensorRTProviderOptionsV2* tensorrt_options; | ||
| 7 | -Ort::SessionOptions session_options; | ||
| 8 | -Ort::AllocatorWithDefaultOptions allocator; | ||
| 9 | -auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); | ||
| 10 | - | ||
| 11 | - | ||
| 12 | -std::vector<float> ortVal2Vector(Ort::Value &tensor, int tensor_length){ | ||
| 13 | - /** | ||
| 14 | - * convert ort tensor to vector | ||
| 15 | - */ | ||
| 16 | - float* floatarr = tensor.GetTensorMutableData<float>(); | ||
| 17 | - std::vector<float> vector {floatarr, floatarr + tensor_length}; | ||
| 18 | - return vector; | ||
| 19 | -} | ||
| 20 | - | ||
| 21 | - | ||
| 22 | -void print_onnx_forward_output(std::vector<Ort::Value> &output_tensors, int num){ | ||
| 23 | - float* floatarr = output_tensors.front().GetTensorMutableData<float>(); | ||
| 24 | - for (int i = 0; i < num; i++) | ||
| 25 | - printf("[%d] = %f\n", i, floatarr[i]); | ||
| 26 | -} | ||
| 27 | - | ||
| 28 | - | ||
| 29 | -void print_shape_of_ort_val(std::vector<Ort::Value> &tensor){ | ||
| 30 | - auto out_shape = tensor.front().GetTensorTypeAndShapeInfo().GetShape(); | ||
| 31 | - auto out_size = out_shape.size(); | ||
| 32 | - std::cout << "("; | ||
| 33 | - for (int i=0; i<out_size; i++){ | ||
| 34 | - std::cout << out_shape[i]; | ||
| 35 | - if (i < out_size-1) | ||
| 36 | - std::cout << ","; | ||
| 37 | - } | ||
| 38 | - std::cout << ")" << std::endl; | ||
| 39 | -} | ||
| 40 | - | ||
| 41 | - | ||
| 42 | -void print_model_info(Ort::Session &session, std::string title){ | ||
| 43 | - std::cout << "=== Printing '" << title << "' model ===" << std::endl; | ||
| 44 | - Ort::AllocatorWithDefaultOptions allocator; | ||
| 45 | - | ||
| 46 | - // print number of model input nodes | ||
| 47 | - size_t num_input_nodes = session.GetInputCount(); | ||
| 48 | - std::vector<const char*> input_node_names(num_input_nodes); | ||
| 49 | - std::vector<int64_t> input_node_dims; | ||
| 50 | - | ||
| 51 | - printf("Number of inputs = %zu\n", num_input_nodes); | ||
| 52 | - | ||
| 53 | - char* output_name = session.GetOutputName(0, allocator); | ||
| 54 | - printf("output name: %s\n", output_name); | ||
| 55 | - | ||
| 56 | - // iterate over all input nodes | ||
| 57 | - for (int i = 0; i < num_input_nodes; i++) { | ||
| 58 | - // print input node names | ||
| 59 | - char* input_name = session.GetInputName(i, allocator); | ||
| 60 | - printf("Input %d : name=%s\n", i, input_name); | ||
| 61 | - input_node_names[i] = input_name; | ||
| 62 | - | ||
| 63 | - // print input node types | ||
| 64 | - Ort::TypeInfo type_info = session.GetInputTypeInfo(i); | ||
| 65 | - auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); | ||
| 66 | - | ||
| 67 | - ONNXTensorElementDataType type = tensor_info.GetElementType(); | ||
| 68 | - printf("Input %d : type=%d\n", i, type); | ||
| 69 | - | ||
| 70 | - // print input shapes/dims | ||
| 71 | - input_node_dims = tensor_info.GetShape(); | ||
| 72 | - printf("Input %d : num_dims=%zu\n", i, input_node_dims.size()); | ||
| 73 | - for (size_t j = 0; j < input_node_dims.size(); j++) | ||
| 74 | - printf("Input %d : dim %zu=%jd\n", i, j, input_node_dims[j]); | ||
| 75 | - } | ||
| 76 | - std::cout << "=======================================" << std::endl; | ||
| 77 | -} |
sherpa-onnx/csrc/wave-reader.cc
0 → 100644
| 1 | +/** | ||
| 2 | + * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang) | ||
| 3 | + * | ||
| 4 | + * See LICENSE for clarification regarding multiple authors | ||
| 5 | + * | ||
| 6 | + * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 7 | + * you may not use this file except in compliance with the License. | ||
| 8 | + * You may obtain a copy of the License at | ||
| 9 | + * | ||
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 | ||
| 11 | + * | ||
| 12 | + * Unless required by applicable law or agreed to in writing, software | ||
| 13 | + * distributed under the License is distributed on an "AS IS" BASIS, | ||
| 14 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 15 | + * See the License for the specific language governing permissions and | ||
| 16 | + * limitations under the License. | ||
| 17 | + */ | ||
| 18 | + | ||
| 19 | +#include "sherpa-onnx/csrc/wave-reader.h" | ||
| 20 | + | ||
| 21 | +#include <cassert> | ||
| 22 | +#include <fstream> | ||
| 23 | +#include <iostream> | ||
| 24 | +#include <utility> | ||
| 25 | +#include <vector> | ||
| 26 | + | ||
| 27 | +namespace sherpa_onnx { | ||
| 28 | + | ||
| 29 | +namespace { | ||
| 30 | +// see http://soundfile.sapp.org/doc/WaveFormat/ | ||
| 31 | +// | ||
| 32 | +// Note: We assume little endian here | ||
| 33 | +// TODO(fangjun): Support big endian | ||
| 34 | +struct WaveHeader { | ||
| 35 | + void Validate() const { | ||
| 36 | + // F F I R | ||
| 37 | + assert(chunk_id == 0x46464952); | ||
| 38 | + assert(chunk_size == 36 + subchunk2_size); | ||
| 39 | + // E V A W | ||
| 40 | + assert(format == 0x45564157); | ||
| 41 | + assert(subchunk1_id == 0x20746d66); | ||
| 42 | + assert(subchunk1_size == 16); // 16 for PCM | ||
| 43 | + assert(audio_format == 1); // 1 for PCM | ||
| 44 | + assert(num_channels == 1); // we support only single channel for now | ||
| 45 | + assert(byte_rate == sample_rate * num_channels * bits_per_sample / 8); | ||
| 46 | + assert(block_align == num_channels * bits_per_sample / 8); | ||
| 47 | + assert(bits_per_sample == 16); // we support only 16 bits per sample | ||
| 48 | + } | ||
| 49 | + | ||
| 50 | + int32_t chunk_id; | ||
| 51 | + int32_t chunk_size; | ||
| 52 | + int32_t format; | ||
| 53 | + int32_t subchunk1_id; | ||
| 54 | + int32_t subchunk1_size; | ||
| 55 | + int16_t audio_format; | ||
| 56 | + int16_t num_channels; | ||
| 57 | + int32_t sample_rate; | ||
| 58 | + int32_t byte_rate; | ||
| 59 | + int16_t block_align; | ||
| 60 | + int16_t bits_per_sample; | ||
| 61 | + int32_t subchunk2_id; | ||
| 62 | + int32_t subchunk2_size; | ||
| 63 | +}; | ||
| 64 | +static_assert(sizeof(WaveHeader) == 44, ""); | ||
| 65 | + | ||
| 66 | +// Read a wave file of mono-channel. | ||
| 67 | +// Return its samples normalized to the range [-1, 1). | ||
| 68 | +std::vector<float> ReadWaveImpl(std::istream &is, float *sample_rate) { | ||
| 69 | + WaveHeader header; | ||
| 70 | + is.read(reinterpret_cast<char *>(&header), sizeof(header)); | ||
| 71 | + assert(static_cast<bool>(is)); | ||
| 72 | + | ||
| 73 | + header.Validate(); | ||
| 74 | + | ||
| 75 | + *sample_rate = header.sample_rate; | ||
| 76 | + | ||
| 77 | + // header.subchunk2_size contains the number of bytes in the data. | ||
| 78 | + // As we assume each sample contains two bytes, so it is divided by 2 here | ||
| 79 | + std::vector<int16_t> samples(header.subchunk2_size / 2); | ||
| 80 | + | ||
| 81 | + is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size); | ||
| 82 | + | ||
| 83 | + assert(static_cast<bool>(is)); | ||
| 84 | + | ||
| 85 | + std::vector<float> ans(samples.size()); | ||
| 86 | + for (int32_t i = 0; i != ans.size(); ++i) { | ||
| 87 | + ans[i] = samples[i] / 32768.; | ||
| 88 | + } | ||
| 89 | + | ||
| 90 | + return ans; | ||
| 91 | +} | ||
| 92 | + | ||
| 93 | +} // namespace | ||
| 94 | + | ||
| 95 | +std::vector<float> ReadWave(const std::string &filename, | ||
| 96 | + float expected_sample_rate) { | ||
| 97 | + std::ifstream is(filename, std::ifstream::binary); | ||
| 98 | + float sample_rate; | ||
| 99 | + auto samples = ReadWaveImpl(is, &sample_rate); | ||
| 100 | + if (expected_sample_rate != sample_rate) { | ||
| 101 | + std::cerr << "Expected sample rate: " << expected_sample_rate | ||
| 102 | + << ". Given: " << sample_rate << ".\n"; | ||
| 103 | + exit(-1); | ||
| 104 | + } | ||
| 105 | + return samples; | ||
| 106 | +} | ||
| 107 | + | ||
| 108 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/wave-reader.h
0 → 100644
| 1 | +/** | ||
| 2 | + * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang) | ||
| 3 | + * | ||
| 4 | + * See LICENSE for clarification regarding multiple authors | ||
| 5 | + * | ||
| 6 | + * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 7 | + * you may not use this file except in compliance with the License. | ||
| 8 | + * You may obtain a copy of the License at | ||
| 9 | + * | ||
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 | ||
| 11 | + * | ||
| 12 | + * Unless required by applicable law or agreed to in writing, software | ||
| 13 | + * distributed under the License is distributed on an "AS IS" BASIS, | ||
| 14 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 15 | + * See the License for the specific language governing permissions and | ||
| 16 | + * limitations under the License. | ||
| 17 | + */ | ||
| 18 | + | ||
| 19 | +#ifndef SHERPA_ONNX_CSRC_WAVE_READER_H_ | ||
| 20 | +#define SHERPA_ONNX_CSRC_WAVE_READER_H_ | ||
| 21 | + | ||
| 22 | +#include <istream> | ||
| 23 | +#include <string> | ||
| 24 | +#include <vector> | ||
| 25 | + | ||
| 26 | +namespace sherpa_onnx { | ||
| 27 | + | ||
| 28 | +/** Read a wave file with expected sample rate. | ||
| 29 | + | ||
| 30 | + @param filename Path to a wave file. It MUST be single channel, PCM encoded. | ||
| 31 | + @param expected_sample_rate Expected sample rate of the wave file. If the | ||
| 32 | + sample rate don't match, it throws an exception. | ||
| 33 | + | ||
| 34 | + @return Return wave samples normalized to the range [-1, 1). | ||
| 35 | + */ | ||
| 36 | +std::vector<float> ReadWave(const std::string &filename, | ||
| 37 | + float expected_sample_rate); | ||
| 38 | + | ||
| 39 | +} // namespace sherpa_onnx | ||
| 40 | + | ||
| 41 | +#endif // SHERPA_ONNX_CSRC_WAVE_READER_H_ |
-
请 注册 或 登录 后发表评论