Fangjun Kuang
Committed by GitHub

Begin to support CTC models (#119)

Please see https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/nemo/index.html for a list of pre-trained CTC models from NeMo.
正在显示 40 个修改的文件 包含 1244 行增加60 行删除
  1 +#!/usr/bin/env bash
  2 +
  3 +set -e
  4 +
  5 +log() {
  6 + # This function is from espnet
  7 + local fname=${BASH_SOURCE[1]##*/}
  8 + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
  9 +}
  10 +
  11 +echo "EXE is $EXE"
  12 +echo "PATH: $PATH"
  13 +
  14 +which $EXE
  15 +
  16 +log "------------------------------------------------------------"
  17 +log "Run Citrinet (stt_en_citrinet_512, English)"
  18 +log "------------------------------------------------------------"
  19 +
  20 +repo_url=http://huggingface.co/csukuangfj/sherpa-onnx-nemo-ctc-en-citrinet-512
  21 +log "Start testing ${repo_url}"
  22 +repo=$(basename $repo_url)
  23 +log "Download pretrained model and test-data from $repo_url"
  24 +
  25 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
  26 +pushd $repo
  27 +git lfs pull --include "*.onnx"
  28 +ls -lh *.onnx
  29 +popd
  30 +
  31 +time $EXE \
  32 + --tokens=$repo/tokens.txt \
  33 + --nemo-ctc-model=$repo/model.onnx \
  34 + --num-threads=2 \
  35 + $repo/test_wavs/0.wav \
  36 + $repo/test_wavs/1.wav \
  37 + $repo/test_wavs/8k.wav
  38 +
  39 +time $EXE \
  40 + --tokens=$repo/tokens.txt \
  41 + --nemo-ctc-model=$repo/model.int8.onnx \
  42 + --num-threads=2 \
  43 + $repo/test_wavs/0.wav \
  44 + $repo/test_wavs/1.wav \
  45 + $repo/test_wavs/8k.wav
  46 +
  47 +rm -rf $repo
@@ -95,6 +95,8 @@ python3 ./python-api-examples/offline-decode-files.py \ @@ -95,6 +95,8 @@ python3 ./python-api-examples/offline-decode-files.py \
95 95
96 python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose 96 python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
97 97
  98 +rm -rf $repo
  99 +
98 log "Test non-streaming paraformer models" 100 log "Test non-streaming paraformer models"
99 101
100 pushd $dir 102 pushd $dir
@@ -128,3 +130,39 @@ python3 ./python-api-examples/offline-decode-files.py \ @@ -128,3 +130,39 @@ python3 ./python-api-examples/offline-decode-files.py \
128 $repo/test_wavs/8k.wav 130 $repo/test_wavs/8k.wav
129 131
130 python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose 132 python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
  133 +
  134 +rm -rf $repo
  135 +
  136 +log "Test non-streaming NeMo CTC models"
  137 +
  138 +pushd $dir
  139 +repo_url=http://huggingface.co/csukuangfj/sherpa-onnx-nemo-ctc-en-citrinet-512
  140 +
  141 +log "Start testing ${repo_url}"
  142 +repo=$dir/$(basename $repo_url)
  143 +log "Download pretrained model and test-data from $repo_url"
  144 +
  145 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
  146 +cd $repo
  147 +git lfs pull --include "*.onnx"
  148 +popd
  149 +
  150 +ls -lh $repo
  151 +
  152 +python3 ./python-api-examples/offline-decode-files.py \
  153 + --tokens=$repo/tokens.txt \
  154 + --nemo-ctc=$repo/model.onnx \
  155 + $repo/test_wavs/0.wav \
  156 + $repo/test_wavs/1.wav \
  157 + $repo/test_wavs/8k.wav
  158 +
  159 +python3 ./python-api-examples/offline-decode-files.py \
  160 + --tokens=$repo/tokens.txt \
  161 + --nemo-ctc=$repo/model.int8.onnx \
  162 + $repo/test_wavs/0.wav \
  163 + $repo/test_wavs/1.wav \
  164 + $repo/test_wavs/8k.wav
  165 +
  166 +python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
  167 +
  168 +rm -rf $repo
@@ -8,6 +8,7 @@ on: @@ -8,6 +8,7 @@ on:
8 - '.github/workflows/linux.yaml' 8 - '.github/workflows/linux.yaml'
9 - '.github/scripts/test-online-transducer.sh' 9 - '.github/scripts/test-online-transducer.sh'
10 - '.github/scripts/test-offline-transducer.sh' 10 - '.github/scripts/test-offline-transducer.sh'
  11 + - '.github/scripts/test-offline-ctc.sh'
11 - 'CMakeLists.txt' 12 - 'CMakeLists.txt'
12 - 'cmake/**' 13 - 'cmake/**'
13 - 'sherpa-onnx/csrc/*' 14 - 'sherpa-onnx/csrc/*'
@@ -20,6 +21,7 @@ on: @@ -20,6 +21,7 @@ on:
20 - '.github/workflows/linux.yaml' 21 - '.github/workflows/linux.yaml'
21 - '.github/scripts/test-online-transducer.sh' 22 - '.github/scripts/test-online-transducer.sh'
22 - '.github/scripts/test-offline-transducer.sh' 23 - '.github/scripts/test-offline-transducer.sh'
  24 + - '.github/scripts/test-offline-ctc.sh'
23 - 'CMakeLists.txt' 25 - 'CMakeLists.txt'
24 - 'cmake/**' 26 - 'cmake/**'
25 - 'sherpa-onnx/csrc/*' 27 - 'sherpa-onnx/csrc/*'
@@ -68,6 +70,14 @@ jobs: @@ -68,6 +70,14 @@ jobs:
68 file build/bin/sherpa-onnx 70 file build/bin/sherpa-onnx
69 readelf -d build/bin/sherpa-onnx 71 readelf -d build/bin/sherpa-onnx
70 72
  73 + - name: Test offline CTC
  74 + shell: bash
  75 + run: |
  76 + export PATH=$PWD/build/bin:$PATH
  77 + export EXE=sherpa-onnx-offline
  78 +
  79 + .github/scripts/test-offline-ctc.sh
  80 +
71 - name: Test offline transducer 81 - name: Test offline transducer
72 shell: bash 82 shell: bash
73 run: | 83 run: |
@@ -8,6 +8,7 @@ on: @@ -8,6 +8,7 @@ on:
8 - '.github/workflows/macos.yaml' 8 - '.github/workflows/macos.yaml'
9 - '.github/scripts/test-online-transducer.sh' 9 - '.github/scripts/test-online-transducer.sh'
10 - '.github/scripts/test-offline-transducer.sh' 10 - '.github/scripts/test-offline-transducer.sh'
  11 + - '.github/scripts/test-offline-ctc.sh'
11 - 'CMakeLists.txt' 12 - 'CMakeLists.txt'
12 - 'cmake/**' 13 - 'cmake/**'
13 - 'sherpa-onnx/csrc/*' 14 - 'sherpa-onnx/csrc/*'
@@ -18,6 +19,7 @@ on: @@ -18,6 +19,7 @@ on:
18 - '.github/workflows/macos.yaml' 19 - '.github/workflows/macos.yaml'
19 - '.github/scripts/test-online-transducer.sh' 20 - '.github/scripts/test-online-transducer.sh'
20 - '.github/scripts/test-offline-transducer.sh' 21 - '.github/scripts/test-offline-transducer.sh'
  22 + - '.github/scripts/test-offline-ctc.sh'
21 - 'CMakeLists.txt' 23 - 'CMakeLists.txt'
22 - 'cmake/**' 24 - 'cmake/**'
23 - 'sherpa-onnx/csrc/*' 25 - 'sherpa-onnx/csrc/*'
@@ -67,6 +69,14 @@ jobs: @@ -67,6 +69,14 @@ jobs:
67 otool -L build/bin/sherpa-onnx 69 otool -L build/bin/sherpa-onnx
68 otool -l build/bin/sherpa-onnx 70 otool -l build/bin/sherpa-onnx
69 71
  72 + - name: Test offline CTC
  73 + shell: bash
  74 + run: |
  75 + export PATH=$PWD/build/bin:$PATH
  76 + export EXE=sherpa-onnx-offline
  77 +
  78 + .github/scripts/test-offline-ctc.sh
  79 +
70 - name: Test offline transducer 80 - name: Test offline transducer
71 shell: bash 81 shell: bash
72 run: | 82 run: |
@@ -8,6 +8,7 @@ on: @@ -8,6 +8,7 @@ on:
8 - '.github/workflows/windows-x64.yaml' 8 - '.github/workflows/windows-x64.yaml'
9 - '.github/scripts/test-online-transducer.sh' 9 - '.github/scripts/test-online-transducer.sh'
10 - '.github/scripts/test-offline-transducer.sh' 10 - '.github/scripts/test-offline-transducer.sh'
  11 + - '.github/scripts/test-offline-ctc.sh'
11 - 'CMakeLists.txt' 12 - 'CMakeLists.txt'
12 - 'cmake/**' 13 - 'cmake/**'
13 - 'sherpa-onnx/csrc/*' 14 - 'sherpa-onnx/csrc/*'
@@ -18,6 +19,7 @@ on: @@ -18,6 +19,7 @@ on:
18 - '.github/workflows/windows-x64.yaml' 19 - '.github/workflows/windows-x64.yaml'
19 - '.github/scripts/test-online-transducer.sh' 20 - '.github/scripts/test-online-transducer.sh'
20 - '.github/scripts/test-offline-transducer.sh' 21 - '.github/scripts/test-offline-transducer.sh'
  22 + - '.github/scripts/test-offline-ctc.sh'
21 - 'CMakeLists.txt' 23 - 'CMakeLists.txt'
22 - 'cmake/**' 24 - 'cmake/**'
23 - 'sherpa-onnx/csrc/*' 25 - 'sherpa-onnx/csrc/*'
@@ -73,6 +75,14 @@ jobs: @@ -73,6 +75,14 @@ jobs:
73 75
74 ls -lh ./bin/Release/sherpa-onnx.exe 76 ls -lh ./bin/Release/sherpa-onnx.exe
75 77
  78 + - name: Test offline CTC for windows x64
  79 + shell: bash
  80 + run: |
  81 + export PATH=$PWD/build/bin/Release:$PATH
  82 + export EXE=sherpa-onnx-offline.exe
  83 +
  84 + .github/scripts/test-offline-ctc.sh
  85 +
76 - name: Test offline transducer for Windows x64 86 - name: Test offline transducer for Windows x64
77 shell: bash 87 shell: bash
78 run: | 88 run: |
@@ -8,6 +8,7 @@ on: @@ -8,6 +8,7 @@ on:
8 - '.github/workflows/windows-x86.yaml' 8 - '.github/workflows/windows-x86.yaml'
9 - '.github/scripts/test-online-transducer.sh' 9 - '.github/scripts/test-online-transducer.sh'
10 - '.github/scripts/test-offline-transducer.sh' 10 - '.github/scripts/test-offline-transducer.sh'
  11 + - '.github/scripts/test-offline-ctc.sh'
11 - 'CMakeLists.txt' 12 - 'CMakeLists.txt'
12 - 'cmake/**' 13 - 'cmake/**'
13 - 'sherpa-onnx/csrc/*' 14 - 'sherpa-onnx/csrc/*'
@@ -18,6 +19,7 @@ on: @@ -18,6 +19,7 @@ on:
18 - '.github/workflows/windows-x86.yaml' 19 - '.github/workflows/windows-x86.yaml'
19 - '.github/scripts/test-online-transducer.sh' 20 - '.github/scripts/test-online-transducer.sh'
20 - '.github/scripts/test-offline-transducer.sh' 21 - '.github/scripts/test-offline-transducer.sh'
  22 + - '.github/scripts/test-offline-ctc.sh'
21 - 'CMakeLists.txt' 23 - 'CMakeLists.txt'
22 - 'cmake/**' 24 - 'cmake/**'
23 - 'sherpa-onnx/csrc/*' 25 - 'sherpa-onnx/csrc/*'
@@ -31,6 +33,7 @@ permissions: @@ -31,6 +33,7 @@ permissions:
31 33
32 jobs: 34 jobs:
33 windows_x86: 35 windows_x86:
  36 + if: false # disable windows x86 CI for now
34 runs-on: ${{ matrix.os }} 37 runs-on: ${{ matrix.os }}
35 name: ${{ matrix.vs-version }} 38 name: ${{ matrix.vs-version }}
36 strategy: 39 strategy:
@@ -73,6 +76,14 @@ jobs: @@ -73,6 +76,14 @@ jobs:
73 76
74 ls -lh ./bin/Release/sherpa-onnx.exe 77 ls -lh ./bin/Release/sherpa-onnx.exe
75 78
  79 + - name: Test offline CTC for windows x86
  80 + shell: bash
  81 + run: |
  82 + export PATH=$PWD/build/bin/Release:$PATH
  83 + export EXE=sherpa-onnx-offline.exe
  84 +
  85 + .github/scripts/test-offline-ctc.sh
  86 +
76 - name: Test offline transducer for Windows x86 87 - name: Test offline transducer for Windows x86
77 shell: bash 88 shell: bash
78 run: | 89 run: |
@@ -52,3 +52,6 @@ run-offline-websocket-client-*.sh @@ -52,3 +52,6 @@ run-offline-websocket-client-*.sh
52 run-sherpa-onnx-*.sh 52 run-sherpa-onnx-*.sh
53 sherpa-onnx-zipformer-en-2023-03-30 53 sherpa-onnx-zipformer-en-2023-03-30
54 sherpa-onnx-zipformer-en-2023-04-01 54 sherpa-onnx-zipformer-en-2023-04-01
  55 +run-offline-decode-files.sh
  56 +sherpa-onnx-nemo-ctc-en-citrinet-512
  57 +run-offline-decode-files-nemo-ctc.sh
@@ -6,7 +6,7 @@ @@ -6,7 +6,7 @@
6 This file demonstrates how to use sherpa-onnx Python API to transcribe 6 This file demonstrates how to use sherpa-onnx Python API to transcribe
7 file(s) with a non-streaming model. 7 file(s) with a non-streaming model.
8 8
9 -paraformer Usage: 9 +(1) For paraformer
10 ./python-api-examples/offline-decode-files.py \ 10 ./python-api-examples/offline-decode-files.py \
11 --tokens=/path/to/tokens.txt \ 11 --tokens=/path/to/tokens.txt \
12 --paraformer=/path/to/paraformer.onnx \ 12 --paraformer=/path/to/paraformer.onnx \
@@ -18,7 +18,7 @@ paraformer Usage: @@ -18,7 +18,7 @@ paraformer Usage:
18 /path/to/0.wav \ 18 /path/to/0.wav \
19 /path/to/1.wav 19 /path/to/1.wav
20 20
21 -transducer Usage: 21 +(2) For transducer models from icefall
22 ./python-api-examples/offline-decode-files.py \ 22 ./python-api-examples/offline-decode-files.py \
23 --tokens=/path/to/tokens.txt \ 23 --tokens=/path/to/tokens.txt \
24 --encoder=/path/to/encoder.onnx \ 24 --encoder=/path/to/encoder.onnx \
@@ -32,6 +32,8 @@ transducer Usage: @@ -32,6 +32,8 @@ transducer Usage:
32 /path/to/0.wav \ 32 /path/to/0.wav \
33 /path/to/1.wav 33 /path/to/1.wav
34 34
  35 +(3) For CTC models from NeMo
  36 +
35 Please refer to 37 Please refer to
36 https://k2-fsa.github.io/sherpa/onnx/index.html 38 https://k2-fsa.github.io/sherpa/onnx/index.html
37 to install sherpa-onnx and to download the pre-trained models 39 to install sherpa-onnx and to download the pre-trained models
@@ -83,7 +85,14 @@ def get_args(): @@ -83,7 +85,14 @@ def get_args():
83 "--paraformer", 85 "--paraformer",
84 default="", 86 default="",
85 type=str, 87 type=str,
86 - help="Path to the paraformer model", 88 + help="Path to the model.onnx from Paraformer",
  89 + )
  90 +
  91 + parser.add_argument(
  92 + "--nemo-ctc",
  93 + default="",
  94 + type=str,
  95 + help="Path to the model.onnx from NeMo CTC",
87 ) 96 )
88 97
89 parser.add_argument( 98 parser.add_argument(
@@ -171,11 +180,14 @@ def main(): @@ -171,11 +180,14 @@ def main():
171 args = get_args() 180 args = get_args()
172 assert_file_exists(args.tokens) 181 assert_file_exists(args.tokens)
173 assert args.num_threads > 0, args.num_threads 182 assert args.num_threads > 0, args.num_threads
174 - if len(args.encoder) > 0: 183 + if args.encoder:
  184 + assert len(args.paraformer) == 0, args.paraformer
  185 + assert len(args.nemo_ctc) == 0, args.nemo_ctc
  186 +
175 assert_file_exists(args.encoder) 187 assert_file_exists(args.encoder)
176 assert_file_exists(args.decoder) 188 assert_file_exists(args.decoder)
177 assert_file_exists(args.joiner) 189 assert_file_exists(args.joiner)
178 - assert len(args.paraformer) == 0, args.paraformer 190 +
179 recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( 191 recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
180 encoder=args.encoder, 192 encoder=args.encoder,
181 decoder=args.decoder, 193 decoder=args.decoder,
@@ -187,8 +199,10 @@ def main(): @@ -187,8 +199,10 @@ def main():
187 decoding_method=args.decoding_method, 199 decoding_method=args.decoding_method,
188 debug=args.debug, 200 debug=args.debug,
189 ) 201 )
190 - else: 202 + elif args.paraformer:
  203 + assert len(args.nemo_ctc) == 0, args.nemo_ctc
191 assert_file_exists(args.paraformer) 204 assert_file_exists(args.paraformer)
  205 +
192 recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( 206 recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
193 paraformer=args.paraformer, 207 paraformer=args.paraformer,
194 tokens=args.tokens, 208 tokens=args.tokens,
@@ -198,6 +212,19 @@ def main(): @@ -198,6 +212,19 @@ def main():
198 decoding_method=args.decoding_method, 212 decoding_method=args.decoding_method,
199 debug=args.debug, 213 debug=args.debug,
200 ) 214 )
  215 + elif args.nemo_ctc:
  216 + recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
  217 + model=args.nemo_ctc,
  218 + tokens=args.tokens,
  219 + num_threads=args.num_threads,
  220 + sample_rate=args.sample_rate,
  221 + feature_dim=args.feature_dim,
  222 + decoding_method=args.decoding_method,
  223 + debug=args.debug,
  224 + )
  225 + else:
  226 + print("Please specify at least one model")
  227 + return
