正在显示
28 个修改的文件
包含
1310 行增加
和
979 行删除
.github/scripts/test-online-transducer.sh
0 → 100755
| 1 | +#!/usr/bin/env bash | ||
| 2 | + | ||
| 3 | +set -e | ||
| 4 | + | ||
| 5 | +log() { | ||
| 6 | + # This function is from espnet | ||
| 7 | + local fname=${BASH_SOURCE[1]##*/} | ||
| 8 | + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | ||
| 9 | +} | ||
| 10 | + | ||
| 11 | +echo "EXE is $EXE" | ||
| 12 | +echo "PATH: $PATH" | ||
| 13 | + | ||
| 14 | +which $EXE | ||
| 15 | + | ||
| 16 | +log "------------------------------------------------------------" | ||
| 17 | +log "Run LSTM transducer (English)" | ||
| 18 | +log "------------------------------------------------------------" | ||
| 19 | + | ||
| 20 | +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-lstm-en-2023-02-17 | ||
| 21 | + | ||
| 22 | +log "Start testing ${repo_url}" | ||
| 23 | +repo=$(basename $repo_url) | ||
| 24 | +log "Download pretrained model and test-data from $repo_url" | ||
| 25 | + | ||
| 26 | +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | ||
| 27 | +pushd $repo | ||
| 28 | +git lfs pull --include "*.onnx" | ||
| 29 | +popd | ||
| 30 | + | ||
| 31 | +waves=( | ||
| 32 | +$repo/test_wavs/1089-134686-0001.wav | ||
| 33 | +$repo/test_wavs/1221-135766-0001.wav | ||
| 34 | +$repo/test_wavs/1221-135766-0002.wav | ||
| 35 | +) | ||
| 36 | + | ||
| 37 | +for wave in ${waves[@]}; do | ||
| 38 | + time $EXE \ | ||
| 39 | + $repo/tokens.txt \ | ||
| 40 | + $repo/encoder-epoch-99-avg-1.onnx \ | ||
| 41 | + $repo/decoder-epoch-99-avg-1.onnx \ | ||
| 42 | + $repo/joiner-epoch-99-avg-1.onnx \ | ||
| 43 | + $wave \ | ||
| 44 | + 4 | ||
| 45 | +done | ||
| 46 | + | ||
| 47 | +rm -rf $repo |
.github/workflows/linux.yaml
0 → 100644
| 1 | +name: linux | ||
| 2 | + | ||
| 3 | +on: | ||
| 4 | + push: | ||
| 5 | + branches: | ||
| 6 | + - master | ||
| 7 | + paths: | ||
| 8 | + - '.github/workflows/linux.yaml' | ||
| 9 | + - '.github/scripts/test-online-transducer.sh' | ||
| 10 | + - 'CMakeLists.txt' | ||
| 11 | + - 'cmake/**' | ||
| 12 | + - 'sherpa-onnx/csrc/*' | ||
| 13 | + pull_request: | ||
| 14 | + branches: | ||
| 15 | + - master | ||
| 16 | + paths: | ||
| 17 | + - '.github/workflows/linux.yaml' | ||
| 18 | + - '.github/scripts/test-online-transducer.sh' | ||
| 19 | + - 'CMakeLists.txt' | ||
| 20 | + - 'cmake/**' | ||
| 21 | + - 'sherpa-onnx/csrc/*' | ||
| 22 | + | ||
| 23 | +concurrency: | ||
| 24 | + group: linux-${{ github.ref }} | ||
| 25 | + cancel-in-progress: true | ||
| 26 | + | ||
| 27 | +permissions: | ||
| 28 | + contents: read | ||
| 29 | + | ||
| 30 | +jobs: | ||
| 31 | + linux: | ||
| 32 | + runs-on: ${{ matrix.os }} | ||
| 33 | + strategy: | ||
| 34 | + fail-fast: false | ||
| 35 | + matrix: | ||
| 36 | + os: [ubuntu-latest] | ||
| 37 | + | ||
| 38 | + steps: | ||
| 39 | + - uses: actions/checkout@v2 | ||
| 40 | + with: | ||
| 41 | + fetch-depth: 0 | ||
| 42 | + | ||
| 43 | + - name: Configure CMake | ||
| 44 | + shell: bash | ||
| 45 | + run: | | ||
| 46 | + mkdir build | ||
| 47 | + cd build | ||
| 48 | + cmake -D CMAKE_BUILD_TYPE=Release .. | ||
| 49 | + | ||
| 50 | + - name: Build sherpa-onnx for ubuntu | ||
| 51 | + shell: bash | ||
| 52 | + run: | | ||
| 53 | + cd build | ||
| 54 | + make -j2 | ||
| 55 | + | ||
| 56 | + ls -lh lib | ||
| 57 | + ls -lh bin | ||
| 58 | + | ||
| 59 | + - name: Display dependencies of sherpa-onnx for linux | ||
| 60 | + shell: bash | ||
| 61 | + run: | | ||
| 62 | + file build/bin/sherpa-onnx | ||
| 63 | + readelf -d build/bin/sherpa-onnx | ||
| 64 | + | ||
| 65 | + - name: Test online transducer | ||
| 66 | + shell: bash | ||
| 67 | + run: | | ||
| 68 | + export PATH=$PWD/build/bin:$PATH | ||
| 69 | + export EXE=sherpa-onnx | ||
| 70 | + | ||
| 71 | + .github/scripts/test-online-transducer.sh |
.github/workflows/macos.yaml
0 → 100644
| 1 | +name: macos | ||
| 2 | + | ||
| 3 | +on: | ||
| 4 | + push: | ||
| 5 | + branches: | ||
| 6 | + - master | ||
| 7 | + paths: | ||
| 8 | + - '.github/workflows/macos.yaml' | ||
| 9 | + - '.github/scripts/test-online-transducer.sh' | ||
| 10 | + - 'CMakeLists.txt' | ||
| 11 | + - 'cmake/**' | ||
| 12 | + - 'sherpa-onnx/csrc/*' | ||
| 13 | + pull_request: | ||
| 14 | + branches: | ||
| 15 | + - master | ||
| 16 | + paths: | ||
| 17 | + - '.github/workflows/macos.yaml' | ||
| 18 | + - '.github/scripts/test-online-transducer.sh' | ||
| 19 | + - 'CMakeLists.txt' | ||
| 20 | + - 'cmake/**' | ||
| 21 | + - 'sherpa-onnx/csrc/*' | ||
| 22 | + | ||
| 23 | +concurrency: | ||
| 24 | + group: macos-${{ github.ref }} | ||
| 25 | + cancel-in-progress: true | ||
| 26 | + | ||
| 27 | +permissions: | ||
| 28 | + contents: read | ||
| 29 | + | ||
| 30 | +jobs: | ||
| 31 | + macos: | ||
| 32 | + runs-on: ${{ matrix.os }} | ||
| 33 | + strategy: | ||
| 34 | + fail-fast: false | ||
| 35 | + matrix: | ||
| 36 | + os: [macos-latest] | ||
| 37 | + | ||
| 38 | + steps: | ||
| 39 | + - uses: actions/checkout@v2 | ||
| 40 | + with: | ||
| 41 | + fetch-depth: 0 | ||
| 42 | + | ||
| 43 | + - name: Configure CMake | ||
| 44 | + shell: bash | ||
| 45 | + run: | | ||
| 46 | + mkdir build | ||
| 47 | + cd build | ||
| 48 | + cmake -D CMAKE_BUILD_TYPE=Release .. | ||
| 49 | + | ||
| 50 | + - name: Build sherpa for macos | ||
| 51 | + shell: bash | ||
| 52 | + run: | | ||
| 53 | + cd build | ||
| 54 | + make -j2 | ||
| 55 | + | ||
| 56 | + ls -lh lib | ||
| 57 | + ls -lh bin | ||
| 58 | + | ||
| 59 | + | ||
| 60 | + - name: Display dependencies of sherpa-onnx for macos | ||
| 61 | + shell: bash | ||
| 62 | + run: | | ||
| 63 | + file bin/sherpa-onnx | ||
| 64 | + otool -L build/bin/sherpa-onnx | ||
| 65 | + otool -l build/bin/sherpa-onnx | ||
| 66 | + | ||
| 67 | + - name: Test online transducer | ||
| 68 | + shell: bash | ||
| 69 | + run: | | ||
| 70 | + export PATH=$PWD/build/bin:$PATH | ||
| 71 | + export EXE=sherpa-onnx | ||
| 72 | + | ||
| 73 | + .github/scripts/test-online-transducer.sh |
| 1 | -name: test-linux-macos-windows | ||
| 2 | - | ||
| 3 | -on: | ||
| 4 | - push: | ||
| 5 | - branches: | ||
| 6 | - - master | ||
| 7 | - paths: | ||
| 8 | - - '.github/workflows/test-linux-macos-windows.yaml' | ||
| 9 | - - 'CMakeLists.txt' | ||
| 10 | - - 'cmake/**' | ||
| 11 | - - 'sherpa-onnx/csrc/*' | ||
| 12 | - pull_request: | ||
| 13 | - branches: | ||
| 14 | - - master | ||
| 15 | - paths: | ||
| 16 | - - '.github/workflows/test-linux-macos-windows.yaml' | ||
| 17 | - - 'CMakeLists.txt' | ||
| 18 | - - 'cmake/**' | ||
| 19 | - - 'sherpa-onnx/csrc/*' | ||
| 20 | - | ||
| 21 | -concurrency: | ||
| 22 | - group: test-linux-macos-windows-${{ github.ref }} | ||
| 23 | - cancel-in-progress: true | ||
| 24 | - | ||
| 25 | -permissions: | ||
| 26 | - contents: read | ||
| 27 | - | ||
| 28 | -jobs: | ||
| 29 | - test-linux-macos-windows: | ||
| 30 | - runs-on: ${{ matrix.os }} | ||
| 31 | - strategy: | ||
| 32 | - fail-fast: false | ||
| 33 | - matrix: | ||
| 34 | - os: [ubuntu-latest, macos-latest, windows-latest] | ||
| 35 | - | ||
| 36 | - steps: | ||
| 37 | - - uses: actions/checkout@v2 | ||
| 38 | - with: | ||
| 39 | - fetch-depth: 0 | ||
| 40 | - | ||
| 41 | - # see https://github.com/microsoft/setup-msbuild | ||
| 42 | - - name: Add msbuild to PATH | ||
| 43 | - if: startsWith(matrix.os, 'windows') | ||
| 44 | - uses: microsoft/setup-msbuild@v1.0.2 | ||
| 45 | - | ||
| 46 | - - name: Download pretrained model and test-data (English) | ||
| 47 | - shell: bash | ||
| 48 | - run: | | ||
| 49 | - git lfs install | ||
| 50 | - GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 | ||
| 51 | - cd icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 | ||
| 52 | - ls -lh exp/onnx/*.onnx | ||
| 53 | - git lfs pull --include "exp/onnx/*.onnx" | ||
| 54 | - ls -lh exp/onnx/*.onnx | ||
| 55 | - | ||
| 56 | - - name: Download pretrained model and test-data (Chinese) | ||
| 57 | - shell: bash | ||
| 58 | - run: | | ||
| 59 | - GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2 | ||
| 60 | - cd icefall_asr_wenetspeech_pruned_transducer_stateless2 | ||
| 61 | - ls -lh exp/*.onnx | ||
| 62 | - git lfs pull --include "exp/*.onnx" | ||
| 63 | - ls -lh exp/*.onnx | ||
| 64 | - | ||
| 65 | - - name: Configure CMake | ||
| 66 | - shell: bash | ||
| 67 | - run: | | ||
| 68 | - mkdir build | ||
| 69 | - cd build | ||
| 70 | - cmake -D CMAKE_BUILD_TYPE=Release .. | ||
| 71 | - | ||
| 72 | - - name: Build sherpa-onnx for ubuntu/macos | ||
| 73 | - if: startsWith(matrix.os, 'ubuntu') || startsWith(matrix.os, 'macos') | ||
| 74 | - shell: bash | ||
| 75 | - run: | | ||
| 76 | - cd build | ||
| 77 | - make VERBOSE=1 -j3 | ||
| 78 | - | ||
| 79 | - - name: Build sherpa-onnx for Windows | ||
| 80 | - if: startsWith(matrix.os, 'windows') | ||
| 81 | - shell: bash | ||
| 82 | - run: | | ||
| 83 | - cmake --build ./build --config Release | ||
| 84 | - | ||
| 85 | - - name: Run tests for ubuntu/macos (English) | ||
| 86 | - if: startsWith(matrix.os, 'ubuntu') || startsWith(matrix.os, 'macos') | ||
| 87 | - shell: bash | ||
| 88 | - run: | | ||
| 89 | - time ./build/bin/sherpa-onnx \ | ||
| 90 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ | ||
| 91 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ | ||
| 92 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ | ||
| 93 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ | ||
| 94 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ | ||
| 95 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ | ||
| 96 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav | ||
| 97 | - | ||
| 98 | - time ./build/bin/sherpa-onnx \ | ||
| 99 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ | ||
| 100 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ | ||
| 101 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ | ||
| 102 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ | ||
| 103 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ | ||
| 104 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ | ||
| 105 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav | ||
| 106 | - | ||
| 107 | - time ./build/bin/sherpa-onnx \ | ||
| 108 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ | ||
| 109 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ | ||
| 110 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ | ||
| 111 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ | ||
| 112 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ | ||
| 113 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ | ||
| 114 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav | ||
| 115 | - | ||
| 116 | - - name: Run tests for Windows (English) | ||
| 117 | - if: startsWith(matrix.os, 'windows') | ||
| 118 | - shell: bash | ||
| 119 | - run: | | ||
| 120 | - ./build/bin/Release/sherpa-onnx \ | ||
| 121 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ | ||
| 122 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ | ||
| 123 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ | ||
| 124 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ | ||
| 125 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ | ||
| 126 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ | ||
| 127 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav | ||
| 128 | - | ||
| 129 | - ./build/bin/Release/sherpa-onnx \ | ||
| 130 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ | ||
| 131 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ | ||
| 132 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ | ||
| 133 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ | ||
| 134 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ | ||
| 135 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ | ||
| 136 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav | ||
| 137 | - | ||
| 138 | - ./build/bin/Release/sherpa-onnx \ | ||
| 139 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ | ||
| 140 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ | ||
| 141 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ | ||
| 142 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ | ||
| 143 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ | ||
| 144 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ | ||
| 145 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav | ||
| 146 | - | ||
| 147 | - - name: Run tests for ubuntu/macos (Chinese) | ||
| 148 | - if: startsWith(matrix.os, 'ubuntu') || startsWith(matrix.os, 'macos') | ||
| 149 | - shell: bash | ||
| 150 | - run: | | ||
| 151 | - time ./build/bin/sherpa-onnx \ | ||
| 152 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \ | ||
| 153 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \ | ||
| 154 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \ | ||
| 155 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \ | ||
| 156 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \ | ||
| 157 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \ | ||
| 158 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000000.wav | ||
| 159 | - | ||
| 160 | - time ./build/bin/sherpa-onnx \ | ||
| 161 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \ | ||
| 162 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \ | ||
| 163 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \ | ||
| 164 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \ | ||
| 165 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \ | ||
| 166 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \ | ||
| 167 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000001.wav | ||
| 168 | - | ||
| 169 | - time ./build/bin/sherpa-onnx \ | ||
| 170 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \ | ||
| 171 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \ | ||
| 172 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \ | ||
| 173 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \ | ||
| 174 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \ | ||
| 175 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \ | ||
| 176 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000002.wav | ||
| 177 | - | ||
| 178 | - - name: Run tests for windows (Chinese) | ||
| 179 | - if: startsWith(matrix.os, 'windows') | ||
| 180 | - shell: bash | ||
| 181 | - run: | | ||
| 182 | - ./build/bin/Release/sherpa-onnx \ | ||
| 183 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \ | ||
| 184 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \ | ||
| 185 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \ | ||
| 186 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \ | ||
| 187 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \ | ||
| 188 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \ | ||
| 189 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000000.wav | ||
| 190 | - | ||
| 191 | - ./build/bin/Release/sherpa-onnx \ | ||
| 192 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \ | ||
| 193 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \ | ||
| 194 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \ | ||
| 195 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \ | ||
| 196 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \ | ||
| 197 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \ | ||
| 198 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000001.wav | ||
| 199 | - | ||
| 200 | - ./build/bin/Release/sherpa-onnx \ | ||
| 201 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \ | ||
| 202 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \ | ||
| 203 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \ | ||
| 204 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \ | ||
| 205 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \ | ||
| 206 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \ | ||
| 207 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000002.wav |
.github/workflows/windows-x64.yaml
0 → 100644
| 1 | +name: windows-x64 | ||
| 2 | + | ||
| 3 | +on: | ||
| 4 | + push: | ||
| 5 | + branches: | ||
| 6 | + - master | ||
| 7 | + paths: | ||
| 8 | + - '.github/workflows/windows-x64.yaml' | ||
| 9 | + - '.github/scripts/test-online-transducer.sh' | ||
| 10 | + - 'CMakeLists.txt' | ||
| 11 | + - 'cmake/**' | ||
| 12 | + - 'sherpa-onnx/csrc/*' | ||
| 13 | + pull_request: | ||
| 14 | + branches: | ||
| 15 | + - master | ||
| 16 | + paths: | ||
| 17 | + - '.github/workflows/windows-x64.yaml' | ||
| 18 | + - '.github/scripts/test-online-transducer.sh' | ||
| 19 | + - 'CMakeLists.txt' | ||
| 20 | + - 'cmake/**' | ||
| 21 | + - 'sherpa-onnx/csrc/*' | ||
| 22 | + | ||
| 23 | +concurrency: | ||
| 24 | + group: windows-x64-${{ github.ref }} | ||
| 25 | + cancel-in-progress: true | ||
| 26 | + | ||
| 27 | +permissions: | ||
| 28 | + contents: read | ||
| 29 | + | ||
| 30 | +jobs: | ||
| 31 | + windows_x64: | ||
| 32 | + runs-on: ${{ matrix.os }} | ||
| 33 | + name: ${{ matrix.vs-version }} | ||
| 34 | + strategy: | ||
| 35 | + fail-fast: false | ||
| 36 | + matrix: | ||
| 37 | + include: | ||
| 38 | + - vs-version: vs2015 | ||
| 39 | + toolset-version: v140 | ||
| 40 | + os: windows-2019 | ||
| 41 | + | ||
| 42 | + - vs-version: vs2017 | ||
| 43 | + toolset-version: v141 | ||
| 44 | + os: windows-2019 | ||
| 45 | + | ||
| 46 | + - vs-version: vs2019 | ||
| 47 | + toolset-version: v142 | ||
| 48 | + os: windows-2022 | ||
| 49 | + | ||
| 50 | + - vs-version: vs2022 | ||
| 51 | + toolset-version: v143 | ||
| 52 | + os: windows-2022 | ||
| 53 | + | ||
| 54 | + steps: | ||
| 55 | + - uses: actions/checkout@v2 | ||
| 56 | + with: | ||
| 57 | + fetch-depth: 0 | ||
| 58 | + | ||
| 59 | + - name: Configure CMake | ||
| 60 | + shell: bash | ||
| 61 | + run: | | ||
| 62 | + mkdir build | ||
| 63 | + cd build | ||
| 64 | + cmake -T ${{ matrix.toolset-version}},host=x64 -A x64 -D CMAKE_BUILD_TYPE=Release .. | ||
| 65 | + | ||
| 66 | + - name: Build sherpa-onnx for windows | ||
| 67 | + shell: bash | ||
| 68 | + run: | | ||
| 69 | + cd build | ||
| 70 | + cmake --build . --config Release -- -m:2 | ||
| 71 | + | ||
| 72 | + ls -lh ./bin/Release/sherpa-onnx.exe | ||
| 73 | + | ||
| 74 | + - name: Test sherpa-onnx for Windows x64 | ||
| 75 | + shell: bash | ||
| 76 | + run: | | ||
| 77 | + export PATH=$PWD/build/bin/Release:$PATH | ||
| 78 | + export EXE=sherpa-onnx.exe | ||
| 79 | + | ||
| 80 | + .github/scripts/test-online-transducer.sh |
| @@ -2,89 +2,7 @@ | @@ -2,89 +2,7 @@ | ||
| 2 | 2 | ||
| 3 | Documentation: <https://k2-fsa.github.io/sherpa/onnx/index.html> | 3 | Documentation: <https://k2-fsa.github.io/sherpa/onnx/index.html> |
| 4 | 4 | ||
| 5 | -Try it in colab: | ||
| 6 | -[](https://colab.research.google.com/drive/1tmQbdlYeTl_klmtaGiUb7a7ZPz-AkBSH?usp=sharing) | ||
| 7 | - | ||
| 8 | See <https://github.com/k2-fsa/sherpa> | 5 | See <https://github.com/k2-fsa/sherpa> |
| 9 | 6 | ||
| 10 | This repo uses [onnxruntime](https://github.com/microsoft/onnxruntime) and | 7 | This repo uses [onnxruntime](https://github.com/microsoft/onnxruntime) and |
| 11 | does not depend on libtorch. | 8 | does not depend on libtorch. |
| 12 | - | ||
| 13 | -We provide exported models in onnx format and they can be downloaded using | ||
| 14 | -the following links: | ||
| 15 | - | ||
| 16 | -- English: <https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13> | ||
| 17 | -- Chinese: <https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2> | ||
| 18 | - | ||
| 19 | -**NOTE**: We provide only non-streaming models at present. | ||
| 20 | - | ||
| 21 | - | ||
| 22 | -**HINT**: The script for exporting the English model can be found at | ||
| 23 | -<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless3/export.py> | ||
| 24 | - | ||
| 25 | -**HINT**: The script for exporting the Chinese model can be found at | ||
| 26 | -<https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py> | ||
| 27 | - | ||
| 28 | -## Build for Linux/macOS | ||
| 29 | - | ||
| 30 | -```bash | ||
| 31 | -git clone https://github.com/k2-fsa/sherpa-onnx | ||
| 32 | -cd sherpa-onnx | ||
| 33 | -mkdir build | ||
| 34 | -cd build | ||
| 35 | -cmake -DCMAKE_BUILD_TYPE=Release .. | ||
| 36 | -make -j6 | ||
| 37 | -cd .. | ||
| 38 | -``` | ||
| 39 | - | ||
| 40 | -## Build for Windows | ||
| 41 | - | ||
| 42 | -```bash | ||
| 43 | -git clone https://github.com/k2-fsa/sherpa-onnx | ||
| 44 | -cd sherpa-onnx | ||
| 45 | -mkdir build | ||
| 46 | -cd build | ||
| 47 | -cmake -DCMAKE_BUILD_TYPE=Release .. | ||
| 48 | -cmake --build . --config Release | ||
| 49 | -cd .. | ||
| 50 | -``` | ||
| 51 | - | ||
| 52 | -## Download the pretrained model (English) | ||
| 53 | - | ||
| 54 | -```bash | ||
| 55 | -GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 | ||
| 56 | -cd icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 | ||
| 57 | -git lfs pull --include "exp/onnx/*.onnx" | ||
| 58 | -cd .. | ||
| 59 | - | ||
| 60 | -./build/bin/sherpa-onnx --help | ||
| 61 | - | ||
| 62 | -./build/bin/sherpa-onnx \ | ||
| 63 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ | ||
| 64 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ | ||
| 65 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ | ||
| 66 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ | ||
| 67 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ | ||
| 68 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ | ||
| 69 | - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav | ||
| 70 | -``` | ||
| 71 | - | ||
| 72 | -## Download the pretrained model (Chinese) | ||
| 73 | - | ||
| 74 | -```bash | ||
| 75 | -GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2 | ||
| 76 | -cd icefall_asr_wenetspeech_pruned_transducer_stateless2 | ||
| 77 | -git lfs pull --include "exp/*.onnx" | ||
| 78 | -cd .. | ||
| 79 | - | ||
| 80 | -./build/bin/sherpa-onnx --help | ||
| 81 | - | ||
| 82 | -./build/bin/sherpa-onnx \ | ||
| 83 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \ | ||
| 84 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \ | ||
| 85 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \ | ||
| 86 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \ | ||
| 87 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \ | ||
| 88 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \ | ||
| 89 | - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000000.wav | ||
| 90 | -``` |
| @@ -2,7 +2,11 @@ include_directories(${CMAKE_SOURCE_DIR}) | @@ -2,7 +2,11 @@ include_directories(${CMAKE_SOURCE_DIR}) | ||
| 2 | 2 | ||
| 3 | add_executable(sherpa-onnx | 3 | add_executable(sherpa-onnx |
| 4 | decode.cc | 4 | decode.cc |
| 5 | - rnnt-model.cc | 5 | + features.cc |
| 6 | + online-lstm-transducer-model.cc | ||
| 7 | + online-transducer-model-config.cc | ||
| 8 | + online-transducer-model.cc | ||
| 9 | + onnx-utils.cc | ||
| 6 | sherpa-onnx.cc | 10 | sherpa-onnx.cc |
| 7 | symbol-table.cc | 11 | symbol-table.cc |
| 8 | wave-reader.cc | 12 | wave-reader.cc |
| @@ -13,5 +17,5 @@ target_link_libraries(sherpa-onnx | @@ -13,5 +17,5 @@ target_link_libraries(sherpa-onnx | ||
| 13 | kaldi-native-fbank-core | 17 | kaldi-native-fbank-core |
| 14 | ) | 18 | ) |
| 15 | 19 | ||
| 16 | -# add_executable(sherpa-show-onnx-info show-onnx-info.cc) | ||
| 17 | -# target_link_libraries(sherpa-show-onnx-info onnxruntime) | 20 | +add_executable(sherpa-onnx-show-info show-onnx-info.cc) |
| 21 | +target_link_libraries(sherpa-onnx-show-info onnxruntime) |
| 1 | -/** | ||
| 2 | - * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) | ||
| 3 | - * | ||
| 4 | - * See LICENSE for clarification regarding multiple authors | ||
| 5 | - * | ||
| 6 | - * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 7 | - * you may not use this file except in compliance with the License. | ||
| 8 | - * You may obtain a copy of the License at | ||
| 9 | - * | ||
| 10 | - * http://www.apache.org/licenses/LICENSE-2.0 | ||
| 11 | - * | ||
| 12 | - * Unless required by applicable law or agreed to in writing, software | ||
| 13 | - * distributed under the License is distributed on an "AS IS" BASIS, | ||
| 14 | - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 15 | - * See the License for the specific language governing permissions and | ||
| 16 | - * limitations under the License. | ||
| 17 | - */ | 1 | +// sherpa/csrc/decode.cc |
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 18 | 4 | ||
| 19 | #include "sherpa-onnx/csrc/decode.h" | 5 | #include "sherpa-onnx/csrc/decode.h" |
| 20 | 6 | ||
| 21 | #include <assert.h> | 7 | #include <assert.h> |
| 22 | 8 | ||
| 23 | #include <algorithm> | 9 | #include <algorithm> |
| 10 | +#include <utility> | ||
| 24 | #include <vector> | 11 | #include <vector> |
| 25 | 12 | ||
| 26 | namespace sherpa_onnx { | 13 | namespace sherpa_onnx { |
| 27 | 14 | ||
| 28 | -std::vector<int32_t> GreedySearch(RnntModel &model, // NOLINT | ||
| 29 | - const Ort::Value &encoder_out) { | 15 | +static Ort::Value Clone(Ort::Value *v) { |
| 16 | + auto type_and_shape = v->GetTensorTypeAndShapeInfo(); | ||
| 17 | + std::vector<int64_t> shape = type_and_shape.GetShape(); | ||
| 18 | + | ||
| 19 | + auto memory_info = | ||
| 20 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 21 | + | ||
| 22 | + return Ort::Value::CreateTensor(memory_info, v->GetTensorMutableData<float>(), | ||
| 23 | + type_and_shape.GetElementCount(), | ||
| 24 | + shape.data(), shape.size()); | ||
| 25 | +} | ||
| 26 | + | ||
| 27 | +static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) { | ||
| 30 | std::vector<int64_t> encoder_out_shape = | 28 | std::vector<int64_t> encoder_out_shape = |
| 31 | - encoder_out.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 32 | - assert(encoder_out_shape[0] == 1 && "Only batch_size=1 is implemented"); | ||
| 33 | - Ort::Value projected_encoder_out = | ||
| 34 | - model.RunJoinerEncoderProj(encoder_out.GetTensorData<float>(), | ||
| 35 | - encoder_out_shape[1], encoder_out_shape[2]); | 29 | + encoder_out->GetTensorTypeAndShapeInfo().GetShape(); |
| 30 | + assert(encoder_out_shape[0] == 1); | ||
| 36 | 31 | ||
| 37 | - const float *p_projected_encoder_out = | ||
| 38 | - projected_encoder_out.GetTensorData<float>(); | 32 | + int32_t encoder_out_dim = encoder_out_shape[2]; |
| 39 | 33 | ||
| 40 | - int32_t context_size = 2; // hard-code it to 2 | ||
| 41 | - int32_t blank_id = 0; // hard-code it to 0 | ||
| 42 | - std::vector<int32_t> hyp(context_size, blank_id); | ||
| 43 | - std::array<int64_t, 2> decoder_input{blank_id, blank_id}; | 34 | + auto memory_info = |
| 35 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 44 | 36 | ||
| 45 | - Ort::Value decoder_out = model.RunDecoder(decoder_input.data(), context_size); | 37 | + std::array<int64_t, 2> shape{1, encoder_out_dim}; |
| 46 | 38 | ||
| 47 | - std::vector<int64_t> decoder_out_shape = | ||
| 48 | - decoder_out.GetTensorTypeAndShapeInfo().GetShape(); | 39 | + return Ort::Value::CreateTensor( |
| 40 | + memory_info, | ||
| 41 | + encoder_out->GetTensorMutableData<float>() + t * encoder_out_dim, | ||
| 42 | + encoder_out_dim, shape.data(), shape.size()); | ||
| 43 | +} | ||
| 49 | 44 | ||
| 50 | - Ort::Value projected_decoder_out = model.RunJoinerDecoderProj( | ||
| 51 | - decoder_out.GetTensorData<float>(), decoder_out_shape[2]); | 45 | +void GreedySearch(OnlineTransducerModel *model, Ort::Value encoder_out, |
| 46 | + std::vector<int64_t> *hyp) { | ||
| 47 | + std::vector<int64_t> encoder_out_shape = | ||
| 48 | + encoder_out.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 52 | 49 | ||
| 53 | - int32_t joiner_dim = | ||
| 54 | - projected_decoder_out.GetTensorTypeAndShapeInfo().GetShape()[1]; | 50 | + if (encoder_out_shape[0] > 1) { |
| 51 | + fprintf(stderr, "Only batch_size=1 is implemented. Given: %d\n", | ||
| 52 | + static_cast<int32_t>(encoder_out_shape[0])); | ||
| 53 | + } | ||
| 55 | 54 | ||
| 56 | - int32_t T = encoder_out_shape[1]; | ||
| 57 | - for (int32_t t = 0; t != T; ++t) { | ||
| 58 | - Ort::Value logit = model.RunJoiner( | ||
| 59 | - p_projected_encoder_out + t * joiner_dim, | ||
| 60 | - projected_decoder_out.GetTensorData<float>(), joiner_dim); | 55 | + int32_t num_frames = encoder_out_shape[1]; |
| 56 | + int32_t vocab_size = model->VocabSize(); | ||
| 61 | 57 | ||
| 62 | - int32_t vocab_size = logit.GetTensorTypeAndShapeInfo().GetShape()[1]; | 58 | + Ort::Value decoder_input = model->BuildDecoderInput(*hyp); |
| 59 | + Ort::Value decoder_out = model->RunDecoder(std::move(decoder_input)); | ||
| 63 | 60 | ||
| 61 | + for (int32_t t = 0; t != num_frames; ++t) { | ||
| 62 | + Ort::Value cur_encoder_out = GetFrame(&encoder_out, t); | ||
| 63 | + Ort::Value logit = | ||
| 64 | + model->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out)); | ||
| 64 | const float *p_logit = logit.GetTensorData<float>(); | 65 | const float *p_logit = logit.GetTensorData<float>(); |
| 65 | 66 | ||
| 66 | auto y = static_cast<int32_t>(std::distance( | 67 | auto y = static_cast<int32_t>(std::distance( |
| 67 | static_cast<const float *>(p_logit), | 68 | static_cast<const float *>(p_logit), |
| 68 | std::max_element(static_cast<const float *>(p_logit), | 69 | std::max_element(static_cast<const float *>(p_logit), |
| 69 | static_cast<const float *>(p_logit) + vocab_size))); | 70 | static_cast<const float *>(p_logit) + vocab_size))); |
| 70 | - | ||
| 71 | - if (y != blank_id) { | ||
| 72 | - decoder_input[0] = hyp.back(); | ||
| 73 | - decoder_input[1] = y; | ||
| 74 | - hyp.push_back(y); | ||
| 75 | - decoder_out = model.RunDecoder(decoder_input.data(), context_size); | ||
| 76 | - projected_decoder_out = model.RunJoinerDecoderProj( | ||
| 77 | - decoder_out.GetTensorData<float>(), decoder_out_shape[2]); | 71 | + if (y != 0) { |
| 72 | + hyp->push_back(y); | ||
| 73 | + decoder_input = model->BuildDecoderInput(*hyp); | ||
| 74 | + decoder_out = model->RunDecoder(std::move(decoder_input)); | ||
| 78 | } | 75 | } |
| 79 | } | 76 | } |
| 80 | - | ||
| 81 | - return {hyp.begin() + context_size, hyp.end()}; | ||
| 82 | } | 77 | } |
| 83 | 78 | ||
| 84 | } // namespace sherpa_onnx | 79 | } // namespace sherpa_onnx |
| 1 | -/** | ||
| 2 | - * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) | ||
| 3 | - * | ||
| 4 | - * See LICENSE for clarification regarding multiple authors | ||
| 5 | - * | ||
| 6 | - * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 7 | - * you may not use this file except in compliance with the License. | ||
| 8 | - * You may obtain a copy of the License at | ||
| 9 | - * | ||
| 10 | - * http://www.apache.org/licenses/LICENSE-2.0 | ||
| 11 | - * | ||
| 12 | - * Unless required by applicable law or agreed to in writing, software | ||
| 13 | - * distributed under the License is distributed on an "AS IS" BASIS, | ||
| 14 | - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 15 | - * See the License for the specific language governing permissions and | ||
| 16 | - * limitations under the License. | ||
| 17 | - */ | 1 | +// sherpa/csrc/decode.h |
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 18 | 4 | ||
| 19 | #ifndef SHERPA_ONNX_CSRC_DECODE_H_ | 5 | #ifndef SHERPA_ONNX_CSRC_DECODE_H_ |
| 20 | #define SHERPA_ONNX_CSRC_DECODE_H_ | 6 | #define SHERPA_ONNX_CSRC_DECODE_H_ |
| 21 | 7 | ||
| 22 | #include <vector> | 8 | #include <vector> |
| 23 | 9 | ||
| 24 | -#include "sherpa-onnx/csrc/rnnt-model.h" | 10 | +#include "sherpa-onnx/csrc/online-transducer-model.h" |
| 25 | 11 | ||
| 26 | namespace sherpa_onnx { | 12 | namespace sherpa_onnx { |
| 27 | 13 | ||
| @@ -32,8 +18,8 @@ namespace sherpa_onnx { | @@ -32,8 +18,8 @@ namespace sherpa_onnx { | ||
| 32 | * @param model The RnntModel | 18 | * @param model The RnntModel |
| 33 | * @param encoder_out Its shape is (1, num_frames, encoder_out_dim). | 19 | * @param encoder_out Its shape is (1, num_frames, encoder_out_dim). |
| 34 | */ | 20 | */ |
| 35 | -std::vector<int32_t> GreedySearch(RnntModel &model, // NOLINT | ||
| 36 | - const Ort::Value &encoder_out); | 21 | +void GreedySearch(OnlineTransducerModel *model, Ort::Value encoder_out, |
| 22 | + std::vector<int64_t> *hyp); | ||
| 37 | 23 | ||
| 38 | } // namespace sherpa_onnx | 24 | } // namespace sherpa_onnx |
| 39 | 25 |
sherpa-onnx/csrc/features.cc
0 → 100644
| 1 | +// sherpa/csrc/features.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/features.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <memory> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +FeatureExtractor::FeatureExtractor() { | ||
| 14 | + opts_.frame_opts.dither = 0; | ||
| 15 | + opts_.frame_opts.snip_edges = false; | ||
| 16 | + opts_.frame_opts.samp_freq = 16000; | ||
| 17 | + | ||
| 18 | + // cache 100 seconds of feature frames, which is more than enough | ||
| 19 | + // for real needs | ||
| 20 | + opts_.frame_opts.max_feature_vectors = 100 * 100; | ||
| 21 | + | ||
| 22 | + opts_.mel_opts.num_bins = 80; // feature dim | ||
| 23 | + | ||
| 24 | + fbank_ = std::make_unique<knf::OnlineFbank>(opts_); | ||
| 25 | +} | ||
| 26 | + | ||
| 27 | +FeatureExtractor::FeatureExtractor(const knf::FbankOptions &opts) | ||
| 28 | + : opts_(opts) { | ||
| 29 | + fbank_ = std::make_unique<knf::OnlineFbank>(opts_); | ||
| 30 | +} | ||
| 31 | + | ||
| 32 | +void FeatureExtractor::AcceptWaveform(float sampling_rate, | ||
| 33 | + const float *waveform, int32_t n) { | ||
| 34 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 35 | + fbank_->AcceptWaveform(sampling_rate, waveform, n); | ||
| 36 | +} | ||
| 37 | + | ||
| 38 | +void FeatureExtractor::InputFinished() { | ||
| 39 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 40 | + fbank_->InputFinished(); | ||
| 41 | +} | ||
| 42 | + | ||
| 43 | +int32_t FeatureExtractor::NumFramesReady() const { | ||
| 44 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 45 | + return fbank_->NumFramesReady(); | ||
| 46 | +} | ||
| 47 | + | ||
| 48 | +bool FeatureExtractor::IsLastFrame(int32_t frame) const { | ||
| 49 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 50 | + return fbank_->IsLastFrame(frame); | ||
| 51 | +} | ||
| 52 | + | ||
| 53 | +std::vector<float> FeatureExtractor::GetFrames(int32_t frame_index, | ||
| 54 | + int32_t n) const { | ||
| 55 | + if (frame_index + n > NumFramesReady()) { | ||
| 56 | + fprintf(stderr, "%d + %d > %d\n", frame_index, n, NumFramesReady()); | ||
| 57 | + exit(-1); | ||
| 58 | + } | ||
| 59 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 60 | + | ||
| 61 | + int32_t feature_dim = fbank_->Dim(); | ||
| 62 | + std::vector<float> features(feature_dim * n); | ||
| 63 | + | ||
| 64 | + float *p = features.data(); | ||
| 65 | + | ||
| 66 | + for (int32_t i = 0; i != n; ++i) { | ||
| 67 | + const float *f = fbank_->GetFrame(i + frame_index); | ||
| 68 | + std::copy(f, f + feature_dim, p); | ||
| 69 | + p += feature_dim; | ||
| 70 | + } | ||
| 71 | + | ||
| 72 | + return features; | ||
| 73 | +} | ||
| 74 | + | ||
| 75 | +void FeatureExtractor::Reset() { | ||
| 76 | + fbank_ = std::make_unique<knf::OnlineFbank>(opts_); | ||
| 77 | +} | ||
| 78 | + | ||
| 79 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/features.h
0 → 100644
| 1 | +// sherpa/csrc/features.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_FEATURES_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_FEATURES_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | +#include <mutex> // NOLINT | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#include "kaldi-native-fbank/csrc/online-feature.h" | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +class FeatureExtractor { | ||
| 17 | + public: | ||
| 18 | + FeatureExtractor(); | ||
| 19 | + explicit FeatureExtractor(const knf::FbankOptions &fbank_opts); | ||
| 20 | + | ||
| 21 | + /** | ||
| 22 | + @param sampling_rate The sampling_rate of the input waveform. Should match | ||
| 23 | + the one expected by the feature extractor. | ||
| 24 | + @param waveform Pointer to a 1-D array of size n | ||
| 25 | + @param n Number of entries in waveform | ||
| 26 | + */ | ||
| 27 | + void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n); | ||
| 28 | + | ||
| 29 | + // InputFinished() tells the class you won't be providing any | ||
| 30 | + // more waveform. This will help flush out the last frame or two | ||
| 31 | + // of features, in the case where snip-edges == false; it also | ||
| 32 | + // affects the return value of IsLastFrame(). | ||
| 33 | + void InputFinished(); | ||
| 34 | + | ||
| 35 | + int32_t NumFramesReady() const; | ||
| 36 | + | ||
| 37 | + // Note: IsLastFrame() will only ever return true if you have called | ||
| 38 | + // InputFinished() (and this frame is the last frame). | ||
| 39 | + bool IsLastFrame(int32_t frame) const; | ||
| 40 | + | ||
| 41 | + /** Get n frames starting from the given frame index. | ||
| 42 | + * | ||
| 43 | + * @param frame_index The starting frame index | ||
| 44 | + * @param n Number of frames to get. | ||
| 45 | + * @return Return a 2-D tensor of shape (n, feature_dim). | ||
| 46 | + * which is flattened into a 1-D vector (flattened in in row major) | ||
| 47 | + */ | ||
| 48 | + std::vector<float> GetFrames(int32_t frame_index, int32_t n) const; | ||
| 49 | + | ||
| 50 | + void Reset(); | ||
| 51 | + int32_t FeatureDim() const { return opts_.mel_opts.num_bins; } | ||
| 52 | + | ||
| 53 | + private: | ||
| 54 | + std::unique_ptr<knf::OnlineFbank> fbank_; | ||
| 55 | + knf::FbankOptions opts_; | ||
| 56 | + mutable std::mutex mutex_; | ||
| 57 | +}; | ||
| 58 | + | ||
| 59 | +} // namespace sherpa_onnx | ||
| 60 | + | ||
| 61 | +#endif // SHERPA_ONNX_CSRC_FEATURES_H_ |
| 1 | +// sherpa/csrc/online-lstm-transducer-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#include "sherpa-onnx/csrc/online-lstm-transducer-model.h" | ||
| 5 | + | ||
| 6 | +#include <memory> | ||
| 7 | +#include <sstream> | ||
| 8 | +#include <string> | ||
| 9 | +#include <utility> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 13 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 14 | + | ||
| 15 | +#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \ | ||
| 16 | + do { \ | ||
| 17 | + auto value = \ | ||
| 18 | + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | ||
| 19 | + if (!value) { \ | ||
| 20 | + fprintf(stderr, "%s does not exist in the metadata\n", src_key); \ | ||
| 21 | + exit(-1); \ | ||
| 22 | + } \ | ||
| 23 | + dst = atoi(value.get()); \ | ||
| 24 | + if (dst <= 0) { \ | ||
| 25 | + fprintf(stderr, "Invalud value %d for %s\n", dst, src_key); \ | ||
| 26 | + exit(-1); \ | ||
| 27 | + } \ | ||
| 28 | + } while (0) | ||
| 29 | + | ||
| 30 | +namespace sherpa_onnx { | ||
| 31 | + | ||
| 32 | +OnlineLstmTransducerModel::OnlineLstmTransducerModel( | ||
| 33 | + const OnlineTransducerModelConfig &config) | ||
| 34 | + : env_(ORT_LOGGING_LEVEL_WARNING), | ||
| 35 | + config_(config), | ||
| 36 | + sess_opts_{}, | ||
| 37 | + allocator_{} { | ||
| 38 | + sess_opts_.SetIntraOpNumThreads(config.num_threads); | ||
| 39 | + sess_opts_.SetInterOpNumThreads(config.num_threads); | ||
| 40 | + | ||
| 41 | + InitEncoder(config.encoder_filename); | ||
| 42 | + InitDecoder(config.decoder_filename); | ||
| 43 | + InitJoiner(config.joiner_filename); | ||
| 44 | +} | ||
| 45 | + | ||
| 46 | +void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) { | ||
| 47 | + encoder_sess_ = std::make_unique<Ort::Session>( | ||
| 48 | + env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); | ||
| 49 | + | ||
| 50 | + GetInputNames(encoder_sess_.get(), &encoder_input_names_, | ||
| 51 | + &encoder_input_names_ptr_); | ||
| 52 | + | ||
| 53 | + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, | ||
| 54 | + &encoder_output_names_ptr_); | ||
| 55 | + | ||
| 56 | + // get meta data | ||
| 57 | + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); | ||
| 58 | + if (config_.debug) { | ||
| 59 | + std::ostringstream os; | ||
| 60 | + os << "---encoder---\n"; | ||
| 61 | + PrintModelMetadata(os, meta_data); | ||
| 62 | + fprintf(stderr, "%s\n", os.str().c_str()); | ||
| 63 | + } | ||
| 64 | + | ||
| 65 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 66 | + SHERPA_ONNX_READ_META_DATA(num_encoder_layers_, "num_encoder_layers"); | ||
| 67 | + SHERPA_ONNX_READ_META_DATA(T_, "T"); | ||
| 68 | + SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); | ||
| 69 | + SHERPA_ONNX_READ_META_DATA(rnn_hidden_size_, "rnn_hidden_size"); | ||
| 70 | + SHERPA_ONNX_READ_META_DATA(d_model_, "d_model"); | ||
| 71 | +} | ||
| 72 | + | ||
| 73 | +void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) { | ||
| 74 | + decoder_sess_ = std::make_unique<Ort::Session>( | ||
| 75 | + env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); | ||
| 76 | + | ||
| 77 | + GetInputNames(decoder_sess_.get(), &decoder_input_names_, | ||
| 78 | + &decoder_input_names_ptr_); | ||
| 79 | + | ||
| 80 | + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, | ||
| 81 | + &decoder_output_names_ptr_); | ||
| 82 | + | ||
| 83 | + // get meta data | ||
| 84 | + Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata(); | ||
| 85 | + if (config_.debug) { | ||
| 86 | + std::ostringstream os; | ||
| 87 | + os << "---decoder---\n"; | ||
| 88 | + PrintModelMetadata(os, meta_data); | ||
| 89 | + fprintf(stderr, "%s\n", os.str().c_str()); | ||
| 90 | + } | ||
| 91 | + | ||
| 92 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 93 | + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); | ||
| 94 | + SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); | ||
| 95 | +} | ||
| 96 | + | ||
| 97 | +void OnlineLstmTransducerModel::InitJoiner(const std::string &filename) { | ||
| 98 | + joiner_sess_ = std::make_unique<Ort::Session>( | ||
| 99 | + env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); | ||
| 100 | + | ||
| 101 | + GetInputNames(joiner_sess_.get(), &joiner_input_names_, | ||
| 102 | + &joiner_input_names_ptr_); | ||
| 103 | + | ||
| 104 | + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, | ||
| 105 | + &joiner_output_names_ptr_); | ||
| 106 | + | ||
| 107 | + // get meta data | ||
| 108 | + Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata(); | ||
| 109 | + if (config_.debug) { | ||
| 110 | + std::ostringstream os; | ||
| 111 | + os << "---joiner---\n"; | ||
| 112 | + PrintModelMetadata(os, meta_data); | ||
| 113 | + fprintf(stderr, "%s\n", os.str().c_str()); | ||
| 114 | + } | ||
| 115 | +} | ||
| 116 | + | ||
| 117 | +Ort::Value OnlineLstmTransducerModel::StackStates( | ||
| 118 | + const std::vector<Ort::Value> &states) const { | ||
| 119 | + fprintf(stderr, "implement me: %s:%d!\n", __func__, | ||
| 120 | + static_cast<int>(__LINE__)); | ||
| 121 | + auto memory_info = | ||
| 122 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 123 | + int64_t a; | ||
| 124 | + std::array<int64_t, 3> x_shape{1, 1, 1}; | ||
| 125 | + Ort::Value x = Ort::Value::CreateTensor(memory_info, &a, 0, &a, 0); | ||
| 126 | + return x; | ||
| 127 | +} | ||
| 128 | + | ||
| 129 | +std::vector<Ort::Value> OnlineLstmTransducerModel::UnStackStates( | ||
| 130 | + Ort::Value states) const { | ||
| 131 | + fprintf(stderr, "implement me: %s:%d!\n", __func__, | ||
| 132 | + static_cast<int>(__LINE__)); | ||
| 133 | + return {}; | ||
| 134 | +} | ||
| 135 | + | ||
| 136 | +std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() { | ||
| 137 | + // Please see | ||
| 138 | + // https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py#L185 | ||
| 139 | + // for details | ||
| 140 | + constexpr int32_t kBatchSize = 1; | ||
| 141 | + std::array<int64_t, 3> h_shape{num_encoder_layers_, kBatchSize, d_model_}; | ||
| 142 | + Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(), | ||
| 143 | + h_shape.size()); | ||
| 144 | + | ||
| 145 | + std::fill(h.GetTensorMutableData<float>(), | ||
| 146 | + h.GetTensorMutableData<float>() + | ||
| 147 | + num_encoder_layers_ * kBatchSize * d_model_, | ||
| 148 | + 0); | ||
| 149 | + | ||
| 150 | + std::array<int64_t, 3> c_shape{num_encoder_layers_, kBatchSize, | ||
| 151 | + rnn_hidden_size_}; | ||
| 152 | + Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(), | ||
| 153 | + c_shape.size()); | ||
| 154 | + | ||
| 155 | + std::fill(c.GetTensorMutableData<float>(), | ||
| 156 | + c.GetTensorMutableData<float>() + | ||
| 157 | + num_encoder_layers_ * kBatchSize * rnn_hidden_size_, | ||
| 158 | + 0); | ||
| 159 | + | ||
| 160 | + std::vector<Ort::Value> states; | ||
| 161 | + | ||
| 162 | + states.reserve(2); | ||
| 163 | + states.push_back(std::move(h)); | ||
| 164 | + states.push_back(std::move(c)); | ||
| 165 | + | ||
| 166 | + return states; | ||
| 167 | +} | ||
| 168 | + | ||
| 169 | +std::pair<Ort::Value, std::vector<Ort::Value>> | ||
| 170 | +OnlineLstmTransducerModel::RunEncoder(Ort::Value features, | ||
| 171 | + std::vector<Ort::Value> &states) { | ||
| 172 | + auto memory_info = | ||
| 173 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 174 | + | ||
| 175 | + std::array<Ort::Value, 3> encoder_inputs = { | ||
| 176 | + std::move(features), std::move(states[0]), std::move(states[1])}; | ||
| 177 | + | ||
| 178 | + auto encoder_out = encoder_sess_->Run( | ||
| 179 | + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(), | ||
| 180 | + encoder_inputs.size(), encoder_output_names_ptr_.data(), | ||
| 181 | + encoder_output_names_ptr_.size()); | ||
| 182 | + | ||
| 183 | + std::vector<Ort::Value> next_states; | ||
| 184 | + next_states.reserve(2); | ||
| 185 | + next_states.push_back(std::move(encoder_out[1])); | ||
| 186 | + next_states.push_back(std::move(encoder_out[2])); | ||
| 187 | + | ||
| 188 | + return {std::move(encoder_out[0]), std::move(next_states)}; | ||
| 189 | +} | ||
| 190 | + | ||
| 191 | +Ort::Value OnlineLstmTransducerModel::BuildDecoderInput( | ||
| 192 | + const std::vector<int64_t> &hyp) { | ||
| 193 | + auto memory_info = | ||
| 194 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 195 | + | ||
| 196 | + std::array<int64_t, 2> shape{1, context_size_}; | ||
| 197 | + | ||
| 198 | + return Ort::Value::CreateTensor( | ||
| 199 | + memory_info, | ||
| 200 | + const_cast<int64_t *>(hyp.data() + hyp.size() - context_size_), | ||
| 201 | + context_size_, shape.data(), shape.size()); | ||
| 202 | +} | ||
| 203 | + | ||
| 204 | +Ort::Value OnlineLstmTransducerModel::RunDecoder(Ort::Value decoder_input) { | ||
| 205 | + auto decoder_out = decoder_sess_->Run( | ||
| 206 | + {}, decoder_input_names_ptr_.data(), &decoder_input, 1, | ||
| 207 | + decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size()); | ||
| 208 | + return std::move(decoder_out[0]); | ||
| 209 | +} | ||
| 210 | + | ||
| 211 | +Ort::Value OnlineLstmTransducerModel::RunJoiner(Ort::Value encoder_out, | ||
| 212 | + Ort::Value decoder_out) { | ||
| 213 | + std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out), | ||
| 214 | + std::move(decoder_out)}; | ||
| 215 | + auto logit = | ||
| 216 | + joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(), | ||
| 217 | + joiner_input.size(), joiner_output_names_ptr_.data(), | ||
| 218 | + joiner_output_names_ptr_.size()); | ||
| 219 | + | ||
| 220 | + return std::move(logit[0]); | ||
| 221 | +} | ||
| 222 | + | ||
| 223 | +} // namespace sherpa_onnx |
| 1 | +// sherpa/csrc/online-lstm-transducer-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_LSTM_TRANSDUCER_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_ONLINE_LSTM_TRANSDUCER_MODEL_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <string> | ||
| 9 | +#include <utility> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 13 | +#include "sherpa-onnx/csrc/online-transducer-model-config.h" | ||
| 14 | +#include "sherpa-onnx/csrc/online-transducer-model.h" | ||
| 15 | + | ||
| 16 | +namespace sherpa_onnx { | ||
| 17 | + | ||
| 18 | +class OnlineLstmTransducerModel : public OnlineTransducerModel { | ||
| 19 | + public: | ||
| 20 | + explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config); | ||
| 21 | + | ||
| 22 | + Ort::Value StackStates(const std::vector<Ort::Value> &states) const override; | ||
| 23 | + | ||
| 24 | + std::vector<Ort::Value> UnStackStates(Ort::Value states) const override; | ||
| 25 | + | ||
| 26 | + std::vector<Ort::Value> GetEncoderInitStates() override; | ||
| 27 | + | ||
| 28 | + std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( | ||
| 29 | + Ort::Value features, std::vector<Ort::Value> &states) override; | ||
| 30 | + | ||
| 31 | + Ort::Value BuildDecoderInput(const std::vector<int64_t> &hyp) override; | ||
| 32 | + | ||
| 33 | + Ort::Value RunDecoder(Ort::Value decoder_input) override; | ||
| 34 | + | ||
| 35 | + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override; | ||
| 36 | + | ||
| 37 | + int32_t ContextSize() const override { return context_size_; } | ||
| 38 | + | ||
| 39 | + int32_t ChunkSize() const override { return T_; } | ||
| 40 | + | ||
| 41 | + int32_t ChunkShift() const override { return decode_chunk_len_; } | ||
| 42 | + | ||
| 43 | + int32_t VocabSize() const override { return vocab_size_; } | ||
| 44 | + | ||
| 45 | + private: | ||
| 46 | + void InitEncoder(const std::string &encoder_filename); | ||
| 47 | + void InitDecoder(const std::string &decoder_filename); | ||
| 48 | + void InitJoiner(const std::string &joiner_filename); | ||
| 49 | + | ||
| 50 | + private: | ||
| 51 | + Ort::Env env_; | ||
| 52 | + Ort::SessionOptions sess_opts_; | ||
| 53 | + | ||
| 54 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 55 | + | ||
| 56 | + std::unique_ptr<Ort::Session> encoder_sess_; | ||
| 57 | + std::unique_ptr<Ort::Session> decoder_sess_; | ||
| 58 | + std::unique_ptr<Ort::Session> joiner_sess_; | ||
| 59 | + | ||
| 60 | + std::vector<std::string> encoder_input_names_; | ||
| 61 | + std::vector<const char *> encoder_input_names_ptr_; | ||
| 62 | + | ||
| 63 | + std::vector<std::string> encoder_output_names_; | ||
| 64 | + std::vector<const char *> encoder_output_names_ptr_; | ||
| 65 | + | ||
| 66 | + std::vector<std::string> decoder_input_names_; | ||
| 67 | + std::vector<const char *> decoder_input_names_ptr_; | ||
| 68 | + | ||
| 69 | + std::vector<std::string> decoder_output_names_; | ||
| 70 | + std::vector<const char *> decoder_output_names_ptr_; | ||
| 71 | + | ||
| 72 | + std::vector<std::string> joiner_input_names_; | ||
| 73 | + std::vector<const char *> joiner_input_names_ptr_; | ||
| 74 | + | ||
| 75 | + std::vector<std::string> joiner_output_names_; | ||
| 76 | + std::vector<const char *> joiner_output_names_ptr_; | ||
| 77 | + | ||
| 78 | + OnlineTransducerModelConfig config_; | ||
| 79 | + | ||
| 80 | + int32_t num_encoder_layers_ = 0; | ||
| 81 | + int32_t T_ = 0; | ||
| 82 | + int32_t decode_chunk_len_ = 0; | ||
| 83 | + int32_t rnn_hidden_size_ = 0; | ||
| 84 | + int32_t d_model_ = 0; | ||
| 85 | + int32_t context_size_ = 0; | ||
| 86 | + int32_t vocab_size_ = 0; | ||
| 87 | +}; | ||
| 88 | + | ||
| 89 | +} // namespace sherpa_onnx | ||
| 90 | + | ||
| 91 | +#endif // SHERPA_ONNX_CSRC_ONLINE_LSTM_TRANSDUCER_MODEL_H_ |
| 1 | +// sherpa/csrc/online-transducer-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#include "sherpa-onnx/csrc/online-transducer-model-config.h" | ||
| 5 | + | ||
| 6 | +#include <sstream> | ||
| 7 | + | ||
| 8 | +namespace sherpa_onnx { | ||
| 9 | + | ||
| 10 | +std::string OnlineTransducerModelConfig::ToString() const { | ||
| 11 | + std::ostringstream os; | ||
| 12 | + | ||
| 13 | + os << "OnlineTransducerModelConfig("; | ||
| 14 | + os << "encoder_filename=\"" << encoder_filename << "\", "; | ||
| 15 | + os << "decoder_filename=\"" << decoder_filename << "\", "; | ||
| 16 | + os << "joiner_filename=\"" << joiner_filename << "\", "; | ||
| 17 | + os << "num_threads=" << num_threads << ", "; | ||
| 18 | + os << "debug=" << (debug ? "True" : "False") << ")"; | ||
| 19 | + | ||
| 20 | + return os.str(); | ||
| 21 | +} | ||
| 22 | + | ||
| 23 | +} // namespace sherpa_onnx |
| 1 | +// sherpa/csrc/online-transducer-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +namespace sherpa_onnx { | ||
| 10 | + | ||
| 11 | +struct OnlineTransducerModelConfig { | ||
| 12 | + std::string encoder_filename; | ||
| 13 | + std::string decoder_filename; | ||
| 14 | + std::string joiner_filename; | ||
| 15 | + int32_t num_threads; | ||
| 16 | + bool debug = false; | ||
| 17 | + | ||
| 18 | + std::string ToString() const; | ||
| 19 | +}; | ||
| 20 | + | ||
| 21 | +} // namespace sherpa_onnx | ||
| 22 | + | ||
| 23 | +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_ |
sherpa-onnx/csrc/online-transducer-model.cc
0 → 100644
| 1 | +// sherpa/csrc/online-transducer-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#include "sherpa-onnx/csrc/online-transducer-model.h" | ||
| 5 | + | ||
| 6 | +#include <memory> | ||
| 7 | +#include <sstream> | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/online-lstm-transducer-model.h" | ||
| 11 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +enum class ModelType { | ||
| 15 | + kLstm, | ||
| 16 | + kUnkown, | ||
| 17 | +}; | ||
| 18 | + | ||
| 19 | +static ModelType GetModelType(const OnlineTransducerModelConfig &config) { | ||
| 20 | + Ort::Env env(ORT_LOGGING_LEVEL_WARNING); | ||
| 21 | + Ort::SessionOptions sess_opts; | ||
| 22 | + | ||
| 23 | + auto sess = std::make_unique<Ort::Session>( | ||
| 24 | + env, SHERPA_MAYBE_WIDE(config.encoder_filename).c_str(), sess_opts); | ||
| 25 | + | ||
| 26 | + Ort::ModelMetadata meta_data = sess->GetModelMetadata(); | ||
| 27 | + if (config.debug) { | ||
| 28 | + std::ostringstream os; | ||
| 29 | + PrintModelMetadata(os, meta_data); | ||
| 30 | + fprintf(stderr, "%s\n", os.str().c_str()); | ||
| 31 | + } | ||
| 32 | + | ||
| 33 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 34 | + auto model_type = | ||
| 35 | + meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); | ||
| 36 | + if (!model_type) { | ||
| 37 | + fprintf(stderr, "No model_type in the metadata!\n"); | ||
| 38 | + return ModelType::kUnkown; | ||
| 39 | + } | ||
| 40 | + | ||
| 41 | + if (model_type.get() == std::string("lstm")) { | ||
| 42 | + return ModelType::kLstm; | ||
| 43 | + } else { | ||
| 44 | + fprintf(stderr, "Unsupported model_type: %s\n", model_type.get()); | ||
| 45 | + return ModelType::kUnkown; | ||
| 46 | + } | ||
| 47 | +} | ||
| 48 | + | ||
| 49 | +std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | ||
| 50 | + const OnlineTransducerModelConfig &config) { | ||
| 51 | + auto model_type = GetModelType(config); | ||
| 52 | + | ||
| 53 | + switch (model_type) { | ||
| 54 | + case ModelType::kLstm: | ||
| 55 | + return std::make_unique<OnlineLstmTransducerModel>(config); | ||
| 56 | + case ModelType::kUnkown: | ||
| 57 | + return nullptr; | ||
| 58 | + } | ||
| 59 | + | ||
| 60 | + // unreachable code | ||
| 61 | + return nullptr; | ||
| 62 | +} | ||
| 63 | + | ||
| 64 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/online-transducer-model.h
0 → 100644
| 1 | +// sherpa/csrc/online-transducer-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <utility> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 12 | +#include "sherpa-onnx/csrc/online-transducer-model-config.h" | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +class OnlineTransducerModel { | ||
| 17 | + public: | ||
| 18 | + virtual ~OnlineTransducerModel() = default; | ||
| 19 | + | ||
| 20 | + static std::unique_ptr<OnlineTransducerModel> Create( | ||
| 21 | + const OnlineTransducerModelConfig &config); | ||
| 22 | + | ||
| 23 | + /** Stack a list of individual states into a batch. | ||
| 24 | + * | ||
| 25 | + * It is the inverse operation of `UnStackStates`. | ||
| 26 | + * | ||
| 27 | + * @param states states[i] contains the state for the i-th utterance. | ||
| 28 | + * @return Return a single value representing the batched state. | ||
| 29 | + */ | ||
| 30 | + virtual Ort::Value StackStates( | ||
| 31 | + const std::vector<Ort::Value> &states) const = 0; | ||
| 32 | + | ||
| 33 | + /** Unstack a batch state into a list of individual states. | ||
| 34 | + * | ||
| 35 | + * It is the inverse operation of `StackStates`. | ||
| 36 | + * | ||
| 37 | + * @param states A batched state. | ||
| 38 | + * @return ans[i] contains the state for the i-th utterance. | ||
| 39 | + */ | ||
| 40 | + virtual std::vector<Ort::Value> UnStackStates(Ort::Value states) const = 0; | ||
| 41 | + | ||
| 42 | + /** Get the initial encoder states. | ||
| 43 | + * | ||
| 44 | + * @return Return the initial encoder state. | ||
| 45 | + */ | ||
| 46 | + virtual std::vector<Ort::Value> GetEncoderInitStates() = 0; | ||
| 47 | + | ||
| 48 | + /** Run the encoder. | ||
| 49 | + * | ||
| 50 | + * @param features A tensor of shape (N, T, C). It is changed in-place. | ||
| 51 | + * @param states Encoder state of the previous chunk. It is changed in-place. | ||
| 52 | + * | ||
| 53 | + * @return Return a tuple containing: | ||
| 54 | + * - encoder_out, a tensor of shape (N, T', encoder_out_dim) | ||
| 55 | + * - next_states Encoder state for the next chunk. | ||
| 56 | + */ | ||
| 57 | + virtual std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( | ||
| 58 | + Ort::Value features, | ||
| 59 | + std::vector<Ort::Value> &states) = 0; // NOLINT | ||
| 60 | + | ||
| 61 | + virtual Ort::Value BuildDecoderInput(const std::vector<int64_t> &hyp) = 0; | ||
| 62 | + | ||
| 63 | + /** Run the decoder network. | ||
| 64 | + * | ||
| 65 | + * Caution: We assume there are no recurrent connections in the decoder and | ||
| 66 | + * the decoder is stateless. See | ||
| 67 | + * https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py | ||
| 68 | + * for an example | ||
| 69 | + * | ||
| 70 | + * @param decoder_input It is usually of shape (N, context_size) | ||
| 71 | + * @return Return a tensor of shape (N, decoder_dim). | ||
| 72 | + */ | ||
| 73 | + virtual Ort::Value RunDecoder(Ort::Value decoder_input) = 0; | ||
| 74 | + | ||
| 75 | + /** Run the joint network. | ||
| 76 | + * | ||
| 77 | + * @param encoder_out Output of the encoder network. A tensor of shape | ||
| 78 | + * (N, joiner_dim). | ||
| 79 | + * @param decoder_out Output of the decoder network. A tensor of shape | ||
| 80 | + * (N, joiner_dim). | ||
| 81 | + * @return Return a tensor of shape (N, vocab_size). In icefall, the last | ||
| 82 | + * last layer of the joint network is `nn.Linear`, | ||
| 83 | + * not `nn.LogSoftmax`. | ||
| 84 | + */ | ||
| 85 | + virtual Ort::Value RunJoiner(Ort::Value encoder_out, | ||
| 86 | + Ort::Value decoder_out) = 0; | ||
| 87 | + | ||
| 88 | + /** If we are using a stateless decoder and if it contains a | ||
| 89 | + * Conv1D, this function returns the kernel size of the convolution layer. | ||
| 90 | + */ | ||
| 91 | + virtual int32_t ContextSize() const = 0; | ||
| 92 | + | ||
| 93 | + /** We send this number of feature frames to the encoder at a time. */ | ||
| 94 | + virtual int32_t ChunkSize() const = 0; | ||
| 95 | + | ||
| 96 | + /** Number of input frames to discard after each call to RunEncoder. | ||
| 97 | + * | ||
| 98 | + * For instance, if we have 30 frames, chunk_size=8, chunk_shift=6. | ||
| 99 | + * | ||
| 100 | + * In the first call of RunEncoder, we use frames 0~7 since chunk_size is 8. | ||
| 101 | + * Then we discard frame 0~5 since chunk_shift is 6. | ||
| 102 | + * In the second call of RunEncoder, we use frames 6~13; and then we discard | ||
| 103 | + * frames 6~11. | ||
| 104 | + * In the third call of RunEncoder, we use frames 12~19; and then we discard | ||
| 105 | + * frames 12~16. | ||
| 106 | + * | ||
| 107 | + * Note: ChunkSize() - ChunkShift() == right context size | ||
| 108 | + */ | ||
| 109 | + virtual int32_t ChunkShift() const = 0; | ||
| 110 | + | ||
| 111 | + virtual int32_t VocabSize() const = 0; | ||
| 112 | + | ||
| 113 | + virtual int32_t SubsamplingFactor() const { return 4; } | ||
| 114 | +}; | ||
| 115 | + | ||
| 116 | +} // namespace sherpa_onnx | ||
| 117 | + | ||
| 118 | +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_H_ |
sherpa-onnx/csrc/onnx-utils.cc
0 → 100644
| 1 | +// sherpa/csrc/onnx-utils.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 5 | + | ||
| 6 | +#include <string> | ||
| 7 | +#include <vector> | ||
| 8 | + | ||
| 9 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names, | ||
| 14 | + std::vector<const char *> *input_names_ptr) { | ||
| 15 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 16 | + size_t node_count = sess->GetInputCount(); | ||
| 17 | + input_names->resize(node_count); | ||
| 18 | + input_names_ptr->resize(node_count); | ||
| 19 | + for (size_t i = 0; i != node_count; ++i) { | ||
| 20 | + auto tmp = sess->GetInputNameAllocated(i, allocator); | ||
| 21 | + (*input_names)[i] = tmp.get(); | ||
| 22 | + (*input_names_ptr)[i] = (*input_names)[i].c_str(); | ||
| 23 | + } | ||
| 24 | +} | ||
| 25 | + | ||
| 26 | +void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names, | ||
| 27 | + std::vector<const char *> *output_names_ptr) { | ||
| 28 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 29 | + size_t node_count = sess->GetOutputCount(); | ||
| 30 | + output_names->resize(node_count); | ||
| 31 | + output_names_ptr->resize(node_count); | ||
| 32 | + for (size_t i = 0; i != node_count; ++i) { | ||
| 33 | + auto tmp = sess->GetOutputNameAllocated(i, allocator); | ||
| 34 | + (*output_names)[i] = tmp.get(); | ||
| 35 | + (*output_names_ptr)[i] = (*output_names)[i].c_str(); | ||
| 36 | + } | ||
| 37 | +} | ||
| 38 | + | ||
| 39 | +void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) { | ||
| 40 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 41 | + std::vector<Ort::AllocatedStringPtr> v = | ||
| 42 | + meta_data.GetCustomMetadataMapKeysAllocated(allocator); | ||
| 43 | + for (const auto &key : v) { | ||
| 44 | + auto p = meta_data.LookupCustomMetadataMapAllocated(key.get(), allocator); | ||
| 45 | + os << key.get() << "=" << p.get() << "\n"; | ||
| 46 | + } | ||
| 47 | +} | ||
| 48 | + | ||
| 49 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/onnx-utils.h
0 → 100644
| 1 | +// sherpa/csrc/onnx-utils.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_ONNX_UTILS_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_ONNX_UTILS_H_ | ||
| 6 | + | ||
| 7 | +#ifdef _MSC_VER | ||
| 8 | +// For ToWide() below | ||
| 9 | +#include <codecvt> | ||
| 10 | +#include <locale> | ||
| 11 | +#endif | ||
| 12 | + | ||
| 13 | +#include <ostream> | ||
| 14 | +#include <string> | ||
| 15 | +#include <vector> | ||
| 16 | + | ||
| 17 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 18 | + | ||
| 19 | +namespace sherpa_onnx { | ||
| 20 | + | ||
| 21 | +#ifdef _MSC_VER | ||
| 22 | +// See | ||
| 23 | +// https://stackoverflow.com/questions/2573834/c-convert-string-or-char-to-wstring-or-wchar-t | ||
| 24 | +static std::wstring ToWide(const std::string &s) { | ||
| 25 | + std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter; | ||
| 26 | + return converter.from_bytes(s); | ||
| 27 | +} | ||
| 28 | +#define SHERPA_MAYBE_WIDE(s) ToWide(s) | ||
| 29 | +#else | ||
| 30 | +#define SHERPA_MAYBE_WIDE(s) s | ||
| 31 | +#endif | ||
| 32 | + | ||
| 33 | +/** | ||
| 34 | + * Get the input names of a model. | ||
| 35 | + * | ||
| 36 | + * @param sess An onnxruntime session. | ||
| 37 | + * @param input_names. On return, it contains the input names of the model. | ||
| 38 | + * @param input_names_ptr. On return, input_names_ptr[i] contains | ||
| 39 | + * input_names[i].c_str() | ||
| 40 | + */ | ||
| 41 | +void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names, | ||
| 42 | + std::vector<const char *> *input_names_ptr); | ||
| 43 | + | ||
| 44 | +/** | ||
| 45 | + * Get the output names of a model. | ||
| 46 | + * | ||
| 47 | + * @param sess An onnxruntime session. | ||
| 48 | + * @param output_names. On return, it contains the output names of the model. | ||
| 49 | + * @param output_names_ptr. On return, output_names_ptr[i] contains | ||
| 50 | + * output_names[i].c_str() | ||
| 51 | + */ | ||
| 52 | +void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names, | ||
| 53 | + std::vector<const char *> *output_names_ptr); | ||
| 54 | + | ||
| 55 | +void PrintModelMetadata(std::ostream &os, | ||
| 56 | + const Ort::ModelMetadata &meta_data); // NOLINT | ||
| 57 | + | ||
| 58 | +} // namespace sherpa_onnx | ||
| 59 | + | ||
| 60 | +#endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_ |
sherpa-onnx/csrc/rnnt-model.cc
已删除
100644 → 0
| 1 | -/** | ||
| 2 | - * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) | ||
| 3 | - * | ||
| 4 | - * See LICENSE for clarification regarding multiple authors | ||
| 5 | - * | ||
| 6 | - * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 7 | - * you may not use this file except in compliance with the License. | ||
| 8 | - * You may obtain a copy of the License at | ||
| 9 | - * | ||
| 10 | - * http://www.apache.org/licenses/LICENSE-2.0 | ||
| 11 | - * | ||
| 12 | - * Unless required by applicable law or agreed to in writing, software | ||
| 13 | - * distributed under the License is distributed on an "AS IS" BASIS, | ||
| 14 | - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 15 | - * See the License for the specific language governing permissions and | ||
| 16 | - * limitations under the License. | ||
| 17 | - */ | ||
| 18 | -#include "sherpa-onnx/csrc/rnnt-model.h" | ||
| 19 | - | ||
| 20 | -#include <array> | ||
| 21 | -#include <utility> | ||
| 22 | -#include <vector> | ||
| 23 | - | ||
| 24 | -#ifdef _MSC_VER | ||
| 25 | -// For ToWide() below | ||
| 26 | -#include <codecvt> | ||
| 27 | -#include <locale> | ||
| 28 | -#endif | ||
| 29 | - | ||
| 30 | -namespace sherpa_onnx { | ||
| 31 | - | ||
| 32 | -#ifdef _MSC_VER | ||
| 33 | -// See | ||
| 34 | -// https://stackoverflow.com/questions/2573834/c-convert-string-or-char-to-wstring-or-wchar-t | ||
| 35 | -static std::wstring ToWide(const std::string &s) { | ||
| 36 | - std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter; | ||
| 37 | - return converter.from_bytes(s); | ||
| 38 | -} | ||
| 39 | -#define SHERPA_MAYBE_WIDE(s) ToWide(s) | ||
| 40 | -#else | ||
| 41 | -#define SHERPA_MAYBE_WIDE(s) s | ||
| 42 | -#endif | ||
| 43 | - | ||
| 44 | -/** | ||
| 45 | - * Get the input names of a model. | ||
| 46 | - * | ||
| 47 | - * @param sess An onnxruntime session. | ||
| 48 | - * @param input_names. On return, it contains the input names of the model. | ||
| 49 | - * @param input_names_ptr. On return, input_names_ptr[i] contains | ||
| 50 | - * input_names[i].c_str() | ||
| 51 | - */ | ||
| 52 | -static void GetInputNames(Ort::Session *sess, | ||
| 53 | - std::vector<std::string> *input_names, | ||
| 54 | - std::vector<const char *> *input_names_ptr) { | ||
| 55 | - Ort::AllocatorWithDefaultOptions allocator; | ||
| 56 | - size_t node_count = sess->GetInputCount(); | ||
| 57 | - input_names->resize(node_count); | ||
| 58 | - input_names_ptr->resize(node_count); | ||
| 59 | - for (size_t i = 0; i != node_count; ++i) { | ||
| 60 | - auto tmp = sess->GetInputNameAllocated(i, allocator); | ||
| 61 | - (*input_names)[i] = tmp.get(); | ||
| 62 | - (*input_names_ptr)[i] = (*input_names)[i].c_str(); | ||
| 63 | - } | ||
| 64 | -} | ||
| 65 | - | ||
| 66 | -/** | ||
| 67 | - * Get the output names of a model. | ||
| 68 | - * | ||
| 69 | - * @param sess An onnxruntime session. | ||
| 70 | - * @param output_names. On return, it contains the output names of the model. | ||
| 71 | - * @param output_names_ptr. On return, output_names_ptr[i] contains | ||
| 72 | - * output_names[i].c_str() | ||
| 73 | - */ | ||
| 74 | -static void GetOutputNames(Ort::Session *sess, | ||
| 75 | - std::vector<std::string> *output_names, | ||
| 76 | - std::vector<const char *> *output_names_ptr) { | ||
| 77 | - Ort::AllocatorWithDefaultOptions allocator; | ||
| 78 | - size_t node_count = sess->GetOutputCount(); | ||
| 79 | - output_names->resize(node_count); | ||
| 80 | - output_names_ptr->resize(node_count); | ||
| 81 | - for (size_t i = 0; i != node_count; ++i) { | ||
| 82 | - auto tmp = sess->GetOutputNameAllocated(i, allocator); | ||
| 83 | - (*output_names)[i] = tmp.get(); | ||
| 84 | - (*output_names_ptr)[i] = (*output_names)[i].c_str(); | ||
| 85 | - } | ||
| 86 | -} | ||
| 87 | - | ||
| 88 | -RnntModel::RnntModel(const std::string &encoder_filename, | ||
| 89 | - const std::string &decoder_filename, | ||
| 90 | - const std::string &joiner_filename, | ||
| 91 | - const std::string &joiner_encoder_proj_filename, | ||
| 92 | - const std::string &joiner_decoder_proj_filename, | ||
| 93 | - int32_t num_threads) | ||
| 94 | - : env_(ORT_LOGGING_LEVEL_WARNING) { | ||
| 95 | - sess_opts_.SetIntraOpNumThreads(num_threads); | ||
| 96 | - sess_opts_.SetInterOpNumThreads(num_threads); | ||
| 97 | - | ||
| 98 | - InitEncoder(encoder_filename); | ||
| 99 | - InitDecoder(decoder_filename); | ||
| 100 | - InitJoiner(joiner_filename); | ||
| 101 | - InitJoinerEncoderProj(joiner_encoder_proj_filename); | ||
| 102 | - InitJoinerDecoderProj(joiner_decoder_proj_filename); | ||
| 103 | -} | ||
| 104 | - | ||
| 105 | -void RnntModel::InitEncoder(const std::string &filename) { | ||
| 106 | - encoder_sess_ = std::make_unique<Ort::Session>( | ||
| 107 | - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); | ||
| 108 | - GetInputNames(encoder_sess_.get(), &encoder_input_names_, | ||
| 109 | - &encoder_input_names_ptr_); | ||
| 110 | - | ||
| 111 | - GetOutputNames(encoder_sess_.get(), &encoder_output_names_, | ||
| 112 | - &encoder_output_names_ptr_); | ||
| 113 | -} | ||
| 114 | - | ||
| 115 | -void RnntModel::InitDecoder(const std::string &filename) { | ||
| 116 | - decoder_sess_ = std::make_unique<Ort::Session>( | ||
| 117 | - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); | ||
| 118 | - | ||
| 119 | - GetInputNames(decoder_sess_.get(), &decoder_input_names_, | ||
| 120 | - &decoder_input_names_ptr_); | ||
| 121 | - | ||
| 122 | - GetOutputNames(decoder_sess_.get(), &decoder_output_names_, | ||
| 123 | - &decoder_output_names_ptr_); | ||
| 124 | -} | ||
| 125 | - | ||
| 126 | -void RnntModel::InitJoiner(const std::string &filename) { | ||
| 127 | - joiner_sess_ = std::make_unique<Ort::Session>( | ||
| 128 | - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); | ||
| 129 | - | ||
| 130 | - GetInputNames(joiner_sess_.get(), &joiner_input_names_, | ||
| 131 | - &joiner_input_names_ptr_); | ||
| 132 | - | ||
| 133 | - GetOutputNames(joiner_sess_.get(), &joiner_output_names_, | ||
| 134 | - &joiner_output_names_ptr_); | ||
| 135 | -} | ||
| 136 | - | ||
| 137 | -void RnntModel::InitJoinerEncoderProj(const std::string &filename) { | ||
| 138 | - joiner_encoder_proj_sess_ = std::make_unique<Ort::Session>( | ||
| 139 | - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); | ||
| 140 | - | ||
| 141 | - GetInputNames(joiner_encoder_proj_sess_.get(), | ||
| 142 | - &joiner_encoder_proj_input_names_, | ||
| 143 | - &joiner_encoder_proj_input_names_ptr_); | ||
| 144 | - | ||
| 145 | - GetOutputNames(joiner_encoder_proj_sess_.get(), | ||
| 146 | - &joiner_encoder_proj_output_names_, | ||
| 147 | - &joiner_encoder_proj_output_names_ptr_); | ||
| 148 | -} | ||
| 149 | - | ||
| 150 | -void RnntModel::InitJoinerDecoderProj(const std::string &filename) { | ||
| 151 | - joiner_decoder_proj_sess_ = std::make_unique<Ort::Session>( | ||
| 152 | - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); | ||
| 153 | - | ||
| 154 | - GetInputNames(joiner_decoder_proj_sess_.get(), | ||
| 155 | - &joiner_decoder_proj_input_names_, | ||
| 156 | - &joiner_decoder_proj_input_names_ptr_); | ||
| 157 | - | ||
| 158 | - GetOutputNames(joiner_decoder_proj_sess_.get(), | ||
| 159 | - &joiner_decoder_proj_output_names_, | ||
| 160 | - &joiner_decoder_proj_output_names_ptr_); | ||
| 161 | -} | ||
| 162 | - | ||
| 163 | -Ort::Value RnntModel::RunEncoder(const float *features, int32_t T, | ||
| 164 | - int32_t feature_dim) { | ||
| 165 | - auto memory_info = | ||
| 166 | - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 167 | - std::array<int64_t, 3> x_shape{1, T, feature_dim}; | ||
| 168 | - Ort::Value x = | ||
| 169 | - Ort::Value::CreateTensor(memory_info, const_cast<float *>(features), | ||
| 170 | - T * feature_dim, x_shape.data(), x_shape.size()); | ||
| 171 | - | ||
| 172 | - std::array<int64_t, 1> x_lens_shape{1}; | ||
| 173 | - int64_t x_lens_tmp = T; | ||
| 174 | - | ||
| 175 | - Ort::Value x_lens = Ort::Value::CreateTensor( | ||
| 176 | - memory_info, &x_lens_tmp, 1, x_lens_shape.data(), x_lens_shape.size()); | ||
| 177 | - | ||
| 178 | - std::array<Ort::Value, 2> encoder_inputs{std::move(x), std::move(x_lens)}; | ||
| 179 | - | ||
| 180 | - // Note: We discard encoder_out_lens since we only implement | ||
| 181 | - // batch==1. | ||
| 182 | - auto encoder_out = encoder_sess_->Run( | ||
| 183 | - {}, encoder_input_names_ptr_.data(), encoder_inputs.data(), | ||
| 184 | - encoder_inputs.size(), encoder_output_names_ptr_.data(), | ||
| 185 | - encoder_output_names_ptr_.size()); | ||
| 186 | - return std::move(encoder_out[0]); | ||
| 187 | -} | ||
| 188 | -Ort::Value RnntModel::RunJoinerEncoderProj(const float *encoder_out, int32_t T, | ||
| 189 | - int32_t encoder_out_dim) { | ||
| 190 | - auto memory_info = | ||
| 191 | - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 192 | - | ||
| 193 | - std::array<int64_t, 2> in_shape{T, encoder_out_dim}; | ||
| 194 | - Ort::Value in = Ort::Value::CreateTensor( | ||
| 195 | - memory_info, const_cast<float *>(encoder_out), T * encoder_out_dim, | ||
| 196 | - in_shape.data(), in_shape.size()); | ||
| 197 | - | ||
| 198 | - auto encoder_proj_out = joiner_encoder_proj_sess_->Run( | ||
| 199 | - {}, joiner_encoder_proj_input_names_ptr_.data(), &in, 1, | ||
| 200 | - joiner_encoder_proj_output_names_ptr_.data(), | ||
| 201 | - joiner_encoder_proj_output_names_ptr_.size()); | ||
| 202 | - return std::move(encoder_proj_out[0]); | ||
| 203 | -} | ||
| 204 | - | ||
| 205 | -Ort::Value RnntModel::RunDecoder(const int64_t *decoder_input, | ||
| 206 | - int32_t context_size) { | ||
| 207 | - auto memory_info = | ||
| 208 | - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 209 | - | ||
| 210 | - int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1 | ||
| 211 | - std::array<int64_t, 2> shape{batch_size, context_size}; | ||
| 212 | - Ort::Value in = Ort::Value::CreateTensor( | ||
| 213 | - memory_info, const_cast<int64_t *>(decoder_input), | ||
| 214 | - batch_size * context_size, shape.data(), shape.size()); | ||
| 215 | - | ||
| 216 | - auto decoder_out = decoder_sess_->Run( | ||
| 217 | - {}, decoder_input_names_ptr_.data(), &in, 1, | ||
| 218 | - decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size()); | ||
| 219 | - return std::move(decoder_out[0]); | ||
| 220 | -} | ||
| 221 | - | ||
| 222 | -Ort::Value RnntModel::RunJoinerDecoderProj(const float *decoder_out, | ||
| 223 | - int32_t decoder_out_dim) { | ||
| 224 | - auto memory_info = | ||
| 225 | - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 226 | - | ||
| 227 | - int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1 | ||
| 228 | - std::array<int64_t, 2> shape{batch_size, decoder_out_dim}; | ||
| 229 | - Ort::Value in = Ort::Value::CreateTensor( | ||
| 230 | - memory_info, const_cast<float *>(decoder_out), | ||
| 231 | - batch_size * decoder_out_dim, shape.data(), shape.size()); | ||
| 232 | - | ||
| 233 | - auto decoder_proj_out = joiner_decoder_proj_sess_->Run( | ||
| 234 | - {}, joiner_decoder_proj_input_names_ptr_.data(), &in, 1, | ||
| 235 | - joiner_decoder_proj_output_names_ptr_.data(), | ||
| 236 | - joiner_decoder_proj_output_names_ptr_.size()); | ||
| 237 | - return std::move(decoder_proj_out[0]); | ||
| 238 | -} | ||
| 239 | - | ||
| 240 | -Ort::Value RnntModel::RunJoiner(const float *projected_encoder_out, | ||
| 241 | - const float *projected_decoder_out, | ||
| 242 | - int32_t joiner_dim) { | ||
| 243 | - auto memory_info = | ||
| 244 | - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 245 | - int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1 | ||
| 246 | - std::array<int64_t, 2> shape{batch_size, joiner_dim}; | ||
| 247 | - | ||
| 248 | - Ort::Value enc = Ort::Value::CreateTensor( | ||
| 249 | - memory_info, const_cast<float *>(projected_encoder_out), | ||
| 250 | - batch_size * joiner_dim, shape.data(), shape.size()); | ||
| 251 | - | ||
| 252 | - Ort::Value dec = Ort::Value::CreateTensor( | ||
| 253 | - memory_info, const_cast<float *>(projected_decoder_out), | ||
| 254 | - batch_size * joiner_dim, shape.data(), shape.size()); | ||
| 255 | - | ||
| 256 | - std::array<Ort::Value, 2> inputs{std::move(enc), std::move(dec)}; | ||
| 257 | - | ||
| 258 | - auto logit = joiner_sess_->Run( | ||
| 259 | - {}, joiner_input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 260 | - joiner_output_names_ptr_.data(), joiner_output_names_ptr_.size()); | ||
| 261 | - | ||
| 262 | - return std::move(logit[0]); | ||
| 263 | -} | ||
| 264 | - | ||
| 265 | -} // namespace sherpa_onnx |
sherpa-onnx/csrc/rnnt-model.h
已删除
100644 → 0
| 1 | -/** | ||
| 2 | - * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) | ||
| 3 | - * | ||
| 4 | - * See LICENSE for clarification regarding multiple authors | ||
| 5 | - * | ||
| 6 | - * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 7 | - * you may not use this file except in compliance with the License. | ||
| 8 | - * You may obtain a copy of the License at | ||
| 9 | - * | ||
| 10 | - * http://www.apache.org/licenses/LICENSE-2.0 | ||
| 11 | - * | ||
| 12 | - * Unless required by applicable law or agreed to in writing, software | ||
| 13 | - * distributed under the License is distributed on an "AS IS" BASIS, | ||
| 14 | - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 15 | - * See the License for the specific language governing permissions and | ||
| 16 | - * limitations under the License. | ||
| 17 | - */ | ||
| 18 | - | ||
| 19 | -#ifndef SHERPA_ONNX_CSRC_RNNT_MODEL_H_ | ||
| 20 | -#define SHERPA_ONNX_CSRC_RNNT_MODEL_H_ | ||
| 21 | - | ||
| 22 | -#include <memory> | ||
| 23 | -#include <string> | ||
| 24 | -#include <vector> | ||
| 25 | - | ||
| 26 | -#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 27 | - | ||
| 28 | -namespace sherpa_onnx { | ||
| 29 | - | ||
| 30 | -class RnntModel { | ||
| 31 | - public: | ||
| 32 | - /** | ||
| 33 | - * @param encoder_filename Path to the encoder model | ||
| 34 | - * @param decoder_filename Path to the decoder model | ||
| 35 | - * @param joiner_filename Path to the joiner model | ||
| 36 | - * @param joiner_encoder_proj_filename Path to the joiner encoder_proj model | ||
| 37 | - * @param joiner_decoder_proj_filename Path to the joiner decoder_proj model | ||
| 38 | - * @param num_threads Number of threads to use to run the models | ||
| 39 | - */ | ||
| 40 | - RnntModel(const std::string &encoder_filename, | ||
| 41 | - const std::string &decoder_filename, | ||
| 42 | - const std::string &joiner_filename, | ||
| 43 | - const std::string &joiner_encoder_proj_filename, | ||
| 44 | - const std::string &joiner_decoder_proj_filename, | ||
| 45 | - int32_t num_threads); | ||
| 46 | - | ||
| 47 | - /** Run the encoder model. | ||
| 48 | - * | ||
| 49 | - * @TODO(fangjun): Support batch_size > 1 | ||
| 50 | - * | ||
| 51 | - * @param features A tensor of shape (batch_size, T, feature_dim) | ||
| 52 | - * @param T Number of feature frames | ||
| 53 | - * @param feature_dim Dimension of the feature. | ||
| 54 | - * | ||
| 55 | - * @return Return a tensor of shape (batch_size, T', encoder_out_dim) | ||
| 56 | - */ | ||
| 57 | - Ort::Value RunEncoder(const float *features, int32_t T, int32_t feature_dim); | ||
| 58 | - | ||
| 59 | - /** Run the joiner encoder_proj model. | ||
| 60 | - * | ||
| 61 | - * @param encoder_out A tensor of shape (T, encoder_out_dim) | ||
| 62 | - * @param T Number of frames in encoder_out. | ||
| 63 | - * @param encoder_out_dim Dimension of encoder_out. | ||
| 64 | - * | ||
| 65 | - * @return Return a tensor of shape (T, joiner_dim) | ||
| 66 | - * | ||
| 67 | - */ | ||
| 68 | - Ort::Value RunJoinerEncoderProj(const float *encoder_out, int32_t T, | ||
| 69 | - int32_t encoder_out_dim); | ||
| 70 | - | ||
| 71 | - /** Run the decoder model. | ||
| 72 | - * | ||
| 73 | - * @TODO(fangjun): Support batch_size > 1 | ||
| 74 | - * | ||
| 75 | - * @param decoder_input A tensor of shape (batch_size, context_size). | ||
| 76 | - * @return Return a tensor of shape (batch_size, 1, decoder_out_dim) | ||
| 77 | - */ | ||
| 78 | - Ort::Value RunDecoder(const int64_t *decoder_input, int32_t context_size); | ||
| 79 | - | ||
| 80 | - /** Run joiner decoder_proj model. | ||
| 81 | - * | ||
| 82 | - * @TODO(fangjun): Support batch_size > 1 | ||
| 83 | - * | ||
| 84 | - * @param decoder_out A tensor of shape (batch_size, decoder_out_dim) | ||
| 85 | - * @param decoder_out_dim Output dimension of the decoder_out. | ||
| 86 | - * | ||
| 87 | - * @return Return a tensor of shape (batch_size, joiner_dim); | ||
| 88 | - */ | ||
| 89 | - Ort::Value RunJoinerDecoderProj(const float *decoder_out, | ||
| 90 | - int32_t decoder_out_dim); | ||
| 91 | - | ||
| 92 | - /** Run the joiner model. | ||
| 93 | - * | ||
| 94 | - * @TODO(fangjun): Support batch_size > 1 | ||
| 95 | - * | ||
| 96 | - * @param projected_encoder_out A tensor of shape (batch_size, joiner_dim). | ||
| 97 | - * @param projected_decoder_out A tensor of shape (batch_size, joiner_dim). | ||
| 98 | - * | ||
| 99 | - * @return Return a tensor of shape (batch_size, vocab_size) | ||
| 100 | - */ | ||
| 101 | - Ort::Value RunJoiner(const float *projected_encoder_out, | ||
| 102 | - const float *projected_decoder_out, int32_t joiner_dim); | ||
| 103 | - | ||
| 104 | - private: | ||
| 105 | - void InitEncoder(const std::string &encoder_filename); | ||
| 106 | - void InitDecoder(const std::string &decoder_filename); | ||
| 107 | - void InitJoiner(const std::string &joiner_filename); | ||
| 108 | - void InitJoinerEncoderProj(const std::string &joiner_encoder_proj_filename); | ||
| 109 | - void InitJoinerDecoderProj(const std::string &joiner_decoder_proj_filename); | ||
| 110 | - | ||
| 111 | - private: | ||
| 112 | - Ort::Env env_; | ||
| 113 | - Ort::SessionOptions sess_opts_; | ||
| 114 | - std::unique_ptr<Ort::Session> encoder_sess_; | ||
| 115 | - std::unique_ptr<Ort::Session> decoder_sess_; | ||
| 116 | - std::unique_ptr<Ort::Session> joiner_sess_; | ||
| 117 | - std::unique_ptr<Ort::Session> joiner_encoder_proj_sess_; | ||
| 118 | - std::unique_ptr<Ort::Session> joiner_decoder_proj_sess_; | ||
| 119 | - | ||
| 120 | - std::vector<std::string> encoder_input_names_; | ||
| 121 | - std::vector<const char *> encoder_input_names_ptr_; | ||
| 122 | - std::vector<std::string> encoder_output_names_; | ||
| 123 | - std::vector<const char *> encoder_output_names_ptr_; | ||
| 124 | - | ||
| 125 | - std::vector<std::string> decoder_input_names_; | ||
| 126 | - std::vector<const char *> decoder_input_names_ptr_; | ||
| 127 | - std::vector<std::string> decoder_output_names_; | ||
| 128 | - std::vector<const char *> decoder_output_names_ptr_; | ||
| 129 | - | ||
| 130 | - std::vector<std::string> joiner_input_names_; | ||
| 131 | - std::vector<const char *> joiner_input_names_ptr_; | ||
| 132 | - std::vector<std::string> joiner_output_names_; | ||
| 133 | - std::vector<const char *> joiner_output_names_ptr_; | ||
| 134 | - | ||
| 135 | - std::vector<std::string> joiner_encoder_proj_input_names_; | ||
| 136 | - std::vector<const char *> joiner_encoder_proj_input_names_ptr_; | ||
| 137 | - std::vector<std::string> joiner_encoder_proj_output_names_; | ||
| 138 | - std::vector<const char *> joiner_encoder_proj_output_names_ptr_; | ||
| 139 | - | ||
| 140 | - std::vector<std::string> joiner_decoder_proj_input_names_; | ||
| 141 | - std::vector<const char *> joiner_decoder_proj_input_names_ptr_; | ||
| 142 | - std::vector<std::string> joiner_decoder_proj_output_names_; | ||
| 143 | - std::vector<const char *> joiner_decoder_proj_output_names_ptr_; | ||
| 144 | -}; | ||
| 145 | - | ||
| 146 | -} // namespace sherpa_onnx | ||
| 147 | - | ||
| 148 | -#endif // SHERPA_ONNX_CSRC_RNNT_MODEL_H_ |
| 1 | -/** | ||
| 2 | - * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) | ||
| 3 | - * | ||
| 4 | - * See LICENSE for clarification regarding multiple authors | ||
| 5 | - * | ||
| 6 | - * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 7 | - * you may not use this file except in compliance with the License. | ||
| 8 | - * You may obtain a copy of the License at | ||
| 9 | - * | ||
| 10 | - * http://www.apache.org/licenses/LICENSE-2.0 | ||
| 11 | - * | ||
| 12 | - * Unless required by applicable law or agreed to in writing, software | ||
| 13 | - * distributed under the License is distributed on an "AS IS" BASIS, | ||
| 14 | - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 15 | - * See the License for the specific language governing permissions and | ||
| 16 | - * limitations under the License. | ||
| 17 | - */ | 1 | +// sherpa-onnx/csrc/sherpa-onnx.cc |
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 18 | 4 | ||
| 5 | +#include <chrono> // NOLINT | ||
| 19 | #include <iostream> | 6 | #include <iostream> |
| 20 | #include <string> | 7 | #include <string> |
| 21 | #include <vector> | 8 | #include <vector> |
| 22 | 9 | ||
| 23 | #include "kaldi-native-fbank/csrc/online-feature.h" | 10 | #include "kaldi-native-fbank/csrc/online-feature.h" |
| 24 | #include "sherpa-onnx/csrc/decode.h" | 11 | #include "sherpa-onnx/csrc/decode.h" |
| 25 | -#include "sherpa-onnx/csrc/rnnt-model.h" | 12 | +#include "sherpa-onnx/csrc/features.h" |
| 13 | +#include "sherpa-onnx/csrc/online-transducer-model-config.h" | ||
| 14 | +#include "sherpa-onnx/csrc/online-transducer-model.h" | ||
| 26 | #include "sherpa-onnx/csrc/symbol-table.h" | 15 | #include "sherpa-onnx/csrc/symbol-table.h" |
| 27 | #include "sherpa-onnx/csrc/wave-reader.h" | 16 | #include "sherpa-onnx/csrc/wave-reader.h" |
| 28 | 17 | ||
| 29 | -/** Compute fbank features of the input wave filename. | ||
| 30 | - * | ||
| 31 | - * @param wav_filename. Path to a mono wave file. | ||
| 32 | - * @param expected_sampling_rate Expected sampling rate of the input wave file. | ||
| 33 | - * @param num_frames On return, it contains the number of feature frames. | ||
| 34 | - * @return Return the computed feature of shape (num_frames, feature_dim) | ||
| 35 | - * stored in row-major. | ||
| 36 | - */ | ||
| 37 | -static std::vector<float> ComputeFeatures(const std::string &wav_filename, | ||
| 38 | - float expected_sampling_rate, | ||
| 39 | - int32_t *num_frames) { | ||
| 40 | - std::vector<float> samples = | ||
| 41 | - sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate); | ||
| 42 | - | ||
| 43 | - float duration = samples.size() / expected_sampling_rate; | ||
| 44 | - | ||
| 45 | - std::cout << "wav filename: " << wav_filename << "\n"; | ||
| 46 | - std::cout << "wav duration (s): " << duration << "\n"; | ||
| 47 | - | ||
| 48 | - knf::FbankOptions opts; | ||
| 49 | - opts.frame_opts.dither = 0; | ||
| 50 | - opts.frame_opts.snip_edges = false; | ||
| 51 | - opts.frame_opts.samp_freq = expected_sampling_rate; | ||
| 52 | - | ||
| 53 | - int32_t feature_dim = 80; | ||
| 54 | - | ||
| 55 | - opts.mel_opts.num_bins = feature_dim; | ||
| 56 | - | ||
| 57 | - knf::OnlineFbank fbank(opts); | ||
| 58 | - fbank.AcceptWaveform(expected_sampling_rate, samples.data(), samples.size()); | ||
| 59 | - fbank.InputFinished(); | ||
| 60 | - | ||
| 61 | - *num_frames = fbank.NumFramesReady(); | ||
| 62 | - | ||
| 63 | - std::vector<float> features(*num_frames * feature_dim); | ||
| 64 | - float *p = features.data(); | ||
| 65 | - | ||
| 66 | - for (int32_t i = 0; i != fbank.NumFramesReady(); ++i, p += feature_dim) { | ||
| 67 | - const float *f = fbank.GetFrame(i); | ||
| 68 | - std::copy(f, f + feature_dim, p); | ||
| 69 | - } | ||
| 70 | - | ||
| 71 | - return features; | ||
| 72 | -} | ||
| 73 | - | ||
| 74 | int main(int32_t argc, char *argv[]) { | 18 | int main(int32_t argc, char *argv[]) { |
| 75 | - if (argc < 8 || argc > 9) { | 19 | + if (argc < 6 || argc > 7) { |
| 76 | const char *usage = R"usage( | 20 | const char *usage = R"usage( |
| 77 | Usage: | 21 | Usage: |
| 78 | ./bin/sherpa-onnx \ | 22 | ./bin/sherpa-onnx \ |
| @@ -80,12 +24,11 @@ Usage: | @@ -80,12 +24,11 @@ Usage: | ||
| 80 | /path/to/encoder.onnx \ | 24 | /path/to/encoder.onnx \ |
| 81 | /path/to/decoder.onnx \ | 25 | /path/to/decoder.onnx \ |
| 82 | /path/to/joiner.onnx \ | 26 | /path/to/joiner.onnx \ |
| 83 | - /path/to/joiner_encoder_proj.onnx \ | ||
| 84 | - /path/to/joiner_decoder_proj.onnx \ | ||
| 85 | /path/to/foo.wav [num_threads] | 27 | /path/to/foo.wav [num_threads] |
| 86 | 28 | ||
| 87 | -You can download pre-trained models from the following repository: | ||
| 88 | -https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 | 29 | +Please refer to |
| 30 | +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html | ||
| 31 | +for a list of pre-trained models to download. | ||
| 89 | )usage"; | 32 | )usage"; |
| 90 | std::cerr << usage << "\n"; | 33 | std::cerr << usage << "\n"; |
| 91 | 34 | ||
| @@ -93,37 +36,102 @@ https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stat | @@ -93,37 +36,102 @@ https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stat | ||
| 93 | } | 36 | } |
| 94 | 37 | ||
| 95 | std::string tokens = argv[1]; | 38 | std::string tokens = argv[1]; |
| 96 | - std::string encoder = argv[2]; | ||
| 97 | - std::string decoder = argv[3]; | ||
| 98 | - std::string joiner = argv[4]; | ||
| 99 | - std::string joiner_encoder_proj = argv[5]; | ||
| 100 | - std::string joiner_decoder_proj = argv[6]; | ||
| 101 | - std::string wav_filename = argv[7]; | ||
| 102 | - int32_t num_threads = 4; | ||
| 103 | - if (argc == 9) { | ||
| 104 | - num_threads = atoi(argv[8]); | 39 | + sherpa_onnx::OnlineTransducerModelConfig config; |
| 40 | + config.debug = true; | ||
| 41 | + config.encoder_filename = argv[2]; | ||
| 42 | + config.decoder_filename = argv[3]; | ||
| 43 | + config.joiner_filename = argv[4]; | ||
| 44 | + std::string wav_filename = argv[5]; | ||
| 45 | + | ||
| 46 | + config.num_threads = 2; | ||
| 47 | + if (argc == 7) { | ||
| 48 | + config.num_threads = atoi(argv[6]); | ||
| 105 | } | 49 | } |
| 50 | + std::cout << config.ToString().c_str() << "\n"; | ||
| 51 | + | ||
| 52 | + auto model = sherpa_onnx::OnlineTransducerModel::Create(config); | ||
| 106 | 53 | ||
| 107 | sherpa_onnx::SymbolTable sym(tokens); | 54 | sherpa_onnx::SymbolTable sym(tokens); |
| 108 | 55 | ||
| 109 | - int32_t num_frames; | ||
| 110 | - auto features = ComputeFeatures(wav_filename, 16000, &num_frames); | ||
| 111 | - int32_t feature_dim = features.size() / num_frames; | 56 | + Ort::AllocatorWithDefaultOptions allocator; |
| 57 | + | ||
| 58 | + int32_t chunk_size = model->ChunkSize(); | ||
| 59 | + int32_t chunk_shift = model->ChunkShift(); | ||
| 60 | + | ||
| 61 | + auto memory_info = | ||
| 62 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 63 | + | ||
| 64 | + std::vector<Ort::Value> states = model->GetEncoderInitStates(); | ||
| 112 | 65 | ||
| 113 | - sherpa_onnx::RnntModel model(encoder, decoder, joiner, joiner_encoder_proj, | ||
| 114 | - joiner_decoder_proj, num_threads); | ||
| 115 | - Ort::Value encoder_out = | ||
| 116 | - model.RunEncoder(features.data(), num_frames, feature_dim); | 66 | + std::vector<int64_t> hyp(model->ContextSize(), 0); |
| 117 | 67 | ||
| 118 | - auto hyp = sherpa_onnx::GreedySearch(model, encoder_out); | 68 | + int32_t expected_sampling_rate = 16000; |
| 119 | 69 | ||
| 70 | + bool is_ok = false; | ||
| 71 | + std::vector<float> samples = | ||
| 72 | + sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate, &is_ok); | ||
| 73 | + | ||
| 74 | + if (!is_ok) { | ||
| 75 | + std::cerr << "Failed to read " << wav_filename << "\n"; | ||
| 76 | + return -1; | ||
| 77 | + } | ||
| 78 | + | ||
| 79 | + const float duration = samples.size() / expected_sampling_rate; | ||
| 80 | + | ||
| 81 | + std::cout << "wav filename: " << wav_filename << "\n"; | ||
| 82 | + std::cout << "wav duration (s): " << duration << "\n"; | ||
| 83 | + | ||
| 84 | + auto begin = std::chrono::steady_clock::now(); | ||
| 85 | + std::cout << "Started!\n"; | ||
| 86 | + | ||
| 87 | + sherpa_onnx::FeatureExtractor feat_extractor; | ||
| 88 | + feat_extractor.AcceptWaveform(expected_sampling_rate, samples.data(), | ||
| 89 | + samples.size()); | ||
| 90 | + | ||
| 91 | + std::vector<float> tail_paddings( | ||
| 92 | + static_cast<int>(0.2 * expected_sampling_rate)); | ||
| 93 | + feat_extractor.AcceptWaveform(expected_sampling_rate, tail_paddings.data(), | ||
| 94 | + tail_paddings.size()); | ||
| 95 | + feat_extractor.InputFinished(); | ||
| 96 | + | ||
| 97 | + int32_t num_frames = feat_extractor.NumFramesReady(); | ||
| 98 | + int32_t feature_dim = feat_extractor.FeatureDim(); | ||
| 99 | + | ||
| 100 | + std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim}; | ||
| 101 | + | ||
| 102 | + for (int32_t start = 0; start + chunk_size < num_frames; | ||
| 103 | + start += chunk_shift) { | ||
| 104 | + std::vector<float> features = feat_extractor.GetFrames(start, chunk_size); | ||
| 105 | + | ||
| 106 | + Ort::Value x = | ||
| 107 | + Ort::Value::CreateTensor(memory_info, features.data(), features.size(), | ||
| 108 | + x_shape.data(), x_shape.size()); | ||
| 109 | + auto pair = model->RunEncoder(std::move(x), states); | ||
| 110 | + states = std::move(pair.second); | ||
| 111 | + sherpa_onnx::GreedySearch(model.get(), std::move(pair.first), &hyp); | ||
| 112 | + } | ||
| 120 | std::string text; | 113 | std::string text; |
| 121 | - for (auto i : hyp) { | ||
| 122 | - text += sym[i]; | 114 | + for (size_t i = model->ContextSize(); i != hyp.size(); ++i) { |
| 115 | + text += sym[hyp[i]]; | ||
| 123 | } | 116 | } |
| 124 | 117 | ||
| 118 | + std::cout << "Done!\n"; | ||
| 119 | + | ||
| 125 | std::cout << "Recognition result for " << wav_filename << "\n" | 120 | std::cout << "Recognition result for " << wav_filename << "\n" |
| 126 | << text << "\n"; | 121 | << text << "\n"; |
| 127 | 122 | ||
| 123 | + auto end = std::chrono::steady_clock::now(); | ||
| 124 | + float elapsed_seconds = | ||
| 125 | + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin) | ||
| 126 | + .count() / | ||
| 127 | + 1000.; | ||
| 128 | + | ||
| 129 | + std::cout << "num threads: " << config.num_threads << "\n"; | ||
| 130 | + | ||
| 131 | + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); | ||
| 132 | + float rtf = elapsed_seconds / duration; | ||
| 133 | + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", | ||
| 134 | + elapsed_seconds, duration, rtf); | ||
| 135 | + | ||
| 128 | return 0; | 136 | return 0; |
| 129 | } | 137 | } |
| 1 | -/** | ||
| 2 | - * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) | ||
| 3 | - * | ||
| 4 | - * See LICENSE for clarification regarding multiple authors | ||
| 5 | - * | ||
| 6 | - * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 7 | - * you may not use this file except in compliance with the License. | ||
| 8 | - * You may obtain a copy of the License at | ||
| 9 | - * | ||
| 10 | - * http://www.apache.org/licenses/LICENSE-2.0 | ||
| 11 | - * | ||
| 12 | - * Unless required by applicable law or agreed to in writing, software | ||
| 13 | - * distributed under the License is distributed on an "AS IS" BASIS, | ||
| 14 | - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 15 | - * See the License for the specific language governing permissions and | ||
| 16 | - * limitations under the License. | ||
| 17 | - */ | 1 | +// sherpa-onnx/csrc/show-onnx-info.cc |
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | + | ||
| 18 | #include <iostream> | 5 | #include <iostream> |
| 19 | #include <sstream> | 6 | #include <sstream> |
| 20 | 7 |
| 1 | -/** | ||
| 2 | - * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang) | ||
| 3 | - * | ||
| 4 | - * See LICENSE for clarification regarding multiple authors | ||
| 5 | - * | ||
| 6 | - * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 7 | - * you may not use this file except in compliance with the License. | ||
| 8 | - * You may obtain a copy of the License at | ||
| 9 | - * | ||
| 10 | - * http://www.apache.org/licenses/LICENSE-2.0 | ||
| 11 | - * | ||
| 12 | - * Unless required by applicable law or agreed to in writing, software | ||
| 13 | - * distributed under the License is distributed on an "AS IS" BASIS, | ||
| 14 | - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 15 | - * See the License for the specific language governing permissions and | ||
| 16 | - * limitations under the License. | ||
| 17 | - */ | 1 | +// sherpa-onnx/csrc/symbol-table.cc |
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 18 | 4 | ||
| 19 | #include "sherpa-onnx/csrc/symbol-table.h" | 5 | #include "sherpa-onnx/csrc/symbol-table.h" |
| 20 | 6 |
| 1 | -/** | ||
| 2 | - * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang) | ||
| 3 | - * | ||
| 4 | - * See LICENSE for clarification regarding multiple authors | ||
| 5 | - * | ||
| 6 | - * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 7 | - * you may not use this file except in compliance with the License. | ||
| 8 | - * You may obtain a copy of the License at | ||
| 9 | - * | ||
| 10 | - * http://www.apache.org/licenses/LICENSE-2.0 | ||
| 11 | - * | ||
| 12 | - * Unless required by applicable law or agreed to in writing, software | ||
| 13 | - * distributed under the License is distributed on an "AS IS" BASIS, | ||
| 14 | - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 15 | - * See the License for the specific language governing permissions and | ||
| 16 | - * limitations under the License. | ||
| 17 | - */ | 1 | +// sherpa-onnx/csrc/symbol-table.cc |
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 18 | 4 | ||
| 19 | #ifndef SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_ | 5 | #ifndef SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_ |
| 20 | #define SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_ | 6 | #define SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_ |
| 1 | -/** | ||
| 2 | - * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang) | ||
| 3 | - * | ||
| 4 | - * See LICENSE for clarification regarding multiple authors | ||
| 5 | - * | ||
| 6 | - * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 7 | - * you may not use this file except in compliance with the License. | ||
| 8 | - * You may obtain a copy of the License at | ||
| 9 | - * | ||
| 10 | - * http://www.apache.org/licenses/LICENSE-2.0 | ||
| 11 | - * | ||
| 12 | - * Unless required by applicable law or agreed to in writing, software | ||
| 13 | - * distributed under the License is distributed on an "AS IS" BASIS, | ||
| 14 | - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 15 | - * See the License for the specific language governing permissions and | ||
| 16 | - * limitations under the License. | ||
| 17 | - */ | 1 | +// sherpa/csrc/wave-reader.cc |
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 18 | 4 | ||
| 19 | #include "sherpa-onnx/csrc/wave-reader.h" | 5 | #include "sherpa-onnx/csrc/wave-reader.h" |
| 20 | 6 | ||
| @@ -31,19 +17,44 @@ namespace { | @@ -31,19 +17,44 @@ namespace { | ||
| 31 | // Note: We assume little endian here | 17 | // Note: We assume little endian here |
| 32 | // TODO(fangjun): Support big endian | 18 | // TODO(fangjun): Support big endian |
| 33 | struct WaveHeader { | 19 | struct WaveHeader { |
| 34 | - void Validate() const { | ||
| 35 | - // F F I R | ||
| 36 | - assert(chunk_id == 0x46464952); | ||
| 37 | - assert(chunk_size == 36 + subchunk2_size); | ||
| 38 | - // E V A W | ||
| 39 | - assert(format == 0x45564157); | ||
| 40 | - assert(subchunk1_id == 0x20746d66); | ||
| 41 | - assert(subchunk1_size == 16); // 16 for PCM | ||
| 42 | - assert(audio_format == 1); // 1 for PCM | ||
| 43 | - assert(num_channels == 1); // we support only single channel for now | ||
| 44 | - assert(byte_rate == sample_rate * num_channels * bits_per_sample / 8); | ||
| 45 | - assert(block_align == num_channels * bits_per_sample / 8); | ||
| 46 | - assert(bits_per_sample == 16); // we support only 16 bits per sample | 20 | + bool Validate() const { |
| 21 | + // F F I R | ||
| 22 | + if (chunk_id != 0x46464952) { | ||
| 23 | + return false; | ||
| 24 | + } | ||
| 25 | + // E V A W | ||
| 26 | + if (format != 0x45564157) { | ||
| 27 | + return false; | ||
| 28 | + } | ||
| 29 | + | ||
| 30 | + if (subchunk1_id != 0x20746d66) { | ||
| 31 | + return false; | ||
| 32 | + } | ||
| 33 | + | ||
| 34 | + if (subchunk1_size != 16) { // 16 for PCM | ||
| 35 | + return false; | ||
| 36 | + } | ||
| 37 | + | ||
| 38 | + if (audio_format != 1) { // 1 for PCM | ||
| 39 | + return false; | ||
| 40 | + } | ||
| 41 | + | ||
| 42 | + if (num_channels != 1) { // we support only single channel for now | ||
| 43 | + return false; | ||
| 44 | + } | ||
| 45 | + if (byte_rate != (sample_rate * num_channels * bits_per_sample / 8)) { | ||
| 46 | + return false; | ||
| 47 | + } | ||
| 48 | + | ||
| 49 | + if (block_align != (num_channels * bits_per_sample / 8)) { | ||
| 50 | + return false; | ||
| 51 | + } | ||
| 52 | + | ||
| 53 | + if (bits_per_sample != 16) { // we support only 16 bits per sample | ||
| 54 | + return false; | ||
| 55 | + } | ||
| 56 | + | ||
| 57 | + return true; | ||
| 47 | } | 58 | } |
| 48 | 59 | ||
| 49 | // See | 60 | // See |
| @@ -52,7 +63,7 @@ struct WaveHeader { | @@ -52,7 +63,7 @@ struct WaveHeader { | ||
| 52 | // https://www.robotplanet.dk/audio/wav_meta_data/riff_mci.pdf | 63 | // https://www.robotplanet.dk/audio/wav_meta_data/riff_mci.pdf |
| 53 | void SeekToDataChunk(std::istream &is) { | 64 | void SeekToDataChunk(std::istream &is) { |
| 54 | // a t a d | 65 | // a t a d |
| 55 | - while (subchunk2_id != 0x61746164) { | 66 | + while (is && subchunk2_id != 0x61746164) { |
| 56 | // const char *p = reinterpret_cast<const char *>(&subchunk2_id); | 67 | // const char *p = reinterpret_cast<const char *>(&subchunk2_id); |
| 57 | // printf("Skip chunk (%x): %c%c%c%c of size: %d\n", subchunk2_id, p[0], | 68 | // printf("Skip chunk (%x): %c%c%c%c of size: %d\n", subchunk2_id, p[0], |
| 58 | // p[1], p[2], p[3], subchunk2_size); | 69 | // p[1], p[2], p[3], subchunk2_size); |
| @@ -80,44 +91,61 @@ static_assert(sizeof(WaveHeader) == 44, ""); | @@ -80,44 +91,61 @@ static_assert(sizeof(WaveHeader) == 44, ""); | ||
| 80 | 91 | ||
| 81 | // Read a wave file of mono-channel. | 92 | // Read a wave file of mono-channel. |
| 82 | // Return its samples normalized to the range [-1, 1). | 93 | // Return its samples normalized to the range [-1, 1). |
| 83 | -std::vector<float> ReadWaveImpl(std::istream &is, float *sample_rate) { | 94 | +std::vector<float> ReadWaveImpl(std::istream &is, float expected_sample_rate, |
| 95 | + bool *is_ok) { | ||
| 84 | WaveHeader header; | 96 | WaveHeader header; |
| 85 | is.read(reinterpret_cast<char *>(&header), sizeof(header)); | 97 | is.read(reinterpret_cast<char *>(&header), sizeof(header)); |
| 86 | - assert(static_cast<bool>(is)); | ||
| 87 | - header.Validate(); | 98 | + if (!is) { |
| 99 | + *is_ok = false; | ||
| 100 | + return {}; | ||
| 101 | + } | ||
| 102 | + | ||
| 103 | + if (!header.Validate()) { | ||
| 104 | + *is_ok = false; | ||
| 105 | + return {}; | ||
| 106 | + } | ||
| 88 | 107 | ||
| 89 | header.SeekToDataChunk(is); | 108 | header.SeekToDataChunk(is); |
| 109 | + if (!is) { | ||
| 110 | + *is_ok = false; | ||
| 111 | + return {}; | ||
| 112 | + } | ||
| 90 | 113 | ||
| 91 | - *sample_rate = header.sample_rate; | 114 | + if (expected_sample_rate != header.sample_rate) { |
| 115 | + *is_ok = false; | ||
| 116 | + return {}; | ||
| 117 | + } | ||
| 92 | 118 | ||
| 93 | // header.subchunk2_size contains the number of bytes in the data. | 119 | // header.subchunk2_size contains the number of bytes in the data. |
| 94 | // As we assume each sample contains two bytes, so it is divided by 2 here | 120 | // As we assume each sample contains two bytes, so it is divided by 2 here |
| 95 | std::vector<int16_t> samples(header.subchunk2_size / 2); | 121 | std::vector<int16_t> samples(header.subchunk2_size / 2); |
| 96 | 122 | ||
| 97 | is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size); | 123 | is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size); |
| 98 | - | ||
| 99 | - assert(static_cast<bool>(is)); | 124 | + if (!is) { |
| 125 | + *is_ok = false; | ||
| 126 | + return {}; | ||
| 127 | + } | ||
| 100 | 128 | ||
| 101 | std::vector<float> ans(samples.size()); | 129 | std::vector<float> ans(samples.size()); |
| 102 | for (int32_t i = 0; i != ans.size(); ++i) { | 130 | for (int32_t i = 0; i != ans.size(); ++i) { |
| 103 | ans[i] = samples[i] / 32768.; | 131 | ans[i] = samples[i] / 32768.; |
| 104 | } | 132 | } |
| 105 | 133 | ||
| 134 | + *is_ok = true; | ||
| 106 | return ans; | 135 | return ans; |
| 107 | } | 136 | } |
| 108 | 137 | ||
| 109 | } // namespace | 138 | } // namespace |
| 110 | 139 | ||
| 111 | std::vector<float> ReadWave(const std::string &filename, | 140 | std::vector<float> ReadWave(const std::string &filename, |
| 112 | - float expected_sample_rate) { | 141 | + float expected_sample_rate, bool *is_ok) { |
| 113 | std::ifstream is(filename, std::ifstream::binary); | 142 | std::ifstream is(filename, std::ifstream::binary); |
| 114 | - float sample_rate; | ||
| 115 | - auto samples = ReadWaveImpl(is, &sample_rate); | ||
| 116 | - if (expected_sample_rate != sample_rate) { | ||
| 117 | - std::cerr << "Expected sample rate: " << expected_sample_rate | ||
| 118 | - << ". Given: " << sample_rate << ".\n"; | ||
| 119 | - exit(-1); | ||
| 120 | - } | 143 | + return ReadWave(is, expected_sample_rate, is_ok); |
| 144 | +} | ||
| 145 | + | ||
| 146 | +std::vector<float> ReadWave(std::istream &is, float expected_sample_rate, | ||
| 147 | + bool *is_ok) { | ||
| 148 | + auto samples = ReadWaveImpl(is, expected_sample_rate, is_ok); | ||
| 121 | return samples; | 149 | return samples; |
| 122 | } | 150 | } |
| 123 | 151 |
| 1 | -/** | ||
| 2 | - * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang) | ||
| 3 | - * | ||
| 4 | - * See LICENSE for clarification regarding multiple authors | ||
| 5 | - * | ||
| 6 | - * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 7 | - * you may not use this file except in compliance with the License. | ||
| 8 | - * You may obtain a copy of the License at | ||
| 9 | - * | ||
| 10 | - * http://www.apache.org/licenses/LICENSE-2.0 | ||
| 11 | - * | ||
| 12 | - * Unless required by applicable law or agreed to in writing, software | ||
| 13 | - * distributed under the License is distributed on an "AS IS" BASIS, | ||
| 14 | - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 15 | - * See the License for the specific language governing permissions and | ||
| 16 | - * limitations under the License. | ||
| 17 | - */ | 1 | +// sherpa/csrc/wave-reader.h |
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 18 | 4 | ||
| 19 | #ifndef SHERPA_ONNX_CSRC_WAVE_READER_H_ | 5 | #ifndef SHERPA_ONNX_CSRC_WAVE_READER_H_ |
| 20 | #define SHERPA_ONNX_CSRC_WAVE_READER_H_ | 6 | #define SHERPA_ONNX_CSRC_WAVE_READER_H_ |
| @@ -30,11 +16,15 @@ namespace sherpa_onnx { | @@ -30,11 +16,15 @@ namespace sherpa_onnx { | ||
| 30 | @param filename Path to a wave file. It MUST be single channel, PCM encoded. | 16 | @param filename Path to a wave file. It MUST be single channel, PCM encoded. |
| 31 | @param expected_sample_rate Expected sample rate of the wave file. If the | 17 | @param expected_sample_rate Expected sample rate of the wave file. If the |
| 32 | sample rate don't match, it throws an exception. | 18 | sample rate don't match, it throws an exception. |
| 19 | + @param is_ok On return it is true if the reading succeeded; false otherwise. | ||
| 33 | 20 | ||
| 34 | @return Return wave samples normalized to the range [-1, 1). | 21 | @return Return wave samples normalized to the range [-1, 1). |
| 35 | */ | 22 | */ |
| 36 | std::vector<float> ReadWave(const std::string &filename, | 23 | std::vector<float> ReadWave(const std::string &filename, |
| 37 | - float expected_sample_rate); | 24 | + float expected_sample_rate, bool *is_ok); |
| 25 | + | ||
| 26 | +std::vector<float> ReadWave(std::istream &is, float expected_sample_rate, | ||
| 27 | + bool *is_ok); | ||
| 38 | 28 | ||
| 39 | } // namespace sherpa_onnx | 29 | } // namespace sherpa_onnx |
| 40 | 30 |
-
请 注册 或 登录 后发表评论