Fangjun Kuang
Committed by GitHub

Add C++ runtime and Python APIs for Moonshine models (#1473)

正在显示 33 个修改的文件 包含 1572 行增加36 行删除
  1 +#!/usr/bin/env bash
  2 +
  3 +set -e
  4 +
  5 +log() {
  6 + # This function is from espnet
  7 + local fname=${BASH_SOURCE[1]##*/}
  8 + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
  9 +}
  10 +
  11 +export GIT_CLONE_PROTECTION_ACTIVE=false
  12 +
  13 +echo "EXE is $EXE"
  14 +echo "PATH: $PATH"
  15 +
  16 +which $EXE
  17 +
  18 +names=(
  19 +tiny
  20 +base
  21 +)
  22 +
  23 +for name in ${names[@]}; do
  24 + log "------------------------------------------------------------"
  25 + log "Run $name"
  26 + log "------------------------------------------------------------"
  27 +
  28 + repo_url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-$name.tar.bz2
  29 + repo_url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-$name-en-int8.tar.bz2
  30 + curl -SL -O $repo_url
  31 + tar xvf sherpa-onnx-moonshine-$name-en-int8.tar.bz2
  32 + rm sherpa-onnx-moonshine-$name-en-int8.tar.bz2
  33 + repo=sherpa-onnx-moonshine-$name-en-int8
  34 + log "Start testing ${repo_url}"
  35 +
  36 + log "test int8 onnx"
  37 +
  38 + time $EXE \
  39 + --moonshine-preprocessor=$repo/preprocess.onnx \
  40 + --moonshine-encoder=$repo/encode.int8.onnx \
  41 + --moonshine-uncached-decoder=$repo/uncached_decode.int8.onnx \
  42 + --moonshine-cached-decoder=$repo/cached_decode.int8.onnx \
  43 + --tokens=$repo/tokens.txt \
  44 + --num-threads=2 \
  45 + $repo/test_wavs/0.wav \
  46 + $repo/test_wavs/1.wav \
  47 + $repo/test_wavs/8k.wav
  48 +
  49 + rm -rf $repo
  50 +done
@@ -8,6 +8,16 @@ log() { @@ -8,6 +8,16 @@ log() {
8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" 8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
9 } 9 }
10 10
  11 +log "test offline Moonshine"
  12 +
  13 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
  14 +tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
  15 +rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
  16 +
  17 +python3 ./python-api-examples/offline-moonshine-decode-files.py
  18 +
  19 +rm -rf sherpa-onnx-moonshine-tiny-en-int8
  20 +
11 log "test offline speaker diarization" 21 log "test offline speaker diarization"
12 22
13 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 23 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
@@ -149,6 +149,19 @@ jobs: @@ -149,6 +149,19 @@ jobs:
149 name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} 149 name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
150 path: install/* 150 path: install/*
151 151
  152 + - name: Test offline Moonshine
  153 + if: matrix.build_type != 'Debug'
  154 + shell: bash
  155 + run: |
  156 + du -h -d1 .
  157 + export PATH=$PWD/build/bin:$PATH
  158 + export EXE=sherpa-onnx-offline
  159 +
  160 + readelf -d build/bin/sherpa-onnx-offline
  161 +
  162 + .github/scripts/test-offline-moonshine.sh
  163 + du -h -d1 .
  164 +
152 - name: Test offline CTC 165 - name: Test offline CTC
153 shell: bash 166 shell: bash
154 run: | 167 run: |
@@ -121,6 +121,15 @@ jobs: @@ -121,6 +121,15 @@ jobs:
121 otool -L build/bin/sherpa-onnx 121 otool -L build/bin/sherpa-onnx
122 otool -l build/bin/sherpa-onnx 122 otool -l build/bin/sherpa-onnx
123 123
  124 + - name: Test offline Moonshine
  125 + if: matrix.build_type != 'Debug'
  126 + shell: bash
  127 + run: |
  128 + export PATH=$PWD/build/bin:$PATH
  129 + export EXE=sherpa-onnx-offline
  130 +
  131 + .github/scripts/test-offline-moonshine.sh
  132 +
124 - name: Test C++ API 133 - name: Test C++ API
125 shell: bash 134 shell: bash
126 run: | 135 run: |
@@ -243,8 +252,6 @@ jobs: @@ -243,8 +252,6 @@ jobs:
243 252
244 .github/scripts/test-offline-whisper.sh 253 .github/scripts/test-offline-whisper.sh
245 254
246 -  
247 -  
248 - name: Test online transducer 255 - name: Test online transducer
249 shell: bash 256 shell: bash
250 run: | 257 run: |
@@ -93,6 +93,14 @@ jobs: @@ -93,6 +93,14 @@ jobs:
93 name: release-windows-x64-${{ matrix.shared_lib }}-${{ matrix.with_tts }} 93 name: release-windows-x64-${{ matrix.shared_lib }}-${{ matrix.with_tts }}
94 path: build/install/* 94 path: build/install/*
95 95
  96 + - name: Test offline Moonshine for windows x64
  97 + shell: bash
  98 + run: |
  99 + export PATH=$PWD/build/bin/Release:$PATH
  100 + export EXE=sherpa-onnx-offline.exe
  101 +
  102 + .github/scripts/test-offline-moonshine.sh
  103 +
96 - name: Test C++ API 104 - name: Test C++ API
97 shell: bash 105 shell: bash
98 run: | 106 run: |
@@ -93,6 +93,14 @@ jobs: @@ -93,6 +93,14 @@ jobs:
93 name: release-windows-x86-${{ matrix.shared_lib }}-${{ matrix.with_tts }} 93 name: release-windows-x86-${{ matrix.shared_lib }}-${{ matrix.with_tts }}
94 path: build/install/* 94 path: build/install/*
95 95
  96 + - name: Test offline Moonshine for windows x86
  97 + shell: bash
  98 + run: |
  99 + export PATH=$PWD/build/bin/Release:$PATH
  100 + export EXE=sherpa-onnx-offline.exe
  101 +
  102 + .github/scripts/test-offline-moonshine.sh
  103 +
96 - name: Test C++ API 104 - name: Test C++ API
97 shell: bash 105 shell: bash
98 run: | 106 run: |
@@ -47,7 +47,19 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v @@ -47,7 +47,19 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v
47 --feature-dim=80 \ 47 --feature-dim=80 \
48 /path/to/test.mp4 48 /path/to/test.mp4
49 49
50 -(3) For Whisper models 50 +(3) For Moonshine models
  51 +
  52 +./python-api-examples/generate-subtitles.py \
  53 + --silero-vad-model=/path/to/silero_vad.onnx \
  54 + --moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \
  55 + --moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \
  56 + --moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \
  57 + --moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \
  58 + --tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt \
  59 + --num-threads=2 \
  60 + /path/to/test.mp4
  61 +
  62 +(4) For Whisper models
51 63
52 ./python-api-examples/generate-subtitles.py \ 64 ./python-api-examples/generate-subtitles.py \
53 --silero-vad-model=/path/to/silero_vad.onnx \ 65 --silero-vad-model=/path/to/silero_vad.onnx \
@@ -58,7 +70,7 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v @@ -58,7 +70,7 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v
58 --num-threads=2 \ 70 --num-threads=2 \
59 /path/to/test.mp4 71 /path/to/test.mp4
60 72
61 -(4) For SenseVoice CTC models 73 +(5) For SenseVoice CTC models
62 74
63 ./python-api-examples/generate-subtitles.py \ 75 ./python-api-examples/generate-subtitles.py \
64 --silero-vad-model=/path/to/silero_vad.onnx \ 76 --silero-vad-model=/path/to/silero_vad.onnx \
@@ -68,7 +80,7 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v @@ -68,7 +80,7 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v
68 /path/to/test.mp4 80 /path/to/test.mp4
69 81
70 82
71 -(5) For WeNet CTC models 83 +(6) For WeNet CTC models
72 84
73 ./python-api-examples/generate-subtitles.py \ 85 ./python-api-examples/generate-subtitles.py \
74 --silero-vad-model=/path/to/silero_vad.onnx \ 86 --silero-vad-model=/path/to/silero_vad.onnx \
@@ -83,6 +95,7 @@ to install sherpa-onnx and to download non-streaming pre-trained models @@ -83,6 +95,7 @@ to install sherpa-onnx and to download non-streaming pre-trained models
83 used in this file. 95 used in this file.
84 """ 96 """
85 import argparse 97 import argparse
  98 +import datetime as dt
86 import shutil 99 import shutil
87 import subprocess 100 import subprocess
88 import sys 101 import sys
@@ -157,7 +170,7 @@ def get_args(): @@ -157,7 +170,7 @@ def get_args():
157 parser.add_argument( 170 parser.add_argument(
158 "--num-threads", 171 "--num-threads",
159 type=int, 172 type=int,
160 - default=1, 173 + default=2,
161 help="Number of threads for neural network computation", 174 help="Number of threads for neural network computation",
162 ) 175 )
163 176
@@ -209,6 +222,34 @@ def get_args(): @@ -209,6 +222,34 @@ def get_args():
209 ) 222 )
210 223
211 parser.add_argument( 224 parser.add_argument(
  225 + "--moonshine-preprocessor",
  226 + default="",
  227 + type=str,
  228 + help="Path to moonshine preprocessor model",
  229 + )
  230 +
  231 + parser.add_argument(
  232 + "--moonshine-encoder",
  233 + default="",
  234 + type=str,
  235 + help="Path to moonshine encoder model",
  236 + )
  237 +
  238 + parser.add_argument(
  239 + "--moonshine-uncached-decoder",
  240 + default="",
  241 + type=str,
  242 + help="Path to moonshine uncached decoder model",
  243 + )
  244 +
  245 + parser.add_argument(
  246 + "--moonshine-cached-decoder",
  247 + default="",
  248 + type=str,
  249 + help="Path to moonshine cached decoder model",
  250 + )
  251 +
  252 + parser.add_argument(
212 "--decoding-method", 253 "--decoding-method",
213 type=str, 254 type=str,
214 default="greedy_search", 255 default="greedy_search",
@@ -263,6 +304,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -263,6 +304,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
263 assert len(args.wenet_ctc) == 0, args.wenet_ctc 304 assert len(args.wenet_ctc) == 0, args.wenet_ctc
264 assert len(args.whisper_encoder) == 0, args.whisper_encoder 305 assert len(args.whisper_encoder) == 0, args.whisper_encoder
265 assert len(args.whisper_decoder) == 0, args.whisper_decoder 306 assert len(args.whisper_decoder) == 0, args.whisper_decoder
  307 + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
  308 + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
  309 + assert (
  310 + len(args.moonshine_uncached_decoder) == 0
  311 + ), args.moonshine_uncached_decoder
  312 + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
266 313
267 assert_file_exists(args.encoder) 314 assert_file_exists(args.encoder)
268 assert_file_exists(args.decoder) 315 assert_file_exists(args.decoder)
@@ -284,6 +331,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -284,6 +331,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
284 assert len(args.wenet_ctc) == 0, args.wenet_ctc 331 assert len(args.wenet_ctc) == 0, args.wenet_ctc
285 assert len(args.whisper_encoder) == 0, args.whisper_encoder 332 assert len(args.whisper_encoder) == 0, args.whisper_encoder
286 assert len(args.whisper_decoder) == 0, args.whisper_decoder 333 assert len(args.whisper_decoder) == 0, args.whisper_decoder
  334 + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
  335 + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
  336 + assert (
  337 + len(args.moonshine_uncached_decoder) == 0
  338 + ), args.moonshine_uncached_decoder
  339 + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
287 340
288 assert_file_exists(args.paraformer) 341 assert_file_exists(args.paraformer)
289 342
@@ -300,6 +353,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -300,6 +353,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
300 assert len(args.wenet_ctc) == 0, args.wenet_ctc 353 assert len(args.wenet_ctc) == 0, args.wenet_ctc
301 assert len(args.whisper_encoder) == 0, args.whisper_encoder 354 assert len(args.whisper_encoder) == 0, args.whisper_encoder
302 assert len(args.whisper_decoder) == 0, args.whisper_decoder 355 assert len(args.whisper_decoder) == 0, args.whisper_decoder
  356 + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
  357 + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
  358 + assert (
  359 + len(args.moonshine_uncached_decoder) == 0
  360 + ), args.moonshine_uncached_decoder
  361 + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
303 362
304 assert_file_exists(args.sense_voice) 363 assert_file_exists(args.sense_voice)
305 recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice( 364 recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
@@ -312,6 +371,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -312,6 +371,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
312 elif args.wenet_ctc: 371 elif args.wenet_ctc:
313 assert len(args.whisper_encoder) == 0, args.whisper_encoder 372 assert len(args.whisper_encoder) == 0, args.whisper_encoder
314 assert len(args.whisper_decoder) == 0, args.whisper_decoder 373 assert len(args.whisper_decoder) == 0, args.whisper_decoder
  374 + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
  375 + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
  376 + assert (
  377 + len(args.moonshine_uncached_decoder) == 0
  378 + ), args.moonshine_uncached_decoder
  379 + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
315 380
316 assert_file_exists(args.wenet_ctc) 381 assert_file_exists(args.wenet_ctc)
317 382
@@ -327,6 +392,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -327,6 +392,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
327 elif args.whisper_encoder: 392 elif args.whisper_encoder:
328 assert_file_exists(args.whisper_encoder) 393 assert_file_exists(args.whisper_encoder)
329 assert_file_exists(args.whisper_decoder) 394 assert_file_exists(args.whisper_decoder)
  395 + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
  396 + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
  397 + assert (
  398 + len(args.moonshine_uncached_decoder) == 0
  399 + ), args.moonshine_uncached_decoder
  400 + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
330 401
331 recognizer = sherpa_onnx.OfflineRecognizer.from_whisper( 402 recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
332 encoder=args.whisper_encoder, 403 encoder=args.whisper_encoder,
@@ -339,6 +410,22 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -339,6 +410,22 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
339 task=args.whisper_task, 410 task=args.whisper_task,
340 tail_paddings=args.whisper_tail_paddings, 411 tail_paddings=args.whisper_tail_paddings,
341 ) 412 )
  413 + elif args.moonshine_preprocessor:
  414 + assert_file_exists(args.moonshine_preprocessor)
  415 + assert_file_exists(args.moonshine_encoder)
  416 + assert_file_exists(args.moonshine_uncached_decoder)
  417 + assert_file_exists(args.moonshine_cached_decoder)
  418 +
  419 + recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine(
  420 + preprocessor=args.moonshine_preprocessor,
  421 + encoder=args.moonshine_encoder,
  422 + uncached_decoder=args.moonshine_uncached_decoder,
  423 + cached_decoder=args.moonshine_cached_decoder,
  424 + tokens=args.tokens,
  425 + num_threads=args.num_threads,
  426 + decoding_method=args.decoding_method,
  427 + debug=args.debug,
  428 + )
342 else: 429 else:
343 raise ValueError("Please specify at least one model") 430 raise ValueError("Please specify at least one model")
344 431
@@ -424,28 +511,32 @@ def main(): @@ -424,28 +511,32 @@ def main():
424 segment_list = [] 511 segment_list = []
425 512
426 print("Started!") 513 print("Started!")
  514 + start_t = dt.datetime.now()
  515 + num_processed_samples = 0
427 516
428 - is_silence = False 517 + is_eof = False
429 # TODO(fangjun): Support multithreads 518 # TODO(fangjun): Support multithreads
430 while True: 519 while True:
431 # *2 because int16_t has two bytes 520 # *2 because int16_t has two bytes
432 data = process.stdout.read(frames_per_read * 2) 521 data = process.stdout.read(frames_per_read * 2)
433 if not data: 522 if not data:
434 - if is_silence: 523 + if is_eof:
435 break 524 break
436 - is_silence = True  
437 - # The converted audio file does not have a mute data of 1 second or more at the end, which will result in the loss of the last segment data 525 + is_eof = True
  526 + # pad 1 second at the end of the file for the VAD
438 data = np.zeros(1 * args.sample_rate, dtype=np.int16) 527 data = np.zeros(1 * args.sample_rate, dtype=np.int16)
439 528
440 samples = np.frombuffer(data, dtype=np.int16) 529 samples = np.frombuffer(data, dtype=np.int16)
441 samples = samples.astype(np.float32) / 32768 530 samples = samples.astype(np.float32) / 32768
442 531
  532 + num_processed_samples += samples.shape[0]
  533 +
443 buffer = np.concatenate([buffer, samples]) 534 buffer = np.concatenate([buffer, samples])
444 while len(buffer) > window_size: 535 while len(buffer) > window_size:
445 vad.accept_waveform(buffer[:window_size]) 536 vad.accept_waveform(buffer[:window_size])
446 buffer = buffer[window_size:] 537 buffer = buffer[window_size:]
447 538
448 - if is_silence: 539 + if is_eof:
449 vad.flush() 540 vad.flush()
450 541
451 streams = [] 542 streams = []
@@ -471,6 +562,11 @@ def main(): @@ -471,6 +562,11 @@ def main():
471 seg.text = stream.result.text 562 seg.text = stream.result.text
472 segment_list.append(seg) 563 segment_list.append(seg)
473 564
  565 + end_t = dt.datetime.now()
  566 + elapsed_seconds = (end_t - start_t).total_seconds()
  567 + duration = num_processed_samples / 16000
  568 + rtf = elapsed_seconds / duration
  569 +
474 srt_filename = Path(args.sound_file).with_suffix(".srt") 570 srt_filename = Path(args.sound_file).with_suffix(".srt")
475 with open(srt_filename, "w", encoding="utf-8") as f: 571 with open(srt_filename, "w", encoding="utf-8") as f:
476 for i, seg in enumerate(segment_list): 572 for i, seg in enumerate(segment_list):
@@ -479,6 +575,9 @@ def main(): @@ -479,6 +575,9 @@ def main():
479 print("", file=f) 575 print("", file=f)
480 576
481 print(f"Saved to {srt_filename}") 577 print(f"Saved to {srt_filename}")
  578 + print(f"Audio duration:\t{duration:.3f} s")
  579 + print(f"Elapsed:\t{elapsed_seconds:.3f} s")
  580 + print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
482 print("Done!") 581 print("Done!")
483 582
484 583
@@ -66,7 +66,21 @@ python3 ./python-api-examples/non_streaming_server.py \ @@ -66,7 +66,21 @@ python3 ./python-api-examples/non_streaming_server.py \
66 --wenet-ctc ./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \ 66 --wenet-ctc ./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
67 --tokens ./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt 67 --tokens ./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt
68 68
69 -(5) Use a Whisper model 69 +(5) Use a Moonshine model
  70 +
  71 +cd /path/to/sherpa-onnx
  72 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
  73 +tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
  74 +rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
  75 +
  76 +python3 ./python-api-examples/non_streaming_server.py \
  77 + --moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \
  78 + --moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \
  79 + --moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \
  80 + --moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \
  81 + --tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt
  82 +
  83 +(6) Use a Whisper model
70 84
71 cd /path/to/sherpa-onnx 85 cd /path/to/sherpa-onnx
72 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2 86 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2
@@ -78,7 +92,7 @@ python3 ./python-api-examples/non_streaming_server.py \ @@ -78,7 +92,7 @@ python3 ./python-api-examples/non_streaming_server.py \
78 --whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \ 92 --whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \
79 --tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt 93 --tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt
80 94
81 -(5) Use a tdnn model of the yesno recipe from icefall 95 +(7) Use a tdnn model of the yesno recipe from icefall
82 96
83 cd /path/to/sherpa-onnx 97 cd /path/to/sherpa-onnx
84 98
@@ -92,7 +106,7 @@ python3 ./python-api-examples/non_streaming_server.py \ @@ -92,7 +106,7 @@ python3 ./python-api-examples/non_streaming_server.py \
92 --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \ 106 --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
93 --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt 107 --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt
94 108
95 -(6) Use a Non-streaming SenseVoice model 109 +(8) Use a Non-streaming SenseVoice model
96 110
97 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 111 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
98 tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 112 tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
@@ -254,6 +268,36 @@ def add_tdnn_ctc_model_args(parser: argparse.ArgumentParser): @@ -254,6 +268,36 @@ def add_tdnn_ctc_model_args(parser: argparse.ArgumentParser):
254 ) 268 )
255 269
256 270
  271 +def add_moonshine_model_args(parser: argparse.ArgumentParser):
  272 + parser.add_argument(
  273 + "--moonshine-preprocessor",
  274 + default="",
  275 + type=str,
  276 + help="Path to moonshine preprocessor model",
  277 + )
  278 +
  279 + parser.add_argument(
  280 + "--moonshine-encoder",
  281 + default="",
  282 + type=str,
  283 + help="Path to moonshine encoder model",
  284 + )
  285 +
  286 + parser.add_argument(
  287 + "--moonshine-uncached-decoder",
  288 + default="",
  289 + type=str,
  290 + help="Path to moonshine uncached decoder model",
  291 + )
  292 +
  293 + parser.add_argument(
  294 + "--moonshine-cached-decoder",
  295 + default="",
  296 + type=str,
  297 + help="Path to moonshine cached decoder model",
  298 + )
  299 +
  300 +
257 def add_whisper_model_args(parser: argparse.ArgumentParser): 301 def add_whisper_model_args(parser: argparse.ArgumentParser):
258 parser.add_argument( 302 parser.add_argument(
259 "--whisper-encoder", 303 "--whisper-encoder",
@@ -311,6 +355,7 @@ def add_model_args(parser: argparse.ArgumentParser): @@ -311,6 +355,7 @@ def add_model_args(parser: argparse.ArgumentParser):
311 add_wenet_ctc_model_args(parser) 355 add_wenet_ctc_model_args(parser)
312 add_tdnn_ctc_model_args(parser) 356 add_tdnn_ctc_model_args(parser)
313 add_whisper_model_args(parser) 357 add_whisper_model_args(parser)
  358 + add_moonshine_model_args(parser)
314 359
315 parser.add_argument( 360 parser.add_argument(
316 "--tokens", 361 "--tokens",
@@ -876,6 +921,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -876,6 +921,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
876 assert len(args.whisper_encoder) == 0, args.whisper_encoder 921 assert len(args.whisper_encoder) == 0, args.whisper_encoder
877 assert len(args.whisper_decoder) == 0, args.whisper_decoder 922 assert len(args.whisper_decoder) == 0, args.whisper_decoder
878 assert len(args.tdnn_model) == 0, args.tdnn_model 923 assert len(args.tdnn_model) == 0, args.tdnn_model
  924 + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
  925 + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
  926 + assert (
  927 + len(args.moonshine_uncached_decoder) == 0
  928 + ), args.moonshine_uncached_decoder
  929 + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
879 930
880 assert_file_exists(args.encoder) 931 assert_file_exists(args.encoder)
881 assert_file_exists(args.decoder) 932 assert_file_exists(args.decoder)
@@ -903,6 +954,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -903,6 +954,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
903 assert len(args.whisper_encoder) == 0, args.whisper_encoder 954 assert len(args.whisper_encoder) == 0, args.whisper_encoder
904 assert len(args.whisper_decoder) == 0, args.whisper_decoder 955 assert len(args.whisper_decoder) == 0, args.whisper_decoder
905 assert len(args.tdnn_model) == 0, args.tdnn_model 956 assert len(args.tdnn_model) == 0, args.tdnn_model
  957 + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
  958 + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
  959 + assert (
  960 + len(args.moonshine_uncached_decoder) == 0
  961 + ), args.moonshine_uncached_decoder
  962 + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
906 963
907 assert_file_exists(args.paraformer) 964 assert_file_exists(args.paraformer)
908 965
@@ -921,6 +978,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -921,6 +978,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
921 assert len(args.whisper_encoder) == 0, args.whisper_encoder 978 assert len(args.whisper_encoder) == 0, args.whisper_encoder
922 assert len(args.whisper_decoder) == 0, args.whisper_decoder 979 assert len(args.whisper_decoder) == 0, args.whisper_decoder
923 assert len(args.tdnn_model) == 0, args.tdnn_model 980 assert len(args.tdnn_model) == 0, args.tdnn_model
  981 + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
  982 + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
  983 + assert (
  984 + len(args.moonshine_uncached_decoder) == 0
  985 + ), args.moonshine_uncached_decoder
  986 + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
924 987
925 assert_file_exists(args.sense_voice) 988 assert_file_exists(args.sense_voice)
926 recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice( 989 recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
@@ -934,6 +997,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -934,6 +997,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
934 assert len(args.whisper_encoder) == 0, args.whisper_encoder 997 assert len(args.whisper_encoder) == 0, args.whisper_encoder
935 assert len(args.whisper_decoder) == 0, args.whisper_decoder 998 assert len(args.whisper_decoder) == 0, args.whisper_decoder
936 assert len(args.tdnn_model) == 0, args.tdnn_model 999 assert len(args.tdnn_model) == 0, args.tdnn_model
  1000 + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
  1001 + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
  1002 + assert (
  1003 + len(args.moonshine_uncached_decoder) == 0
  1004 + ), args.moonshine_uncached_decoder
  1005 + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
937 1006
938 assert_file_exists(args.nemo_ctc) 1007 assert_file_exists(args.nemo_ctc)
939 1008
@@ -950,6 +1019,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -950,6 +1019,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
950 assert len(args.whisper_encoder) == 0, args.whisper_encoder 1019 assert len(args.whisper_encoder) == 0, args.whisper_encoder
951 assert len(args.whisper_decoder) == 0, args.whisper_decoder 1020 assert len(args.whisper_decoder) == 0, args.whisper_decoder
952 assert len(args.tdnn_model) == 0, args.tdnn_model 1021 assert len(args.tdnn_model) == 0, args.tdnn_model
  1022 + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
  1023 + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
  1024 + assert (
  1025 + len(args.moonshine_uncached_decoder) == 0
  1026 + ), args.moonshine_uncached_decoder
  1027 + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
953 1028
954 assert_file_exists(args.wenet_ctc) 1029 assert_file_exists(args.wenet_ctc)
955 1030
@@ -966,6 +1041,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -966,6 +1041,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
966 assert len(args.tdnn_model) == 0, args.tdnn_model 1041 assert len(args.tdnn_model) == 0, args.tdnn_model
967 assert_file_exists(args.whisper_encoder) 1042 assert_file_exists(args.whisper_encoder)
968 assert_file_exists(args.whisper_decoder) 1043 assert_file_exists(args.whisper_decoder)
  1044 + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
  1045 + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
  1046 + assert (
  1047 + len(args.moonshine_uncached_decoder) == 0
  1048 + ), args.moonshine_uncached_decoder
  1049 + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
969 1050
970 recognizer = sherpa_onnx.OfflineRecognizer.from_whisper( 1051 recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
971 encoder=args.whisper_encoder, 1052 encoder=args.whisper_encoder,
@@ -980,6 +1061,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -980,6 +1061,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
980 ) 1061 )
981 elif args.tdnn_model: 1062 elif args.tdnn_model:
982 assert_file_exists(args.tdnn_model) 1063 assert_file_exists(args.tdnn_model)
  1064 + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
  1065 + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
  1066 + assert (
  1067 + len(args.moonshine_uncached_decoder) == 0
  1068 + ), args.moonshine_uncached_decoder
  1069 + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
983 1070
984 recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc( 1071 recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc(
985 model=args.tdnn_model, 1072 model=args.tdnn_model,
@@ -990,6 +1077,21 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -990,6 +1077,21 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
990 decoding_method=args.decoding_method, 1077 decoding_method=args.decoding_method,
991 provider=args.provider, 1078 provider=args.provider,
992 ) 1079 )
  1080 + elif args.moonshine_preprocessor:
  1081 + assert_file_exists(args.moonshine_preprocessor)
  1082 + assert_file_exists(args.moonshine_encoder)
  1083 + assert_file_exists(args.moonshine_uncached_decoder)
  1084 + assert_file_exists(args.moonshine_cached_decoder)
  1085 +
  1086 + recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine(
  1087 + preprocessor=args.moonshine_preprocessor,
  1088 + encoder=args.moonshine_encoder,
  1089 + uncached_decoder=args.moonshine_uncached_decoder,
  1090 + cached_decoder=args.moonshine_cached_decoder,
  1091 + tokens=args.tokens,
  1092 + num_threads=args.num_threads,
  1093 + decoding_method=args.decoding_method,
  1094 + )
993 else: 1095 else:
994 raise ValueError("Please specify at least one model") 1096 raise ValueError("Please specify at least one model")
995 1097
  1 +#!/usr/bin/env python3
  2 +
  3 +"""
  4 +This file shows how to use a non-streaming Moonshine model from
  5 +https://github.com/usefulsensors/moonshine
  6 +to decode files.
  7 +
  8 +Please download model files from
  9 +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
  10 +
  11 +For instance,
  12 +
  13 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
  14 +tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
  15 +rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
  16 +"""
  17 +
  18 +import datetime as dt
  19 +from pathlib import Path
  20 +
  21 +import sherpa_onnx
  22 +import soundfile as sf
  23 +
  24 +
  25 +def create_recognizer():
  26 + preprocessor = "./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx"
  27 + encoder = "./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx"
  28 + uncached_decoder = "./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx"
  29 + cached_decoder = "./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx"
  30 +
  31 + tokens = "./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt"
  32 + test_wav = "./sherpa-onnx-moonshine-tiny-en-int8/test_wavs/0.wav"
  33 +
  34 + if not Path(preprocessor).is_file() or not Path(test_wav).is_file():
  35 + raise ValueError(
  36 + """Please download model files from
  37 + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
  38 + """
  39 + )
  40 + return (
  41 + sherpa_onnx.OfflineRecognizer.from_moonshine(
  42 + preprocessor=preprocessor,
  43 + encoder=encoder,
  44 + uncached_decoder=uncached_decoder,
  45 + cached_decoder=cached_decoder,
  46 + tokens=tokens,
  47 + debug=True,
  48 + ),
  49 + test_wav,
  50 + )
  51 +
  52 +
  53 +def main():
  54 + recognizer, wave_filename = create_recognizer()
  55 +
  56 + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
  57 + audio = audio[:, 0] # only use the first channel
  58 +
  59 + # audio is a 1-D float32 numpy array normalized to the range [-1, 1]
  60 + # sample_rate does not need to be 16000 Hz
  61 +
  62 + start_t = dt.datetime.now()
  63 +
  64 + stream = recognizer.create_stream()
  65 + stream.accept_waveform(sample_rate, audio)
  66 + recognizer.decode_stream(stream)
  67 +
  68 + end_t = dt.datetime.now()
  69 + elapsed_seconds = (end_t - start_t).total_seconds()
  70 + duration = audio.shape[-1] / sample_rate
  71 + rtf = elapsed_seconds / duration
  72 +
  73 + print(stream.result)
  74 + print(wave_filename)
  75 + print("Text:", stream.result.text)
  76 + print(f"Audio duration:\t{duration:.3f} s")
  77 + print(f"Elapsed:\t{elapsed_seconds:.3f} s")
  78 + print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
  79 +
  80 +
  81 +if __name__ == "__main__":
  82 + main()
  1 +#!/usr/bin/env python3
  2 +
  3 +"""
  4 +This file shows how to use a non-streaming whisper model from
  5 +https://github.com/openai/whisper
  6 +to decode files.
  7 +
  8 +Please download model files from
  9 +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
  10 +
  11 +For instance,
  12 +
  13 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2
  14 +tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2
  15 +rm sherpa-onnx-whisper-tiny.en.tar.bz2
  16 +"""
  17 +
  18 +import datetime as dt
  19 +from pathlib import Path
  20 +
  21 +import sherpa_onnx
  22 +import soundfile as sf
  23 +
  24 +
  25 +def create_recognizer():
  26 + encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx"
  27 + decoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx"
  28 + tokens = "./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt"
  29 + test_wav = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav"
  30 +
  31 + if not Path(encoder).is_file() or not Path(test_wav).is_file():
  32 + raise ValueError(
  33 + """Please download model files from
  34 + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
  35 + """
  36 + )
  37 + return (
  38 + sherpa_onnx.OfflineRecognizer.from_whisper(
  39 + encoder=encoder,
  40 + decoder=decoder,
  41 + tokens=tokens,
  42 + debug=True,
  43 + ),
  44 + test_wav,
  45 + )
  46 +
  47 +
  48 +def main():
  49 + recognizer, wave_filename = create_recognizer()
  50 +
  51 + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
  52 + audio = audio[:, 0] # only use the first channel
  53 +
  54 + # audio is a 1-D float32 numpy array normalized to the range [-1, 1]
  55 + # sample_rate does not need to be 16000 Hz
  56 +
  57 + start_t = dt.datetime.now()
  58 +
  59 + stream = recognizer.create_stream()
  60 + stream.accept_waveform(sample_rate, audio)
  61 + recognizer.decode_stream(stream)
  62 +
  63 + end_t = dt.datetime.now()
  64 + elapsed_seconds = (end_t - start_t).total_seconds()
  65 + duration = audio.shape[-1] / sample_rate
  66 + rtf = elapsed_seconds / duration
  67 +
  68 + print(stream.result)
  69 + print(wave_filename)
  70 + print("Text:", stream.result.text)
  71 + print(f"Audio duration:\t{duration:.3f} s")
  72 + print(f"Elapsed:\t{elapsed_seconds:.3f} s")
  73 + print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
  74 +
  75 +
  76 +if __name__ == "__main__":
  77 + main()
@@ -35,7 +35,18 @@ Note that you need a non-streaming model for this script. @@ -35,7 +35,18 @@ Note that you need a non-streaming model for this script.
35 --sample-rate=16000 \ 35 --sample-rate=16000 \
36 --feature-dim=80 36 --feature-dim=80
37 37
38 -(3) For Whisper models 38 +(3) For Moonshine models
  39 +
  40 +./python-api-examples/vad-with-non-streaming-asr.py \
  41 + --silero-vad-model=/path/to/silero_vad.onnx \
  42 + --moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \
  43 + --moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \
  44 + --moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \
  45 + --moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \
  46 + --tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt \
  47 + --num-threads=2
  48 +
  49 +(4) For Whisper models
39 50
40 ./python-api-examples/vad-with-non-streaming-asr.py \ 51 ./python-api-examples/vad-with-non-streaming-asr.py \
41 --silero-vad-model=/path/to/silero_vad.onnx \ 52 --silero-vad-model=/path/to/silero_vad.onnx \
@@ -45,7 +56,7 @@ Note that you need a non-streaming model for this script. @@ -45,7 +56,7 @@ Note that you need a non-streaming model for this script.
45 --whisper-task=transcribe \ 56 --whisper-task=transcribe \
46 --num-threads=2 57 --num-threads=2
47 58
48 -(4) For SenseVoice CTC models 59 +(5) For SenseVoice CTC models
49 60
50 ./python-api-examples/vad-with-non-streaming-asr.py \ 61 ./python-api-examples/vad-with-non-streaming-asr.py \
51 --silero-vad-model=/path/to/silero_vad.onnx \ 62 --silero-vad-model=/path/to/silero_vad.onnx \
@@ -193,6 +204,34 @@ def get_args(): @@ -193,6 +204,34 @@ def get_args():
193 ) 204 )
194 205
195 parser.add_argument( 206 parser.add_argument(
  207 + "--moonshine-preprocessor",
  208 + default="",
  209 + type=str,
  210 + help="Path to moonshine preprocessor model",
  211 + )
  212 +
  213 + parser.add_argument(
  214 + "--moonshine-encoder",
  215 + default="",
  216 + type=str,
  217 + help="Path to moonshine encoder model",
  218 + )
  219 +
  220 + parser.add_argument(
  221 + "--moonshine-uncached-decoder",
  222 + default="",
  223 + type=str,
  224 + help="Path to moonshine uncached decoder model",
  225 + )
  226 +
  227 + parser.add_argument(
  228 + "--moonshine-cached-decoder",
  229 + default="",
  230 + type=str,
  231 + help="Path to moonshine cached decoder model",
  232 + )
  233 +
  234 + parser.add_argument(
196 "--blank-penalty", 235 "--blank-penalty",
197 type=float, 236 type=float,
198 default=0.0, 237 default=0.0,
@@ -251,6 +290,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -251,6 +290,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
251 assert len(args.sense_voice) == 0, args.sense_voice 290 assert len(args.sense_voice) == 0, args.sense_voice
252 assert len(args.whisper_encoder) == 0, args.whisper_encoder 291 assert len(args.whisper_encoder) == 0, args.whisper_encoder
253 assert len(args.whisper_decoder) == 0, args.whisper_decoder 292 assert len(args.whisper_decoder) == 0, args.whisper_decoder
  293 + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
  294 + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
  295 + assert (
  296 + len(args.moonshine_uncached_decoder) == 0
  297 + ), args.moonshine_uncached_decoder
  298 + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
254 299
255 assert_file_exists(args.encoder) 300 assert_file_exists(args.encoder)
256 assert_file_exists(args.decoder) 301 assert_file_exists(args.decoder)
@@ -272,6 +317,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -272,6 +317,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
272 assert len(args.sense_voice) == 0, args.sense_voice 317 assert len(args.sense_voice) == 0, args.sense_voice
273 assert len(args.whisper_encoder) == 0, args.whisper_encoder 318 assert len(args.whisper_encoder) == 0, args.whisper_encoder
274 assert len(args.whisper_decoder) == 0, args.whisper_decoder 319 assert len(args.whisper_decoder) == 0, args.whisper_decoder
  320 + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
  321 + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
  322 + assert (
  323 + len(args.moonshine_uncached_decoder) == 0
  324 + ), args.moonshine_uncached_decoder
  325 + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
275 326
276 assert_file_exists(args.paraformer) 327 assert_file_exists(args.paraformer)
277 328
@@ -287,6 +338,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -287,6 +338,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
287 elif args.sense_voice: 338 elif args.sense_voice:
288 assert len(args.whisper_encoder) == 0, args.whisper_encoder 339 assert len(args.whisper_encoder) == 0, args.whisper_encoder
289 assert len(args.whisper_decoder) == 0, args.whisper_decoder 340 assert len(args.whisper_decoder) == 0, args.whisper_decoder
  341 + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
  342 + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
  343 + assert (
  344 + len(args.moonshine_uncached_decoder) == 0
  345 + ), args.moonshine_uncached_decoder
  346 + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
290 347
291 assert_file_exists(args.sense_voice) 348 assert_file_exists(args.sense_voice)
292 recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice( 349 recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
@@ -299,6 +356,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -299,6 +356,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
299 elif args.whisper_encoder: 356 elif args.whisper_encoder:
300 assert_file_exists(args.whisper_encoder) 357 assert_file_exists(args.whisper_encoder)
301 assert_file_exists(args.whisper_decoder) 358 assert_file_exists(args.whisper_decoder)
  359 + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
  360 + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
  361 + assert (
  362 + len(args.moonshine_uncached_decoder) == 0
  363 + ), args.moonshine_uncached_decoder
  364 + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
302 365
303 recognizer = sherpa_onnx.OfflineRecognizer.from_whisper( 366 recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
304 encoder=args.whisper_encoder, 367 encoder=args.whisper_encoder,
@@ -311,6 +374,22 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -311,6 +374,22 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
311 task=args.whisper_task, 374 task=args.whisper_task,
312 tail_paddings=args.whisper_tail_paddings, 375 tail_paddings=args.whisper_tail_paddings,
313 ) 376 )
  377 + elif args.moonshine_preprocessor:
  378 + assert_file_exists(args.moonshine_preprocessor)
  379 + assert_file_exists(args.moonshine_encoder)
  380 + assert_file_exists(args.moonshine_uncached_decoder)
  381 + assert_file_exists(args.moonshine_cached_decoder)
  382 +
  383 + recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine(
  384 + preprocessor=args.moonshine_preprocessor,
  385 + encoder=args.moonshine_encoder,
  386 + uncached_decoder=args.moonshine_uncached_decoder,
  387 + cached_decoder=args.moonshine_cached_decoder,
  388 + tokens=args.tokens,
  389 + num_threads=args.num_threads,
  390 + decoding_method=args.decoding_method,
  391 + debug=args.debug,
  392 + )
314 else: 393 else:
315 raise ValueError("Please specify at least one model") 394 raise ValueError("Please specify at least one model")
316 395
@@ -29,6 +29,9 @@ set(sources @@ -29,6 +29,9 @@ set(sources
29 offline-lm-config.cc 29 offline-lm-config.cc
30 offline-lm.cc 30 offline-lm.cc
31 offline-model-config.cc 31 offline-model-config.cc
  32 + offline-moonshine-greedy-search-decoder.cc
  33 + offline-moonshine-model-config.cc
  34 + offline-moonshine-model.cc
32 offline-nemo-enc-dec-ctc-model-config.cc 35 offline-nemo-enc-dec-ctc-model-config.cc
33 offline-nemo-enc-dec-ctc-model.cc 36 offline-nemo-enc-dec-ctc-model.cc
34 offline-paraformer-greedy-search-decoder.cc 37 offline-paraformer-greedy-search-decoder.cc
@@ -19,6 +19,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { @@ -19,6 +19,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
19 zipformer_ctc.Register(po); 19 zipformer_ctc.Register(po);
20 wenet_ctc.Register(po); 20 wenet_ctc.Register(po);
21 sense_voice.Register(po); 21 sense_voice.Register(po);
  22 + moonshine.Register(po);
22 23
23 po->Register("telespeech-ctc", &telespeech_ctc, 24 po->Register("telespeech-ctc", &telespeech_ctc,
24 "Path to model.onnx for telespeech ctc"); 25 "Path to model.onnx for telespeech ctc");
@@ -99,6 +100,10 @@ bool OfflineModelConfig::Validate() const { @@ -99,6 +100,10 @@ bool OfflineModelConfig::Validate() const {
99 return sense_voice.Validate(); 100 return sense_voice.Validate();
100 } 101 }
101 102
  103 + if (!moonshine.preprocessor.empty()) {
  104 + return moonshine.Validate();
  105 + }
  106 +
102 if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) { 107 if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) {
103 SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist", 108 SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist",
104 telespeech_ctc.c_str()); 109 telespeech_ctc.c_str());
@@ -124,6 +129,7 @@ std::string OfflineModelConfig::ToString() const { @@ -124,6 +129,7 @@ std::string OfflineModelConfig::ToString() const {
124 os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", "; 129 os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
125 os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; 130 os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
126 os << "sense_voice=" << sense_voice.ToString() << ", "; 131 os << "sense_voice=" << sense_voice.ToString() << ", ";
  132 + os << "moonshine=" << moonshine.ToString() << ", ";
127 os << "telespeech_ctc=\"" << telespeech_ctc << "\", "; 133 os << "telespeech_ctc=\"" << telespeech_ctc << "\", ";
128 os << "tokens=\"" << tokens << "\", "; 134 os << "tokens=\"" << tokens << "\", ";
129 os << "num_threads=" << num_threads << ", "; 135 os << "num_threads=" << num_threads << ", ";
@@ -6,6 +6,7 @@ @@ -6,6 +6,7 @@
6 6
7 #include <string> 7 #include <string>
8 8
  9 +#include "sherpa-onnx/csrc/offline-moonshine-model-config.h"
9 #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h" 10 #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
10 #include "sherpa-onnx/csrc/offline-paraformer-model-config.h" 11 #include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
11 #include "sherpa-onnx/csrc/offline-sense-voice-model-config.h" 12 #include "sherpa-onnx/csrc/offline-sense-voice-model-config.h"
@@ -26,6 +27,7 @@ struct OfflineModelConfig { @@ -26,6 +27,7 @@ struct OfflineModelConfig {
26 OfflineZipformerCtcModelConfig zipformer_ctc; 27 OfflineZipformerCtcModelConfig zipformer_ctc;
27 OfflineWenetCtcModelConfig wenet_ctc; 28 OfflineWenetCtcModelConfig wenet_ctc;
28 OfflineSenseVoiceModelConfig sense_voice; 29 OfflineSenseVoiceModelConfig sense_voice;
  30 + OfflineMoonshineModelConfig moonshine;
29 std::string telespeech_ctc; 31 std::string telespeech_ctc;
30 32
31 std::string tokens; 33 std::string tokens;
@@ -56,6 +58,7 @@ struct OfflineModelConfig { @@ -56,6 +58,7 @@ struct OfflineModelConfig {
56 const OfflineZipformerCtcModelConfig &zipformer_ctc, 58 const OfflineZipformerCtcModelConfig &zipformer_ctc,
57 const OfflineWenetCtcModelConfig &wenet_ctc, 59 const OfflineWenetCtcModelConfig &wenet_ctc,
58 const OfflineSenseVoiceModelConfig &sense_voice, 60 const OfflineSenseVoiceModelConfig &sense_voice,
  61 + const OfflineMoonshineModelConfig &moonshine,
59 const std::string &telespeech_ctc, 62 const std::string &telespeech_ctc,
60 const std::string &tokens, int32_t num_threads, bool debug, 63 const std::string &tokens, int32_t num_threads, bool debug,
61 const std::string &provider, const std::string &model_type, 64 const std::string &provider, const std::string &model_type,
@@ -69,6 +72,7 @@ struct OfflineModelConfig { @@ -69,6 +72,7 @@ struct OfflineModelConfig {
69 zipformer_ctc(zipformer_ctc), 72 zipformer_ctc(zipformer_ctc),
70 wenet_ctc(wenet_ctc), 73 wenet_ctc(wenet_ctc),
71 sense_voice(sense_voice), 74 sense_voice(sense_voice),
  75 + moonshine(moonshine),
72 telespeech_ctc(telespeech_ctc), 76 telespeech_ctc(telespeech_ctc),
73 tokens(tokens), 77 tokens(tokens),
74 num_threads(num_threads), 78 num_threads(num_threads),
  1 +// sherpa-onnx/csrc/offline-moonshine-decoder.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_DECODER_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_DECODER_H_
  7 +
  8 +#include <vector>
  9 +
  10 +#include "onnxruntime_cxx_api.h" // NOLINT
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +struct OfflineMoonshineDecoderResult {
  15 + /// The decoded token IDs
  16 + std::vector<int32_t> tokens;
  17 +};
  18 +
  19 +class OfflineMoonshineDecoder {
  20 + public:
  21 + virtual ~OfflineMoonshineDecoder() = default;
  22 +
  23 + /** Run beam search given the output from the moonshine encoder model.
  24 + *
  25 + * @param encoder_out A 3-D tensor of shape (batch_size, T, dim)
  26 + * @return Return a vector of size `N` containing the decoded results.
  27 + */
  28 + virtual std::vector<OfflineMoonshineDecoderResult> Decode(
  29 + Ort::Value encoder_out) = 0;
  30 +};
  31 +
  32 +} // namespace sherpa_onnx
  33 +
  34 +#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_DECODER_H_
  1 +// sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h"
  6 +
  7 +#include <algorithm>
  8 +#include <utility>
  9 +
  10 +#include "sherpa-onnx/csrc/macros.h"
  11 +#include "sherpa-onnx/csrc/onnx-utils.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +std::vector<OfflineMoonshineDecoderResult>
  16 +OfflineMoonshineGreedySearchDecoder::Decode(Ort::Value encoder_out) {
  17 + auto encoder_out_shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape();
  18 + if (encoder_out_shape[0] != 1) {
  19 + SHERPA_ONNX_LOGE("Support only batch size == 1. Given: %d\n",
  20 + static_cast<int32_t>(encoder_out_shape[0]));
  21 + return {};
  22 + }
  23 +
  24 + auto memory_info =
  25 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  26 +
  27 + // encoder_out_shape[1] * 384 is the number of audio samples
  28 + // 16000 is the sample rate
  29 + //
  30 + //
  31 + // 384 is from the moonshine paper
  32 + int32_t max_len =
  33 + static_cast<int32_t>(encoder_out_shape[1] * 384 / 16000.0 * 6);
  34 +
  35 + int32_t sos = 1;
  36 + int32_t eos = 2;
  37 + int32_t seq_len = 1;
  38 +
  39 + std::vector<int32_t> tokens;
  40 +
  41 + std::array<int64_t, 2> token_shape = {1, 1};
  42 + int64_t seq_len_shape = 1;
  43 +
  44 + Ort::Value token_tensor = Ort::Value::CreateTensor(
  45 + memory_info, &sos, 1, token_shape.data(), token_shape.size());
  46 +
  47 + Ort::Value seq_len_tensor =
  48 + Ort::Value::CreateTensor(memory_info, &seq_len, 1, &seq_len_shape, 1);
  49 +
  50 + Ort::Value logits{nullptr};
  51 + std::vector<Ort::Value> states;
  52 +
  53 + std::tie(logits, states) = model_->ForwardUnCachedDecoder(
  54 + std::move(token_tensor), std::move(seq_len_tensor), View(&encoder_out));
  55 +
  56 + int32_t vocab_size = logits.GetTensorTypeAndShapeInfo().GetShape()[2];
  57 +
  58 + for (int32_t i = 0; i != max_len; ++i) {
  59 + const float *p = logits.GetTensorData<float>();
  60 +
  61 + int32_t max_token_id = static_cast<int32_t>(
  62 + std::distance(p, std::max_element(p, p + vocab_size)));
  63 + if (max_token_id == eos) {
  64 + break;
  65 + }
  66 + tokens.push_back(max_token_id);
  67 +
  68 + seq_len += 1;
  69 +
  70 + token_tensor = Ort::Value::CreateTensor(
  71 + memory_info, &tokens.back(), 1, token_shape.data(), token_shape.size());
  72 +
  73 + seq_len_tensor =
  74 + Ort::Value::CreateTensor(memory_info, &seq_len, 1, &seq_len_shape, 1);
  75 +
  76 + std::tie(logits, states) = model_->ForwardCachedDecoder(
  77 + std::move(token_tensor), std::move(seq_len_tensor), View(&encoder_out),
  78 + std::move(states));
  79 + }
  80 +
  81 + OfflineMoonshineDecoderResult ans;
  82 + ans.tokens = std::move(tokens);
  83 +
  84 + return {ans};
  85 +}
  86 +
  87 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_GREEDY_SEARCH_DECODER_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_GREEDY_SEARCH_DECODER_H_
  7 +
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/offline-moonshine-decoder.h"
  11 +#include "sherpa-onnx/csrc/offline-moonshine-model.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +class OfflineMoonshineGreedySearchDecoder : public OfflineMoonshineDecoder {
  16 + public:
  17 + explicit OfflineMoonshineGreedySearchDecoder(OfflineMoonshineModel *model)
  18 + : model_(model) {}
  19 +
  20 + std::vector<OfflineMoonshineDecoderResult> Decode(
  21 + Ort::Value encoder_out) override;
  22 +
  23 + private:
  24 + OfflineMoonshineModel *model_; // not owned
  25 +};
  26 +
  27 +} // namespace sherpa_onnx
  28 +
  29 +#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_GREEDY_SEARCH_DECODER_H_
  1 +// sherpa-onnx/csrc/offline-moonshine-model-config.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-moonshine-model-config.h"
  6 +
  7 +#include "sherpa-onnx/csrc/file-utils.h"
  8 +#include "sherpa-onnx/csrc/macros.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void OfflineMoonshineModelConfig::Register(ParseOptions *po) {
  13 + po->Register("moonshine-preprocessor", &preprocessor,
  14 + "Path to onnx preprocessor of moonshine, e.g., preprocess.onnx");
  15 +
  16 + po->Register("moonshine-encoder", &encoder,
  17 + "Path to onnx encoder of moonshine, e.g., encode.onnx");
  18 +
  19 + po->Register(
  20 + "moonshine-uncached-decoder", &uncached_decoder,
  21 + "Path to onnx uncached_decoder of moonshine, e.g., uncached_decode.onnx");
  22 +
  23 + po->Register(
  24 + "moonshine-cached-decoder", &cached_decoder,
  25 + "Path to onnx cached_decoder of moonshine, e.g., cached_decode.onnx");
  26 +}
  27 +
  28 +bool OfflineMoonshineModelConfig::Validate() const {
  29 + if (preprocessor.empty()) {
  30 + SHERPA_ONNX_LOGE("Please provide --moonshine-preprocessor");
  31 + return false;
  32 + }
  33 +
  34 + if (!FileExists(preprocessor)) {
  35 + SHERPA_ONNX_LOGE("moonshine preprocessor file '%s' does not exist",
  36 + preprocessor.c_str());
  37 + return false;
  38 + }
  39 +
  40 + if (encoder.empty()) {
  41 + SHERPA_ONNX_LOGE("Please provide --moonshine-encoder");
  42 + return false;
  43 + }
  44 +
  45 + if (!FileExists(encoder)) {
  46 + SHERPA_ONNX_LOGE("moonshine encoder file '%s' does not exist",
  47 + encoder.c_str());
  48 + return false;
  49 + }
  50 +
  51 + if (uncached_decoder.empty()) {
  52 + SHERPA_ONNX_LOGE("Please provide --moonshine-uncached-decoder");
  53 + return false;
  54 + }
  55 +
  56 + if (!FileExists(uncached_decoder)) {
  57 + SHERPA_ONNX_LOGE("moonshine uncached decoder file '%s' does not exist",
  58 + uncached_decoder.c_str());
  59 + return false;
  60 + }
  61 +
  62 + if (cached_decoder.empty()) {
  63 + SHERPA_ONNX_LOGE("Please provide --moonshine-cached-decoder");
  64 + return false;
  65 + }
  66 +
  67 + if (!FileExists(cached_decoder)) {
  68 + SHERPA_ONNX_LOGE("moonshine cached decoder file '%s' does not exist",
  69 + cached_decoder.c_str());
  70 + return false;
  71 + }
  72 +
  73 + return true;
  74 +}
  75 +
  76 +std::string OfflineMoonshineModelConfig::ToString() const {
  77 + std::ostringstream os;
  78 +
  79 + os << "OfflineMoonshineModelConfig(";
  80 + os << "preprocessor=\"" << preprocessor << "\", ";
  81 + os << "encoder=\"" << encoder << "\", ";
  82 + os << "uncached_decoder=\"" << uncached_decoder << "\", ";
  83 + os << "cached_decoder=\"" << cached_decoder << "\")";
  84 +
  85 + return os.str();
  86 +}
  87 +
  88 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-moonshine-model-config.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/parse-options.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +struct OfflineMoonshineModelConfig {
  14 + std::string preprocessor;
  15 + std::string encoder;
  16 + std::string uncached_decoder;
  17 + std::string cached_decoder;
  18 +
  19 + OfflineMoonshineModelConfig() = default;
  20 + OfflineMoonshineModelConfig(const std::string &preprocessor,
  21 + const std::string &encoder,
  22 + const std::string &uncached_decoder,
  23 + const std::string &cached_decoder)
  24 + : preprocessor(preprocessor),
  25 + encoder(encoder),
  26 + uncached_decoder(uncached_decoder),
  27 + cached_decoder(cached_decoder) {}
  28 +
  29 + void Register(ParseOptions *po);
  30 + bool Validate() const;
  31 +
  32 + std::string ToString() const;
  33 +};
  34 +
  35 +} // namespace sherpa_onnx
  36 +
  37 +#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/offline-moonshine-model.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-moonshine-model.h"
  6 +
  7 +#include <string>
  8 +#include <utility>
  9 +#include <vector>
  10 +
  11 +#include "sherpa-onnx/csrc/macros.h"
  12 +#include "sherpa-onnx/csrc/onnx-utils.h"
  13 +#include "sherpa-onnx/csrc/session.h"
  14 +#include "sherpa-onnx/csrc/text-utils.h"
  15 +
  16 +namespace sherpa_onnx {
  17 +
  18 +class OfflineMoonshineModel::Impl {
  19 + public:
  20 + explicit Impl(const OfflineModelConfig &config)
  21 + : config_(config),
  22 + env_(ORT_LOGGING_LEVEL_ERROR),
  23 + sess_opts_(GetSessionOptions(config)),
  24 + allocator_{} {
  25 + {
  26 + auto buf = ReadFile(config.moonshine.preprocessor);
  27 + InitPreprocessor(buf.data(), buf.size());
  28 + }
  29 +
  30 + {
  31 + auto buf = ReadFile(config.moonshine.encoder);
  32 + InitEncoder(buf.data(), buf.size());
  33 + }
  34 +
  35 + {
  36 + auto buf = ReadFile(config.moonshine.uncached_decoder);
  37 + InitUnCachedDecoder(buf.data(), buf.size());
  38 + }
  39 +
  40 + {
  41 + auto buf = ReadFile(config.moonshine.cached_decoder);
  42 + InitCachedDecoder(buf.data(), buf.size());
  43 + }
  44 + }
  45 +
  46 +#if __ANDROID_API__ >= 9
  47 + Impl(AAssetManager *mgr, const OfflineModelConfig &config)
  48 + : config_(config),
  49 + env_(ORT_LOGGING_LEVEL_ERROR),
  50 + sess_opts_(GetSessionOptions(config)),
  51 + allocator_{} {
  52 + {
  53 + auto buf = ReadFile(mgr, config.moonshine.preprocessor);
  54 + InitPreprocessor(buf.data(), buf.size());
  55 + }
  56 +
  57 + {
  58 + auto buf = ReadFile(mgr, config.moonshine.encoder);
  59 + InitEncoder(buf.data(), buf.size());
  60 + }
  61 +
  62 + {
  63 + auto buf = ReadFile(mgr, config.moonshine.uncached_decoder);
  64 + InitUnCachedDecoder(buf.data(), buf.size());
  65 + }
  66 +
  67 + {
  68 + auto buf = ReadFile(mgr, config.moonshine.cached_decoder);
  69 + InitCachedDecoder(buf.data(), buf.size());
  70 + }
  71 + }
  72 +#endif
  73 +
  74 + Ort::Value ForwardPreprocessor(Ort::Value audio) {
  75 + auto features = preprocessor_sess_->Run(
  76 + {}, preprocessor_input_names_ptr_.data(), &audio, 1,
  77 + preprocessor_output_names_ptr_.data(),
  78 + preprocessor_output_names_ptr_.size());
  79 +
  80 + return std::move(features[0]);
  81 + }
  82 +
  83 + Ort::Value ForwardEncoder(Ort::Value features, Ort::Value features_len) {
  84 + std::array<Ort::Value, 2> encoder_inputs{std::move(features),
  85 + std::move(features_len)};
  86 + auto encoder_out = encoder_sess_->Run(
  87 + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
  88 + encoder_inputs.size(), encoder_output_names_ptr_.data(),
  89 + encoder_output_names_ptr_.size());
  90 +
  91 + return std::move(encoder_out[0]);
  92 + }
  93 +
  94 + std::pair<Ort::Value, std::vector<Ort::Value>> ForwardUnCachedDecoder(
  95 + Ort::Value tokens, Ort::Value seq_len, Ort::Value encoder_out) {
  96 + std::array<Ort::Value, 3> uncached_decoder_input = {
  97 + std::move(tokens),
  98 + std::move(encoder_out),
  99 + std::move(seq_len),
  100 + };
  101 +
  102 + auto uncached_decoder_out = uncached_decoder_sess_->Run(
  103 + {}, uncached_decoder_input_names_ptr_.data(),
  104 + uncached_decoder_input.data(), uncached_decoder_input.size(),
  105 + uncached_decoder_output_names_ptr_.data(),
  106 + uncached_decoder_output_names_ptr_.size());
  107 +
  108 + std::vector<Ort::Value> states;
  109 + states.reserve(uncached_decoder_out.size() - 1);
  110 +
  111 + int32_t i = -1;
  112 + for (auto &s : uncached_decoder_out) {
  113 + ++i;
  114 + if (i == 0) {
  115 + continue;
  116 + }
  117 +
  118 + states.push_back(std::move(s));
  119 + }
  120 +
  121 + return {std::move(uncached_decoder_out[0]), std::move(states)};
  122 + }
  123 +
  124 + std::pair<Ort::Value, std::vector<Ort::Value>> ForwardCachedDecoder(
  125 + Ort::Value tokens, Ort::Value seq_len, Ort::Value encoder_out,
  126 + std::vector<Ort::Value> states) {
  127 + std::vector<Ort::Value> cached_decoder_input;
  128 + cached_decoder_input.reserve(3 + states.size());
  129 + cached_decoder_input.push_back(std::move(tokens));
  130 + cached_decoder_input.push_back(std::move(encoder_out));
  131 + cached_decoder_input.push_back(std::move(seq_len));
  132 +
  133 + for (auto &s : states) {
  134 + cached_decoder_input.push_back(std::move(s));
  135 + }
  136 +
  137 + auto cached_decoder_out = cached_decoder_sess_->Run(
  138 + {}, cached_decoder_input_names_ptr_.data(), cached_decoder_input.data(),
  139 + cached_decoder_input.size(), cached_decoder_output_names_ptr_.data(),
  140 + cached_decoder_output_names_ptr_.size());
  141 +
  142 + std::vector<Ort::Value> next_states;
  143 + next_states.reserve(cached_decoder_out.size() - 1);
  144 +
  145 + int32_t i = -1;
  146 + for (auto &s : cached_decoder_out) {
  147 + ++i;
  148 + if (i == 0) {
  149 + continue;
  150 + }
  151 +
  152 + next_states.push_back(std::move(s));
  153 + }
  154 +
  155 + return {std::move(cached_decoder_out[0]), std::move(next_states)};
  156 + }
  157 +
  158 + OrtAllocator *Allocator() const { return allocator_; }
  159 +
  160 + private:
  161 + void InitPreprocessor(void *model_data, size_t model_data_length) {
  162 + preprocessor_sess_ = std::make_unique<Ort::Session>(
  163 + env_, model_data, model_data_length, sess_opts_);
  164 +
  165 + GetInputNames(preprocessor_sess_.get(), &preprocessor_input_names_,
  166 + &preprocessor_input_names_ptr_);
  167 +
  168 + GetOutputNames(preprocessor_sess_.get(), &preprocessor_output_names_,
  169 + &preprocessor_output_names_ptr_);
  170 + }
  171 +
  172 + void InitEncoder(void *model_data, size_t model_data_length) {
  173 + encoder_sess_ = std::make_unique<Ort::Session>(
  174 + env_, model_data, model_data_length, sess_opts_);
  175 +
  176 + GetInputNames(encoder_sess_.get(), &encoder_input_names_,
  177 + &encoder_input_names_ptr_);
  178 +
  179 + GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
  180 + &encoder_output_names_ptr_);
  181 + }
  182 +
  183 + void InitUnCachedDecoder(void *model_data, size_t model_data_length) {
  184 + uncached_decoder_sess_ = std::make_unique<Ort::Session>(
  185 + env_, model_data, model_data_length, sess_opts_);
  186 +
  187 + GetInputNames(uncached_decoder_sess_.get(), &uncached_decoder_input_names_,
  188 + &uncached_decoder_input_names_ptr_);
  189 +
  190 + GetOutputNames(uncached_decoder_sess_.get(),
  191 + &uncached_decoder_output_names_,
  192 + &uncached_decoder_output_names_ptr_);
  193 + }
  194 +
  195 + void InitCachedDecoder(void *model_data, size_t model_data_length) {
  196 + cached_decoder_sess_ = std::make_unique<Ort::Session>(
  197 + env_, model_data, model_data_length, sess_opts_);
  198 +
  199 + GetInputNames(cached_decoder_sess_.get(), &cached_decoder_input_names_,
  200 + &cached_decoder_input_names_ptr_);
  201 +
  202 + GetOutputNames(cached_decoder_sess_.get(), &cached_decoder_output_names_,
  203 + &cached_decoder_output_names_ptr_);
  204 + }
  205 +
  206 + private:
  207 + OfflineModelConfig config_;
  208 + Ort::Env env_;
  209 + Ort::SessionOptions sess_opts_;
  210 + Ort::AllocatorWithDefaultOptions allocator_;
  211 +
  212 + std::unique_ptr<Ort::Session> preprocessor_sess_;
  213 + std::unique_ptr<Ort::Session> encoder_sess_;
  214 + std::unique_ptr<Ort::Session> uncached_decoder_sess_;
  215 + std::unique_ptr<Ort::Session> cached_decoder_sess_;
  216 +
  217 + std::vector<std::string> preprocessor_input_names_;
  218 + std::vector<const char *> preprocessor_input_names_ptr_;
  219 +
  220 + std::vector<std::string> preprocessor_output_names_;
  221 + std::vector<const char *> preprocessor_output_names_ptr_;
  222 +
  223 + std::vector<std::string> encoder_input_names_;
  224 + std::vector<const char *> encoder_input_names_ptr_;
  225 +
  226 + std::vector<std::string> encoder_output_names_;
  227 + std::vector<const char *> encoder_output_names_ptr_;
  228 +
  229 + std::vector<std::string> uncached_decoder_input_names_;
  230 + std::vector<const char *> uncached_decoder_input_names_ptr_;
  231 +
  232 + std::vector<std::string> uncached_decoder_output_names_;
  233 + std::vector<const char *> uncached_decoder_output_names_ptr_;
  234 +
  235 + std::vector<std::string> cached_decoder_input_names_;
  236 + std::vector<const char *> cached_decoder_input_names_ptr_;
  237 +
  238 + std::vector<std::string> cached_decoder_output_names_;
  239 + std::vector<const char *> cached_decoder_output_names_ptr_;
  240 +};
  241 +
  242 +OfflineMoonshineModel::OfflineMoonshineModel(const OfflineModelConfig &config)
  243 + : impl_(std::make_unique<Impl>(config)) {}
  244 +
  245 +#if __ANDROID_API__ >= 9
  246 +OfflineMoonshineModel::OfflineMoonshineModel(AAssetManager *mgr,
  247 + const OfflineModelConfig &config)
  248 + : impl_(std::make_unique<Impl>(mgr, config)) {}
  249 +#endif
  250 +
  251 +OfflineMoonshineModel::~OfflineMoonshineModel() = default;
  252 +
  253 +Ort::Value OfflineMoonshineModel::ForwardPreprocessor(Ort::Value audio) const {
  254 + return impl_->ForwardPreprocessor(std::move(audio));
  255 +}
  256 +
  257 +Ort::Value OfflineMoonshineModel::ForwardEncoder(
  258 + Ort::Value features, Ort::Value features_len) const {
  259 + return impl_->ForwardEncoder(std::move(features), std::move(features_len));
  260 +}
  261 +
  262 +std::pair<Ort::Value, std::vector<Ort::Value>>
  263 +OfflineMoonshineModel::ForwardUnCachedDecoder(Ort::Value token,
  264 + Ort::Value seq_len,
  265 + Ort::Value encoder_out) const {
  266 + return impl_->ForwardUnCachedDecoder(std::move(token), std::move(seq_len),
  267 + std::move(encoder_out));
  268 +}
  269 +
  270 +std::pair<Ort::Value, std::vector<Ort::Value>>
  271 +OfflineMoonshineModel::ForwardCachedDecoder(
  272 + Ort::Value token, Ort::Value seq_len, Ort::Value encoder_out,
  273 + std::vector<Ort::Value> states) const {
  274 + return impl_->ForwardCachedDecoder(std::move(token), std::move(seq_len),
  275 + std::move(encoder_out), std::move(states));
  276 +}
  277 +
  278 +OrtAllocator *OfflineMoonshineModel::Allocator() const {
  279 + return impl_->Allocator();
  280 +}
  281 +
  282 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-moonshine-model.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_H_
  6 +
  7 +#include <memory>
  8 +#include <string>
  9 +#include <utility>
  10 +#include <vector>
  11 +
  12 +#if __ANDROID_API__ >= 9
  13 +#include "android/asset_manager.h"
  14 +#include "android/asset_manager_jni.h"
  15 +#endif
  16 +
  17 +#include "onnxruntime_cxx_api.h" // NOLINT
  18 +#include "sherpa-onnx/csrc/offline-model-config.h"
  19 +
  20 +namespace sherpa_onnx {
  21 +
  22 +// please see
  23 +// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/moonshine/test.py
  24 +class OfflineMoonshineModel {
  25 + public:
  26 + explicit OfflineMoonshineModel(const OfflineModelConfig &config);
  27 +
  28 +#if __ANDROID_API__ >= 9
  29 + OfflineMoonshineModel(AAssetManager *mgr, const OfflineModelConfig &config);
  30 +#endif
  31 +
  32 + ~OfflineMoonshineModel();
  33 +
  34 + /** Run the preprocessor model.
  35 + *
  36 + * @param audio A float32 tensor of shape (batch_size, num_samples)
  37 + *
  38 + * @return Return a float32 tensor of shape (batch_size, T, dim) that
  39 + * can be used as the input of ForwardEncoder()
  40 + */
  41 + Ort::Value ForwardPreprocessor(Ort::Value audio) const;
  42 +
  43 + /** Run the encoder model.
  44 + *
  45 + * @param features A float32 tensor of shape (batch_size, T, dim)
  46 + * @param features_len A int32 tensor of shape (batch_size,)
  47 + * @returns A float32 tensor of shape (batch_size, T, dim).
  48 + */
  49 + Ort::Value ForwardEncoder(Ort::Value features, Ort::Value features_len) const;
  50 +
  51 + /** Run the uncached decoder.
  52 + *
  53 + * @param token A int32 tensor of shape (batch_size, num_tokens)
  54 + * @param seq_len A int32 tensor of shape (batch_size,) containing number
  55 + * of predicted tokens so far
  56 + * @param encoder_out A float32 tensor of shape (batch_size, T, dim)
  57 + *
  58 + * @returns Return a pair:
  59 + *
  60 + * - logits, a float32 tensor of shape (batch_size, 1, dim)
  61 + * - states, a list of states
  62 + */
  63 + std::pair<Ort::Value, std::vector<Ort::Value>> ForwardUnCachedDecoder(
  64 + Ort::Value token, Ort::Value seq_len, Ort::Value encoder_out) const;
  65 +
  66 + /** Run the cached decoder.
  67 + *
  68 + * @param token A int32 tensor of shape (batch_size, num_tokens)
  69 + * @param seq_len A int32 tensor of shape (batch_size,) containing number
  70 + * of predicted tokens so far
  71 + * @param encoder_out A float32 tensor of shape (batch_size, T, dim)
  72 + * @param states A list of previous states
  73 + *
  74 + * @returns Return a pair:
  75 + * - logits, a float32 tensor of shape (batch_size, 1, dim)
  76 + * - states, a list of new states
  77 + */
  78 + std::pair<Ort::Value, std::vector<Ort::Value>> ForwardCachedDecoder(
  79 + Ort::Value token, Ort::Value seq_len, Ort::Value encoder_out,
  80 + std::vector<Ort::Value> states) const;
  81 +
  82 + /** Return an allocator for allocating memory
  83 + */
  84 + OrtAllocator *Allocator() const;
  85 +
  86 + private:
  87 + class Impl;
  88 + std::unique_ptr<Impl> impl_;
  89 +};
  90 +
  91 +} // namespace sherpa_onnx
  92 +
  93 +#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_H_