201 228
202 print("Started!") 229 print("Started!")
203 start_time = time.time() 230 start_time = time.time()
@@ -225,12 +252,14 @@ def main(): @@ -225,12 +252,14 @@ def main():
225 print("-" * 10) 252 print("-" * 10)
226 253
227 elapsed_seconds = end_time - start_time 254 elapsed_seconds = end_time - start_time
228 - rtf = elapsed_seconds / duration 255 + rtf = elapsed_seconds / total_duration
229 print(f"num_threads: {args.num_threads}") 256 print(f"num_threads: {args.num_threads}")
230 print(f"decoding_method: {args.decoding_method}") 257 print(f"decoding_method: {args.decoding_method}")
231 - print(f"Wave duration: {duration:.3f} s") 258 + print(f"Wave duration: {total_duration:.3f} s")
232 print(f"Elapsed time: {elapsed_seconds:.3f} s") 259 print(f"Elapsed time: {elapsed_seconds:.3f} s")
233 - print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}") 260 + print(
  261 + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
  262 + )
234 263
235 264
236 if __name__ == "__main__": 265 if __name__ == "__main__":
@@ -172,12 +172,14 @@ def main(): @@ -172,12 +172,14 @@ def main():
172 print("-" * 10) 172 print("-" * 10)
173 173
174 elapsed_seconds = end_time - start_time 174 elapsed_seconds = end_time - start_time
175 - rtf = elapsed_seconds / duration 175 + rtf = elapsed_seconds / total_duration
176 print(f"num_threads: {args.num_threads}") 176 print(f"num_threads: {args.num_threads}")
177 print(f"decoding_method: {args.decoding_method}") 177 print(f"decoding_method: {args.decoding_method}")
178 - print(f"Wave duration: {duration:.3f} s") 178 + print(f"Wave duration: {total_duration:.3f} s")
179 print(f"Elapsed time: {elapsed_seconds:.3f} s") 179 print(f"Elapsed time: {elapsed_seconds:.3f} s")
180 - print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}") 180 + print(
  181 + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
  182 + )
