Fangjun Kuang
Committed by GitHub

Add online LSTM transducer model (#25)

  1 +#!/usr/bin/env bash
  2 +
  3 +set -e
  4 +
  5 +log() {
  6 + # This function is from espnet
  7 + local fname=${BASH_SOURCE[1]##*/}
  8 + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
  9 +}
  10 +
  11 +echo "EXE is $EXE"
  12 +echo "PATH: $PATH"
  13 +
  14 +which $EXE
  15 +
  16 +log "------------------------------------------------------------"
  17 +log "Run LSTM transducer (English)"
  18 +log "------------------------------------------------------------"
  19 +
  20 +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-lstm-en-2023-02-17
  21 +
  22 +log "Start testing ${repo_url}"
  23 +repo=$(basename $repo_url)
  24 +log "Download pretrained model and test-data from $repo_url"
  25 +
  26 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
  27 +pushd $repo
  28 +git lfs pull --include "*.onnx"
  29 +popd
  30 +
  31 +waves=(
  32 +$repo/test_wavs/1089-134686-0001.wav
  33 +$repo/test_wavs/1221-135766-0001.wav
  34 +$repo/test_wavs/1221-135766-0002.wav
  35 +)
  36 +
  37 +for wave in ${waves[@]}; do
  38 + time $EXE \
  39 + $repo/tokens.txt \
  40 + $repo/encoder-epoch-99-avg-1.onnx \
  41 + $repo/decoder-epoch-99-avg-1.onnx \
  42 + $repo/joiner-epoch-99-avg-1.onnx \
  43 + $wave \
  44 + 4
  45 +done
  46 +
  47 +rm -rf $repo
  1 +name: linux
  2 +
  3 +on:
  4 + push:
  5 + branches:
  6 + - master
  7 + paths:
  8 + - '.github/workflows/linux.yaml'
  9 + - '.github/scripts/test-online-transducer.sh'
  10 + - 'CMakeLists.txt'
  11 + - 'cmake/**'
  12 + - 'sherpa-onnx/csrc/*'
  13 + pull_request:
  14 + branches:
  15 + - master
  16 + paths:
  17 + - '.github/workflows/linux.yaml'
  18 + - '.github/scripts/test-online-transducer.sh'
  19 + - 'CMakeLists.txt'
  20 + - 'cmake/**'
  21 + - 'sherpa-onnx/csrc/*'
  22 +
  23 +concurrency:
  24 + group: linux-${{ github.ref }}
  25 + cancel-in-progress: true
  26 +
  27 +permissions:
  28 + contents: read
  29 +
  30 +jobs:
  31 + linux:
  32 + runs-on: ${{ matrix.os }}
  33 + strategy:
  34 + fail-fast: false
  35 + matrix:
  36 + os: [ubuntu-latest]
  37 +
  38 + steps:
  39 + - uses: actions/checkout@v2
  40 + with:
  41 + fetch-depth: 0
  42 +
  43 + - name: Configure CMake
  44 + shell: bash
  45 + run: |
  46 + mkdir build
  47 + cd build
  48 + cmake -D CMAKE_BUILD_TYPE=Release ..
  49 +
  50 + - name: Build sherpa-onnx for ubuntu
  51 + shell: bash
  52 + run: |
  53 + cd build
  54 + make -j2
  55 +
  56 + ls -lh lib
  57 + ls -lh bin
  58 +
  59 + - name: Display dependencies of sherpa-onnx for linux
  60 + shell: bash
  61 + run: |
  62 + file build/bin/sherpa-onnx
  63 + readelf -d build/bin/sherpa-onnx
  64 +
  65 + - name: Test online transducer
  66 + shell: bash
  67 + run: |
  68 + export PATH=$PWD/build/bin:$PATH
  69 + export EXE=sherpa-onnx
  70 +
  71 + .github/scripts/test-online-transducer.sh
  1 +name: macos
  2 +
  3 +on:
  4 + push:
  5 + branches:
  6 + - master
  7 + paths:
  8 + - '.github/workflows/macos.yaml'
  9 + - '.github/scripts/test-online-transducer.sh'
  10 + - 'CMakeLists.txt'
  11 + - 'cmake/**'
  12 + - 'sherpa-onnx/csrc/*'
  13 + pull_request:
  14 + branches:
  15 + - master
  16 + paths:
  17 + - '.github/workflows/macos.yaml'
  18 + - '.github/scripts/test-online-transducer.sh'
  19 + - 'CMakeLists.txt'
  20 + - 'cmake/**'
  21 + - 'sherpa-onnx/csrc/*'
  22 +
  23 +concurrency:
  24 + group: macos-${{ github.ref }}
  25 + cancel-in-progress: true
  26 +
  27 +permissions:
  28 + contents: read
  29 +
  30 +jobs:
  31 + macos:
  32 + runs-on: ${{ matrix.os }}
  33 + strategy:
  34 + fail-fast: false
  35 + matrix:
  36 + os: [macos-latest]
  37 +
  38 + steps:
  39 + - uses: actions/checkout@v2
  40 + with:
  41 + fetch-depth: 0
  42 +
  43 + - name: Configure CMake
  44 + shell: bash
  45 + run: |
  46 + mkdir build
  47 + cd build
  48 + cmake -D CMAKE_BUILD_TYPE=Release ..
  49 +
  50 + - name: Build sherpa for macos
  51 + shell: bash
  52 + run: |
  53 + cd build
  54 + make -j2
  55 +
  56 + ls -lh lib
  57 + ls -lh bin
  58 +
  59 +
  60 + - name: Display dependencies of sherpa-onnx for macos
  61 + shell: bash
  62 + run: |
  63 + file bin/sherpa-onnx
  64 + otool -L build/bin/sherpa-onnx
  65 + otool -l build/bin/sherpa-onnx
  66 +
  67 + - name: Test online transducer
  68 + shell: bash
  69 + run: |
  70 + export PATH=$PWD/build/bin:$PATH
  71 + export EXE=sherpa-onnx
  72 +
  73 + .github/scripts/test-online-transducer.sh
1 -name: test-linux-macos-windows  
2 -  
3 -on:  
4 - push:  
5 - branches:  
6 - - master  
7 - paths:  
8 - - '.github/workflows/test-linux-macos-windows.yaml'  
9 - - 'CMakeLists.txt'  
10 - - 'cmake/**'  
11 - - 'sherpa-onnx/csrc/*'  
12 - pull_request:  
13 - branches:  
14 - - master  
15 - paths:  
16 - - '.github/workflows/test-linux-macos-windows.yaml'  
17 - - 'CMakeLists.txt'  
18 - - 'cmake/**'  
19 - - 'sherpa-onnx/csrc/*'  
20 -  
21 -concurrency:  
22 - group: test-linux-macos-windows-${{ github.ref }}  
23 - cancel-in-progress: true  
24 -  
25 -permissions:  
26 - contents: read  
27 -  
28 -jobs:  
29 - test-linux-macos-windows:  
30 - runs-on: ${{ matrix.os }}  
31 - strategy:  
32 - fail-fast: false  
33 - matrix:  
34 - os: [ubuntu-latest, macos-latest, windows-latest]  
35 -  
36 - steps:  
37 - - uses: actions/checkout@v2  
38 - with:  
39 - fetch-depth: 0  
40 -  
41 - # see https://github.com/microsoft/setup-msbuild  
42 - - name: Add msbuild to PATH  
43 - if: startsWith(matrix.os, 'windows')  
44 - uses: microsoft/setup-msbuild@v1.0.2  
45 -  
46 - - name: Download pretrained model and test-data (English)  
47 - shell: bash  
48 - run: |  
49 - git lfs install  
50 - GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13  
51 - cd icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13  
52 - ls -lh exp/onnx/*.onnx  
53 - git lfs pull --include "exp/onnx/*.onnx"  
54 - ls -lh exp/onnx/*.onnx  
55 -  
56 - - name: Download pretrained model and test-data (Chinese)  
57 - shell: bash  
58 - run: |  
59 - GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2  
60 - cd icefall_asr_wenetspeech_pruned_transducer_stateless2  
61 - ls -lh exp/*.onnx  
62 - git lfs pull --include "exp/*.onnx"  
63 - ls -lh exp/*.onnx  
64 -  
65 - - name: Configure CMake  
66 - shell: bash  
67 - run: |  
68 - mkdir build  
69 - cd build  
70 - cmake -D CMAKE_BUILD_TYPE=Release ..  
71 -  
72 - - name: Build sherpa-onnx for ubuntu/macos  
73 - if: startsWith(matrix.os, 'ubuntu') || startsWith(matrix.os, 'macos')  
74 - shell: bash  
75 - run: |  
76 - cd build  
77 - make VERBOSE=1 -j3  
78 -  
79 - - name: Build sherpa-onnx for Windows  
80 - if: startsWith(matrix.os, 'windows')  
81 - shell: bash  
82 - run: |  
83 - cmake --build ./build --config Release  
84 -  
85 - - name: Run tests for ubuntu/macos (English)  
86 - if: startsWith(matrix.os, 'ubuntu') || startsWith(matrix.os, 'macos')  
87 - shell: bash  
88 - run: |  
89 - time ./build/bin/sherpa-onnx \  
90 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \  
91 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \  
92 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \  
93 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \  
94 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \  
95 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \  
96 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav  
97 -  
98 - time ./build/bin/sherpa-onnx \  
99 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \  
100 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \  
101 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \  
102 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \  
103 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \  
104 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \  
105 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav  
106 -  
107 - time ./build/bin/sherpa-onnx \  
108 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \  
109 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \  
110 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \  
111 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \  
112 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \  
113 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \  
114 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav  
115 -  
116 - - name: Run tests for Windows (English)  
117 - if: startsWith(matrix.os, 'windows')  
118 - shell: bash  
119 - run: |  
120 - ./build/bin/Release/sherpa-onnx \  
121 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \  
122 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \  
123 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \  
124 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \  
125 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \  
126 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \  
127 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav  
128 -  
129 - ./build/bin/Release/sherpa-onnx \  
130 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \  
131 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \  
132 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \  
133 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \  
134 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \  
135 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \  
136 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav  
137 -  
138 - ./build/bin/Release/sherpa-onnx \  
139 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \  
140 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \  
141 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \  
142 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \  
143 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \  
144 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \  
145 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav  
146 -  
147 - - name: Run tests for ubuntu/macos (Chinese)  
148 - if: startsWith(matrix.os, 'ubuntu') || startsWith(matrix.os, 'macos')  
149 - shell: bash  
150 - run: |  
151 - time ./build/bin/sherpa-onnx \  
152 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \  
153 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \  
154 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \  
155 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \  
156 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \  
157 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \  
158 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000000.wav  
159 -  
160 - time ./build/bin/sherpa-onnx \  
161 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \  
162 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \  
163 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \  
164 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \  
165 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \  
166 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \  
167 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000001.wav  
168 -  
169 - time ./build/bin/sherpa-onnx \  
170 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \  
171 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \  
172 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \  
173 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \  
174 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \  
175 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \  
176 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000002.wav  
177 -  
178 - - name: Run tests for windows (Chinese)  
179 - if: startsWith(matrix.os, 'windows')  
180 - shell: bash  
181 - run: |  
182 - ./build/bin/Release/sherpa-onnx \  
183 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \  
184 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \  
185 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \  
186 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \  
187 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \  
188 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \  
189 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000000.wav  
190 -  
191 - ./build/bin/Release/sherpa-onnx \  
192 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \  
193 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \  
194 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \  
195 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \  
196 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \  
197 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \  
198 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000001.wav  
199 -  
200 - ./build/bin/Release/sherpa-onnx \  
201 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \  
202 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \  
203 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \  
204 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \  
205 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \  
206 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \  
207 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000002.wav  
  1 +name: windows-x64
  2 +
  3 +on:
  4 + push:
  5 + branches:
  6 + - master
  7 + paths:
  8 + - '.github/workflows/windows-x64.yaml'
  9 + - '.github/scripts/test-online-transducer.sh'
  10 + - 'CMakeLists.txt'
  11 + - 'cmake/**'
  12 + - 'sherpa-onnx/csrc/*'
  13 + pull_request:
  14 + branches:
  15 + - master
  16 + paths:
  17 + - '.github/workflows/windows-x64.yaml'
  18 + - '.github/scripts/test-online-transducer.sh'
  19 + - 'CMakeLists.txt'
  20 + - 'cmake/**'
  21 + - 'sherpa-onnx/csrc/*'
  22 +
  23 +concurrency:
  24 + group: windows-x64-${{ github.ref }}
  25 + cancel-in-progress: true
  26 +
  27 +permissions:
  28 + contents: read
  29 +
  30 +jobs:
  31 + windows_x64:
  32 + runs-on: ${{ matrix.os }}
  33 + name: ${{ matrix.vs-version }}
  34 + strategy:
  35 + fail-fast: false
  36 + matrix:
  37 + include:
  38 + - vs-version: vs2015
  39 + toolset-version: v140
  40 + os: windows-2019
  41 +
  42 + - vs-version: vs2017
  43 + toolset-version: v141
  44 + os: windows-2019
  45 +
  46 + - vs-version: vs2019
  47 + toolset-version: v142
  48 + os: windows-2022
  49 +
  50 + - vs-version: vs2022
  51 + toolset-version: v143
  52 + os: windows-2022
  53 +
  54 + steps:
  55 + - uses: actions/checkout@v2
  56 + with:
  57 + fetch-depth: 0
  58 +
  59 + - name: Configure CMake
  60 + shell: bash
  61 + run: |
  62 + mkdir build
  63 + cd build
  64 + cmake -T ${{ matrix.toolset-version}},host=x64 -A x64 -D CMAKE_BUILD_TYPE=Release ..
  65 +
  66 + - name: Build sherpa-onnx for windows
  67 + shell: bash
  68 + run: |
  69 + cd build
  70 + cmake --build . --config Release -- -m:2
  71 +
  72 + ls -lh ./bin/Release/sherpa-onnx.exe
  73 +
  74 + - name: Test sherpa-onnx for Windows x64
  75 + shell: bash
  76 + run: |
  77 + export PATH=$PWD/build/bin/Release:$PATH
  78 + export EXE=sherpa-onnx.exe
  79 +
  80 + .github/scripts/test-online-transducer.sh
@@ -4,3 +4,4 @@ build @@ -4,3 +4,4 @@ build
4 onnxruntime-* 4 onnxruntime-*
5 icefall-* 5 icefall-*
6 run.sh 6 run.sh
  7 +sherpa-onnx-*
@@ -2,89 +2,7 @@ @@ -2,89 +2,7 @@
2 2
3 Documentation: <https://k2-fsa.github.io/sherpa/onnx/index.html> 3 Documentation: <https://k2-fsa.github.io/sherpa/onnx/index.html>
4 4
5 -Try it in colab:  
6 -[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1tmQbdlYeTl_klmtaGiUb7a7ZPz-AkBSH?usp=sharing)  
7 -  
8 See <https://github.com/k2-fsa/sherpa> 5 See <https://github.com/k2-fsa/sherpa>
9 6
10 This repo uses [onnxruntime](https://github.com/microsoft/onnxruntime) and 7 This repo uses [onnxruntime](https://github.com/microsoft/onnxruntime) and
11 does not depend on libtorch. 8 does not depend on libtorch.
12 -  
13 -We provide exported models in onnx format and they can be downloaded using  
14 -the following links:  
15 -  
16 -- English: <https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13>  
17 -- Chinese: <https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2>  
18 -  
19 -**NOTE**: We provide only non-streaming models at present.  
20 -  
21 -  
22 -**HINT**: The script for exporting the English model can be found at  
23 -<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless3/export.py>  
24 -  
25 -**HINT**: The script for exporting the Chinese model can be found at  
26 -<https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py>  
27 -  
28 -## Build for Linux/macOS  
29 -  
30 -```bash  
31 -git clone https://github.com/k2-fsa/sherpa-onnx  
32 -cd sherpa-onnx  
33 -mkdir build  
34 -cd build  
35 -cmake -DCMAKE_BUILD_TYPE=Release ..  
36 -make -j6  
37 -cd ..  
38 -```  
39 -  
40 -## Build for Windows  
41 -  
42 -```bash  
43 -git clone https://github.com/k2-fsa/sherpa-onnx  
44 -cd sherpa-onnx  
45 -mkdir build  
46 -cd build  
47 -cmake -DCMAKE_BUILD_TYPE=Release ..  
48 -cmake --build . --config Release  
49 -cd ..  
50 -```  
51 -  
52 -## Download the pretrained model (English)  
53 -  
54 -```bash  
55 -GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13  
56 -cd icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13  
57 -git lfs pull --include "exp/onnx/*.onnx"  
58 -cd ..  
59 -  
60 -./build/bin/sherpa-onnx --help  
61 -  
62 -./build/bin/sherpa-onnx \  
63 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \  
64 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \  
65 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \  
66 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \  
67 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \  
68 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \  
69 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav  
70 -```  
71 -  
72 -## Download the pretrained model (Chinese)  
73 -  
74 -```bash  
75 -GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2  
76 -cd icefall_asr_wenetspeech_pruned_transducer_stateless2  
77 -git lfs pull --include "exp/*.onnx"  
78 -cd ..  
79 -  
80 -./build/bin/sherpa-onnx --help  
81 -  
82 -./build/bin/sherpa-onnx \  
83 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \  
84 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \  
85 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \  
86 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \  
87 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \  
88 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \  
89 - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000000.wav  
90 -```  
@@ -2,7 +2,11 @@ include_directories(${CMAKE_SOURCE_DIR}) @@ -2,7 +2,11 @@ include_directories(${CMAKE_SOURCE_DIR})
2 2
3 add_executable(sherpa-onnx 3 add_executable(sherpa-onnx
4 decode.cc 4 decode.cc
5 - rnnt-model.cc 5 + features.cc
  6 + online-lstm-transducer-model.cc
  7 + online-transducer-model-config.cc
  8 + online-transducer-model.cc
  9 + onnx-utils.cc
6 sherpa-onnx.cc 10 sherpa-onnx.cc
7 symbol-table.cc 11 symbol-table.cc
8 wave-reader.cc 12 wave-reader.cc
@@ -13,5 +17,5 @@ target_link_libraries(sherpa-onnx @@ -13,5 +17,5 @@ target_link_libraries(sherpa-onnx
13 kaldi-native-fbank-core 17 kaldi-native-fbank-core
14 ) 18 )
15 19
16 -# add_executable(sherpa-show-onnx-info show-onnx-info.cc)  
17 -# target_link_libraries(sherpa-show-onnx-info onnxruntime) 20 +add_executable(sherpa-onnx-show-info show-onnx-info.cc)
  21 +target_link_libraries(sherpa-onnx-show-info onnxruntime)
1 -/**  
2 - * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)  
3 - *  
4 - * See LICENSE for clarification regarding multiple authors  
5 - *  
6 - * Licensed under the Apache License, Version 2.0 (the "License");  
7 - * you may not use this file except in compliance with the License.  
8 - * You may obtain a copy of the License at  
9 - *  
10 - * http://www.apache.org/licenses/LICENSE-2.0  
11 - *  
12 - * Unless required by applicable law or agreed to in writing, software  
13 - * distributed under the License is distributed on an "AS IS" BASIS,  
14 - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
15 - * See the License for the specific language governing permissions and  
16 - * limitations under the License.  
17 - */ 1 +// sherpa/csrc/decode.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
18 4
19 #include "sherpa-onnx/csrc/decode.h" 5 #include "sherpa-onnx/csrc/decode.h"
20 6
21 #include <assert.h> 7 #include <assert.h>
22 8
23 #include <algorithm> 9 #include <algorithm>
  10 +#include <utility>
24 #include <vector> 11 #include <vector>
25 12
26 namespace sherpa_onnx { 13 namespace sherpa_onnx {
27 14
28 -std::vector<int32_t> GreedySearch(RnntModel &model, // NOLINT  
29 - const Ort::Value &encoder_out) { 15 +static Ort::Value Clone(Ort::Value *v) {
  16 + auto type_and_shape = v->GetTensorTypeAndShapeInfo();
  17 + std::vector<int64_t> shape = type_and_shape.GetShape();
  18 +
  19 + auto memory_info =
  20 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  21 +
  22 + return Ort::Value::CreateTensor(memory_info, v->GetTensorMutableData<float>(),
  23 + type_and_shape.GetElementCount(),
  24 + shape.data(), shape.size());
  25 +}
  26 +
  27 +static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) {
30 std::vector<int64_t> encoder_out_shape = 28 std::vector<int64_t> encoder_out_shape =
31 - encoder_out.GetTensorTypeAndShapeInfo().GetShape();  
32 - assert(encoder_out_shape[0] == 1 && "Only batch_size=1 is implemented");  
33 - Ort::Value projected_encoder_out =  
34 - model.RunJoinerEncoderProj(encoder_out.GetTensorData<float>(),  
35 - encoder_out_shape[1], encoder_out_shape[2]); 29 + encoder_out->GetTensorTypeAndShapeInfo().GetShape();
  30 + assert(encoder_out_shape[0] == 1);
36 31
37 - const float *p_projected_encoder_out =  
38 - projected_encoder_out.GetTensorData<float>(); 32 + int32_t encoder_out_dim = encoder_out_shape[2];
39 33
40 - int32_t context_size = 2; // hard-code it to 2  
41 - int32_t blank_id = 0; // hard-code it to 0  
42 - std::vector<int32_t> hyp(context_size, blank_id);  
43 - std::array<int64_t, 2> decoder_input{blank_id, blank_id}; 34 + auto memory_info =
  35 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
44 36
45 - Ort::Value decoder_out = model.RunDecoder(decoder_input.data(), context_size); 37 + std::array<int64_t, 2> shape{1, encoder_out_dim};
46 38
47 - std::vector<int64_t> decoder_out_shape =  
48 - decoder_out.GetTensorTypeAndShapeInfo().GetShape(); 39 + return Ort::Value::CreateTensor(
  40 + memory_info,
  41 + encoder_out->GetTensorMutableData<float>() + t * encoder_out_dim,
  42 + encoder_out_dim, shape.data(), shape.size());
  43 +}
49 44
50 - Ort::Value projected_decoder_out = model.RunJoinerDecoderProj(  
51 - decoder_out.GetTensorData<float>(), decoder_out_shape[2]); 45 +void GreedySearch(OnlineTransducerModel *model, Ort::Value encoder_out,
  46 + std::vector<int64_t> *hyp) {
  47 + std::vector<int64_t> encoder_out_shape =
  48 + encoder_out.GetTensorTypeAndShapeInfo().GetShape();
52 49
53 - int32_t joiner_dim =  
54 - projected_decoder_out.GetTensorTypeAndShapeInfo().GetShape()[1]; 50 + if (encoder_out_shape[0] > 1) {
  51 + fprintf(stderr, "Only batch_size=1 is implemented. Given: %d\n",
  52 + static_cast<int32_t>(encoder_out_shape[0]));
  53 + }
55 54
56 - int32_t T = encoder_out_shape[1];  
57 - for (int32_t t = 0; t != T; ++t) {  
58 - Ort::Value logit = model.RunJoiner(  
59 - p_projected_encoder_out + t * joiner_dim,  
60 - projected_decoder_out.GetTensorData<float>(), joiner_dim); 55 + int32_t num_frames = encoder_out_shape[1];
  56 + int32_t vocab_size = model->VocabSize();
61 57
62 - int32_t vocab_size = logit.GetTensorTypeAndShapeInfo().GetShape()[1]; 58 + Ort::Value decoder_input = model->BuildDecoderInput(*hyp);
  59 + Ort::Value decoder_out = model->RunDecoder(std::move(decoder_input));
63 60
  61 + for (int32_t t = 0; t != num_frames; ++t) {
  62 + Ort::Value cur_encoder_out = GetFrame(&encoder_out, t);
  63 + Ort::Value logit =
  64 + model->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out));
64 const float *p_logit = logit.GetTensorData<float>(); 65 const float *p_logit = logit.GetTensorData<float>();
65 66
66 auto y = static_cast<int32_t>(std::distance( 67 auto y = static_cast<int32_t>(std::distance(
67 static_cast<const float *>(p_logit), 68 static_cast<const float *>(p_logit),
68 std::max_element(static_cast<const float *>(p_logit), 69 std::max_element(static_cast<const float *>(p_logit),
69 static_cast<const float *>(p_logit) + vocab_size))); 70 static_cast<const float *>(p_logit) + vocab_size)));
70 -  
71 - if (y != blank_id) {  
72 - decoder_input[0] = hyp.back();  
73 - decoder_input[1] = y;  
74 - hyp.push_back(y);  
75 - decoder_out = model.RunDecoder(decoder_input.data(), context_size);  
76 - projected_decoder_out = model.RunJoinerDecoderProj(  
77 - decoder_out.GetTensorData<float>(), decoder_out_shape[2]); 71 + if (y != 0) {
  72 + hyp->push_back(y);
  73 + decoder_input = model->BuildDecoderInput(*hyp);
  74 + decoder_out = model->RunDecoder(std::move(decoder_input));
78 } 75 }
79 } 76 }
80 -  
81 - return {hyp.begin() + context_size, hyp.end()};  
82 } 77 }
83 78
84 } // namespace sherpa_onnx 79 } // namespace sherpa_onnx
1 -/**  
2 - * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)  
3 - *  
4 - * See LICENSE for clarification regarding multiple authors  
5 - *  
6 - * Licensed under the Apache License, Version 2.0 (the "License");  
7 - * you may not use this file except in compliance with the License.  
8 - * You may obtain a copy of the License at  
9 - *  
10 - * http://www.apache.org/licenses/LICENSE-2.0  
11 - *  
12 - * Unless required by applicable law or agreed to in writing, software  
13 - * distributed under the License is distributed on an "AS IS" BASIS,  
14 - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
15 - * See the License for the specific language governing permissions and  
16 - * limitations under the License.  
17 - */ 1 +// sherpa/csrc/decode.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
18 4
19 #ifndef SHERPA_ONNX_CSRC_DECODE_H_ 5 #ifndef SHERPA_ONNX_CSRC_DECODE_H_
20 #define SHERPA_ONNX_CSRC_DECODE_H_ 6 #define SHERPA_ONNX_CSRC_DECODE_H_
21 7
22 #include <vector> 8 #include <vector>
23 9
24 -#include "sherpa-onnx/csrc/rnnt-model.h" 10 +#include "sherpa-onnx/csrc/online-transducer-model.h"
25 11
26 namespace sherpa_onnx { 12 namespace sherpa_onnx {
27 13
@@ -32,8 +18,8 @@ namespace sherpa_onnx { @@ -32,8 +18,8 @@ namespace sherpa_onnx {
32 * @param model The RnntModel 18 * @param model The RnntModel
33 * @param encoder_out Its shape is (1, num_frames, encoder_out_dim). 19 * @param encoder_out Its shape is (1, num_frames, encoder_out_dim).
34 */ 20 */
35 -std::vector<int32_t> GreedySearch(RnntModel &model, // NOLINT  
36 - const Ort::Value &encoder_out); 21 +void GreedySearch(OnlineTransducerModel *model, Ort::Value encoder_out,
  22 + std::vector<int64_t> *hyp);
37 23
38 } // namespace sherpa_onnx 24 } // namespace sherpa_onnx
39 25
  1 +// sherpa/csrc/features.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/features.h"
  6 +
  7 +#include <algorithm>
  8 +#include <memory>
  9 +#include <vector>
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +FeatureExtractor::FeatureExtractor() {
  14 + opts_.frame_opts.dither = 0;
  15 + opts_.frame_opts.snip_edges = false;
  16 + opts_.frame_opts.samp_freq = 16000;
  17 +
  18 + // cache 100 seconds of feature frames, which is more than enough
  19 + // for real needs
  20 + opts_.frame_opts.max_feature_vectors = 100 * 100;
  21 +
  22 + opts_.mel_opts.num_bins = 80; // feature dim
  23 +
  24 + fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
  25 +}
  26 +
  27 +FeatureExtractor::FeatureExtractor(const knf::FbankOptions &opts)
  28 + : opts_(opts) {
  29 + fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
  30 +}
  31 +
  32 +void FeatureExtractor::AcceptWaveform(float sampling_rate,
  33 + const float *waveform, int32_t n) {
  34 + std::lock_guard<std::mutex> lock(mutex_);
  35 + fbank_->AcceptWaveform(sampling_rate, waveform, n);
  36 +}
  37 +
  38 +void FeatureExtractor::InputFinished() {
  39 + std::lock_guard<std::mutex> lock(mutex_);
  40 + fbank_->InputFinished();
  41 +}
  42 +
  43 +int32_t FeatureExtractor::NumFramesReady() const {
  44 + std::lock_guard<std::mutex> lock(mutex_);
  45 + return fbank_->NumFramesReady();
  46 +}
  47 +
  48 +bool FeatureExtractor::IsLastFrame(int32_t frame) const {
  49 + std::lock_guard<std::mutex> lock(mutex_);
  50 + return fbank_->IsLastFrame(frame);
  51 +}
  52 +
  53 +std::vector<float> FeatureExtractor::GetFrames(int32_t frame_index,
  54 + int32_t n) const {
  55 + if (frame_index + n > NumFramesReady()) {
  56 + fprintf(stderr, "%d + %d > %d\n", frame_index, n, NumFramesReady());
  57 + exit(-1);
  58 + }
  59 + std::lock_guard<std::mutex> lock(mutex_);
  60 +
  61 + int32_t feature_dim = fbank_->Dim();
  62 + std::vector<float> features(feature_dim * n);
  63 +
  64 + float *p = features.data();
  65 +
  66 + for (int32_t i = 0; i != n; ++i) {
  67 + const float *f = fbank_->GetFrame(i + frame_index);
  68 + std::copy(f, f + feature_dim, p);
  69 + p += feature_dim;
  70 + }
  71 +
  72 + return features;
  73 +}
  74 +
  75 +void FeatureExtractor::Reset() {
  76 + fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
  77 +}
  78 +
  79 +} // namespace sherpa_onnx
  1 +// sherpa/csrc/features.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_FEATURES_H_
  6 +#define SHERPA_ONNX_CSRC_FEATURES_H_
  7 +
  8 +#include <memory>
  9 +#include <mutex> // NOLINT
  10 +#include <vector>
  11 +
  12 +#include "kaldi-native-fbank/csrc/online-feature.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +class FeatureExtractor {
  17 + public:
  18 + FeatureExtractor();
  19 + explicit FeatureExtractor(const knf::FbankOptions &fbank_opts);
  20 +
  21 + /**
  22 + @param sampling_rate The sampling_rate of the input waveform. Should match
  23 + the one expected by the feature extractor.
  24 + @param waveform Pointer to a 1-D array of size n
  25 + @param n Number of entries in waveform
  26 + */
  27 + void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n);
  28 +
  29 + // InputFinished() tells the class you won't be providing any
  30 + // more waveform. This will help flush out the last frame or two
  31 + // of features, in the case where snip-edges == false; it also
  32 + // affects the return value of IsLastFrame().
  33 + void InputFinished();
  34 +
  35 + int32_t NumFramesReady() const;
  36 +
  37 + // Note: IsLastFrame() will only ever return true if you have called
  38 + // InputFinished() (and this frame is the last frame).
  39 + bool IsLastFrame(int32_t frame) const;
  40 +
  41 + /** Get n frames starting from the given frame index.
  42 + *
  43 + * @param frame_index The starting frame index
  44 + * @param n Number of frames to get.
  45 + * @return Return a 2-D tensor of shape (n, feature_dim).
  46 + * which is flattened into a 1-D vector (flattened in in row major)
  47 + */
  48 + std::vector<float> GetFrames(int32_t frame_index, int32_t n) const;
  49 +
  50 + void Reset();
  51 + int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
  52 +
  53 + private:
  54 + std::unique_ptr<knf::OnlineFbank> fbank_;
  55 + knf::FbankOptions opts_;
  56 + mutable std::mutex mutex_;
  57 +};
  58 +
  59 +} // namespace sherpa_onnx
  60 +
  61 +#endif // SHERPA_ONNX_CSRC_FEATURES_H_
  1 +// sherpa/csrc/online-lstm-transducer-model.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
  5 +
  6 +#include <memory>
  7 +#include <sstream>
  8 +#include <string>
  9 +#include <utility>
  10 +#include <vector>
  11 +
  12 +#include "onnxruntime_cxx_api.h" // NOLINT
  13 +#include "sherpa-onnx/csrc/onnx-utils.h"
  14 +
  15 +#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \
  16 + do { \
  17 + auto value = \
  18 + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
  19 + if (!value) { \
  20 + fprintf(stderr, "%s does not exist in the metadata\n", src_key); \
  21 + exit(-1); \
  22 + } \
  23 + dst = atoi(value.get()); \
  24 + if (dst <= 0) { \
  25 + fprintf(stderr, "Invalud value %d for %s\n", dst, src_key); \
  26 + exit(-1); \
  27 + } \
  28 + } while (0)
  29 +
  30 +namespace sherpa_onnx {
  31 +
  32 +OnlineLstmTransducerModel::OnlineLstmTransducerModel(
  33 + const OnlineTransducerModelConfig &config)
  34 + : env_(ORT_LOGGING_LEVEL_WARNING),
  35 + config_(config),
  36 + sess_opts_{},
  37 + allocator_{} {
  38 + sess_opts_.SetIntraOpNumThreads(config.num_threads);
  39 + sess_opts_.SetInterOpNumThreads(config.num_threads);
  40 +
  41 + InitEncoder(config.encoder_filename);
  42 + InitDecoder(config.decoder_filename);
  43 + InitJoiner(config.joiner_filename);
  44 +}
  45 +
  46 +void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) {
  47 + encoder_sess_ = std::make_unique<Ort::Session>(
  48 + env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
  49 +
  50 + GetInputNames(encoder_sess_.get(), &encoder_input_names_,
  51 + &encoder_input_names_ptr_);
  52 +
  53 + GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
  54 + &encoder_output_names_ptr_);
  55 +
  56 + // get meta data
  57 + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
  58 + if (config_.debug) {
  59 + std::ostringstream os;
  60 + os << "---encoder---\n";
  61 + PrintModelMetadata(os, meta_data);
  62 + fprintf(stderr, "%s\n", os.str().c_str());
  63 + }
  64 +
  65 + Ort::AllocatorWithDefaultOptions allocator;
  66 + SHERPA_ONNX_READ_META_DATA(num_encoder_layers_, "num_encoder_layers");
  67 + SHERPA_ONNX_READ_META_DATA(T_, "T");
  68 + SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len");
  69 + SHERPA_ONNX_READ_META_DATA(rnn_hidden_size_, "rnn_hidden_size");
  70 + SHERPA_ONNX_READ_META_DATA(d_model_, "d_model");
  71 +}
  72 +
  73 +void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) {
  74 + decoder_sess_ = std::make_unique<Ort::Session>(
  75 + env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
  76 +
  77 + GetInputNames(decoder_sess_.get(), &decoder_input_names_,
  78 + &decoder_input_names_ptr_);
  79 +
  80 + GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
  81 + &decoder_output_names_ptr_);
  82 +
  83 + // get meta data
  84 + Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata();
  85 + if (config_.debug) {
  86 + std::ostringstream os;
  87 + os << "---decoder---\n";
  88 + PrintModelMetadata(os, meta_data);
  89 + fprintf(stderr, "%s\n", os.str().c_str());
  90 + }
  91 +
  92 + Ort::AllocatorWithDefaultOptions allocator;
  93 + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
  94 + SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
  95 +}
  96 +
  97 +void OnlineLstmTransducerModel::InitJoiner(const std::string &filename) {
  98 + joiner_sess_ = std::make_unique<Ort::Session>(
  99 + env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
  100 +
  101 + GetInputNames(joiner_sess_.get(), &joiner_input_names_,
  102 + &joiner_input_names_ptr_);
  103 +
  104 + GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
  105 + &joiner_output_names_ptr_);
  106 +
  107 + // get meta data
  108 + Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata();
  109 + if (config_.debug) {
  110 + std::ostringstream os;
  111 + os << "---joiner---\n";
  112 + PrintModelMetadata(os, meta_data);
  113 + fprintf(stderr, "%s\n", os.str().c_str());
  114 + }
  115 +}
  116 +
  117 +Ort::Value OnlineLstmTransducerModel::StackStates(
  118 + const std::vector<Ort::Value> &states) const {
  119 + fprintf(stderr, "implement me: %s:%d!\n", __func__,
  120 + static_cast<int>(__LINE__));
  121 + auto memory_info =
  122 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  123 + int64_t a;
  124 + std::array<int64_t, 3> x_shape{1, 1, 1};
  125 + Ort::Value x = Ort::Value::CreateTensor(memory_info, &a, 0, &a, 0);
  126 + return x;
  127 +}
  128 +
  129 +std::vector<Ort::Value> OnlineLstmTransducerModel::UnStackStates(
  130 + Ort::Value states) const {
  131 + fprintf(stderr, "implement me: %s:%d!\n", __func__,
  132 + static_cast<int>(__LINE__));
  133 + return {};
  134 +}
  135 +
  136 +std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() {
  137 + // Please see
  138 + // https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py#L185
  139 + // for details
  140 + constexpr int32_t kBatchSize = 1;
  141 + std::array<int64_t, 3> h_shape{num_encoder_layers_, kBatchSize, d_model_};
  142 + Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
  143 + h_shape.size());
  144 +
  145 + std::fill(h.GetTensorMutableData<float>(),
  146 + h.GetTensorMutableData<float>() +
  147 + num_encoder_layers_ * kBatchSize * d_model_,
  148 + 0);
  149 +
  150 + std::array<int64_t, 3> c_shape{num_encoder_layers_, kBatchSize,
  151 + rnn_hidden_size_};
  152 + Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(),
  153 + c_shape.size());
  154 +
  155 + std::fill(c.GetTensorMutableData<float>(),
  156 + c.GetTensorMutableData<float>() +
  157 + num_encoder_layers_ * kBatchSize * rnn_hidden_size_,
  158 + 0);
  159 +
  160 + std::vector<Ort::Value> states;
  161 +
  162 + states.reserve(2);
  163 + states.push_back(std::move(h));
  164 + states.push_back(std::move(c));
  165 +
  166 + return states;
  167 +}
  168 +
  169 +std::pair<Ort::Value, std::vector<Ort::Value>>
  170 +OnlineLstmTransducerModel::RunEncoder(Ort::Value features,
  171 + std::vector<Ort::Value> &states) {
  172 + auto memory_info =
  173 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  174 +
  175 + std::array<Ort::Value, 3> encoder_inputs = {
  176 + std::move(features), std::move(states[0]), std::move(states[1])};
  177 +
  178 + auto encoder_out = encoder_sess_->Run(
  179 + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
  180 + encoder_inputs.size(), encoder_output_names_ptr_.data(),
  181 + encoder_output_names_ptr_.size());
  182 +
  183 + std::vector<Ort::Value> next_states;
  184 + next_states.reserve(2);
  185 + next_states.push_back(std::move(encoder_out[1]));
  186 + next_states.push_back(std::move(encoder_out[2]));
  187 +
  188 + return {std::move(encoder_out[0]), std::move(next_states)};
  189 +}
  190 +
  191 +Ort::Value OnlineLstmTransducerModel::BuildDecoderInput(
  192 + const std::vector<int64_t> &hyp) {
  193 + auto memory_info =
  194 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  195 +
  196 + std::array<int64_t, 2> shape{1, context_size_};
  197 +
  198 + return Ort::Value::CreateTensor(
  199 + memory_info,
  200 + const_cast<int64_t *>(hyp.data() + hyp.size() - context_size_),
  201 + context_size_, shape.data(), shape.size());
  202 +}
  203 +
  204 +Ort::Value OnlineLstmTransducerModel::RunDecoder(Ort::Value decoder_input) {
  205 + auto decoder_out = decoder_sess_->Run(
  206 + {}, decoder_input_names_ptr_.data(), &decoder_input, 1,
  207 + decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size());
  208 + return std::move(decoder_out[0]);
  209 +}
  210 +
  211 +Ort::Value OnlineLstmTransducerModel::RunJoiner(Ort::Value encoder_out,
  212 + Ort::Value decoder_out) {
  213 + std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
  214 + std::move(decoder_out)};
  215 + auto logit =
  216 + joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(),
  217 + joiner_input.size(), joiner_output_names_ptr_.data(),
  218 + joiner_output_names_ptr_.size());
  219 +
  220 + return std::move(logit[0]);
  221 +}
  222 +
  223 +} // namespace sherpa_onnx
  1 +// sherpa/csrc/online-lstm-transducer-model.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_ONLINE_LSTM_TRANSDUCER_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_ONLINE_LSTM_TRANSDUCER_MODEL_H_
  6 +
  7 +#include <memory>
  8 +#include <string>
  9 +#include <utility>
  10 +#include <vector>
  11 +
  12 +#include "onnxruntime_cxx_api.h" // NOLINT
  13 +#include "sherpa-onnx/csrc/online-transducer-model-config.h"
  14 +#include "sherpa-onnx/csrc/online-transducer-model.h"
  15 +
  16 +namespace sherpa_onnx {
  17 +
  18 +class OnlineLstmTransducerModel : public OnlineTransducerModel {
  19 + public:
  20 + explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config);
  21 +
  22 + Ort::Value StackStates(const std::vector<Ort::Value> &states) const override;
  23 +
  24 + std::vector<Ort::Value> UnStackStates(Ort::Value states) const override;
  25 +
  26 + std::vector<Ort::Value> GetEncoderInitStates() override;
  27 +
  28 + std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
  29 + Ort::Value features, std::vector<Ort::Value> &states) override;
  30 +
  31 + Ort::Value BuildDecoderInput(const std::vector<int64_t> &hyp) override;
  32 +
  33 + Ort::Value RunDecoder(Ort::Value decoder_input) override;
  34 +
  35 + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override;
  36 +
  37 + int32_t ContextSize() const override { return context_size_; }
  38 +
  39 + int32_t ChunkSize() const override { return T_; }
  40 +
  41 + int32_t ChunkShift() const override { return decode_chunk_len_; }
  42 +
  43 + int32_t VocabSize() const override { return vocab_size_; }
  44 +
  45 + private:
  46 + void InitEncoder(const std::string &encoder_filename);
  47 + void InitDecoder(const std::string &decoder_filename);
  48 + void InitJoiner(const std::string &joiner_filename);
  49 +
  50 + private:
  51 + Ort::Env env_;
  52 + Ort::SessionOptions sess_opts_;
  53 +
  54 + Ort::AllocatorWithDefaultOptions allocator_;
  55 +
  56 + std::unique_ptr<Ort::Session> encoder_sess_;
  57 + std::unique_ptr<Ort::Session> decoder_sess_;
  58 + std::unique_ptr<Ort::Session> joiner_sess_;
  59 +
  60 + std::vector<std::string> encoder_input_names_;
  61 + std::vector<const char *> encoder_input_names_ptr_;
  62 +
  63 + std::vector<std::string> encoder_output_names_;
  64 + std::vector<const char *> encoder_output_names_ptr_;
  65 +
  66 + std::vector<std::string> decoder_input_names_;
  67 + std::vector<const char *> decoder_input_names_ptr_;
  68 +
  69 + std::vector<std::string> decoder_output_names_;
  70 + std::vector<const char *> decoder_output_names_ptr_;
  71 +
  72 + std::vector<std::string> joiner_input_names_;
  73 + std::vector<const char *> joiner_input_names_ptr_;
  74 +
  75 + std::vector<std::string> joiner_output_names_;
  76 + std::vector<const char *> joiner_output_names_ptr_;
  77 +
  78 + OnlineTransducerModelConfig config_;
  79 +
  80 + int32_t num_encoder_layers_ = 0;
  81 + int32_t T_ = 0;
  82 + int32_t decode_chunk_len_ = 0;
  83 + int32_t rnn_hidden_size_ = 0;
  84 + int32_t d_model_ = 0;
  85 + int32_t context_size_ = 0;
  86 + int32_t vocab_size_ = 0;
  87 +};
  88 +
  89 +} // namespace sherpa_onnx
  90 +
  91 +#endif // SHERPA_ONNX_CSRC_ONLINE_LSTM_TRANSDUCER_MODEL_H_
  1 +// sherpa/csrc/online-transducer-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#include "sherpa-onnx/csrc/online-transducer-model-config.h"
  5 +
  6 +#include <sstream>
  7 +
  8 +namespace sherpa_onnx {
  9 +
  10 +std::string OnlineTransducerModelConfig::ToString() const {
  11 + std::ostringstream os;
  12 +
  13 + os << "OnlineTransducerModelConfig(";
  14 + os << "encoder_filename=\"" << encoder_filename << "\", ";
  15 + os << "decoder_filename=\"" << decoder_filename << "\", ";
  16 + os << "joiner_filename=\"" << joiner_filename << "\", ";
  17 + os << "num_threads=" << num_threads << ", ";
  18 + os << "debug=" << (debug ? "True" : "False") << ")";
  19 +
  20 + return os.str();
  21 +}
  22 +
  23 +} // namespace sherpa_onnx
  1 +// sherpa/csrc/online-transducer-model-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_
  5 +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_
  6 +
  7 +#include <string>
  8 +
  9 +namespace sherpa_onnx {
  10 +
  11 +struct OnlineTransducerModelConfig {
  12 + std::string encoder_filename;
  13 + std::string decoder_filename;
  14 + std::string joiner_filename;
  15 + int32_t num_threads;
  16 + bool debug = false;
  17 +
  18 + std::string ToString() const;
  19 +};
  20 +
  21 +} // namespace sherpa_onnx
  22 +
  23 +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_
  1 +// sherpa/csrc/online-transducer-model.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#include "sherpa-onnx/csrc/online-transducer-model.h"
  5 +
  6 +#include <memory>
  7 +#include <sstream>
  8 +#include <string>
  9 +
  10 +#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
  11 +#include "sherpa-onnx/csrc/onnx-utils.h"
  12 +namespace sherpa_onnx {
  13 +
  14 +enum class ModelType {
  15 + kLstm,
  16 + kUnkown,
  17 +};
  18 +
  19 +static ModelType GetModelType(const OnlineTransducerModelConfig &config) {
  20 + Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
  21 + Ort::SessionOptions sess_opts;
  22 +
  23 + auto sess = std::make_unique<Ort::Session>(
  24 + env, SHERPA_MAYBE_WIDE(config.encoder_filename).c_str(), sess_opts);
  25 +
  26 + Ort::ModelMetadata meta_data = sess->GetModelMetadata();
  27 + if (config.debug) {
  28 + std::ostringstream os;
  29 + PrintModelMetadata(os, meta_data);
  30 + fprintf(stderr, "%s\n", os.str().c_str());
  31 + }
  32 +
  33 + Ort::AllocatorWithDefaultOptions allocator;
  34 + auto model_type =
  35 + meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
  36 + if (!model_type) {
  37 + fprintf(stderr, "No model_type in the metadata!\n");
  38 + return ModelType::kUnkown;
  39 + }
  40 +
  41 + if (model_type.get() == std::string("lstm")) {
  42 + return ModelType::kLstm;
  43 + } else {
  44 + fprintf(stderr, "Unsupported model_type: %s\n", model_type.get());
  45 + return ModelType::kUnkown;
  46 + }
  47 +}
  48 +
  49 +std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
  50 + const OnlineTransducerModelConfig &config) {
  51 + auto model_type = GetModelType(config);
  52 +
  53 + switch (model_type) {
  54 + case ModelType::kLstm:
  55 + return std::make_unique<OnlineLstmTransducerModel>(config);
  56 + case ModelType::kUnkown:
  57 + return nullptr;
  58 + }
  59 +
  60 + // unreachable code
  61 + return nullptr;
  62 +}
  63 +
  64 +} // namespace sherpa_onnx
  1 +// sherpa/csrc/online-transducer-model.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_H_
  6 +
  7 +#include <memory>
  8 +#include <utility>
  9 +#include <vector>
  10 +
  11 +#include "onnxruntime_cxx_api.h" // NOLINT
  12 +#include "sherpa-onnx/csrc/online-transducer-model-config.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +class OnlineTransducerModel {
  17 + public:
  18 + virtual ~OnlineTransducerModel() = default;
  19 +
  20 + static std::unique_ptr<OnlineTransducerModel> Create(
  21 + const OnlineTransducerModelConfig &config);
  22 +
  23 + /** Stack a list of individual states into a batch.
  24 + *
  25 + * It is the inverse operation of `UnStackStates`.
  26 + *
  27 + * @param states states[i] contains the state for the i-th utterance.
  28 + * @return Return a single value representing the batched state.
  29 + */
  30 + virtual Ort::Value StackStates(
  31 + const std::vector<Ort::Value> &states) const = 0;
  32 +
  33 + /** Unstack a batch state into a list of individual states.
  34 + *
  35 + * It is the inverse operation of `StackStates`.
  36 + *
  37 + * @param states A batched state.
  38 + * @return ans[i] contains the state for the i-th utterance.
  39 + */
  40 + virtual std::vector<Ort::Value> UnStackStates(Ort::Value states) const = 0;
  41 +
  42 + /** Get the initial encoder states.
  43 + *
  44 + * @return Return the initial encoder state.
  45 + */
  46 + virtual std::vector<Ort::Value> GetEncoderInitStates() = 0;
  47 +
  48 + /** Run the encoder.
  49 + *
  50 + * @param features A tensor of shape (N, T, C). It is changed in-place.
  51 + * @param states Encoder state of the previous chunk. It is changed in-place.
  52 + *
  53 + * @return Return a tuple containing:
  54 + * - encoder_out, a tensor of shape (N, T', encoder_out_dim)
  55 + * - next_states Encoder state for the next chunk.
  56 + */
  57 + virtual std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
  58 + Ort::Value features,
  59 + std::vector<Ort::Value> &states) = 0; // NOLINT
  60 +
  61 + virtual Ort::Value BuildDecoderInput(const std::vector<int64_t> &hyp) = 0;
  62 +
  63 + /** Run the decoder network.
  64 + *
  65 + * Caution: We assume there are no recurrent connections in the decoder and
  66 + * the decoder is stateless. See
  67 + * https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py
  68 + * for an example
  69 + *
  70 + * @param decoder_input It is usually of shape (N, context_size)
  71 + * @return Return a tensor of shape (N, decoder_dim).
  72 + */
  73 + virtual Ort::Value RunDecoder(Ort::Value decoder_input) = 0;
  74 +
  75 + /** Run the joint network.
  76 + *
  77 + * @param encoder_out Output of the encoder network. A tensor of shape
  78 + * (N, joiner_dim).
  79 + * @param decoder_out Output of the decoder network. A tensor of shape
  80 + * (N, joiner_dim).
  81 + * @return Return a tensor of shape (N, vocab_size). In icefall, the last
  82 + * last layer of the joint network is `nn.Linear`,
  83 + * not `nn.LogSoftmax`.
  84 + */
  85 + virtual Ort::Value RunJoiner(Ort::Value encoder_out,
  86 + Ort::Value decoder_out) = 0;
  87 +
  88 + /** If we are using a stateless decoder and if it contains a
  89 + * Conv1D, this function returns the kernel size of the convolution layer.
  90 + */
  91 + virtual int32_t ContextSize() const = 0;
  92 +
  93 + /** We send this number of feature frames to the encoder at a time. */
  94 + virtual int32_t ChunkSize() const = 0;
  95 +
  96 + /** Number of input frames to discard after each call to RunEncoder.
  97 + *
  98 + * For instance, if we have 30 frames, chunk_size=8, chunk_shift=6.
  99 + *
  100 + * In the first call of RunEncoder, we use frames 0~7 since chunk_size is 8.
  101 + * Then we discard frame 0~5 since chunk_shift is 6.
  102 + * In the second call of RunEncoder, we use frames 6~13; and then we discard
  103 + * frames 6~11.
  104 + * In the third call of RunEncoder, we use frames 12~19; and then we discard
  105 + * frames 12~16.
  106 + *
  107 + * Note: ChunkSize() - ChunkShift() == right context size
  108 + */
  109 + virtual int32_t ChunkShift() const = 0;
  110 +
  111 + virtual int32_t VocabSize() const = 0;
  112 +
  113 + virtual int32_t SubsamplingFactor() const { return 4; }
  114 +};
  115 +
  116 +} // namespace sherpa_onnx
  117 +
  118 +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_H_
  1 +// sherpa/csrc/onnx-utils.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#include "sherpa-onnx/csrc/onnx-utils.h"
  5 +
  6 +#include <string>
  7 +#include <vector>
  8 +
  9 +#include "onnxruntime_cxx_api.h" // NOLINT
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names,
  14 + std::vector<const char *> *input_names_ptr) {
  15 + Ort::AllocatorWithDefaultOptions allocator;
  16 + size_t node_count = sess->GetInputCount();
  17 + input_names->resize(node_count);
  18 + input_names_ptr->resize(node_count);
  19 + for (size_t i = 0; i != node_count; ++i) {
  20 + auto tmp = sess->GetInputNameAllocated(i, allocator);
  21 + (*input_names)[i] = tmp.get();
  22 + (*input_names_ptr)[i] = (*input_names)[i].c_str();
  23 + }
  24 +}
  25 +
  26 +void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
  27 + std::vector<const char *> *output_names_ptr) {
  28 + Ort::AllocatorWithDefaultOptions allocator;
  29 + size_t node_count = sess->GetOutputCount();
  30 + output_names->resize(node_count);
  31 + output_names_ptr->resize(node_count);
  32 + for (size_t i = 0; i != node_count; ++i) {
  33 + auto tmp = sess->GetOutputNameAllocated(i, allocator);
  34 + (*output_names)[i] = tmp.get();
  35 + (*output_names_ptr)[i] = (*output_names)[i].c_str();
  36 + }
  37 +}
  38 +
  39 +void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) {
  40 + Ort::AllocatorWithDefaultOptions allocator;
  41 + std::vector<Ort::AllocatedStringPtr> v =
  42 + meta_data.GetCustomMetadataMapKeysAllocated(allocator);
  43 + for (const auto &key : v) {
  44 + auto p = meta_data.LookupCustomMetadataMapAllocated(key.get(), allocator);
  45 + os << key.get() << "=" << p.get() << "\n";
  46 + }
  47 +}
  48 +
  49 +} // namespace sherpa_onnx
  1 +// sherpa/csrc/onnx-utils.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_ONNX_UTILS_H_
  5 +#define SHERPA_ONNX_CSRC_ONNX_UTILS_H_
  6 +
  7 +#ifdef _MSC_VER
  8 +// For ToWide() below
  9 +#include <codecvt>
  10 +#include <locale>
  11 +#endif
  12 +
  13 +#include <ostream>
  14 +#include <string>
  15 +#include <vector>
  16 +
  17 +#include "onnxruntime_cxx_api.h" // NOLINT
  18 +
  19 +namespace sherpa_onnx {
  20 +
  21 +#ifdef _MSC_VER
  22 +// See
  23 +// https://stackoverflow.com/questions/2573834/c-convert-string-or-char-to-wstring-or-wchar-t
  24 +static std::wstring ToWide(const std::string &s) {
  25 + std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
  26 + return converter.from_bytes(s);
  27 +}
  28 +#define SHERPA_MAYBE_WIDE(s) ToWide(s)
  29 +#else
  30 +#define SHERPA_MAYBE_WIDE(s) s
  31 +#endif
  32 +
  33 +/**
  34 + * Get the input names of a model.
  35 + *
  36 + * @param sess An onnxruntime session.
  37 + * @param input_names. On return, it contains the input names of the model.
  38 + * @param input_names_ptr. On return, input_names_ptr[i] contains
  39 + * input_names[i].c_str()
  40 + */
  41 +void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names,
  42 + std::vector<const char *> *input_names_ptr);
  43 +
  44 +/**
  45 + * Get the output names of a model.
  46 + *
  47 + * @param sess An onnxruntime session.
  48 + * @param output_names. On return, it contains the output names of the model.
  49 + * @param output_names_ptr. On return, output_names_ptr[i] contains
  50 + * output_names[i].c_str()
  51 + */
  52 +void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
  53 + std::vector<const char *> *output_names_ptr);
  54 +
  55 +void PrintModelMetadata(std::ostream &os,
  56 + const Ort::ModelMetadata &meta_data); // NOLINT
  57 +
  58 +} // namespace sherpa_onnx
  59 +
  60 +#endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_