@@ -20,6 +20,7 @@ @@ -20,6 +20,7 @@
20 #include "onnxruntime_cxx_api.h" // NOLINT 20 #include "onnxruntime_cxx_api.h" // NOLINT
21 #include "sherpa-onnx/csrc/macros.h" 21 #include "sherpa-onnx/csrc/macros.h"
22 #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" 22 #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
  23 +#include "sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h"
23 #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h" 24 #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
24 #include "sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h" 25 #include "sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h"
25 #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" 26 #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
@@ -51,6 +52,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -51,6 +52,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
51 return std::make_unique<OfflineRecognizerWhisperImpl>(config); 52 return std::make_unique<OfflineRecognizerWhisperImpl>(config);
52 } 53 }
53 54
  55 + if (!config.model_config.moonshine.preprocessor.empty()) {
  56 + return std::make_unique<OfflineRecognizerMoonshineImpl>(config);
  57 + }
  58 +
54 // TODO(fangjun): Refactor it. We only need to use model type for the 59 // TODO(fangjun): Refactor it. We only need to use model type for the
55 // following models: 60 // following models:
56 // 1. transducer and nemo_transducer 61 // 1. transducer and nemo_transducer
@@ -67,7 +72,11 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -67,7 +72,11 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
67 model_type == "telespeech_ctc") { 72 model_type == "telespeech_ctc") {
68 return std::make_unique<OfflineRecognizerCtcImpl>(config); 73 return std::make_unique<OfflineRecognizerCtcImpl>(config);
69 } else if (model_type == "whisper") { 74 } else if (model_type == "whisper") {
  75 + // unreachable
70 return std::make_unique<OfflineRecognizerWhisperImpl>(config); 76 return std::make_unique<OfflineRecognizerWhisperImpl>(config);
  77 + } else if (model_type == "moonshine") {
  78 + // unreachable
  79 + return std::make_unique<OfflineRecognizerMoonshineImpl>(config);
71 } else { 80 } else {
72 SHERPA_ONNX_LOGE( 81 SHERPA_ONNX_LOGE(
73 "Invalid model_type: %s. Trying to load the model to get its type", 82 "Invalid model_type: %s. Trying to load the model to get its type",
@@ -225,6 +234,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -225,6 +234,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
225 return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config); 234 return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
226 } 235 }
227 236
  237 + if (!config.model_config.moonshine.preprocessor.empty()) {
  238 + return std::make_unique<OfflineRecognizerMoonshineImpl>(mgr, config);
  239 + }
  240 +