181 183
182 184
183 if __name__ == "__main__": 185 if __name__ == "__main__":
@@ -16,7 +16,11 @@ set(sources @@ -16,7 +16,11 @@ set(sources
16 features.cc 16 features.cc
17 file-utils.cc 17 file-utils.cc
18 hypothesis.cc 18 hypothesis.cc
  19 + offline-ctc-greedy-search-decoder.cc
  20 + offline-ctc-model.cc
19 offline-model-config.cc 21 offline-model-config.cc
  22 + offline-nemo-enc-dec-ctc-model-config.cc
  23 + offline-nemo-enc-dec-ctc-model.cc
20 offline-paraformer-greedy-search-decoder.cc 24 offline-paraformer-greedy-search-decoder.cc
21 offline-paraformer-model-config.cc 25 offline-paraformer-model-config.cc
22 offline-paraformer-model.cc 26 offline-paraformer-model.cc
@@ -11,15 +11,19 @@ @@ -11,15 +11,19 @@
11 #include "android/log.h" 11 #include "android/log.h"
12 #define SHERPA_ONNX_LOGE(...) \ 12 #define SHERPA_ONNX_LOGE(...) \
13 do { \ 13 do { \
  14 + fprintf(stderr, "%s:%s:%d ", __FILE__, __func__, \
  15 + static_cast<int>(__LINE__)); \
14 fprintf(stderr, ##__VA_ARGS__); \ 16 fprintf(stderr, ##__VA_ARGS__); \
15 fprintf(stderr, "\n"); \ 17 fprintf(stderr, "\n"); \
16 __android_log_print(ANDROID_LOG_WARN, "sherpa-onnx", ##__VA_ARGS__); \ 18 __android_log_print(ANDROID_LOG_WARN, "sherpa-onnx", ##__VA_ARGS__); \
17 } while (0) 19 } while (0)
18 #else 20 #else
19 -#define SHERPA_ONNX_LOGE(...) \  
20 - do { \  
21 - fprintf(stderr, ##__VA_ARGS__); \  
22 - fprintf(stderr, "\n"); \ 21 +#define SHERPA_ONNX_LOGE(...) \
  22 + do { \
  23 + fprintf(stderr, "%s:%s:%d ", __FILE__, __func__, \
  24 + static_cast<int>(__LINE__)); \
  25 + fprintf(stderr, ##__VA_ARGS__); \
  26 + fprintf(stderr, "\n"); \
23 } while (0) 27 } while (0)
24 #endif 28 #endif
25 29
  1 +// sherpa-onnx/csrc/offline-ctc-decoder.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_DECODER_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_CTC_DECODER_H_
  7 +
  8 +#include <vector>
  9 +
  10 +#include "onnxruntime_cxx_api.h" // NOLINT
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +struct OfflineCtcDecoderResult {
  15 + /// The decoded token IDs
  16 + std::vector<int64_t> tokens;
  17 +
  18 + /// timestamps[i] contains the output frame index where tokens[i] is decoded.
  19 + /// Note: The index is after subsampling
  20 + std::vector<int32_t> timestamps;
  21 +};
  22 +
  23 +class OfflineCtcDecoder {
  24 + public:
  25 + virtual ~OfflineCtcDecoder() = default;
  26 +
  27 + /** Run CTC decoding given the output from the encoder model.
  28 + *
  29 + * @param log_probs A 3-D tensor of shape (N, T, vocab_size) containing
  30 + * lob_probs.
  31 + * @param log_probs_length A 1-D tensor of shape (N,) containing number
  32 + * of valid frames in log_probs before padding.
  33 + *
  34 + * @return Return a vector of size `N` containing the decoded results.
  35 + */
  36 + virtual std::vector<OfflineCtcDecoderResult> Decode(
  37 + Ort::Value log_probs, Ort::Value log_probs_length) = 0;
  38 +};
  39 +
  40 +} // namespace sherpa_onnx
  41 +
  42 +#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_DECODER_H_
  1 +// sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h"
  6 +
  7 +#include <algorithm>
  8 +#include <utility>
  9 +#include <vector>
  10 +
  11 +#include "sherpa-onnx/csrc/macros.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +std::vector<OfflineCtcDecoderResult> OfflineCtcGreedySearchDecoder::Decode(
  16 + Ort::Value log_probs, Ort::Value log_probs_length) {
  17 + std::vector<int64_t> shape = log_probs.GetTensorTypeAndShapeInfo().GetShape();
  18 + int32_t batch_size = static_cast<int32_t>(shape[0]);
  19 + int32_t num_frames = static_cast<int32_t>(shape[1]);
  20 + int32_t vocab_size = static_cast<int32_t>(shape[2]);
  21 +
  22 + const int64_t *p_log_probs_length = log_probs_length.GetTensorData<int64_t>();
  23 +
  24 + std::vector<OfflineCtcDecoderResult> ans;
  25 + ans.reserve(batch_size);
  26 +
  27 + for (int32_t b = 0; b != batch_size; ++b) {
  28 + const float *p_log_probs =
  29 + log_probs.GetTensorData<float>() + b * num_frames * vocab_size;
  30 +
  31 + OfflineCtcDecoderResult r;
  32 + int64_t prev_id = -1;
  33 +
  34 + for (int32_t t = 0; t != static_cast<int32_t>(p_log_probs_length[b]); ++t) {
  35 + auto y = static_cast<int64_t>(std::distance(
  36 + static_cast<const float *>(p_log_probs),
  37 + std::max_element(
  38 + static_cast<const float *>(p_log_probs),
  39 + static_cast<const float *>(p_log_probs) + vocab_size)));
  40 + p_log_probs += vocab_size;
  41 +
  42 + if (y != blank_id_ && y != prev_id) {
  43 + r.tokens.push_back(y);
  44 + r.timestamps.push_back(t);
  45 + prev_id = y;
  46 + }
  47 + } // for (int32_t t = 0; ...)
  48 +
  49 + ans.push_back(std::move(r));
  50 + }
  51 + return ans;
  52 +}
  53 +
  54 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_GREEDY_SEARCH_DECODER_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_CTC_GREEDY_SEARCH_DECODER_H_
  7 +
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/offline-ctc-decoder.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +class OfflineCtcGreedySearchDecoder : public OfflineCtcDecoder {
  15 + public:
  16 + explicit OfflineCtcGreedySearchDecoder(int32_t blank_id)
  17 + : blank_id_(blank_id) {}
  18 +
  19 + std::vector<OfflineCtcDecoderResult> Decode(
  20 + Ort::Value log_probs, Ort::Value log_probs_length) override;
  21 +
  22 + private:
  23 + int32_t blank_id_;
  24 +};
  25 +
  26 +} // namespace sherpa_onnx
  27 +
  28 +#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_GREEDY_SEARCH_DECODER_H_
  1 +// sherpa-onnx/csrc/offline-ctc-model.cc
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-ctc-model.h"
  6 +
  7 +#include <algorithm>
  8 +#include <memory>
  9 +#include <sstream>
  10 +#include <string>
  11 +
  12 +#include "sherpa-onnx/csrc/macros.h"
  13 +#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
  14 +#include "sherpa-onnx/csrc/onnx-utils.h"
  15 +
  16 +namespace {
  17 +
  18 +enum class ModelType {
  19 + kEncDecCTCModelBPE,
  20 + kUnkown,
  21 +};
  22 +
  23 +}
  24 +
  25 +namespace sherpa_onnx {
  26 +
  27 +static ModelType GetModelType(char *model_data, size_t model_data_length,
  28 + bool debug) {
  29 + Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
  30 + Ort::SessionOptions sess_opts;
  31 +
  32 + auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length,
  33 + sess_opts);
  34 +
  35 + Ort::ModelMetadata meta_data = sess->GetModelMetadata();
  36 + if (debug) {
  37 + std::ostringstream os;
  38 + PrintModelMetadata(os, meta_data);
  39 + SHERPA_ONNX_LOGE("%s", os.str().c_str());
  40 + }
  41 +
  42 + Ort::AllocatorWithDefaultOptions allocator;
  43 + auto model_type =
  44 + meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
  45 + if (!model_type) {
  46 + SHERPA_ONNX_LOGE(
  47 + "No model_type in the metadata!\n"
  48 + "If you are using models from NeMo, please refer to\n"
  49 + "https://huggingface.co/csukuangfj/"
  50 + "sherpa-onnx-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py"
  51 + "\n"
  52 + "for how to add metadta to model.onnx\n");
  53 + return ModelType::kUnkown;
  54 + }
  55 +
  56 + if (model_type.get() == std::string("EncDecCTCModelBPE")) {
  57 + return ModelType::kEncDecCTCModelBPE;
  58 + } else {
  59 + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
  60 + return ModelType::kUnkown;
  61 + }
  62 +}
  63 +
  64 +std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
  65 + const OfflineModelConfig &config) {
  66 + ModelType model_type = ModelType::kUnkown;
  67 +
  68 + {
  69 + auto buffer = ReadFile(config.nemo_ctc.model);
  70 +
  71 + model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
  72 + }
  73 +
  74 + switch (model_type) {
  75 + case ModelType::kEncDecCTCModelBPE:
  76 + return std::make_unique<OfflineNemoEncDecCtcModel>(config);
  77 + break;
  78 + case ModelType::kUnkown:
  79 + SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
  80 + return nullptr;
  81 + }
  82 +
  83 + return nullptr;
  84 +}
  85 +
  86 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-ctc-model.h
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_CTC_MODEL_H_
  6 +
  7 +#include <memory>
  8 +#include <string>
  9 +#include <utility>
  10 +
  11 +#include "onnxruntime_cxx_api.h" // NOLINT
  12 +#include "sherpa-onnx/csrc/offline-model-config.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +class OfflineCtcModel {
  17 + public:
  18 + virtual ~OfflineCtcModel() = default;
  19 + static std::unique_ptr<OfflineCtcModel> Create(
  20 + const OfflineModelConfig &config);
  21 +
  22 + /** Run the forward method of the model.
  23 + *
  24 + * @param features A tensor of shape (N, T, C). It is changed in-place.
  25 + * @param features_length A 1-D tensor of shape (N,) containing number of
  26 + * valid frames in `features` before padding.
  27 + * Its dtype is int64_t.
  28 + *
  29 + * @return Return a pair containing:
  30 + * - log_probs: A 3-D tensor of shape (N, T', vocab_size).
  31 + * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
  32 + */
  33 + virtual std::pair<Ort::Value, Ort::Value> Forward(
  34 + Ort::Value features, Ort::Value features_length) = 0;
  35 +
  36 + /** Return the vocabulary size of the model
  37 + */
  38 + virtual int32_t VocabSize() const = 0;
  39 +
  40 + /** SubsamplingFactor of the model
  41 + *
  42 + * For Citrinet, the subsampling factor is usually 4.
  43 + * For Conformer CTC, the subsampling factor is usually 8.
  44 + */
  45 + virtual int32_t SubsamplingFactor() const = 0;
  46 +
  47 + /** Return an allocator for allocating memory
  48 + */
  49 + virtual OrtAllocator *Allocator() const = 0;
  50 +
  51 + /** For some models, e.g., those from NeMo, they require some preprocessing
  52 + * for the features.
  53 + */
  54 + virtual std::string FeatureNormalizationMethod() const { return {}; }
  55 +};
  56 +
  57 +} // namespace sherpa_onnx
  58 +
  59 +#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_MODEL_H_