1 -/**  
2 - * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)  
3 - *  
4 - * See LICENSE for clarification regarding multiple authors  
5 - *  
6 - * Licensed under the Apache License, Version 2.0 (the "License");  
7 - * you may not use this file except in compliance with the License.  
8 - * You may obtain a copy of the License at  
9 - *  
10 - * http://www.apache.org/licenses/LICENSE-2.0  
11 - *  
12 - * Unless required by applicable law or agreed to in writing, software  
13 - * distributed under the License is distributed on an "AS IS" BASIS,  
14 - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
15 - * See the License for the specific language governing permissions and  
16 - * limitations under the License.  
17 - */  
18 -#include "sherpa-onnx/csrc/rnnt-model.h"  
19 -  
20 -#include <array>  
21 -#include <utility>  
22 -#include <vector>  
23 -  
24 -#ifdef _MSC_VER  
25 -// For ToWide() below  
26 -#include <codecvt>  
27 -#include <locale>  
28 -#endif  
29 -  
30 -namespace sherpa_onnx {  
31 -  
32 -#ifdef _MSC_VER  
33 -// See  
34 -// https://stackoverflow.com/questions/2573834/c-convert-string-or-char-to-wstring-or-wchar-t  
35 -static std::wstring ToWide(const std::string &s) {  
36 - std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;  
37 - return converter.from_bytes(s);  
38 -}  
39 -#define SHERPA_MAYBE_WIDE(s) ToWide(s)  
40 -#else  
41 -#define SHERPA_MAYBE_WIDE(s) s  
42 -#endif  
43 -  
44 -/**  
45 - * Get the input names of a model.  
46 - *  
47 - * @param sess An onnxruntime session.  
48 - * @param input_names. On return, it contains the input names of the model.  
49 - * @param input_names_ptr. On return, input_names_ptr[i] contains  
50 - * input_names[i].c_str()  
51 - */  
52 -static void GetInputNames(Ort::Session *sess,  
53 - std::vector<std::string> *input_names,  
54 - std::vector<const char *> *input_names_ptr) {  
55 - Ort::AllocatorWithDefaultOptions allocator;  
56 - size_t node_count = sess->GetInputCount();  
57 - input_names->resize(node_count);  
58 - input_names_ptr->resize(node_count);  
59 - for (size_t i = 0; i != node_count; ++i) {  
60 - auto tmp = sess->GetInputNameAllocated(i, allocator);  
61 - (*input_names)[i] = tmp.get();  
62 - (*input_names_ptr)[i] = (*input_names)[i].c_str();  
63 - }  
64 -}  
65 -  
66 -/**  
67 - * Get the output names of a model.  
68 - *  
69 - * @param sess An onnxruntime session.  
70 - * @param output_names. On return, it contains the output names of the model.  
71 - * @param output_names_ptr. On return, output_names_ptr[i] contains  
72 - * output_names[i].c_str()  
73 - */  
74 -static void GetOutputNames(Ort::Session *sess,  
75 - std::vector<std::string> *output_names,  
76 - std::vector<const char *> *output_names_ptr) {  
77 - Ort::AllocatorWithDefaultOptions allocator;  
78 - size_t node_count = sess->GetOutputCount();  
79 - output_names->resize(node_count);  
80 - output_names_ptr->resize(node_count);  
81 - for (size_t i = 0; i != node_count; ++i) {  
82 - auto tmp = sess->GetOutputNameAllocated(i, allocator);  
83 - (*output_names)[i] = tmp.get();  
84 - (*output_names_ptr)[i] = (*output_names)[i].c_str();  
85 - }  
86 -}  
87 -  
88 -RnntModel::RnntModel(const std::string &encoder_filename,  
89 - const std::string &decoder_filename,  
90 - const std::string &joiner_filename,  
91 - const std::string &joiner_encoder_proj_filename,  
92 - const std::string &joiner_decoder_proj_filename,  
93 - int32_t num_threads)  
94 - : env_(ORT_LOGGING_LEVEL_WARNING) {  
95 - sess_opts_.SetIntraOpNumThreads(num_threads);  
96 - sess_opts_.SetInterOpNumThreads(num_threads);  
97 -  
98 - InitEncoder(encoder_filename);  
99 - InitDecoder(decoder_filename);  
100 - InitJoiner(joiner_filename);  
101 - InitJoinerEncoderProj(joiner_encoder_proj_filename);  
102 - InitJoinerDecoderProj(joiner_decoder_proj_filename);  
103 -}  
104 -  
105 -void RnntModel::InitEncoder(const std::string &filename) {  
106 - encoder_sess_ = std::make_unique<Ort::Session>(  
107 - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);  
108 - GetInputNames(encoder_sess_.get(), &encoder_input_names_,  
109 - &encoder_input_names_ptr_);  
110 -  
111 - GetOutputNames(encoder_sess_.get(), &encoder_output_names_,  
112 - &encoder_output_names_ptr_);  
113 -}  
114 -  
115 -void RnntModel::InitDecoder(const std::string &filename) {  
116 - decoder_sess_ = std::make_unique<Ort::Session>(  
117 - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);  
118 -  
119 - GetInputNames(decoder_sess_.get(), &decoder_input_names_,  
120 - &decoder_input_names_ptr_);  
121 -  
122 - GetOutputNames(decoder_sess_.get(), &decoder_output_names_,  
123 - &decoder_output_names_ptr_);  
124 -}  
125 -  
126 -void RnntModel::InitJoiner(const std::string &filename) {  
127 - joiner_sess_ = std::make_unique<Ort::Session>(  
128 - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);  
129 -  
130 - GetInputNames(joiner_sess_.get(), &joiner_input_names_,  
131 - &joiner_input_names_ptr_);  
132 -  
133 - GetOutputNames(joiner_sess_.get(), &joiner_output_names_,  
134 - &joiner_output_names_ptr_);  
135 -}  
136 -  
137 -void RnntModel::InitJoinerEncoderProj(const std::string &filename) {  
138 - joiner_encoder_proj_sess_ = std::make_unique<Ort::Session>(  
139 - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);  
140 -  
141 - GetInputNames(joiner_encoder_proj_sess_.get(),  
142 - &joiner_encoder_proj_input_names_,  
143 - &joiner_encoder_proj_input_names_ptr_);  
144 -  
145 - GetOutputNames(joiner_encoder_proj_sess_.get(),  
146 - &joiner_encoder_proj_output_names_,  
147 - &joiner_encoder_proj_output_names_ptr_);  
148 -}  
149 -  
150 -void RnntModel::InitJoinerDecoderProj(const std::string &filename) {  
151 - joiner_decoder_proj_sess_ = std::make_unique<Ort::Session>(  
152 - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);  
153 -  
154 - GetInputNames(joiner_decoder_proj_sess_.get(),  
155 - &joiner_decoder_proj_input_names_,  
156 - &joiner_decoder_proj_input_names_ptr_);  
157 -  
158 - GetOutputNames(joiner_decoder_proj_sess_.get(),  
159 - &joiner_decoder_proj_output_names_,  
160 - &joiner_decoder_proj_output_names_ptr_);  
161 -}  
162 -  
163 -Ort::Value RnntModel::RunEncoder(const float *features, int32_t T,  
164 - int32_t feature_dim) {  
165 - auto memory_info =  
166 - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);  
167 - std::array<int64_t, 3> x_shape{1, T, feature_dim};  
168 - Ort::Value x =  
169 - Ort::Value::CreateTensor(memory_info, const_cast<float *>(features),  
170 - T * feature_dim, x_shape.data(), x_shape.size());  
171 -  
172 - std::array<int64_t, 1> x_lens_shape{1};  
173 - int64_t x_lens_tmp = T;  
174 -  
175 - Ort::Value x_lens = Ort::Value::CreateTensor(  
176 - memory_info, &x_lens_tmp, 1, x_lens_shape.data(), x_lens_shape.size());  
177 -  
178 - std::array<Ort::Value, 2> encoder_inputs{std::move(x), std::move(x_lens)};  
179 -  
180 - // Note: We discard encoder_out_lens since we only implement  
181 - // batch==1.  
182 - auto encoder_out = encoder_sess_->Run(  
183 - {}, encoder_input_names_ptr_.data(), encoder_inputs.data(),  
184 - encoder_inputs.size(), encoder_output_names_ptr_.data(),  
185 - encoder_output_names_ptr_.size());  
186 - return std::move(encoder_out[0]);  
187 -}  
188 -Ort::Value RnntModel::RunJoinerEncoderProj(const float *encoder_out, int32_t T,  
189 - int32_t encoder_out_dim) {  
190 - auto memory_info =  
191 - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);  
192 -  
193 - std::array<int64_t, 2> in_shape{T, encoder_out_dim};  
194 - Ort::Value in = Ort::Value::CreateTensor(  
195 - memory_info, const_cast<float *>(encoder_out), T * encoder_out_dim,  
196 - in_shape.data(), in_shape.size());  
197 -  
198 - auto encoder_proj_out = joiner_encoder_proj_sess_->Run(  
199 - {}, joiner_encoder_proj_input_names_ptr_.data(), &in, 1,  
200 - joiner_encoder_proj_output_names_ptr_.data(),  
201 - joiner_encoder_proj_output_names_ptr_.size());  
202 - return std::move(encoder_proj_out[0]);  
203 -}  
204 -  
205 -Ort::Value RnntModel::RunDecoder(const int64_t *decoder_input,  
206 - int32_t context_size) {  
207 - auto memory_info =  
208 - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);  
209 -  
210 - int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1  
211 - std::array<int64_t, 2> shape{batch_size, context_size};  
212 - Ort::Value in = Ort::Value::CreateTensor(  
213 - memory_info, const_cast<int64_t *>(decoder_input),  
214 - batch_size * context_size, shape.data(), shape.size());  
215 -  
216 - auto decoder_out = decoder_sess_->Run(  
217 - {}, decoder_input_names_ptr_.data(), &in, 1,  
218 - decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size());  
219 - return std::move(decoder_out[0]);  
220 -}  
221 -  
222 -Ort::Value RnntModel::RunJoinerDecoderProj(const float *decoder_out,  
223 - int32_t decoder_out_dim) {  
224 - auto memory_info =  
225 - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);  
226 -  
227 - int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1  
228 - std::array<int64_t, 2> shape{batch_size, decoder_out_dim};  
229 - Ort::Value in = Ort::Value::CreateTensor(  
230 - memory_info, const_cast<float *>(decoder_out),  
231 - batch_size * decoder_out_dim, shape.data(), shape.size());  
232 -  
233 - auto decoder_proj_out = joiner_decoder_proj_sess_->Run(  
234 - {}, joiner_decoder_proj_input_names_ptr_.data(), &in, 1,  
235 - joiner_decoder_proj_output_names_ptr_.data(),  
236 - joiner_decoder_proj_output_names_ptr_.size());  
237 - return std::move(decoder_proj_out[0]);  
238 -}  
239 -  
240 -Ort::Value RnntModel::RunJoiner(const float *projected_encoder_out,  
241 - const float *projected_decoder_out,  
242 - int32_t joiner_dim) {  
243 - auto memory_info =  
244 - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);  
245 - int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1  
246 - std::array<int64_t, 2> shape{batch_size, joiner_dim};  
247 -  
248 - Ort::Value enc = Ort::Value::CreateTensor(  
249 - memory_info, const_cast<float *>(projected_encoder_out),  
250 - batch_size * joiner_dim, shape.data(), shape.size());  
251 -  
252 - Ort::Value dec = Ort::Value::CreateTensor(  
253 - memory_info, const_cast<float *>(projected_decoder_out),  
254 - batch_size * joiner_dim, shape.data(), shape.size());  
255 -  
256 - std::array<Ort::Value, 2> inputs{std::move(enc), std::move(dec)};  
257 -  
258 - auto logit = joiner_sess_->Run(  
259 - {}, joiner_input_names_ptr_.data(), inputs.data(), inputs.size(),  
260 - joiner_output_names_ptr_.data(), joiner_output_names_ptr_.size());  
261 -  
262 - return std::move(logit[0]);  
263 -}  
264 -  
265 -} // namespace sherpa_onnx  
1 -/**  
2 - * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)  
3 - *  
4 - * See LICENSE for clarification regarding multiple authors  
5 - *  
6 - * Licensed under the Apache License, Version 2.0 (the "License");  
7 - * you may not use this file except in compliance with the License.  
8 - * You may obtain a copy of the License at  
9 - *  
10 - * http://www.apache.org/licenses/LICENSE-2.0  
11 - *  
12 - * Unless required by applicable law or agreed to in writing, software  
13 - * distributed under the License is distributed on an "AS IS" BASIS,  
14 - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
15 - * See the License for the specific language governing permissions and  
16 - * limitations under the License.  
17 - */  
18 -  
19 -#ifndef SHERPA_ONNX_CSRC_RNNT_MODEL_H_  
20 -#define SHERPA_ONNX_CSRC_RNNT_MODEL_H_  
21 -  
22 -#include <memory>  
23 -#include <string>  
24 -#include <vector>  
25 -  
26 -#include "onnxruntime_cxx_api.h" // NOLINT  
27 -  
28 -namespace sherpa_onnx {  
29 -  
30 -class RnntModel {  
31 - public:  
32 - /**  
33 - * @param encoder_filename Path to the encoder model  
34 - * @param decoder_filename Path to the decoder model  
35 - * @param joiner_filename Path to the joiner model  
36 - * @param joiner_encoder_proj_filename Path to the joiner encoder_proj model  
37 - * @param joiner_decoder_proj_filename Path to the joiner decoder_proj model  
38 - * @param num_threads Number of threads to use to run the models  
39 - */  
40 - RnntModel(const std::string &encoder_filename,  
41 - const std::string &decoder_filename,  
42 - const std::string &joiner_filename,  
43 - const std::string &joiner_encoder_proj_filename,  
44 - const std::string &joiner_decoder_proj_filename,  
45 - int32_t num_threads);  
46 -  
47 - /** Run the encoder model.  
48 - *  
49 - * @TODO(fangjun): Support batch_size > 1  
50 - *  
51 - * @param features A tensor of shape (batch_size, T, feature_dim)  
52 - * @param T Number of feature frames  
53 - * @param feature_dim Dimension of the feature.  
54 - *  
55 - * @return Return a tensor of shape (batch_size, T', encoder_out_dim)  
56 - */  
57 - Ort::Value RunEncoder(const float *features, int32_t T, int32_t feature_dim);  
58 -  
59 - /** Run the joiner encoder_proj model.  
60 - *  
61 - * @param encoder_out A tensor of shape (T, encoder_out_dim)  
62 - * @param T Number of frames in encoder_out.  
63 - * @param encoder_out_dim Dimension of encoder_out.  
64 - *  
65 - * @return Return a tensor of shape (T, joiner_dim)  
66 - *  
67 - */  
68 - Ort::Value RunJoinerEncoderProj(const float *encoder_out, int32_t T,  
69 - int32_t encoder_out_dim);  
70 -  
71 - /** Run the decoder model.  
72 - *  
73 - * @TODO(fangjun): Support batch_size > 1  
74 - *  
75 - * @param decoder_input A tensor of shape (batch_size, context_size).  
76 - * @return Return a tensor of shape (batch_size, 1, decoder_out_dim)  
77 - */  
78 - Ort::Value RunDecoder(const int64_t *decoder_input, int32_t context_size);  
79 -  
80 - /** Run joiner decoder_proj model.  
81 - *  
82 - * @TODO(fangjun): Support batch_size > 1  
83 - *  
84 - * @param decoder_out A tensor of shape (batch_size, decoder_out_dim)  
85 - * @param decoder_out_dim Output dimension of the decoder_out.  
86 - *  
87 - * @return Return a tensor of shape (batch_size, joiner_dim);  
88 - */  
89 - Ort::Value RunJoinerDecoderProj(const float *decoder_out,  
90 - int32_t decoder_out_dim);  
91 -  
92 - /** Run the joiner model.  
93 - *  
94 - * @TODO(fangjun): Support batch_size > 1  
95 - *  
96 - * @param projected_encoder_out A tensor of shape (batch_size, joiner_dim).  
97 - * @param projected_decoder_out A tensor of shape (batch_size, joiner_dim).  
98 - *  
99 - * @return Return a tensor of shape (batch_size, vocab_size)  
100 - */  
101 - Ort::Value RunJoiner(const float *projected_encoder_out,  
102 - const float *projected_decoder_out, int32_t joiner_dim);  
103 -  
104 - private:  
105 - void InitEncoder(const std::string &encoder_filename);  
106 - void InitDecoder(const std::string &decoder_filename);  
107 - void InitJoiner(const std::string &joiner_filename);  
108 - void InitJoinerEncoderProj(const std::string &joiner_encoder_proj_filename);  
109 - void InitJoinerDecoderProj(const std::string &joiner_decoder_proj_filename);  
110 -  
111 - private:  
112 - Ort::Env env_;  
113 - Ort::SessionOptions sess_opts_;  
114 - std::unique_ptr<Ort::Session> encoder_sess_;  
115 - std::unique_ptr<Ort::Session> decoder_sess_;  
116 - std::unique_ptr<Ort::Session> joiner_sess_;  
117 - std::unique_ptr<Ort::Session> joiner_encoder_proj_sess_;  
118 - std::unique_ptr<Ort::Session> joiner_decoder_proj_sess_;  
119 -  
120 - std::vector<std::string> encoder_input_names_;  
121 - std::vector<const char *> encoder_input_names_ptr_;  
122 - std::vector<std::string> encoder_output_names_;  
123 - std::vector<const char *> encoder_output_names_ptr_;  
124 -  
125 - std::vector<std::string> decoder_input_names_;  
126 - std::vector<const char *> decoder_input_names_ptr_;  
127 - std::vector<std::string> decoder_output_names_;  
128 - std::vector<const char *> decoder_output_names_ptr_;  
129 -  
130 - std::vector<std::string> joiner_input_names_;  
131 - std::vector<const char *> joiner_input_names_ptr_;  
132 - std::vector<std::string> joiner_output_names_;  
133 - std::vector<const char *> joiner_output_names_ptr_;  
134 -  
135 - std::vector<std::string> joiner_encoder_proj_input_names_;  
136 - std::vector<const char *> joiner_encoder_proj_input_names_ptr_;  
137 - std::vector<std::string> joiner_encoder_proj_output_names_;  
138 - std::vector<const char *> joiner_encoder_proj_output_names_ptr_;  
139 -  
140 - std::vector<std::string> joiner_decoder_proj_input_names_;  
141 - std::vector<const char *> joiner_decoder_proj_input_names_ptr_;  
142 - std::vector<std::string> joiner_decoder_proj_output_names_;  
143 - std::vector<const char *> joiner_decoder_proj_output_names_ptr_;  
144 -};  
145 -  
146 -} // namespace sherpa_onnx  
147 -  
148 -#endif // SHERPA_ONNX_CSRC_RNNT_MODEL_H_  
1 -/**  
2 - * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)  
3 - *  
4 - * See LICENSE for clarification regarding multiple authors  
5 - *  
6 - * Licensed under the Apache License, Version 2.0 (the "License");  
7 - * you may not use this file except in compliance with the License.  
8 - * You may obtain a copy of the License at  
9 - *  
10 - * http://www.apache.org/licenses/LICENSE-2.0  
11 - *  
12 - * Unless required by applicable law or agreed to in writing, software  
13 - * distributed under the License is distributed on an "AS IS" BASIS,  
14 - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
15 - * See the License for the specific language governing permissions and  
16 - * limitations under the License.  
17 - */ 1 +// sherpa-onnx/csrc/sherpa-onnx.cc
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
18 4
  5 +#include <chrono> // NOLINT
