Committed by
GitHub
Add C++ runtime and Python APIs for Moonshine models (#1473)
正在显示
33 个修改的文件
包含
1572 行增加
和
36 行删除
.github/scripts/test-offline-moonshine.sh
0 → 100755
| 1 | +#!/usr/bin/env bash | ||
| 2 | + | ||
| 3 | +set -e | ||
| 4 | + | ||
| 5 | +log() { | ||
| 6 | + # This function is from espnet | ||
| 7 | + local fname=${BASH_SOURCE[1]##*/} | ||
| 8 | + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | ||
| 9 | +} | ||
| 10 | + | ||
| 11 | +export GIT_CLONE_PROTECTION_ACTIVE=false | ||
| 12 | + | ||
| 13 | +echo "EXE is $EXE" | ||
| 14 | +echo "PATH: $PATH" | ||
| 15 | + | ||
| 16 | +which $EXE | ||
| 17 | + | ||
| 18 | +names=( | ||
| 19 | +tiny | ||
| 20 | +base | ||
| 21 | +) | ||
| 22 | + | ||
| 23 | +for name in ${names[@]}; do | ||
| 24 | + log "------------------------------------------------------------" | ||
| 25 | + log "Run $name" | ||
| 26 | + log "------------------------------------------------------------" | ||
| 27 | + | ||
| 28 | + repo_url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-$name.tar.bz2 | ||
| 29 | + repo_url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-$name-en-int8.tar.bz2 | ||
| 30 | + curl -SL -O $repo_url | ||
| 31 | + tar xvf sherpa-onnx-moonshine-$name-en-int8.tar.bz2 | ||
| 32 | + rm sherpa-onnx-moonshine-$name-en-int8.tar.bz2 | ||
| 33 | + repo=sherpa-onnx-moonshine-$name-en-int8 | ||
| 34 | + log "Start testing ${repo_url}" | ||
| 35 | + | ||
| 36 | + log "test int8 onnx" | ||
| 37 | + | ||
| 38 | + time $EXE \ | ||
| 39 | + --moonshine-preprocessor=$repo/preprocess.onnx \ | ||
| 40 | + --moonshine-encoder=$repo/encode.int8.onnx \ | ||
| 41 | + --moonshine-uncached-decoder=$repo/uncached_decode.int8.onnx \ | ||
| 42 | + --moonshine-cached-decoder=$repo/cached_decode.int8.onnx \ | ||
| 43 | + --tokens=$repo/tokens.txt \ | ||
| 44 | + --num-threads=2 \ | ||
| 45 | + $repo/test_wavs/0.wav \ | ||
| 46 | + $repo/test_wavs/1.wav \ | ||
| 47 | + $repo/test_wavs/8k.wav | ||
| 48 | + | ||
| 49 | + rm -rf $repo | ||
| 50 | +done |
| @@ -8,6 +8,16 @@ log() { | @@ -8,6 +8,16 @@ log() { | ||
| 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" |
| 9 | } | 9 | } |
| 10 | 10 | ||
| 11 | +log "test offline Moonshine" | ||
| 12 | + | ||
| 13 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 | ||
| 14 | +tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 | ||
| 15 | +rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 | ||
| 16 | + | ||
| 17 | +python3 ./python-api-examples/offline-moonshine-decode-files.py | ||
| 18 | + | ||
| 19 | +rm -rf sherpa-onnx-moonshine-tiny-en-int8 | ||
| 20 | + | ||
| 11 | log "test offline speaker diarization" | 21 | log "test offline speaker diarization" |
| 12 | 22 | ||
| 13 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 | 23 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 |
| @@ -149,6 +149,19 @@ jobs: | @@ -149,6 +149,19 @@ jobs: | ||
| 149 | name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} | 149 | name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} |
| 150 | path: install/* | 150 | path: install/* |
| 151 | 151 | ||
| 152 | + - name: Test offline Moonshine | ||
| 153 | + if: matrix.build_type != 'Debug' | ||
| 154 | + shell: bash | ||
| 155 | + run: | | ||
| 156 | + du -h -d1 . | ||
| 157 | + export PATH=$PWD/build/bin:$PATH | ||
| 158 | + export EXE=sherpa-onnx-offline | ||
| 159 | + | ||
| 160 | + readelf -d build/bin/sherpa-onnx-offline | ||
| 161 | + | ||
| 162 | + .github/scripts/test-offline-moonshine.sh | ||
| 163 | + du -h -d1 . | ||
| 164 | + | ||
| 152 | - name: Test offline CTC | 165 | - name: Test offline CTC |
| 153 | shell: bash | 166 | shell: bash |
| 154 | run: | | 167 | run: | |
| @@ -121,6 +121,15 @@ jobs: | @@ -121,6 +121,15 @@ jobs: | ||
| 121 | otool -L build/bin/sherpa-onnx | 121 | otool -L build/bin/sherpa-onnx |
| 122 | otool -l build/bin/sherpa-onnx | 122 | otool -l build/bin/sherpa-onnx |
| 123 | 123 | ||
| 124 | + - name: Test offline Moonshine | ||
| 125 | + if: matrix.build_type != 'Debug' | ||
| 126 | + shell: bash | ||
| 127 | + run: | | ||
| 128 | + export PATH=$PWD/build/bin:$PATH | ||
| 129 | + export EXE=sherpa-onnx-offline | ||
| 130 | + | ||
| 131 | + .github/scripts/test-offline-moonshine.sh | ||
| 132 | + | ||
| 124 | - name: Test C++ API | 133 | - name: Test C++ API |
| 125 | shell: bash | 134 | shell: bash |
| 126 | run: | | 135 | run: | |
| @@ -243,8 +252,6 @@ jobs: | @@ -243,8 +252,6 @@ jobs: | ||
| 243 | 252 | ||
| 244 | .github/scripts/test-offline-whisper.sh | 253 | .github/scripts/test-offline-whisper.sh |
| 245 | 254 | ||
| 246 | - | ||
| 247 | - | ||
| 248 | - name: Test online transducer | 255 | - name: Test online transducer |
| 249 | shell: bash | 256 | shell: bash |
| 250 | run: | | 257 | run: | |
| @@ -93,6 +93,14 @@ jobs: | @@ -93,6 +93,14 @@ jobs: | ||
| 93 | name: release-windows-x64-${{ matrix.shared_lib }}-${{ matrix.with_tts }} | 93 | name: release-windows-x64-${{ matrix.shared_lib }}-${{ matrix.with_tts }} |
| 94 | path: build/install/* | 94 | path: build/install/* |
| 95 | 95 | ||
| 96 | + - name: Test offline Moonshine for windows x64 | ||
| 97 | + shell: bash | ||
| 98 | + run: | | ||
| 99 | + export PATH=$PWD/build/bin/Release:$PATH | ||
| 100 | + export EXE=sherpa-onnx-offline.exe | ||
| 101 | + | ||
| 102 | + .github/scripts/test-offline-moonshine.sh | ||
| 103 | + | ||
| 96 | - name: Test C++ API | 104 | - name: Test C++ API |
| 97 | shell: bash | 105 | shell: bash |
| 98 | run: | | 106 | run: | |
| @@ -93,6 +93,14 @@ jobs: | @@ -93,6 +93,14 @@ jobs: | ||
| 93 | name: release-windows-x86-${{ matrix.shared_lib }}-${{ matrix.with_tts }} | 93 | name: release-windows-x86-${{ matrix.shared_lib }}-${{ matrix.with_tts }} |
| 94 | path: build/install/* | 94 | path: build/install/* |
| 95 | 95 | ||
| 96 | + - name: Test offline Moonshine for windows x86 | ||
| 97 | + shell: bash | ||
| 98 | + run: | | ||
| 99 | + export PATH=$PWD/build/bin/Release:$PATH | ||
| 100 | + export EXE=sherpa-onnx-offline.exe | ||
| 101 | + | ||
| 102 | + .github/scripts/test-offline-moonshine.sh | ||
| 103 | + | ||
| 96 | - name: Test C++ API | 104 | - name: Test C++ API |
| 97 | shell: bash | 105 | shell: bash |
| 98 | run: | | 106 | run: | |
| @@ -47,7 +47,19 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v | @@ -47,7 +47,19 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v | ||
| 47 | --feature-dim=80 \ | 47 | --feature-dim=80 \ |
| 48 | /path/to/test.mp4 | 48 | /path/to/test.mp4 |
| 49 | 49 | ||
| 50 | -(3) For Whisper models | 50 | +(3) For Moonshine models |
| 51 | + | ||
| 52 | +./python-api-examples/generate-subtitles.py \ | ||
| 53 | + --silero-vad-model=/path/to/silero_vad.onnx \ | ||
| 54 | + --moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \ | ||
| 55 | + --moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \ | ||
| 56 | + --moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \ | ||
| 57 | + --moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \ | ||
| 58 | + --tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt \ | ||
| 59 | + --num-threads=2 \ | ||
| 60 | + /path/to/test.mp4 | ||
| 61 | + | ||
| 62 | +(4) For Whisper models | ||
| 51 | 63 | ||
| 52 | ./python-api-examples/generate-subtitles.py \ | 64 | ./python-api-examples/generate-subtitles.py \ |
| 53 | --silero-vad-model=/path/to/silero_vad.onnx \ | 65 | --silero-vad-model=/path/to/silero_vad.onnx \ |
| @@ -58,7 +70,7 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v | @@ -58,7 +70,7 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v | ||
| 58 | --num-threads=2 \ | 70 | --num-threads=2 \ |
| 59 | /path/to/test.mp4 | 71 | /path/to/test.mp4 |
| 60 | 72 | ||
| 61 | -(4) For SenseVoice CTC models | 73 | +(5) For SenseVoice CTC models |
| 62 | 74 | ||
| 63 | ./python-api-examples/generate-subtitles.py \ | 75 | ./python-api-examples/generate-subtitles.py \ |
| 64 | --silero-vad-model=/path/to/silero_vad.onnx \ | 76 | --silero-vad-model=/path/to/silero_vad.onnx \ |
| @@ -68,7 +80,7 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v | @@ -68,7 +80,7 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v | ||
| 68 | /path/to/test.mp4 | 80 | /path/to/test.mp4 |
| 69 | 81 | ||
| 70 | 82 | ||
| 71 | -(5) For WeNet CTC models | 83 | +(6) For WeNet CTC models |
| 72 | 84 | ||
| 73 | ./python-api-examples/generate-subtitles.py \ | 85 | ./python-api-examples/generate-subtitles.py \ |
| 74 | --silero-vad-model=/path/to/silero_vad.onnx \ | 86 | --silero-vad-model=/path/to/silero_vad.onnx \ |
| @@ -83,6 +95,7 @@ to install sherpa-onnx and to download non-streaming pre-trained models | @@ -83,6 +95,7 @@ to install sherpa-onnx and to download non-streaming pre-trained models | ||
| 83 | used in this file. | 95 | used in this file. |
| 84 | """ | 96 | """ |
| 85 | import argparse | 97 | import argparse |
| 98 | +import datetime as dt | ||
| 86 | import shutil | 99 | import shutil |
| 87 | import subprocess | 100 | import subprocess |
| 88 | import sys | 101 | import sys |
| @@ -157,7 +170,7 @@ def get_args(): | @@ -157,7 +170,7 @@ def get_args(): | ||
| 157 | parser.add_argument( | 170 | parser.add_argument( |
| 158 | "--num-threads", | 171 | "--num-threads", |
| 159 | type=int, | 172 | type=int, |
| 160 | - default=1, | 173 | + default=2, |
| 161 | help="Number of threads for neural network computation", | 174 | help="Number of threads for neural network computation", |
| 162 | ) | 175 | ) |
| 163 | 176 | ||
| @@ -209,6 +222,34 @@ def get_args(): | @@ -209,6 +222,34 @@ def get_args(): | ||
| 209 | ) | 222 | ) |
| 210 | 223 | ||
| 211 | parser.add_argument( | 224 | parser.add_argument( |
| 225 | + "--moonshine-preprocessor", | ||
| 226 | + default="", | ||
| 227 | + type=str, | ||
| 228 | + help="Path to moonshine preprocessor model", | ||
| 229 | + ) | ||
| 230 | + | ||
| 231 | + parser.add_argument( | ||
| 232 | + "--moonshine-encoder", | ||
| 233 | + default="", | ||
| 234 | + type=str, | ||
| 235 | + help="Path to moonshine encoder model", | ||
| 236 | + ) | ||
| 237 | + | ||
| 238 | + parser.add_argument( | ||
| 239 | + "--moonshine-uncached-decoder", | ||
| 240 | + default="", | ||
| 241 | + type=str, | ||
| 242 | + help="Path to moonshine uncached decoder model", | ||
| 243 | + ) | ||
| 244 | + | ||
| 245 | + parser.add_argument( | ||
| 246 | + "--moonshine-cached-decoder", | ||
| 247 | + default="", | ||
| 248 | + type=str, | ||
| 249 | + help="Path to moonshine cached decoder model", | ||
| 250 | + ) | ||
| 251 | + | ||
| 252 | + parser.add_argument( | ||
| 212 | "--decoding-method", | 253 | "--decoding-method", |
| 213 | type=str, | 254 | type=str, |
| 214 | default="greedy_search", | 255 | default="greedy_search", |
| @@ -263,6 +304,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -263,6 +304,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 263 | assert len(args.wenet_ctc) == 0, args.wenet_ctc | 304 | assert len(args.wenet_ctc) == 0, args.wenet_ctc |
| 264 | assert len(args.whisper_encoder) == 0, args.whisper_encoder | 305 | assert len(args.whisper_encoder) == 0, args.whisper_encoder |
| 265 | assert len(args.whisper_decoder) == 0, args.whisper_decoder | 306 | assert len(args.whisper_decoder) == 0, args.whisper_decoder |
| 307 | + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor | ||
| 308 | + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder | ||
| 309 | + assert ( | ||
| 310 | + len(args.moonshine_uncached_decoder) == 0 | ||
| 311 | + ), args.moonshine_uncached_decoder | ||
| 312 | + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder | ||
| 266 | 313 | ||
| 267 | assert_file_exists(args.encoder) | 314 | assert_file_exists(args.encoder) |
| 268 | assert_file_exists(args.decoder) | 315 | assert_file_exists(args.decoder) |
| @@ -284,6 +331,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -284,6 +331,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 284 | assert len(args.wenet_ctc) == 0, args.wenet_ctc | 331 | assert len(args.wenet_ctc) == 0, args.wenet_ctc |
| 285 | assert len(args.whisper_encoder) == 0, args.whisper_encoder | 332 | assert len(args.whisper_encoder) == 0, args.whisper_encoder |
| 286 | assert len(args.whisper_decoder) == 0, args.whisper_decoder | 333 | assert len(args.whisper_decoder) == 0, args.whisper_decoder |
| 334 | + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor | ||
| 335 | + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder | ||
| 336 | + assert ( | ||
| 337 | + len(args.moonshine_uncached_decoder) == 0 | ||
| 338 | + ), args.moonshine_uncached_decoder | ||
| 339 | + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder | ||
| 287 | 340 | ||
| 288 | assert_file_exists(args.paraformer) | 341 | assert_file_exists(args.paraformer) |
| 289 | 342 | ||
| @@ -300,6 +353,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -300,6 +353,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 300 | assert len(args.wenet_ctc) == 0, args.wenet_ctc | 353 | assert len(args.wenet_ctc) == 0, args.wenet_ctc |
| 301 | assert len(args.whisper_encoder) == 0, args.whisper_encoder | 354 | assert len(args.whisper_encoder) == 0, args.whisper_encoder |
| 302 | assert len(args.whisper_decoder) == 0, args.whisper_decoder | 355 | assert len(args.whisper_decoder) == 0, args.whisper_decoder |
| 356 | + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor | ||
| 357 | + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder | ||
| 358 | + assert ( | ||
| 359 | + len(args.moonshine_uncached_decoder) == 0 | ||
| 360 | + ), args.moonshine_uncached_decoder | ||
| 361 | + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder | ||
| 303 | 362 | ||
| 304 | assert_file_exists(args.sense_voice) | 363 | assert_file_exists(args.sense_voice) |
| 305 | recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice( | 364 | recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice( |
| @@ -312,6 +371,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -312,6 +371,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 312 | elif args.wenet_ctc: | 371 | elif args.wenet_ctc: |
| 313 | assert len(args.whisper_encoder) == 0, args.whisper_encoder | 372 | assert len(args.whisper_encoder) == 0, args.whisper_encoder |
| 314 | assert len(args.whisper_decoder) == 0, args.whisper_decoder | 373 | assert len(args.whisper_decoder) == 0, args.whisper_decoder |
| 374 | + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor | ||
| 375 | + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder | ||
| 376 | + assert ( | ||
| 377 | + len(args.moonshine_uncached_decoder) == 0 | ||
| 378 | + ), args.moonshine_uncached_decoder | ||
| 379 | + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder | ||
| 315 | 380 | ||
| 316 | assert_file_exists(args.wenet_ctc) | 381 | assert_file_exists(args.wenet_ctc) |
| 317 | 382 | ||
| @@ -327,6 +392,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -327,6 +392,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 327 | elif args.whisper_encoder: | 392 | elif args.whisper_encoder: |
| 328 | assert_file_exists(args.whisper_encoder) | 393 | assert_file_exists(args.whisper_encoder) |
| 329 | assert_file_exists(args.whisper_decoder) | 394 | assert_file_exists(args.whisper_decoder) |
| 395 | + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor | ||
| 396 | + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder | ||
| 397 | + assert ( | ||
| 398 | + len(args.moonshine_uncached_decoder) == 0 | ||
| 399 | + ), args.moonshine_uncached_decoder | ||
| 400 | + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder | ||
| 330 | 401 | ||
| 331 | recognizer = sherpa_onnx.OfflineRecognizer.from_whisper( | 402 | recognizer = sherpa_onnx.OfflineRecognizer.from_whisper( |
| 332 | encoder=args.whisper_encoder, | 403 | encoder=args.whisper_encoder, |
| @@ -339,6 +410,22 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -339,6 +410,22 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 339 | task=args.whisper_task, | 410 | task=args.whisper_task, |
| 340 | tail_paddings=args.whisper_tail_paddings, | 411 | tail_paddings=args.whisper_tail_paddings, |
| 341 | ) | 412 | ) |
| 413 | + elif args.moonshine_preprocessor: | ||
| 414 | + assert_file_exists(args.moonshine_preprocessor) | ||
| 415 | + assert_file_exists(args.moonshine_encoder) | ||
| 416 | + assert_file_exists(args.moonshine_uncached_decoder) | ||
| 417 | + assert_file_exists(args.moonshine_cached_decoder) | ||
| 418 | + | ||
| 419 | + recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine( | ||
| 420 | + preprocessor=args.moonshine_preprocessor, | ||
| 421 | + encoder=args.moonshine_encoder, | ||
| 422 | + uncached_decoder=args.moonshine_uncached_decoder, | ||
| 423 | + cached_decoder=args.moonshine_cached_decoder, | ||
| 424 | + tokens=args.tokens, | ||
| 425 | + num_threads=args.num_threads, | ||
| 426 | + decoding_method=args.decoding_method, | ||
| 427 | + debug=args.debug, | ||
| 428 | + ) | ||
| 342 | else: | 429 | else: |
| 343 | raise ValueError("Please specify at least one model") | 430 | raise ValueError("Please specify at least one model") |
| 344 | 431 | ||
| @@ -424,28 +511,32 @@ def main(): | @@ -424,28 +511,32 @@ def main(): | ||
| 424 | segment_list = [] | 511 | segment_list = [] |
| 425 | 512 | ||
| 426 | print("Started!") | 513 | print("Started!") |
| 514 | + start_t = dt.datetime.now() | ||
| 515 | + num_processed_samples = 0 | ||
| 427 | 516 | ||
| 428 | - is_silence = False | 517 | + is_eof = False |
| 429 | # TODO(fangjun): Support multithreads | 518 | # TODO(fangjun): Support multithreads |
| 430 | while True: | 519 | while True: |
| 431 | # *2 because int16_t has two bytes | 520 | # *2 because int16_t has two bytes |
| 432 | data = process.stdout.read(frames_per_read * 2) | 521 | data = process.stdout.read(frames_per_read * 2) |
| 433 | if not data: | 522 | if not data: |
| 434 | - if is_silence: | 523 | + if is_eof: |
| 435 | break | 524 | break |
| 436 | - is_silence = True | ||
| 437 | - # The converted audio file does not have a mute data of 1 second or more at the end, which will result in the loss of the last segment data | 525 | + is_eof = True |
| 526 | + # pad 1 second at the end of the file for the VAD | ||
| 438 | data = np.zeros(1 * args.sample_rate, dtype=np.int16) | 527 | data = np.zeros(1 * args.sample_rate, dtype=np.int16) |
| 439 | 528 | ||
| 440 | samples = np.frombuffer(data, dtype=np.int16) | 529 | samples = np.frombuffer(data, dtype=np.int16) |
| 441 | samples = samples.astype(np.float32) / 32768 | 530 | samples = samples.astype(np.float32) / 32768 |
| 442 | 531 | ||
| 532 | + num_processed_samples += samples.shape[0] | ||
| 533 | + | ||
| 443 | buffer = np.concatenate([buffer, samples]) | 534 | buffer = np.concatenate([buffer, samples]) |
| 444 | while len(buffer) > window_size: | 535 | while len(buffer) > window_size: |
| 445 | vad.accept_waveform(buffer[:window_size]) | 536 | vad.accept_waveform(buffer[:window_size]) |
| 446 | buffer = buffer[window_size:] | 537 | buffer = buffer[window_size:] |
| 447 | 538 | ||
| 448 | - if is_silence: | 539 | + if is_eof: |
| 449 | vad.flush() | 540 | vad.flush() |
| 450 | 541 | ||
| 451 | streams = [] | 542 | streams = [] |
| @@ -471,6 +562,11 @@ def main(): | @@ -471,6 +562,11 @@ def main(): | ||
| 471 | seg.text = stream.result.text | 562 | seg.text = stream.result.text |
| 472 | segment_list.append(seg) | 563 | segment_list.append(seg) |
| 473 | 564 | ||
| 565 | + end_t = dt.datetime.now() | ||
| 566 | + elapsed_seconds = (end_t - start_t).total_seconds() | ||
| 567 | + duration = num_processed_samples / 16000 | ||
| 568 | + rtf = elapsed_seconds / duration | ||
| 569 | + | ||
| 474 | srt_filename = Path(args.sound_file).with_suffix(".srt") | 570 | srt_filename = Path(args.sound_file).with_suffix(".srt") |
| 475 | with open(srt_filename, "w", encoding="utf-8") as f: | 571 | with open(srt_filename, "w", encoding="utf-8") as f: |
| 476 | for i, seg in enumerate(segment_list): | 572 | for i, seg in enumerate(segment_list): |
| @@ -479,6 +575,9 @@ def main(): | @@ -479,6 +575,9 @@ def main(): | ||
| 479 | print("", file=f) | 575 | print("", file=f) |
| 480 | 576 | ||
| 481 | print(f"Saved to {srt_filename}") | 577 | print(f"Saved to {srt_filename}") |
| 578 | + print(f"Audio duration:\t{duration:.3f} s") | ||
| 579 | + print(f"Elapsed:\t{elapsed_seconds:.3f} s") | ||
| 580 | + print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}") | ||
| 482 | print("Done!") | 581 | print("Done!") |
| 483 | 582 | ||
| 484 | 583 |
| @@ -66,7 +66,21 @@ python3 ./python-api-examples/non_streaming_server.py \ | @@ -66,7 +66,21 @@ python3 ./python-api-examples/non_streaming_server.py \ | ||
| 66 | --wenet-ctc ./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \ | 66 | --wenet-ctc ./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \ |
| 67 | --tokens ./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt | 67 | --tokens ./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt |
| 68 | 68 | ||
| 69 | -(5) Use a Whisper model | 69 | +(5) Use a Moonshine model |
| 70 | + | ||
| 71 | +cd /path/to/sherpa-onnx | ||
| 72 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 | ||
| 73 | +tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 | ||
| 74 | +rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 | ||
| 75 | + | ||
| 76 | +python3 ./python-api-examples/non_streaming_server.py \ | ||
| 77 | + --moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \ | ||
| 78 | + --moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \ | ||
| 79 | + --moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \ | ||
| 80 | + --moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \ | ||
| 81 | + --tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt | ||
| 82 | + | ||
| 83 | +(6) Use a Whisper model | ||
| 70 | 84 | ||
| 71 | cd /path/to/sherpa-onnx | 85 | cd /path/to/sherpa-onnx |
| 72 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2 | 86 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2 |
| @@ -78,7 +92,7 @@ python3 ./python-api-examples/non_streaming_server.py \ | @@ -78,7 +92,7 @@ python3 ./python-api-examples/non_streaming_server.py \ | ||
| 78 | --whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \ | 92 | --whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \ |
| 79 | --tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt | 93 | --tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt |
| 80 | 94 | ||
| 81 | -(5) Use a tdnn model of the yesno recipe from icefall | 95 | +(7) Use a tdnn model of the yesno recipe from icefall |
| 82 | 96 | ||
| 83 | cd /path/to/sherpa-onnx | 97 | cd /path/to/sherpa-onnx |
| 84 | 98 | ||
| @@ -92,7 +106,7 @@ python3 ./python-api-examples/non_streaming_server.py \ | @@ -92,7 +106,7 @@ python3 ./python-api-examples/non_streaming_server.py \ | ||
| 92 | --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \ | 106 | --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \ |
| 93 | --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt | 107 | --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt |
| 94 | 108 | ||
| 95 | -(6) Use a Non-streaming SenseVoice model | 109 | +(8) Use a Non-streaming SenseVoice model |
| 96 | 110 | ||
| 97 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 | 111 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 |
| 98 | tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 | 112 | tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 |
| @@ -254,6 +268,36 @@ def add_tdnn_ctc_model_args(parser: argparse.ArgumentParser): | @@ -254,6 +268,36 @@ def add_tdnn_ctc_model_args(parser: argparse.ArgumentParser): | ||
| 254 | ) | 268 | ) |
| 255 | 269 | ||
| 256 | 270 | ||
| 271 | +def add_moonshine_model_args(parser: argparse.ArgumentParser): | ||
| 272 | + parser.add_argument( | ||
| 273 | + "--moonshine-preprocessor", | ||
| 274 | + default="", | ||
| 275 | + type=str, | ||
| 276 | + help="Path to moonshine preprocessor model", | ||
| 277 | + ) | ||
| 278 | + | ||
| 279 | + parser.add_argument( | ||
| 280 | + "--moonshine-encoder", | ||
| 281 | + default="", | ||
| 282 | + type=str, | ||
| 283 | + help="Path to moonshine encoder model", | ||
| 284 | + ) | ||
| 285 | + | ||
| 286 | + parser.add_argument( | ||
| 287 | + "--moonshine-uncached-decoder", | ||
| 288 | + default="", | ||
| 289 | + type=str, | ||
| 290 | + help="Path to moonshine uncached decoder model", | ||
| 291 | + ) | ||
| 292 | + | ||
| 293 | + parser.add_argument( | ||
| 294 | + "--moonshine-cached-decoder", | ||
| 295 | + default="", | ||
| 296 | + type=str, | ||
| 297 | + help="Path to moonshine cached decoder model", | ||
| 298 | + ) | ||
| 299 | + | ||
| 300 | + | ||
| 257 | def add_whisper_model_args(parser: argparse.ArgumentParser): | 301 | def add_whisper_model_args(parser: argparse.ArgumentParser): |
| 258 | parser.add_argument( | 302 | parser.add_argument( |
| 259 | "--whisper-encoder", | 303 | "--whisper-encoder", |
| @@ -311,6 +355,7 @@ def add_model_args(parser: argparse.ArgumentParser): | @@ -311,6 +355,7 @@ def add_model_args(parser: argparse.ArgumentParser): | ||
| 311 | add_wenet_ctc_model_args(parser) | 355 | add_wenet_ctc_model_args(parser) |
| 312 | add_tdnn_ctc_model_args(parser) | 356 | add_tdnn_ctc_model_args(parser) |
| 313 | add_whisper_model_args(parser) | 357 | add_whisper_model_args(parser) |
| 358 | + add_moonshine_model_args(parser) | ||
| 314 | 359 | ||
| 315 | parser.add_argument( | 360 | parser.add_argument( |
| 316 | "--tokens", | 361 | "--tokens", |
| @@ -876,6 +921,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -876,6 +921,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 876 | assert len(args.whisper_encoder) == 0, args.whisper_encoder | 921 | assert len(args.whisper_encoder) == 0, args.whisper_encoder |
| 877 | assert len(args.whisper_decoder) == 0, args.whisper_decoder | 922 | assert len(args.whisper_decoder) == 0, args.whisper_decoder |
| 878 | assert len(args.tdnn_model) == 0, args.tdnn_model | 923 | assert len(args.tdnn_model) == 0, args.tdnn_model |
| 924 | + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor | ||
| 925 | + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder | ||
| 926 | + assert ( | ||
| 927 | + len(args.moonshine_uncached_decoder) == 0 | ||
| 928 | + ), args.moonshine_uncached_decoder | ||
| 929 | + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder | ||
| 879 | 930 | ||
| 880 | assert_file_exists(args.encoder) | 931 | assert_file_exists(args.encoder) |
| 881 | assert_file_exists(args.decoder) | 932 | assert_file_exists(args.decoder) |
| @@ -903,6 +954,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -903,6 +954,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 903 | assert len(args.whisper_encoder) == 0, args.whisper_encoder | 954 | assert len(args.whisper_encoder) == 0, args.whisper_encoder |
| 904 | assert len(args.whisper_decoder) == 0, args.whisper_decoder | 955 | assert len(args.whisper_decoder) == 0, args.whisper_decoder |
| 905 | assert len(args.tdnn_model) == 0, args.tdnn_model | 956 | assert len(args.tdnn_model) == 0, args.tdnn_model |
| 957 | + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor | ||
| 958 | + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder | ||
| 959 | + assert ( | ||
| 960 | + len(args.moonshine_uncached_decoder) == 0 | ||
| 961 | + ), args.moonshine_uncached_decoder | ||
| 962 | + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder | ||
| 906 | 963 | ||
| 907 | assert_file_exists(args.paraformer) | 964 | assert_file_exists(args.paraformer) |
| 908 | 965 | ||
| @@ -921,6 +978,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -921,6 +978,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 921 | assert len(args.whisper_encoder) == 0, args.whisper_encoder | 978 | assert len(args.whisper_encoder) == 0, args.whisper_encoder |
| 922 | assert len(args.whisper_decoder) == 0, args.whisper_decoder | 979 | assert len(args.whisper_decoder) == 0, args.whisper_decoder |
| 923 | assert len(args.tdnn_model) == 0, args.tdnn_model | 980 | assert len(args.tdnn_model) == 0, args.tdnn_model |
| 981 | + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor | ||
| 982 | + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder | ||
| 983 | + assert ( | ||
| 984 | + len(args.moonshine_uncached_decoder) == 0 | ||
| 985 | + ), args.moonshine_uncached_decoder | ||
| 986 | + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder | ||
| 924 | 987 | ||
| 925 | assert_file_exists(args.sense_voice) | 988 | assert_file_exists(args.sense_voice) |
| 926 | recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice( | 989 | recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice( |
| @@ -934,6 +997,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -934,6 +997,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 934 | assert len(args.whisper_encoder) == 0, args.whisper_encoder | 997 | assert len(args.whisper_encoder) == 0, args.whisper_encoder |
| 935 | assert len(args.whisper_decoder) == 0, args.whisper_decoder | 998 | assert len(args.whisper_decoder) == 0, args.whisper_decoder |
| 936 | assert len(args.tdnn_model) == 0, args.tdnn_model | 999 | assert len(args.tdnn_model) == 0, args.tdnn_model |
| 1000 | + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor | ||
| 1001 | + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder | ||
| 1002 | + assert ( | ||
| 1003 | + len(args.moonshine_uncached_decoder) == 0 | ||
| 1004 | + ), args.moonshine_uncached_decoder | ||
| 1005 | + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder | ||
| 937 | 1006 | ||
| 938 | assert_file_exists(args.nemo_ctc) | 1007 | assert_file_exists(args.nemo_ctc) |
| 939 | 1008 | ||
| @@ -950,6 +1019,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -950,6 +1019,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 950 | assert len(args.whisper_encoder) == 0, args.whisper_encoder | 1019 | assert len(args.whisper_encoder) == 0, args.whisper_encoder |
| 951 | assert len(args.whisper_decoder) == 0, args.whisper_decoder | 1020 | assert len(args.whisper_decoder) == 0, args.whisper_decoder |
| 952 | assert len(args.tdnn_model) == 0, args.tdnn_model | 1021 | assert len(args.tdnn_model) == 0, args.tdnn_model |
| 1022 | + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor | ||
| 1023 | + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder | ||
| 1024 | + assert ( | ||
| 1025 | + len(args.moonshine_uncached_decoder) == 0 | ||
| 1026 | + ), args.moonshine_uncached_decoder | ||
| 1027 | + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder | ||
| 953 | 1028 | ||
| 954 | assert_file_exists(args.wenet_ctc) | 1029 | assert_file_exists(args.wenet_ctc) |
| 955 | 1030 | ||
| @@ -966,6 +1041,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -966,6 +1041,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 966 | assert len(args.tdnn_model) == 0, args.tdnn_model | 1041 | assert len(args.tdnn_model) == 0, args.tdnn_model |
| 967 | assert_file_exists(args.whisper_encoder) | 1042 | assert_file_exists(args.whisper_encoder) |
| 968 | assert_file_exists(args.whisper_decoder) | 1043 | assert_file_exists(args.whisper_decoder) |
| 1044 | + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor | ||
| 1045 | + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder | ||
| 1046 | + assert ( | ||
| 1047 | + len(args.moonshine_uncached_decoder) == 0 | ||
| 1048 | + ), args.moonshine_uncached_decoder | ||
| 1049 | + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder | ||
| 969 | 1050 | ||
| 970 | recognizer = sherpa_onnx.OfflineRecognizer.from_whisper( | 1051 | recognizer = sherpa_onnx.OfflineRecognizer.from_whisper( |
| 971 | encoder=args.whisper_encoder, | 1052 | encoder=args.whisper_encoder, |
| @@ -980,6 +1061,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -980,6 +1061,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 980 | ) | 1061 | ) |
| 981 | elif args.tdnn_model: | 1062 | elif args.tdnn_model: |
| 982 | assert_file_exists(args.tdnn_model) | 1063 | assert_file_exists(args.tdnn_model) |
| 1064 | + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor | ||
| 1065 | + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder | ||
| 1066 | + assert ( | ||
| 1067 | + len(args.moonshine_uncached_decoder) == 0 | ||
| 1068 | + ), args.moonshine_uncached_decoder | ||
| 1069 | + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder | ||
| 983 | 1070 | ||
| 984 | recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc( | 1071 | recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc( |
| 985 | model=args.tdnn_model, | 1072 | model=args.tdnn_model, |
| @@ -990,6 +1077,21 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -990,6 +1077,21 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 990 | decoding_method=args.decoding_method, | 1077 | decoding_method=args.decoding_method, |
| 991 | provider=args.provider, | 1078 | provider=args.provider, |
| 992 | ) | 1079 | ) |
| 1080 | + elif args.moonshine_preprocessor: | ||
| 1081 | + assert_file_exists(args.moonshine_preprocessor) | ||
| 1082 | + assert_file_exists(args.moonshine_encoder) | ||
| 1083 | + assert_file_exists(args.moonshine_uncached_decoder) | ||
| 1084 | + assert_file_exists(args.moonshine_cached_decoder) | ||
| 1085 | + | ||
| 1086 | + recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine( | ||
| 1087 | + preprocessor=args.moonshine_preprocessor, | ||
| 1088 | + encoder=args.moonshine_encoder, | ||
| 1089 | + uncached_decoder=args.moonshine_uncached_decoder, | ||
| 1090 | + cached_decoder=args.moonshine_cached_decoder, | ||
| 1091 | + tokens=args.tokens, | ||
| 1092 | + num_threads=args.num_threads, | ||
| 1093 | + decoding_method=args.decoding_method, | ||
| 1094 | + ) | ||
| 993 | else: | 1095 | else: |
| 994 | raise ValueError("Please specify at least one model") | 1096 | raise ValueError("Please specify at least one model") |
| 995 | 1097 |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | + | ||
| 3 | +""" | ||
| 4 | +This file shows how to use a non-streaming Moonshine model from | ||
| 5 | +https://github.com/usefulsensors/moonshine | ||
| 6 | +to decode files. | ||
| 7 | + | ||
| 8 | +Please download model files from | ||
| 9 | +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models | ||
| 10 | + | ||
| 11 | +For instance, | ||
| 12 | + | ||
| 13 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 | ||
| 14 | +tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 | ||
| 15 | +rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 | ||
| 16 | +""" | ||
| 17 | + | ||
| 18 | +import datetime as dt | ||
| 19 | +from pathlib import Path | ||
| 20 | + | ||
| 21 | +import sherpa_onnx | ||
| 22 | +import soundfile as sf | ||
| 23 | + | ||
| 24 | + | ||
| 25 | +def create_recognizer(): | ||
| 26 | + preprocessor = "./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx" | ||
| 27 | + encoder = "./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx" | ||
| 28 | + uncached_decoder = "./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx" | ||
| 29 | + cached_decoder = "./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx" | ||
| 30 | + | ||
| 31 | + tokens = "./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt" | ||
| 32 | + test_wav = "./sherpa-onnx-moonshine-tiny-en-int8/test_wavs/0.wav" | ||
| 33 | + | ||
| 34 | + if not Path(preprocessor).is_file() or not Path(test_wav).is_file(): | ||
| 35 | + raise ValueError( | ||
| 36 | + """Please download model files from | ||
| 37 | + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models | ||
| 38 | + """ | ||
| 39 | + ) | ||
| 40 | + return ( | ||
| 41 | + sherpa_onnx.OfflineRecognizer.from_moonshine( | ||
| 42 | + preprocessor=preprocessor, | ||
| 43 | + encoder=encoder, | ||
| 44 | + uncached_decoder=uncached_decoder, | ||
| 45 | + cached_decoder=cached_decoder, | ||
| 46 | + tokens=tokens, | ||
| 47 | + debug=True, | ||
| 48 | + ), | ||
| 49 | + test_wav, | ||
| 50 | + ) | ||
| 51 | + | ||
| 52 | + | ||
| 53 | +def main(): | ||
| 54 | + recognizer, wave_filename = create_recognizer() | ||
| 55 | + | ||
| 56 | + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) | ||
| 57 | + audio = audio[:, 0] # only use the first channel | ||
| 58 | + | ||
| 59 | + # audio is a 1-D float32 numpy array normalized to the range [-1, 1] | ||
| 60 | + # sample_rate does not need to be 16000 Hz | ||
| 61 | + | ||
| 62 | + start_t = dt.datetime.now() | ||
| 63 | + | ||
| 64 | + stream = recognizer.create_stream() | ||
| 65 | + stream.accept_waveform(sample_rate, audio) | ||
| 66 | + recognizer.decode_stream(stream) | ||
| 67 | + | ||
| 68 | + end_t = dt.datetime.now() | ||
| 69 | + elapsed_seconds = (end_t - start_t).total_seconds() | ||
| 70 | + duration = audio.shape[-1] / sample_rate | ||
| 71 | + rtf = elapsed_seconds / duration | ||
| 72 | + | ||
| 73 | + print(stream.result) | ||
| 74 | + print(wave_filename) | ||
| 75 | + print("Text:", stream.result.text) | ||
| 76 | + print(f"Audio duration:\t{duration:.3f} s") | ||
| 77 | + print(f"Elapsed:\t{elapsed_seconds:.3f} s") | ||
| 78 | + print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}") | ||
| 79 | + | ||
| 80 | + | ||
| 81 | +if __name__ == "__main__": | ||
| 82 | + main() |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | + | ||
| 3 | +""" | ||
| 4 | +This file shows how to use a non-streaming whisper model from | ||
| 5 | +https://github.com/openai/whisper | ||
| 6 | +to decode files. | ||
| 7 | + | ||
| 8 | +Please download model files from | ||
| 9 | +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models | ||
| 10 | + | ||
| 11 | +For instance, | ||
| 12 | + | ||
| 13 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2 | ||
| 14 | +tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2 | ||
| 15 | +rm sherpa-onnx-whisper-tiny.en.tar.bz2 | ||
| 16 | +""" | ||
| 17 | + | ||
| 18 | +import datetime as dt | ||
| 19 | +from pathlib import Path | ||
| 20 | + | ||
| 21 | +import sherpa_onnx | ||
| 22 | +import soundfile as sf | ||
| 23 | + | ||
| 24 | + | ||
| 25 | +def create_recognizer(): | ||
| 26 | + encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx" | ||
| 27 | + decoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx" | ||
| 28 | + tokens = "./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt" | ||
| 29 | + test_wav = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav" | ||
| 30 | + | ||
| 31 | + if not Path(encoder).is_file() or not Path(test_wav).is_file(): | ||
| 32 | + raise ValueError( | ||
| 33 | + """Please download model files from | ||
| 34 | + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models | ||
| 35 | + """ | ||
| 36 | + ) | ||
| 37 | + return ( | ||
| 38 | + sherpa_onnx.OfflineRecognizer.from_whisper( | ||
| 39 | + encoder=encoder, | ||
| 40 | + decoder=decoder, | ||
| 41 | + tokens=tokens, | ||
| 42 | + debug=True, | ||
| 43 | + ), | ||
| 44 | + test_wav, | ||
| 45 | + ) | ||
| 46 | + | ||
| 47 | + | ||
| 48 | +def main(): | ||
| 49 | + recognizer, wave_filename = create_recognizer() | ||
| 50 | + | ||
| 51 | + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) | ||
| 52 | + audio = audio[:, 0] # only use the first channel | ||
| 53 | + | ||
| 54 | + # audio is a 1-D float32 numpy array normalized to the range [-1, 1] | ||
| 55 | + # sample_rate does not need to be 16000 Hz | ||
| 56 | + | ||
| 57 | + start_t = dt.datetime.now() | ||
| 58 | + | ||
| 59 | + stream = recognizer.create_stream() | ||
| 60 | + stream.accept_waveform(sample_rate, audio) | ||
| 61 | + recognizer.decode_stream(stream) | ||
| 62 | + | ||
| 63 | + end_t = dt.datetime.now() | ||
| 64 | + elapsed_seconds = (end_t - start_t).total_seconds() | ||
| 65 | + duration = audio.shape[-1] / sample_rate | ||
| 66 | + rtf = elapsed_seconds / duration | ||
| 67 | + | ||
| 68 | + print(stream.result) | ||
| 69 | + print(wave_filename) | ||
| 70 | + print("Text:", stream.result.text) | ||
| 71 | + print(f"Audio duration:\t{duration:.3f} s") | ||
| 72 | + print(f"Elapsed:\t{elapsed_seconds:.3f} s") | ||
| 73 | + print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}") | ||
| 74 | + | ||
| 75 | + | ||
| 76 | +if __name__ == "__main__": | ||
| 77 | + main() |
| @@ -35,7 +35,18 @@ Note that you need a non-streaming model for this script. | @@ -35,7 +35,18 @@ Note that you need a non-streaming model for this script. | ||
| 35 | --sample-rate=16000 \ | 35 | --sample-rate=16000 \ |
| 36 | --feature-dim=80 | 36 | --feature-dim=80 |
| 37 | 37 | ||
| 38 | -(3) For Whisper models | 38 | +(3) For Moonshine models |
| 39 | + | ||
| 40 | +./python-api-examples/vad-with-non-streaming-asr.py \ | ||
| 41 | + --silero-vad-model=/path/to/silero_vad.onnx \ | ||
| 42 | + --moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \ | ||
| 43 | + --moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \ | ||
| 44 | + --moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \ | ||
| 45 | + --moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \ | ||
| 46 | + --tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt \ | ||
| 47 | + --num-threads=2 | ||
| 48 | + | ||
| 49 | +(4) For Whisper models | ||
| 39 | 50 | ||
| 40 | ./python-api-examples/vad-with-non-streaming-asr.py \ | 51 | ./python-api-examples/vad-with-non-streaming-asr.py \ |
| 41 | --silero-vad-model=/path/to/silero_vad.onnx \ | 52 | --silero-vad-model=/path/to/silero_vad.onnx \ |
| @@ -45,7 +56,7 @@ Note that you need a non-streaming model for this script. | @@ -45,7 +56,7 @@ Note that you need a non-streaming model for this script. | ||
| 45 | --whisper-task=transcribe \ | 56 | --whisper-task=transcribe \ |
| 46 | --num-threads=2 | 57 | --num-threads=2 |
| 47 | 58 | ||
| 48 | -(4) For SenseVoice CTC models | 59 | +(5) For SenseVoice CTC models |
| 49 | 60 | ||
| 50 | ./python-api-examples/vad-with-non-streaming-asr.py \ | 61 | ./python-api-examples/vad-with-non-streaming-asr.py \ |
| 51 | --silero-vad-model=/path/to/silero_vad.onnx \ | 62 | --silero-vad-model=/path/to/silero_vad.onnx \ |
| @@ -193,6 +204,34 @@ def get_args(): | @@ -193,6 +204,34 @@ def get_args(): | ||
| 193 | ) | 204 | ) |
| 194 | 205 | ||
| 195 | parser.add_argument( | 206 | parser.add_argument( |
| 207 | + "--moonshine-preprocessor", | ||
| 208 | + default="", | ||
| 209 | + type=str, | ||
| 210 | + help="Path to moonshine preprocessor model", | ||
| 211 | + ) | ||
| 212 | + | ||
| 213 | + parser.add_argument( | ||
| 214 | + "--moonshine-encoder", | ||
| 215 | + default="", | ||
| 216 | + type=str, | ||
| 217 | + help="Path to moonshine encoder model", | ||
| 218 | + ) | ||
| 219 | + | ||
| 220 | + parser.add_argument( | ||
| 221 | + "--moonshine-uncached-decoder", | ||
| 222 | + default="", | ||
| 223 | + type=str, | ||
| 224 | + help="Path to moonshine uncached decoder model", | ||
| 225 | + ) | ||
| 226 | + | ||
| 227 | + parser.add_argument( | ||
| 228 | + "--moonshine-cached-decoder", | ||
| 229 | + default="", | ||
| 230 | + type=str, | ||
| 231 | + help="Path to moonshine cached decoder model", | ||
| 232 | + ) | ||
| 233 | + | ||
| 234 | + parser.add_argument( | ||
| 196 | "--blank-penalty", | 235 | "--blank-penalty", |
| 197 | type=float, | 236 | type=float, |
| 198 | default=0.0, | 237 | default=0.0, |
| @@ -251,6 +290,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -251,6 +290,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 251 | assert len(args.sense_voice) == 0, args.sense_voice | 290 | assert len(args.sense_voice) == 0, args.sense_voice |
| 252 | assert len(args.whisper_encoder) == 0, args.whisper_encoder | 291 | assert len(args.whisper_encoder) == 0, args.whisper_encoder |
| 253 | assert len(args.whisper_decoder) == 0, args.whisper_decoder | 292 | assert len(args.whisper_decoder) == 0, args.whisper_decoder |
| 293 | + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor | ||
| 294 | + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder | ||
| 295 | + assert ( | ||
| 296 | + len(args.moonshine_uncached_decoder) == 0 | ||
| 297 | + ), args.moonshine_uncached_decoder | ||
| 298 | + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder | ||
| 254 | 299 | ||
| 255 | assert_file_exists(args.encoder) | 300 | assert_file_exists(args.encoder) |
| 256 | assert_file_exists(args.decoder) | 301 | assert_file_exists(args.decoder) |
| @@ -272,6 +317,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -272,6 +317,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 272 | assert len(args.sense_voice) == 0, args.sense_voice | 317 | assert len(args.sense_voice) == 0, args.sense_voice |
| 273 | assert len(args.whisper_encoder) == 0, args.whisper_encoder | 318 | assert len(args.whisper_encoder) == 0, args.whisper_encoder |
| 274 | assert len(args.whisper_decoder) == 0, args.whisper_decoder | 319 | assert len(args.whisper_decoder) == 0, args.whisper_decoder |
| 320 | + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor | ||
| 321 | + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder | ||
| 322 | + assert ( | ||
| 323 | + len(args.moonshine_uncached_decoder) == 0 | ||
| 324 | + ), args.moonshine_uncached_decoder | ||
| 325 | + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder | ||
| 275 | 326 | ||
| 276 | assert_file_exists(args.paraformer) | 327 | assert_file_exists(args.paraformer) |
| 277 | 328 | ||
| @@ -287,6 +338,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -287,6 +338,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 287 | elif args.sense_voice: | 338 | elif args.sense_voice: |
| 288 | assert len(args.whisper_encoder) == 0, args.whisper_encoder | 339 | assert len(args.whisper_encoder) == 0, args.whisper_encoder |
| 289 | assert len(args.whisper_decoder) == 0, args.whisper_decoder | 340 | assert len(args.whisper_decoder) == 0, args.whisper_decoder |
| 341 | + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor | ||
| 342 | + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder | ||
| 343 | + assert ( | ||
| 344 | + len(args.moonshine_uncached_decoder) == 0 | ||
| 345 | + ), args.moonshine_uncached_decoder | ||
| 346 | + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder | ||
| 290 | 347 | ||
| 291 | assert_file_exists(args.sense_voice) | 348 | assert_file_exists(args.sense_voice) |
| 292 | recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice( | 349 | recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice( |
| @@ -299,6 +356,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -299,6 +356,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 299 | elif args.whisper_encoder: | 356 | elif args.whisper_encoder: |
| 300 | assert_file_exists(args.whisper_encoder) | 357 | assert_file_exists(args.whisper_encoder) |
| 301 | assert_file_exists(args.whisper_decoder) | 358 | assert_file_exists(args.whisper_decoder) |
| 359 | + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor | ||
| 360 | + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder | ||
| 361 | + assert ( | ||
| 362 | + len(args.moonshine_uncached_decoder) == 0 | ||
| 363 | + ), args.moonshine_uncached_decoder | ||
| 364 | + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder | ||
| 302 | 365 | ||
| 303 | recognizer = sherpa_onnx.OfflineRecognizer.from_whisper( | 366 | recognizer = sherpa_onnx.OfflineRecognizer.from_whisper( |
| 304 | encoder=args.whisper_encoder, | 367 | encoder=args.whisper_encoder, |
| @@ -311,6 +374,22 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -311,6 +374,22 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 311 | task=args.whisper_task, | 374 | task=args.whisper_task, |
| 312 | tail_paddings=args.whisper_tail_paddings, | 375 | tail_paddings=args.whisper_tail_paddings, |
| 313 | ) | 376 | ) |
| 377 | + elif args.moonshine_preprocessor: | ||
| 378 | + assert_file_exists(args.moonshine_preprocessor) | ||
| 379 | + assert_file_exists(args.moonshine_encoder) | ||
| 380 | + assert_file_exists(args.moonshine_uncached_decoder) | ||
| 381 | + assert_file_exists(args.moonshine_cached_decoder) | ||
| 382 | + | ||
| 383 | + recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine( | ||
| 384 | + preprocessor=args.moonshine_preprocessor, | ||
| 385 | + encoder=args.moonshine_encoder, | ||
| 386 | + uncached_decoder=args.moonshine_uncached_decoder, | ||
| 387 | + cached_decoder=args.moonshine_cached_decoder, | ||
| 388 | + tokens=args.tokens, | ||
| 389 | + num_threads=args.num_threads, | ||
| 390 | + decoding_method=args.decoding_method, | ||
| 391 | + debug=args.debug, | ||
| 392 | + ) | ||
| 314 | else: | 393 | else: |
| 315 | raise ValueError("Please specify at least one model") | 394 | raise ValueError("Please specify at least one model") |
| 316 | 395 |
| @@ -29,6 +29,9 @@ set(sources | @@ -29,6 +29,9 @@ set(sources | ||
| 29 | offline-lm-config.cc | 29 | offline-lm-config.cc |
| 30 | offline-lm.cc | 30 | offline-lm.cc |
| 31 | offline-model-config.cc | 31 | offline-model-config.cc |
| 32 | + offline-moonshine-greedy-search-decoder.cc | ||
| 33 | + offline-moonshine-model-config.cc | ||
| 34 | + offline-moonshine-model.cc | ||
| 32 | offline-nemo-enc-dec-ctc-model-config.cc | 35 | offline-nemo-enc-dec-ctc-model-config.cc |
| 33 | offline-nemo-enc-dec-ctc-model.cc | 36 | offline-nemo-enc-dec-ctc-model.cc |
| 34 | offline-paraformer-greedy-search-decoder.cc | 37 | offline-paraformer-greedy-search-decoder.cc |
| @@ -19,6 +19,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { | @@ -19,6 +19,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { | ||
| 19 | zipformer_ctc.Register(po); | 19 | zipformer_ctc.Register(po); |
| 20 | wenet_ctc.Register(po); | 20 | wenet_ctc.Register(po); |
| 21 | sense_voice.Register(po); | 21 | sense_voice.Register(po); |
| 22 | + moonshine.Register(po); | ||
| 22 | 23 | ||
| 23 | po->Register("telespeech-ctc", &telespeech_ctc, | 24 | po->Register("telespeech-ctc", &telespeech_ctc, |
| 24 | "Path to model.onnx for telespeech ctc"); | 25 | "Path to model.onnx for telespeech ctc"); |
| @@ -99,6 +100,10 @@ bool OfflineModelConfig::Validate() const { | @@ -99,6 +100,10 @@ bool OfflineModelConfig::Validate() const { | ||
| 99 | return sense_voice.Validate(); | 100 | return sense_voice.Validate(); |
| 100 | } | 101 | } |
| 101 | 102 | ||
| 103 | + if (!moonshine.preprocessor.empty()) { | ||
| 104 | + return moonshine.Validate(); | ||
| 105 | + } | ||
| 106 | + | ||
| 102 | if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) { | 107 | if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) { |
| 103 | SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist", | 108 | SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist", |
| 104 | telespeech_ctc.c_str()); | 109 | telespeech_ctc.c_str()); |
| @@ -124,6 +129,7 @@ std::string OfflineModelConfig::ToString() const { | @@ -124,6 +129,7 @@ std::string OfflineModelConfig::ToString() const { | ||
| 124 | os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", "; | 129 | os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", "; |
| 125 | os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; | 130 | os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; |
| 126 | os << "sense_voice=" << sense_voice.ToString() << ", "; | 131 | os << "sense_voice=" << sense_voice.ToString() << ", "; |
| 132 | + os << "moonshine=" << moonshine.ToString() << ", "; | ||
| 127 | os << "telespeech_ctc=\"" << telespeech_ctc << "\", "; | 133 | os << "telespeech_ctc=\"" << telespeech_ctc << "\", "; |
| 128 | os << "tokens=\"" << tokens << "\", "; | 134 | os << "tokens=\"" << tokens << "\", "; |
| 129 | os << "num_threads=" << num_threads << ", "; | 135 | os << "num_threads=" << num_threads << ", "; |
| @@ -6,6 +6,7 @@ | @@ -6,6 +6,7 @@ | ||
| 6 | 6 | ||
| 7 | #include <string> | 7 | #include <string> |
| 8 | 8 | ||
| 9 | +#include "sherpa-onnx/csrc/offline-moonshine-model-config.h" | ||
| 9 | #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h" | 10 | #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h" |
| 10 | #include "sherpa-onnx/csrc/offline-paraformer-model-config.h" | 11 | #include "sherpa-onnx/csrc/offline-paraformer-model-config.h" |
| 11 | #include "sherpa-onnx/csrc/offline-sense-voice-model-config.h" | 12 | #include "sherpa-onnx/csrc/offline-sense-voice-model-config.h" |
| @@ -26,6 +27,7 @@ struct OfflineModelConfig { | @@ -26,6 +27,7 @@ struct OfflineModelConfig { | ||
| 26 | OfflineZipformerCtcModelConfig zipformer_ctc; | 27 | OfflineZipformerCtcModelConfig zipformer_ctc; |
| 27 | OfflineWenetCtcModelConfig wenet_ctc; | 28 | OfflineWenetCtcModelConfig wenet_ctc; |
| 28 | OfflineSenseVoiceModelConfig sense_voice; | 29 | OfflineSenseVoiceModelConfig sense_voice; |
| 30 | + OfflineMoonshineModelConfig moonshine; | ||
| 29 | std::string telespeech_ctc; | 31 | std::string telespeech_ctc; |
| 30 | 32 | ||
| 31 | std::string tokens; | 33 | std::string tokens; |
| @@ -56,6 +58,7 @@ struct OfflineModelConfig { | @@ -56,6 +58,7 @@ struct OfflineModelConfig { | ||
| 56 | const OfflineZipformerCtcModelConfig &zipformer_ctc, | 58 | const OfflineZipformerCtcModelConfig &zipformer_ctc, |
| 57 | const OfflineWenetCtcModelConfig &wenet_ctc, | 59 | const OfflineWenetCtcModelConfig &wenet_ctc, |
| 58 | const OfflineSenseVoiceModelConfig &sense_voice, | 60 | const OfflineSenseVoiceModelConfig &sense_voice, |
| 61 | + const OfflineMoonshineModelConfig &moonshine, | ||
| 59 | const std::string &telespeech_ctc, | 62 | const std::string &telespeech_ctc, |
| 60 | const std::string &tokens, int32_t num_threads, bool debug, | 63 | const std::string &tokens, int32_t num_threads, bool debug, |
| 61 | const std::string &provider, const std::string &model_type, | 64 | const std::string &provider, const std::string &model_type, |
| @@ -69,6 +72,7 @@ struct OfflineModelConfig { | @@ -69,6 +72,7 @@ struct OfflineModelConfig { | ||
| 69 | zipformer_ctc(zipformer_ctc), | 72 | zipformer_ctc(zipformer_ctc), |
| 70 | wenet_ctc(wenet_ctc), | 73 | wenet_ctc(wenet_ctc), |
| 71 | sense_voice(sense_voice), | 74 | sense_voice(sense_voice), |
| 75 | + moonshine(moonshine), | ||
| 72 | telespeech_ctc(telespeech_ctc), | 76 | telespeech_ctc(telespeech_ctc), |
| 73 | tokens(tokens), | 77 | tokens(tokens), |
| 74 | num_threads(num_threads), | 78 | num_threads(num_threads), |
sherpa-onnx/csrc/offline-moonshine-decoder.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-moonshine-decoder.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_DECODER_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_DECODER_H_ | ||
| 7 | + | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +struct OfflineMoonshineDecoderResult { | ||
| 15 | + /// The decoded token IDs | ||
| 16 | + std::vector<int32_t> tokens; | ||
| 17 | +}; | ||
| 18 | + | ||
| 19 | +class OfflineMoonshineDecoder { | ||
| 20 | + public: | ||
| 21 | + virtual ~OfflineMoonshineDecoder() = default; | ||
| 22 | + | ||
| 23 | + /** Run beam search given the output from the moonshine encoder model. | ||
| 24 | + * | ||
| 25 | + * @param encoder_out A 3-D tensor of shape (batch_size, T, dim) | ||
| 26 | + * @return Return a vector of size `N` containing the decoded results. | ||
| 27 | + */ | ||
| 28 | + virtual std::vector<OfflineMoonshineDecoderResult> Decode( | ||
| 29 | + Ort::Value encoder_out) = 0; | ||
| 30 | +}; | ||
| 31 | + | ||
| 32 | +} // namespace sherpa_onnx | ||
| 33 | + | ||
| 34 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_DECODER_H_ |
| 1 | +// sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <utility> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 11 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +std::vector<OfflineMoonshineDecoderResult> | ||
| 16 | +OfflineMoonshineGreedySearchDecoder::Decode(Ort::Value encoder_out) { | ||
| 17 | + auto encoder_out_shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 18 | + if (encoder_out_shape[0] != 1) { | ||
| 19 | + SHERPA_ONNX_LOGE("Support only batch size == 1. Given: %d\n", | ||
| 20 | + static_cast<int32_t>(encoder_out_shape[0])); | ||
| 21 | + return {}; | ||
| 22 | + } | ||
| 23 | + | ||
| 24 | + auto memory_info = | ||
| 25 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 26 | + | ||
| 27 | + // encoder_out_shape[1] * 384 is the number of audio samples | ||
| 28 | + // 16000 is the sample rate | ||
| 29 | + // | ||
| 30 | + // | ||
| 31 | + // 384 is from the moonshine paper | ||
| 32 | + int32_t max_len = | ||
| 33 | + static_cast<int32_t>(encoder_out_shape[1] * 384 / 16000.0 * 6); | ||
| 34 | + | ||
| 35 | + int32_t sos = 1; | ||
| 36 | + int32_t eos = 2; | ||
| 37 | + int32_t seq_len = 1; | ||
| 38 | + | ||
| 39 | + std::vector<int32_t> tokens; | ||
| 40 | + | ||
| 41 | + std::array<int64_t, 2> token_shape = {1, 1}; | ||
| 42 | + int64_t seq_len_shape = 1; | ||
| 43 | + | ||
| 44 | + Ort::Value token_tensor = Ort::Value::CreateTensor( | ||
| 45 | + memory_info, &sos, 1, token_shape.data(), token_shape.size()); | ||
| 46 | + | ||
| 47 | + Ort::Value seq_len_tensor = | ||
| 48 | + Ort::Value::CreateTensor(memory_info, &seq_len, 1, &seq_len_shape, 1); | ||
| 49 | + | ||
| 50 | + Ort::Value logits{nullptr}; | ||
| 51 | + std::vector<Ort::Value> states; | ||
| 52 | + | ||
| 53 | + std::tie(logits, states) = model_->ForwardUnCachedDecoder( | ||
| 54 | + std::move(token_tensor), std::move(seq_len_tensor), View(&encoder_out)); | ||
| 55 | + | ||
| 56 | + int32_t vocab_size = logits.GetTensorTypeAndShapeInfo().GetShape()[2]; | ||
| 57 | + | ||
| 58 | + for (int32_t i = 0; i != max_len; ++i) { | ||
| 59 | + const float *p = logits.GetTensorData<float>(); | ||
| 60 | + | ||
| 61 | + int32_t max_token_id = static_cast<int32_t>( | ||
| 62 | + std::distance(p, std::max_element(p, p + vocab_size))); | ||
| 63 | + if (max_token_id == eos) { | ||
| 64 | + break; | ||
| 65 | + } | ||
| 66 | + tokens.push_back(max_token_id); | ||
| 67 | + | ||
| 68 | + seq_len += 1; | ||
| 69 | + | ||
| 70 | + token_tensor = Ort::Value::CreateTensor( | ||
| 71 | + memory_info, &tokens.back(), 1, token_shape.data(), token_shape.size()); | ||
| 72 | + | ||
| 73 | + seq_len_tensor = | ||
| 74 | + Ort::Value::CreateTensor(memory_info, &seq_len, 1, &seq_len_shape, 1); | ||
| 75 | + | ||
| 76 | + std::tie(logits, states) = model_->ForwardCachedDecoder( | ||
| 77 | + std::move(token_tensor), std::move(seq_len_tensor), View(&encoder_out), | ||
| 78 | + std::move(states)); | ||
| 79 | + } | ||
| 80 | + | ||
| 81 | + OfflineMoonshineDecoderResult ans; | ||
| 82 | + ans.tokens = std::move(tokens); | ||
| 83 | + | ||
| 84 | + return {ans}; | ||
| 85 | +} | ||
| 86 | + | ||
| 87 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_GREEDY_SEARCH_DECODER_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_GREEDY_SEARCH_DECODER_H_ | ||
| 7 | + | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/offline-moonshine-decoder.h" | ||
| 11 | +#include "sherpa-onnx/csrc/offline-moonshine-model.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +class OfflineMoonshineGreedySearchDecoder : public OfflineMoonshineDecoder { | ||
| 16 | + public: | ||
| 17 | + explicit OfflineMoonshineGreedySearchDecoder(OfflineMoonshineModel *model) | ||
| 18 | + : model_(model) {} | ||
| 19 | + | ||
| 20 | + std::vector<OfflineMoonshineDecoderResult> Decode( | ||
| 21 | + Ort::Value encoder_out) override; | ||
| 22 | + | ||
| 23 | + private: | ||
| 24 | + OfflineMoonshineModel *model_; // not owned | ||
| 25 | +}; | ||
| 26 | + | ||
| 27 | +} // namespace sherpa_onnx | ||
| 28 | + | ||
| 29 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_GREEDY_SEARCH_DECODER_H_ |
| 1 | +// sherpa-onnx/csrc/offline-moonshine-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-moonshine-model-config.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 8 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void OfflineMoonshineModelConfig::Register(ParseOptions *po) { | ||
| 13 | + po->Register("moonshine-preprocessor", &preprocessor, | ||
| 14 | + "Path to onnx preprocessor of moonshine, e.g., preprocess.onnx"); | ||
| 15 | + | ||
| 16 | + po->Register("moonshine-encoder", &encoder, | ||
| 17 | + "Path to onnx encoder of moonshine, e.g., encode.onnx"); | ||
| 18 | + | ||
| 19 | + po->Register( | ||
| 20 | + "moonshine-uncached-decoder", &uncached_decoder, | ||
| 21 | + "Path to onnx uncached_decoder of moonshine, e.g., uncached_decode.onnx"); | ||
| 22 | + | ||
| 23 | + po->Register( | ||
| 24 | + "moonshine-cached-decoder", &cached_decoder, | ||
| 25 | + "Path to onnx cached_decoder of moonshine, e.g., cached_decode.onnx"); | ||
| 26 | +} | ||
| 27 | + | ||
| 28 | +bool OfflineMoonshineModelConfig::Validate() const { | ||
| 29 | + if (preprocessor.empty()) { | ||
| 30 | + SHERPA_ONNX_LOGE("Please provide --moonshine-preprocessor"); | ||
| 31 | + return false; | ||
| 32 | + } | ||
| 33 | + | ||
| 34 | + if (!FileExists(preprocessor)) { | ||
| 35 | + SHERPA_ONNX_LOGE("moonshine preprocessor file '%s' does not exist", | ||
| 36 | + preprocessor.c_str()); | ||
| 37 | + return false; | ||
| 38 | + } | ||
| 39 | + | ||
| 40 | + if (encoder.empty()) { | ||
| 41 | + SHERPA_ONNX_LOGE("Please provide --moonshine-encoder"); | ||
| 42 | + return false; | ||
| 43 | + } | ||
| 44 | + | ||
| 45 | + if (!FileExists(encoder)) { | ||
| 46 | + SHERPA_ONNX_LOGE("moonshine encoder file '%s' does not exist", | ||
| 47 | + encoder.c_str()); | ||
| 48 | + return false; | ||
| 49 | + } | ||
| 50 | + | ||
| 51 | + if (uncached_decoder.empty()) { | ||
| 52 | + SHERPA_ONNX_LOGE("Please provide --moonshine-uncached-decoder"); | ||
| 53 | + return false; | ||
| 54 | + } | ||
| 55 | + | ||
| 56 | + if (!FileExists(uncached_decoder)) { | ||
| 57 | + SHERPA_ONNX_LOGE("moonshine uncached decoder file '%s' does not exist", | ||
| 58 | + uncached_decoder.c_str()); | ||
| 59 | + return false; | ||
| 60 | + } | ||
| 61 | + | ||
| 62 | + if (cached_decoder.empty()) { | ||
| 63 | + SHERPA_ONNX_LOGE("Please provide --moonshine-cached-decoder"); | ||
| 64 | + return false; | ||
| 65 | + } | ||
| 66 | + | ||
| 67 | + if (!FileExists(cached_decoder)) { | ||
| 68 | + SHERPA_ONNX_LOGE("moonshine cached decoder file '%s' does not exist", | ||
| 69 | + cached_decoder.c_str()); | ||
| 70 | + return false; | ||
| 71 | + } | ||
| 72 | + | ||
| 73 | + return true; | ||
| 74 | +} | ||
| 75 | + | ||
| 76 | +std::string OfflineMoonshineModelConfig::ToString() const { | ||
| 77 | + std::ostringstream os; | ||
| 78 | + | ||
| 79 | + os << "OfflineMoonshineModelConfig("; | ||
| 80 | + os << "preprocessor=\"" << preprocessor << "\", "; | ||
| 81 | + os << "encoder=\"" << encoder << "\", "; | ||
| 82 | + os << "uncached_decoder=\"" << uncached_decoder << "\", "; | ||
| 83 | + os << "cached_decoder=\"" << cached_decoder << "\")"; | ||
| 84 | + | ||
| 85 | + return os.str(); | ||
| 86 | +} | ||
| 87 | + | ||
| 88 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-moonshine-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +struct OfflineMoonshineModelConfig { | ||
| 14 | + std::string preprocessor; | ||
| 15 | + std::string encoder; | ||
| 16 | + std::string uncached_decoder; | ||
| 17 | + std::string cached_decoder; | ||
| 18 | + | ||
| 19 | + OfflineMoonshineModelConfig() = default; | ||
| 20 | + OfflineMoonshineModelConfig(const std::string &preprocessor, | ||
| 21 | + const std::string &encoder, | ||
| 22 | + const std::string &uncached_decoder, | ||
| 23 | + const std::string &cached_decoder) | ||
| 24 | + : preprocessor(preprocessor), | ||
| 25 | + encoder(encoder), | ||
| 26 | + uncached_decoder(uncached_decoder), | ||
| 27 | + cached_decoder(cached_decoder) {} | ||
| 28 | + | ||
| 29 | + void Register(ParseOptions *po); | ||
| 30 | + bool Validate() const; | ||
| 31 | + | ||
| 32 | + std::string ToString() const; | ||
| 33 | +}; | ||
| 34 | + | ||
| 35 | +} // namespace sherpa_onnx | ||
| 36 | + | ||
| 37 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_ |
sherpa-onnx/csrc/offline-moonshine-model.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-moonshine-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-moonshine-model.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | +#include <utility> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 12 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 13 | +#include "sherpa-onnx/csrc/session.h" | ||
| 14 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 15 | + | ||
| 16 | +namespace sherpa_onnx { | ||
| 17 | + | ||
| 18 | +class OfflineMoonshineModel::Impl { | ||
| 19 | + public: | ||
| 20 | + explicit Impl(const OfflineModelConfig &config) | ||
| 21 | + : config_(config), | ||
| 22 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 23 | + sess_opts_(GetSessionOptions(config)), | ||
| 24 | + allocator_{} { | ||
| 25 | + { | ||
| 26 | + auto buf = ReadFile(config.moonshine.preprocessor); | ||
| 27 | + InitPreprocessor(buf.data(), buf.size()); | ||
| 28 | + } | ||
| 29 | + | ||
| 30 | + { | ||
| 31 | + auto buf = ReadFile(config.moonshine.encoder); | ||
| 32 | + InitEncoder(buf.data(), buf.size()); | ||
| 33 | + } | ||
| 34 | + | ||
| 35 | + { | ||
| 36 | + auto buf = ReadFile(config.moonshine.uncached_decoder); | ||
| 37 | + InitUnCachedDecoder(buf.data(), buf.size()); | ||
| 38 | + } | ||
| 39 | + | ||
| 40 | + { | ||
| 41 | + auto buf = ReadFile(config.moonshine.cached_decoder); | ||
| 42 | + InitCachedDecoder(buf.data(), buf.size()); | ||
| 43 | + } | ||
| 44 | + } | ||
| 45 | + | ||
| 46 | +#if __ANDROID_API__ >= 9 | ||
| 47 | + Impl(AAssetManager *mgr, const OfflineModelConfig &config) | ||
| 48 | + : config_(config), | ||
| 49 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 50 | + sess_opts_(GetSessionOptions(config)), | ||
| 51 | + allocator_{} { | ||
| 52 | + { | ||
| 53 | + auto buf = ReadFile(mgr, config.moonshine.preprocessor); | ||
| 54 | + InitPreprocessor(buf.data(), buf.size()); | ||
| 55 | + } | ||
| 56 | + | ||
| 57 | + { | ||
| 58 | + auto buf = ReadFile(mgr, config.moonshine.encoder); | ||
| 59 | + InitEncoder(buf.data(), buf.size()); | ||
| 60 | + } | ||
| 61 | + | ||
| 62 | + { | ||
| 63 | + auto buf = ReadFile(mgr, config.moonshine.uncached_decoder); | ||
| 64 | + InitUnCachedDecoder(buf.data(), buf.size()); | ||
| 65 | + } | ||
| 66 | + | ||
| 67 | + { | ||
| 68 | + auto buf = ReadFile(mgr, config.moonshine.cached_decoder); | ||
| 69 | + InitCachedDecoder(buf.data(), buf.size()); | ||
| 70 | + } | ||
| 71 | + } | ||
| 72 | +#endif | ||
| 73 | + | ||
| 74 | + Ort::Value ForwardPreprocessor(Ort::Value audio) { | ||
| 75 | + auto features = preprocessor_sess_->Run( | ||
| 76 | + {}, preprocessor_input_names_ptr_.data(), &audio, 1, | ||
| 77 | + preprocessor_output_names_ptr_.data(), | ||
| 78 | + preprocessor_output_names_ptr_.size()); | ||
| 79 | + | ||
| 80 | + return std::move(features[0]); | ||
| 81 | + } | ||
| 82 | + | ||
| 83 | + Ort::Value ForwardEncoder(Ort::Value features, Ort::Value features_len) { | ||
| 84 | + std::array<Ort::Value, 2> encoder_inputs{std::move(features), | ||
| 85 | + std::move(features_len)}; | ||
| 86 | + auto encoder_out = encoder_sess_->Run( | ||
| 87 | + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(), | ||
| 88 | + encoder_inputs.size(), encoder_output_names_ptr_.data(), | ||
| 89 | + encoder_output_names_ptr_.size()); | ||
| 90 | + | ||
| 91 | + return std::move(encoder_out[0]); | ||
| 92 | + } | ||
| 93 | + | ||
| 94 | + std::pair<Ort::Value, std::vector<Ort::Value>> ForwardUnCachedDecoder( | ||
| 95 | + Ort::Value tokens, Ort::Value seq_len, Ort::Value encoder_out) { | ||
| 96 | + std::array<Ort::Value, 3> uncached_decoder_input = { | ||
| 97 | + std::move(tokens), | ||
| 98 | + std::move(encoder_out), | ||
| 99 | + std::move(seq_len), | ||
| 100 | + }; | ||
| 101 | + | ||
| 102 | + auto uncached_decoder_out = uncached_decoder_sess_->Run( | ||
| 103 | + {}, uncached_decoder_input_names_ptr_.data(), | ||
| 104 | + uncached_decoder_input.data(), uncached_decoder_input.size(), | ||
| 105 | + uncached_decoder_output_names_ptr_.data(), | ||
| 106 | + uncached_decoder_output_names_ptr_.size()); | ||
| 107 | + | ||
| 108 | + std::vector<Ort::Value> states; | ||
| 109 | + states.reserve(uncached_decoder_out.size() - 1); | ||
| 110 | + | ||
| 111 | + int32_t i = -1; | ||
| 112 | + for (auto &s : uncached_decoder_out) { | ||
| 113 | + ++i; | ||
| 114 | + if (i == 0) { | ||
| 115 | + continue; | ||
| 116 | + } | ||
| 117 | + | ||
| 118 | + states.push_back(std::move(s)); | ||
| 119 | + } | ||
| 120 | + | ||
| 121 | + return {std::move(uncached_decoder_out[0]), std::move(states)}; | ||
| 122 | + } | ||
| 123 | + | ||
| 124 | + std::pair<Ort::Value, std::vector<Ort::Value>> ForwardCachedDecoder( | ||
| 125 | + Ort::Value tokens, Ort::Value seq_len, Ort::Value encoder_out, | ||
| 126 | + std::vector<Ort::Value> states) { | ||
| 127 | + std::vector<Ort::Value> cached_decoder_input; | ||
| 128 | + cached_decoder_input.reserve(3 + states.size()); | ||
| 129 | + cached_decoder_input.push_back(std::move(tokens)); | ||
| 130 | + cached_decoder_input.push_back(std::move(encoder_out)); | ||
| 131 | + cached_decoder_input.push_back(std::move(seq_len)); | ||
| 132 | + | ||
| 133 | + for (auto &s : states) { | ||
| 134 | + cached_decoder_input.push_back(std::move(s)); | ||
| 135 | + } | ||
| 136 | + | ||
| 137 | + auto cached_decoder_out = cached_decoder_sess_->Run( | ||
| 138 | + {}, cached_decoder_input_names_ptr_.data(), cached_decoder_input.data(), | ||
| 139 | + cached_decoder_input.size(), cached_decoder_output_names_ptr_.data(), | ||
| 140 | + cached_decoder_output_names_ptr_.size()); | ||
| 141 | + | ||
| 142 | + std::vector<Ort::Value> next_states; | ||
| 143 | + next_states.reserve(cached_decoder_out.size() - 1); | ||
| 144 | + | ||
| 145 | + int32_t i = -1; | ||
| 146 | + for (auto &s : cached_decoder_out) { | ||
| 147 | + ++i; | ||
| 148 | + if (i == 0) { | ||
| 149 | + continue; | ||
| 150 | + } | ||
| 151 | + | ||
| 152 | + next_states.push_back(std::move(s)); | ||
| 153 | + } | ||
| 154 | + | ||
| 155 | + return {std::move(cached_decoder_out[0]), std::move(next_states)}; | ||
| 156 | + } | ||
| 157 | + | ||
| 158 | + OrtAllocator *Allocator() const { return allocator_; } | ||
| 159 | + | ||
| 160 | + private: | ||
| 161 | + void InitPreprocessor(void *model_data, size_t model_data_length) { | ||
| 162 | + preprocessor_sess_ = std::make_unique<Ort::Session>( | ||
| 163 | + env_, model_data, model_data_length, sess_opts_); | ||
| 164 | + | ||
| 165 | + GetInputNames(preprocessor_sess_.get(), &preprocessor_input_names_, | ||
| 166 | + &preprocessor_input_names_ptr_); | ||
| 167 | + | ||
| 168 | + GetOutputNames(preprocessor_sess_.get(), &preprocessor_output_names_, | ||
| 169 | + &preprocessor_output_names_ptr_); | ||
| 170 | + } | ||
| 171 | + | ||
| 172 | + void InitEncoder(void *model_data, size_t model_data_length) { | ||
| 173 | + encoder_sess_ = std::make_unique<Ort::Session>( | ||
| 174 | + env_, model_data, model_data_length, sess_opts_); | ||
| 175 | + | ||
| 176 | + GetInputNames(encoder_sess_.get(), &encoder_input_names_, | ||
| 177 | + &encoder_input_names_ptr_); | ||
| 178 | + | ||
| 179 | + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, | ||
| 180 | + &encoder_output_names_ptr_); | ||
| 181 | + } | ||
| 182 | + | ||
| 183 | + void InitUnCachedDecoder(void *model_data, size_t model_data_length) { | ||
| 184 | + uncached_decoder_sess_ = std::make_unique<Ort::Session>( | ||
| 185 | + env_, model_data, model_data_length, sess_opts_); | ||
| 186 | + | ||
| 187 | + GetInputNames(uncached_decoder_sess_.get(), &uncached_decoder_input_names_, | ||
| 188 | + &uncached_decoder_input_names_ptr_); | ||
| 189 | + | ||
| 190 | + GetOutputNames(uncached_decoder_sess_.get(), | ||
| 191 | + &uncached_decoder_output_names_, | ||
| 192 | + &uncached_decoder_output_names_ptr_); | ||
| 193 | + } | ||
| 194 | + | ||
| 195 | + void InitCachedDecoder(void *model_data, size_t model_data_length) { | ||
| 196 | + cached_decoder_sess_ = std::make_unique<Ort::Session>( | ||
| 197 | + env_, model_data, model_data_length, sess_opts_); | ||
| 198 | + | ||
| 199 | + GetInputNames(cached_decoder_sess_.get(), &cached_decoder_input_names_, | ||
| 200 | + &cached_decoder_input_names_ptr_); | ||
| 201 | + | ||
| 202 | + GetOutputNames(cached_decoder_sess_.get(), &cached_decoder_output_names_, | ||
| 203 | + &cached_decoder_output_names_ptr_); | ||
| 204 | + } | ||
| 205 | + | ||
| 206 | + private: | ||
| 207 | + OfflineModelConfig config_; | ||
| 208 | + Ort::Env env_; | ||
| 209 | + Ort::SessionOptions sess_opts_; | ||
| 210 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 211 | + | ||
| 212 | + std::unique_ptr<Ort::Session> preprocessor_sess_; | ||
| 213 | + std::unique_ptr<Ort::Session> encoder_sess_; | ||
| 214 | + std::unique_ptr<Ort::Session> uncached_decoder_sess_; | ||
| 215 | + std::unique_ptr<Ort::Session> cached_decoder_sess_; | ||
| 216 | + | ||
| 217 | + std::vector<std::string> preprocessor_input_names_; | ||
| 218 | + std::vector<const char *> preprocessor_input_names_ptr_; | ||
| 219 | + | ||
| 220 | + std::vector<std::string> preprocessor_output_names_; | ||
| 221 | + std::vector<const char *> preprocessor_output_names_ptr_; | ||
| 222 | + | ||
| 223 | + std::vector<std::string> encoder_input_names_; | ||
| 224 | + std::vector<const char *> encoder_input_names_ptr_; | ||
| 225 | + | ||
| 226 | + std::vector<std::string> encoder_output_names_; | ||
| 227 | + std::vector<const char *> encoder_output_names_ptr_; | ||
| 228 | + | ||
| 229 | + std::vector<std::string> uncached_decoder_input_names_; | ||
| 230 | + std::vector<const char *> uncached_decoder_input_names_ptr_; | ||
| 231 | + | ||
| 232 | + std::vector<std::string> uncached_decoder_output_names_; | ||
| 233 | + std::vector<const char *> uncached_decoder_output_names_ptr_; | ||
| 234 | + | ||
| 235 | + std::vector<std::string> cached_decoder_input_names_; | ||
| 236 | + std::vector<const char *> cached_decoder_input_names_ptr_; | ||
| 237 | + | ||
| 238 | + std::vector<std::string> cached_decoder_output_names_; | ||
| 239 | + std::vector<const char *> cached_decoder_output_names_ptr_; | ||
| 240 | +}; | ||
| 241 | + | ||
| 242 | +OfflineMoonshineModel::OfflineMoonshineModel(const OfflineModelConfig &config) | ||
| 243 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 244 | + | ||
| 245 | +#if __ANDROID_API__ >= 9 | ||
| 246 | +OfflineMoonshineModel::OfflineMoonshineModel(AAssetManager *mgr, | ||
| 247 | + const OfflineModelConfig &config) | ||
| 248 | + : impl_(std::make_unique<Impl>(mgr, config)) {} | ||
| 249 | +#endif | ||
| 250 | + | ||
| 251 | +OfflineMoonshineModel::~OfflineMoonshineModel() = default; | ||
| 252 | + | ||
| 253 | +Ort::Value OfflineMoonshineModel::ForwardPreprocessor(Ort::Value audio) const { | ||
| 254 | + return impl_->ForwardPreprocessor(std::move(audio)); | ||
| 255 | +} | ||
| 256 | + | ||
| 257 | +Ort::Value OfflineMoonshineModel::ForwardEncoder( | ||
| 258 | + Ort::Value features, Ort::Value features_len) const { | ||
| 259 | + return impl_->ForwardEncoder(std::move(features), std::move(features_len)); | ||
| 260 | +} | ||
| 261 | + | ||
| 262 | +std::pair<Ort::Value, std::vector<Ort::Value>> | ||
| 263 | +OfflineMoonshineModel::ForwardUnCachedDecoder(Ort::Value token, | ||
| 264 | + Ort::Value seq_len, | ||
| 265 | + Ort::Value encoder_out) const { | ||
| 266 | + return impl_->ForwardUnCachedDecoder(std::move(token), std::move(seq_len), | ||
| 267 | + std::move(encoder_out)); | ||
| 268 | +} | ||
| 269 | + | ||
| 270 | +std::pair<Ort::Value, std::vector<Ort::Value>> | ||
| 271 | +OfflineMoonshineModel::ForwardCachedDecoder( | ||
| 272 | + Ort::Value token, Ort::Value seq_len, Ort::Value encoder_out, | ||
| 273 | + std::vector<Ort::Value> states) const { | ||
| 274 | + return impl_->ForwardCachedDecoder(std::move(token), std::move(seq_len), | ||
| 275 | + std::move(encoder_out), std::move(states)); | ||
| 276 | +} | ||
| 277 | + | ||
| 278 | +OrtAllocator *OfflineMoonshineModel::Allocator() const { | ||
| 279 | + return impl_->Allocator(); | ||
| 280 | +} | ||
| 281 | + | ||
| 282 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/offline-moonshine-model.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-moonshine-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <string> | ||
| 9 | +#include <utility> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#if __ANDROID_API__ >= 9 | ||
| 13 | +#include "android/asset_manager.h" | ||
| 14 | +#include "android/asset_manager_jni.h" | ||
| 15 | +#endif | ||
| 16 | + | ||
| 17 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 18 | +#include "sherpa-onnx/csrc/offline-model-config.h" | ||
| 19 | + | ||
| 20 | +namespace sherpa_onnx { | ||
| 21 | + | ||
| 22 | +// please see | ||
| 23 | +// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/moonshine/test.py | ||
| 24 | +class OfflineMoonshineModel { | ||
| 25 | + public: | ||
| 26 | + explicit OfflineMoonshineModel(const OfflineModelConfig &config); | ||
| 27 | + | ||
| 28 | +#if __ANDROID_API__ >= 9 | ||
| 29 | + OfflineMoonshineModel(AAssetManager *mgr, const OfflineModelConfig &config); | ||
| 30 | +#endif | ||
| 31 | + | ||
| 32 | + ~OfflineMoonshineModel(); | ||
| 33 | + | ||
| 34 | + /** Run the preprocessor model. | ||
| 35 | + * | ||
| 36 | + * @param audio A float32 tensor of shape (batch_size, num_samples) | ||
| 37 | + * | ||
| 38 | + * @return Return a float32 tensor of shape (batch_size, T, dim) that | ||
| 39 | + * can be used as the input of ForwardEncoder() | ||
| 40 | + */ | ||
| 41 | + Ort::Value ForwardPreprocessor(Ort::Value audio) const; | ||
| 42 | + | ||
| 43 | + /** Run the encoder model. | ||
| 44 | + * | ||
| 45 | + * @param features A float32 tensor of shape (batch_size, T, dim) | ||
| 46 | + * @param features_len A int32 tensor of shape (batch_size,) | ||
| 47 | + * @returns A float32 tensor of shape (batch_size, T, dim). | ||
| 48 | + */ | ||
| 49 | + Ort::Value ForwardEncoder(Ort::Value features, Ort::Value features_len) const; | ||
| 50 | + | ||
| 51 | + /** Run the uncached decoder. | ||
| 52 | + * | ||
| 53 | + * @param token A int32 tensor of shape (batch_size, num_tokens) | ||
| 54 | + * @param seq_len A int32 tensor of shape (batch_size,) containing number | ||
| 55 | + * of predicted tokens so far | ||
| 56 | + * @param encoder_out A float32 tensor of shape (batch_size, T, dim) | ||
| 57 | + * | ||
| 58 | + * @returns Return a pair: | ||
| 59 | + * | ||
| 60 | + * - logits, a float32 tensor of shape (batch_size, 1, dim) | ||
| 61 | + * - states, a list of states | ||
| 62 | + */ | ||
| 63 | + std::pair<Ort::Value, std::vector<Ort::Value>> ForwardUnCachedDecoder( | ||
| 64 | + Ort::Value token, Ort::Value seq_len, Ort::Value encoder_out) const; | ||
| 65 | + | ||
| 66 | + /** Run the cached decoder. | ||
| 67 | + * | ||
| 68 | + * @param token A int32 tensor of shape (batch_size, num_tokens) | ||
| 69 | + * @param seq_len A int32 tensor of shape (batch_size,) containing number | ||
| 70 | + * of predicted tokens so far | ||
| 71 | + * @param encoder_out A float32 tensor of shape (batch_size, T, dim) | ||
| 72 | + * @param states A list of previous states | ||
| 73 | + * | ||
| 74 | + * @returns Return a pair: | ||
| 75 | + * - logits, a float32 tensor of shape (batch_size, 1, dim) | ||
| 76 | + * - states, a list of new states | ||
| 77 | + */ | ||
| 78 | + std::pair<Ort::Value, std::vector<Ort::Value>> ForwardCachedDecoder( | ||
| 79 | + Ort::Value token, Ort::Value seq_len, Ort::Value encoder_out, | ||
| 80 | + std::vector<Ort::Value> states) const; | ||
| 81 | + | ||
| 82 | + /** Return an allocator for allocating memory | ||
| 83 | + */ | ||
| 84 | + OrtAllocator *Allocator() const; | ||
| 85 | + | ||
| 86 | + private: | ||
| 87 | + class Impl; | ||
| 88 | + std::unique_ptr<Impl> impl_; | ||
| 89 | +}; | ||
| 90 | + | ||
| 91 | +} // namespace sherpa_onnx | ||
| 92 | + | ||
| 93 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_H_ |
| @@ -20,6 +20,7 @@ | @@ -20,6 +20,7 @@ | ||
| 20 | #include "onnxruntime_cxx_api.h" // NOLINT | 20 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 21 | #include "sherpa-onnx/csrc/macros.h" | 21 | #include "sherpa-onnx/csrc/macros.h" |
| 22 | #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" | 22 | #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" |
| 23 | +#include "sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h" | ||
| 23 | #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h" | 24 | #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h" |
| 24 | #include "sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h" | 25 | #include "sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h" |
| 25 | #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" | 26 | #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" |
| @@ -51,6 +52,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -51,6 +52,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 51 | return std::make_unique<OfflineRecognizerWhisperImpl>(config); | 52 | return std::make_unique<OfflineRecognizerWhisperImpl>(config); |
| 52 | } | 53 | } |
| 53 | 54 | ||
| 55 | + if (!config.model_config.moonshine.preprocessor.empty()) { | ||
| 56 | + return std::make_unique<OfflineRecognizerMoonshineImpl>(config); | ||
| 57 | + } | ||
| 58 | + | ||
| 54 | // TODO(fangjun): Refactor it. We only need to use model type for the | 59 | // TODO(fangjun): Refactor it. We only need to use model type for the |
| 55 | // following models: | 60 | // following models: |
| 56 | // 1. transducer and nemo_transducer | 61 | // 1. transducer and nemo_transducer |
| @@ -67,7 +72,11 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -67,7 +72,11 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 67 | model_type == "telespeech_ctc") { | 72 | model_type == "telespeech_ctc") { |
| 68 | return std::make_unique<OfflineRecognizerCtcImpl>(config); | 73 | return std::make_unique<OfflineRecognizerCtcImpl>(config); |
| 69 | } else if (model_type == "whisper") { | 74 | } else if (model_type == "whisper") { |
| 75 | + // unreachable | ||
| 70 | return std::make_unique<OfflineRecognizerWhisperImpl>(config); | 76 | return std::make_unique<OfflineRecognizerWhisperImpl>(config); |
| 77 | + } else if (model_type == "moonshine") { | ||
| 78 | + // unreachable | ||
| 79 | + return std::make_unique<OfflineRecognizerMoonshineImpl>(config); | ||
| 71 | } else { | 80 | } else { |
| 72 | SHERPA_ONNX_LOGE( | 81 | SHERPA_ONNX_LOGE( |
| 73 | "Invalid model_type: %s. Trying to load the model to get its type", | 82 | "Invalid model_type: %s. Trying to load the model to get its type", |
| @@ -225,6 +234,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -225,6 +234,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 225 | return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config); | 234 | return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config); |
| 226 | } | 235 | } |
| 227 | 236 | ||
| 237 | + if (!config.model_config.moonshine.preprocessor.empty()) { | ||
| 238 | + return std::make_unique<OfflineRecognizerMoonshineImpl>(mgr, config); | ||
| 239 | + } | ||
| 240 | + | ||
| 228 | // TODO(fangjun): Refactor it. We only need to use model type for the | 241 | // TODO(fangjun): Refactor it. We only need to use model type for the |
| 229 | // following models: | 242 | // following models: |
| 230 | // 1. transducer and nemo_transducer | 243 | // 1. transducer and nemo_transducer |
| @@ -242,6 +255,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -242,6 +255,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 242 | return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config); | 255 | return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config); |
| 243 | } else if (model_type == "whisper") { | 256 | } else if (model_type == "whisper") { |
| 244 | return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config); | 257 | return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config); |
| 258 | + } else if (model_type == "moonshine") { | ||
| 259 | + return std::make_unique<OfflineRecognizerMoonshineImpl>(mgr, config); | ||
| 245 | } else { | 260 | } else { |
| 246 | SHERPA_ONNX_LOGE( | 261 | SHERPA_ONNX_LOGE( |
| 247 | "Invalid model_type: %s. Trying to load the model to get its type", | 262 | "Invalid model_type: %s. Trying to load the model to get its type", |
| 1 | +// sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_MOONSHINE_IMPL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_MOONSHINE_IMPL_H_ | ||
| 7 | + | ||
| 8 | +#include <algorithm> | ||
| 9 | +#include <cmath> | ||
| 10 | +#include <memory> | ||
| 11 | +#include <string> | ||
| 12 | +#include <utility> | ||
| 13 | +#include <vector> | ||
| 14 | + | ||
| 15 | +#if __ANDROID_API__ >= 9 | ||
| 16 | +#include "android/asset_manager.h" | ||
| 17 | +#include "android/asset_manager_jni.h" | ||
| 18 | +#endif | ||
| 19 | + | ||
| 20 | +#include "sherpa-onnx/csrc/offline-model-config.h" | ||
| 21 | +#include "sherpa-onnx/csrc/offline-moonshine-decoder.h" | ||
| 22 | +#include "sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h" | ||
| 23 | +#include "sherpa-onnx/csrc/offline-moonshine-model.h" | ||
| 24 | +#include "sherpa-onnx/csrc/offline-recognizer-impl.h" | ||
| 25 | +#include "sherpa-onnx/csrc/offline-recognizer.h" | ||
| 26 | +#include "sherpa-onnx/csrc/symbol-table.h" | ||
| 27 | +#include "sherpa-onnx/csrc/transpose.h" | ||
| 28 | + | ||
| 29 | +namespace sherpa_onnx { | ||
| 30 | + | ||
| 31 | +static OfflineRecognitionResult Convert( | ||
| 32 | + const OfflineMoonshineDecoderResult &src, const SymbolTable &sym_table) { | ||
| 33 | + OfflineRecognitionResult r; | ||
| 34 | + r.tokens.reserve(src.tokens.size()); | ||
| 35 | + | ||
| 36 | + std::string text; | ||
| 37 | + for (auto i : src.tokens) { | ||
| 38 | + if (!sym_table.Contains(i)) { | ||
| 39 | + continue; | ||
| 40 | + } | ||
| 41 | + | ||
| 42 | + const auto &s = sym_table[i]; | ||
| 43 | + text += s; | ||
| 44 | + r.tokens.push_back(s); | ||
| 45 | + } | ||
| 46 | + | ||
| 47 | + r.text = text; | ||
| 48 | + | ||
| 49 | + return r; | ||
| 50 | +} | ||
| 51 | + | ||
| 52 | +class OfflineRecognizerMoonshineImpl : public OfflineRecognizerImpl { | ||
| 53 | + public: | ||
| 54 | + explicit OfflineRecognizerMoonshineImpl(const OfflineRecognizerConfig &config) | ||
| 55 | + : OfflineRecognizerImpl(config), | ||
| 56 | + config_(config), | ||
| 57 | + symbol_table_(config_.model_config.tokens), | ||
| 58 | + model_(std::make_unique<OfflineMoonshineModel>(config.model_config)) { | ||
| 59 | + Init(); | ||
| 60 | + } | ||
| 61 | + | ||
| 62 | +#if __ANDROID_API__ >= 9 | ||
| 63 | + OfflineRecognizerMoonshineImpl(AAssetManager *mgr, | ||
| 64 | + const OfflineRecognizerConfig &config) | ||
| 65 | + : OfflineRecognizerImpl(mgr, config), | ||
| 66 | + config_(config), | ||
| 67 | + symbol_table_(mgr, config_.model_config.tokens), | ||
| 68 | + model_( | ||
| 69 | + std::make_unique<OfflineMoonshineModel>(mgr, config.model_config)) { | ||
| 70 | + Init(); | ||
| 71 | + } | ||
| 72 | + | ||
| 73 | +#endif | ||
| 74 | + | ||
| 75 | + void Init() { | ||
| 76 | + if (config_.decoding_method == "greedy_search") { | ||
| 77 | + decoder_ = | ||
| 78 | + std::make_unique<OfflineMoonshineGreedySearchDecoder>(model_.get()); | ||
| 79 | + } else { | ||
| 80 | + SHERPA_ONNX_LOGE( | ||
| 81 | + "Only greedy_search is supported at present for moonshine. Given %s", | ||
| 82 | + config_.decoding_method.c_str()); | ||
| 83 | + exit(-1); | ||
| 84 | + } | ||
| 85 | + } | ||
| 86 | + | ||
| 87 | + std::unique_ptr<OfflineStream> CreateStream() const override { | ||
| 88 | + MoonshineTag tag; | ||
| 89 | + return std::make_unique<OfflineStream>(tag); | ||
| 90 | + } | ||
| 91 | + | ||
| 92 | + void DecodeStreams(OfflineStream **ss, int32_t n) const override { | ||
| 93 | + // batch decoding is not implemented yet | ||
| 94 | + for (int32_t i = 0; i != n; ++i) { | ||
| 95 | + DecodeStream(ss[i]); | ||
| 96 | + } | ||
| 97 | + } | ||
| 98 | + | ||
| 99 | + OfflineRecognizerConfig GetConfig() const override { return config_; } | ||
| 100 | + | ||
| 101 | + private: | ||
| 102 | + void DecodeStream(OfflineStream *s) const { | ||
| 103 | + auto memory_info = | ||
| 104 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 105 | + | ||
| 106 | + std::vector<float> audio = s->GetFrames(); | ||
| 107 | + | ||
| 108 | + try { | ||
| 109 | + std::array<int64_t, 2> shape{1, static_cast<int64_t>(audio.size())}; | ||
| 110 | + | ||
| 111 | + Ort::Value audio_tensor = Ort::Value::CreateTensor( | ||
| 112 | + memory_info, audio.data(), audio.size(), shape.data(), shape.size()); | ||
| 113 | + | ||
| 114 | + Ort::Value features = | ||
| 115 | + model_->ForwardPreprocessor(std::move(audio_tensor)); | ||
| 116 | + | ||
| 117 | + int32_t features_len = features.GetTensorTypeAndShapeInfo().GetShape()[1]; | ||
| 118 | + | ||
| 119 | + int64_t features_shape = 1; | ||
| 120 | + | ||
| 121 | + Ort::Value features_len_tensor = Ort::Value::CreateTensor( | ||
| 122 | + memory_info, &features_len, 1, &features_shape, 1); | ||
| 123 | + | ||
| 124 | + Ort::Value encoder_out = model_->ForwardEncoder( | ||
| 125 | + std::move(features), std::move(features_len_tensor)); | ||
| 126 | + | ||
| 127 | + auto results = decoder_->Decode(std::move(encoder_out)); | ||
| 128 | + | ||
| 129 | + auto r = Convert(results[0], symbol_table_); | ||
| 130 | + r.text = ApplyInverseTextNormalization(std::move(r.text)); | ||
| 131 | + s->SetResult(r); | ||
| 132 | + } catch (const Ort::Exception &ex) { | ||
| 133 | + SHERPA_ONNX_LOGE( | ||
| 134 | + "\n\nCaught exception:\n\n%s\n\nReturn an empty result. Number of " | ||
| 135 | + "audio samples: %d", | ||
| 136 | + ex.what(), static_cast<int32_t>(audio.size())); | ||
| 137 | + return; | ||
| 138 | + } | ||
| 139 | + } | ||
| 140 | + | ||
| 141 | + private: | ||
| 142 | + OfflineRecognizerConfig config_; | ||
| 143 | + SymbolTable symbol_table_; | ||
| 144 | + std::unique_ptr<OfflineMoonshineModel> model_; | ||
| 145 | + std::unique_ptr<OfflineMoonshineDecoder> decoder_; | ||
| 146 | +}; | ||
| 147 | + | ||
| 148 | +} // namespace sherpa_onnx | ||
| 149 | + | ||
| 150 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_MOONSHINE_IMPL_H_ |
| @@ -133,6 +133,10 @@ class OfflineStream::Impl { | @@ -133,6 +133,10 @@ class OfflineStream::Impl { | ||
| 133 | fbank_ = std::make_unique<knf::OnlineFbank>(opts_); | 133 | fbank_ = std::make_unique<knf::OnlineFbank>(opts_); |
| 134 | } | 134 | } |
| 135 | 135 | ||
| 136 | + explicit Impl(MoonshineTag /*tag*/) : is_moonshine_(true) { | ||
| 137 | + config_.sampling_rate = 16000; | ||
| 138 | + } | ||
| 139 | + | ||
| 136 | void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { | 140 | void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { |
| 137 | if (config_.normalize_samples) { | 141 | if (config_.normalize_samples) { |
| 138 | AcceptWaveformImpl(sampling_rate, waveform, n); | 142 | AcceptWaveformImpl(sampling_rate, waveform, n); |
| @@ -164,7 +168,9 @@ class OfflineStream::Impl { | @@ -164,7 +168,9 @@ class OfflineStream::Impl { | ||
| 164 | std::vector<float> samples; | 168 | std::vector<float> samples; |
| 165 | resampler->Resample(waveform, n, true, &samples); | 169 | resampler->Resample(waveform, n, true, &samples); |
| 166 | 170 | ||
| 167 | - if (fbank_) { | 171 | + if (is_moonshine_) { |
| 172 | + samples_.insert(samples_.end(), samples.begin(), samples.end()); | ||
| 173 | + } else if (fbank_) { | ||
| 168 | fbank_->AcceptWaveform(config_.sampling_rate, samples.data(), | 174 | fbank_->AcceptWaveform(config_.sampling_rate, samples.data(), |
| 169 | samples.size()); | 175 | samples.size()); |
| 170 | fbank_->InputFinished(); | 176 | fbank_->InputFinished(); |
| @@ -181,7 +187,9 @@ class OfflineStream::Impl { | @@ -181,7 +187,9 @@ class OfflineStream::Impl { | ||
| 181 | return; | 187 | return; |
| 182 | } // if (sampling_rate != config_.sampling_rate) | 188 | } // if (sampling_rate != config_.sampling_rate) |
| 183 | 189 | ||
| 184 | - if (fbank_) { | 190 | + if (is_moonshine_) { |
| 191 | + samples_.insert(samples_.end(), waveform, waveform + n); | ||
| 192 | + } else if (fbank_) { | ||
| 185 | fbank_->AcceptWaveform(sampling_rate, waveform, n); | 193 | fbank_->AcceptWaveform(sampling_rate, waveform, n); |
| 186 | fbank_->InputFinished(); | 194 | fbank_->InputFinished(); |
| 187 | } else if (mfcc_) { | 195 | } else if (mfcc_) { |
| @@ -194,10 +202,18 @@ class OfflineStream::Impl { | @@ -194,10 +202,18 @@ class OfflineStream::Impl { | ||
| 194 | } | 202 | } |
| 195 | 203 | ||
| 196 | int32_t FeatureDim() const { | 204 | int32_t FeatureDim() const { |
| 205 | + if (is_moonshine_) { | ||
| 206 | + return samples_.size(); | ||
| 207 | + } | ||
| 208 | + | ||
| 197 | return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins; | 209 | return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins; |
| 198 | } | 210 | } |
| 199 | 211 | ||
| 200 | std::vector<float> GetFrames() const { | 212 | std::vector<float> GetFrames() const { |
| 213 | + if (is_moonshine_) { | ||
| 214 | + return samples_; | ||
| 215 | + } | ||
| 216 | + | ||
| 201 | int32_t n = fbank_ ? fbank_->NumFramesReady() | 217 | int32_t n = fbank_ ? fbank_->NumFramesReady() |
| 202 | : mfcc_ ? mfcc_->NumFramesReady() | 218 | : mfcc_ ? mfcc_->NumFramesReady() |
| 203 | : whisper_fbank_->NumFramesReady(); | 219 | : whisper_fbank_->NumFramesReady(); |
| @@ -300,6 +316,10 @@ class OfflineStream::Impl { | @@ -300,6 +316,10 @@ class OfflineStream::Impl { | ||
| 300 | OfflineRecognitionResult r_; | 316 | OfflineRecognitionResult r_; |
| 301 | ContextGraphPtr context_graph_; | 317 | ContextGraphPtr context_graph_; |
| 302 | bool is_ced_ = false; | 318 | bool is_ced_ = false; |
| 319 | + bool is_moonshine_ = false; | ||
| 320 | + | ||
| 321 | + // used only when is_moonshine_== true | ||
| 322 | + std::vector<float> samples_; | ||
| 303 | }; | 323 | }; |
| 304 | 324 | ||
| 305 | OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/, | 325 | OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/, |
| @@ -311,6 +331,9 @@ OfflineStream::OfflineStream(WhisperTag tag) | @@ -311,6 +331,9 @@ OfflineStream::OfflineStream(WhisperTag tag) | ||
| 311 | 331 | ||
| 312 | OfflineStream::OfflineStream(CEDTag tag) : impl_(std::make_unique<Impl>(tag)) {} | 332 | OfflineStream::OfflineStream(CEDTag tag) : impl_(std::make_unique<Impl>(tag)) {} |
| 313 | 333 | ||
| 334 | +OfflineStream::OfflineStream(MoonshineTag tag) | ||
| 335 | + : impl_(std::make_unique<Impl>(tag)) {} | ||
| 336 | + | ||
| 314 | OfflineStream::~OfflineStream() = default; | 337 | OfflineStream::~OfflineStream() = default; |
| 315 | 338 | ||
| 316 | void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform, | 339 | void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform, |
| @@ -34,7 +34,7 @@ struct OfflineRecognitionResult { | @@ -34,7 +34,7 @@ struct OfflineRecognitionResult { | ||
| 34 | // event target of the audio. | 34 | // event target of the audio. |
| 35 | std::string event; | 35 | std::string event; |
| 36 | 36 | ||
| 37 | - /// timestamps.size() == tokens.size() | 37 | + /// timestamps.size() == tokens.size() |
| 38 | /// timestamps[i] records the time in seconds when tokens[i] is decoded. | 38 | /// timestamps[i] records the time in seconds when tokens[i] is decoded. |
| 39 | std::vector<float> timestamps; | 39 | std::vector<float> timestamps; |
| 40 | 40 | ||
| @@ -49,6 +49,10 @@ struct WhisperTag { | @@ -49,6 +49,10 @@ struct WhisperTag { | ||
| 49 | 49 | ||
| 50 | struct CEDTag {}; | 50 | struct CEDTag {}; |
| 51 | 51 | ||
| 52 | +// It uses a neural network model, a preprocessor, to convert | ||
| 53 | +// audio samples to features | ||
| 54 | +struct MoonshineTag {}; | ||
| 55 | + | ||
| 52 | class OfflineStream { | 56 | class OfflineStream { |
| 53 | public: | 57 | public: |
| 54 | explicit OfflineStream(const FeatureExtractorConfig &config = {}, | 58 | explicit OfflineStream(const FeatureExtractorConfig &config = {}, |
| @@ -56,6 +60,7 @@ class OfflineStream { | @@ -56,6 +60,7 @@ class OfflineStream { | ||
| 56 | 60 | ||
| 57 | explicit OfflineStream(WhisperTag tag); | 61 | explicit OfflineStream(WhisperTag tag); |
| 58 | explicit OfflineStream(CEDTag tag); | 62 | explicit OfflineStream(CEDTag tag); |
| 63 | + explicit OfflineStream(MoonshineTag tag); | ||
| 59 | ~OfflineStream(); | 64 | ~OfflineStream(); |
| 60 | 65 | ||
| 61 | /** | 66 | /** |
| @@ -72,7 +77,10 @@ class OfflineStream { | @@ -72,7 +77,10 @@ class OfflineStream { | ||
| 72 | void AcceptWaveform(int32_t sampling_rate, const float *waveform, | 77 | void AcceptWaveform(int32_t sampling_rate, const float *waveform, |
| 73 | int32_t n) const; | 78 | int32_t n) const; |
| 74 | 79 | ||
| 75 | - /// Return feature dim of this extractor | 80 | + /// Return feature dim of this extractor. |
| 81 | + /// | ||
| 82 | + /// Note: if it is Moonshine, then it returns the number of audio samples | ||
| 83 | + /// currently received. | ||
| 76 | int32_t FeatureDim() const; | 84 | int32_t FeatureDim() const; |
| 77 | 85 | ||
| 78 | // Get all the feature frames of this stream in a 1-D array, which is | 86 | // Get all the feature frames of this stream in a 1-D array, which is |
| @@ -23,7 +23,6 @@ class OfflineWhisperModel::Impl { | @@ -23,7 +23,6 @@ class OfflineWhisperModel::Impl { | ||
| 23 | explicit Impl(const OfflineModelConfig &config) | 23 | explicit Impl(const OfflineModelConfig &config) |
| 24 | : config_(config), | 24 | : config_(config), |
| 25 | env_(ORT_LOGGING_LEVEL_ERROR), | 25 | env_(ORT_LOGGING_LEVEL_ERROR), |
| 26 | - debug_(config.debug), | ||
| 27 | sess_opts_(GetSessionOptions(config)), | 26 | sess_opts_(GetSessionOptions(config)), |
| 28 | allocator_{} { | 27 | allocator_{} { |
| 29 | { | 28 | { |
| @@ -40,7 +39,6 @@ class OfflineWhisperModel::Impl { | @@ -40,7 +39,6 @@ class OfflineWhisperModel::Impl { | ||
| 40 | explicit Impl(const SpokenLanguageIdentificationConfig &config) | 39 | explicit Impl(const SpokenLanguageIdentificationConfig &config) |
| 41 | : lid_config_(config), | 40 | : lid_config_(config), |
| 42 | env_(ORT_LOGGING_LEVEL_ERROR), | 41 | env_(ORT_LOGGING_LEVEL_ERROR), |
| 43 | - debug_(config_.debug), | ||
| 44 | sess_opts_(GetSessionOptions(config)), | 42 | sess_opts_(GetSessionOptions(config)), |
| 45 | allocator_{} { | 43 | allocator_{} { |
| 46 | { | 44 | { |
| @@ -60,7 +58,6 @@ class OfflineWhisperModel::Impl { | @@ -60,7 +58,6 @@ class OfflineWhisperModel::Impl { | ||
| 60 | env_(ORT_LOGGING_LEVEL_ERROR), | 58 | env_(ORT_LOGGING_LEVEL_ERROR), |
| 61 | sess_opts_(GetSessionOptions(config)), | 59 | sess_opts_(GetSessionOptions(config)), |
| 62 | allocator_{} { | 60 | allocator_{} { |
| 63 | - debug_ = config_.debug; | ||
| 64 | { | 61 | { |
| 65 | auto buf = ReadFile(mgr, config.whisper.encoder); | 62 | auto buf = ReadFile(mgr, config.whisper.encoder); |
| 66 | InitEncoder(buf.data(), buf.size()); | 63 | InitEncoder(buf.data(), buf.size()); |
| @@ -77,7 +74,6 @@ class OfflineWhisperModel::Impl { | @@ -77,7 +74,6 @@ class OfflineWhisperModel::Impl { | ||
| 77 | env_(ORT_LOGGING_LEVEL_ERROR), | 74 | env_(ORT_LOGGING_LEVEL_ERROR), |
| 78 | sess_opts_(GetSessionOptions(config)), | 75 | sess_opts_(GetSessionOptions(config)), |
| 79 | allocator_{} { | 76 | allocator_{} { |
| 80 | - debug_ = config_.debug; | ||
| 81 | { | 77 | { |
| 82 | auto buf = ReadFile(mgr, config.whisper.encoder); | 78 | auto buf = ReadFile(mgr, config.whisper.encoder); |
| 83 | InitEncoder(buf.data(), buf.size()); | 79 | InitEncoder(buf.data(), buf.size()); |
| @@ -164,7 +160,7 @@ class OfflineWhisperModel::Impl { | @@ -164,7 +160,7 @@ class OfflineWhisperModel::Impl { | ||
| 164 | } | 160 | } |
| 165 | } | 161 | } |
| 166 | 162 | ||
| 167 | - if (debug_) { | 163 | + if (config_.debug) { |
| 168 | SHERPA_ONNX_LOGE("Detected language: %s", | 164 | SHERPA_ONNX_LOGE("Detected language: %s", |
| 169 | GetID2Lang().at(lang_id).c_str()); | 165 | GetID2Lang().at(lang_id).c_str()); |
| 170 | } | 166 | } |
| @@ -237,7 +233,7 @@ class OfflineWhisperModel::Impl { | @@ -237,7 +233,7 @@ class OfflineWhisperModel::Impl { | ||
| 237 | 233 | ||
| 238 | // get meta data | 234 | // get meta data |
| 239 | Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); | 235 | Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); |
| 240 | - if (debug_) { | 236 | + if (config_.debug) { |
| 241 | std::ostringstream os; | 237 | std::ostringstream os; |
| 242 | os << "---encoder---\n"; | 238 | os << "---encoder---\n"; |
| 243 | PrintModelMetadata(os, meta_data); | 239 | PrintModelMetadata(os, meta_data); |
| @@ -294,7 +290,6 @@ class OfflineWhisperModel::Impl { | @@ -294,7 +290,6 @@ class OfflineWhisperModel::Impl { | ||
| 294 | private: | 290 | private: |
| 295 | OfflineModelConfig config_; | 291 | OfflineModelConfig config_; |
| 296 | SpokenLanguageIdentificationConfig lid_config_; | 292 | SpokenLanguageIdentificationConfig lid_config_; |
| 297 | - bool debug_ = false; | ||
| 298 | Ort::Env env_; | 293 | Ort::Env env_; |
| 299 | Ort::SessionOptions sess_opts_; | 294 | Ort::SessionOptions sess_opts_; |
| 300 | Ort::AllocatorWithDefaultOptions allocator_; | 295 | Ort::AllocatorWithDefaultOptions allocator_; |
| @@ -43,7 +43,20 @@ See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/in | @@ -43,7 +43,20 @@ See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/in | ||
| 43 | --decoding-method=greedy_search \ | 43 | --decoding-method=greedy_search \ |
| 44 | /path/to/foo.wav [bar.wav foobar.wav ...] | 44 | /path/to/foo.wav [bar.wav foobar.wav ...] |
| 45 | 45 | ||
| 46 | -(3) Whisper models | 46 | +(3) Moonshine models |
| 47 | + | ||
| 48 | +See https://k2-fsa.github.io/sherpa/onnx/moonshine/index.html | ||
| 49 | + | ||
| 50 | + ./bin/sherpa-onnx-offline \ | ||
| 51 | + --moonshine-preprocessor=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/preprocess.onnx \ | ||
| 52 | + --moonshine-encoder=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/encode.int8.onnx \ | ||
| 53 | + --moonshine-uncached-decoder=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/uncached_decode.int8.onnx \ | ||
| 54 | + --moonshine-cached-decoder=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/cached_decode.int8.onnx \ | ||
| 55 | + --tokens=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/tokens.txt \ | ||
| 56 | + --num-threads=1 \ | ||
| 57 | + /path/to/foo.wav [bar.wav foobar.wav ...] | ||
| 58 | + | ||
| 59 | +(4) Whisper models | ||
| 47 | 60 | ||
| 48 | See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html | 61 | See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html |
| 49 | 62 | ||
| @@ -54,7 +67,7 @@ See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html | @@ -54,7 +67,7 @@ See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html | ||
| 54 | --num-threads=1 \ | 67 | --num-threads=1 \ |
| 55 | /path/to/foo.wav [bar.wav foobar.wav ...] | 68 | /path/to/foo.wav [bar.wav foobar.wav ...] |
| 56 | 69 | ||
| 57 | -(4) NeMo CTC models | 70 | +(5) NeMo CTC models |
| 58 | 71 | ||
| 59 | See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html | 72 | See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html |
| 60 | 73 | ||
| @@ -68,7 +81,7 @@ See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.htm | @@ -68,7 +81,7 @@ See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.htm | ||
| 68 | ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/1.wav \ | 81 | ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/1.wav \ |
| 69 | ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/8k.wav | 82 | ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/8k.wav |
| 70 | 83 | ||
| 71 | -(5) TDNN CTC model for the yesno recipe from icefall | 84 | +(6) TDNN CTC model for the yesno recipe from icefall |
| 72 | 85 | ||
| 73 | See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/yesno/index.html | 86 | See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/yesno/index.html |
| 74 | // | 87 | // |
| @@ -109,6 +109,8 @@ const std::string SymbolTable::operator[](int32_t id) const { | @@ -109,6 +109,8 @@ const std::string SymbolTable::operator[](int32_t id) const { | ||
| 109 | 109 | ||
| 110 | // for byte-level BPE | 110 | // for byte-level BPE |
| 111 | // id 0 is blank, id 1 is sos/eos, id 2 is unk | 111 | // id 0 is blank, id 1 is sos/eos, id 2 is unk |
| 112 | + // | ||
| 113 | + // Note: For moonshine models, 0 is <unk>, 1, is <s>, 2 is</s> | ||
| 112 | if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' && | 114 | if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' && |
| 113 | sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') { | 115 | sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') { |
| 114 | std::ostringstream os; | 116 | std::ostringstream os; |
| @@ -11,6 +11,7 @@ set(srcs | @@ -11,6 +11,7 @@ set(srcs | ||
| 11 | offline-ctc-fst-decoder-config.cc | 11 | offline-ctc-fst-decoder-config.cc |
| 12 | offline-lm-config.cc | 12 | offline-lm-config.cc |
| 13 | offline-model-config.cc | 13 | offline-model-config.cc |
| 14 | + offline-moonshine-model-config.cc | ||
| 14 | offline-nemo-enc-dec-ctc-model-config.cc | 15 | offline-nemo-enc-dec-ctc-model-config.cc |
| 15 | offline-paraformer-model-config.cc | 16 | offline-paraformer-model-config.cc |
| 16 | offline-punctuation.cc | 17 | offline-punctuation.cc |
| @@ -8,6 +8,7 @@ | @@ -8,6 +8,7 @@ | ||
| 8 | #include <vector> | 8 | #include <vector> |
| 9 | 9 | ||
| 10 | #include "sherpa-onnx/csrc/offline-model-config.h" | 10 | #include "sherpa-onnx/csrc/offline-model-config.h" |
| 11 | +#include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h" | ||
| 11 | #include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h" | 12 | #include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h" |
| 12 | #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" | 13 | #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" |
| 13 | #include "sherpa-onnx/python/csrc/offline-sense-voice-model-config.h" | 14 | #include "sherpa-onnx/python/csrc/offline-sense-voice-model-config.h" |
| @@ -28,6 +29,7 @@ void PybindOfflineModelConfig(py::module *m) { | @@ -28,6 +29,7 @@ void PybindOfflineModelConfig(py::module *m) { | ||
| 28 | PybindOfflineZipformerCtcModelConfig(m); | 29 | PybindOfflineZipformerCtcModelConfig(m); |
| 29 | PybindOfflineWenetCtcModelConfig(m); | 30 | PybindOfflineWenetCtcModelConfig(m); |
| 30 | PybindOfflineSenseVoiceModelConfig(m); | 31 | PybindOfflineSenseVoiceModelConfig(m); |
| 32 | + PybindOfflineMoonshineModelConfig(m); | ||
| 31 | 33 | ||
| 32 | using PyClass = OfflineModelConfig; | 34 | using PyClass = OfflineModelConfig; |
| 33 | py::class_<PyClass>(*m, "OfflineModelConfig") | 35 | py::class_<PyClass>(*m, "OfflineModelConfig") |
| @@ -39,7 +41,8 @@ void PybindOfflineModelConfig(py::module *m) { | @@ -39,7 +41,8 @@ void PybindOfflineModelConfig(py::module *m) { | ||
| 39 | const OfflineWhisperModelConfig &, const OfflineTdnnModelConfig &, | 41 | const OfflineWhisperModelConfig &, const OfflineTdnnModelConfig &, |
| 40 | const OfflineZipformerCtcModelConfig &, | 42 | const OfflineZipformerCtcModelConfig &, |
| 41 | const OfflineWenetCtcModelConfig &, | 43 | const OfflineWenetCtcModelConfig &, |
| 42 | - const OfflineSenseVoiceModelConfig &, const std::string &, | 44 | + const OfflineSenseVoiceModelConfig &, |
| 45 | + const OfflineMoonshineModelConfig &, const std::string &, | ||
| 43 | const std::string &, int32_t, bool, const std::string &, | 46 | const std::string &, int32_t, bool, const std::string &, |
| 44 | const std::string &, const std::string &, const std::string &>(), | 47 | const std::string &, const std::string &, const std::string &>(), |
| 45 | py::arg("transducer") = OfflineTransducerModelConfig(), | 48 | py::arg("transducer") = OfflineTransducerModelConfig(), |
| @@ -50,6 +53,7 @@ void PybindOfflineModelConfig(py::module *m) { | @@ -50,6 +53,7 @@ void PybindOfflineModelConfig(py::module *m) { | ||
| 50 | py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(), | 53 | py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(), |
| 51 | py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(), | 54 | py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(), |
| 52 | py::arg("sense_voice") = OfflineSenseVoiceModelConfig(), | 55 | py::arg("sense_voice") = OfflineSenseVoiceModelConfig(), |
| 56 | + py::arg("moonshine") = OfflineMoonshineModelConfig(), | ||
| 53 | py::arg("telespeech_ctc") = "", py::arg("tokens"), | 57 | py::arg("telespeech_ctc") = "", py::arg("tokens"), |
| 54 | py::arg("num_threads"), py::arg("debug") = false, | 58 | py::arg("num_threads"), py::arg("debug") = false, |
| 55 | py::arg("provider") = "cpu", py::arg("model_type") = "", | 59 | py::arg("provider") = "cpu", py::arg("model_type") = "", |
| @@ -62,6 +66,7 @@ void PybindOfflineModelConfig(py::module *m) { | @@ -62,6 +66,7 @@ void PybindOfflineModelConfig(py::module *m) { | ||
| 62 | .def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc) | 66 | .def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc) |
| 63 | .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) | 67 | .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) |
| 64 | .def_readwrite("sense_voice", &PyClass::sense_voice) | 68 | .def_readwrite("sense_voice", &PyClass::sense_voice) |
| 69 | + .def_readwrite("moonshine", &PyClass::moonshine) | ||
| 65 | .def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc) | 70 | .def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc) |
| 66 | .def_readwrite("tokens", &PyClass::tokens) | 71 | .def_readwrite("tokens", &PyClass::tokens) |
| 67 | .def_readwrite("num_threads", &PyClass::num_threads) | 72 | .def_readwrite("num_threads", &PyClass::num_threads) |
| 1 | +// sherpa-onnx/python/csrc/offline-moonshine-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-moonshine-model-config.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +void PybindOfflineMoonshineModelConfig(py::module *m) { | ||
| 15 | + using PyClass = OfflineMoonshineModelConfig; | ||
| 16 | + py::class_<PyClass>(*m, "OfflineMoonshineModelConfig") | ||
| 17 | + .def(py::init<const std::string &, const std::string &, | ||
| 18 | + const std::string &, const std::string &>(), | ||
| 19 | + py::arg("preprocessor"), py::arg("encoder"), | ||
| 20 | + py::arg("uncached_decoder"), py::arg("cached_decoder")) | ||
| 21 | + .def_readwrite("preprocessor", &PyClass::preprocessor) | ||
| 22 | + .def_readwrite("encoder", &PyClass::encoder) | ||
| 23 | + .def_readwrite("uncached_decoder", &PyClass::uncached_decoder) | ||
| 24 | + .def_readwrite("cached_decoder", &PyClass::cached_decoder) | ||
| 25 | + .def("__str__", &PyClass::ToString); | ||
| 26 | +} | ||
| 27 | + | ||
| 28 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/python/csrc/offline-moonshine-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindOfflineMoonshineModelConfig(py::module *m); | ||
| 13 | + | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_ |
| @@ -8,13 +8,14 @@ from _sherpa_onnx import ( | @@ -8,13 +8,14 @@ from _sherpa_onnx import ( | ||
| 8 | OfflineCtcFstDecoderConfig, | 8 | OfflineCtcFstDecoderConfig, |
| 9 | OfflineLMConfig, | 9 | OfflineLMConfig, |
| 10 | OfflineModelConfig, | 10 | OfflineModelConfig, |
| 11 | + OfflineMoonshineModelConfig, | ||
| 11 | OfflineNemoEncDecCtcModelConfig, | 12 | OfflineNemoEncDecCtcModelConfig, |
| 12 | OfflineParaformerModelConfig, | 13 | OfflineParaformerModelConfig, |
| 13 | - OfflineSenseVoiceModelConfig, | ||
| 14 | ) | 14 | ) |
| 15 | from _sherpa_onnx import OfflineRecognizer as _Recognizer | 15 | from _sherpa_onnx import OfflineRecognizer as _Recognizer |
| 16 | from _sherpa_onnx import ( | 16 | from _sherpa_onnx import ( |
| 17 | OfflineRecognizerConfig, | 17 | OfflineRecognizerConfig, |
| 18 | + OfflineSenseVoiceModelConfig, | ||
| 18 | OfflineStream, | 19 | OfflineStream, |
| 19 | OfflineTdnnModelConfig, | 20 | OfflineTdnnModelConfig, |
| 20 | OfflineTransducerModelConfig, | 21 | OfflineTransducerModelConfig, |
| @@ -503,12 +504,12 @@ class OfflineRecognizer(object): | @@ -503,12 +504,12 @@ class OfflineRecognizer(object): | ||
| 503 | e.g., tiny, tiny.en, base, base.en, etc. | 504 | e.g., tiny, tiny.en, base, base.en, etc. |
| 504 | 505 | ||
| 505 | Args: | 506 | Args: |
| 506 | - encoder_model: | ||
| 507 | - Path to the encoder model, e.g., tiny-encoder.onnx, | ||
| 508 | - tiny-encoder.int8.onnx, tiny-encoder.ort, etc. | ||
| 509 | - decoder_model: | 507 | + encoder: |
| 510 | Path to the encoder model, e.g., tiny-encoder.onnx, | 508 | Path to the encoder model, e.g., tiny-encoder.onnx, |
| 511 | tiny-encoder.int8.onnx, tiny-encoder.ort, etc. | 509 | tiny-encoder.int8.onnx, tiny-encoder.ort, etc. |
| 510 | + decoder: | ||
| 511 | + Path to the decoder model, e.g., tiny-decoder.onnx, | ||
| 512 | + tiny-decoder.int8.onnx, tiny-decoder.ort, etc. | ||
| 512 | tokens: | 513 | tokens: |
| 513 | Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two | 514 | Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two |
| 514 | columns:: | 515 | columns:: |
| @@ -571,6 +572,87 @@ class OfflineRecognizer(object): | @@ -571,6 +572,87 @@ class OfflineRecognizer(object): | ||
| 571 | return self | 572 | return self |
| 572 | 573 | ||
| 573 | @classmethod | 574 | @classmethod |
| 575 | + def from_moonshine( | ||
| 576 | + cls, | ||
| 577 | + preprocessor: str, | ||
| 578 | + encoder: str, | ||
| 579 | + uncached_decoder: str, | ||
| 580 | + cached_decoder: str, | ||
| 581 | + tokens: str, | ||
| 582 | + num_threads: int = 1, | ||
| 583 | + decoding_method: str = "greedy_search", | ||
| 584 | + debug: bool = False, | ||
| 585 | + provider: str = "cpu", | ||
| 586 | + rule_fsts: str = "", | ||
| 587 | + rule_fars: str = "", | ||
| 588 | + ): | ||
| 589 | + """ | ||
| 590 | + Please refer to | ||
| 591 | + `<https://k2-fsa.github.io/sherpa/onnx/moonshine/index.html>`_ | ||
| 592 | + to download pre-trained models for different kinds of moonshine models, | ||
| 593 | + e.g., tiny, base, etc. | ||
| 594 | + | ||
| 595 | + Args: | ||
| 596 | + preprocessor: | ||
| 597 | + Path to the preprocessor model, e.g., preprocess.onnx | ||
| 598 | + encoder: | ||
| 599 | + Path to the encoder model, e.g., encode.int8.onnx | ||
| 600 | + uncached_decoder: | ||
| 601 | + Path to the uncached decoder model, e.g., uncached_decode.int8.onnx, | ||
| 602 | + cached_decoder: | ||
| 603 | + Path to the cached decoder model, e.g., cached_decode.int8.onnx, | ||
| 604 | + tokens: | ||
| 605 | + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two | ||
| 606 | + columns:: | ||
| 607 | + | ||
| 608 | + symbol integer_id | ||
| 609 | + | ||
| 610 | + num_threads: | ||
| 611 | + Number of threads for neural network computation. | ||
| 612 | + decoding_method: | ||
| 613 | + Valid values: greedy_search. | ||
| 614 | + debug: | ||
| 615 | + True to show debug messages. | ||
| 616 | + provider: | ||
| 617 | + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | ||
| 618 | + rule_fsts: | ||
| 619 | + If not empty, it specifies fsts for inverse text normalization. | ||
| 620 | + If there are multiple fsts, they are separated by a comma. | ||
| 621 | + rule_fars: | ||
| 622 | + If not empty, it specifies fst archives for inverse text normalization. | ||
| 623 | + If there are multiple archives, they are separated by a comma. | ||
| 624 | + """ | ||
| 625 | + self = cls.__new__(cls) | ||
| 626 | + model_config = OfflineModelConfig( | ||
| 627 | + moonshine=OfflineMoonshineModelConfig( | ||
| 628 | + preprocessor=preprocessor, | ||
| 629 | + encoder=encoder, | ||
| 630 | + uncached_decoder=uncached_decoder, | ||
| 631 | + cached_decoder=cached_decoder, | ||
| 632 | + ), | ||
| 633 | + tokens=tokens, | ||
| 634 | + num_threads=num_threads, | ||
| 635 | + debug=debug, | ||
| 636 | + provider=provider, | ||
| 637 | + ) | ||
| 638 | + | ||
| 639 | + unused_feat_config = FeatureExtractorConfig( | ||
| 640 | + sampling_rate=16000, | ||
| 641 | + feature_dim=80, | ||
| 642 | + ) | ||
| 643 | + | ||
| 644 | + recognizer_config = OfflineRecognizerConfig( | ||
| 645 | + model_config=model_config, | ||
| 646 | + feat_config=unused_feat_config, | ||
| 647 | + decoding_method=decoding_method, | ||
| 648 | + rule_fsts=rule_fsts, | ||
| 649 | + rule_fars=rule_fars, | ||
| 650 | + ) | ||
| 651 | + self.recognizer = _Recognizer(recognizer_config) | ||
| 652 | + self.config = recognizer_config | ||
| 653 | + return self | ||
| 654 | + | ||
| 655 | + @classmethod | ||
| 574 | def from_tdnn_ctc( | 656 | def from_tdnn_ctc( |
| 575 | cls, | 657 | cls, |
| 576 | model: str, | 658 | model: str, |
-
请 注册 或 登录 后发表评论