@@ -13,6 +13,7 @@ namespace sherpa_onnx { @@ -13,6 +13,7 @@ namespace sherpa_onnx {
13 void OfflineModelConfig::Register(ParseOptions *po) { 13 void OfflineModelConfig::Register(ParseOptions *po) {
14 transducer.Register(po); 14 transducer.Register(po);
15 paraformer.Register(po); 15 paraformer.Register(po);
  16 + nemo_ctc.Register(po);
16 17
17 po->Register("tokens", &tokens, "Path to tokens.txt"); 18 po->Register("tokens", &tokens, "Path to tokens.txt");
18 19
@@ -38,6 +39,10 @@ bool OfflineModelConfig::Validate() const { @@ -38,6 +39,10 @@ bool OfflineModelConfig::Validate() const {
38 return paraformer.Validate(); 39 return paraformer.Validate();
39 } 40 }
40 41
  42 + if (!nemo_ctc.model.empty()) {
  43 + return nemo_ctc.Validate();
  44 + }
  45 +
41 return transducer.Validate(); 46 return transducer.Validate();
42 } 47 }
43 48
@@ -47,6 +52,7 @@ std::string OfflineModelConfig::ToString() const { @@ -47,6 +52,7 @@ std::string OfflineModelConfig::ToString() const {
47 os << "OfflineModelConfig("; 52 os << "OfflineModelConfig(";
48 os << "transducer=" << transducer.ToString() << ", "; 53 os << "transducer=" << transducer.ToString() << ", ";
49 os << "paraformer=" << paraformer.ToString() << ", "; 54 os << "paraformer=" << paraformer.ToString() << ", ";
  55 + os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
50 os << "tokens=\"" << tokens << "\", "; 56 os << "tokens=\"" << tokens << "\", ";
51 os << "num_threads=" << num_threads << ", "; 57 os << "num_threads=" << num_threads << ", ";
52 os << "debug=" << (debug ? "True" : "False") << ")"; 58 os << "debug=" << (debug ? "True" : "False") << ")";
@@ -6,6 +6,7 @@ @@ -6,6 +6,7 @@
6 6
7 #include <string> 7 #include <string>
8 8
  9 +#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
9 #include "sherpa-onnx/csrc/offline-paraformer-model-config.h" 10 #include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
10 #include "sherpa-onnx/csrc/offline-transducer-model-config.h" 11 #include "sherpa-onnx/csrc/offline-transducer-model-config.h"
11 12
@@ -14,6 +15,7 @@ namespace sherpa_onnx { @@ -14,6 +15,7 @@ namespace sherpa_onnx {
14 struct OfflineModelConfig { 15 struct OfflineModelConfig {
15 OfflineTransducerModelConfig transducer; 16 OfflineTransducerModelConfig transducer;
16 OfflineParaformerModelConfig paraformer; 17 OfflineParaformerModelConfig paraformer;
  18 + OfflineNemoEncDecCtcModelConfig nemo_ctc;
17 19
18 std::string tokens; 20 std::string tokens;
19 int32_t num_threads = 2; 21 int32_t num_threads = 2;
@@ -22,9 +24,11 @@ struct OfflineModelConfig { @@ -22,9 +24,11 @@ struct OfflineModelConfig {
22 OfflineModelConfig() = default; 24 OfflineModelConfig() = default;
23 OfflineModelConfig(const OfflineTransducerModelConfig &transducer, 25 OfflineModelConfig(const OfflineTransducerModelConfig &transducer,
24 const OfflineParaformerModelConfig &paraformer, 26 const OfflineParaformerModelConfig &paraformer,
  27 + const OfflineNemoEncDecCtcModelConfig &nemo_ctc,
25 const std::string &tokens, int32_t num_threads, bool debug) 28 const std::string &tokens, int32_t num_threads, bool debug)
26 : transducer(transducer), 29 : transducer(transducer),
27 paraformer(paraformer), 30 paraformer(paraformer),
  31 + nemo_ctc(nemo_ctc),
28 tokens(tokens), 32 tokens(tokens),
29 num_threads(num_threads), 33 num_threads(num_threads),
30 debug(debug) {} 34 debug(debug) {}
  1 +// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
  6 +
  7 +#include "sherpa-onnx/csrc/file-utils.h"
  8 +#include "sherpa-onnx/csrc/macros.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void OfflineNemoEncDecCtcModelConfig::Register(ParseOptions *po) {
  13 + po->Register("nemo-ctc-model", &model,
  14 + "Path to model.onnx of Nemo EncDecCtcModel.");
  15 +}
  16 +
  17 +bool OfflineNemoEncDecCtcModelConfig::Validate() const {
  18 + if (!FileExists(model)) {
  19 + SHERPA_ONNX_LOGE("%s does not exist", model.c_str());
  20 + return false;
  21 + }
  22 +
  23 + return true;
  24 +}
  25 +
  26 +std::string OfflineNemoEncDecCtcModelConfig::ToString() const {
  27 + std::ostringstream os;
  28 +
  29 + os << "OfflineNemoEncDecCtcModelConfig(";
  30 + os << "model=\"" << model << "\")";
  31 +
  32 + return os.str();
  33 +}
  34 +
  35 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/parse-options.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +struct OfflineNemoEncDecCtcModelConfig {
  14 + std::string model;
  15 +
  16 + OfflineNemoEncDecCtcModelConfig() = default;
  17 + explicit OfflineNemoEncDecCtcModelConfig(const std::string &model)
  18 + : model(model) {}
  19 +
  20 + void Register(ParseOptions *po);
  21 + bool Validate() const;
  22 +
  23 + std::string ToString() const;
  24 +};
  25 +
  26 +} // namespace sherpa_onnx
  27 +
  28 +#endif // SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
  6 +
  7 +#include "sherpa-onnx/csrc/macros.h"
  8 +#include "sherpa-onnx/csrc/onnx-utils.h"
  9 +#include "sherpa-onnx/csrc/text-utils.h"
  10 +#include "sherpa-onnx/csrc/transpose.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +class OfflineNemoEncDecCtcModel::Impl {
  15 + public:
  16 + explicit Impl(const OfflineModelConfig &config)
  17 + : config_(config),
  18 + env_(ORT_LOGGING_LEVEL_ERROR),
  19 + sess_opts_{},
  20 + allocator_{} {
  21 + sess_opts_.SetIntraOpNumThreads(config_.num_threads);
  22 + sess_opts_.SetInterOpNumThreads(config_.num_threads);
  23 +
  24 + Init();
  25 + }
  26 +
  27 + std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
  28 + Ort::Value features_length) {
  29 + std::vector<int64_t> shape =
  30 + features_length.GetTensorTypeAndShapeInfo().GetShape();
  31 +
  32 + Ort::Value out_features_length = Ort::Value::CreateTensor<int64_t>(
  33 + allocator_, shape.data(), shape.size());
  34 +
  35 + const int64_t *src = features_length.GetTensorData<int64_t>();
  36 + int64_t *dst = out_features_length.GetTensorMutableData<int64_t>();
  37 + for (int64_t i = 0; i != shape[0]; ++i) {
  38 + dst[i] = src[i] / subsampling_factor_;
  39 + }
  40 +
  41 + // (B, T, C) -> (B, C, T)
  42 + features = Transpose12(allocator_, &features);
  43 +
  44 + std::array<Ort::Value, 2> inputs = {std::move(features),
  45 + std::move(features_length)};
  46 + auto out =
  47 + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
  48 + output_names_ptr_.data(), output_names_ptr_.size());
  49 +
  50 + return {std::move(out[0]), std::move(out_features_length)};
  51 + }
  52 +
  53 + int32_t VocabSize() const { return vocab_size_; }
  54 +
  55 + int32_t SubsamplingFactor() const { return subsampling_factor_; }
  56 +
  57 + OrtAllocator *Allocator() const { return allocator_; }
  58 +
  59 + std::string FeatureNormalizationMethod() const { return normalize_type_; }
  60 +
  61 + private:
  62 + void Init() {
  63 + auto buf = ReadFile(config_.nemo_ctc.model);
  64 +
  65 + sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
  66 + sess_opts_);
  67 +
  68 + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
  69 +
  70 + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
  71 +
  72 + // get meta data
  73 + Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
  74 + if (config_.debug) {
  75 + std::ostringstream os;
  76 + PrintModelMetadata(os, meta_data);
  77 + SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
  78 + }
  79 +
  80 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  81 + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
  82 + SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor");
  83 + SHERPA_ONNX_READ_META_DATA_STR(normalize_type_, "normalize_type");
  84 + }
  85 +
  86 + private:
  87 + OfflineModelConfig config_;
  88 + Ort::Env env_;
  89 + Ort::SessionOptions sess_opts_;
  90 + Ort::AllocatorWithDefaultOptions allocator_;
  91 +
  92 + std::unique_ptr<Ort::Session> sess_;
  93 +
  94 + std::vector<std::string> input_names_;
  95 + std::vector<const char *> input_names_ptr_;
  96 +
  97 + std::vector<std::string> output_names_;
  98 + std::vector<const char *> output_names_ptr_;
  99 +
  100 + int32_t vocab_size_ = 0;
  101 + int32_t subsampling_factor_ = 0;
  102 + std::string normalize_type_;
  103 +};
  104 +
  105 +OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel(
  106 + const OfflineModelConfig &config)
  107 + : impl_(std::make_unique<Impl>(config)) {}
  108 +
  109 +OfflineNemoEncDecCtcModel::~OfflineNemoEncDecCtcModel() = default;
  110 +
  111 +std::pair<Ort::Value, Ort::Value> OfflineNemoEncDecCtcModel::Forward(
  112 + Ort::Value features, Ort::Value features_length) {
  113 + return impl_->Forward(std::move(features), std::move(features_length));
  114 +}
  115 +
  116 +int32_t OfflineNemoEncDecCtcModel::VocabSize() const {
  117 + return impl_->VocabSize();
  118 +}
  119 +int32_t OfflineNemoEncDecCtcModel::SubsamplingFactor() const {
  120 + return impl_->SubsamplingFactor();
  121 +}
  122 +
  123 +OrtAllocator *OfflineNemoEncDecCtcModel::Allocator() const {
  124 + return impl_->Allocator();
  125 +}
  126 +
  127 +std::string OfflineNemoEncDecCtcModel::FeatureNormalizationMethod() const {
  128 + return impl_->FeatureNormalizationMethod();
  129 +}
  130 +
  131 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_H_
  6 +#include <memory>
  7 +#include <string>
  8 +#include <utility>
  9 +#include <vector>
  10 +
  11 +#include "onnxruntime_cxx_api.h" // NOLINT
  12 +#include "sherpa-onnx/csrc/offline-ctc-model.h"
  13 +#include "sherpa-onnx/csrc/offline-model-config.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +/** This class implements the EncDecCTCModelBPE model from NeMo.
  18 + *
  19 + * See
  20 + * https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/ctc_bpe_models.py
  21 + * https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/ctc_models.py
  22 + */
  23 +class OfflineNemoEncDecCtcModel : public OfflineCtcModel {
  24 + public:
  25 + explicit OfflineNemoEncDecCtcModel(const OfflineModelConfig &config);
  26 + ~OfflineNemoEncDecCtcModel() override;
  27 +
  28 + /** Run the forward method of the model.
  29 + *
  30 + * @param features A tensor of shape (N, T, C). It is changed in-place.
  31 + * @param features_length A 1-D tensor of shape (N,) containing number of
  32 + * valid frames in `features` before padding.
  33 + * Its dtype is int64_t.
  34 + *
  35 + * @return Return a pair containing:
  36 + * - log_probs: A 3-D tensor of shape (N, T', vocab_size).
  37 + * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
  38 + */
  39 + std::pair<Ort::Value, Ort::Value> Forward(
  40 + Ort::Value features, Ort::Value features_length) override;
  41 +
  42 + /** Return the vocabulary size of the model
  43 + */
  44 + int32_t VocabSize() const override;
  45 +
  46 + /** SubsamplingFactor of the model
  47 + *
  48 + * For Citrinet, the subsampling factor is usually 4.
  49 + * For Conformer CTC, the subsampling factor is usually 8.
  50 + */
  51 + int32_t SubsamplingFactor() const override;
  52 +
  53 + /** Return an allocator for allocating memory
  54 + */
  55 + OrtAllocator *Allocator() const override;
  56 +
  57 + // Possible values:
  58 + // - per_feature
  59 + // - all_features (not implemented yet)
  60 + // - fixed_mean (not implemented)
  61 + // - fixed_std (not implemented)
  62 + // - or just leave it to empty
  63 + // See
  64 + // https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59
  65 + // for details
  66 + std::string FeatureNormalizationMethod() const override;
  67 +
  68 + private:
  69 + class Impl;
  70 + std::unique_ptr<Impl> impl_;
  71 +};
  72 +
  73 +} // namespace sherpa_onnx
  74 +
  75 +#endif // SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_H_
  1 +// sherpa-onnx/csrc/offline-recognizer-ctc-impl.h
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CTC_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CTC_IMPL_H_
  7 +
  8 +#include <memory>
  9 +#include <string>
  10 +#include <utility>
  11 +#include <vector>
  12 +
  13 +#include "sherpa-onnx/csrc/offline-ctc-decoder.h"
  14 +#include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h"
  15 +#include "sherpa-onnx/csrc/offline-ctc-model.h"
  16 +#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
  17 +#include "sherpa-onnx/csrc/pad-sequence.h"
  18 +#include "sherpa-onnx/csrc/symbol-table.h"
  19 +
  20 +namespace sherpa_onnx {
  21 +
  22 +static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
  23 + const SymbolTable &sym_table) {
  24 + OfflineRecognitionResult r;
  25 + r.tokens.reserve(src.tokens.size());
  26 +
  27 + std::string text;
  28 +
  29 + for (int32_t i = 0; i != src.tokens.size(); ++i) {
  30 + auto sym = sym_table[src.tokens[i]];
  31 + text.append(sym);
  32 + r.tokens.push_back(std::move(sym));
  33 + }
  34 + r.text = std::move(text);
  35 +
  36 + return r;
  37 +}
  38 +
  39 +class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
  40 + public:
  41 + explicit OfflineRecognizerCtcImpl(const OfflineRecognizerConfig &config)
  42 + : config_(config),
  43 + symbol_table_(config_.model_config.tokens),
  44 + model_(OfflineCtcModel::Create(config_.model_config)) {
  45 + config_.feat_config.nemo_normalize_type =
  46 + model_->FeatureNormalizationMethod();
  47 +
  48 + if (config.decoding_method == "greedy_search") {
  49 + if (!symbol_table_.contains("<blk>")) {
  50 + SHERPA_ONNX_LOGE(
  51 + "We expect that tokens.txt contains "
  52 + "the symbol <blk> and its ID.");
  53 + exit(-1);
  54 + }
  55 +
  56 + int32_t blank_id = symbol_table_["<blk>"];
  57 + decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id);
  58 + } else {
  59 + SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
  60 + config.decoding_method.c_str());
  61 + exit(-1);
  62 + }
  63 + }
  64 +
  65 + std::unique_ptr<OfflineStream> CreateStream() const override {
  66 + return std::make_unique<OfflineStream>(config_.feat_config);
  67 + }
  68 +
  69 + void DecodeStreams(OfflineStream **ss, int32_t n) const override {
  70 + auto memory_info =
  71 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  72 +
  73 + int32_t feat_dim = config_.feat_config.feature_dim;
  74 +
  75 + std::vector<Ort::Value> features;
  76 + features.reserve(n);
  77 +
  78 + std::vector<std::vector<float>> features_vec(n);
  79 + std::vector<int64_t> features_length_vec(n);
  80 +
  81 + for (int32_t i = 0; i != n; ++i) {
  82 + std::vector<float> f = ss[i]->GetFrames();
  83 +
  84 + int32_t num_frames = f.size() / feat_dim;
  85 + features_vec[i] = std::move(f);
  86 +
  87 + features_length_vec[i] = num_frames;
  88 +
  89 + std::array<int64_t, 2> shape = {num_frames, feat_dim};
  90 +
  91 + Ort::Value x = Ort::Value::CreateTensor(
  92 + memory_info, features_vec[i].data(), features_vec[i].size(),
  93 + shape.data(), shape.size());
  94 + features.push_back(std::move(x));
  95 + } // for (int32_t i = 0; i != n; ++i)
  96 +
  97 + std::vector<const Ort::Value *> features_pointer(n);
  98 + for (int32_t i = 0; i != n; ++i) {
  99 + features_pointer[i] = &features[i];
  100 + }
  101 +
  102 + std::array<int64_t, 1> features_length_shape = {n};
  103 + Ort::Value x_length = Ort::Value::CreateTensor(
  104 + memory_info, features_length_vec.data(), n,
  105 + features_length_shape.data(), features_length_shape.size());
  106 +
  107 + Ort::Value x = PadSequence(model_->Allocator(), features_pointer,
  108 + -23.025850929940457f);
  109 + auto t = model_->Forward(std::move(x), std::move(x_length));
  110 +
  111 + auto results = decoder_->Decode(std::move(t.first), std::move(t.second));
  112 +
  113 + for (int32_t i = 0; i != n; ++i) {
  114 + auto r = Convert(results[i], symbol_table_);
  115 + ss[i]->SetResult(r);
  116 + }
  117 + }
  118 +
  119 + private:
  120 + OfflineRecognizerConfig config_;
  121 + SymbolTable symbol_table_;
  122 + std::unique_ptr<OfflineCtcModel> model_;
  123 + std::unique_ptr<OfflineCtcDecoder> decoder_;
  124 +};
  125 +
  126 +} // namespace sherpa_onnx
  127 +
  128 +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CTC_IMPL_H_