19 #include <iostream> 6 #include <iostream>
20 #include <string> 7 #include <string>
21 #include <vector> 8 #include <vector>
22 9
23 #include "kaldi-native-fbank/csrc/online-feature.h" 10 #include "kaldi-native-fbank/csrc/online-feature.h"
24 #include "sherpa-onnx/csrc/decode.h" 11 #include "sherpa-onnx/csrc/decode.h"
25 -#include "sherpa-onnx/csrc/rnnt-model.h" 12 +#include "sherpa-onnx/csrc/features.h"
  13 +#include "sherpa-onnx/csrc/online-transducer-model-config.h"
  14 +#include "sherpa-onnx/csrc/online-transducer-model.h"
26 #include "sherpa-onnx/csrc/symbol-table.h" 15 #include "sherpa-onnx/csrc/symbol-table.h"
27 #include "sherpa-onnx/csrc/wave-reader.h" 16 #include "sherpa-onnx/csrc/wave-reader.h"
28 17
29 -/** Compute fbank features of the input wave filename.  
30 - *  
31 - * @param wav_filename. Path to a mono wave file.  
32 - * @param expected_sampling_rate Expected sampling rate of the input wave file.  
33 - * @param num_frames On return, it contains the number of feature frames.  
34 - * @return Return the computed feature of shape (num_frames, feature_dim)  
35 - * stored in row-major.  
36 - */  
37 -static std::vector<float> ComputeFeatures(const std::string &wav_filename,  
38 - float expected_sampling_rate,  
39 - int32_t *num_frames) {  
40 - std::vector<float> samples =  
41 - sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate);  
42 -  
43 - float duration = samples.size() / expected_sampling_rate;  
44 -  
45 - std::cout << "wav filename: " << wav_filename << "\n";  
46 - std::cout << "wav duration (s): " << duration << "\n";  
47 -  
48 - knf::FbankOptions opts;  
49 - opts.frame_opts.dither = 0;  
50 - opts.frame_opts.snip_edges = false;  
51 - opts.frame_opts.samp_freq = expected_sampling_rate;  
52 -  
53 - int32_t feature_dim = 80;  
54 -  
55 - opts.mel_opts.num_bins = feature_dim;  
56 -  
57 - knf::OnlineFbank fbank(opts);  
58 - fbank.AcceptWaveform(expected_sampling_rate, samples.data(), samples.size());  
59 - fbank.InputFinished();  
60 -  
61 - *num_frames = fbank.NumFramesReady();  
62 -  
63 - std::vector<float> features(*num_frames * feature_dim);  
64 - float *p = features.data();  
65 -  
66 - for (int32_t i = 0; i != fbank.NumFramesReady(); ++i, p += feature_dim) {  
67 - const float *f = fbank.GetFrame(i);  
68 - std::copy(f, f + feature_dim, p);  
69 - }  
70 -  
71 - return features;  
72 -}  
73 -  
74 int main(int32_t argc, char *argv[]) { 18 int main(int32_t argc, char *argv[]) {
75 - if (argc < 8 || argc > 9) { 19 + if (argc < 6 || argc > 7) {
76 const char *usage = R"usage( 20 const char *usage = R"usage(
77 Usage: 21 Usage:
78 ./bin/sherpa-onnx \ 22 ./bin/sherpa-onnx \
@@ -80,12 +24,11 @@ Usage: @@ -80,12 +24,11 @@ Usage:
80 /path/to/encoder.onnx \ 24 /path/to/encoder.onnx \
81 /path/to/decoder.onnx \ 25 /path/to/decoder.onnx \
82 /path/to/joiner.onnx \ 26 /path/to/joiner.onnx \
83 - /path/to/joiner_encoder_proj.onnx \  
84 - /path/to/joiner_decoder_proj.onnx \  
85 /path/to/foo.wav [num_threads] 27 /path/to/foo.wav [num_threads]
86 28
87 -You can download pre-trained models from the following repository:  
88 -https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 29 +Please refer to
  30 +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
  31 +for a list of pre-trained models to download.
89 )usage"; 32 )usage";
90 std::cerr << usage << "\n"; 33 std::cerr << usage << "\n";
91 34
@@ -93,37 +36,102 @@ https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stat @@ -93,37 +36,102 @@ https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stat
93 } 36 }
94 37
95 std::string tokens = argv[1]; 38 std::string tokens = argv[1];
96 - std::string encoder = argv[2];  
97 - std::string decoder = argv[3];  
98 - std::string joiner = argv[4];  
99 - std::string joiner_encoder_proj = argv[5];  
100 - std::string joiner_decoder_proj = argv[6];  
101 - std::string wav_filename = argv[7];  
102 - int32_t num_threads = 4;  
103 - if (argc == 9) {  
104 - num_threads = atoi(argv[8]); 39 + sherpa_onnx::OnlineTransducerModelConfig config;
  40 + config.debug = true;
  41 + config.encoder_filename = argv[2];
  42 + config.decoder_filename = argv[3];
  43 + config.joiner_filename = argv[4];
  44 + std::string wav_filename = argv[5];
  45 +
  46 + config.num_threads = 2;
  47 + if (argc == 7) {
  48 + config.num_threads = atoi(argv[6]);
105 } 49 }
  50 + std::cout << config.ToString().c_str() << "\n";
  51 +
  52 + auto model = sherpa_onnx::OnlineTransducerModel::Create(config);
106 53
107 sherpa_onnx::SymbolTable sym(tokens); 54 sherpa_onnx::SymbolTable sym(tokens);
108 55
109 - int32_t num_frames;  
110 - auto features = ComputeFeatures(wav_filename, 16000, &num_frames);  
111 - int32_t feature_dim = features.size() / num_frames; 56 + Ort::AllocatorWithDefaultOptions allocator;
  57 +
  58 + int32_t chunk_size = model->ChunkSize();
  59 + int32_t chunk_shift = model->ChunkShift();
  60 +
  61 + auto memory_info =
  62 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  63 +
  64 + std::vector<Ort::Value> states = model->GetEncoderInitStates();
112 65
113 - sherpa_onnx::RnntModel model(encoder, decoder, joiner, joiner_encoder_proj,  
114 - joiner_decoder_proj, num_threads);  
115 - Ort::Value encoder_out =  
116 - model.RunEncoder(features.data(), num_frames, feature_dim); 66 + std::vector<int64_t> hyp(model->ContextSize(), 0);
117 67
118 - auto hyp = sherpa_onnx::GreedySearch(model, encoder_out); 68 + int32_t expected_sampling_rate = 16000;
119 69
  70 + bool is_ok = false;
  71 + std::vector<float> samples =
  72 + sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate, &is_ok);
  73 +
  74 + if (!is_ok) {
  75 + std::cerr << "Failed to read " << wav_filename << "\n";
  76 + return -1;
  77 + }
  78 +
  79 + const float duration = samples.size() / expected_sampling_rate;
  80 +
  81 + std::cout << "wav filename: " << wav_filename << "\n";
  82 + std::cout << "wav duration (s): " << duration << "\n";
  83 +
  84 + auto begin = std::chrono::steady_clock::now();
  85 + std::cout << "Started!\n";
  86 +
  87 + sherpa_onnx::FeatureExtractor feat_extractor;
  88 + feat_extractor.AcceptWaveform(expected_sampling_rate, samples.data(),
  89 + samples.size());
  90 +
  91 + std::vector<float> tail_paddings(
  92 + static_cast<int>(0.2 * expected_sampling_rate));
  93 + feat_extractor.AcceptWaveform(expected_sampling_rate, tail_paddings.data(),
  94 + tail_paddings.size());
  95 + feat_extractor.InputFinished();
  96 +
  97 + int32_t num_frames = feat_extractor.NumFramesReady();
  98 + int32_t feature_dim = feat_extractor.FeatureDim();
  99 +
  100 + std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim};
  101 +
  102 + for (int32_t start = 0; start + chunk_size < num_frames;
  103 + start += chunk_shift) {
  104 + std::vector<float> features = feat_extractor.GetFrames(start, chunk_size);
  105 +
  106 + Ort::Value x =
  107 + Ort::Value::CreateTensor(memory_info, features.data(), features.size(),
  108 + x_shape.data(), x_shape.size());
  109 + auto pair = model->RunEncoder(std::move(x), states);
  110 + states = std::move(pair.second);
  111 + sherpa_onnx::GreedySearch(model.get(), std::move(pair.first), &hyp);
  112 + }
