Fangjun Kuang
Committed by GitHub

Add C++ runtime and Python APIs for Moonshine models (#1473)

正在显示 33 个修改的文件 包含 1572 行增加36 行删除
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
export GIT_CLONE_PROTECTION_ACTIVE=false
echo "EXE is $EXE"
echo "PATH: $PATH"
which $EXE
names=(
tiny
base
)
for name in ${names[@]}; do
log "------------------------------------------------------------"
log "Run $name"
log "------------------------------------------------------------"
repo_url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-$name.tar.bz2
repo_url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-$name-en-int8.tar.bz2
curl -SL -O $repo_url
tar xvf sherpa-onnx-moonshine-$name-en-int8.tar.bz2
rm sherpa-onnx-moonshine-$name-en-int8.tar.bz2
repo=sherpa-onnx-moonshine-$name-en-int8
log "Start testing ${repo_url}"
log "test int8 onnx"
time $EXE \
--moonshine-preprocessor=$repo/preprocess.onnx \
--moonshine-encoder=$repo/encode.int8.onnx \
--moonshine-uncached-decoder=$repo/uncached_decode.int8.onnx \
--moonshine-cached-decoder=$repo/cached_decode.int8.onnx \
--tokens=$repo/tokens.txt \
--num-threads=2 \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav
rm -rf $repo
done
... ...
... ... @@ -8,6 +8,16 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "test offline Moonshine"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
python3 ./python-api-examples/offline-moonshine-decode-files.py
rm -rf sherpa-onnx-moonshine-tiny-en-int8
log "test offline speaker diarization"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
... ...
... ... @@ -149,6 +149,19 @@ jobs:
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
path: install/*
- name: Test offline Moonshine
if: matrix.build_type != 'Debug'
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
readelf -d build/bin/sherpa-onnx-offline
.github/scripts/test-offline-moonshine.sh
du -h -d1 .
- name: Test offline CTC
shell: bash
run: |
... ...
... ... @@ -121,6 +121,15 @@ jobs:
otool -L build/bin/sherpa-onnx
otool -l build/bin/sherpa-onnx
- name: Test offline Moonshine
if: matrix.build_type != 'Debug'
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
.github/scripts/test-offline-moonshine.sh
- name: Test C++ API
shell: bash
run: |
... ... @@ -243,8 +252,6 @@ jobs:
.github/scripts/test-offline-whisper.sh
- name: Test online transducer
shell: bash
run: |
... ...
... ... @@ -93,6 +93,14 @@ jobs:
name: release-windows-x64-${{ matrix.shared_lib }}-${{ matrix.with_tts }}
path: build/install/*
- name: Test offline Moonshine for windows x64
shell: bash
run: |
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx-offline.exe
.github/scripts/test-offline-moonshine.sh
- name: Test C++ API
shell: bash
run: |
... ...
... ... @@ -93,6 +93,14 @@ jobs:
name: release-windows-x86-${{ matrix.shared_lib }}-${{ matrix.with_tts }}
path: build/install/*
- name: Test offline Moonshine for windows x86
shell: bash
run: |
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx-offline.exe
.github/scripts/test-offline-moonshine.sh
- name: Test C++ API
shell: bash
run: |
... ...
... ... @@ -47,7 +47,19 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v
--feature-dim=80 \
/path/to/test.mp4
(3) For Whisper models
(3) For Moonshine models
./python-api-examples/generate-subtitles.py \
--silero-vad-model=/path/to/silero_vad.onnx \
--moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \
--moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \
--moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \
--moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \
--tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt \
--num-threads=2 \
/path/to/test.mp4
(4) For Whisper models
./python-api-examples/generate-subtitles.py \
--silero-vad-model=/path/to/silero_vad.onnx \
... ... @@ -58,7 +70,7 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v
--num-threads=2 \
/path/to/test.mp4
(4) For SenseVoice CTC models
(5) For SenseVoice CTC models
./python-api-examples/generate-subtitles.py \
--silero-vad-model=/path/to/silero_vad.onnx \
... ... @@ -68,7 +80,7 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v
/path/to/test.mp4
(5) For WeNet CTC models
(6) For WeNet CTC models
./python-api-examples/generate-subtitles.py \
--silero-vad-model=/path/to/silero_vad.onnx \
... ... @@ -83,6 +95,7 @@ to install sherpa-onnx and to download non-streaming pre-trained models
used in this file.
"""
import argparse
import datetime as dt
import shutil
import subprocess
import sys
... ... @@ -157,7 +170,7 @@ def get_args():
parser.add_argument(
"--num-threads",
type=int,
default=1,
default=2,
help="Number of threads for neural network computation",
)
... ... @@ -209,6 +222,34 @@ def get_args():
)
parser.add_argument(
"--moonshine-preprocessor",
default="",
type=str,
help="Path to moonshine preprocessor model",
)
parser.add_argument(
"--moonshine-encoder",
default="",
type=str,
help="Path to moonshine encoder model",
)
parser.add_argument(
"--moonshine-uncached-decoder",
default="",
type=str,
help="Path to moonshine uncached decoder model",
)
parser.add_argument(
"--moonshine-cached-decoder",
default="",
type=str,
help="Path to moonshine cached decoder model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
... ... @@ -263,6 +304,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.encoder)
assert_file_exists(args.decoder)
... ... @@ -284,6 +331,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.paraformer)
... ... @@ -300,6 +353,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.sense_voice)
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
... ... @@ -312,6 +371,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
elif args.wenet_ctc:
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.wenet_ctc)
... ... @@ -327,6 +392,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
elif args.whisper_encoder:
assert_file_exists(args.whisper_encoder)
assert_file_exists(args.whisper_decoder)
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
encoder=args.whisper_encoder,
... ... @@ -339,6 +410,22 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
task=args.whisper_task,
tail_paddings=args.whisper_tail_paddings,
)
elif args.moonshine_preprocessor:
assert_file_exists(args.moonshine_preprocessor)
assert_file_exists(args.moonshine_encoder)
assert_file_exists(args.moonshine_uncached_decoder)
assert_file_exists(args.moonshine_cached_decoder)
recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine(
preprocessor=args.moonshine_preprocessor,
encoder=args.moonshine_encoder,
uncached_decoder=args.moonshine_uncached_decoder,
cached_decoder=args.moonshine_cached_decoder,
tokens=args.tokens,
num_threads=args.num_threads,
decoding_method=args.decoding_method,
debug=args.debug,
)
else:
raise ValueError("Please specify at least one model")
... ... @@ -424,28 +511,32 @@ def main():
segment_list = []
print("Started!")
start_t = dt.datetime.now()
num_processed_samples = 0
is_silence = False
is_eof = False
# TODO(fangjun): Support multithreads
while True:
# *2 because int16_t has two bytes
data = process.stdout.read(frames_per_read * 2)
if not data:
if is_silence:
if is_eof:
break
is_silence = True
# The converted audio file does not have a mute data of 1 second or more at the end, which will result in the loss of the last segment data
is_eof = True
# pad 1 second at the end of the file for the VAD
data = np.zeros(1 * args.sample_rate, dtype=np.int16)
samples = np.frombuffer(data, dtype=np.int16)
samples = samples.astype(np.float32) / 32768
num_processed_samples += samples.shape[0]
buffer = np.concatenate([buffer, samples])
while len(buffer) > window_size:
vad.accept_waveform(buffer[:window_size])
buffer = buffer[window_size:]
if is_silence:
if is_eof:
vad.flush()
streams = []
... ... @@ -471,6 +562,11 @@ def main():
seg.text = stream.result.text
segment_list.append(seg)
end_t = dt.datetime.now()
elapsed_seconds = (end_t - start_t).total_seconds()
duration = num_processed_samples / 16000
rtf = elapsed_seconds / duration
srt_filename = Path(args.sound_file).with_suffix(".srt")
with open(srt_filename, "w", encoding="utf-8") as f:
for i, seg in enumerate(segment_list):
... ... @@ -479,6 +575,9 @@ def main():
print("", file=f)
print(f"Saved to {srt_filename}")
print(f"Audio duration:\t{duration:.3f} s")
print(f"Elapsed:\t{elapsed_seconds:.3f} s")
print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
print("Done!")
... ...
... ... @@ -66,7 +66,21 @@ python3 ./python-api-examples/non_streaming_server.py \
--wenet-ctc ./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
--tokens ./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt
(5) Use a Whisper model
(5) Use a Moonshine model
cd /path/to/sherpa-onnx
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
python3 ./python-api-examples/non_streaming_server.py \
--moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \
--moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \
--moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \
--moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \
--tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt
(6) Use a Whisper model
cd /path/to/sherpa-onnx
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2
... ... @@ -78,7 +92,7 @@ python3 ./python-api-examples/non_streaming_server.py \
--whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \
--tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt
(5) Use a tdnn model of the yesno recipe from icefall
(7) Use a tdnn model of the yesno recipe from icefall
cd /path/to/sherpa-onnx
... ... @@ -92,7 +106,7 @@ python3 ./python-api-examples/non_streaming_server.py \
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt
(6) Use a Non-streaming SenseVoice model
(8) Use a Non-streaming SenseVoice model
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
... ... @@ -254,6 +268,36 @@ def add_tdnn_ctc_model_args(parser: argparse.ArgumentParser):
)
def add_moonshine_model_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--moonshine-preprocessor",
default="",
type=str,
help="Path to moonshine preprocessor model",
)
parser.add_argument(
"--moonshine-encoder",
default="",
type=str,
help="Path to moonshine encoder model",
)
parser.add_argument(
"--moonshine-uncached-decoder",
default="",
type=str,
help="Path to moonshine uncached decoder model",
)
parser.add_argument(
"--moonshine-cached-decoder",
default="",
type=str,
help="Path to moonshine cached decoder model",
)
def add_whisper_model_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--whisper-encoder",
... ... @@ -311,6 +355,7 @@ def add_model_args(parser: argparse.ArgumentParser):
add_wenet_ctc_model_args(parser)
add_tdnn_ctc_model_args(parser)
add_whisper_model_args(parser)
add_moonshine_model_args(parser)
parser.add_argument(
"--tokens",
... ... @@ -876,6 +921,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.encoder)
assert_file_exists(args.decoder)
... ... @@ -903,6 +954,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.paraformer)
... ... @@ -921,6 +978,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.sense_voice)
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
... ... @@ -934,6 +997,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.nemo_ctc)
... ... @@ -950,6 +1019,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.wenet_ctc)
... ... @@ -966,6 +1041,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.tdnn_model) == 0, args.tdnn_model
assert_file_exists(args.whisper_encoder)
assert_file_exists(args.whisper_decoder)
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
encoder=args.whisper_encoder,
... ... @@ -980,6 +1061,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
)
elif args.tdnn_model:
assert_file_exists(args.tdnn_model)
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc(
model=args.tdnn_model,
... ... @@ -990,6 +1077,21 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
decoding_method=args.decoding_method,
provider=args.provider,
)
elif args.moonshine_preprocessor:
assert_file_exists(args.moonshine_preprocessor)
assert_file_exists(args.moonshine_encoder)
assert_file_exists(args.moonshine_uncached_decoder)
assert_file_exists(args.moonshine_cached_decoder)
recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine(
preprocessor=args.moonshine_preprocessor,
encoder=args.moonshine_encoder,
uncached_decoder=args.moonshine_uncached_decoder,
cached_decoder=args.moonshine_cached_decoder,
tokens=args.tokens,
num_threads=args.num_threads,
decoding_method=args.decoding_method,
)
else:
raise ValueError("Please specify at least one model")
... ...
#!/usr/bin/env python3
"""
This file shows how to use a non-streaming Moonshine model from
https://github.com/usefulsensors/moonshine
to decode files.
Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
For instance,
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
"""
import datetime as dt
from pathlib import Path
import sherpa_onnx
import soundfile as sf
def create_recognizer():
preprocessor = "./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx"
encoder = "./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx"
uncached_decoder = "./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx"
cached_decoder = "./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx"
tokens = "./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt"
test_wav = "./sherpa-onnx-moonshine-tiny-en-int8/test_wavs/0.wav"
if not Path(preprocessor).is_file() or not Path(test_wav).is_file():
raise ValueError(
"""Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
"""
)
return (
sherpa_onnx.OfflineRecognizer.from_moonshine(
preprocessor=preprocessor,
encoder=encoder,
uncached_decoder=uncached_decoder,
cached_decoder=cached_decoder,
tokens=tokens,
debug=True,
),
test_wav,
)
def main():
recognizer, wave_filename = create_recognizer()
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
# sample_rate does not need to be 16000 Hz
start_t = dt.datetime.now()
stream = recognizer.create_stream()
stream.accept_waveform(sample_rate, audio)
recognizer.decode_stream(stream)
end_t = dt.datetime.now()
elapsed_seconds = (end_t - start_t).total_seconds()
duration = audio.shape[-1] / sample_rate
rtf = elapsed_seconds / duration
print(stream.result)
print(wave_filename)
print("Text:", stream.result.text)
print(f"Audio duration:\t{duration:.3f} s")
print(f"Elapsed:\t{elapsed_seconds:.3f} s")
print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
if __name__ == "__main__":
main()
... ...
#!/usr/bin/env python3
"""
This file shows how to use a non-streaming whisper model from
https://github.com/openai/whisper
to decode files.
Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
For instance,
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2
tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2
rm sherpa-onnx-whisper-tiny.en.tar.bz2
"""
import datetime as dt
from pathlib import Path
import sherpa_onnx
import soundfile as sf
def create_recognizer():
encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx"
decoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx"
tokens = "./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt"
test_wav = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav"
if not Path(encoder).is_file() or not Path(test_wav).is_file():
raise ValueError(
"""Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
"""
)
return (
sherpa_onnx.OfflineRecognizer.from_whisper(
encoder=encoder,
decoder=decoder,
tokens=tokens,
debug=True,
),
test_wav,
)
def main():
recognizer, wave_filename = create_recognizer()
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
# sample_rate does not need to be 16000 Hz
start_t = dt.datetime.now()
stream = recognizer.create_stream()
stream.accept_waveform(sample_rate, audio)
recognizer.decode_stream(stream)
end_t = dt.datetime.now()
elapsed_seconds = (end_t - start_t).total_seconds()
duration = audio.shape[-1] / sample_rate
rtf = elapsed_seconds / duration
print(stream.result)
print(wave_filename)
print("Text:", stream.result.text)
print(f"Audio duration:\t{duration:.3f} s")
print(f"Elapsed:\t{elapsed_seconds:.3f} s")
print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
if __name__ == "__main__":
main()
... ...
... ... @@ -35,7 +35,18 @@ Note that you need a non-streaming model for this script.
--sample-rate=16000 \
--feature-dim=80
(3) For Whisper models
(3) For Moonshine models
./python-api-examples/vad-with-non-streaming-asr.py \
--silero-vad-model=/path/to/silero_vad.onnx \
--moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \
--moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \
--moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \
--moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \
--tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt \
--num-threads=2
(4) For Whisper models
./python-api-examples/vad-with-non-streaming-asr.py \
--silero-vad-model=/path/to/silero_vad.onnx \
... ... @@ -45,7 +56,7 @@ Note that you need a non-streaming model for this script.
--whisper-task=transcribe \
--num-threads=2
(4) For SenseVoice CTC models
(5) For SenseVoice CTC models
./python-api-examples/vad-with-non-streaming-asr.py \
--silero-vad-model=/path/to/silero_vad.onnx \
... ... @@ -193,6 +204,34 @@ def get_args():
)
parser.add_argument(
"--moonshine-preprocessor",
default="",
type=str,
help="Path to moonshine preprocessor model",
)
parser.add_argument(
"--moonshine-encoder",
default="",
type=str,
help="Path to moonshine encoder model",
)
parser.add_argument(
"--moonshine-uncached-decoder",
default="",
type=str,
help="Path to moonshine uncached decoder model",
)
parser.add_argument(
"--moonshine-cached-decoder",
default="",
type=str,
help="Path to moonshine cached decoder model",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
... ... @@ -251,6 +290,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.sense_voice) == 0, args.sense_voice
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.encoder)
assert_file_exists(args.decoder)
... ... @@ -272,6 +317,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.sense_voice) == 0, args.sense_voice
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.paraformer)
... ... @@ -287,6 +338,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
elif args.sense_voice:
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
assert_file_exists(args.sense_voice)
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
... ... @@ -299,6 +356,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
elif args.whisper_encoder:
assert_file_exists(args.whisper_encoder)
assert_file_exists(args.whisper_decoder)
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
encoder=args.whisper_encoder,
... ... @@ -311,6 +374,22 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
task=args.whisper_task,
tail_paddings=args.whisper_tail_paddings,
)
elif args.moonshine_preprocessor:
assert_file_exists(args.moonshine_preprocessor)
assert_file_exists(args.moonshine_encoder)
assert_file_exists(args.moonshine_uncached_decoder)
assert_file_exists(args.moonshine_cached_decoder)
recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine(
preprocessor=args.moonshine_preprocessor,
encoder=args.moonshine_encoder,
uncached_decoder=args.moonshine_uncached_decoder,
cached_decoder=args.moonshine_cached_decoder,
tokens=args.tokens,
num_threads=args.num_threads,
decoding_method=args.decoding_method,
debug=args.debug,
)
else:
raise ValueError("Please specify at least one model")
... ...
... ... @@ -29,6 +29,9 @@ set(sources
offline-lm-config.cc
offline-lm.cc
offline-model-config.cc
offline-moonshine-greedy-search-decoder.cc
offline-moonshine-model-config.cc
offline-moonshine-model.cc
offline-nemo-enc-dec-ctc-model-config.cc
offline-nemo-enc-dec-ctc-model.cc
offline-paraformer-greedy-search-decoder.cc
... ...
... ... @@ -19,6 +19,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
zipformer_ctc.Register(po);
wenet_ctc.Register(po);
sense_voice.Register(po);
moonshine.Register(po);
po->Register("telespeech-ctc", &telespeech_ctc,
"Path to model.onnx for telespeech ctc");
... ... @@ -99,6 +100,10 @@ bool OfflineModelConfig::Validate() const {
return sense_voice.Validate();
}
if (!moonshine.preprocessor.empty()) {
return moonshine.Validate();
}
if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) {
SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist",
telespeech_ctc.c_str());
... ... @@ -124,6 +129,7 @@ std::string OfflineModelConfig::ToString() const {
os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
os << "sense_voice=" << sense_voice.ToString() << ", ";
os << "moonshine=" << moonshine.ToString() << ", ";
os << "telespeech_ctc=\"" << telespeech_ctc << "\", ";
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";
... ...
... ... @@ -6,6 +6,7 @@
#include <string>
#include "sherpa-onnx/csrc/offline-moonshine-model-config.h"
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
#include "sherpa-onnx/csrc/offline-sense-voice-model-config.h"
... ... @@ -26,6 +27,7 @@ struct OfflineModelConfig {
OfflineZipformerCtcModelConfig zipformer_ctc;
OfflineWenetCtcModelConfig wenet_ctc;
OfflineSenseVoiceModelConfig sense_voice;
OfflineMoonshineModelConfig moonshine;
std::string telespeech_ctc;
std::string tokens;
... ... @@ -56,6 +58,7 @@ struct OfflineModelConfig {
const OfflineZipformerCtcModelConfig &zipformer_ctc,
const OfflineWenetCtcModelConfig &wenet_ctc,
const OfflineSenseVoiceModelConfig &sense_voice,
const OfflineMoonshineModelConfig &moonshine,
const std::string &telespeech_ctc,
const std::string &tokens, int32_t num_threads, bool debug,
const std::string &provider, const std::string &model_type,
... ... @@ -69,6 +72,7 @@ struct OfflineModelConfig {
zipformer_ctc(zipformer_ctc),
wenet_ctc(wenet_ctc),
sense_voice(sense_voice),
moonshine(moonshine),
telespeech_ctc(telespeech_ctc),
tokens(tokens),
num_threads(num_threads),
... ...
// sherpa-onnx/csrc/offline-moonshine-decoder.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_DECODER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_DECODER_H_
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
namespace sherpa_onnx {
struct OfflineMoonshineDecoderResult {
/// The decoded token IDs
std::vector<int32_t> tokens;
};
class OfflineMoonshineDecoder {
public:
virtual ~OfflineMoonshineDecoder() = default;
/** Run beam search given the output from the moonshine encoder model.
*
* @param encoder_out A 3-D tensor of shape (batch_size, T, dim)
* @return Return a vector of size `N` containing the decoded results.
*/
virtual std::vector<OfflineMoonshineDecoderResult> Decode(
Ort::Value encoder_out) = 0;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_DECODER_H_
... ...
// sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h"
#include <algorithm>
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
std::vector<OfflineMoonshineDecoderResult>
OfflineMoonshineGreedySearchDecoder::Decode(Ort::Value encoder_out) {
auto encoder_out_shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape();
if (encoder_out_shape[0] != 1) {
SHERPA_ONNX_LOGE("Support only batch size == 1. Given: %d\n",
static_cast<int32_t>(encoder_out_shape[0]));
return {};
}
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
// encoder_out_shape[1] * 384 is the number of audio samples
// 16000 is the sample rate
//
//
// 384 is from the moonshine paper
int32_t max_len =
static_cast<int32_t>(encoder_out_shape[1] * 384 / 16000.0 * 6);
int32_t sos = 1;
int32_t eos = 2;
int32_t seq_len = 1;
std::vector<int32_t> tokens;
std::array<int64_t, 2> token_shape = {1, 1};
int64_t seq_len_shape = 1;
Ort::Value token_tensor = Ort::Value::CreateTensor(
memory_info, &sos, 1, token_shape.data(), token_shape.size());
Ort::Value seq_len_tensor =
Ort::Value::CreateTensor(memory_info, &seq_len, 1, &seq_len_shape, 1);
Ort::Value logits{nullptr};
std::vector<Ort::Value> states;
std::tie(logits, states) = model_->ForwardUnCachedDecoder(
std::move(token_tensor), std::move(seq_len_tensor), View(&encoder_out));
int32_t vocab_size = logits.GetTensorTypeAndShapeInfo().GetShape()[2];
for (int32_t i = 0; i != max_len; ++i) {
const float *p = logits.GetTensorData<float>();
int32_t max_token_id = static_cast<int32_t>(
std::distance(p, std::max_element(p, p + vocab_size)));
if (max_token_id == eos) {
break;
}
tokens.push_back(max_token_id);
seq_len += 1;
token_tensor = Ort::Value::CreateTensor(
memory_info, &tokens.back(), 1, token_shape.data(), token_shape.size());
seq_len_tensor =
Ort::Value::CreateTensor(memory_info, &seq_len, 1, &seq_len_shape, 1);
std::tie(logits, states) = model_->ForwardCachedDecoder(
std::move(token_tensor), std::move(seq_len_tensor), View(&encoder_out),
std::move(states));
}
OfflineMoonshineDecoderResult ans;
ans.tokens = std::move(tokens);
return {ans};
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_GREEDY_SEARCH_DECODER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_GREEDY_SEARCH_DECODER_H_
#include <vector>
#include "sherpa-onnx/csrc/offline-moonshine-decoder.h"
#include "sherpa-onnx/csrc/offline-moonshine-model.h"
namespace sherpa_onnx {
class OfflineMoonshineGreedySearchDecoder : public OfflineMoonshineDecoder {
public:
explicit OfflineMoonshineGreedySearchDecoder(OfflineMoonshineModel *model)
: model_(model) {}
std::vector<OfflineMoonshineDecoderResult> Decode(
Ort::Value encoder_out) override;
private:
OfflineMoonshineModel *model_; // not owned
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_GREEDY_SEARCH_DECODER_H_
... ...
// sherpa-onnx/csrc/offline-moonshine-model-config.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-moonshine-model-config.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OfflineMoonshineModelConfig::Register(ParseOptions *po) {
po->Register("moonshine-preprocessor", &preprocessor,
"Path to onnx preprocessor of moonshine, e.g., preprocess.onnx");
po->Register("moonshine-encoder", &encoder,
"Path to onnx encoder of moonshine, e.g., encode.onnx");
po->Register(
"moonshine-uncached-decoder", &uncached_decoder,
"Path to onnx uncached_decoder of moonshine, e.g., uncached_decode.onnx");
po->Register(
"moonshine-cached-decoder", &cached_decoder,
"Path to onnx cached_decoder of moonshine, e.g., cached_decode.onnx");
}
bool OfflineMoonshineModelConfig::Validate() const {
if (preprocessor.empty()) {
SHERPA_ONNX_LOGE("Please provide --moonshine-preprocessor");
return false;
}
if (!FileExists(preprocessor)) {
SHERPA_ONNX_LOGE("moonshine preprocessor file '%s' does not exist",
preprocessor.c_str());
return false;
}
if (encoder.empty()) {
SHERPA_ONNX_LOGE("Please provide --moonshine-encoder");
return false;
}
if (!FileExists(encoder)) {
SHERPA_ONNX_LOGE("moonshine encoder file '%s' does not exist",
encoder.c_str());
return false;
}
if (uncached_decoder.empty()) {
SHERPA_ONNX_LOGE("Please provide --moonshine-uncached-decoder");
return false;
}
if (!FileExists(uncached_decoder)) {
SHERPA_ONNX_LOGE("moonshine uncached decoder file '%s' does not exist",
uncached_decoder.c_str());
return false;
}
if (cached_decoder.empty()) {
SHERPA_ONNX_LOGE("Please provide --moonshine-cached-decoder");
return false;
}
if (!FileExists(cached_decoder)) {
SHERPA_ONNX_LOGE("moonshine cached decoder file '%s' does not exist",
cached_decoder.c_str());
return false;
}
return true;
}
std::string OfflineMoonshineModelConfig::ToString() const {
std::ostringstream os;
os << "OfflineMoonshineModelConfig(";
os << "preprocessor=\"" << preprocessor << "\", ";
os << "encoder=\"" << encoder << "\", ";
os << "uncached_decoder=\"" << uncached_decoder << "\", ";
os << "cached_decoder=\"" << cached_decoder << "\")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-moonshine-model-config.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OfflineMoonshineModelConfig {
std::string preprocessor;
std::string encoder;
std::string uncached_decoder;
std::string cached_decoder;
OfflineMoonshineModelConfig() = default;
OfflineMoonshineModelConfig(const std::string &preprocessor,
const std::string &encoder,
const std::string &uncached_decoder,
const std::string &cached_decoder)
: preprocessor(preprocessor),
encoder(encoder),
uncached_decoder(uncached_decoder),
cached_decoder(cached_decoder) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
... ...
// sherpa-onnx/csrc/offline-moonshine-model.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-moonshine-model.h"
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
class OfflineMoonshineModel::Impl {
public:
explicit Impl(const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(config.moonshine.preprocessor);
InitPreprocessor(buf.data(), buf.size());
}
{
auto buf = ReadFile(config.moonshine.encoder);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(config.moonshine.uncached_decoder);
InitUnCachedDecoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(config.moonshine.cached_decoder);
InitCachedDecoder(buf.data(), buf.size());
}
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(mgr, config.moonshine.preprocessor);
InitPreprocessor(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.moonshine.encoder);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.moonshine.uncached_decoder);
InitUnCachedDecoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.moonshine.cached_decoder);
InitCachedDecoder(buf.data(), buf.size());
}
}
#endif
Ort::Value ForwardPreprocessor(Ort::Value audio) {
auto features = preprocessor_sess_->Run(
{}, preprocessor_input_names_ptr_.data(), &audio, 1,
preprocessor_output_names_ptr_.data(),
preprocessor_output_names_ptr_.size());
return std::move(features[0]);
}
Ort::Value ForwardEncoder(Ort::Value features, Ort::Value features_len) {
std::array<Ort::Value, 2> encoder_inputs{std::move(features),
std::move(features_len)};
auto encoder_out = encoder_sess_->Run(
{}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
encoder_inputs.size(), encoder_output_names_ptr_.data(),
encoder_output_names_ptr_.size());
return std::move(encoder_out[0]);
}
std::pair<Ort::Value, std::vector<Ort::Value>> ForwardUnCachedDecoder(
Ort::Value tokens, Ort::Value seq_len, Ort::Value encoder_out) {
std::array<Ort::Value, 3> uncached_decoder_input = {
std::move(tokens),
std::move(encoder_out),
std::move(seq_len),
};
auto uncached_decoder_out = uncached_decoder_sess_->Run(
{}, uncached_decoder_input_names_ptr_.data(),
uncached_decoder_input.data(), uncached_decoder_input.size(),
uncached_decoder_output_names_ptr_.data(),
uncached_decoder_output_names_ptr_.size());
std::vector<Ort::Value> states;
states.reserve(uncached_decoder_out.size() - 1);
int32_t i = -1;
for (auto &s : uncached_decoder_out) {
++i;
if (i == 0) {
continue;
}
states.push_back(std::move(s));
}
return {std::move(uncached_decoder_out[0]), std::move(states)};
}
std::pair<Ort::Value, std::vector<Ort::Value>> ForwardCachedDecoder(
Ort::Value tokens, Ort::Value seq_len, Ort::Value encoder_out,
std::vector<Ort::Value> states) {
std::vector<Ort::Value> cached_decoder_input;
cached_decoder_input.reserve(3 + states.size());
cached_decoder_input.push_back(std::move(tokens));
cached_decoder_input.push_back(std::move(encoder_out));
cached_decoder_input.push_back(std::move(seq_len));
for (auto &s : states) {
cached_decoder_input.push_back(std::move(s));
}
auto cached_decoder_out = cached_decoder_sess_->Run(
{}, cached_decoder_input_names_ptr_.data(), cached_decoder_input.data(),
cached_decoder_input.size(), cached_decoder_output_names_ptr_.data(),
cached_decoder_output_names_ptr_.size());
std::vector<Ort::Value> next_states;
next_states.reserve(cached_decoder_out.size() - 1);
int32_t i = -1;
for (auto &s : cached_decoder_out) {
++i;
if (i == 0) {
continue;
}
next_states.push_back(std::move(s));
}
return {std::move(cached_decoder_out[0]), std::move(next_states)};
}
OrtAllocator *Allocator() const { return allocator_; }
private:
void InitPreprocessor(void *model_data, size_t model_data_length) {
preprocessor_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(preprocessor_sess_.get(), &preprocessor_input_names_,
&preprocessor_input_names_ptr_);
GetOutputNames(preprocessor_sess_.get(), &preprocessor_output_names_,
&preprocessor_output_names_ptr_);
}
void InitEncoder(void *model_data, size_t model_data_length) {
encoder_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
&encoder_input_names_ptr_);
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
&encoder_output_names_ptr_);
}
void InitUnCachedDecoder(void *model_data, size_t model_data_length) {
uncached_decoder_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(uncached_decoder_sess_.get(), &uncached_decoder_input_names_,
&uncached_decoder_input_names_ptr_);
GetOutputNames(uncached_decoder_sess_.get(),
&uncached_decoder_output_names_,
&uncached_decoder_output_names_ptr_);
}
void InitCachedDecoder(void *model_data, size_t model_data_length) {
cached_decoder_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(cached_decoder_sess_.get(), &cached_decoder_input_names_,
&cached_decoder_input_names_ptr_);
GetOutputNames(cached_decoder_sess_.get(), &cached_decoder_output_names_,
&cached_decoder_output_names_ptr_);
}
private:
OfflineModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> preprocessor_sess_;
std::unique_ptr<Ort::Session> encoder_sess_;
std::unique_ptr<Ort::Session> uncached_decoder_sess_;
std::unique_ptr<Ort::Session> cached_decoder_sess_;
std::vector<std::string> preprocessor_input_names_;
std::vector<const char *> preprocessor_input_names_ptr_;
std::vector<std::string> preprocessor_output_names_;
std::vector<const char *> preprocessor_output_names_ptr_;
std::vector<std::string> encoder_input_names_;
std::vector<const char *> encoder_input_names_ptr_;
std::vector<std::string> encoder_output_names_;
std::vector<const char *> encoder_output_names_ptr_;
std::vector<std::string> uncached_decoder_input_names_;
std::vector<const char *> uncached_decoder_input_names_ptr_;
std::vector<std::string> uncached_decoder_output_names_;
std::vector<const char *> uncached_decoder_output_names_ptr_;
std::vector<std::string> cached_decoder_input_names_;
std::vector<const char *> cached_decoder_input_names_ptr_;
std::vector<std::string> cached_decoder_output_names_;
std::vector<const char *> cached_decoder_output_names_ptr_;
};
OfflineMoonshineModel::OfflineMoonshineModel(const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineMoonshineModel::OfflineMoonshineModel(AAssetManager *mgr,
const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineMoonshineModel::~OfflineMoonshineModel() = default;
Ort::Value OfflineMoonshineModel::ForwardPreprocessor(Ort::Value audio) const {
return impl_->ForwardPreprocessor(std::move(audio));
}
Ort::Value OfflineMoonshineModel::ForwardEncoder(
Ort::Value features, Ort::Value features_len) const {
return impl_->ForwardEncoder(std::move(features), std::move(features_len));
}
std::pair<Ort::Value, std::vector<Ort::Value>>
OfflineMoonshineModel::ForwardUnCachedDecoder(Ort::Value token,
Ort::Value seq_len,
Ort::Value encoder_out) const {
return impl_->ForwardUnCachedDecoder(std::move(token), std::move(seq_len),
std::move(encoder_out));
}
std::pair<Ort::Value, std::vector<Ort::Value>>
OfflineMoonshineModel::ForwardCachedDecoder(
Ort::Value token, Ort::Value seq_len, Ort::Value encoder_out,
std::vector<Ort::Value> states) const {
return impl_->ForwardCachedDecoder(std::move(token), std::move(seq_len),
std::move(encoder_out), std::move(states));
}
OrtAllocator *OfflineMoonshineModel::Allocator() const {
return impl_->Allocator();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-moonshine-model.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-model-config.h"
namespace sherpa_onnx {
// please see
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/moonshine/test.py
class OfflineMoonshineModel {
public:
explicit OfflineMoonshineModel(const OfflineModelConfig &config);
#if __ANDROID_API__ >= 9
OfflineMoonshineModel(AAssetManager *mgr, const OfflineModelConfig &config);
#endif
~OfflineMoonshineModel();
/** Run the preprocessor model.
*
* @param audio A float32 tensor of shape (batch_size, num_samples)
*
* @return Return a float32 tensor of shape (batch_size, T, dim) that
* can be used as the input of ForwardEncoder()
*/
Ort::Value ForwardPreprocessor(Ort::Value audio) const;
/** Run the encoder model.
*
* @param features A float32 tensor of shape (batch_size, T, dim)
* @param features_len A int32 tensor of shape (batch_size,)
* @returns A float32 tensor of shape (batch_size, T, dim).
*/
Ort::Value ForwardEncoder(Ort::Value features, Ort::Value features_len) const;
/** Run the uncached decoder.
*
* @param token A int32 tensor of shape (batch_size, num_tokens)
* @param seq_len A int32 tensor of shape (batch_size,) containing number
* of predicted tokens so far
* @param encoder_out A float32 tensor of shape (batch_size, T, dim)
*
* @returns Return a pair:
*
* - logits, a float32 tensor of shape (batch_size, 1, dim)
* - states, a list of states
*/
std::pair<Ort::Value, std::vector<Ort::Value>> ForwardUnCachedDecoder(
Ort::Value token, Ort::Value seq_len, Ort::Value encoder_out) const;
/** Run the cached decoder.
*
* @param token A int32 tensor of shape (batch_size, num_tokens)
* @param seq_len A int32 tensor of shape (batch_size,) containing number
* of predicted tokens so far
* @param encoder_out A float32 tensor of shape (batch_size, T, dim)
* @param states A list of previous states
*
* @returns Return a pair:
* - logits, a float32 tensor of shape (batch_size, 1, dim)
* - states, a list of new states
*/
std::pair<Ort::Value, std::vector<Ort::Value>> ForwardCachedDecoder(
Ort::Value token, Ort::Value seq_len, Ort::Value encoder_out,
std::vector<Ort::Value> states) const;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_H_
... ...
... ... @@ -20,6 +20,7 @@
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
... ... @@ -51,6 +52,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
}
if (!config.model_config.moonshine.preprocessor.empty()) {
return std::make_unique<OfflineRecognizerMoonshineImpl>(config);
}
// TODO(fangjun): Refactor it. We only need to use model type for the
// following models:
// 1. transducer and nemo_transducer
... ... @@ -67,7 +72,11 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
model_type == "telespeech_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(config);
} else if (model_type == "whisper") {
// unreachable
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
} else if (model_type == "moonshine") {
// unreachable
return std::make_unique<OfflineRecognizerMoonshineImpl>(config);
} else {
SHERPA_ONNX_LOGE(
"Invalid model_type: %s. Trying to load the model to get its type",
... ... @@ -225,6 +234,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
}
if (!config.model_config.moonshine.preprocessor.empty()) {
return std::make_unique<OfflineRecognizerMoonshineImpl>(mgr, config);
}
// TODO(fangjun): Refactor it. We only need to use model type for the
// following models:
// 1. transducer and nemo_transducer
... ... @@ -242,6 +255,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
} else if (model_type == "whisper") {
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
} else if (model_type == "moonshine") {
return std::make_unique<OfflineRecognizerMoonshineImpl>(mgr, config);
} else {
SHERPA_ONNX_LOGE(
"Invalid model_type: %s. Trying to load the model to get its type",
... ...
// sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_MOONSHINE_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_MOONSHINE_IMPL_H_
#include <algorithm>
#include <cmath>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/csrc/offline-moonshine-decoder.h"
#include "sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/offline-moonshine-model.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/transpose.h"
namespace sherpa_onnx {
static OfflineRecognitionResult Convert(
const OfflineMoonshineDecoderResult &src, const SymbolTable &sym_table) {
OfflineRecognitionResult r;
r.tokens.reserve(src.tokens.size());
std::string text;
for (auto i : src.tokens) {
if (!sym_table.Contains(i)) {
continue;
}
const auto &s = sym_table[i];
text += s;
r.tokens.push_back(s);
}
r.text = text;
return r;
}
class OfflineRecognizerMoonshineImpl : public OfflineRecognizerImpl {
public:
explicit OfflineRecognizerMoonshineImpl(const OfflineRecognizerConfig &config)
: OfflineRecognizerImpl(config),
config_(config),
symbol_table_(config_.model_config.tokens),
model_(std::make_unique<OfflineMoonshineModel>(config.model_config)) {
Init();
}
#if __ANDROID_API__ >= 9
OfflineRecognizerMoonshineImpl(AAssetManager *mgr,
const OfflineRecognizerConfig &config)
: OfflineRecognizerImpl(mgr, config),
config_(config),
symbol_table_(mgr, config_.model_config.tokens),
model_(
std::make_unique<OfflineMoonshineModel>(mgr, config.model_config)) {
Init();
}
#endif
void Init() {
if (config_.decoding_method == "greedy_search") {
decoder_ =
std::make_unique<OfflineMoonshineGreedySearchDecoder>(model_.get());
} else {
SHERPA_ONNX_LOGE(
"Only greedy_search is supported at present for moonshine. Given %s",
config_.decoding_method.c_str());
exit(-1);
}
}
std::unique_ptr<OfflineStream> CreateStream() const override {
MoonshineTag tag;
return std::make_unique<OfflineStream>(tag);
}
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
// batch decoding is not implemented yet
for (int32_t i = 0; i != n; ++i) {
DecodeStream(ss[i]);
}
}
OfflineRecognizerConfig GetConfig() const override { return config_; }
private:
void DecodeStream(OfflineStream *s) const {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::vector<float> audio = s->GetFrames();
try {
std::array<int64_t, 2> shape{1, static_cast<int64_t>(audio.size())};
Ort::Value audio_tensor = Ort::Value::CreateTensor(
memory_info, audio.data(), audio.size(), shape.data(), shape.size());
Ort::Value features =
model_->ForwardPreprocessor(std::move(audio_tensor));
int32_t features_len = features.GetTensorTypeAndShapeInfo().GetShape()[1];
int64_t features_shape = 1;
Ort::Value features_len_tensor = Ort::Value::CreateTensor(
memory_info, &features_len, 1, &features_shape, 1);
Ort::Value encoder_out = model_->ForwardEncoder(
std::move(features), std::move(features_len_tensor));
auto results = decoder_->Decode(std::move(encoder_out));
auto r = Convert(results[0], symbol_table_);
r.text = ApplyInverseTextNormalization(std::move(r.text));
s->SetResult(r);
} catch (const Ort::Exception &ex) {
SHERPA_ONNX_LOGE(
"\n\nCaught exception:\n\n%s\n\nReturn an empty result. Number of "
"audio samples: %d",
ex.what(), static_cast<int32_t>(audio.size()));
return;
}
}
private:
OfflineRecognizerConfig config_;
SymbolTable symbol_table_;
std::unique_ptr<OfflineMoonshineModel> model_;
std::unique_ptr<OfflineMoonshineDecoder> decoder_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_MOONSHINE_IMPL_H_
... ...
... ... @@ -133,6 +133,10 @@ class OfflineStream::Impl {
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
explicit Impl(MoonshineTag /*tag*/) : is_moonshine_(true) {
config_.sampling_rate = 16000;
}
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
if (config_.normalize_samples) {
AcceptWaveformImpl(sampling_rate, waveform, n);
... ... @@ -164,7 +168,9 @@ class OfflineStream::Impl {
std::vector<float> samples;
resampler->Resample(waveform, n, true, &samples);
if (fbank_) {
if (is_moonshine_) {
samples_.insert(samples_.end(), samples.begin(), samples.end());
} else if (fbank_) {
fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
fbank_->InputFinished();
... ... @@ -181,7 +187,9 @@ class OfflineStream::Impl {
return;
} // if (sampling_rate != config_.sampling_rate)
if (fbank_) {
if (is_moonshine_) {
samples_.insert(samples_.end(), waveform, waveform + n);
} else if (fbank_) {
fbank_->AcceptWaveform(sampling_rate, waveform, n);
fbank_->InputFinished();
} else if (mfcc_) {
... ... @@ -194,10 +202,18 @@ class OfflineStream::Impl {
}
int32_t FeatureDim() const {
if (is_moonshine_) {
return samples_.size();
}
return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins;
}
std::vector<float> GetFrames() const {
if (is_moonshine_) {
return samples_;
}
int32_t n = fbank_ ? fbank_->NumFramesReady()
: mfcc_ ? mfcc_->NumFramesReady()
: whisper_fbank_->NumFramesReady();
... ... @@ -300,6 +316,10 @@ class OfflineStream::Impl {
OfflineRecognitionResult r_;
ContextGraphPtr context_graph_;
bool is_ced_ = false;
bool is_moonshine_ = false;
// used only when is_moonshine_== true
std::vector<float> samples_;
};
OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/,
... ... @@ -311,6 +331,9 @@ OfflineStream::OfflineStream(WhisperTag tag)
OfflineStream::OfflineStream(CEDTag tag) : impl_(std::make_unique<Impl>(tag)) {}
OfflineStream::OfflineStream(MoonshineTag tag)
: impl_(std::make_unique<Impl>(tag)) {}
OfflineStream::~OfflineStream() = default;
void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform,
... ...
... ... @@ -34,7 +34,7 @@ struct OfflineRecognitionResult {
// event target of the audio.
std::string event;
/// timestamps.size() == tokens.size()
/// timestamps.size() == tokens.size()
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
std::vector<float> timestamps;
... ... @@ -49,6 +49,10 @@ struct WhisperTag {
struct CEDTag {};
// It uses a neural network model, a preprocessor, to convert
// audio samples to features
struct MoonshineTag {};
class OfflineStream {
public:
explicit OfflineStream(const FeatureExtractorConfig &config = {},
... ... @@ -56,6 +60,7 @@ class OfflineStream {
explicit OfflineStream(WhisperTag tag);
explicit OfflineStream(CEDTag tag);
explicit OfflineStream(MoonshineTag tag);
~OfflineStream();
/**
... ... @@ -72,7 +77,10 @@ class OfflineStream {
void AcceptWaveform(int32_t sampling_rate, const float *waveform,
int32_t n) const;
/// Return feature dim of this extractor
/// Return feature dim of this extractor.
///
/// Note: if it is Moonshine, then it returns the number of audio samples
/// currently received.
int32_t FeatureDim() const;
// Get all the feature frames of this stream in a 1-D array, which is
... ...
... ... @@ -23,7 +23,6 @@ class OfflineWhisperModel::Impl {
explicit Impl(const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
debug_(config.debug),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
... ... @@ -40,7 +39,6 @@ class OfflineWhisperModel::Impl {
explicit Impl(const SpokenLanguageIdentificationConfig &config)
: lid_config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
debug_(config_.debug),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
... ... @@ -60,7 +58,6 @@ class OfflineWhisperModel::Impl {
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
debug_ = config_.debug;
{
auto buf = ReadFile(mgr, config.whisper.encoder);
InitEncoder(buf.data(), buf.size());
... ... @@ -77,7 +74,6 @@ class OfflineWhisperModel::Impl {
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
debug_ = config_.debug;
{
auto buf = ReadFile(mgr, config.whisper.encoder);
InitEncoder(buf.data(), buf.size());
... ... @@ -164,7 +160,7 @@ class OfflineWhisperModel::Impl {
}
}
if (debug_) {
if (config_.debug) {
SHERPA_ONNX_LOGE("Detected language: %s",
GetID2Lang().at(lang_id).c_str());
}
... ... @@ -237,7 +233,7 @@ class OfflineWhisperModel::Impl {
// get meta data
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
if (debug_) {
if (config_.debug) {
std::ostringstream os;
os << "---encoder---\n";
PrintModelMetadata(os, meta_data);
... ... @@ -294,7 +290,6 @@ class OfflineWhisperModel::Impl {
private:
OfflineModelConfig config_;
SpokenLanguageIdentificationConfig lid_config_;
bool debug_ = false;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
... ...
... ... @@ -43,7 +43,20 @@ See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/in
--decoding-method=greedy_search \
/path/to/foo.wav [bar.wav foobar.wav ...]
(3) Whisper models
(3) Moonshine models
See https://k2-fsa.github.io/sherpa/onnx/moonshine/index.html
./bin/sherpa-onnx-offline \
--moonshine-preprocessor=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/preprocess.onnx \
--moonshine-encoder=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/encode.int8.onnx \
--moonshine-uncached-decoder=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/uncached_decode.int8.onnx \
--moonshine-cached-decoder=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/cached_decode.int8.onnx \
--tokens=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/tokens.txt \
--num-threads=1 \
/path/to/foo.wav [bar.wav foobar.wav ...]
(4) Whisper models
See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html
... ... @@ -54,7 +67,7 @@ See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html
--num-threads=1 \
/path/to/foo.wav [bar.wav foobar.wav ...]
(4) NeMo CTC models
(5) NeMo CTC models
See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html
... ... @@ -68,7 +81,7 @@ See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.htm
./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/1.wav \
./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/8k.wav
(5) TDNN CTC model for the yesno recipe from icefall
(6) TDNN CTC model for the yesno recipe from icefall
See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/yesno/index.html
//
... ...
... ... @@ -109,6 +109,8 @@ const std::string SymbolTable::operator[](int32_t id) const {
// for byte-level BPE
// id 0 is blank, id 1 is sos/eos, id 2 is unk
//
// Note: For moonshine models, 0 is <unk>, 1, is <s>, 2 is</s>
if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' &&
sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') {
std::ostringstream os;
... ...
... ... @@ -11,6 +11,7 @@ set(srcs
offline-ctc-fst-decoder-config.cc
offline-lm-config.cc
offline-model-config.cc
offline-moonshine-model-config.cc
offline-nemo-enc-dec-ctc-model-config.cc
offline-paraformer-model-config.cc
offline-punctuation.cc
... ...
... ... @@ -8,6 +8,7 @@
#include <vector>
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h"
#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
#include "sherpa-onnx/python/csrc/offline-sense-voice-model-config.h"
... ... @@ -28,6 +29,7 @@ void PybindOfflineModelConfig(py::module *m) {
PybindOfflineZipformerCtcModelConfig(m);
PybindOfflineWenetCtcModelConfig(m);
PybindOfflineSenseVoiceModelConfig(m);
PybindOfflineMoonshineModelConfig(m);
using PyClass = OfflineModelConfig;
py::class_<PyClass>(*m, "OfflineModelConfig")
... ... @@ -39,7 +41,8 @@ void PybindOfflineModelConfig(py::module *m) {
const OfflineWhisperModelConfig &, const OfflineTdnnModelConfig &,
const OfflineZipformerCtcModelConfig &,
const OfflineWenetCtcModelConfig &,
const OfflineSenseVoiceModelConfig &, const std::string &,
const OfflineSenseVoiceModelConfig &,
const OfflineMoonshineModelConfig &, const std::string &,
const std::string &, int32_t, bool, const std::string &,
const std::string &, const std::string &, const std::string &>(),
py::arg("transducer") = OfflineTransducerModelConfig(),
... ... @@ -50,6 +53,7 @@ void PybindOfflineModelConfig(py::module *m) {
py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
py::arg("sense_voice") = OfflineSenseVoiceModelConfig(),
py::arg("moonshine") = OfflineMoonshineModelConfig(),
py::arg("telespeech_ctc") = "", py::arg("tokens"),
py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "",
... ... @@ -62,6 +66,7 @@ void PybindOfflineModelConfig(py::module *m) {
.def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc)
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
.def_readwrite("sense_voice", &PyClass::sense_voice)
.def_readwrite("moonshine", &PyClass::moonshine)
.def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc)
.def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("num_threads", &PyClass::num_threads)
... ...
// sherpa-onnx/python/csrc/offline-moonshine-model-config.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-moonshine-model-config.h"
#include <string>
#include <vector>
#include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h"
namespace sherpa_onnx {
void PybindOfflineMoonshineModelConfig(py::module *m) {
using PyClass = OfflineMoonshineModelConfig;
py::class_<PyClass>(*m, "OfflineMoonshineModelConfig")
.def(py::init<const std::string &, const std::string &,
const std::string &, const std::string &>(),
py::arg("preprocessor"), py::arg("encoder"),
py::arg("uncached_decoder"), py::arg("cached_decoder"))
.def_readwrite("preprocessor", &PyClass::preprocessor)
.def_readwrite("encoder", &PyClass::encoder)
.def_readwrite("uncached_decoder", &PyClass::uncached_decoder)
.def_readwrite("cached_decoder", &PyClass::cached_decoder)
.def("__str__", &PyClass::ToString);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/offline-moonshine-model-config.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOfflineMoonshineModelConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_
... ...
... ... @@ -8,13 +8,14 @@ from _sherpa_onnx import (
OfflineCtcFstDecoderConfig,
OfflineLMConfig,
OfflineModelConfig,
OfflineMoonshineModelConfig,
OfflineNemoEncDecCtcModelConfig,
OfflineParaformerModelConfig,
OfflineSenseVoiceModelConfig,
)
from _sherpa_onnx import OfflineRecognizer as _Recognizer
from _sherpa_onnx import (
OfflineRecognizerConfig,
OfflineSenseVoiceModelConfig,
OfflineStream,
OfflineTdnnModelConfig,
OfflineTransducerModelConfig,
... ... @@ -503,12 +504,12 @@ class OfflineRecognizer(object):
e.g., tiny, tiny.en, base, base.en, etc.
Args:
encoder_model:
Path to the encoder model, e.g., tiny-encoder.onnx,
tiny-encoder.int8.onnx, tiny-encoder.ort, etc.
decoder_model:
encoder:
Path to the encoder model, e.g., tiny-encoder.onnx,
tiny-encoder.int8.onnx, tiny-encoder.ort, etc.
decoder:
Path to the decoder model, e.g., tiny-decoder.onnx,
tiny-decoder.int8.onnx, tiny-decoder.ort, etc.
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
... ... @@ -571,6 +572,87 @@ class OfflineRecognizer(object):
return self
@classmethod
def from_moonshine(
cls,
preprocessor: str,
encoder: str,
uncached_decoder: str,
cached_decoder: str,
tokens: str,
num_threads: int = 1,
decoding_method: str = "greedy_search",
debug: bool = False,
provider: str = "cpu",
rule_fsts: str = "",
rule_fars: str = "",
):
"""
Please refer to
`<https://k2-fsa.github.io/sherpa/onnx/moonshine/index.html>`_
to download pre-trained models for different kinds of moonshine models,
e.g., tiny, base, etc.
Args:
preprocessor:
Path to the preprocessor model, e.g., preprocess.onnx
encoder:
Path to the encoder model, e.g., encode.int8.onnx
uncached_decoder:
Path to the uncached decoder model, e.g., uncached_decode.int8.onnx,
cached_decoder:
Path to the cached decoder model, e.g., cached_decode.int8.onnx,
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
num_threads:
Number of threads for neural network computation.
decoding_method:
Valid values: greedy_search.
debug:
True to show debug messages.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
rule_fsts:
If not empty, it specifies fsts for inverse text normalization.
If there are multiple fsts, they are separated by a comma.
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
"""
self = cls.__new__(cls)
model_config = OfflineModelConfig(
moonshine=OfflineMoonshineModelConfig(
preprocessor=preprocessor,
encoder=encoder,
uncached_decoder=uncached_decoder,
cached_decoder=cached_decoder,
),
tokens=tokens,
num_threads=num_threads,
debug=debug,
provider=provider,
)
unused_feat_config = FeatureExtractorConfig(
sampling_rate=16000,
feature_dim=80,
)
recognizer_config = OfflineRecognizerConfig(
model_config=model_config,
feat_config=unused_feat_config,
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
return self
@classmethod
def from_tdnn_ctc(
cls,
model: str,
... ...