@@ -8,6 +8,7 @@ @@ -8,6 +8,7 @@
8 8
9 #include "onnxruntime_cxx_api.h" // NOLINT 9 #include "onnxruntime_cxx_api.h" // NOLINT
10 #include "sherpa-onnx/csrc/macros.h" 10 #include "sherpa-onnx/csrc/macros.h"
  11 +#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
11 #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h" 12 #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
12 #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" 13 #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
13 #include "sherpa-onnx/csrc/onnx-utils.h" 14 #include "sherpa-onnx/csrc/onnx-utils.h"
@@ -25,6 +26,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -25,6 +26,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
25 model_filename = config.model_config.transducer.encoder_filename; 26 model_filename = config.model_config.transducer.encoder_filename;
26 } else if (!config.model_config.paraformer.model.empty()) { 27 } else if (!config.model_config.paraformer.model.empty()) {
27 model_filename = config.model_config.paraformer.model; 28 model_filename = config.model_config.paraformer.model;
  29 + } else if (!config.model_config.nemo_ctc.model.empty()) {
  30 + model_filename = config.model_config.nemo_ctc.model;
28 } else { 31 } else {
29 SHERPA_ONNX_LOGE("Please provide a model"); 32 SHERPA_ONNX_LOGE("Please provide a model");
30 exit(-1); 33 exit(-1);
@@ -39,8 +42,30 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -39,8 +42,30 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
39 42
40 Ort::AllocatorWithDefaultOptions allocator; // used in the macro below 43 Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
41 44
42 - std::string model_type;  
43 - SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type"); 45 + auto model_type_ptr =
  46 + meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
  47 + if (!model_type_ptr) {
  48 + SHERPA_ONNX_LOGE(
  49 + "No model_type in the metadata!\n\n"
  50 + "Please refer to the following URLs to add metadata"
  51 + "\n"
  52 + "(0) Transducer models from icefall"
  53 + "\n "
  54 + "https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/"
  55 + "pruned_transducer_stateless7/export-onnx.py#L303"
  56 + "\n"
  57 + "(1) Nemo CTC models\n "
  58 + "https://huggingface.co/csukuangfj/"
  59 + "sherpa-onnx-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py"
  60 + "\n"
  61 + "(2) Paraformer"
  62 + "\n "
  63 + "https://huggingface.co/csukuangfj/"
  64 + "paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py"
  65 + "\n");
  66 + exit(-1);
  67 + }
  68 + std::string model_type(model_type_ptr.get());
44 69
45 if (model_type == "conformer" || model_type == "zipformer") { 70 if (model_type == "conformer" || model_type == "zipformer") {
46 return std::make_unique<OfflineRecognizerTransducerImpl>(config); 71 return std::make_unique<OfflineRecognizerTransducerImpl>(config);
@@ -50,11 +75,16 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -50,11 +75,16 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
50 return std::make_unique<OfflineRecognizerParaformerImpl>(config); 75 return std::make_unique<OfflineRecognizerParaformerImpl>(config);
51 } 76 }
52 77
  78 + if (model_type == "EncDecCTCModelBPE") {
  79 + return std::make_unique<OfflineRecognizerCtcImpl>(config);
  80 + }
  81 +
53 SHERPA_ONNX_LOGE( 82 SHERPA_ONNX_LOGE(
54 "\nUnsupported model_type: %s\n" 83 "\nUnsupported model_type: %s\n"
55 "We support only the following model types at present: \n" 84 "We support only the following model types at present: \n"
56 - " - transducer models from icefall\n"  
57 - " - Paraformer models from FunASR\n", 85 + " - Non-streaming transducer models from icefall\n"
  86 + " - Non-streaming Paraformer models from FunASR\n"
  87 + " - EncDecCTCModelBPE models from NeMo\n",
58 model_type.c_str()); 88 model_type.c_str());
59 89
60 exit(-1); 90 exit(-1);
@@ -7,6 +7,7 @@ @@ -7,6 +7,7 @@
7 #include <assert.h> 7 #include <assert.h>
8 8
9 #include <algorithm> 9 #include <algorithm>
  10 +#include <cmath>
10 11
11 #include "kaldi-native-fbank/csrc/online-feature.h" 12 #include "kaldi-native-fbank/csrc/online-feature.h"
12 #include "sherpa-onnx/csrc/macros.h" 13 #include "sherpa-onnx/csrc/macros.h"
@@ -15,6 +16,41 @@ @@ -15,6 +16,41 @@
15 16
16 namespace sherpa_onnx { 17 namespace sherpa_onnx {
17 18
  19 +/* Compute mean and inverse stddev over rows.
  20 + *
  21 + * @param p A pointer to a 2-d array of shape (num_rows, num_cols)
  22 + * @param num_rows Number of rows
  23 + * @param num_cols Number of columns
  24 + * @param mean On return, it contains p.mean(axis=0)
  25 + * @param inv_stddev On return, it contains 1/p.std(axis=0)
  26 + */
  27 +static void ComputeMeanAndInvStd(const float *p, int32_t num_rows,
  28 + int32_t num_cols, std::vector<float> *mean,
  29 + std::vector<float> *inv_stddev) {
  30 + std::vector<float> sum(num_cols);
  31 + std::vector<float> sum_sq(num_cols);
  32 +
  33 + for (int32_t i = 0; i != num_rows; ++i) {
  34 + for (int32_t c = 0; c != num_cols; ++c) {
  35 + auto t = p[c];
  36 + sum[c] += t;
  37 + sum_sq[c] += t * t;
  38 + }
  39 + p += num_cols;
  40 + }
  41 +
  42 + mean->resize(num_cols);
  43 + inv_stddev->resize(num_cols);
  44 +
  45 + for (int32_t i = 0; i != num_cols; ++i) {
  46 + auto t = sum[i] / num_rows;
  47 + (*mean)[i] = t;
  48 +
  49 + float stddev = std::sqrt(sum_sq[i] / num_rows - t * t);
  50 + (*inv_stddev)[i] = 1.0f / (stddev + 1e-5f);
  51 + }
  52 +}
  53 +
18 void OfflineFeatureExtractorConfig::Register(ParseOptions *po) { 54 void OfflineFeatureExtractorConfig::Register(ParseOptions *po) {
19 po->Register("sample-rate", &sampling_rate, 55 po->Register("sample-rate", &sampling_rate,
20 "Sampling rate of the input waveform. " 56 "Sampling rate of the input waveform. "
@@ -106,6 +142,8 @@ class OfflineStream::Impl { @@ -106,6 +142,8 @@ class OfflineStream::Impl {
106 p += feature_dim; 142 p += feature_dim;
107 } 143 }
108 144
  145 + NemoNormalizeFeatures(features.data(), n, feature_dim);
  146 +
109 return features; 147 return features;
110 } 148 }
111 149
@@ -114,6 +152,38 @@ class OfflineStream::Impl { @@ -114,6 +152,38 @@ class OfflineStream::Impl {
114 const OfflineRecognitionResult &GetResult() const { return r_; } 152 const OfflineRecognitionResult &GetResult() const { return r_; }
115 153
116 private: 154 private:
  155 + void NemoNormalizeFeatures(float *p, int32_t num_frames,
  156 + int32_t feature_dim) const {
  157 + if (config_.nemo_normalize_type.empty()) {
  158 + return;
  159 + }
  160 +
  161 + if (config_.nemo_normalize_type != "per_feature") {
  162 + SHERPA_ONNX_LOGE(
  163 + "Only normalize_type=per_feature is implemented. Given: %s",
  164 + config_.nemo_normalize_type.c_str());
  165 + exit(-1);
  166 + }
  167 +
  168 + NemoNormalizePerFeature(p, num_frames, feature_dim);
  169 + }
  170 +
  171 + static void NemoNormalizePerFeature(float *p, int32_t num_frames,
  172 + int32_t feature_dim) {
  173 + std::vector<float> mean;
  174 + std::vector<float> inv_stddev;
  175 +
  176 + ComputeMeanAndInvStd(p, num_frames, feature_dim, &mean, &inv_stddev);
  177 +
  178 + for (int32_t n = 0; n != num_frames; ++n) {
  179 + for (int32_t i = 0; i != feature_dim; ++i) {
  180 + p[i] = (p[i] - mean[i]) * inv_stddev[i];
  181 + }
  182 + p += feature_dim;
  183 + }
  184 + }
  185 +
  186 + private:
117 OfflineFeatureExtractorConfig config_; 187 OfflineFeatureExtractorConfig config_;
118 std::unique_ptr<knf::OnlineFbank> fbank_; 188 std::unique_ptr<knf::OnlineFbank> fbank_;
119 knf::FbankOptions opts_; 189 knf::FbankOptions opts_;
@@ -37,13 +37,26 @@ struct OfflineFeatureExtractorConfig { @@ -37,13 +37,26 @@ struct OfflineFeatureExtractorConfig {
37 // Feature dimension 37 // Feature dimension
38 int32_t feature_dim = 80; 38 int32_t feature_dim = 80;
39 39
40 - // Set internally by some models, e.g., paraformer 40 + // Set internally by some models, e.g., paraformer sets it to false.
41 // This parameter is not exposed to users from the commandline 41 // This parameter is not exposed to users from the commandline
42 // If true, the feature extractor expects inputs to be normalized to 42 // If true, the feature extractor expects inputs to be normalized to
43 // the range [-1, 1]. 43 // the range [-1, 1].
44 // If false, we will multiply the inputs by 32768 44 // If false, we will multiply the inputs by 32768
45 bool normalize_samples = true; 45 bool normalize_samples = true;
46 46
  47 + // For models from NeMo
  48 + // This option is not exposed and is set internally when loading models.
  49 + // Possible values:
  50 + // - per_feature
  51 + // - all_features (not implemented yet)
  52 + // - fixed_mean (not implemented)
  53 + // - fixed_std (not implemented)
  54 + // - or just leave it to empty
  55 + // See
  56 + // https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59
  57 + // for details
  58 + std::string nemo_normalize_type;
  59 +
47 std::string ToString() const; 60 std::string ToString() const;
48 61
49 void Register(ParseOptions *po); 62 void Register(ParseOptions *po);
@@ -14,10 +14,12 @@ @@ -14,10 +14,12 @@
14 #include <sstream> 14 #include <sstream>
15 #include <string> 15 #include <string>
16 16
  17 +#include "sherpa-onnx/csrc/macros.h"
17 #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" 18 #include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
18 #include "sherpa-onnx/csrc/online-zipformer-transducer-model.h" 19 #include "sherpa-onnx/csrc/online-zipformer-transducer-model.h"
19 #include "sherpa-onnx/csrc/onnx-utils.h" 20 #include "sherpa-onnx/csrc/onnx-utils.h"
20 -namespace sherpa_onnx { 21 +
  22 +namespace {
21 23
22 enum class ModelType { 24 enum class ModelType {
23 kLstm, 25 kLstm,
@@ -25,6 +27,10 @@ enum class ModelType { @@ -25,6 +27,10 @@ enum class ModelType {
25 kUnkown, 27 kUnkown,
26 }; 28 };
27 29
  30 +}
  31 +
  32 +namespace sherpa_onnx {
  33 +
28 static ModelType GetModelType(char *model_data, size_t model_data_length, 34 static ModelType GetModelType(char *model_data, size_t model_data_length,
29 bool debug) { 35 bool debug) {
30 Ort::Env env(ORT_LOGGING_LEVEL_WARNING); 36 Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
@@ -37,14 +43,17 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -37,14 +43,17 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
37 if (debug) { 43 if (debug) {
38 std::ostringstream os; 44 std::ostringstream os;
39 PrintModelMetadata(os, meta_data); 45 PrintModelMetadata(os, meta_data);
40 - fprintf(stderr, "%s\n", os.str().c_str()); 46 + SHERPA_ONNX_LOGE("%s", os.str().c_str());
41 } 47 }
42 48
43 Ort::AllocatorWithDefaultOptions allocator; 49 Ort::AllocatorWithDefaultOptions allocator;
44 auto model_type = 50 auto model_type =
45 meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); 51 meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
46 if (!model_type) { 52 if (!model_type) {
47 - fprintf(stderr, "No model_type in the metadata!\n"); 53 + SHERPA_ONNX_LOGE(
  54 + "No model_type in the metadata!\n"
  55 + "Please make sure you are using the latest export-onnx.py from icefall "
  56 + "to export your transducer models");
48 return ModelType::kUnkown; 57 return ModelType::kUnkown;
49 } 58 }
50 59
@@ -53,7 +62,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -53,7 +62,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
53 } else if (model_type.get() == std::string("zipformer")) { 62 } else if (model_type.get() == std::string("zipformer")) {
54 return ModelType::kZipformer; 63 return ModelType::kZipformer;
55 } else { 64 } else {
56 - fprintf(stderr, "Unsupported model_type: %s\n", model_type.get()); 65 + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
57 return ModelType::kUnkown; 66 return ModelType::kUnkown;
58 } 67 }
59 } 68 }
@@ -74,6 +83,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( @@ -74,6 +83,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
74 case ModelType::kZipformer: 83 case ModelType::kZipformer:
75 return std::make_unique<OnlineZipformerTransducerModel>(config); 84 return std::make_unique<OnlineZipformerTransducerModel>(config);
76 case ModelType::kUnkown: 85 case ModelType::kUnkown:
  86 + SHERPA_ONNX_LOGE("Unknown model type in online transducer!");
77 return nullptr; 87 return nullptr;
78 } 88 }
79 89
@@ -127,6 +137,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( @@ -127,6 +137,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
127 case ModelType::kZipformer: 137 case ModelType::kZipformer:
128 return std::make_unique<OnlineZipformerTransducerModel>(mgr, config); 138 return std::make_unique<OnlineZipformerTransducerModel>(mgr, config);
129 case ModelType::kUnkown: 139 case ModelType::kUnkown:
  140 + SHERPA_ONNX_LOGE("Unknown model type in online transducer!");
130 return nullptr; 141 return nullptr;
131 } 142 }
132 143
@@ -35,4 +35,28 @@ TEST(Tranpose, Tranpose01) { @@ -35,4 +35,28 @@ TEST(Tranpose, Tranpose01) {
35 } 35 }
36 } 36 }
37 37
  38 +TEST(Tranpose, Tranpose12) {
  39 + Ort::AllocatorWithDefaultOptions allocator;
  40 + std::array<int64_t, 3> shape{3, 2, 5};
  41 + Ort::Value v =
  42 + Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
  43 + float *p = v.GetTensorMutableData<float>();
  44 +
  45 + std::iota(p, p + shape[0] * shape[1] * shape[2], 0);
  46 +
  47 + auto ans = Transpose12(allocator, &v);
  48 + auto v2 = Transpose12(allocator, &ans);
  49 +
  50 + Print3D(&v);
  51 + Print3D(&ans);
  52 + Print3D(&v2);
  53 +
  54 + const float *q = v2.GetTensorData<float>();
  55 +
  56 + for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1] * shape[2]);
  57 + ++i) {
  58 + EXPECT_EQ(p[i], q[i]);
  59 + }
  60 +}
  61 +
38 } // namespace sherpa_onnx 62 } // namespace sherpa_onnx
@@ -17,8 +17,8 @@ Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v) { @@ -17,8 +17,8 @@ Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v) {
17 assert(shape.size() == 3); 17 assert(shape.size() == 3);
18 18
19 std::array<int64_t, 3> ans_shape{shape[1], shape[0], shape[2]}; 19 std::array<int64_t, 3> ans_shape{shape[1], shape[0], shape[2]};
20 - Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(),  
21 - ans_shape.size()); 20 + Ort::Value ans = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(),
  21 + ans_shape.size());