228 // TODO(fangjun): Refactor it. We only need to use model type for the 241 // TODO(fangjun): Refactor it. We only need to use model type for the
229 // following models: 242 // following models:
230 // 1. transducer and nemo_transducer 243 // 1. transducer and nemo_transducer
@@ -242,6 +255,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -242,6 +255,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
242 return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config); 255 return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
243 } else if (model_type == "whisper") { 256 } else if (model_type == "whisper") {
244 return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config); 257 return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
  258 + } else if (model_type == "moonshine") {
  259 + return std::make_unique<OfflineRecognizerMoonshineImpl>(mgr, config);
245 } else { 260 } else {
246 SHERPA_ONNX_LOGE( 261 SHERPA_ONNX_LOGE(
247 "Invalid model_type: %s. Trying to load the model to get its type", 262 "Invalid model_type: %s. Trying to load the model to get its type",
  1 +// sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_MOONSHINE_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_MOONSHINE_IMPL_H_
  7 +
  8 +#include <algorithm>
  9 +#include <cmath>
  10 +#include <memory>
  11 +#include <string>
  12 +#include <utility>
  13 +#include <vector>
  14 +
  15 +#if __ANDROID_API__ >= 9
  16 +#include "android/asset_manager.h"
  17 +#include "android/asset_manager_jni.h"
  18 +#endif
  19 +
  20 +#include "sherpa-onnx/csrc/offline-model-config.h"
  21 +#include "sherpa-onnx/csrc/offline-moonshine-decoder.h"
  22 +#include "sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h"
  23 +#include "sherpa-onnx/csrc/offline-moonshine-model.h"
  24 +#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
  25 +#include "sherpa-onnx/csrc/offline-recognizer.h"
  26 +#include "sherpa-onnx/csrc/symbol-table.h"
  27 +#include "sherpa-onnx/csrc/transpose.h"
  28 +
  29 +namespace sherpa_onnx {
  30 +
  31 +static OfflineRecognitionResult Convert(
  32 + const OfflineMoonshineDecoderResult &src, const SymbolTable &sym_table) {
  33 + OfflineRecognitionResult r;
  34 + r.tokens.reserve(src.tokens.size());
  35 +
  36 + std::string text;
  37 + for (auto i : src.tokens) {
  38 + if (!sym_table.Contains(i)) {
  39 + continue;
  40 + }
  41 +
  42 + const auto &s = sym_table[i];
  43 + text += s;
  44 + r.tokens.push_back(s);
  45 + }
  46 +
  47 + r.text = text;
  48 +
  49 + return r;
  50 +}
  51 +
  52 +class OfflineRecognizerMoonshineImpl : public OfflineRecognizerImpl {
  53 + public:
  54 + explicit OfflineRecognizerMoonshineImpl(const OfflineRecognizerConfig &config)
  55 + : OfflineRecognizerImpl(config),
  56 + config_(config),
  57 + symbol_table_(config_.model_config.tokens),
  58 + model_(std::make_unique<OfflineMoonshineModel>(config.model_config)) {
  59 + Init();
  60 + }
  61 +
  62 +#if __ANDROID_API__ >= 9
  63 + OfflineRecognizerMoonshineImpl(AAssetManager *mgr,
  64 + const OfflineRecognizerConfig &config)
  65 + : OfflineRecognizerImpl(mgr, config),
  66 + config_(config),
  67 + symbol_table_(mgr, config_.model_config.tokens),
  68 + model_(
  69 + std::make_unique<OfflineMoonshineModel>(mgr, config.model_config)) {
  70 + Init();
  71 + }
  72 +
  73 +#endif
  74 +
  75 + void Init() {
  76 + if (config_.decoding_method == "greedy_search") {
  77 + decoder_ =
  78 + std::make_unique<OfflineMoonshineGreedySearchDecoder>(model_.get());
  79 + } else {
  80 + SHERPA_ONNX_LOGE(
  81 + "Only greedy_search is supported at present for moonshine. Given %s",
  82 + config_.decoding_method.c_str());
  83 + exit(-1);
  84 + }
  85 + }
  86 +
  87 + std::unique_ptr<OfflineStream> CreateStream() const override {
  88 + MoonshineTag tag;
  89 + return std::make_unique<OfflineStream>(tag);
  90 + }
  91 +
  92 + void DecodeStreams(OfflineStream **ss, int32_t n) const override {
  93 + // batch decoding is not implemented yet
  94 + for (int32_t i = 0; i != n; ++i) {
  95 + DecodeStream(ss[i]);
  96 + }
  97 + }
  98 +
  99 + OfflineRecognizerConfig GetConfig() const override { return config_; }
  100 +
  101 + private:
  102 + void DecodeStream(OfflineStream *s) const {
  103 + auto memory_info =
  104 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  105 +
  106 + std::vector<float> audio = s->GetFrames();
  107 +
  108 + try {
  109 + std::array<int64_t, 2> shape{1, static_cast<int64_t>(audio.size())};
  110 +
  111 + Ort::Value audio_tensor = Ort::Value::CreateTensor(
  112 + memory_info, audio.data(), audio.size(), shape.data(), shape.size());
  113 +
  114 + Ort::Value features =
  115 + model_->ForwardPreprocessor(std::move(audio_tensor));
  116 +
  117 + int32_t features_len = features.GetTensorTypeAndShapeInfo().GetShape()[1];
  118 +
  119 + int64_t features_shape = 1;
  120 +
  121 + Ort::Value features_len_tensor = Ort::Value::CreateTensor(
  122 + memory_info, &features_len, 1, &features_shape, 1);
  123 +
  124 + Ort::Value encoder_out = model_->ForwardEncoder(
  125 + std::move(features), std::move(features_len_tensor));
  126 +
  127 + auto results = decoder_->Decode(std::move(encoder_out));
  128 +
  129 + auto r = Convert(results[0], symbol_table_);
  130 + r.text = ApplyInverseTextNormalization(std::move(r.text));
  131 + s->SetResult(r);
  132 + } catch (const Ort::Exception &ex) {
  133 + SHERPA_ONNX_LOGE(
  134 + "\n\nCaught exception:\n\n%s\n\nReturn an empty result. Number of "
  135 + "audio samples: %d",
  136 + ex.what(), static_cast<int32_t>(audio.size()));
  137 + return;
  138 + }
  139 + }
  140 +
  141 + private:
  142 + OfflineRecognizerConfig config_;
  143 + SymbolTable symbol_table_;
  144 + std::unique_ptr<OfflineMoonshineModel> model_;
  145 + std::unique_ptr<OfflineMoonshineDecoder> decoder_;
  146 +};
  147 +
  148 +} // namespace sherpa_onnx
  149 +
  150 +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_MOONSHINE_IMPL_H_
@@ -133,6 +133,10 @@ class OfflineStream::Impl { @@ -133,6 +133,10 @@ class OfflineStream::Impl {
133 fbank_ = std::make_unique<knf::OnlineFbank>(opts_); 133 fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
134 } 134 }
135 135
  136 + explicit Impl(MoonshineTag /*tag*/) : is_moonshine_(true) {
  137 + config_.sampling_rate = 16000;
  138 + }
  139 +
136 void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { 140 void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
137 if (config_.normalize_samples) { 141 if (config_.normalize_samples) {
138 AcceptWaveformImpl(sampling_rate, waveform, n); 142 AcceptWaveformImpl(sampling_rate, waveform, n);
@@ -164,7 +168,9 @@ class OfflineStream::Impl { @@ -164,7 +168,9 @@ class OfflineStream::Impl {
164 std::vector<float> samples; 168 std::vector<float> samples;
165 resampler->Resample(waveform, n, true, &samples); 169 resampler->Resample(waveform, n, true, &samples);
166 170
167 - if (fbank_) { 171 + if (is_moonshine_) {
  172 + samples_.insert(samples_.end(), samples.begin(), samples.end());
  173 + } else if (fbank_) {
168 fbank_->AcceptWaveform(config_.sampling_rate, samples.data(), 174 fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
169 samples.size()); 175 samples.size());
170 fbank_->InputFinished(); 176 fbank_->InputFinished();
@@ -181,7 +187,9 @@ class OfflineStream::Impl { @@ -181,7 +187,9 @@ class OfflineStream::Impl {
181 return; 187 return;
182 } // if (sampling_rate != config_.sampling_rate) 188 } // if (sampling_rate != config_.sampling_rate)
183 189
184 - if (fbank_) { 190 + if (is_moonshine_) {
  191 + samples_.insert(samples_.end(), waveform, waveform + n);
  192 + } else if (fbank_) {
185 fbank_->AcceptWaveform(sampling_rate, waveform, n); 193 fbank_->AcceptWaveform(sampling_rate, waveform, n);
186 fbank_->InputFinished(); 194 fbank_->InputFinished();
187 } else if (mfcc_) { 195 } else if (mfcc_) {
@@ -194,10 +202,18 @@ class OfflineStream::Impl { @@ -194,10 +202,18 @@ class OfflineStream::Impl {
194 } 202 }
195 203
196 int32_t FeatureDim() const { 204 int32_t FeatureDim() const {
  205 + if (is_moonshine_) {
  206 + return samples_.size();
  207 + }
  208 +
197 return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins; 209 return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins;
198 } 210 }
199 211
200 std::vector<float> GetFrames() const { 212 std::vector<float> GetFrames() const {
  213 + if (is_moonshine_) {
  214 + return samples_;
  215 + }
  216 +
201 int32_t n = fbank_ ? fbank_->NumFramesReady() 217 int32_t n = fbank_ ? fbank_->NumFramesReady()
202 : mfcc_ ? mfcc_->NumFramesReady() 218 : mfcc_ ? mfcc_->NumFramesReady()
203 : whisper_fbank_->NumFramesReady(); 219 : whisper_fbank_->NumFramesReady();
@@ -300,6 +316,10 @@ class OfflineStream::Impl { @@ -300,6 +316,10 @@ class OfflineStream::Impl {
300 OfflineRecognitionResult r_; 316 OfflineRecognitionResult r_;
301 ContextGraphPtr context_graph_; 317 ContextGraphPtr context_graph_;
302 bool is_ced_ = false; 318 bool is_ced_ = false;
  319 + bool is_moonshine_ = false;
  320 +
  321 + // used only when is_moonshine_== true
  322 + std::vector<float> samples_;
303 }; 323 };
304 324
305 OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/, 325 OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/,
@@ -311,6 +331,9 @@ OfflineStream::OfflineStream(WhisperTag tag) @@ -311,6 +331,9 @@ OfflineStream::OfflineStream(WhisperTag tag)
311 331
312 OfflineStream::OfflineStream(CEDTag tag) : impl_(std::make_unique<Impl>(tag)) {} 332 OfflineStream::OfflineStream(CEDTag tag) : impl_(std::make_unique<Impl>(tag)) {}
313 333
  334 +OfflineStream::OfflineStream(MoonshineTag tag)
  335 + : impl_(std::make_unique<Impl>(tag)) {}
  336 +
314 OfflineStream::~OfflineStream() = default; 337 OfflineStream::~OfflineStream() = default;
315 338
316 void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform, 339 void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform,
@@ -34,7 +34,7 @@ struct OfflineRecognitionResult { @@ -34,7 +34,7 @@ struct OfflineRecognitionResult {
34 // event target of the audio. 34 // event target of the audio.
35 std::string event; 35 std::string event;
36 36
37 - /// timestamps.size() == tokens.size() 37 + /// timestamps.size() == tokens.size()
38 /// timestamps[i] records the time in seconds when tokens[i] is decoded. 38 /// timestamps[i] records the time in seconds when tokens[i] is decoded.
39 std::vector<float> timestamps; 39 std::vector<float> timestamps;
40 40
@@ -49,6 +49,10 @@ struct WhisperTag { @@ -49,6 +49,10 @@ struct WhisperTag {
49 49
50 struct CEDTag {}; 50 struct CEDTag {};
51 51
  52 +// It uses a neural network model, a preprocessor, to convert
  53 +// audio samples to features
  54 +struct MoonshineTag {};
  55 +
52 class OfflineStream { 56 class OfflineStream {
53 public: 57 public:
54 explicit OfflineStream(const FeatureExtractorConfig &config = {}, 58 explicit OfflineStream(const FeatureExtractorConfig &config = {},
@@ -56,6 +60,7 @@ class OfflineStream { @@ -56,6 +60,7 @@ class OfflineStream {
56 60
57 explicit OfflineStream(WhisperTag tag); 61 explicit OfflineStream(WhisperTag tag);
58 explicit OfflineStream(CEDTag tag); 62 explicit OfflineStream(CEDTag tag);
  63 + explicit OfflineStream(MoonshineTag tag);
59 ~OfflineStream(); 64 ~OfflineStream();
60 65
61 /** 66 /**
@@ -72,7 +77,10 @@ class OfflineStream { @@ -72,7 +77,10 @@ class OfflineStream {
72 void AcceptWaveform(int32_t sampling_rate, const float *waveform, 77 void AcceptWaveform(int32_t sampling_rate, const float *waveform,
73 int32_t n) const; 78 int32_t n) const;
74 79
75 - /// Return feature dim of this extractor 80 + /// Return feature dim of this extractor.
  81 + ///
  82 + /// Note: if it is Moonshine, then it returns the number of audio samples
  83 + /// currently received.
76 int32_t FeatureDim() const; 84 int32_t FeatureDim() const;
77 85
78 // Get all the feature frames of this stream in a 1-D array, which is 86 // Get all the feature frames of this stream in a 1-D array, which is
@@ -23,7 +23,6 @@ class OfflineWhisperModel::Impl { @@ -23,7 +23,6 @@ class OfflineWhisperModel::Impl {
23 explicit Impl(const OfflineModelConfig &config) 23 explicit Impl(const OfflineModelConfig &config)
24 : config_(config), 24 : config_(config),
25 env_(ORT_LOGGING_LEVEL_ERROR), 25 env_(ORT_LOGGING_LEVEL_ERROR),
26 - debug_(config.debug),  
27 sess_opts_(GetSessionOptions(config)), 26 sess_opts_(GetSessionOptions(config)),
28 allocator_{} { 27 allocator_{} {
29 { 28 {
@@ -40,7 +39,6 @@ class OfflineWhisperModel::Impl { @@ -40,7 +39,6 @@ class OfflineWhisperModel::Impl {
40 explicit Impl(const SpokenLanguageIdentificationConfig &config) 39 explicit Impl(const SpokenLanguageIdentificationConfig &config)
41 : lid_config_(config), 40 : lid_config_(config),
42 env_(ORT_LOGGING_LEVEL_ERROR), 41 env_(ORT_LOGGING_LEVEL_ERROR),
43 - debug_(config_.debug),  
44 sess_opts_(GetSessionOptions(config)), 42 sess_opts_(GetSessionOptions(config)),
45 allocator_{} { 43 allocator_{} {
46 { 44 {
@@ -60,7 +58,6 @@ class OfflineWhisperModel::Impl { @@ -60,7 +58,6 @@ class OfflineWhisperModel::Impl {
60 env_(ORT_LOGGING_LEVEL_ERROR), 58 env_(ORT_LOGGING_LEVEL_ERROR),
61 sess_opts_(GetSessionOptions(config)), 59 sess_opts_(GetSessionOptions(config)),
62 allocator_{} { 60 allocator_{} {
63 - debug_ = config_.debug;  
64 { 61 {
65 auto buf = ReadFile(mgr, config.whisper.encoder); 62 auto buf = ReadFile(mgr, config.whisper.encoder);
66 InitEncoder(buf.data(), buf.size()); 63 InitEncoder(buf.data(), buf.size());
@@ -77,7 +74,6 @@ class OfflineWhisperModel::Impl { @@ -77,7 +74,6 @@ class OfflineWhisperModel::Impl {
77 env_(ORT_LOGGING_LEVEL_ERROR), 74 env_(ORT_LOGGING_LEVEL_ERROR),
78 sess_opts_(GetSessionOptions(config)), 75 sess_opts_(GetSessionOptions(config)),
79 allocator_{} { 76 allocator_{} {
80 - debug_ = config_.debug;  
81 { 77 {
82 auto buf = ReadFile(mgr, config.whisper.encoder); 78 auto buf = ReadFile(mgr, config.whisper.encoder);
83 InitEncoder(buf.data(), buf.size()); 79 InitEncoder(buf.data(), buf.size());
@@ -164,7 +160,7 @@ class OfflineWhisperModel::Impl { @@ -164,7 +160,7 @@ class OfflineWhisperModel::Impl {
164 } 160 }
165 } 161 }
166 162
167 - if (debug_) { 163 + if (config_.debug) {
168 SHERPA_ONNX_LOGE("Detected language: %s", 164 SHERPA_ONNX_LOGE("Detected language: %s",
169 GetID2Lang().at(lang_id).c_str()); 165 GetID2Lang().at(lang_id).c_str());
170 } 166 }
@@ -237,7 +233,7 @@ class OfflineWhisperModel::Impl { @@ -237,7 +233,7 @@ class OfflineWhisperModel::Impl {
237 233
238 // get meta data 234 // get meta data
239 Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); 235 Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
240 - if (debug_) { 236 + if (config_.debug) {
241 std::ostringstream os; 237 std::ostringstream os;
242 os << "---encoder---\n"; 238 os << "---encoder---\n";
243 PrintModelMetadata(os, meta_data); 239 PrintModelMetadata(os, meta_data);
@@ -294,7 +290,6 @@ class OfflineWhisperModel::Impl { @@ -294,7 +290,6 @@ class OfflineWhisperModel::Impl {
294 private: 290 private:
295 OfflineModelConfig config_; 291 OfflineModelConfig config_;
296 SpokenLanguageIdentificationConfig lid_config_; 292 SpokenLanguageIdentificationConfig lid_config_;
297 - bool debug_ = false;  
298 Ort::Env env_; 293 Ort::Env env_;
299 Ort::SessionOptions sess_opts_; 294 Ort::SessionOptions sess_opts_;
300 Ort::AllocatorWithDefaultOptions allocator_; 295 Ort::AllocatorWithDefaultOptions allocator_;
@@ -43,7 +43,20 @@ See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/in @@ -43,7 +43,20 @@ See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/in
43 --decoding-method=greedy_search \ 43 --decoding-method=greedy_search \
44 /path/to/foo.wav [bar.wav foobar.wav ...] 44 /path/to/foo.wav [bar.wav foobar.wav ...]
45 45
46 -(3) Whisper models 46 +(3) Moonshine models
  47 +
  48 +See https://k2-fsa.github.io/sherpa/onnx/moonshine/index.html
  49 +
  50 + ./bin/sherpa-onnx-offline \
  51 + --moonshine-preprocessor=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/preprocess.onnx \
  52 + --moonshine-encoder=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/encode.int8.onnx \
  53 + --moonshine-uncached-decoder=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/uncached_decode.int8.onnx \
  54 + --moonshine-cached-decoder=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/cached_decode.int8.onnx \
  55 + --tokens=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/tokens.txt \
  56 + --num-threads=1 \
  57 + /path/to/foo.wav [bar.wav foobar.wav ...]
  58 +
  59 +(4) Whisper models
47 60
48 See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html 61 See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html
49 62
@@ -54,7 +67,7 @@ See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html @@ -54,7 +67,7 @@ See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html
54 --num-threads=1 \ 67 --num-threads=1 \
55 /path/to/foo.wav [bar.wav foobar.wav ...] 68 /path/to/foo.wav [bar.wav foobar.wav ...]
56 69
57 -(4) NeMo CTC models 70 +(5) NeMo CTC models
58 71
59 See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html 72 See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html
60 73
@@ -68,7 +81,7 @@ See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.htm @@ -68,7 +81,7 @@ See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.htm
68 ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/1.wav \ 81 ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/1.wav \
69 ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/8k.wav 82 ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/8k.wav
70 83
71 -(5) TDNN CTC model for the yesno recipe from icefall 84 +(6) TDNN CTC model for the yesno recipe from icefall
72 85
73 See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/yesno/index.html 86 See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/yesno/index.html
74 // 87 //
@@ -109,6 +109,8 @@ const std::string SymbolTable::operator[](int32_t id) const { @@ -109,6 +109,8 @@ const std::string SymbolTable::operator[](int32_t id) const {
109 109
110 // for byte-level BPE 110 // for byte-level BPE
111 // id 0 is blank, id 1 is sos/eos, id 2 is unk 111 // id 0 is blank, id 1 is sos/eos, id 2 is unk
  112 + //
  113 + // Note: For moonshine models, 0 is <unk>, 1, is <s>, 2 is</s>
112 if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' && 114 if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' &&
113 sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') { 115 sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') {
114 std::ostringstream os; 116 std::ostringstream os;
@@ -11,6 +11,7 @@ set(srcs @@ -11,6 +11,7 @@ set(srcs
11 offline-ctc-fst-decoder-config.cc 11 offline-ctc-fst-decoder-config.cc
12 offline-lm-config.cc 12 offline-lm-config.cc
13 offline-model-config.cc 13 offline-model-config.cc
  14 + offline-moonshine-model-config.cc
14 offline-nemo-enc-dec-ctc-model-config.cc 15 offline-nemo-enc-dec-ctc-model-config.cc
15 offline-paraformer-model-config.cc 16 offline-paraformer-model-config.cc
16 offline-punctuation.cc 17 offline-punctuation.cc
@@ -8,6 +8,7 @@ @@ -8,6 +8,7 @@
8 #include <vector> 8 #include <vector>
9 9
10 #include "sherpa-onnx/csrc/offline-model-config.h" 10 #include "sherpa-onnx/csrc/offline-model-config.h"
  11 +#include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h"
11 #include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h" 12 #include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
12 #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" 13 #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
13 #include "sherpa-onnx/python/csrc/offline-sense-voice-model-config.h" 14 #include "sherpa-onnx/python/csrc/offline-sense-voice-model-config.h"
@@ -28,6 +29,7 @@ void PybindOfflineModelConfig(py::module *m) { @@ -28,6 +29,7 @@ void PybindOfflineModelConfig(py::module *m) {
28 PybindOfflineZipformerCtcModelConfig(m); 29 PybindOfflineZipformerCtcModelConfig(m);
29 PybindOfflineWenetCtcModelConfig(m); 30 PybindOfflineWenetCtcModelConfig(m);
30 PybindOfflineSenseVoiceModelConfig(m); 31 PybindOfflineSenseVoiceModelConfig(m);
  32 + PybindOfflineMoonshineModelConfig(m);
31 33
32 using PyClass = OfflineModelConfig; 34 using PyClass = OfflineModelConfig;
33 py::class_<PyClass>(*m, "OfflineModelConfig") 35 py::class_<PyClass>(*m, "OfflineModelConfig")
@@ -39,7 +41,8 @@ void PybindOfflineModelConfig(py::module *m) { @@ -39,7 +41,8 @@ void PybindOfflineModelConfig(py::module *m) {
39 const OfflineWhisperModelConfig &, const OfflineTdnnModelConfig &, 41 const OfflineWhisperModelConfig &, const OfflineTdnnModelConfig &,
40 const OfflineZipformerCtcModelConfig &, 42 const OfflineZipformerCtcModelConfig &,
41 const OfflineWenetCtcModelConfig &, 43 const OfflineWenetCtcModelConfig &,
42 - const OfflineSenseVoiceModelConfig &, const std::string &, 44 + const OfflineSenseVoiceModelConfig &,
  45 + const OfflineMoonshineModelConfig &, const std::string &,
43 const std::string &, int32_t, bool, const std::string &, 46 const std::string &, int32_t, bool, const std::string &,
44 const std::string &, const std::string &, const std::string &>(), 47 const std::string &, const std::string &, const std::string &>(),
45 py::arg("transducer") = OfflineTransducerModelConfig(), 48 py::arg("transducer") = OfflineTransducerModelConfig(),
@@ -50,6 +53,7 @@ void PybindOfflineModelConfig(py::module *m) { @@ -50,6 +53,7 @@ void PybindOfflineModelConfig(py::module *m) {
50 py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(), 53 py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
51 py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(), 54 py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
52 py::arg("sense_voice") = OfflineSenseVoiceModelConfig(), 55 py::arg("sense_voice") = OfflineSenseVoiceModelConfig(),
  56 + py::arg("moonshine") = OfflineMoonshineModelConfig(),
53 py::arg("telespeech_ctc") = "", py::arg("tokens"), 57 py::arg("telespeech_ctc") = "", py::arg("tokens"),
54 py::arg("num_threads"), py::arg("debug") = false, 58 py::arg("num_threads"), py::arg("debug") = false,
55 py::arg("provider") = "cpu", py::arg("model_type") = "", 59 py::arg("provider") = "cpu", py::arg("model_type") = "",
@@ -62,6 +66,7 @@ void PybindOfflineModelConfig(py::module *m) { @@ -62,6 +66,7 @@ void PybindOfflineModelConfig(py::module *m) {
62 .def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc) 66 .def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc)
63 .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) 67 .def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
64 .def_readwrite("sense_voice", &PyClass::sense_voice) 68 .def_readwrite("sense_voice", &PyClass::sense_voice)
  69 + .def_readwrite("moonshine", &PyClass::moonshine)
65 .def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc) 70 .def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc)
66 .def_readwrite("tokens", &PyClass::tokens) 71 .def_readwrite("tokens", &PyClass::tokens)
67 .def_readwrite("num_threads", &PyClass::num_threads) 72 .def_readwrite("num_threads", &PyClass::num_threads)
  1 +// sherpa-onnx/python/csrc/offline-moonshine-model-config.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-moonshine-model-config.h"
  6 +
  7 +#include <string>
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +void PybindOfflineMoonshineModelConfig(py::module *m) {
  15 + using PyClass = OfflineMoonshineModelConfig;
  16 + py::class_<PyClass>(*m, "OfflineMoonshineModelConfig")
  17 + .def(py::init<const std::string &, const std::string &,
  18 + const std::string &, const std::string &>(),
  19 + py::arg("preprocessor"), py::arg("encoder"),
  20 + py::arg("uncached_decoder"), py::arg("cached_decoder"))
  21 + .def_readwrite("preprocessor", &PyClass::preprocessor)
  22 + .def_readwrite("encoder", &PyClass::encoder)
  23 + .def_readwrite("uncached_decoder", &PyClass::uncached_decoder)
  24 + .def_readwrite("cached_decoder", &PyClass::cached_decoder)
  25 + .def("__str__", &PyClass::ToString);
  26 +}
  27 +
  28 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-moonshine-model-config.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineMoonshineModelConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
@@ -8,13 +8,14 @@ from _sherpa_onnx import ( @@ -8,13 +8,14 @@ from _sherpa_onnx import (
8 OfflineCtcFstDecoderConfig, 8 OfflineCtcFstDecoderConfig,
9 OfflineLMConfig, 9 OfflineLMConfig,
10 OfflineModelConfig, 10 OfflineModelConfig,
  11 + OfflineMoonshineModelConfig,
11 OfflineNemoEncDecCtcModelConfig, 12 OfflineNemoEncDecCtcModelConfig,
12 OfflineParaformerModelConfig, 13 OfflineParaformerModelConfig,
13 - OfflineSenseVoiceModelConfig,  
14 ) 14 )
15 from _sherpa_onnx import OfflineRecognizer as _Recognizer 15 from _sherpa_onnx import OfflineRecognizer as _Recognizer
16 from _sherpa_onnx import ( 16 from _sherpa_onnx import (
17 OfflineRecognizerConfig, 17 OfflineRecognizerConfig,
  18 + OfflineSenseVoiceModelConfig,
18 OfflineStream, 19 OfflineStream,
19 OfflineTdnnModelConfig, 20 OfflineTdnnModelConfig,
20 OfflineTransducerModelConfig, 21 OfflineTransducerModelConfig,
@@ -503,12 +504,12 @@ class OfflineRecognizer(object): @@ -503,12 +504,12 @@ class OfflineRecognizer(object):
503 e.g., tiny, tiny.en, base, base.en, etc. 504 e.g., tiny, tiny.en, base, base.en, etc.
504 505
505 Args: 506 Args:
506 - encoder_model:  
507 - Path to the encoder model, e.g., tiny-encoder.onnx,  
508 - tiny-encoder.int8.onnx, tiny-encoder.ort, etc.  
509 - decoder_model: 507 + encoder:
510 Path to the encoder model, e.g., tiny-encoder.onnx, 508 Path to the encoder model, e.g., tiny-encoder.onnx,
511 tiny-encoder.int8.onnx, tiny-encoder.ort, etc. 509 tiny-encoder.int8.onnx, tiny-encoder.ort, etc.
  510 + decoder:
  511 + Path to the decoder model, e.g., tiny-decoder.onnx,
  512 + tiny-decoder.int8.onnx, tiny-decoder.ort, etc.
512 tokens: 513 tokens:
513 Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two 514 Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
514 columns:: 515 columns::
@@ -571,6 +572,87 @@ class OfflineRecognizer(object): @@ -571,6 +572,87 @@ class OfflineRecognizer(object):
571 return self 572 return self
572 573
573 @classmethod 574 @classmethod
  575 + def from_moonshine(
  576 + cls,
  577 + preprocessor: str,
  578 + encoder: str,
  579 + uncached_decoder: str,
  580 + cached_decoder: str,
  581 + tokens: str,
  582 + num_threads: int = 1,
  583 + decoding_method: str = "greedy_search",
  584 + debug: bool = False,
  585 + provider: str = "cpu",
  586 + rule_fsts: str = "",
  587 + rule_fars: str = "",
  588 + ):
  589 + """
  590 + Please refer to
  591 + `<https://k2-fsa.github.io/sherpa/onnx/moonshine/index.html>`_
  592 + to download pre-trained models for different kinds of moonshine models,
  593 + e.g., tiny, base, etc.
  594 +
  595 + Args:
  596 + preprocessor:
  597 + Path to the preprocessor model, e.g., preprocess.onnx
  598 + encoder:
  599 + Path to the encoder model, e.g., encode.int8.onnx
  600 + uncached_decoder:
  601 + Path to the uncached decoder model, e.g., uncached_decode.int8.onnx,
  602 + cached_decoder:
  603 + Path to the cached decoder model, e.g., cached_decode.int8.onnx,
  604 + tokens:
  605 + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
  606 + columns::
  607 +
  608 + symbol integer_id
  609 +
  610 + num_threads:
  611 + Number of threads for neural network computation.
  612 + decoding_method:
  613 + Valid values: greedy_search.
  614 + debug:
  615 + True to show debug messages.
  616 + provider:
  617 + onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
  618 + rule_fsts:
  619 + If not empty, it specifies fsts for inverse text normalization.
  620 + If there are multiple fsts, they are separated by a comma.
  621 + rule_fars:
  622 + If not empty, it specifies fst archives for inverse text normalization.
  623 + If there are multiple archives, they are separated by a comma.
  624 + """
  625 + self = cls.__new__(cls)
  626 + model_config = OfflineModelConfig(
  627 + moonshine=OfflineMoonshineModelConfig(
  628 + preprocessor=preprocessor,
  629 + encoder=encoder,
  630 + uncached_decoder=uncached_decoder,
  631 + cached_decoder=cached_decoder,
  632 + ),
  633 + tokens=tokens,
  634 + num_threads=num_threads,
  635 + debug=debug,
  636 + provider=provider,
  637 + )
  638 +
  639 + unused_feat_config = FeatureExtractorConfig(
  640 + sampling_rate=16000,
  641 + feature_dim=80,
  642 + )
  643 +
  644 + recognizer_config = OfflineRecognizerConfig(
  645 + model_config=model_config,
  646 + feat_config=unused_feat_config,
  647 + decoding_method=decoding_method,
  648 + rule_fsts=rule_fsts,
  649 + rule_fars=rule_fars,
  650 + )
  651 + self.recognizer = _Recognizer(recognizer_config)
  652 + self.config = recognizer_config
  653 + return self
  654 +
  655 + @classmethod
574 def from_tdnn_ctc( 656 def from_tdnn_ctc(
575 cls, 657 cls,
576 model: str, 658 model: str,