Committed by
GitHub
Begin to support CTC models (#119)
Please see https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/nemo/index.html for a list of pre-trained CTC models from NeMo.
正在显示
40 个修改的文件
包含
1244 行增加
和
60 行删除
.github/scripts/test-offline-ctc.sh
0 → 100755
| 1 | +#!/usr/bin/env bash | ||
| 2 | + | ||
| 3 | +set -e | ||
| 4 | + | ||
| 5 | +log() { | ||
| 6 | + # This function is from espnet | ||
| 7 | + local fname=${BASH_SOURCE[1]##*/} | ||
| 8 | + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | ||
| 9 | +} | ||
| 10 | + | ||
| 11 | +echo "EXE is $EXE" | ||
| 12 | +echo "PATH: $PATH" | ||
| 13 | + | ||
| 14 | +which $EXE | ||
| 15 | + | ||
| 16 | +log "------------------------------------------------------------" | ||
| 17 | +log "Run Citrinet (stt_en_citrinet_512, English)" | ||
| 18 | +log "------------------------------------------------------------" | ||
| 19 | + | ||
| 20 | +repo_url=http://huggingface.co/csukuangfj/sherpa-onnx-nemo-ctc-en-citrinet-512 | ||
| 21 | +log "Start testing ${repo_url}" | ||
| 22 | +repo=$(basename $repo_url) | ||
| 23 | +log "Download pretrained model and test-data from $repo_url" | ||
| 24 | + | ||
| 25 | +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | ||
| 26 | +pushd $repo | ||
| 27 | +git lfs pull --include "*.onnx" | ||
| 28 | +ls -lh *.onnx | ||
| 29 | +popd | ||
| 30 | + | ||
| 31 | +time $EXE \ | ||
| 32 | + --tokens=$repo/tokens.txt \ | ||
| 33 | + --nemo-ctc-model=$repo/model.onnx \ | ||
| 34 | + --num-threads=2 \ | ||
| 35 | + $repo/test_wavs/0.wav \ | ||
| 36 | + $repo/test_wavs/1.wav \ | ||
| 37 | + $repo/test_wavs/8k.wav | ||
| 38 | + | ||
| 39 | +time $EXE \ | ||
| 40 | + --tokens=$repo/tokens.txt \ | ||
| 41 | + --nemo-ctc-model=$repo/model.int8.onnx \ | ||
| 42 | + --num-threads=2 \ | ||
| 43 | + $repo/test_wavs/0.wav \ | ||
| 44 | + $repo/test_wavs/1.wav \ | ||
| 45 | + $repo/test_wavs/8k.wav | ||
| 46 | + | ||
| 47 | +rm -rf $repo |
| @@ -95,6 +95,8 @@ python3 ./python-api-examples/offline-decode-files.py \ | @@ -95,6 +95,8 @@ python3 ./python-api-examples/offline-decode-files.py \ | ||
| 95 | 95 | ||
| 96 | python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose | 96 | python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose |
| 97 | 97 | ||
| 98 | +rm -rf $repo | ||
| 99 | + | ||
| 98 | log "Test non-streaming paraformer models" | 100 | log "Test non-streaming paraformer models" |
| 99 | 101 | ||
| 100 | pushd $dir | 102 | pushd $dir |
| @@ -128,3 +130,39 @@ python3 ./python-api-examples/offline-decode-files.py \ | @@ -128,3 +130,39 @@ python3 ./python-api-examples/offline-decode-files.py \ | ||
| 128 | $repo/test_wavs/8k.wav | 130 | $repo/test_wavs/8k.wav |
| 129 | 131 | ||
| 130 | python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose | 132 | python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose |
| 133 | + | ||
| 134 | +rm -rf $repo | ||
| 135 | + | ||
| 136 | +log "Test non-streaming NeMo CTC models" | ||
| 137 | + | ||
| 138 | +pushd $dir | ||
| 139 | +repo_url=http://huggingface.co/csukuangfj/sherpa-onnx-nemo-ctc-en-citrinet-512 | ||
| 140 | + | ||
| 141 | +log "Start testing ${repo_url}" | ||
| 142 | +repo=$dir/$(basename $repo_url) | ||
| 143 | +log "Download pretrained model and test-data from $repo_url" | ||
| 144 | + | ||
| 145 | +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | ||
| 146 | +cd $repo | ||
| 147 | +git lfs pull --include "*.onnx" | ||
| 148 | +popd | ||
| 149 | + | ||
| 150 | +ls -lh $repo | ||
| 151 | + | ||
| 152 | +python3 ./python-api-examples/offline-decode-files.py \ | ||
| 153 | + --tokens=$repo/tokens.txt \ | ||
| 154 | + --nemo-ctc=$repo/model.onnx \ | ||
| 155 | + $repo/test_wavs/0.wav \ | ||
| 156 | + $repo/test_wavs/1.wav \ | ||
| 157 | + $repo/test_wavs/8k.wav | ||
| 158 | + | ||
| 159 | +python3 ./python-api-examples/offline-decode-files.py \ | ||
| 160 | + --tokens=$repo/tokens.txt \ | ||
| 161 | + --nemo-ctc=$repo/model.int8.onnx \ | ||
| 162 | + $repo/test_wavs/0.wav \ | ||
| 163 | + $repo/test_wavs/1.wav \ | ||
| 164 | + $repo/test_wavs/8k.wav | ||
| 165 | + | ||
| 166 | +python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose | ||
| 167 | + | ||
| 168 | +rm -rf $repo |
| @@ -8,6 +8,7 @@ on: | @@ -8,6 +8,7 @@ on: | ||
| 8 | - '.github/workflows/linux.yaml' | 8 | - '.github/workflows/linux.yaml' |
| 9 | - '.github/scripts/test-online-transducer.sh' | 9 | - '.github/scripts/test-online-transducer.sh' |
| 10 | - '.github/scripts/test-offline-transducer.sh' | 10 | - '.github/scripts/test-offline-transducer.sh' |
| 11 | + - '.github/scripts/test-offline-ctc.sh' | ||
| 11 | - 'CMakeLists.txt' | 12 | - 'CMakeLists.txt' |
| 12 | - 'cmake/**' | 13 | - 'cmake/**' |
| 13 | - 'sherpa-onnx/csrc/*' | 14 | - 'sherpa-onnx/csrc/*' |
| @@ -20,6 +21,7 @@ on: | @@ -20,6 +21,7 @@ on: | ||
| 20 | - '.github/workflows/linux.yaml' | 21 | - '.github/workflows/linux.yaml' |
| 21 | - '.github/scripts/test-online-transducer.sh' | 22 | - '.github/scripts/test-online-transducer.sh' |
| 22 | - '.github/scripts/test-offline-transducer.sh' | 23 | - '.github/scripts/test-offline-transducer.sh' |
| 24 | + - '.github/scripts/test-offline-ctc.sh' | ||
| 23 | - 'CMakeLists.txt' | 25 | - 'CMakeLists.txt' |
| 24 | - 'cmake/**' | 26 | - 'cmake/**' |
| 25 | - 'sherpa-onnx/csrc/*' | 27 | - 'sherpa-onnx/csrc/*' |
| @@ -68,6 +70,14 @@ jobs: | @@ -68,6 +70,14 @@ jobs: | ||
| 68 | file build/bin/sherpa-onnx | 70 | file build/bin/sherpa-onnx |
| 69 | readelf -d build/bin/sherpa-onnx | 71 | readelf -d build/bin/sherpa-onnx |
| 70 | 72 | ||
| 73 | + - name: Test offline CTC | ||
| 74 | + shell: bash | ||
| 75 | + run: | | ||
| 76 | + export PATH=$PWD/build/bin:$PATH | ||
| 77 | + export EXE=sherpa-onnx-offline | ||
| 78 | + | ||
| 79 | + .github/scripts/test-offline-ctc.sh | ||
| 80 | + | ||
| 71 | - name: Test offline transducer | 81 | - name: Test offline transducer |
| 72 | shell: bash | 82 | shell: bash |
| 73 | run: | | 83 | run: | |
| @@ -8,6 +8,7 @@ on: | @@ -8,6 +8,7 @@ on: | ||
| 8 | - '.github/workflows/macos.yaml' | 8 | - '.github/workflows/macos.yaml' |
| 9 | - '.github/scripts/test-online-transducer.sh' | 9 | - '.github/scripts/test-online-transducer.sh' |
| 10 | - '.github/scripts/test-offline-transducer.sh' | 10 | - '.github/scripts/test-offline-transducer.sh' |
| 11 | + - '.github/scripts/test-offline-ctc.sh' | ||
| 11 | - 'CMakeLists.txt' | 12 | - 'CMakeLists.txt' |
| 12 | - 'cmake/**' | 13 | - 'cmake/**' |
| 13 | - 'sherpa-onnx/csrc/*' | 14 | - 'sherpa-onnx/csrc/*' |
| @@ -18,6 +19,7 @@ on: | @@ -18,6 +19,7 @@ on: | ||
| 18 | - '.github/workflows/macos.yaml' | 19 | - '.github/workflows/macos.yaml' |
| 19 | - '.github/scripts/test-online-transducer.sh' | 20 | - '.github/scripts/test-online-transducer.sh' |
| 20 | - '.github/scripts/test-offline-transducer.sh' | 21 | - '.github/scripts/test-offline-transducer.sh' |
| 22 | + - '.github/scripts/test-offline-ctc.sh' | ||
| 21 | - 'CMakeLists.txt' | 23 | - 'CMakeLists.txt' |
| 22 | - 'cmake/**' | 24 | - 'cmake/**' |
| 23 | - 'sherpa-onnx/csrc/*' | 25 | - 'sherpa-onnx/csrc/*' |
| @@ -67,6 +69,14 @@ jobs: | @@ -67,6 +69,14 @@ jobs: | ||
| 67 | otool -L build/bin/sherpa-onnx | 69 | otool -L build/bin/sherpa-onnx |
| 68 | otool -l build/bin/sherpa-onnx | 70 | otool -l build/bin/sherpa-onnx |
| 69 | 71 | ||
| 72 | + - name: Test offline CTC | ||
| 73 | + shell: bash | ||
| 74 | + run: | | ||
| 75 | + export PATH=$PWD/build/bin:$PATH | ||
| 76 | + export EXE=sherpa-onnx-offline | ||
| 77 | + | ||
| 78 | + .github/scripts/test-offline-ctc.sh | ||
| 79 | + | ||
| 70 | - name: Test offline transducer | 80 | - name: Test offline transducer |
| 71 | shell: bash | 81 | shell: bash |
| 72 | run: | | 82 | run: | |
| @@ -8,6 +8,7 @@ on: | @@ -8,6 +8,7 @@ on: | ||
| 8 | - '.github/workflows/windows-x64.yaml' | 8 | - '.github/workflows/windows-x64.yaml' |
| 9 | - '.github/scripts/test-online-transducer.sh' | 9 | - '.github/scripts/test-online-transducer.sh' |
| 10 | - '.github/scripts/test-offline-transducer.sh' | 10 | - '.github/scripts/test-offline-transducer.sh' |
| 11 | + - '.github/scripts/test-offline-ctc.sh' | ||
| 11 | - 'CMakeLists.txt' | 12 | - 'CMakeLists.txt' |
| 12 | - 'cmake/**' | 13 | - 'cmake/**' |
| 13 | - 'sherpa-onnx/csrc/*' | 14 | - 'sherpa-onnx/csrc/*' |
| @@ -18,6 +19,7 @@ on: | @@ -18,6 +19,7 @@ on: | ||
| 18 | - '.github/workflows/windows-x64.yaml' | 19 | - '.github/workflows/windows-x64.yaml' |
| 19 | - '.github/scripts/test-online-transducer.sh' | 20 | - '.github/scripts/test-online-transducer.sh' |
| 20 | - '.github/scripts/test-offline-transducer.sh' | 21 | - '.github/scripts/test-offline-transducer.sh' |
| 22 | + - '.github/scripts/test-offline-ctc.sh' | ||
| 21 | - 'CMakeLists.txt' | 23 | - 'CMakeLists.txt' |
| 22 | - 'cmake/**' | 24 | - 'cmake/**' |
| 23 | - 'sherpa-onnx/csrc/*' | 25 | - 'sherpa-onnx/csrc/*' |
| @@ -73,6 +75,14 @@ jobs: | @@ -73,6 +75,14 @@ jobs: | ||
| 73 | 75 | ||
| 74 | ls -lh ./bin/Release/sherpa-onnx.exe | 76 | ls -lh ./bin/Release/sherpa-onnx.exe |
| 75 | 77 | ||
| 78 | + - name: Test offline CTC for windows x64 | ||
| 79 | + shell: bash | ||
| 80 | + run: | | ||
| 81 | + export PATH=$PWD/build/bin/Release:$PATH | ||
| 82 | + export EXE=sherpa-onnx-offline.exe | ||
| 83 | + | ||
| 84 | + .github/scripts/test-offline-ctc.sh | ||
| 85 | + | ||
| 76 | - name: Test offline transducer for Windows x64 | 86 | - name: Test offline transducer for Windows x64 |
| 77 | shell: bash | 87 | shell: bash |
| 78 | run: | | 88 | run: | |
| @@ -8,6 +8,7 @@ on: | @@ -8,6 +8,7 @@ on: | ||
| 8 | - '.github/workflows/windows-x86.yaml' | 8 | - '.github/workflows/windows-x86.yaml' |
| 9 | - '.github/scripts/test-online-transducer.sh' | 9 | - '.github/scripts/test-online-transducer.sh' |
| 10 | - '.github/scripts/test-offline-transducer.sh' | 10 | - '.github/scripts/test-offline-transducer.sh' |
| 11 | + - '.github/scripts/test-offline-ctc.sh' | ||
| 11 | - 'CMakeLists.txt' | 12 | - 'CMakeLists.txt' |
| 12 | - 'cmake/**' | 13 | - 'cmake/**' |
| 13 | - 'sherpa-onnx/csrc/*' | 14 | - 'sherpa-onnx/csrc/*' |
| @@ -18,6 +19,7 @@ on: | @@ -18,6 +19,7 @@ on: | ||
| 18 | - '.github/workflows/windows-x86.yaml' | 19 | - '.github/workflows/windows-x86.yaml' |
| 19 | - '.github/scripts/test-online-transducer.sh' | 20 | - '.github/scripts/test-online-transducer.sh' |
| 20 | - '.github/scripts/test-offline-transducer.sh' | 21 | - '.github/scripts/test-offline-transducer.sh' |
| 22 | + - '.github/scripts/test-offline-ctc.sh' | ||
| 21 | - 'CMakeLists.txt' | 23 | - 'CMakeLists.txt' |
| 22 | - 'cmake/**' | 24 | - 'cmake/**' |
| 23 | - 'sherpa-onnx/csrc/*' | 25 | - 'sherpa-onnx/csrc/*' |
| @@ -31,6 +33,7 @@ permissions: | @@ -31,6 +33,7 @@ permissions: | ||
| 31 | 33 | ||
| 32 | jobs: | 34 | jobs: |
| 33 | windows_x86: | 35 | windows_x86: |
| 36 | + if: false # disable windows x86 CI for now | ||
| 34 | runs-on: ${{ matrix.os }} | 37 | runs-on: ${{ matrix.os }} |
| 35 | name: ${{ matrix.vs-version }} | 38 | name: ${{ matrix.vs-version }} |
| 36 | strategy: | 39 | strategy: |
| @@ -73,6 +76,14 @@ jobs: | @@ -73,6 +76,14 @@ jobs: | ||
| 73 | 76 | ||
| 74 | ls -lh ./bin/Release/sherpa-onnx.exe | 77 | ls -lh ./bin/Release/sherpa-onnx.exe |
| 75 | 78 | ||
| 79 | + - name: Test offline CTC for windows x86 | ||
| 80 | + shell: bash | ||
| 81 | + run: | | ||
| 82 | + export PATH=$PWD/build/bin/Release:$PATH | ||
| 83 | + export EXE=sherpa-onnx-offline.exe | ||
| 84 | + | ||
| 85 | + .github/scripts/test-offline-ctc.sh | ||
| 86 | + | ||
| 76 | - name: Test offline transducer for Windows x86 | 87 | - name: Test offline transducer for Windows x86 |
| 77 | shell: bash | 88 | shell: bash |
| 78 | run: | | 89 | run: | |
| @@ -52,3 +52,6 @@ run-offline-websocket-client-*.sh | @@ -52,3 +52,6 @@ run-offline-websocket-client-*.sh | ||
| 52 | run-sherpa-onnx-*.sh | 52 | run-sherpa-onnx-*.sh |
| 53 | sherpa-onnx-zipformer-en-2023-03-30 | 53 | sherpa-onnx-zipformer-en-2023-03-30 |
| 54 | sherpa-onnx-zipformer-en-2023-04-01 | 54 | sherpa-onnx-zipformer-en-2023-04-01 |
| 55 | +run-offline-decode-files.sh | ||
| 56 | +sherpa-onnx-nemo-ctc-en-citrinet-512 | ||
| 57 | +run-offline-decode-files-nemo-ctc.sh |
| @@ -6,7 +6,7 @@ | @@ -6,7 +6,7 @@ | ||
| 6 | This file demonstrates how to use sherpa-onnx Python API to transcribe | 6 | This file demonstrates how to use sherpa-onnx Python API to transcribe |
| 7 | file(s) with a non-streaming model. | 7 | file(s) with a non-streaming model. |
| 8 | 8 | ||
| 9 | -paraformer Usage: | 9 | +(1) For paraformer |
| 10 | ./python-api-examples/offline-decode-files.py \ | 10 | ./python-api-examples/offline-decode-files.py \ |
| 11 | --tokens=/path/to/tokens.txt \ | 11 | --tokens=/path/to/tokens.txt \ |
| 12 | --paraformer=/path/to/paraformer.onnx \ | 12 | --paraformer=/path/to/paraformer.onnx \ |
| @@ -18,7 +18,7 @@ paraformer Usage: | @@ -18,7 +18,7 @@ paraformer Usage: | ||
| 18 | /path/to/0.wav \ | 18 | /path/to/0.wav \ |
| 19 | /path/to/1.wav | 19 | /path/to/1.wav |
| 20 | 20 | ||
| 21 | -transducer Usage: | 21 | +(2) For transducer models from icefall |
| 22 | ./python-api-examples/offline-decode-files.py \ | 22 | ./python-api-examples/offline-decode-files.py \ |
| 23 | --tokens=/path/to/tokens.txt \ | 23 | --tokens=/path/to/tokens.txt \ |
| 24 | --encoder=/path/to/encoder.onnx \ | 24 | --encoder=/path/to/encoder.onnx \ |
| @@ -32,6 +32,8 @@ transducer Usage: | @@ -32,6 +32,8 @@ transducer Usage: | ||
| 32 | /path/to/0.wav \ | 32 | /path/to/0.wav \ |
| 33 | /path/to/1.wav | 33 | /path/to/1.wav |
| 34 | 34 | ||
| 35 | +(3) For CTC models from NeMo | ||
| 36 | + | ||
| 35 | Please refer to | 37 | Please refer to |
| 36 | https://k2-fsa.github.io/sherpa/onnx/index.html | 38 | https://k2-fsa.github.io/sherpa/onnx/index.html |
| 37 | to install sherpa-onnx and to download the pre-trained models | 39 | to install sherpa-onnx and to download the pre-trained models |
| @@ -83,7 +85,14 @@ def get_args(): | @@ -83,7 +85,14 @@ def get_args(): | ||
| 83 | "--paraformer", | 85 | "--paraformer", |
| 84 | default="", | 86 | default="", |
| 85 | type=str, | 87 | type=str, |
| 86 | - help="Path to the paraformer model", | 88 | + help="Path to the model.onnx from Paraformer", |
| 89 | + ) | ||
| 90 | + | ||
| 91 | + parser.add_argument( | ||
| 92 | + "--nemo-ctc", | ||
| 93 | + default="", | ||
| 94 | + type=str, | ||
| 95 | + help="Path to the model.onnx from NeMo CTC", | ||
| 87 | ) | 96 | ) |
| 88 | 97 | ||
| 89 | parser.add_argument( | 98 | parser.add_argument( |
| @@ -171,11 +180,14 @@ def main(): | @@ -171,11 +180,14 @@ def main(): | ||
| 171 | args = get_args() | 180 | args = get_args() |
| 172 | assert_file_exists(args.tokens) | 181 | assert_file_exists(args.tokens) |
| 173 | assert args.num_threads > 0, args.num_threads | 182 | assert args.num_threads > 0, args.num_threads |
| 174 | - if len(args.encoder) > 0: | 183 | + if args.encoder: |
| 184 | + assert len(args.paraformer) == 0, args.paraformer | ||
| 185 | + assert len(args.nemo_ctc) == 0, args.nemo_ctc | ||
| 186 | + | ||
| 175 | assert_file_exists(args.encoder) | 187 | assert_file_exists(args.encoder) |
| 176 | assert_file_exists(args.decoder) | 188 | assert_file_exists(args.decoder) |
| 177 | assert_file_exists(args.joiner) | 189 | assert_file_exists(args.joiner) |
| 178 | - assert len(args.paraformer) == 0, args.paraformer | 190 | + |
| 179 | recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( | 191 | recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( |
| 180 | encoder=args.encoder, | 192 | encoder=args.encoder, |
| 181 | decoder=args.decoder, | 193 | decoder=args.decoder, |
| @@ -187,8 +199,10 @@ def main(): | @@ -187,8 +199,10 @@ def main(): | ||
| 187 | decoding_method=args.decoding_method, | 199 | decoding_method=args.decoding_method, |
| 188 | debug=args.debug, | 200 | debug=args.debug, |
| 189 | ) | 201 | ) |
| 190 | - else: | 202 | + elif args.paraformer: |
| 203 | + assert len(args.nemo_ctc) == 0, args.nemo_ctc | ||
| 191 | assert_file_exists(args.paraformer) | 204 | assert_file_exists(args.paraformer) |
| 205 | + | ||
| 192 | recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( | 206 | recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( |
| 193 | paraformer=args.paraformer, | 207 | paraformer=args.paraformer, |
| 194 | tokens=args.tokens, | 208 | tokens=args.tokens, |
| @@ -198,6 +212,19 @@ def main(): | @@ -198,6 +212,19 @@ def main(): | ||
| 198 | decoding_method=args.decoding_method, | 212 | decoding_method=args.decoding_method, |
| 199 | debug=args.debug, | 213 | debug=args.debug, |
| 200 | ) | 214 | ) |
| 215 | + elif args.nemo_ctc: | ||
| 216 | + recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc( | ||
| 217 | + model=args.nemo_ctc, | ||
| 218 | + tokens=args.tokens, | ||
| 219 | + num_threads=args.num_threads, | ||
| 220 | + sample_rate=args.sample_rate, | ||
| 221 | + feature_dim=args.feature_dim, | ||
| 222 | + decoding_method=args.decoding_method, | ||
| 223 | + debug=args.debug, | ||
| 224 | + ) | ||
| 225 | + else: | ||
| 226 | + print("Please specify at least one model") | ||
| 227 | + return | ||
| 201 | 228 | ||
| 202 | print("Started!") | 229 | print("Started!") |
| 203 | start_time = time.time() | 230 | start_time = time.time() |
| @@ -225,12 +252,14 @@ def main(): | @@ -225,12 +252,14 @@ def main(): | ||
| 225 | print("-" * 10) | 252 | print("-" * 10) |
| 226 | 253 | ||
| 227 | elapsed_seconds = end_time - start_time | 254 | elapsed_seconds = end_time - start_time |
| 228 | - rtf = elapsed_seconds / duration | 255 | + rtf = elapsed_seconds / total_duration |
| 229 | print(f"num_threads: {args.num_threads}") | 256 | print(f"num_threads: {args.num_threads}") |
| 230 | print(f"decoding_method: {args.decoding_method}") | 257 | print(f"decoding_method: {args.decoding_method}") |
| 231 | - print(f"Wave duration: {duration:.3f} s") | 258 | + print(f"Wave duration: {total_duration:.3f} s") |
| 232 | print(f"Elapsed time: {elapsed_seconds:.3f} s") | 259 | print(f"Elapsed time: {elapsed_seconds:.3f} s") |
| 233 | - print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}") | 260 | + print( |
| 261 | + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" | ||
| 262 | + ) | ||
| 234 | 263 | ||
| 235 | 264 | ||
| 236 | if __name__ == "__main__": | 265 | if __name__ == "__main__": |
| @@ -172,12 +172,14 @@ def main(): | @@ -172,12 +172,14 @@ def main(): | ||
| 172 | print("-" * 10) | 172 | print("-" * 10) |
| 173 | 173 | ||
| 174 | elapsed_seconds = end_time - start_time | 174 | elapsed_seconds = end_time - start_time |
| 175 | - rtf = elapsed_seconds / duration | 175 | + rtf = elapsed_seconds / total_duration |
| 176 | print(f"num_threads: {args.num_threads}") | 176 | print(f"num_threads: {args.num_threads}") |
| 177 | print(f"decoding_method: {args.decoding_method}") | 177 | print(f"decoding_method: {args.decoding_method}") |
| 178 | - print(f"Wave duration: {duration:.3f} s") | 178 | + print(f"Wave duration: {total_duration:.3f} s") |
| 179 | print(f"Elapsed time: {elapsed_seconds:.3f} s") | 179 | print(f"Elapsed time: {elapsed_seconds:.3f} s") |
| 180 | - print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}") | 180 | + print( |
| 181 | + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" | ||
| 182 | + ) | ||
| 181 | 183 | ||
| 182 | 184 | ||
| 183 | if __name__ == "__main__": | 185 | if __name__ == "__main__": |
| @@ -16,7 +16,11 @@ set(sources | @@ -16,7 +16,11 @@ set(sources | ||
| 16 | features.cc | 16 | features.cc |
| 17 | file-utils.cc | 17 | file-utils.cc |
| 18 | hypothesis.cc | 18 | hypothesis.cc |
| 19 | + offline-ctc-greedy-search-decoder.cc | ||
| 20 | + offline-ctc-model.cc | ||
| 19 | offline-model-config.cc | 21 | offline-model-config.cc |
| 22 | + offline-nemo-enc-dec-ctc-model-config.cc | ||
| 23 | + offline-nemo-enc-dec-ctc-model.cc | ||
| 20 | offline-paraformer-greedy-search-decoder.cc | 24 | offline-paraformer-greedy-search-decoder.cc |
| 21 | offline-paraformer-model-config.cc | 25 | offline-paraformer-model-config.cc |
| 22 | offline-paraformer-model.cc | 26 | offline-paraformer-model.cc |
| @@ -11,15 +11,19 @@ | @@ -11,15 +11,19 @@ | ||
| 11 | #include "android/log.h" | 11 | #include "android/log.h" |
| 12 | #define SHERPA_ONNX_LOGE(...) \ | 12 | #define SHERPA_ONNX_LOGE(...) \ |
| 13 | do { \ | 13 | do { \ |
| 14 | + fprintf(stderr, "%s:%s:%d ", __FILE__, __func__, \ | ||
| 15 | + static_cast<int>(__LINE__)); \ | ||
| 14 | fprintf(stderr, ##__VA_ARGS__); \ | 16 | fprintf(stderr, ##__VA_ARGS__); \ |
| 15 | fprintf(stderr, "\n"); \ | 17 | fprintf(stderr, "\n"); \ |
| 16 | __android_log_print(ANDROID_LOG_WARN, "sherpa-onnx", ##__VA_ARGS__); \ | 18 | __android_log_print(ANDROID_LOG_WARN, "sherpa-onnx", ##__VA_ARGS__); \ |
| 17 | } while (0) | 19 | } while (0) |
| 18 | #else | 20 | #else |
| 19 | -#define SHERPA_ONNX_LOGE(...) \ | ||
| 20 | - do { \ | ||
| 21 | - fprintf(stderr, ##__VA_ARGS__); \ | ||
| 22 | - fprintf(stderr, "\n"); \ | 21 | +#define SHERPA_ONNX_LOGE(...) \ |
| 22 | + do { \ | ||
| 23 | + fprintf(stderr, "%s:%s:%d ", __FILE__, __func__, \ | ||
| 24 | + static_cast<int>(__LINE__)); \ | ||
| 25 | + fprintf(stderr, ##__VA_ARGS__); \ | ||
| 26 | + fprintf(stderr, "\n"); \ | ||
| 23 | } while (0) | 27 | } while (0) |
| 24 | #endif | 28 | #endif |
| 25 | 29 |
sherpa-onnx/csrc/offline-ctc-decoder.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-ctc-decoder.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_DECODER_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_CTC_DECODER_H_ | ||
| 7 | + | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +struct OfflineCtcDecoderResult { | ||
| 15 | + /// The decoded token IDs | ||
| 16 | + std::vector<int64_t> tokens; | ||
| 17 | + | ||
| 18 | + /// timestamps[i] contains the output frame index where tokens[i] is decoded. | ||
| 19 | + /// Note: The index is after subsampling | ||
| 20 | + std::vector<int32_t> timestamps; | ||
| 21 | +}; | ||
| 22 | + | ||
| 23 | +class OfflineCtcDecoder { | ||
| 24 | + public: | ||
| 25 | + virtual ~OfflineCtcDecoder() = default; | ||
| 26 | + | ||
| 27 | + /** Run CTC decoding given the output from the encoder model. | ||
| 28 | + * | ||
| 29 | + * @param log_probs A 3-D tensor of shape (N, T, vocab_size) containing | ||
| 30 | + * lob_probs. | ||
| 31 | + * @param log_probs_length A 1-D tensor of shape (N,) containing number | ||
| 32 | + * of valid frames in log_probs before padding. | ||
| 33 | + * | ||
| 34 | + * @return Return a vector of size `N` containing the decoded results. | ||
| 35 | + */ | ||
| 36 | + virtual std::vector<OfflineCtcDecoderResult> Decode( | ||
| 37 | + Ort::Value log_probs, Ort::Value log_probs_length) = 0; | ||
| 38 | +}; | ||
| 39 | + | ||
| 40 | +} // namespace sherpa_onnx | ||
| 41 | + | ||
| 42 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_DECODER_H_ |
| 1 | +// sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <utility> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +std::vector<OfflineCtcDecoderResult> OfflineCtcGreedySearchDecoder::Decode( | ||
| 16 | + Ort::Value log_probs, Ort::Value log_probs_length) { | ||
| 17 | + std::vector<int64_t> shape = log_probs.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 18 | + int32_t batch_size = static_cast<int32_t>(shape[0]); | ||
| 19 | + int32_t num_frames = static_cast<int32_t>(shape[1]); | ||
| 20 | + int32_t vocab_size = static_cast<int32_t>(shape[2]); | ||
| 21 | + | ||
| 22 | + const int64_t *p_log_probs_length = log_probs_length.GetTensorData<int64_t>(); | ||
| 23 | + | ||
| 24 | + std::vector<OfflineCtcDecoderResult> ans; | ||
| 25 | + ans.reserve(batch_size); | ||
| 26 | + | ||
| 27 | + for (int32_t b = 0; b != batch_size; ++b) { | ||
| 28 | + const float *p_log_probs = | ||
| 29 | + log_probs.GetTensorData<float>() + b * num_frames * vocab_size; | ||
| 30 | + | ||
| 31 | + OfflineCtcDecoderResult r; | ||
| 32 | + int64_t prev_id = -1; | ||
| 33 | + | ||
| 34 | + for (int32_t t = 0; t != static_cast<int32_t>(p_log_probs_length[b]); ++t) { | ||
| 35 | + auto y = static_cast<int64_t>(std::distance( | ||
| 36 | + static_cast<const float *>(p_log_probs), | ||
| 37 | + std::max_element( | ||
| 38 | + static_cast<const float *>(p_log_probs), | ||
| 39 | + static_cast<const float *>(p_log_probs) + vocab_size))); | ||
| 40 | + p_log_probs += vocab_size; | ||
| 41 | + | ||
| 42 | + if (y != blank_id_ && y != prev_id) { | ||
| 43 | + r.tokens.push_back(y); | ||
| 44 | + r.timestamps.push_back(t); | ||
| 45 | + prev_id = y; | ||
| 46 | + } | ||
| 47 | + } // for (int32_t t = 0; ...) | ||
| 48 | + | ||
| 49 | + ans.push_back(std::move(r)); | ||
| 50 | + } | ||
| 51 | + return ans; | ||
| 52 | +} | ||
| 53 | + | ||
| 54 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_GREEDY_SEARCH_DECODER_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_CTC_GREEDY_SEARCH_DECODER_H_ | ||
| 7 | + | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/offline-ctc-decoder.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +class OfflineCtcGreedySearchDecoder : public OfflineCtcDecoder { | ||
| 15 | + public: | ||
| 16 | + explicit OfflineCtcGreedySearchDecoder(int32_t blank_id) | ||
| 17 | + : blank_id_(blank_id) {} | ||
| 18 | + | ||
| 19 | + std::vector<OfflineCtcDecoderResult> Decode( | ||
| 20 | + Ort::Value log_probs, Ort::Value log_probs_length) override; | ||
| 21 | + | ||
| 22 | + private: | ||
| 23 | + int32_t blank_id_; | ||
| 24 | +}; | ||
| 25 | + | ||
| 26 | +} // namespace sherpa_onnx | ||
| 27 | + | ||
| 28 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_GREEDY_SEARCH_DECODER_H_ |
sherpa-onnx/csrc/offline-ctc-model.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-ctc-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-ctc-model.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <memory> | ||
| 9 | +#include <sstream> | ||
| 10 | +#include <string> | ||
| 11 | + | ||
| 12 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 13 | +#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h" | ||
| 14 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 15 | + | ||
| 16 | +namespace { | ||
| 17 | + | ||
| 18 | +enum class ModelType { | ||
| 19 | + kEncDecCTCModelBPE, | ||
| 20 | + kUnkown, | ||
| 21 | +}; | ||
| 22 | + | ||
| 23 | +} | ||
| 24 | + | ||
| 25 | +namespace sherpa_onnx { | ||
| 26 | + | ||
| 27 | +static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 28 | + bool debug) { | ||
| 29 | + Ort::Env env(ORT_LOGGING_LEVEL_WARNING); | ||
| 30 | + Ort::SessionOptions sess_opts; | ||
| 31 | + | ||
| 32 | + auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length, | ||
| 33 | + sess_opts); | ||
| 34 | + | ||
| 35 | + Ort::ModelMetadata meta_data = sess->GetModelMetadata(); | ||
| 36 | + if (debug) { | ||
| 37 | + std::ostringstream os; | ||
| 38 | + PrintModelMetadata(os, meta_data); | ||
| 39 | + SHERPA_ONNX_LOGE("%s", os.str().c_str()); | ||
| 40 | + } | ||
| 41 | + | ||
| 42 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 43 | + auto model_type = | ||
| 44 | + meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); | ||
| 45 | + if (!model_type) { | ||
| 46 | + SHERPA_ONNX_LOGE( | ||
| 47 | + "No model_type in the metadata!\n" | ||
| 48 | + "If you are using models from NeMo, please refer to\n" | ||
| 49 | + "https://huggingface.co/csukuangfj/" | ||
| 50 | + "sherpa-onnx-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py" | ||
| 51 | + "\n" | ||
| 52 | + "for how to add metadta to model.onnx\n"); | ||
| 53 | + return ModelType::kUnkown; | ||
| 54 | + } | ||
| 55 | + | ||
| 56 | + if (model_type.get() == std::string("EncDecCTCModelBPE")) { | ||
| 57 | + return ModelType::kEncDecCTCModelBPE; | ||
| 58 | + } else { | ||
| 59 | + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); | ||
| 60 | + return ModelType::kUnkown; | ||
| 61 | + } | ||
| 62 | +} | ||
| 63 | + | ||
| 64 | +std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | ||
| 65 | + const OfflineModelConfig &config) { | ||
| 66 | + ModelType model_type = ModelType::kUnkown; | ||
| 67 | + | ||
| 68 | + { | ||
| 69 | + auto buffer = ReadFile(config.nemo_ctc.model); | ||
| 70 | + | ||
| 71 | + model_type = GetModelType(buffer.data(), buffer.size(), config.debug); | ||
| 72 | + } | ||
| 73 | + | ||
| 74 | + switch (model_type) { | ||
| 75 | + case ModelType::kEncDecCTCModelBPE: | ||
| 76 | + return std::make_unique<OfflineNemoEncDecCtcModel>(config); | ||
| 77 | + break; | ||
| 78 | + case ModelType::kUnkown: | ||
| 79 | + SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); | ||
| 80 | + return nullptr; | ||
| 81 | + } | ||
| 82 | + | ||
| 83 | + return nullptr; | ||
| 84 | +} | ||
| 85 | + | ||
| 86 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/offline-ctc-model.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-ctc-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_CTC_MODEL_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <string> | ||
| 9 | +#include <utility> | ||
| 10 | + | ||
| 11 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 12 | +#include "sherpa-onnx/csrc/offline-model-config.h" | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +class OfflineCtcModel { | ||
| 17 | + public: | ||
| 18 | + virtual ~OfflineCtcModel() = default; | ||
| 19 | + static std::unique_ptr<OfflineCtcModel> Create( | ||
| 20 | + const OfflineModelConfig &config); | ||
| 21 | + | ||
| 22 | + /** Run the forward method of the model. | ||
| 23 | + * | ||
| 24 | + * @param features A tensor of shape (N, T, C). It is changed in-place. | ||
| 25 | + * @param features_length A 1-D tensor of shape (N,) containing number of | ||
| 26 | + * valid frames in `features` before padding. | ||
| 27 | + * Its dtype is int64_t. | ||
| 28 | + * | ||
| 29 | + * @return Return a pair containing: | ||
| 30 | + * - log_probs: A 3-D tensor of shape (N, T', vocab_size). | ||
| 31 | + * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t | ||
| 32 | + */ | ||
| 33 | + virtual std::pair<Ort::Value, Ort::Value> Forward( | ||
| 34 | + Ort::Value features, Ort::Value features_length) = 0; | ||
| 35 | + | ||
| 36 | + /** Return the vocabulary size of the model | ||
| 37 | + */ | ||
| 38 | + virtual int32_t VocabSize() const = 0; | ||
| 39 | + | ||
| 40 | + /** SubsamplingFactor of the model | ||
| 41 | + * | ||
| 42 | + * For Citrinet, the subsampling factor is usually 4. | ||
| 43 | + * For Conformer CTC, the subsampling factor is usually 8. | ||
| 44 | + */ | ||
| 45 | + virtual int32_t SubsamplingFactor() const = 0; | ||
| 46 | + | ||
| 47 | + /** Return an allocator for allocating memory | ||
| 48 | + */ | ||
| 49 | + virtual OrtAllocator *Allocator() const = 0; | ||
| 50 | + | ||
| 51 | + /** For some models, e.g., those from NeMo, they require some preprocessing | ||
| 52 | + * for the features. | ||
| 53 | + */ | ||
| 54 | + virtual std::string FeatureNormalizationMethod() const { return {}; } | ||
| 55 | +}; | ||
| 56 | + | ||
| 57 | +} // namespace sherpa_onnx | ||
| 58 | + | ||
| 59 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_MODEL_H_ |
| @@ -13,6 +13,7 @@ namespace sherpa_onnx { | @@ -13,6 +13,7 @@ namespace sherpa_onnx { | ||
| 13 | void OfflineModelConfig::Register(ParseOptions *po) { | 13 | void OfflineModelConfig::Register(ParseOptions *po) { |
| 14 | transducer.Register(po); | 14 | transducer.Register(po); |
| 15 | paraformer.Register(po); | 15 | paraformer.Register(po); |
| 16 | + nemo_ctc.Register(po); | ||
| 16 | 17 | ||
| 17 | po->Register("tokens", &tokens, "Path to tokens.txt"); | 18 | po->Register("tokens", &tokens, "Path to tokens.txt"); |
| 18 | 19 | ||
| @@ -38,6 +39,10 @@ bool OfflineModelConfig::Validate() const { | @@ -38,6 +39,10 @@ bool OfflineModelConfig::Validate() const { | ||
| 38 | return paraformer.Validate(); | 39 | return paraformer.Validate(); |
| 39 | } | 40 | } |
| 40 | 41 | ||
| 42 | + if (!nemo_ctc.model.empty()) { | ||
| 43 | + return nemo_ctc.Validate(); | ||
| 44 | + } | ||
| 45 | + | ||
| 41 | return transducer.Validate(); | 46 | return transducer.Validate(); |
| 42 | } | 47 | } |
| 43 | 48 | ||
| @@ -47,6 +52,7 @@ std::string OfflineModelConfig::ToString() const { | @@ -47,6 +52,7 @@ std::string OfflineModelConfig::ToString() const { | ||
| 47 | os << "OfflineModelConfig("; | 52 | os << "OfflineModelConfig("; |
| 48 | os << "transducer=" << transducer.ToString() << ", "; | 53 | os << "transducer=" << transducer.ToString() << ", "; |
| 49 | os << "paraformer=" << paraformer.ToString() << ", "; | 54 | os << "paraformer=" << paraformer.ToString() << ", "; |
| 55 | + os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; | ||
| 50 | os << "tokens=\"" << tokens << "\", "; | 56 | os << "tokens=\"" << tokens << "\", "; |
| 51 | os << "num_threads=" << num_threads << ", "; | 57 | os << "num_threads=" << num_threads << ", "; |
| 52 | os << "debug=" << (debug ? "True" : "False") << ")"; | 58 | os << "debug=" << (debug ? "True" : "False") << ")"; |
| @@ -6,6 +6,7 @@ | @@ -6,6 +6,7 @@ | ||
| 6 | 6 | ||
| 7 | #include <string> | 7 | #include <string> |
| 8 | 8 | ||
| 9 | +#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h" | ||
| 9 | #include "sherpa-onnx/csrc/offline-paraformer-model-config.h" | 10 | #include "sherpa-onnx/csrc/offline-paraformer-model-config.h" |
| 10 | #include "sherpa-onnx/csrc/offline-transducer-model-config.h" | 11 | #include "sherpa-onnx/csrc/offline-transducer-model-config.h" |
| 11 | 12 | ||
| @@ -14,6 +15,7 @@ namespace sherpa_onnx { | @@ -14,6 +15,7 @@ namespace sherpa_onnx { | ||
| 14 | struct OfflineModelConfig { | 15 | struct OfflineModelConfig { |
| 15 | OfflineTransducerModelConfig transducer; | 16 | OfflineTransducerModelConfig transducer; |
| 16 | OfflineParaformerModelConfig paraformer; | 17 | OfflineParaformerModelConfig paraformer; |
| 18 | + OfflineNemoEncDecCtcModelConfig nemo_ctc; | ||
| 17 | 19 | ||
| 18 | std::string tokens; | 20 | std::string tokens; |
| 19 | int32_t num_threads = 2; | 21 | int32_t num_threads = 2; |
| @@ -22,9 +24,11 @@ struct OfflineModelConfig { | @@ -22,9 +24,11 @@ struct OfflineModelConfig { | ||
| 22 | OfflineModelConfig() = default; | 24 | OfflineModelConfig() = default; |
| 23 | OfflineModelConfig(const OfflineTransducerModelConfig &transducer, | 25 | OfflineModelConfig(const OfflineTransducerModelConfig &transducer, |
| 24 | const OfflineParaformerModelConfig ¶former, | 26 | const OfflineParaformerModelConfig ¶former, |
| 27 | + const OfflineNemoEncDecCtcModelConfig &nemo_ctc, | ||
| 25 | const std::string &tokens, int32_t num_threads, bool debug) | 28 | const std::string &tokens, int32_t num_threads, bool debug) |
| 26 | : transducer(transducer), | 29 | : transducer(transducer), |
| 27 | paraformer(paraformer), | 30 | paraformer(paraformer), |
| 31 | + nemo_ctc(nemo_ctc), | ||
| 28 | tokens(tokens), | 32 | tokens(tokens), |
| 29 | num_threads(num_threads), | 33 | num_threads(num_threads), |
| 30 | debug(debug) {} | 34 | debug(debug) {} |
| 1 | +// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 8 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void OfflineNemoEncDecCtcModelConfig::Register(ParseOptions *po) { | ||
| 13 | + po->Register("nemo-ctc-model", &model, | ||
| 14 | + "Path to model.onnx of Nemo EncDecCtcModel."); | ||
| 15 | +} | ||
| 16 | + | ||
| 17 | +bool OfflineNemoEncDecCtcModelConfig::Validate() const { | ||
| 18 | + if (!FileExists(model)) { | ||
| 19 | + SHERPA_ONNX_LOGE("%s does not exist", model.c_str()); | ||
| 20 | + return false; | ||
| 21 | + } | ||
| 22 | + | ||
| 23 | + return true; | ||
| 24 | +} | ||
| 25 | + | ||
| 26 | +std::string OfflineNemoEncDecCtcModelConfig::ToString() const { | ||
| 27 | + std::ostringstream os; | ||
| 28 | + | ||
| 29 | + os << "OfflineNemoEncDecCtcModelConfig("; | ||
| 30 | + os << "model=\"" << model << "\")"; | ||
| 31 | + | ||
| 32 | + return os.str(); | ||
| 33 | +} | ||
| 34 | + | ||
| 35 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +struct OfflineNemoEncDecCtcModelConfig { | ||
| 14 | + std::string model; | ||
| 15 | + | ||
| 16 | + OfflineNemoEncDecCtcModelConfig() = default; | ||
| 17 | + explicit OfflineNemoEncDecCtcModelConfig(const std::string &model) | ||
| 18 | + : model(model) {} | ||
| 19 | + | ||
| 20 | + void Register(ParseOptions *po); | ||
| 21 | + bool Validate() const; | ||
| 22 | + | ||
| 23 | + std::string ToString() const; | ||
| 24 | +}; | ||
| 25 | + | ||
| 26 | +} // namespace sherpa_onnx | ||
| 27 | + | ||
| 28 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_ |
| 1 | +// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 8 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 9 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 10 | +#include "sherpa-onnx/csrc/transpose.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +class OfflineNemoEncDecCtcModel::Impl { | ||
| 15 | + public: | ||
| 16 | + explicit Impl(const OfflineModelConfig &config) | ||
| 17 | + : config_(config), | ||
| 18 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 19 | + sess_opts_{}, | ||
| 20 | + allocator_{} { | ||
| 21 | + sess_opts_.SetIntraOpNumThreads(config_.num_threads); | ||
| 22 | + sess_opts_.SetInterOpNumThreads(config_.num_threads); | ||
| 23 | + | ||
| 24 | + Init(); | ||
| 25 | + } | ||
| 26 | + | ||
| 27 | + std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features, | ||
| 28 | + Ort::Value features_length) { | ||
| 29 | + std::vector<int64_t> shape = | ||
| 30 | + features_length.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 31 | + | ||
| 32 | + Ort::Value out_features_length = Ort::Value::CreateTensor<int64_t>( | ||
| 33 | + allocator_, shape.data(), shape.size()); | ||
| 34 | + | ||
| 35 | + const int64_t *src = features_length.GetTensorData<int64_t>(); | ||
| 36 | + int64_t *dst = out_features_length.GetTensorMutableData<int64_t>(); | ||
| 37 | + for (int64_t i = 0; i != shape[0]; ++i) { | ||
| 38 | + dst[i] = src[i] / subsampling_factor_; | ||
| 39 | + } | ||
| 40 | + | ||
| 41 | + // (B, T, C) -> (B, C, T) | ||
| 42 | + features = Transpose12(allocator_, &features); | ||
| 43 | + | ||
| 44 | + std::array<Ort::Value, 2> inputs = {std::move(features), | ||
| 45 | + std::move(features_length)}; | ||
| 46 | + auto out = | ||
| 47 | + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 48 | + output_names_ptr_.data(), output_names_ptr_.size()); | ||
| 49 | + | ||
| 50 | + return {std::move(out[0]), std::move(out_features_length)}; | ||
| 51 | + } | ||
| 52 | + | ||
| 53 | + int32_t VocabSize() const { return vocab_size_; } | ||
| 54 | + | ||
| 55 | + int32_t SubsamplingFactor() const { return subsampling_factor_; } | ||
| 56 | + | ||
| 57 | + OrtAllocator *Allocator() const { return allocator_; } | ||
| 58 | + | ||
| 59 | + std::string FeatureNormalizationMethod() const { return normalize_type_; } | ||
| 60 | + | ||
| 61 | + private: | ||
| 62 | + void Init() { | ||
| 63 | + auto buf = ReadFile(config_.nemo_ctc.model); | ||
| 64 | + | ||
| 65 | + sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(), | ||
| 66 | + sess_opts_); | ||
| 67 | + | ||
| 68 | + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | ||
| 69 | + | ||
| 70 | + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | ||
| 71 | + | ||
| 72 | + // get meta data | ||
| 73 | + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); | ||
| 74 | + if (config_.debug) { | ||
| 75 | + std::ostringstream os; | ||
| 76 | + PrintModelMetadata(os, meta_data); | ||
| 77 | + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
| 78 | + } | ||
| 79 | + | ||
| 80 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 81 | + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); | ||
| 82 | + SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor"); | ||
| 83 | + SHERPA_ONNX_READ_META_DATA_STR(normalize_type_, "normalize_type"); | ||
| 84 | + } | ||
| 85 | + | ||
| 86 | + private: | ||
| 87 | + OfflineModelConfig config_; | ||
| 88 | + Ort::Env env_; | ||
| 89 | + Ort::SessionOptions sess_opts_; | ||
| 90 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 91 | + | ||
| 92 | + std::unique_ptr<Ort::Session> sess_; | ||
| 93 | + | ||
| 94 | + std::vector<std::string> input_names_; | ||
| 95 | + std::vector<const char *> input_names_ptr_; | ||
| 96 | + | ||
| 97 | + std::vector<std::string> output_names_; | ||
| 98 | + std::vector<const char *> output_names_ptr_; | ||
| 99 | + | ||
| 100 | + int32_t vocab_size_ = 0; | ||
| 101 | + int32_t subsampling_factor_ = 0; | ||
| 102 | + std::string normalize_type_; | ||
| 103 | +}; | ||
| 104 | + | ||
| 105 | +OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel( | ||
| 106 | + const OfflineModelConfig &config) | ||
| 107 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 108 | + | ||
| 109 | +OfflineNemoEncDecCtcModel::~OfflineNemoEncDecCtcModel() = default; | ||
| 110 | + | ||
| 111 | +std::pair<Ort::Value, Ort::Value> OfflineNemoEncDecCtcModel::Forward( | ||
| 112 | + Ort::Value features, Ort::Value features_length) { | ||
| 113 | + return impl_->Forward(std::move(features), std::move(features_length)); | ||
| 114 | +} | ||
| 115 | + | ||
| 116 | +int32_t OfflineNemoEncDecCtcModel::VocabSize() const { | ||
| 117 | + return impl_->VocabSize(); | ||
| 118 | +} | ||
| 119 | +int32_t OfflineNemoEncDecCtcModel::SubsamplingFactor() const { | ||
| 120 | + return impl_->SubsamplingFactor(); | ||
| 121 | +} | ||
| 122 | + | ||
| 123 | +OrtAllocator *OfflineNemoEncDecCtcModel::Allocator() const { | ||
| 124 | + return impl_->Allocator(); | ||
| 125 | +} | ||
| 126 | + | ||
| 127 | +std::string OfflineNemoEncDecCtcModel::FeatureNormalizationMethod() const { | ||
| 128 | + return impl_->FeatureNormalizationMethod(); | ||
| 129 | +} | ||
| 130 | + | ||
| 131 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_H_ | ||
| 6 | +#include <memory> | ||
| 7 | +#include <string> | ||
| 8 | +#include <utility> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 12 | +#include "sherpa-onnx/csrc/offline-ctc-model.h" | ||
| 13 | +#include "sherpa-onnx/csrc/offline-model-config.h" | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | + | ||
| 17 | +/** This class implements the EncDecCTCModelBPE model from NeMo. | ||
| 18 | + * | ||
| 19 | + * See | ||
| 20 | + * https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/ctc_bpe_models.py | ||
| 21 | + * https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/ctc_models.py | ||
| 22 | + */ | ||
| 23 | +class OfflineNemoEncDecCtcModel : public OfflineCtcModel { | ||
| 24 | + public: | ||
| 25 | + explicit OfflineNemoEncDecCtcModel(const OfflineModelConfig &config); | ||
| 26 | + ~OfflineNemoEncDecCtcModel() override; | ||
| 27 | + | ||
| 28 | + /** Run the forward method of the model. | ||
| 29 | + * | ||
| 30 | + * @param features A tensor of shape (N, T, C). It is changed in-place. | ||
| 31 | + * @param features_length A 1-D tensor of shape (N,) containing number of | ||
| 32 | + * valid frames in `features` before padding. | ||
| 33 | + * Its dtype is int64_t. | ||
| 34 | + * | ||
| 35 | + * @return Return a pair containing: | ||
| 36 | + * - log_probs: A 3-D tensor of shape (N, T', vocab_size). | ||
| 37 | + * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t | ||
| 38 | + */ | ||
| 39 | + std::pair<Ort::Value, Ort::Value> Forward( | ||
| 40 | + Ort::Value features, Ort::Value features_length) override; | ||
| 41 | + | ||
| 42 | + /** Return the vocabulary size of the model | ||
| 43 | + */ | ||
| 44 | + int32_t VocabSize() const override; | ||
| 45 | + | ||
| 46 | + /** SubsamplingFactor of the model | ||
| 47 | + * | ||
| 48 | + * For Citrinet, the subsampling factor is usually 4. | ||
| 49 | + * For Conformer CTC, the subsampling factor is usually 8. | ||
| 50 | + */ | ||
| 51 | + int32_t SubsamplingFactor() const override; | ||
| 52 | + | ||
| 53 | + /** Return an allocator for allocating memory | ||
| 54 | + */ | ||
| 55 | + OrtAllocator *Allocator() const override; | ||
| 56 | + | ||
| 57 | + // Possible values: | ||
| 58 | + // - per_feature | ||
| 59 | + // - all_features (not implemented yet) | ||
| 60 | + // - fixed_mean (not implemented) | ||
| 61 | + // - fixed_std (not implemented) | ||
| 62 | + // - or just leave it to empty | ||
| 63 | + // See | ||
| 64 | + // https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59 | ||
| 65 | + // for details | ||
| 66 | + std::string FeatureNormalizationMethod() const override; | ||
| 67 | + | ||
| 68 | + private: | ||
| 69 | + class Impl; | ||
| 70 | + std::unique_ptr<Impl> impl_; | ||
| 71 | +}; | ||
| 72 | + | ||
| 73 | +} // namespace sherpa_onnx | ||
| 74 | + | ||
| 75 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_H_ |
| 1 | +// sherpa-onnx/csrc/offline-recognizer-ctc-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CTC_IMPL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CTC_IMPL_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | +#include <string> | ||
| 10 | +#include <utility> | ||
| 11 | +#include <vector> | ||
| 12 | + | ||
| 13 | +#include "sherpa-onnx/csrc/offline-ctc-decoder.h" | ||
| 14 | +#include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h" | ||
| 15 | +#include "sherpa-onnx/csrc/offline-ctc-model.h" | ||
| 16 | +#include "sherpa-onnx/csrc/offline-recognizer-impl.h" | ||
| 17 | +#include "sherpa-onnx/csrc/pad-sequence.h" | ||
| 18 | +#include "sherpa-onnx/csrc/symbol-table.h" | ||
| 19 | + | ||
| 20 | +namespace sherpa_onnx { | ||
| 21 | + | ||
| 22 | +static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, | ||
| 23 | + const SymbolTable &sym_table) { | ||
| 24 | + OfflineRecognitionResult r; | ||
| 25 | + r.tokens.reserve(src.tokens.size()); | ||
| 26 | + | ||
| 27 | + std::string text; | ||
| 28 | + | ||
| 29 | + for (int32_t i = 0; i != src.tokens.size(); ++i) { | ||
| 30 | + auto sym = sym_table[src.tokens[i]]; | ||
| 31 | + text.append(sym); | ||
| 32 | + r.tokens.push_back(std::move(sym)); | ||
| 33 | + } | ||
| 34 | + r.text = std::move(text); | ||
| 35 | + | ||
| 36 | + return r; | ||
| 37 | +} | ||
| 38 | + | ||
| 39 | +class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | ||
| 40 | + public: | ||
| 41 | + explicit OfflineRecognizerCtcImpl(const OfflineRecognizerConfig &config) | ||
| 42 | + : config_(config), | ||
| 43 | + symbol_table_(config_.model_config.tokens), | ||
| 44 | + model_(OfflineCtcModel::Create(config_.model_config)) { | ||
| 45 | + config_.feat_config.nemo_normalize_type = | ||
| 46 | + model_->FeatureNormalizationMethod(); | ||
| 47 | + | ||
| 48 | + if (config.decoding_method == "greedy_search") { | ||
| 49 | + if (!symbol_table_.contains("<blk>")) { | ||
| 50 | + SHERPA_ONNX_LOGE( | ||
| 51 | + "We expect that tokens.txt contains " | ||
| 52 | + "the symbol <blk> and its ID."); | ||
| 53 | + exit(-1); | ||
| 54 | + } | ||
| 55 | + | ||
| 56 | + int32_t blank_id = symbol_table_["<blk>"]; | ||
| 57 | + decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id); | ||
| 58 | + } else { | ||
| 59 | + SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", | ||
| 60 | + config.decoding_method.c_str()); | ||
| 61 | + exit(-1); | ||
| 62 | + } | ||
| 63 | + } | ||
| 64 | + | ||
| 65 | + std::unique_ptr<OfflineStream> CreateStream() const override { | ||
| 66 | + return std::make_unique<OfflineStream>(config_.feat_config); | ||
| 67 | + } | ||
| 68 | + | ||
| 69 | + void DecodeStreams(OfflineStream **ss, int32_t n) const override { | ||
| 70 | + auto memory_info = | ||
| 71 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 72 | + | ||
| 73 | + int32_t feat_dim = config_.feat_config.feature_dim; | ||
| 74 | + | ||
| 75 | + std::vector<Ort::Value> features; | ||
| 76 | + features.reserve(n); | ||
| 77 | + | ||
| 78 | + std::vector<std::vector<float>> features_vec(n); | ||
| 79 | + std::vector<int64_t> features_length_vec(n); | ||
| 80 | + | ||
| 81 | + for (int32_t i = 0; i != n; ++i) { | ||
| 82 | + std::vector<float> f = ss[i]->GetFrames(); | ||
| 83 | + | ||
| 84 | + int32_t num_frames = f.size() / feat_dim; | ||
| 85 | + features_vec[i] = std::move(f); | ||
| 86 | + | ||
| 87 | + features_length_vec[i] = num_frames; | ||
| 88 | + | ||
| 89 | + std::array<int64_t, 2> shape = {num_frames, feat_dim}; | ||
| 90 | + | ||
| 91 | + Ort::Value x = Ort::Value::CreateTensor( | ||
| 92 | + memory_info, features_vec[i].data(), features_vec[i].size(), | ||
| 93 | + shape.data(), shape.size()); | ||
| 94 | + features.push_back(std::move(x)); | ||
| 95 | + } // for (int32_t i = 0; i != n; ++i) | ||
| 96 | + | ||
| 97 | + std::vector<const Ort::Value *> features_pointer(n); | ||
| 98 | + for (int32_t i = 0; i != n; ++i) { | ||
| 99 | + features_pointer[i] = &features[i]; | ||
| 100 | + } | ||
| 101 | + | ||
| 102 | + std::array<int64_t, 1> features_length_shape = {n}; | ||
| 103 | + Ort::Value x_length = Ort::Value::CreateTensor( | ||
| 104 | + memory_info, features_length_vec.data(), n, | ||
| 105 | + features_length_shape.data(), features_length_shape.size()); | ||
| 106 | + | ||
| 107 | + Ort::Value x = PadSequence(model_->Allocator(), features_pointer, | ||
| 108 | + -23.025850929940457f); | ||
| 109 | + auto t = model_->Forward(std::move(x), std::move(x_length)); | ||
| 110 | + | ||
| 111 | + auto results = decoder_->Decode(std::move(t.first), std::move(t.second)); | ||
| 112 | + | ||
| 113 | + for (int32_t i = 0; i != n; ++i) { | ||
| 114 | + auto r = Convert(results[i], symbol_table_); | ||
| 115 | + ss[i]->SetResult(r); | ||
| 116 | + } | ||
| 117 | + } | ||
| 118 | + | ||
| 119 | + private: | ||
| 120 | + OfflineRecognizerConfig config_; | ||
| 121 | + SymbolTable symbol_table_; | ||
| 122 | + std::unique_ptr<OfflineCtcModel> model_; | ||
| 123 | + std::unique_ptr<OfflineCtcDecoder> decoder_; | ||
| 124 | +}; | ||
| 125 | + | ||
| 126 | +} // namespace sherpa_onnx | ||
| 127 | + | ||
| 128 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CTC_IMPL_H_ |
| @@ -8,6 +8,7 @@ | @@ -8,6 +8,7 @@ | ||
| 8 | 8 | ||
| 9 | #include "onnxruntime_cxx_api.h" // NOLINT | 9 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 10 | #include "sherpa-onnx/csrc/macros.h" | 10 | #include "sherpa-onnx/csrc/macros.h" |
| 11 | +#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" | ||
| 11 | #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h" | 12 | #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h" |
| 12 | #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" | 13 | #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" |
| 13 | #include "sherpa-onnx/csrc/onnx-utils.h" | 14 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| @@ -25,6 +26,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -25,6 +26,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 25 | model_filename = config.model_config.transducer.encoder_filename; | 26 | model_filename = config.model_config.transducer.encoder_filename; |
| 26 | } else if (!config.model_config.paraformer.model.empty()) { | 27 | } else if (!config.model_config.paraformer.model.empty()) { |
| 27 | model_filename = config.model_config.paraformer.model; | 28 | model_filename = config.model_config.paraformer.model; |
| 29 | + } else if (!config.model_config.nemo_ctc.model.empty()) { | ||
| 30 | + model_filename = config.model_config.nemo_ctc.model; | ||
| 28 | } else { | 31 | } else { |
| 29 | SHERPA_ONNX_LOGE("Please provide a model"); | 32 | SHERPA_ONNX_LOGE("Please provide a model"); |
| 30 | exit(-1); | 33 | exit(-1); |
| @@ -39,8 +42,30 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -39,8 +42,30 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 39 | 42 | ||
| 40 | Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | 43 | Ort::AllocatorWithDefaultOptions allocator; // used in the macro below |
| 41 | 44 | ||
| 42 | - std::string model_type; | ||
| 43 | - SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type"); | 45 | + auto model_type_ptr = |
| 46 | + meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); | ||
| 47 | + if (!model_type_ptr) { | ||
| 48 | + SHERPA_ONNX_LOGE( | ||
| 49 | + "No model_type in the metadata!\n\n" | ||
| 50 | + "Please refer to the following URLs to add metadata" | ||
| 51 | + "\n" | ||
| 52 | + "(0) Transducer models from icefall" | ||
| 53 | + "\n " | ||
| 54 | + "https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/" | ||
| 55 | + "pruned_transducer_stateless7/export-onnx.py#L303" | ||
| 56 | + "\n" | ||
| 57 | + "(1) Nemo CTC models\n " | ||
| 58 | + "https://huggingface.co/csukuangfj/" | ||
| 59 | + "sherpa-onnx-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py" | ||
| 60 | + "\n" | ||
| 61 | + "(2) Paraformer" | ||
| 62 | + "\n " | ||
| 63 | + "https://huggingface.co/csukuangfj/" | ||
| 64 | + "paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py" | ||
| 65 | + "\n"); | ||
| 66 | + exit(-1); | ||
| 67 | + } | ||
| 68 | + std::string model_type(model_type_ptr.get()); | ||
| 44 | 69 | ||
| 45 | if (model_type == "conformer" || model_type == "zipformer") { | 70 | if (model_type == "conformer" || model_type == "zipformer") { |
| 46 | return std::make_unique<OfflineRecognizerTransducerImpl>(config); | 71 | return std::make_unique<OfflineRecognizerTransducerImpl>(config); |
| @@ -50,11 +75,16 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -50,11 +75,16 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 50 | return std::make_unique<OfflineRecognizerParaformerImpl>(config); | 75 | return std::make_unique<OfflineRecognizerParaformerImpl>(config); |
| 51 | } | 76 | } |
| 52 | 77 | ||
| 78 | + if (model_type == "EncDecCTCModelBPE") { | ||
| 79 | + return std::make_unique<OfflineRecognizerCtcImpl>(config); | ||
| 80 | + } | ||
| 81 | + | ||
| 53 | SHERPA_ONNX_LOGE( | 82 | SHERPA_ONNX_LOGE( |
| 54 | "\nUnsupported model_type: %s\n" | 83 | "\nUnsupported model_type: %s\n" |
| 55 | "We support only the following model types at present: \n" | 84 | "We support only the following model types at present: \n" |
| 56 | - " - transducer models from icefall\n" | ||
| 57 | - " - Paraformer models from FunASR\n", | 85 | + " - Non-streaming transducer models from icefall\n" |
| 86 | + " - Non-streaming Paraformer models from FunASR\n" | ||
| 87 | + " - EncDecCTCModelBPE models from NeMo\n", | ||
| 58 | model_type.c_str()); | 88 | model_type.c_str()); |
| 59 | 89 | ||
| 60 | exit(-1); | 90 | exit(-1); |
| @@ -7,6 +7,7 @@ | @@ -7,6 +7,7 @@ | ||
| 7 | #include <assert.h> | 7 | #include <assert.h> |
| 8 | 8 | ||
| 9 | #include <algorithm> | 9 | #include <algorithm> |
| 10 | +#include <cmath> | ||
| 10 | 11 | ||
| 11 | #include "kaldi-native-fbank/csrc/online-feature.h" | 12 | #include "kaldi-native-fbank/csrc/online-feature.h" |
| 12 | #include "sherpa-onnx/csrc/macros.h" | 13 | #include "sherpa-onnx/csrc/macros.h" |
| @@ -15,6 +16,41 @@ | @@ -15,6 +16,41 @@ | ||
| 15 | 16 | ||
| 16 | namespace sherpa_onnx { | 17 | namespace sherpa_onnx { |
| 17 | 18 | ||
| 19 | +/* Compute mean and inverse stddev over rows. | ||
| 20 | + * | ||
| 21 | + * @param p A pointer to a 2-d array of shape (num_rows, num_cols) | ||
| 22 | + * @param num_rows Number of rows | ||
| 23 | + * @param num_cols Number of columns | ||
| 24 | + * @param mean On return, it contains p.mean(axis=0) | ||
| 25 | + * @param inv_stddev On return, it contains 1/p.std(axis=0) | ||
| 26 | + */ | ||
| 27 | +static void ComputeMeanAndInvStd(const float *p, int32_t num_rows, | ||
| 28 | + int32_t num_cols, std::vector<float> *mean, | ||
| 29 | + std::vector<float> *inv_stddev) { | ||
| 30 | + std::vector<float> sum(num_cols); | ||
| 31 | + std::vector<float> sum_sq(num_cols); | ||
| 32 | + | ||
| 33 | + for (int32_t i = 0; i != num_rows; ++i) { | ||
| 34 | + for (int32_t c = 0; c != num_cols; ++c) { | ||
| 35 | + auto t = p[c]; | ||
| 36 | + sum[c] += t; | ||
| 37 | + sum_sq[c] += t * t; | ||
| 38 | + } | ||
| 39 | + p += num_cols; | ||
| 40 | + } | ||
| 41 | + | ||
| 42 | + mean->resize(num_cols); | ||
| 43 | + inv_stddev->resize(num_cols); | ||
| 44 | + | ||
| 45 | + for (int32_t i = 0; i != num_cols; ++i) { | ||
| 46 | + auto t = sum[i] / num_rows; | ||
| 47 | + (*mean)[i] = t; | ||
| 48 | + | ||
| 49 | + float stddev = std::sqrt(sum_sq[i] / num_rows - t * t); | ||
| 50 | + (*inv_stddev)[i] = 1.0f / (stddev + 1e-5f); | ||
| 51 | + } | ||
| 52 | +} | ||
| 53 | + | ||
| 18 | void OfflineFeatureExtractorConfig::Register(ParseOptions *po) { | 54 | void OfflineFeatureExtractorConfig::Register(ParseOptions *po) { |
| 19 | po->Register("sample-rate", &sampling_rate, | 55 | po->Register("sample-rate", &sampling_rate, |
| 20 | "Sampling rate of the input waveform. " | 56 | "Sampling rate of the input waveform. " |
| @@ -106,6 +142,8 @@ class OfflineStream::Impl { | @@ -106,6 +142,8 @@ class OfflineStream::Impl { | ||
| 106 | p += feature_dim; | 142 | p += feature_dim; |
| 107 | } | 143 | } |
| 108 | 144 | ||
| 145 | + NemoNormalizeFeatures(features.data(), n, feature_dim); | ||
| 146 | + | ||
| 109 | return features; | 147 | return features; |
| 110 | } | 148 | } |
| 111 | 149 | ||
| @@ -114,6 +152,38 @@ class OfflineStream::Impl { | @@ -114,6 +152,38 @@ class OfflineStream::Impl { | ||
| 114 | const OfflineRecognitionResult &GetResult() const { return r_; } | 152 | const OfflineRecognitionResult &GetResult() const { return r_; } |
| 115 | 153 | ||
| 116 | private: | 154 | private: |
| 155 | + void NemoNormalizeFeatures(float *p, int32_t num_frames, | ||
| 156 | + int32_t feature_dim) const { | ||
| 157 | + if (config_.nemo_normalize_type.empty()) { | ||
| 158 | + return; | ||
| 159 | + } | ||
| 160 | + | ||
| 161 | + if (config_.nemo_normalize_type != "per_feature") { | ||
| 162 | + SHERPA_ONNX_LOGE( | ||
| 163 | + "Only normalize_type=per_feature is implemented. Given: %s", | ||
| 164 | + config_.nemo_normalize_type.c_str()); | ||
| 165 | + exit(-1); | ||
| 166 | + } | ||
| 167 | + | ||
| 168 | + NemoNormalizePerFeature(p, num_frames, feature_dim); | ||
| 169 | + } | ||
| 170 | + | ||
| 171 | + static void NemoNormalizePerFeature(float *p, int32_t num_frames, | ||
| 172 | + int32_t feature_dim) { | ||
| 173 | + std::vector<float> mean; | ||
| 174 | + std::vector<float> inv_stddev; | ||
| 175 | + | ||
| 176 | + ComputeMeanAndInvStd(p, num_frames, feature_dim, &mean, &inv_stddev); | ||
| 177 | + | ||
| 178 | + for (int32_t n = 0; n != num_frames; ++n) { | ||
| 179 | + for (int32_t i = 0; i != feature_dim; ++i) { | ||
| 180 | + p[i] = (p[i] - mean[i]) * inv_stddev[i]; | ||
| 181 | + } | ||
| 182 | + p += feature_dim; | ||
| 183 | + } | ||
| 184 | + } | ||
| 185 | + | ||
| 186 | + private: | ||
| 117 | OfflineFeatureExtractorConfig config_; | 187 | OfflineFeatureExtractorConfig config_; |
| 118 | std::unique_ptr<knf::OnlineFbank> fbank_; | 188 | std::unique_ptr<knf::OnlineFbank> fbank_; |
| 119 | knf::FbankOptions opts_; | 189 | knf::FbankOptions opts_; |
| @@ -37,13 +37,26 @@ struct OfflineFeatureExtractorConfig { | @@ -37,13 +37,26 @@ struct OfflineFeatureExtractorConfig { | ||
| 37 | // Feature dimension | 37 | // Feature dimension |
| 38 | int32_t feature_dim = 80; | 38 | int32_t feature_dim = 80; |
| 39 | 39 | ||
| 40 | - // Set internally by some models, e.g., paraformer | 40 | + // Set internally by some models, e.g., paraformer sets it to false. |
| 41 | // This parameter is not exposed to users from the commandline | 41 | // This parameter is not exposed to users from the commandline |
| 42 | // If true, the feature extractor expects inputs to be normalized to | 42 | // If true, the feature extractor expects inputs to be normalized to |
| 43 | // the range [-1, 1]. | 43 | // the range [-1, 1]. |
| 44 | // If false, we will multiply the inputs by 32768 | 44 | // If false, we will multiply the inputs by 32768 |
| 45 | bool normalize_samples = true; | 45 | bool normalize_samples = true; |
| 46 | 46 | ||
| 47 | + // For models from NeMo | ||
| 48 | + // This option is not exposed and is set internally when loading models. | ||
| 49 | + // Possible values: | ||
| 50 | + // - per_feature | ||
| 51 | + // - all_features (not implemented yet) | ||
| 52 | + // - fixed_mean (not implemented) | ||
| 53 | + // - fixed_std (not implemented) | ||
| 54 | + // - or just leave it to empty | ||
| 55 | + // See | ||
| 56 | + // https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59 | ||
| 57 | + // for details | ||
| 58 | + std::string nemo_normalize_type; | ||
| 59 | + | ||
| 47 | std::string ToString() const; | 60 | std::string ToString() const; |
| 48 | 61 | ||
| 49 | void Register(ParseOptions *po); | 62 | void Register(ParseOptions *po); |
| @@ -14,10 +14,12 @@ | @@ -14,10 +14,12 @@ | ||
| 14 | #include <sstream> | 14 | #include <sstream> |
| 15 | #include <string> | 15 | #include <string> |
| 16 | 16 | ||
| 17 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 17 | #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" | 18 | #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" |
| 18 | #include "sherpa-onnx/csrc/online-zipformer-transducer-model.h" | 19 | #include "sherpa-onnx/csrc/online-zipformer-transducer-model.h" |
| 19 | #include "sherpa-onnx/csrc/onnx-utils.h" | 20 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 20 | -namespace sherpa_onnx { | 21 | + |
| 22 | +namespace { | ||
| 21 | 23 | ||
| 22 | enum class ModelType { | 24 | enum class ModelType { |
| 23 | kLstm, | 25 | kLstm, |
| @@ -25,6 +27,10 @@ enum class ModelType { | @@ -25,6 +27,10 @@ enum class ModelType { | ||
| 25 | kUnkown, | 27 | kUnkown, |
| 26 | }; | 28 | }; |
| 27 | 29 | ||
| 30 | +} | ||
| 31 | + | ||
| 32 | +namespace sherpa_onnx { | ||
| 33 | + | ||
| 28 | static ModelType GetModelType(char *model_data, size_t model_data_length, | 34 | static ModelType GetModelType(char *model_data, size_t model_data_length, |
| 29 | bool debug) { | 35 | bool debug) { |
| 30 | Ort::Env env(ORT_LOGGING_LEVEL_WARNING); | 36 | Ort::Env env(ORT_LOGGING_LEVEL_WARNING); |
| @@ -37,14 +43,17 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -37,14 +43,17 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 37 | if (debug) { | 43 | if (debug) { |
| 38 | std::ostringstream os; | 44 | std::ostringstream os; |
| 39 | PrintModelMetadata(os, meta_data); | 45 | PrintModelMetadata(os, meta_data); |
| 40 | - fprintf(stderr, "%s\n", os.str().c_str()); | 46 | + SHERPA_ONNX_LOGE("%s", os.str().c_str()); |
| 41 | } | 47 | } |
| 42 | 48 | ||
| 43 | Ort::AllocatorWithDefaultOptions allocator; | 49 | Ort::AllocatorWithDefaultOptions allocator; |
| 44 | auto model_type = | 50 | auto model_type = |
| 45 | meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); | 51 | meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); |
| 46 | if (!model_type) { | 52 | if (!model_type) { |
| 47 | - fprintf(stderr, "No model_type in the metadata!\n"); | 53 | + SHERPA_ONNX_LOGE( |
| 54 | + "No model_type in the metadata!\n" | ||
| 55 | + "Please make sure you are using the latest export-onnx.py from icefall " | ||
| 56 | + "to export your transducer models"); | ||
| 48 | return ModelType::kUnkown; | 57 | return ModelType::kUnkown; |
| 49 | } | 58 | } |
| 50 | 59 | ||
| @@ -53,7 +62,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -53,7 +62,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 53 | } else if (model_type.get() == std::string("zipformer")) { | 62 | } else if (model_type.get() == std::string("zipformer")) { |
| 54 | return ModelType::kZipformer; | 63 | return ModelType::kZipformer; |
| 55 | } else { | 64 | } else { |
| 56 | - fprintf(stderr, "Unsupported model_type: %s\n", model_type.get()); | 65 | + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); |
| 57 | return ModelType::kUnkown; | 66 | return ModelType::kUnkown; |
| 58 | } | 67 | } |
| 59 | } | 68 | } |
| @@ -74,6 +83,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | @@ -74,6 +83,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | ||
| 74 | case ModelType::kZipformer: | 83 | case ModelType::kZipformer: |
| 75 | return std::make_unique<OnlineZipformerTransducerModel>(config); | 84 | return std::make_unique<OnlineZipformerTransducerModel>(config); |
| 76 | case ModelType::kUnkown: | 85 | case ModelType::kUnkown: |
| 86 | + SHERPA_ONNX_LOGE("Unknown model type in online transducer!"); | ||
| 77 | return nullptr; | 87 | return nullptr; |
| 78 | } | 88 | } |
| 79 | 89 | ||
| @@ -127,6 +137,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | @@ -127,6 +137,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | ||
| 127 | case ModelType::kZipformer: | 137 | case ModelType::kZipformer: |
| 128 | return std::make_unique<OnlineZipformerTransducerModel>(mgr, config); | 138 | return std::make_unique<OnlineZipformerTransducerModel>(mgr, config); |
| 129 | case ModelType::kUnkown: | 139 | case ModelType::kUnkown: |
| 140 | + SHERPA_ONNX_LOGE("Unknown model type in online transducer!"); | ||
| 130 | return nullptr; | 141 | return nullptr; |
| 131 | } | 142 | } |
| 132 | 143 |
| @@ -35,4 +35,28 @@ TEST(Tranpose, Tranpose01) { | @@ -35,4 +35,28 @@ TEST(Tranpose, Tranpose01) { | ||
| 35 | } | 35 | } |
| 36 | } | 36 | } |
| 37 | 37 | ||
| 38 | +TEST(Tranpose, Tranpose12) { | ||
| 39 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 40 | + std::array<int64_t, 3> shape{3, 2, 5}; | ||
| 41 | + Ort::Value v = | ||
| 42 | + Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size()); | ||
| 43 | + float *p = v.GetTensorMutableData<float>(); | ||
| 44 | + | ||
| 45 | + std::iota(p, p + shape[0] * shape[1] * shape[2], 0); | ||
| 46 | + | ||
| 47 | + auto ans = Transpose12(allocator, &v); | ||
| 48 | + auto v2 = Transpose12(allocator, &ans); | ||
| 49 | + | ||
| 50 | + Print3D(&v); | ||
| 51 | + Print3D(&ans); | ||
| 52 | + Print3D(&v2); | ||
| 53 | + | ||
| 54 | + const float *q = v2.GetTensorData<float>(); | ||
| 55 | + | ||
| 56 | + for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1] * shape[2]); | ||
| 57 | + ++i) { | ||
| 58 | + EXPECT_EQ(p[i], q[i]); | ||
| 59 | + } | ||
| 60 | +} | ||
| 61 | + | ||
| 38 | } // namespace sherpa_onnx | 62 | } // namespace sherpa_onnx |
| @@ -17,8 +17,8 @@ Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v) { | @@ -17,8 +17,8 @@ Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v) { | ||
| 17 | assert(shape.size() == 3); | 17 | assert(shape.size() == 3); |
| 18 | 18 | ||
| 19 | std::array<int64_t, 3> ans_shape{shape[1], shape[0], shape[2]}; | 19 | std::array<int64_t, 3> ans_shape{shape[1], shape[0], shape[2]}; |
| 20 | - Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(), | ||
| 21 | - ans_shape.size()); | 20 | + Ort::Value ans = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(), |
| 21 | + ans_shape.size()); | ||
| 22 | 22 | ||
| 23 | T *dst = ans.GetTensorMutableData<T>(); | 23 | T *dst = ans.GetTensorMutableData<T>(); |
| 24 | auto plane_offset = shape[1] * shape[2]; | 24 | auto plane_offset = shape[1] * shape[2]; |
| @@ -35,7 +35,32 @@ Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v) { | @@ -35,7 +35,32 @@ Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v) { | ||
| 35 | return ans; | 35 | return ans; |
| 36 | } | 36 | } |
| 37 | 37 | ||
| 38 | +template <typename T /*= float*/> | ||
| 39 | +Ort::Value Transpose12(OrtAllocator *allocator, const Ort::Value *v) { | ||
| 40 | + std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 41 | + assert(shape.size() == 3); | ||
| 42 | + | ||
| 43 | + std::array<int64_t, 3> ans_shape{shape[0], shape[2], shape[1]}; | ||
| 44 | + Ort::Value ans = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(), | ||
| 45 | + ans_shape.size()); | ||
| 46 | + T *dst = ans.GetTensorMutableData<T>(); | ||
| 47 | + auto row_stride = shape[2]; | ||
| 48 | + for (int64_t b = 0; b != ans_shape[0]; ++b) { | ||
| 49 | + const T *src = v->GetTensorData<T>() + b * shape[1] * shape[2]; | ||
| 50 | + for (int64_t i = 0; i != ans_shape[1]; ++i) { | ||
| 51 | + for (int64_t k = 0; k != ans_shape[2]; ++k, ++dst) { | ||
| 52 | + *dst = (src + k * row_stride)[i]; | ||
| 53 | + } | ||
| 54 | + } | ||
| 55 | + } | ||
| 56 | + | ||
| 57 | + return ans; | ||
| 58 | +} | ||
| 59 | + | ||
| 38 | template Ort::Value Transpose01<float>(OrtAllocator *allocator, | 60 | template Ort::Value Transpose01<float>(OrtAllocator *allocator, |
| 39 | const Ort::Value *v); | 61 | const Ort::Value *v); |
| 40 | 62 | ||
| 63 | +template Ort::Value Transpose12<float>(OrtAllocator *allocator, | ||
| 64 | + const Ort::Value *v); | ||
| 65 | + | ||
| 41 | } // namespace sherpa_onnx | 66 | } // namespace sherpa_onnx |
| @@ -10,13 +10,23 @@ namespace sherpa_onnx { | @@ -10,13 +10,23 @@ namespace sherpa_onnx { | ||
| 10 | /** Transpose a 3-D tensor from shape (B, T, C) to (T, B, C). | 10 | /** Transpose a 3-D tensor from shape (B, T, C) to (T, B, C). |
| 11 | * | 11 | * |
| 12 | * @param allocator | 12 | * @param allocator |
| 13 | - * @param v A 3-D tensor of shape (B, T, C). Its dataype is T. | 13 | + * @param v A 3-D tensor of shape (B, T, C). Its dataype is type. |
| 14 | * | 14 | * |
| 15 | - * @return Return a 3-D tensor of shape (T, B, C). Its datatype is T. | 15 | + * @return Return a 3-D tensor of shape (T, B, C). Its datatype is type. |
| 16 | */ | 16 | */ |
| 17 | -template <typename T = float> | 17 | +template <typename type = float> |
| 18 | Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v); | 18 | Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v); |
| 19 | 19 | ||
| 20 | +/** Transpose a 3-D tensor from shape (B, T, C) to (B, C, T). | ||
| 21 | + * | ||
| 22 | + * @param allocator | ||
| 23 | + * @param v A 3-D tensor of shape (B, T, C). Its dataype is type. | ||
| 24 | + * | ||
| 25 | + * @return Return a 3-D tensor of shape (B, C, T). Its datatype is type. | ||
| 26 | + */ | ||
| 27 | +template <typename type = float> | ||
| 28 | +Ort::Value Transpose12(OrtAllocator *allocator, const Ort::Value *v); | ||
| 29 | + | ||
| 20 | } // namespace sherpa_onnx | 30 | } // namespace sherpa_onnx |
| 21 | 31 | ||
| 22 | #endif // SHERPA_ONNX_CSRC_TRANSPOSE_H_ | 32 | #endif // SHERPA_ONNX_CSRC_TRANSPOSE_H_ |
| @@ -5,6 +5,7 @@ pybind11_add_module(_sherpa_onnx | @@ -5,6 +5,7 @@ pybind11_add_module(_sherpa_onnx | ||
| 5 | endpoint.cc | 5 | endpoint.cc |
| 6 | features.cc | 6 | features.cc |
| 7 | offline-model-config.cc | 7 | offline-model-config.cc |
| 8 | + offline-nemo-enc-dec-ctc-model-config.cc | ||
| 8 | offline-paraformer-model-config.cc | 9 | offline-paraformer-model-config.cc |
| 9 | offline-recognizer.cc | 10 | offline-recognizer.cc |
| 10 | offline-stream.cc | 11 | offline-stream.cc |
| @@ -7,26 +7,31 @@ | @@ -7,26 +7,31 @@ | ||
| 7 | #include <string> | 7 | #include <string> |
| 8 | #include <vector> | 8 | #include <vector> |
| 9 | 9 | ||
| 10 | -#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" | ||
| 11 | -#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" | ||
| 12 | - | ||
| 13 | #include "sherpa-onnx/csrc/offline-model-config.h" | 10 | #include "sherpa-onnx/csrc/offline-model-config.h" |
| 11 | +#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h" | ||
| 12 | +#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" | ||
| 13 | +#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" | ||
| 14 | 14 | ||
| 15 | namespace sherpa_onnx { | 15 | namespace sherpa_onnx { |
| 16 | 16 | ||
| 17 | void PybindOfflineModelConfig(py::module *m) { | 17 | void PybindOfflineModelConfig(py::module *m) { |
| 18 | PybindOfflineTransducerModelConfig(m); | 18 | PybindOfflineTransducerModelConfig(m); |
| 19 | PybindOfflineParaformerModelConfig(m); | 19 | PybindOfflineParaformerModelConfig(m); |
| 20 | + PybindOfflineNemoEncDecCtcModelConfig(m); | ||
| 20 | 21 | ||
| 21 | using PyClass = OfflineModelConfig; | 22 | using PyClass = OfflineModelConfig; |
| 22 | py::class_<PyClass>(*m, "OfflineModelConfig") | 23 | py::class_<PyClass>(*m, "OfflineModelConfig") |
| 23 | - .def(py::init<OfflineTransducerModelConfig &, | ||
| 24 | - OfflineParaformerModelConfig &, | ||
| 25 | - const std::string &, int32_t, bool>(), | ||
| 26 | - py::arg("transducer"), py::arg("paraformer"), py::arg("tokens"), | ||
| 27 | - py::arg("num_threads"), py::arg("debug") = false) | 24 | + .def(py::init<const OfflineTransducerModelConfig &, |
| 25 | + const OfflineParaformerModelConfig &, | ||
| 26 | + const OfflineNemoEncDecCtcModelConfig &, | ||
| 27 | + const std::string &, int32_t, bool>(), | ||
| 28 | + py::arg("transducer") = OfflineTransducerModelConfig(), | ||
| 29 | + py::arg("paraformer") = OfflineParaformerModelConfig(), | ||
| 30 | + py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), | ||
| 31 | + py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false) | ||
| 28 | .def_readwrite("transducer", &PyClass::transducer) | 32 | .def_readwrite("transducer", &PyClass::transducer) |
| 29 | .def_readwrite("paraformer", &PyClass::paraformer) | 33 | .def_readwrite("paraformer", &PyClass::paraformer) |
| 34 | + .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) | ||
| 30 | .def_readwrite("tokens", &PyClass::tokens) | 35 | .def_readwrite("tokens", &PyClass::tokens) |
| 31 | .def_readwrite("num_threads", &PyClass::num_threads) | 36 | .def_readwrite("num_threads", &PyClass::num_threads) |
| 32 | .def_readwrite("debug", &PyClass::debug) | 37 | .def_readwrite("debug", &PyClass::debug) |
| 1 | +// sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +void PybindOfflineNemoEncDecCtcModelConfig(py::module *m) { | ||
| 15 | + using PyClass = OfflineNemoEncDecCtcModelConfig; | ||
| 16 | + py::class_<PyClass>(*m, "OfflineNemoEncDecCtcModelConfig") | ||
| 17 | + .def(py::init<const std::string &>(), py::arg("model")) | ||
| 18 | + .def_readwrite("model", &PyClass::model) | ||
| 19 | + .def("__str__", &PyClass::ToString); | ||
| 20 | +} | ||
| 21 | + | ||
| 22 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindOfflineNemoEncDecCtcModelConfig(py::module *m); | ||
| 13 | + | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_ |
| @@ -4,7 +4,6 @@ | @@ -4,7 +4,6 @@ | ||
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" | 5 | #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" |
| 6 | 6 | ||
| 7 | - | ||
| 8 | #include <string> | 7 | #include <string> |
| 9 | #include <vector> | 8 | #include <vector> |
| 10 | 9 | ||
| @@ -15,8 +14,7 @@ namespace sherpa_onnx { | @@ -15,8 +14,7 @@ namespace sherpa_onnx { | ||
| 15 | void PybindOfflineParaformerModelConfig(py::module *m) { | 14 | void PybindOfflineParaformerModelConfig(py::module *m) { |
| 16 | using PyClass = OfflineParaformerModelConfig; | 15 | using PyClass = OfflineParaformerModelConfig; |
| 17 | py::class_<PyClass>(*m, "OfflineParaformerModelConfig") | 16 | py::class_<PyClass>(*m, "OfflineParaformerModelConfig") |
| 18 | - .def(py::init<const std::string &>(), | ||
| 19 | - py::arg("model")) | 17 | + .def(py::init<const std::string &>(), py::arg("model")) |
| 20 | .def_readwrite("model", &PyClass::model) | 18 | .def_readwrite("model", &PyClass::model) |
| 21 | .def("__str__", &PyClass::ToString); | 19 | .def("__str__", &PyClass::ToString); |
| 22 | } | 20 | } |
| @@ -11,8 +11,6 @@ | @@ -11,8 +11,6 @@ | ||
| 11 | 11 | ||
| 12 | namespace sherpa_onnx { | 12 | namespace sherpa_onnx { |
| 13 | 13 | ||
| 14 | - | ||
| 15 | - | ||
| 16 | static void PybindOfflineRecognizerConfig(py::module *m) { | 14 | static void PybindOfflineRecognizerConfig(py::module *m) { |
| 17 | using PyClass = OfflineRecognizerConfig; | 15 | using PyClass = OfflineRecognizerConfig; |
| 18 | py::class_<PyClass>(*m, "OfflineRecognizerConfig") | 16 | py::class_<PyClass>(*m, "OfflineRecognizerConfig") |
| @@ -31,7 +31,6 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT | @@ -31,7 +31,6 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT | ||
| 31 | "timestamps", [](const PyClass &self) { return self.timestamps; }); | 31 | "timestamps", [](const PyClass &self) { return self.timestamps; }); |
| 32 | } | 32 | } |
| 33 | 33 | ||
| 34 | - | ||
| 35 | static void PybindOfflineFeatureExtractorConfig(py::module *m) { | 34 | static void PybindOfflineFeatureExtractorConfig(py::module *m) { |
| 36 | using PyClass = OfflineFeatureExtractorConfig; | 35 | using PyClass = OfflineFeatureExtractorConfig; |
| 37 | py::class_<PyClass>(*m, "OfflineFeatureExtractorConfig") | 36 | py::class_<PyClass>(*m, "OfflineFeatureExtractorConfig") |
| @@ -42,7 +41,6 @@ static void PybindOfflineFeatureExtractorConfig(py::module *m) { | @@ -42,7 +41,6 @@ static void PybindOfflineFeatureExtractorConfig(py::module *m) { | ||
| 42 | .def("__str__", &PyClass::ToString); | 41 | .def("__str__", &PyClass::ToString); |
| 43 | } | 42 | } |
| 44 | 43 | ||
| 45 | - | ||
| 46 | void PybindOfflineStream(py::module *m) { | 44 | void PybindOfflineStream(py::module *m) { |
| 47 | PybindOfflineFeatureExtractorConfig(m); | 45 | PybindOfflineFeatureExtractorConfig(m); |
| 48 | PybindOfflineRecognitionResult(m); | 46 | PybindOfflineRecognitionResult(m); |
| @@ -55,7 +53,7 @@ void PybindOfflineStream(py::module *m) { | @@ -55,7 +53,7 @@ void PybindOfflineStream(py::module *m) { | ||
| 55 | self.AcceptWaveform(sample_rate, waveform.data(), waveform.size()); | 53 | self.AcceptWaveform(sample_rate, waveform.data(), waveform.size()); |
| 56 | }, | 54 | }, |
| 57 | py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage) | 55 | py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage) |
| 58 | - .def_property_readonly("result", &PyClass::GetResult); | 56 | + .def_property_readonly("result", &PyClass::GetResult); |
| 59 | } | 57 | } |
| 60 | 58 | ||
| 61 | } // namespace sherpa_onnx | 59 | } // namespace sherpa_onnx |
| @@ -7,15 +7,12 @@ | @@ -7,15 +7,12 @@ | ||
| 7 | #include "sherpa-onnx/python/csrc/display.h" | 7 | #include "sherpa-onnx/python/csrc/display.h" |
| 8 | #include "sherpa-onnx/python/csrc/endpoint.h" | 8 | #include "sherpa-onnx/python/csrc/endpoint.h" |
| 9 | #include "sherpa-onnx/python/csrc/features.h" | 9 | #include "sherpa-onnx/python/csrc/features.h" |
| 10 | -#include "sherpa-onnx/python/csrc/online-recognizer.h" | ||
| 11 | -#include "sherpa-onnx/python/csrc/online-stream.h" | ||
| 12 | -#include "sherpa-onnx/python/csrc/online-transducer-model-config.h" | ||
| 13 | - | ||
| 14 | #include "sherpa-onnx/python/csrc/offline-model-config.h" | 10 | #include "sherpa-onnx/python/csrc/offline-model-config.h" |
| 15 | -#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" | ||
| 16 | #include "sherpa-onnx/python/csrc/offline-recognizer.h" | 11 | #include "sherpa-onnx/python/csrc/offline-recognizer.h" |
| 17 | #include "sherpa-onnx/python/csrc/offline-stream.h" | 12 | #include "sherpa-onnx/python/csrc/offline-stream.h" |
| 18 | -#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" | 13 | +#include "sherpa-onnx/python/csrc/online-recognizer.h" |
| 14 | +#include "sherpa-onnx/python/csrc/online-stream.h" | ||
| 15 | +#include "sherpa-onnx/python/csrc/online-transducer-model-config.h" | ||
| 19 | 16 | ||
| 20 | namespace sherpa_onnx { | 17 | namespace sherpa_onnx { |
| 21 | 18 |
| @@ -4,12 +4,15 @@ from typing import List | @@ -4,12 +4,15 @@ from typing import List | ||
| 4 | 4 | ||
| 5 | from _sherpa_onnx import ( | 5 | from _sherpa_onnx import ( |
| 6 | OfflineFeatureExtractorConfig, | 6 | OfflineFeatureExtractorConfig, |
| 7 | - OfflineRecognizer as _Recognizer, | 7 | + OfflineModelConfig, |
| 8 | + OfflineNemoEncDecCtcModelConfig, | ||
| 9 | + OfflineParaformerModelConfig, | ||
| 10 | +) | ||
| 11 | +from _sherpa_onnx import OfflineRecognizer as _Recognizer | ||
| 12 | +from _sherpa_onnx import ( | ||
| 8 | OfflineRecognizerConfig, | 13 | OfflineRecognizerConfig, |
| 9 | OfflineStream, | 14 | OfflineStream, |
| 10 | - OfflineModelConfig, | ||
| 11 | OfflineTransducerModelConfig, | 15 | OfflineTransducerModelConfig, |
| 12 | - OfflineParaformerModelConfig, | ||
| 13 | ) | 16 | ) |
| 14 | 17 | ||
| 15 | 18 | ||
| @@ -75,7 +78,6 @@ class OfflineRecognizer(object): | @@ -75,7 +78,6 @@ class OfflineRecognizer(object): | ||
| 75 | decoder_filename=decoder, | 78 | decoder_filename=decoder, |
| 76 | joiner_filename=joiner, | 79 | joiner_filename=joiner, |
| 77 | ), | 80 | ), |
| 78 | - paraformer=OfflineParaformerModelConfig(model=""), | ||
| 79 | tokens=tokens, | 81 | tokens=tokens, |
| 80 | num_threads=num_threads, | 82 | num_threads=num_threads, |
| 81 | debug=debug, | 83 | debug=debug, |
| @@ -119,7 +121,7 @@ class OfflineRecognizer(object): | @@ -119,7 +121,7 @@ class OfflineRecognizer(object): | ||
| 119 | symbol integer_id | 121 | symbol integer_id |
| 120 | 122 | ||
| 121 | paraformer: | 123 | paraformer: |
| 122 | - Path to ``paraformer.onnx``. | 124 | + Path to ``model.onnx``. |
| 123 | num_threads: | 125 | num_threads: |
| 124 | Number of threads for neural network computation. | 126 | Number of threads for neural network computation. |
| 125 | sample_rate: | 127 | sample_rate: |
| @@ -133,9 +135,6 @@ class OfflineRecognizer(object): | @@ -133,9 +135,6 @@ class OfflineRecognizer(object): | ||
| 133 | """ | 135 | """ |
| 134 | self = cls.__new__(cls) | 136 | self = cls.__new__(cls) |
| 135 | model_config = OfflineModelConfig( | 137 | model_config = OfflineModelConfig( |
| 136 | - transducer=OfflineTransducerModelConfig( | ||
| 137 | - encoder_filename="", decoder_filename="", joiner_filename="" | ||
| 138 | - ), | ||
| 139 | paraformer=OfflineParaformerModelConfig(model=paraformer), | 138 | paraformer=OfflineParaformerModelConfig(model=paraformer), |
| 140 | tokens=tokens, | 139 | tokens=tokens, |
| 141 | num_threads=num_threads, | 140 | num_threads=num_threads, |
| @@ -155,6 +154,64 @@ class OfflineRecognizer(object): | @@ -155,6 +154,64 @@ class OfflineRecognizer(object): | ||
| 155 | self.recognizer = _Recognizer(recognizer_config) | 154 | self.recognizer = _Recognizer(recognizer_config) |
| 156 | return self | 155 | return self |
| 157 | 156 | ||
| 157 | + @classmethod | ||
| 158 | + def from_nemo_ctc( | ||
| 159 | + cls, | ||
| 160 | + model: str, | ||
| 161 | + tokens: str, | ||
| 162 | + num_threads: int, | ||
| 163 | + sample_rate: int = 16000, | ||
| 164 | + feature_dim: int = 80, | ||
| 165 | + decoding_method: str = "greedy_search", | ||
| 166 | + debug: bool = False, | ||
| 167 | + ): | ||
| 168 | + """ | ||
| 169 | + Please refer to | ||
| 170 | + `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_ | ||
| 171 | + to download pre-trained models for different languages, e.g., Chinese, | ||
| 172 | + English, etc. | ||
| 173 | + | ||
| 174 | + Args: | ||
| 175 | + tokens: | ||
| 176 | + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two | ||
| 177 | + columns:: | ||
| 178 | + | ||
| 179 | + symbol integer_id | ||
| 180 | + | ||
| 181 | + model: | ||
| 182 | + Path to ``model.onnx``. | ||
| 183 | + num_threads: | ||
| 184 | + Number of threads for neural network computation. | ||
| 185 | + sample_rate: | ||
| 186 | + Sample rate of the training data used to train the model. | ||
| 187 | + feature_dim: | ||
| 188 | + Dimension of the feature used to train the model. | ||
| 189 | + decoding_method: | ||
| 190 | + Valid values are greedy_search, modified_beam_search. | ||
| 191 | + debug: | ||
| 192 | + True to show debug messages. | ||
| 193 | + """ | ||
| 194 | + self = cls.__new__(cls) | ||
| 195 | + model_config = OfflineModelConfig( | ||
| 196 | + nemo_ctc=OfflineNemoEncDecCtcModelConfig(model=model), | ||
| 197 | + tokens=tokens, | ||
| 198 | + num_threads=num_threads, | ||
| 199 | + debug=debug, | ||
| 200 | + ) | ||
| 201 | + | ||
| 202 | + feat_config = OfflineFeatureExtractorConfig( | ||
| 203 | + sampling_rate=sample_rate, | ||
| 204 | + feature_dim=feature_dim, | ||
| 205 | + ) | ||
| 206 | + | ||
| 207 | + recognizer_config = OfflineRecognizerConfig( | ||
| 208 | + feat_config=feat_config, | ||
| 209 | + model_config=model_config, | ||
| 210 | + decoding_method=decoding_method, | ||
| 211 | + ) | ||
| 212 | + self.recognizer = _Recognizer(recognizer_config) | ||
| 213 | + return self | ||
| 214 | + | ||
| 158 | def create_stream(self): | 215 | def create_stream(self): |
| 159 | return self.recognizer.create_stream() | 216 | return self.recognizer.create_stream() |
| 160 | 217 |
| @@ -196,6 +196,71 @@ class TestOfflineRecognizer(unittest.TestCase): | @@ -196,6 +196,71 @@ class TestOfflineRecognizer(unittest.TestCase): | ||
| 196 | print(s2.result.text) | 196 | print(s2.result.text) |
| 197 | print(s3.result.text) | 197 | print(s3.result.text) |
| 198 | 198 | ||
| 199 | + def test_nemo_ctc_single_file(self): | ||
| 200 | + for use_int8 in [True, False]: | ||
| 201 | + if use_int8: | ||
| 202 | + model = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/model.int8.onnx" | ||
| 203 | + else: | ||
| 204 | + model = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx" | ||
| 205 | + | ||
| 206 | + tokens = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt" | ||
| 207 | + wave0 = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav" | ||
| 208 | + | ||
| 209 | + if not Path(model).is_file(): | ||
| 210 | + print("skipping test_nemo_ctc_single_file()") | ||
| 211 | + return | ||
| 212 | + | ||
| 213 | + recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc( | ||
| 214 | + model=model, | ||
| 215 | + tokens=tokens, | ||
| 216 | + num_threads=1, | ||
| 217 | + ) | ||
| 218 | + | ||
| 219 | + s = recognizer.create_stream() | ||
| 220 | + samples, sample_rate = read_wave(wave0) | ||
| 221 | + s.accept_waveform(sample_rate, samples) | ||
| 222 | + recognizer.decode_stream(s) | ||
| 223 | + print(s.result.text) | ||
| 224 | + | ||
| 225 | + def test_nemo_ctc_multiple_files(self): | ||
| 226 | + for use_int8 in [True, False]: | ||
| 227 | + if use_int8: | ||
| 228 | + model = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/model.int8.onnx" | ||
| 229 | + else: | ||
| 230 | + model = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx" | ||
| 231 | + | ||
| 232 | + tokens = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt" | ||
| 233 | + wave0 = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav" | ||
| 234 | + wave1 = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav" | ||
| 235 | + wave2 = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav" | ||
| 236 | + | ||
| 237 | + if not Path(model).is_file(): | ||
| 238 | + print("skipping test_nemo_ctc_multiple_files()") | ||
| 239 | + return | ||
| 240 | + | ||
| 241 | + recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc( | ||
| 242 | + model=model, | ||
| 243 | + tokens=tokens, | ||
| 244 | + num_threads=1, | ||
| 245 | + ) | ||
| 246 | + | ||
| 247 | + s0 = recognizer.create_stream() | ||
| 248 | + samples0, sample_rate0 = read_wave(wave0) | ||
| 249 | + s0.accept_waveform(sample_rate0, samples0) | ||
| 250 | + | ||
| 251 | + s1 = recognizer.create_stream() | ||
| 252 | + samples1, sample_rate1 = read_wave(wave1) | ||
| 253 | + s1.accept_waveform(sample_rate1, samples1) | ||
| 254 | + | ||
| 255 | + s2 = recognizer.create_stream() | ||
| 256 | + samples2, sample_rate2 = read_wave(wave2) | ||
| 257 | + s2.accept_waveform(sample_rate2, samples2) | ||
| 258 | + | ||
| 259 | + recognizer.decode_streams([s0, s1, s2]) | ||
| 260 | + print(s0.result.text) | ||
| 261 | + print(s1.result.text) | ||
| 262 | + print(s2.result.text) | ||
| 263 | + | ||
| 199 | 264 | ||
| 200 | if __name__ == "__main__": | 265 | if __name__ == "__main__": |
| 201 | unittest.main() | 266 | unittest.main() |
-
请 注册 或 登录 后发表评论