22 22
23 T *dst = ans.GetTensorMutableData<T>(); 23 T *dst = ans.GetTensorMutableData<T>();
24 auto plane_offset = shape[1] * shape[2]; 24 auto plane_offset = shape[1] * shape[2];
@@ -35,7 +35,32 @@ Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v) { @@ -35,7 +35,32 @@ Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v) {
35 return ans; 35 return ans;
36 } 36 }
37 37
  38 +template <typename T /*= float*/>
  39 +Ort::Value Transpose12(OrtAllocator *allocator, const Ort::Value *v) {
  40 + std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
  41 + assert(shape.size() == 3);
  42 +
  43 + std::array<int64_t, 3> ans_shape{shape[0], shape[2], shape[1]};
  44 + Ort::Value ans = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(),
  45 + ans_shape.size());
  46 + T *dst = ans.GetTensorMutableData<T>();
  47 + auto row_stride = shape[2];
  48 + for (int64_t b = 0; b != ans_shape[0]; ++b) {
  49 + const T *src = v->GetTensorData<T>() + b * shape[1] * shape[2];
  50 + for (int64_t i = 0; i != ans_shape[1]; ++i) {
  51 + for (int64_t k = 0; k != ans_shape[2]; ++k, ++dst) {
  52 + *dst = (src + k * row_stride)[i];
  53 + }
  54 + }
  55 + }
  56 +
  57 + return ans;
  58 +}
  59 +