120 std::string text; 113 std::string text;
121 - for (auto i : hyp) {  
122 - text += sym[i]; 114 + for (size_t i = model->ContextSize(); i != hyp.size(); ++i) {
  115 + text += sym[hyp[i]];
123 } 116 }
124 117
  118 + std::cout << "Done!\n";
  119 +
125 std::cout << "Recognition result for " << wav_filename << "\n" 120 std::cout << "Recognition result for " << wav_filename << "\n"
126 << text << "\n"; 121 << text << "\n";
127 122
  123 + auto end = std::chrono::steady_clock::now();
  124 + float elapsed_seconds =
  125 + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
  126 + .count() /
  127 + 1000.;
  128 +
  129 + std::cout << "num threads: " << config.num_threads << "\n";
  130 +
  131 + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
  132 + float rtf = elapsed_seconds / duration;
  133 + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
  134 + elapsed_seconds, duration, rtf);
  135 +
128 return 0; 136 return 0;
129 } 137 }
1 -/**  
2 - * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)  
3 - *  
4 - * See LICENSE for clarification regarding multiple authors  
5 - *  
6 - * Licensed under the Apache License, Version 2.0 (the "License");  
7 - * you may not use this file except in compliance with the License.  
8 - * You may obtain a copy of the License at  
9 - *  
10 - * http://www.apache.org/licenses/LICENSE-2.0  
11 - *  
12 - * Unless required by applicable law or agreed to in writing, software  
13 - * distributed under the License is distributed on an "AS IS" BASIS,  
14 - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
15 - * See the License for the specific language governing permissions and  
16 - * limitations under the License.  
17 - */ 1 +// sherpa-onnx/csrc/show-onnx-info.cc
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +
18 #include <iostream> 5 #include <iostream>
19 #include <sstream> 6 #include <sstream>
20 7
1 -/**  
2 - * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)  
3 - *  
4 - * See LICENSE for clarification regarding multiple authors  
5 - *  
6 - * Licensed under the Apache License, Version 2.0 (the "License");  
7 - * you may not use this file except in compliance with the License.  
8 - * You may obtain a copy of the License at  
9 - *  
10 - * http://www.apache.org/licenses/LICENSE-2.0  
11 - *  
12 - * Unless required by applicable law or agreed to in writing, software  
13 - * distributed under the License is distributed on an "AS IS" BASIS,  
14 - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
15 - * See the License for the specific language governing permissions and  
16 - * limitations under the License.  
17 - */ 1 +// sherpa-onnx/csrc/symbol-table.cc
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
18 4
19 #include "sherpa-onnx/csrc/symbol-table.h" 5 #include "sherpa-onnx/csrc/symbol-table.h"
20 6
1 -/**  
2 - * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)  
3 - *  
4 - * See LICENSE for clarification regarding multiple authors  
5 - *  
6 - * Licensed under the Apache License, Version 2.0 (the "License");  
7 - * you may not use this file except in compliance with the License.  
8 - * You may obtain a copy of the License at  
9 - *  
10 - * http://www.apache.org/licenses/LICENSE-2.0  
11 - *  
12 - * Unless required by applicable law or agreed to in writing, software  
13 - * distributed under the License is distributed on an "AS IS" BASIS,  
14 - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
15 - * See the License for the specific language governing permissions and  
16 - * limitations under the License.  
17 - */ 1 +// sherpa-onnx/csrc/symbol-table.cc
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
18 4
19 #ifndef SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_ 5 #ifndef SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_
20 #define SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_ 6 #define SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_
1 -/**  
2 - * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)  
3 - *  
4 - * See LICENSE for clarification regarding multiple authors  
5 - *  
6 - * Licensed under the Apache License, Version 2.0 (the "License");  
7 - * you may not use this file except in compliance with the License.  
8 - * You may obtain a copy of the License at  
9 - *  
10 - * http://www.apache.org/licenses/LICENSE-2.0  
11 - *  
12 - * Unless required by applicable law or agreed to in writing, software  
13 - * distributed under the License is distributed on an "AS IS" BASIS,  
14 - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
15 - * See the License for the specific language governing permissions and  
16 - * limitations under the License.  
17 - */ 1 +// sherpa/csrc/wave-reader.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
18 4
19 #include "sherpa-onnx/csrc/wave-reader.h" 5 #include "sherpa-onnx/csrc/wave-reader.h"
20 6
@@ -31,19 +17,44 @@ namespace { @@ -31,19 +17,44 @@ namespace {
31 // Note: We assume little endian here 17 // Note: We assume little endian here
32 // TODO(fangjun): Support big endian 18 // TODO(fangjun): Support big endian
33 struct WaveHeader { 19 struct WaveHeader {
34 - void Validate() const {  
35 - // F F I R  
36 - assert(chunk_id == 0x46464952);  
37 - assert(chunk_size == 36 + subchunk2_size);  
38 - // E V A W  
39 - assert(format == 0x45564157);  
40 - assert(subchunk1_id == 0x20746d66);  
41 - assert(subchunk1_size == 16); // 16 for PCM  
42 - assert(audio_format == 1); // 1 for PCM  
43 - assert(num_channels == 1); // we support only single channel for now  
44 - assert(byte_rate == sample_rate * num_channels * bits_per_sample / 8);  
45 - assert(block_align == num_channels * bits_per_sample / 8);  
46 - assert(bits_per_sample == 16); // we support only 16 bits per sample 20 + bool Validate() const {
  21 + // F F I R
  22 + if (chunk_id != 0x46464952) {
  23 + return false;
  24 + }
  25 + // E V A W
  26 + if (format != 0x45564157) {
  27 + return false;
  28 + }
  29 +
  30 + if (subchunk1_id != 0x20746d66) {
  31 + return false;
  32 + }
  33 +
  34 + if (subchunk1_size != 16) { // 16 for PCM
  35 + return false;
  36 + }
  37 +
  38 + if (audio_format != 1) { // 1 for PCM
  39 + return false;
  40 + }
  41 +
  42 + if (num_channels != 1) { // we support only single channel for now
  43 + return false;
  44 + }
  45 + if (byte_rate != (sample_rate * num_channels * bits_per_sample / 8)) {
  46 + return false;
  47 + }
  48 +
  49 + if (block_align != (num_channels * bits_per_sample / 8)) {
  50 + return false;
  51 + }
  52 +
  53 + if (bits_per_sample != 16) { // we support only 16 bits per sample
  54 + return false;
  55 + }
  56 +
  57 + return true;
47 } 58 }
48 59
49 // See 60 // See
@@ -52,7 +63,7 @@ struct WaveHeader { @@ -52,7 +63,7 @@ struct WaveHeader {
52 // https://www.robotplanet.dk/audio/wav_meta_data/riff_mci.pdf 63 // https://www.robotplanet.dk/audio/wav_meta_data/riff_mci.pdf
53 void SeekToDataChunk(std::istream &is) { 64 void SeekToDataChunk(std::istream &is) {
54 // a t a d 65 // a t a d
55 - while (subchunk2_id != 0x61746164) { 66 + while (is && subchunk2_id != 0x61746164) {
56 // const char *p = reinterpret_cast<const char *>(&subchunk2_id); 67 // const char *p = reinterpret_cast<const char *>(&subchunk2_id);
57 // printf("Skip chunk (%x): %c%c%c%c of size: %d\n", subchunk2_id, p[0], 68 // printf("Skip chunk (%x): %c%c%c%c of size: %d\n", subchunk2_id, p[0],
58 // p[1], p[2], p[3], subchunk2_size); 69 // p[1], p[2], p[3], subchunk2_size);
@@ -80,44 +91,61 @@ static_assert(sizeof(WaveHeader) == 44, ""); @@ -80,44 +91,61 @@ static_assert(sizeof(WaveHeader) == 44, "");
80 91
81 // Read a wave file of mono-channel. 92 // Read a wave file of mono-channel.
82 // Return its samples normalized to the range [-1, 1). 93 // Return its samples normalized to the range [-1, 1).
83 -std::vector<float> ReadWaveImpl(std::istream &is, float *sample_rate) { 94 +std::vector<float> ReadWaveImpl(std::istream &is, float expected_sample_rate,
  95 + bool *is_ok) {
84 WaveHeader header; 96 WaveHeader header;
85 is.read(reinterpret_cast<char *>(&header), sizeof(header)); 97 is.read(reinterpret_cast<char *>(&header), sizeof(header));
86 - assert(static_cast<bool>(is));  
87 - header.Validate(); 98 + if (!is) {
  99 + *is_ok = false;
  100 + return {};
  101 + }
  102 +
  103 + if (!header.Validate()) {
  104 + *is_ok = false;
  105 + return {};
  106 + }
88 107
89 header.SeekToDataChunk(is); 108 header.SeekToDataChunk(is);
  109 + if (!is) {
  110 + *is_ok = false;
  111 + return {};
  112 + }
90 113
91 - *sample_rate = header.sample_rate; 114 + if (expected_sample_rate != header.sample_rate) {
  115 + *is_ok = false;
  116 + return {};
  117 + }
92 118
93 // header.subchunk2_size contains the number of bytes in the data. 119 // header.subchunk2_size contains the number of bytes in the data.
94 // As we assume each sample contains two bytes, so it is divided by 2 here 120 // As we assume each sample contains two bytes, so it is divided by 2 here
95 std::vector<int16_t> samples(header.subchunk2_size / 2); 121 std::vector<int16_t> samples(header.subchunk2_size / 2);
96 122
97 is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size); 123 is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size);
98 -  
99 - assert(static_cast<bool>(is)); 124 + if (!is) {
  125 + *is_ok = false;
  126 + return {};
  127 + }
100 128
101 std::vector<float> ans(samples.size()); 129 std::vector<float> ans(samples.size());
102 for (int32_t i = 0; i != ans.size(); ++i) { 130 for (int32_t i = 0; i != ans.size(); ++i) {
103 ans[i] = samples[i] / 32768.; 131 ans[i] = samples[i] / 32768.;
104 } 132 }
105 133
  134 + *is_ok = true;
106 return ans; 135 return ans;
107 } 136 }
108 137
109 } // namespace 138 } // namespace
110 139
111 std::vector<float> ReadWave(const std::string &filename, 140 std::vector<float> ReadWave(const std::string &filename,
112 - float expected_sample_rate) { 141 + float expected_sample_rate, bool *is_ok) {
113 std::ifstream is(filename, std::ifstream::binary); 142 std::ifstream is(filename, std::ifstream::binary);
114 - float sample_rate;  
115 - auto samples = ReadWaveImpl(is, &sample_rate);  
116 - if (expected_sample_rate != sample_rate) {  
117 - std::cerr << "Expected sample rate: " << expected_sample_rate  
118 - << ". Given: " << sample_rate << ".\n";  
119 - exit(-1);  
120 - } 143 + return ReadWave(is, expected_sample_rate, is_ok);
  144 +}
  145 +
  146 +std::vector<float> ReadWave(std::istream &is, float expected_sample_rate,
  147 + bool *is_ok) {
  148 + auto samples = ReadWaveImpl(is, expected_sample_rate, is_ok);
121 return samples; 149 return samples;
122 } 150 }
123 151
1 -/**  
2 - * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)  
3 - *  
4 - * See LICENSE for clarification regarding multiple authors  
5 - *  
6 - * Licensed under the Apache License, Version 2.0 (the "License");  
7 - * you may not use this file except in compliance with the License.  
8 - * You may obtain a copy of the License at  
9 - *  
10 - * http://www.apache.org/licenses/LICENSE-2.0  
11 - *  
12 - * Unless required by applicable law or agreed to in writing, software  
13 - * distributed under the License is distributed on an "AS IS" BASIS,  
14 - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
15 - * See the License for the specific language governing permissions and  
16 - * limitations under the License.  
17 - */ 1 +// sherpa/csrc/wave-reader.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
18 4
19 #ifndef SHERPA_ONNX_CSRC_WAVE_READER_H_ 5 #ifndef SHERPA_ONNX_CSRC_WAVE_READER_H_
20 #define SHERPA_ONNX_CSRC_WAVE_READER_H_ 6 #define SHERPA_ONNX_CSRC_WAVE_READER_H_
@@ -30,11 +16,15 @@ namespace sherpa_onnx { @@ -30,11 +16,15 @@ namespace sherpa_onnx {
30 @param filename Path to a wave file. It MUST be single channel, PCM encoded. 16 @param filename Path to a wave file. It MUST be single channel, PCM encoded.
31 @param expected_sample_rate Expected sample rate of the wave file. If the 17 @param expected_sample_rate Expected sample rate of the wave file. If the
32 sample rate don't match, it throws an exception. 18 sample rate don't match, it throws an exception.
  19 + @param is_ok On return it is true if the reading succeeded; false otherwise.
33 20
34 @return Return wave samples normalized to the range [-1, 1). 21 @return Return wave samples normalized to the range [-1, 1).
35 */ 22 */
36 std::vector<float> ReadWave(const std::string &filename, 23 std::vector<float> ReadWave(const std::string &filename,
37 - float expected_sample_rate); 24 + float expected_sample_rate, bool *is_ok);
  25 +
  26 +std::vector<float> ReadWave(std::istream &is, float expected_sample_rate,
  27 + bool *is_ok);
38 28
39 } // namespace sherpa_onnx 29 } // namespace sherpa_onnx
40 30