38 template Ort::Value Transpose01<float>(OrtAllocator *allocator, 60 template Ort::Value Transpose01<float>(OrtAllocator *allocator,
39 const Ort::Value *v); 61 const Ort::Value *v);
40 62
  63 +template Ort::Value Transpose12<float>(OrtAllocator *allocator,
  64 + const Ort::Value *v);
  65 +
41 } // namespace sherpa_onnx 66 } // namespace sherpa_onnx
@@ -10,13 +10,23 @@ namespace sherpa_onnx { @@ -10,13 +10,23 @@ namespace sherpa_onnx {
10 /** Transpose a 3-D tensor from shape (B, T, C) to (T, B, C). 10 /** Transpose a 3-D tensor from shape (B, T, C) to (T, B, C).
11 * 11 *
12 * @param allocator 12 * @param allocator
13 - * @param v A 3-D tensor of shape (B, T, C). Its dataype is T. 13 + * @param v A 3-D tensor of shape (B, T, C). Its dataype is type.
14 * 14 *
15 - * @return Return a 3-D tensor of shape (T, B, C). Its datatype is T. 15 + * @return Return a 3-D tensor of shape (T, B, C). Its datatype is type.
16 */ 16 */
17 -template <typename T = float> 17 +template <typename type = float>
18 Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v); 18 Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v);
19 19
  20 +/** Transpose a 3-D tensor from shape (B, T, C) to (B, C, T).
  21 + *
  22 + * @param allocator
  23 + * @param v A 3-D tensor of shape (B, T, C). Its dataype is type.
  24 + *
  25 + * @return Return a 3-D tensor of shape (B, C, T). Its datatype is type.
  26 + */
  27 +template <typename type = float>
  28 +Ort::Value Transpose12(OrtAllocator *allocator, const Ort::Value *v);
  29 +
20 } // namespace sherpa_onnx 30 } // namespace sherpa_onnx
21 31
22 #endif // SHERPA_ONNX_CSRC_TRANSPOSE_H_ 32 #endif // SHERPA_ONNX_CSRC_TRANSPOSE_H_
@@ -5,6 +5,7 @@ pybind11_add_module(_sherpa_onnx @@ -5,6 +5,7 @@ pybind11_add_module(_sherpa_onnx
5 endpoint.cc 5 endpoint.cc
6 features.cc 6 features.cc
7 offline-model-config.cc 7 offline-model-config.cc
  8 + offline-nemo-enc-dec-ctc-model-config.cc
8 offline-paraformer-model-config.cc 9 offline-paraformer-model-config.cc
9 offline-recognizer.cc 10 offline-recognizer.cc
10 offline-stream.cc 11 offline-stream.cc
@@ -7,26 +7,31 @@ @@ -7,26 +7,31 @@
7 #include <string> 7 #include <string>
8 #include <vector> 8 #include <vector>
9 9
10 -#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"  
11 -#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"  
12 -  
13 #include "sherpa-onnx/csrc/offline-model-config.h" 10 #include "sherpa-onnx/csrc/offline-model-config.h"
  11 +#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
  12 +#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
  13 +#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
14 14
15 namespace sherpa_onnx { 15 namespace sherpa_onnx {
16 16
17 void PybindOfflineModelConfig(py::module *m) { 17 void PybindOfflineModelConfig(py::module *m) {
18 PybindOfflineTransducerModelConfig(m); 18 PybindOfflineTransducerModelConfig(m);
19 PybindOfflineParaformerModelConfig(m); 19 PybindOfflineParaformerModelConfig(m);
  20 + PybindOfflineNemoEncDecCtcModelConfig(m);
20 21
21 using PyClass = OfflineModelConfig; 22 using PyClass = OfflineModelConfig;
22 py::class_<PyClass>(*m, "OfflineModelConfig") 23 py::class_<PyClass>(*m, "OfflineModelConfig")
23 - .def(py::init<OfflineTransducerModelConfig &,  
24 - OfflineParaformerModelConfig &,  
25 - const std::string &, int32_t, bool>(),  
26 - py::arg("transducer"), py::arg("paraformer"), py::arg("tokens"),  
27 - py::arg("num_threads"), py::arg("debug") = false) 24 + .def(py::init<const OfflineTransducerModelConfig &,
  25 + const OfflineParaformerModelConfig &,
  26 + const OfflineNemoEncDecCtcModelConfig &,
  27 + const std::string &, int32_t, bool>(),
  28 + py::arg("transducer") = OfflineTransducerModelConfig(),
  29 + py::arg("paraformer") = OfflineParaformerModelConfig(),
  30 + py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
  31 + py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false)
28 .def_readwrite("transducer", &PyClass::transducer) 32 .def_readwrite("transducer", &PyClass::transducer)
29 .def_readwrite("paraformer", &PyClass::paraformer) 33 .def_readwrite("paraformer", &PyClass::paraformer)
  34 + .def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
30 .def_readwrite("tokens", &PyClass::tokens) 35 .def_readwrite("tokens", &PyClass::tokens)
31 .def_readwrite("num_threads", &PyClass::num_threads) 36 .def_readwrite("num_threads", &PyClass::num_threads)
32 .def_readwrite("debug", &PyClass::debug) 37 .def_readwrite("debug", &PyClass::debug)
  1 +// sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
  6 +
  7 +#include <string>
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +void PybindOfflineNemoEncDecCtcModelConfig(py::module *m) {
  15 + using PyClass = OfflineNemoEncDecCtcModelConfig;
  16 + py::class_<PyClass>(*m, "OfflineNemoEncDecCtcModelConfig")
  17 + .def(py::init<const std::string &>(), py::arg("model"))
  18 + .def_readwrite("model", &PyClass::model)
  19 + .def("__str__", &PyClass::ToString);
  20 +}
  21 +
  22 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineNemoEncDecCtcModelConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_
@@ -4,7 +4,6 @@ @@ -4,7 +4,6 @@
4 4
5 #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" 5 #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
6 6
7 -  
8 #include <string> 7 #include <string>
9 #include <vector> 8 #include <vector>
10 9
@@ -15,8 +14,7 @@ namespace sherpa_onnx { @@ -15,8 +14,7 @@ namespace sherpa_onnx {
15 void PybindOfflineParaformerModelConfig(py::module *m) { 14 void PybindOfflineParaformerModelConfig(py::module *m) {
16 using PyClass = OfflineParaformerModelConfig; 15 using PyClass = OfflineParaformerModelConfig;
17 py::class_<PyClass>(*m, "OfflineParaformerModelConfig") 16 py::class_<PyClass>(*m, "OfflineParaformerModelConfig")
18 - .def(py::init<const std::string &>(),  
19 - py::arg("model")) 17 + .def(py::init<const std::string &>(), py::arg("model"))
20 .def_readwrite("model", &PyClass::model) 18 .def_readwrite("model", &PyClass::model)
21 .def("__str__", &PyClass::ToString); 19 .def("__str__", &PyClass::ToString);
22 } 20 }
@@ -11,8 +11,6 @@ @@ -11,8 +11,6 @@
11 11
12 namespace sherpa_onnx { 12 namespace sherpa_onnx {
13 13
14 -  
15 -  
16 static void PybindOfflineRecognizerConfig(py::module *m) { 14 static void PybindOfflineRecognizerConfig(py::module *m) {
17 using PyClass = OfflineRecognizerConfig; 15 using PyClass = OfflineRecognizerConfig;
18 py::class_<PyClass>(*m, "OfflineRecognizerConfig") 16 py::class_<PyClass>(*m, "OfflineRecognizerConfig")
@@ -31,7 +31,6 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT @@ -31,7 +31,6 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT
31 "timestamps", [](const PyClass &self) { return self.timestamps; }); 31 "timestamps", [](const PyClass &self) { return self.timestamps; });
32 } 32 }
33 33
34 -  
35 static void PybindOfflineFeatureExtractorConfig(py::module *m) { 34 static void PybindOfflineFeatureExtractorConfig(py::module *m) {
36 using PyClass = OfflineFeatureExtractorConfig; 35 using PyClass = OfflineFeatureExtractorConfig;
37 py::class_<PyClass>(*m, "OfflineFeatureExtractorConfig") 36 py::class_<PyClass>(*m, "OfflineFeatureExtractorConfig")
@@ -42,7 +41,6 @@ static void PybindOfflineFeatureExtractorConfig(py::module *m) { @@ -42,7 +41,6 @@ static void PybindOfflineFeatureExtractorConfig(py::module *m) {
42 .def("__str__", &PyClass::ToString); 41 .def("__str__", &PyClass::ToString);
43 } 42 }
44 43
45 -  
46 void PybindOfflineStream(py::module *m) { 44 void PybindOfflineStream(py::module *m) {
47 PybindOfflineFeatureExtractorConfig(m); 45 PybindOfflineFeatureExtractorConfig(m);
48 PybindOfflineRecognitionResult(m); 46 PybindOfflineRecognitionResult(m);
@@ -55,7 +53,7 @@ void PybindOfflineStream(py::module *m) { @@ -55,7 +53,7 @@ void PybindOfflineStream(py::module *m) {
55 self.AcceptWaveform(sample_rate, waveform.data(), waveform.size()); 53 self.AcceptWaveform(sample_rate, waveform.data(), waveform.size());
56 }, 54 },
57 py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage) 55 py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage)
58 - .def_property_readonly("result", &PyClass::GetResult); 56 + .def_property_readonly("result", &PyClass::GetResult);
59 } 57 }
60 58
61 } // namespace sherpa_onnx 59 } // namespace sherpa_onnx
@@ -7,15 +7,12 @@ @@ -7,15 +7,12 @@
7 #include "sherpa-onnx/python/csrc/display.h" 7 #include "sherpa-onnx/python/csrc/display.h"
8 #include "sherpa-onnx/python/csrc/endpoint.h" 8 #include "sherpa-onnx/python/csrc/endpoint.h"
9 #include "sherpa-onnx/python/csrc/features.h" 9 #include "sherpa-onnx/python/csrc/features.h"
10 -#include "sherpa-onnx/python/csrc/online-recognizer.h"  
11 -#include "sherpa-onnx/python/csrc/online-stream.h"  
12 -#include "sherpa-onnx/python/csrc/online-transducer-model-config.h"  
13 -  
14 #include "sherpa-onnx/python/csrc/offline-model-config.h" 10 #include "sherpa-onnx/python/csrc/offline-model-config.h"
15 -#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"  
16 #include "sherpa-onnx/python/csrc/offline-recognizer.h" 11 #include "sherpa-onnx/python/csrc/offline-recognizer.h"
17 #include "sherpa-onnx/python/csrc/offline-stream.h" 12 #include "sherpa-onnx/python/csrc/offline-stream.h"
18 -#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" 13 +#include "sherpa-onnx/python/csrc/online-recognizer.h"
  14 +#include "sherpa-onnx/python/csrc/online-stream.h"
  15 +#include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
19 16
20 namespace sherpa_onnx { 17 namespace sherpa_onnx {
21 18
@@ -4,12 +4,15 @@ from typing import List @@ -4,12 +4,15 @@ from typing import List
4 4
5 from _sherpa_onnx import ( 5 from _sherpa_onnx import (
6 OfflineFeatureExtractorConfig, 6 OfflineFeatureExtractorConfig,
7 - OfflineRecognizer as _Recognizer, 7 + OfflineModelConfig,
  8 + OfflineNemoEncDecCtcModelConfig,
  9 + OfflineParaformerModelConfig,
  10 +)
  11 +from _sherpa_onnx import OfflineRecognizer as _Recognizer
  12 +from _sherpa_onnx import (
8 OfflineRecognizerConfig, 13 OfflineRecognizerConfig,
9 OfflineStream, 14 OfflineStream,
10 - OfflineModelConfig,  
11 OfflineTransducerModelConfig, 15 OfflineTransducerModelConfig,
12 - OfflineParaformerModelConfig,  
13 ) 16 )
14 17
15 18
@@ -75,7 +78,6 @@ class OfflineRecognizer(object): @@ -75,7 +78,6 @@ class OfflineRecognizer(object):
75 decoder_filename=decoder, 78 decoder_filename=decoder,
76 joiner_filename=joiner, 79 joiner_filename=joiner,
77 ), 80 ),
78 - paraformer=OfflineParaformerModelConfig(model=""),  
79 tokens=tokens, 81 tokens=tokens,
80 num_threads=num_threads, 82 num_threads=num_threads,
81 debug=debug, 83 debug=debug,
@@ -119,7 +121,7 @@ class OfflineRecognizer(object): @@ -119,7 +121,7 @@ class OfflineRecognizer(object):
119 symbol integer_id 121 symbol integer_id
120 122
121 paraformer: 123 paraformer:
122 - Path to ``paraformer.onnx``. 124 + Path to ``model.onnx``.
123 num_threads: 125 num_threads:
124 Number of threads for neural network computation. 126 Number of threads for neural network computation.
125 sample_rate: 127 sample_rate:
@@ -133,9 +135,6 @@ class OfflineRecognizer(object): @@ -133,9 +135,6 @@ class OfflineRecognizer(object):
133 """ 135 """
134 self = cls.__new__(cls) 136 self = cls.__new__(cls)
135 model_config = OfflineModelConfig( 137 model_config = OfflineModelConfig(
136 - transducer=OfflineTransducerModelConfig(  
137 - encoder_filename="", decoder_filename="", joiner_filename=""  
138 - ),  
139 paraformer=OfflineParaformerModelConfig(model=paraformer), 138 paraformer=OfflineParaformerModelConfig(model=paraformer),
140 tokens=tokens, 139 tokens=tokens,
141 num_threads=num_threads, 140 num_threads=num_threads,
@@ -155,6 +154,64 @@ class OfflineRecognizer(object): @@ -155,6 +154,64 @@ class OfflineRecognizer(object):
155 self.recognizer = _Recognizer(recognizer_config) 154 self.recognizer = _Recognizer(recognizer_config)
156 return self 155 return self
157 156
  157 + @classmethod
  158 + def from_nemo_ctc(
  159 + cls,
  160 + model: str,
  161 + tokens: str,
  162 + num_threads: int,
  163 + sample_rate: int = 16000,
  164 + feature_dim: int = 80,
  165 + decoding_method: str = "greedy_search",
  166 + debug: bool = False,
  167 + ):
  168 + """
  169 + Please refer to
  170 + `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_
  171 + to download pre-trained models for different languages, e.g., Chinese,
  172 + English, etc.
  173 +
  174 + Args:
  175 + tokens:
  176 + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
  177 + columns::
  178 +
  179 + symbol integer_id
  180 +
  181 + model:
  182 + Path to ``model.onnx``.
  183 + num_threads:
  184 + Number of threads for neural network computation.
  185 + sample_rate:
  186 + Sample rate of the training data used to train the model.
  187 + feature_dim:
  188 + Dimension of the feature used to train the model.
  189 + decoding_method:
  190 + Valid values are greedy_search, modified_beam_search.
  191 + debug:
  192 + True to show debug messages.
  193 + """
  194 + self = cls.__new__(cls)
  195 + model_config = OfflineModelConfig(
  196 + nemo_ctc=OfflineNemoEncDecCtcModelConfig(model=model),
  197 + tokens=tokens,
  198 + num_threads=num_threads,
  199 + debug=debug,
  200 + )
  201 +
  202 + feat_config = OfflineFeatureExtractorConfig(
  203 + sampling_rate=sample_rate,
  204 + feature_dim=feature_dim,
  205 + )
  206 +
  207 + recognizer_config = OfflineRecognizerConfig(
  208 + feat_config=feat_config,
  209 + model_config=model_config,
  210 + decoding_method=decoding_method,
  211 + )
  212 + self.recognizer = _Recognizer(recognizer_config)
  213 + return self
  214 +
158 def create_stream(self): 215 def create_stream(self):
159 return self.recognizer.create_stream() 216 return self.recognizer.create_stream()
160 217
@@ -196,6 +196,71 @@ class TestOfflineRecognizer(unittest.TestCase): @@ -196,6 +196,71 @@ class TestOfflineRecognizer(unittest.TestCase):
196 print(s2.result.text) 196 print(s2.result.text)
197 print(s3.result.text) 197 print(s3.result.text)
198 198
  199 + def test_nemo_ctc_single_file(self):
  200 + for use_int8 in [True, False]:
  201 + if use_int8:
  202 + model = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/model.int8.onnx"
  203 + else:
  204 + model = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx"
  205 +
  206 + tokens = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt"
  207 + wave0 = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav"
  208 +
  209 + if not Path(model).is_file():
  210 + print("skipping test_nemo_ctc_single_file()")
  211 + return
  212 +
  213 + recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
  214 + model=model,
  215 + tokens=tokens,
  216 + num_threads=1,
  217 + )
  218 +
  219 + s = recognizer.create_stream()
  220 + samples, sample_rate = read_wave(wave0)
  221 + s.accept_waveform(sample_rate, samples)
  222 + recognizer.decode_stream(s)
  223 + print(s.result.text)
  224 +
  225 + def test_nemo_ctc_multiple_files(self):
  226 + for use_int8 in [True, False]:
  227 + if use_int8:
  228 + model = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/model.int8.onnx"
  229 + else:
  230 + model = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx"
  231 +
  232 + tokens = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt"
  233 + wave0 = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav"
  234 + wave1 = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav"
  235 + wave2 = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav"
  236 +
  237 + if not Path(model).is_file():
  238 + print("skipping test_nemo_ctc_multiple_files()")
  239 + return
  240 +
  241 + recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
  242 + model=model,
  243 + tokens=tokens,
  244 + num_threads=1,
  245 + )
  246 +
  247 + s0 = recognizer.create_stream()
  248 + samples0, sample_rate0 = read_wave(wave0)
  249 + s0.accept_waveform(sample_rate0, samples0)
  250 +
  251 + s1 = recognizer.create_stream()
  252 + samples1, sample_rate1 = read_wave(wave1)
  253 + s1.accept_waveform(sample_rate1, samples1)
  254 +
  255 + s2 = recognizer.create_stream()
  256 + samples2, sample_rate2 = read_wave(wave2)
  257 + s2.accept_waveform(sample_rate2, samples2)
  258 +
  259 + recognizer.decode_streams([s0, s1, s2])
  260 + print(s0.result.text)
  261 + print(s1.result.text)
  262 + print(s2.result.text)
  263 +
199 264
200 if __name__ == "__main__": 265 if __name__ == "__main__":
201 unittest.main() 266 unittest.main()