Fangjun Kuang
Committed by GitHub

Support whisper models (#238)

正在显示 39 个修改的文件 包含 1835 行增加51 行删除
name: export-whisper-to-onnx
on:
workflow_dispatch:
concurrency:
group: release-whisper-${{ github.ref }}
cancel-in-progress: true
jobs:
release-whisper-models:
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
name: ${{ matrix.model }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [macos-latest]
model: ["tiny.en", "base.en", "small.en", "medium.en"]
steps:
- uses: actions/checkout@v2
- name: Install dependencies
shell: bash
run: |
python3 -m pip install openai-whisper torch onnxruntime onnx
- name: export ${{ matrix.model }}
shell: bash
run: |
cd scripts/whisper
python3 ./export-onnx.py --model ${{ matrix.model }}
python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./
ls -lh
ls -lh ~/.cache/whisper
- name: Publish ${{ matrix.model }} to huggingface
shell: bash
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
cd scripts/whisper
git config --global user.email "csukuangfj@gmail.com"
git config --global user.name "Fangjun Kuang"
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface
cp *.onnx ./huggingface
cp *.ort ./huggingface
cp *tokens.txt ./huggingface
cd huggingface
git status
ls -lh
git lfs track "*.onnx"
git lfs track "*.ort"
git add .
git commit -m "upload ${{ matrix.model }}"
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} main
... ...
... ... @@ -23,14 +23,14 @@ on:
- 'sherpa-onnx/jni/*'
concurrency:
group: jni-${{ github.ref }}
group: run-java-test-${{ github.ref }}
cancel-in-progress: true
permissions:
contents: read
jobs:
jni:
run_java_test:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
... ...
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(sherpa-onnx)
set(SHERPA_ONNX_VERSION "1.5.5")
set(SHERPA_ONNX_VERSION "1.6.0")
# Disable warning about
#
... ...
function(download_kaldi_native_fbank)
include(FetchContent)
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.17.tar.gz")
set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.17.tar.gz")
set(kaldi_native_fbank_HASH "SHA256=300dc282d51d738e70f194ef13a50bf4cf8d54a3b2686d75f7fc2fb821f8c1e6")
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.1.tar.gz")
set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.1.tar.gz")
set(kaldi_native_fbank_HASH "SHA256=c7676f319fa97e8c8bca6018792de120895dcfe122fa9b4bff00f8f9165348e7")
set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
... ... @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank)
# If you don't have access to the Internet,
# please pre-download kaldi-native-fbank
set(possible_file_locations
$ENV{HOME}/Downloads/kaldi-native-fbank-1.17.tar.gz
${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.17.tar.gz
${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.17.tar.gz
/tmp/kaldi-native-fbank-1.17.tar.gz
/star-fj/fangjun/download/github/kaldi-native-fbank-1.17.tar.gz
$ENV{HOME}/Downloads/kaldi-native-fbank-1.18.1.tar.gz
${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.18.1.tar.gz
${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.18.1.tar.gz
/tmp/kaldi-native-fbank-1.18.1.tar.gz
/star-fj/fangjun/download/github/kaldi-native-fbank-1.18.1.tar.gz
)
foreach(f IN LISTS possible_file_locations)
... ...
#!/usr/bin/env python3
#
# Copyright (c) 2023 by manyeyes
# Copyright (c) 2023 Xiaomi Corporation
"""
This file demonstrates how to use sherpa-onnx Python API to transcribe
... ... @@ -34,6 +35,27 @@ file(s) with a non-streaming model.
(3) For CTC models from NeMo
python3 ./python-api-examples/offline-decode-files.py \
--tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \
--nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \
--num-threads=2 \
--decoding-method=greedy_search \
--debug=false \
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav
(4) For Whisper models
python3 ./python-api-examples/offline-decode-files.py \
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
--num-threads=1 \
./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
Please refer to
https://k2-fsa.github.io/sherpa/onnx/index.html
to install sherpa-onnx and to download the pre-trained models
... ... @@ -145,6 +167,20 @@ def get_args():
)
parser.add_argument(
"--whisper-encoder",
default="",
type=str,
help="Path to whisper encoder model",
)
parser.add_argument(
"--whisper-decoder",
default="",
type=str,
help="Path to whisper decoder model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
... ... @@ -247,6 +283,8 @@ def main():
if args.encoder:
assert len(args.paraformer) == 0, args.paraformer
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]
if contexts:
... ... @@ -271,6 +309,9 @@ def main():
)
elif args.paraformer:
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert_file_exists(args.paraformer)
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
... ... @@ -283,6 +324,11 @@ def main():
debug=args.debug,
)
elif args.nemo_ctc:
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert_file_exists(args.nemo_ctc)
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
model=args.nemo_ctc,
tokens=args.tokens,
... ... @@ -292,6 +338,18 @@ def main():
decoding_method=args.decoding_method,
debug=args.debug,
)
elif args.whisper_encoder:
assert_file_exists(args.whisper_encoder)
assert_file_exists(args.whisper_decoder)
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
encoder=args.whisper_encoder,
decoder=args.whisper_decoder,
tokens=args.tokens,
num_threads=args.num_threads,
decoding_method=args.decoding_method,
debug=args.debug,
)
else:
print("Please specify at least one model")
return
... ...
*.onnx
*.config
*.ort
*-tokens.txt
... ...
# Introduction
This folder contains code showing how to convert [Whisper][whisper] to onnx
and use onnxruntime to replace PyTorch for speech recognition.
You can use [sherpa-onnx][sherpa-onnx] to run the converted model.
[whisper]: https://github.com/openai/whisper
[sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx
... ...
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
# flake8: noqa
"""
Note: Code in this file is modified from
https://github.com/TadaoYamaoka/whisper/blob/main/to_onnx.py
Thanks to https://github.com/TadaoYamaoka
for making the onnx export script public.
"""
import argparse
from pathlib import Path
from typing import Any, Dict, Optional
import onnx
import torch
from onnxruntime.quantization import QuantType, quantize_dynamic
from torch import Tensor, nn
import whisper
from whisper.model import (
AudioEncoder,
MultiHeadAttention,
ResidualAttentionBlock,
TextDecoder,
)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
required=True,
# fmt: off
choices=[
"tiny", "tiny.en", "base", "base.en",
"small", "small.en", "medium", "medium.en",
"large", "large-v1", "large-v2"],
# fmt: on
)
return parser.parse_args()
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
class AudioEncoderTensorCache(nn.Module):
def __init__(self, inAudioEncoder: AudioEncoder, inTextDecoder: TextDecoder):
super().__init__()
self.audioEncoder = inAudioEncoder
self.textDecoder = inTextDecoder
def forward(self, x: Tensor):
audio_features = self.audioEncoder(x)
n_layer_cross_k_list = []
n_layer_cross_v_list = []
for block in self.textDecoder.blocks:
n_layer_cross_k_list.append(block.cross_attn.key(audio_features))
n_layer_cross_v_list.append(block.cross_attn.value(audio_features))
return torch.stack(n_layer_cross_k_list), torch.stack(n_layer_cross_v_list)
class MultiHeadAttentionCross(nn.Module):
def __init__(self, inMultiHeadAttention: MultiHeadAttention):
super().__init__()
self.multiHeadAttention = inMultiHeadAttention
def forward(
self,
x: Tensor,
k: Tensor,
v: Tensor,
mask: Optional[Tensor] = None,
):
q = self.multiHeadAttention.query(x)
wv, qk = self.multiHeadAttention.qkv_attention(q, k, v, mask)
return self.multiHeadAttention.out(wv)
class MultiHeadAttentionSelf(nn.Module):
def __init__(self, inMultiHeadAttention: MultiHeadAttention):
super().__init__()
self.multiHeadAttention = inMultiHeadAttention
def forward(
self,
x: Tensor, # (b, n_ctx , n_state)
k_cache: Tensor, # (b, n_ctx_cache, n_state)
v_cache: Tensor, # (b, n_ctx_cache, n_state)
mask: Tensor,
):
q = self.multiHeadAttention.query(x) # (b, n_ctx, n_state)
k = self.multiHeadAttention.key(x) # (b, n_ctx, n_state)
v = self.multiHeadAttention.value(x) # (b, n_ctx, n_state)
k_cache[:, -k.shape[1] :, :] = k # (b, n_ctx_cache + n_ctx, n_state)
v_cache[:, -v.shape[1] :, :] = v # (b, n_ctx_cache + n_ctx, n_state)
wv, qk = self.multiHeadAttention.qkv_attention(q, k_cache, v_cache, mask)
return self.multiHeadAttention.out(wv), k_cache, v_cache
class ResidualAttentionBlockTensorCache(nn.Module):
def __init__(self, inResidualAttentionBlock: ResidualAttentionBlock):
super().__init__()
self.originalBlock = inResidualAttentionBlock
self.attn = MultiHeadAttentionSelf(inResidualAttentionBlock.attn)
self.cross_attn = (
MultiHeadAttentionCross(inResidualAttentionBlock.cross_attn)
if inResidualAttentionBlock.cross_attn
else None
)
def forward(
self,
x: Tensor,
self_k_cache: Tensor,
self_v_cache: Tensor,
cross_k: Tensor,
cross_v: Tensor,
mask: Tensor,
):
self_attn_x, self_k_cache_updated, self_v_cache_updated = self.attn(
self.originalBlock.attn_ln(x), self_k_cache, self_v_cache, mask=mask
)
x = x + self_attn_x
if self.cross_attn:
x = x + self.cross_attn(
self.originalBlock.cross_attn_ln(x), cross_k, cross_v
)
x = x + self.originalBlock.mlp(self.originalBlock.mlp_ln(x))
return x, self_k_cache_updated, self_v_cache_updated
class TextDecoderTensorCache(nn.Module):
def __init__(self, inTextDecoder: TextDecoder, in_n_ctx: int):
super().__init__()
self.textDecoder = inTextDecoder
self.n_ctx = in_n_ctx
self.blocks = []
for orginal_block in self.textDecoder.blocks:
self.blocks.append(ResidualAttentionBlockTensorCache(orginal_block))
def forward(
self,
tokens: Tensor,
n_layer_self_k_cache: Tensor,
n_layer_self_v_cache: Tensor,
n_layer_cross_k: Tensor,
n_layer_cross_v: Tensor,
offset: Tensor,
):
x = (
self.textDecoder.token_embedding(tokens)
+ self.textDecoder.positional_embedding[
offset[0] : offset[0] + tokens.shape[-1]
]
)
x = x.to(n_layer_cross_k[0].dtype)
i = 0
for block in self.blocks:
self_k_cache = n_layer_self_k_cache[i, :, : offset[0] + tokens.shape[-1], :]
self_v_cache = n_layer_self_v_cache[i, :, : offset[0] + tokens.shape[-1], :]
x, self_k_cache, self_v_cache = block(
x,
self_k_cache=self_k_cache,
self_v_cache=self_v_cache,
cross_k=n_layer_cross_k[i],
cross_v=n_layer_cross_v[i],
mask=self.textDecoder.mask,
)
n_layer_self_k_cache[i, :, : offset[0] + tokens.shape[-1], :] = self_k_cache
n_layer_self_v_cache[i, :, : offset[0] + tokens.shape[-1], :] = self_v_cache
i += 1
x = self.textDecoder.ln(x)
logits = (
x
@ torch.transpose(self.textDecoder.token_embedding.weight.to(x.dtype), 0, 1)
).float()
return logits, n_layer_self_k_cache, n_layer_self_v_cache
# ref: https://github.com/ggerganov/whisper.cpp/blob/master/models/convert-pt-to-ggml.py#L232
def convert_tokens(name, model):
whisper_dir = Path(whisper.__file__).parent
multilingual = model.is_multilingual
tokenizer = (
whisper_dir
/ "assets"
/ (multilingual and "multilingual.tiktoken" or "gpt2.tiktoken")
)
if not tokenizer.is_file():
raise ValueError(f"Cannot find {tokenizer}")
# import base64
with open(tokenizer, "r") as f:
contents = f.read()
# tokens = {
# base64.b64decode(token): int(rank)
# for token, rank in (line.split() for line in contents.splitlines() if line)
# }
tokens = {
token: int(rank)
for token, rank in (line.split() for line in contents.splitlines() if line)
}
with open(f"{name}-tokens.txt", "w") as f:
for t, i in tokens.items():
f.write(f"{t} {i}\n")
@torch.no_grad()
def main():
args = get_args()
name = args.model
opset_version = 13
model = whisper.load_model(name)
convert_tokens(name=name, model=model)
# write tokens
tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual)
model.eval()
print(model.dims)
audio = torch.rand(16000 * 2)
audio = whisper.pad_or_trim(audio)
assert audio.shape == (16000 * 30,), audio.shape
# make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio).to(model.device).unsqueeze(0)
batch_size = 1
assert mel.shape == (batch_size, 80, 30 * 100)
encoder = AudioEncoderTensorCache(model.encoder, model.decoder)
n_layer_cross_k, n_layer_cross_v = encoder(mel)
assert n_layer_cross_k.shape == (
model.dims.n_text_layer,
batch_size,
model.dims.n_audio_ctx,
model.dims.n_text_state,
), n_layer_cross_k.shape
assert n_layer_cross_v.shape == (
model.dims.n_text_layer,
batch_size,
model.dims.n_audio_ctx,
model.dims.n_text_state,
), n_layer_cross_v.shape
encoder_filename = f"{name}-encoder.onnx"
torch.onnx.export(
encoder,
mel,
encoder_filename,
opset_version=opset_version,
input_names=["mel"],
output_names=["n_layer_cross_k", "n_layer_cross_v"],
dynamic_axes={
"mel": {0: "n_audio"}, # n_audio is also known as batch_size
"n_layer_cross_k": {1: "n_audio"},
"n_layer_cross_v": {1: "n_audio"},
},
)
encoder_meta_data = {
"model_type": f"whisper-{name}",
"version": "1",
"maintainer": "k2-fsa",
"n_mels": model.dims.n_mels,
"n_audio_ctx": model.dims.n_audio_ctx,
"n_audio_state": model.dims.n_audio_state,
"n_audio_head": model.dims.n_audio_head,
"n_audio_layer": model.dims.n_audio_layer,
"n_vocab": model.dims.n_vocab,
"n_text_ctx": model.dims.n_text_ctx,
"n_text_state": model.dims.n_text_state,
"n_text_head": model.dims.n_text_head,
"n_text_layer": model.dims.n_text_layer,
"sot_sequence": ",".join(list(map(str, tokenizer.sot_sequence))),
"all_language_tokens": ",".join(list(map(str, tokenizer.all_language_tokens))),
"all_language_codes": ",".join(tokenizer.all_language_codes),
"sot": tokenizer.sot,
"sot_index": tokenizer.sot_sequence.index(tokenizer.sot),
"eot": tokenizer.eot,
"blank_id": tokenizer.encode(" ")[0],
"is_multilingual": int(model.is_multilingual),
"no_speech": tokenizer.no_speech,
"non_speech_tokens": ",".join(list(map(str, tokenizer.non_speech_tokens))),
"transcribe": tokenizer.transcribe,
"translate": tokenizer.translate,
"sot_prev": tokenizer.sot_prev,
"sot_lm": tokenizer.sot_lm,
"no_timestamps": tokenizer.no_timestamps,
}
print(f"encoder_meta_data: {encoder_meta_data}")
add_meta_data(filename=encoder_filename, meta_data=encoder_meta_data)
n_audio = mel.shape[0]
tokens = torch.tensor([[tokenizer.sot, tokenizer.sot, tokenizer.sot]] * n_audio).to(
mel.device
) # [n_audio, 3]
decoder = TextDecoderTensorCache(model.decoder, model.dims.n_text_ctx)
n_layer_self_k_cache = torch.zeros(
(
len(model.decoder.blocks),
n_audio,
model.dims.n_text_ctx,
model.dims.n_text_state,
),
device=mel.device,
)
n_layer_self_v_cache = torch.zeros(
(
len(model.decoder.blocks),
n_audio,
model.dims.n_text_ctx,
model.dims.n_text_state,
),
device=mel.device,
)
offset = torch.zeros(1, dtype=torch.int64).to(mel.device)
logits, n_layer_self_k_cache, n_layer_self_v_cache = decoder(
tokens,
n_layer_self_k_cache,
n_layer_self_v_cache,
n_layer_cross_k,
n_layer_cross_v,
offset,
)
assert logits.shape == (n_audio, tokens.shape[1], model.dims.n_vocab)
assert n_layer_self_k_cache.shape == (
model.dims.n_text_layer,
n_audio,
model.dims.n_text_ctx,
model.dims.n_text_state,
)
assert n_layer_self_v_cache.shape == (
model.dims.n_text_layer,
n_audio,
model.dims.n_text_ctx,
model.dims.n_text_state,
)
offset = torch.tensor([tokens.shape[1]], dtype=torch.int64).to(mel.device)
tokens = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = decoder(
tokens,
n_layer_self_k_cache,
n_layer_self_v_cache,
n_layer_cross_k,
n_layer_cross_v,
offset,
)
decoder_filename = f"{name}-decoder.onnx"
torch.onnx.export(
decoder,
(
tokens,
n_layer_self_k_cache,
n_layer_self_v_cache,
n_layer_cross_k,
n_layer_cross_v,
offset,
),
decoder_filename,
opset_version=opset_version,
input_names=[
"tokens",
"in_n_layer_self_k_cache",
"in_n_layer_self_v_cache",
"n_layer_cross_k",
"n_layer_cross_v",
"offset",
],
output_names=["logits", "out_n_layer_self_k_cache", "out_n_layer_self_v_cache"],
dynamic_axes={
"tokens": {0: "n_audio", 1: "n_tokens"},
"in_n_layer_self_k_cache": {1: "n_audio"},
"in_n_layer_self_v_cache": {1: "n_audio"},
"n_layer_cross_k": {1: "n_audio"},
"n_layer_cross_v": {1: "n_audio"},
},
)
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
print("Generate int8 quantization models")
encoder_filename_int8 = f"{name}-encoder.int8.onnx"
quantize_dynamic(
model_input=encoder_filename,
model_output=encoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
decoder_filename_int8 = f"{name}-decoder.int8.onnx"
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
if __name__ == "__main__":
main()
... ...
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
"""
Please first run ./export-onnx.py
before you run this script
"""
import base64
from typing import Tuple
import kaldi_native_fbank as knf
import onnxruntime as ort
import torch
import whisper
import argparse
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
required=True,
# fmt: off
choices=[
"tiny", "tiny.en", "base", "base.en",
"small", "small.en", "medium", "medium.en",
"large", "large-v1", "large-v2"],
# fmt: on
)
return parser.parse_args()
class OnnxModel:
def __init__(
self,
encoder: str,
decoder: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 4
self.session_opts = session_opts
self.init_encoder(encoder)
self.init_decoder(decoder)
def init_encoder(self, encoder: str):
self.encoder = ort.InferenceSession(
encoder,
sess_options=self.session_opts,
)
meta = self.encoder.get_modelmeta().custom_metadata_map
self.n_text_layer = int(meta["n_text_layer"])
self.n_text_ctx = int(meta["n_text_ctx"])
self.n_text_state = int(meta["n_text_state"])
self.sot = int(meta["sot"])
self.eot = int(meta["eot"])
self.translate = int(meta["translate"])
self.no_timestamps = int(meta["no_timestamps"])
self.no_speech = int(meta["no_speech"])
self.blank = int(meta["blank_id"])
self.sot_sequence = list(map(int, meta["sot_sequence"].split(",")))
self.is_multilingual = int(meta["is_multilingual"]) == 1
def init_decoder(self, decoder: str):
self.decoder = ort.InferenceSession(
decoder,
sess_options=self.session_opts,
)
def run_encoder(
self,
mel: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
n_layer_cross_k, n_layer_cross_v = self.encoder.run(
[
self.encoder.get_outputs()[0].name,
self.encoder.get_outputs()[1].name,
],
{
self.encoder.get_inputs()[0].name: mel.numpy(),
},
)
return torch.from_numpy(n_layer_cross_k), torch.from_numpy(n_layer_cross_v)
def run_decoder(
self,
tokens: torch.Tensor,
n_layer_self_k_cache: torch.Tensor,
n_layer_self_v_cache: torch.Tensor,
n_layer_cross_k: torch.Tensor,
n_layer_cross_v: torch.Tensor,
offset: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder.run(
[
self.decoder.get_outputs()[0].name,
self.decoder.get_outputs()[1].name,
self.decoder.get_outputs()[2].name,
],
{
self.decoder.get_inputs()[0].name: tokens.numpy(),
self.decoder.get_inputs()[1].name: n_layer_self_k_cache.numpy(),
self.decoder.get_inputs()[2].name: n_layer_self_v_cache.numpy(),
self.decoder.get_inputs()[3].name: n_layer_cross_k.numpy(),
self.decoder.get_inputs()[4].name: n_layer_cross_v.numpy(),
self.decoder.get_inputs()[5].name: offset.numpy(),
},
)
return (
torch.from_numpy(logits),
torch.from_numpy(out_n_layer_self_k_cache),
torch.from_numpy(out_n_layer_self_v_cache),
)
def get_self_cache(self) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = 1
n_layer_self_k_cache = torch.zeros(
self.n_text_layer,
batch_size,
self.n_text_ctx,
self.n_text_state,
)
n_layer_self_v_cache = torch.zeros(
self.n_text_layer,
batch_size,
self.n_text_ctx,
self.n_text_state,
)
return n_layer_self_k_cache, n_layer_self_v_cache
def suppress_tokens(self, logits, is_initial: bool) -> None:
# suppress blank
if is_initial:
logits[self.eot] = float("-inf")
logits[self.blank] = float("-inf")
# suppress <|notimestamps|>
logits[self.no_timestamps] = float("-inf")
logits[self.sot] = float("-inf")
logits[self.no_speech] = float("-inf")
# logits is changed in-place
logits[self.translate] = float("-inf")
def load_tokens(filename):
tokens = dict()
with open(filename, "r") as f:
for line in f:
t, i = line.split()
tokens[int(i)] = t
return tokens
def main():
args = get_args()
name = args.model
encoder = f"./{name}-encoder.onnx"
decoder = f"./{name}-decoder.onnx"
audio = whisper.load_audio("0.wav")
features = []
online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions())
online_whisper_fbank.accept_waveform(16000, audio)
online_whisper_fbank.input_finished()
for i in range(online_whisper_fbank.num_frames_ready):
f = online_whisper_fbank.get_frame(i)
f = torch.from_numpy(f)
features.append(f)
features = torch.stack(features)
log_spec = torch.clamp(features, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
mel = (log_spec + 4.0) / 4.0
target = 3000
mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0)
mel = mel.t().unsqueeze(0)
model = OnnxModel(encoder, decoder)
n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)
n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache()
tokens = torch.tensor([model.sot_sequence], dtype=torch.int64)
offset = torch.zeros(1, dtype=torch.int64)
logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
tokens=tokens,
n_layer_self_k_cache=n_layer_self_k_cache,
n_layer_self_v_cache=n_layer_self_v_cache,
n_layer_cross_k=n_layer_cross_k,
n_layer_cross_v=n_layer_cross_v,
offset=offset,
)
# logits.shape (batch_size, tokens.shape[1], vocab_size)
logits = logits[0, -1]
model.suppress_tokens(logits, is_initial=True)
# logits = logits.softmax(dim=-1)
# for greedy search, we don't need to compute softmax or log_softmax
max_token_id = logits.argmax(dim=-1)
results = []
for i in range(model.n_text_ctx):
if max_token_id == model.eot:
break
results.append(max_token_id.item())
tokens = torch.tensor([[results[-1]]])
offset += 1
logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
tokens=tokens,
n_layer_self_k_cache=n_layer_self_k_cache,
n_layer_self_v_cache=n_layer_self_v_cache,
n_layer_cross_k=n_layer_cross_k,
n_layer_cross_v=n_layer_cross_v,
offset=offset,
)
logits = logits[0, -1]
model.suppress_tokens(logits, is_initial=False)
max_token_id = logits.argmax(dim=-1)
token_table = load_tokens(f"./{name}-tokens.txt")
s = b""
for i in results:
if i in token_table:
s += base64.b64decode(token_table[i])
else:
print("oov", i)
print(s.decode().strip())
print(results)
print(model.sot_sequence)
if __name__ == "__main__":
main()
... ...
... ... @@ -11,6 +11,7 @@ if(SHERPA_ONNX_ENABLE_PYTHON)
endif()
set(sources
base64-decode.cc
cat.cc
context-graph.cc
endpoint.cc
... ... @@ -35,6 +36,9 @@ set(sources
offline-transducer-model-config.cc
offline-transducer-model.cc
offline-transducer-modified-beam-search-decoder.cc
offline-whisper-greedy-search-decoder.cc
offline-whisper-model-config.cc
offline-whisper-model.cc
online-conformer-transducer-model.cc
online-lm-config.cc
online-lm.cc
... ... @@ -50,12 +54,12 @@ set(sources
online-zipformer-transducer-model.cc
online-zipformer2-transducer-model.cc
onnx-utils.cc
session.cc
packed-sequence.cc
pad-sequence.cc
parse-options.cc
provider.cc
resample.cc
session.cc
slice.cc
stack.cc
symbol-table.cc
... ...
// sherpa-onnx/csrc/base64-decode.cc
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/base64-decode.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
static int32_t Ord(char c) {
if (c >= 'A' && c <= 'Z') {
return c - 'A';
} else if (c >= 'a' && c <= 'z') {
return c - 'a' + ('Z' - 'A') + 1;
} else if (c >= '0' && c <= '9') {
return c - '0' + ('Z' - 'A') + ('z' - 'a') + 2;
} else if (c == '+') {
return 62;
} else if (c == '/') {
return 63;
}
SHERPA_ONNX_LOGE("Unknown character %d, %c\n", c, c);
exit(-1);
}
// see
// https://github.com/ReneNyffenegger/cpp-base64/blob/master/base64.cpp#L243
std::string Base64Decode(const std::string &s) {
if (s.empty()) {
SHERPA_ONNX_LOGE("Empty string!");
exit(-1);
}
int32_t n = s.size() / 4 * 3;
std::string ans;
ans.reserve(n);
int32_t i = 0;
while (i < static_cast<int32_t>(s.size())) {
if (s[i] == '=') {
return " ";
}
int32_t first = (Ord(s[i]) << 2) + ((Ord(s[i + 1]) & 0x30) >> 4);
ans.push_back(first);
if (i + 2 < static_cast<int32_t>(s.size()) && s[i + 2] != '=') {
int32_t second =
((Ord(s[i + 1]) & 0x0f) << 4) + ((Ord(s[i + 2]) & 0x3c) >> 2);
ans.push_back(second);
if (i + 3 < static_cast<int32_t>(s.size()) && s[i + 3] != '=') {
int32_t third = ((Ord(s[i + 2]) & 0x03) << 6) + Ord(s[i + 3]);
ans.push_back(third);
}
}
i += 4;
}
return ans;
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/base64-decode.h
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_BASE64_DECODE_H_
#define SHERPA_ONNX_CSRC_BASE64_DECODE_H_
#include <string>
namespace sherpa_onnx {
/** @param s A base64 encoded string.
* @return Return the decoded string.
*/
std::string Base64Decode(const std::string &s);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_BASE64_DECODE_H_
... ...
// sherpa-onnx/csrc/macros.h
//
// Copyright 2023 Xiaomi Corporation
... ...
... ... @@ -14,6 +14,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
transducer.Register(po);
paraformer.Register(po);
nemo_ctc.Register(po);
whisper.Register(po);
po->Register("tokens", &tokens, "Path to tokens.txt");
... ... @@ -28,7 +29,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
po->Register("model-type", &model_type,
"Specify it to reduce model initialization time. "
"Valid values are: transducer, paraformer, nemo_ctc. "
"Valid values are: transducer, paraformer, nemo_ctc, whisper."
"All other values lead to loading the model twice.");
}
... ... @@ -51,6 +52,10 @@ bool OfflineModelConfig::Validate() const {
return nemo_ctc.Validate();
}
if (!whisper.encoder.empty()) {
return whisper.Validate();
}
return transducer.Validate();
}
... ... @@ -61,6 +66,7 @@ std::string OfflineModelConfig::ToString() const {
os << "transducer=" << transducer.ToString() << ", ";
os << "paraformer=" << paraformer.ToString() << ", ";
os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
os << "whisper=" << whisper.ToString() << ", ";
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
... ...
... ... @@ -9,6 +9,7 @@
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
namespace sherpa_onnx {
... ... @@ -16,6 +17,7 @@ struct OfflineModelConfig {
OfflineTransducerModelConfig transducer;
OfflineParaformerModelConfig paraformer;
OfflineNemoEncDecCtcModelConfig nemo_ctc;
OfflineWhisperModelConfig whisper;
std::string tokens;
int32_t num_threads = 2;
... ... @@ -37,11 +39,13 @@ struct OfflineModelConfig {
OfflineModelConfig(const OfflineTransducerModelConfig &transducer,
const OfflineParaformerModelConfig &paraformer,
const OfflineNemoEncDecCtcModelConfig &nemo_ctc,
const OfflineWhisperModelConfig &whisper,
const std::string &tokens, int32_t num_threads, bool debug,
const std::string &provider, const std::string &model_type)
: transducer(transducer),
paraformer(paraformer),
nemo_ctc(nemo_ctc),
whisper(whisper),
tokens(tokens),
num_threads(num_threads),
debug(debug),
... ...
... ... @@ -16,7 +16,7 @@ void OfflineNemoEncDecCtcModelConfig::Register(ParseOptions *po) {
bool OfflineNemoEncDecCtcModelConfig::Validate() const {
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("%s does not exist", model.c_str());
SHERPA_ONNX_LOGE("NeMo model: %s does not exist", model.c_str());
return false;
}
... ...
... ... @@ -15,7 +15,7 @@ void OfflineParaformerModelConfig::Register(ParseOptions *po) {
bool OfflineParaformerModelConfig::Validate() const {
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("%s does not exist", model.c_str());
SHERPA_ONNX_LOGE("Paraformer model %s does not exist", model.c_str());
return false;
}
... ...
... ... @@ -11,6 +11,7 @@
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/text-utils.h"
... ... @@ -26,6 +27,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
} else if (model_type == "nemo_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(config);
} else if (model_type == "whisper") {
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
} else {
SHERPA_ONNX_LOGE(
"Invalid model_type: %s. Trying to load the model to get its type",
... ... @@ -43,6 +46,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
model_filename = config.model_config.paraformer.model;
} else if (!config.model_config.nemo_ctc.model.empty()) {
model_filename = config.model_config.nemo_ctc.model;
} else if (!config.model_config.whisper.encoder.empty()) {
model_filename = config.model_config.whisper.encoder;
} else {
SHERPA_ONNX_LOGE("Please provide a model");
exit(-1);
... ... @@ -77,6 +82,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
"\n "
"https://huggingface.co/csukuangfj/"
"paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py"
"\n "
"(3) Whisper"
"\n");
exit(-1);
}
... ... @@ -95,12 +102,17 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerCtcImpl>(config);
}
if (strncmp(model_type.c_str(), "whisper", 7) == 0) {
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
}
SHERPA_ONNX_LOGE(
"\nUnsupported model_type: %s\n"
"We support only the following model types at present: \n"
" - Non-streaming transducer models from icefall\n"
" - Non-streaming Paraformer models from FunASR\n"
" - EncDecCTCModelBPE models from NeMo\n",
" - EncDecCTCModelBPE models from NeMo\n"
" - Whisper models\n",
model_type.c_str());
exit(-1);
... ...
// sherpa-onnx/csrc/offline-recognizer-whisper-impl.h
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_
#include <algorithm>
#include <cmath>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/offline-whisper-decoder.h"
#include "sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/offline-whisper-model.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/transpose.h"
namespace sherpa_onnx {
static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
const SymbolTable &sym_table) {
OfflineRecognitionResult r;
r.tokens.reserve(src.tokens.size());
for (auto i : src.tokens) {
if (!sym_table.contains(i)) {
continue;
}
const auto &s = sym_table[i];
r.text += s;
r.tokens.push_back(s);
}
return r;
}
class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
public:
explicit OfflineRecognizerWhisperImpl(const OfflineRecognizerConfig &config)
: config_(config),
symbol_table_(config_.model_config.tokens),
model_(std::make_unique<OfflineWhisperModel>(config.model_config)) {
// tokens.txt from whisper is base64 encoded, so we need to decode it
symbol_table_.ApplyBase64Decode();
if (config.decoding_method == "greedy_search") {
decoder_ =
std::make_unique<OfflineWhisperGreedySearchDecoder>(model_.get());
} else {
SHERPA_ONNX_LOGE(
"Only greedy_search is supported at present for whisper. Given %s",
config.decoding_method.c_str());
exit(-1);
}
}
std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(WhisperTag{});
}
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
// batch decoding is not implemented yet
for (int32_t i = 0; i != n; ++i) {
DecodeStream(ss[i]);
}
}
private:
void DecodeStream(OfflineStream *s) const {
int32_t max_num_frames = 3000;
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
int32_t feat_dim = s->FeatureDim();
std::vector<float> f = s->GetFrames();
int32_t num_frames = f.size() / feat_dim;
if (num_frames > max_num_frames) {
SHERPA_ONNX_LOGE("Only waves less than 30 seconds are supported.");
exit(-1);
}
NormalizeFeatures(f.data(), num_frames, feat_dim);
std::array<int64_t, 3> shape{1, max_num_frames, feat_dim};
Ort::Value mel = Ort::Value::CreateTensor<float>(
model_->Allocator(), shape.data(), shape.size());
float *p_mel = mel.GetTensorMutableData<float>();
std::copy(f.begin(), f.end(), p_mel);
memset(p_mel + f.size(), 0,
(max_num_frames - num_frames) * feat_dim * sizeof(float));
mel = Transpose12(model_->Allocator(), &mel);
auto cross_kv = model_->ForwardEncoder(std::move(mel));
auto results =
decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second));
auto r = Convert(results[0], symbol_table_);
s->SetResult(r);
}
private:
static void NormalizeFeatures(float *features, int32_t num_frames,
int32_t feat_dim) {
// log_spec = torch.clamp(features, min=1e-10).log10()
// log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
// mel = (log_spec + 4.0) / 4.0
int32_t n = num_frames * feat_dim;
float max_v = -1e20;
for (int32_t i = 0; i != n; ++i) {
float f = features[i];
f = std::max<float>(f, 1e-10);
f = std::log10(f);
max_v = std::max(f, max_v);
features[i] = f;
}
max_v -= 8;
for (int32_t i = 0; i != n; ++i) {
float f = features[i];
f = std::max(f, max_v);
f = (f + 4) / 4;
features[i] = f;
}
}
private:
OfflineRecognizerConfig config_;
SymbolTable symbol_table_;
std::unique_ptr<OfflineWhisperModel> model_;
std::unique_ptr<OfflineWhisperDecoder> decoder_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_
... ...
... ... @@ -86,6 +86,15 @@ class OfflineStream::Impl {
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
Impl(WhisperTag /*tag*/, ContextGraphPtr context_graph)
: context_graph_(context_graph) {
config_.normalize_samples = true;
opts_.frame_opts.samp_freq = 16000;
opts_.mel_opts.num_bins = 80;
whisper_fbank_ =
std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts);
}
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
if (config_.normalize_samples) {
AcceptWaveformImpl(sampling_rate, waveform, n);
... ... @@ -117,20 +126,35 @@ class OfflineStream::Impl {
lowpass_filter_width);
std::vector<float> samples;
resampler->Resample(waveform, n, true, &samples);
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
samples.size());
fbank_->InputFinished();
if (fbank_) {
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
samples.size());
fbank_->InputFinished();
} else {
whisper_fbank_->AcceptWaveform(opts_.frame_opts.samp_freq,
samples.data(), samples.size());
whisper_fbank_->InputFinished();
}
return;
}
} // if (sampling_rate != opts_.frame_opts.samp_freq)
fbank_->AcceptWaveform(sampling_rate, waveform, n);
fbank_->InputFinished();
if (fbank_) {
fbank_->AcceptWaveform(sampling_rate, waveform, n);
fbank_->InputFinished();
} else {
whisper_fbank_->AcceptWaveform(sampling_rate, waveform, n);
whisper_fbank_->InputFinished();
}
}
int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
std::vector<float> GetFrames() const {
int32_t n = fbank_->NumFramesReady();
int32_t n =
fbank_ ? fbank_->NumFramesReady() : whisper_fbank_->NumFramesReady();
assert(n > 0 && "Please first call AcceptWaveform()");
int32_t feature_dim = FeatureDim();
... ... @@ -140,7 +164,8 @@ class OfflineStream::Impl {
float *p = features.data();
for (int32_t i = 0; i != n; ++i) {
const float *f = fbank_->GetFrame(i);
const float *f =
fbank_ ? fbank_->GetFrame(i) : whisper_fbank_->GetFrame(i);
std::copy(f, f + feature_dim, p);
p += feature_dim;
}
... ... @@ -191,6 +216,7 @@ class OfflineStream::Impl {
private:
OfflineFeatureExtractorConfig config_;
std::unique_ptr<knf::OnlineFbank> fbank_;
std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_;
knf::FbankOptions opts_;
OfflineRecognitionResult r_;
ContextGraphPtr context_graph_;
... ... @@ -201,6 +227,10 @@ OfflineStream::OfflineStream(
ContextGraphPtr context_graph /*= nullptr*/)
: impl_(std::make_unique<Impl>(config, context_graph)) {}
OfflineStream::OfflineStream(WhisperTag tag,
ContextGraphPtr context_graph /*= nullptr*/)
: impl_(std::make_unique<Impl>(tag, context_graph)) {}
OfflineStream::~OfflineStream() = default;
void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform,
... ...
... ... @@ -65,10 +65,15 @@ struct OfflineFeatureExtractorConfig {
void Register(ParseOptions *po);
};
struct WhisperTag {};
class OfflineStream {
public:
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {},
ContextGraphPtr context_graph = nullptr);
explicit OfflineStream(WhisperTag tag,
ContextGraphPtr context_graph = nullptr);
~OfflineStream();
/**
... ...
... ... @@ -18,17 +18,20 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) {
bool OfflineTransducerModelConfig::Validate() const {
if (!FileExists(encoder_filename)) {
SHERPA_ONNX_LOGE("encoder: %s does not exist", encoder_filename.c_str());
SHERPA_ONNX_LOGE("transducer encoder: %s does not exist",
encoder_filename.c_str());
return false;
}
if (!FileExists(decoder_filename)) {
SHERPA_ONNX_LOGE("decoder: %s does not exist", decoder_filename.c_str());
SHERPA_ONNX_LOGE("transducer decoder: %s does not exist",
decoder_filename.c_str());
return false;
}
if (!FileExists(joiner_filename)) {
SHERPA_ONNX_LOGE("joiner: %s does not exist", joiner_filename.c_str());
SHERPA_ONNX_LOGE("transducer joiner: %s does not exist",
joiner_filename.c_str());
return false;
}
... ...
// sherpa-onnx/csrc/offline-whisper-decoder.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
namespace sherpa_onnx {
struct OfflineWhisperDecoderResult {
/// The decoded token IDs
std::vector<int32_t> tokens;
};
class OfflineWhisperDecoder {
public:
virtual ~OfflineWhisperDecoder() = default;
/** Run beam search given the output from the whisper encoder model.
*
* @param n_layer_cross_k A 4-D tensor of shape
* (n_text_layer, N, n_audio_ctx, n_text_state).
* @param n_layer_cross_v A 4-D tensor of shape
* (n_text_layer, N, n_audio_ctx, n_text_state).
*
* @return Return a vector of size `N` containing the decoded results.
*/
virtual std::vector<OfflineWhisperDecoderResult> Decode(
Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
... ...
// sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h"
#include <algorithm>
#include <utility>
namespace sherpa_onnx {
std::vector<OfflineWhisperDecoderResult>
OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
Ort::Value cross_v) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
auto self_kv_cache = model_->GetInitialSelfKVCache();
std::vector<int64_t> initial_tokens = model_->GetInitialTokens();
int32_t batch_size = 1;
std::array<int64_t, 2> token_shape{
batch_size, static_cast<int64_t>(initial_tokens.size())};
Ort::Value tokens = Ort::Value::CreateTensor(
memory_info, initial_tokens.data(), initial_tokens.size(),
token_shape.data(), token_shape.size());
std::array<int64_t, 1> offset_shape{1};
Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
model_->Allocator(), offset_shape.data(), offset_shape.size());
*(offset.GetTensorMutableData<int64_t>()) = 0;
auto decoder_out = model_->ForwardDecoder(
std::move(tokens), std::move(self_kv_cache.first),
std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v),
std::move(offset));
const auto &logits = std::get<0>(decoder_out);
const float *p_logits = logits.GetTensorData<float>();
auto logits_shape = logits.GetTensorTypeAndShapeInfo().GetShape();
int32_t vocab_size = logits_shape[2];
int32_t max_token_id = static_cast<int32_t>(std::distance(
p_logits, std::max_element(p_logits, p_logits + vocab_size)));
int32_t n_text_ctx = model_->TextCtx();
std::vector<int32_t> predicted_tokens;
for (int32_t i = 0; i < n_text_ctx; ++i) {
if (max_token_id == model_->EOT()) {
break;
}
predicted_tokens.push_back(max_token_id);
std::array<int64_t, 2> token_shape{1, 1};
Ort::Value tokens = Ort::Value::CreateTensor<int64_t>(
model_->Allocator(), token_shape.data(), token_shape.size());
int64_t *p_tokens = tokens.GetTensorMutableData<int64_t>();
p_tokens[0] = max_token_id;
int64_t *p_offset =
std::get<5>(decoder_out).GetTensorMutableData<int64_t>();
if (i == 0) {
*p_offset = initial_tokens.size();
} else {
*p_offset += 1;
}
decoder_out = model_->ForwardDecoder(std::move(tokens),
std::move(std::get<1>(decoder_out)),
std::move(std::get<2>(decoder_out)),
std::move(std::get<3>(decoder_out)),
std::move(std::get<4>(decoder_out)),
std::move(std::get<5>(decoder_out)));
const auto &logits = std::get<0>(decoder_out);
const float *p_logits = logits.GetTensorData<float>();
max_token_id = static_cast<int64_t>(std::distance(
p_logits, std::max_element(p_logits, p_logits + vocab_size)));
}
std::vector<OfflineWhisperDecoderResult> ans(1);
ans[0].tokens = std::move(predicted_tokens);
return ans;
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_GREEDY_SEARCH_DECODER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_GREEDY_SEARCH_DECODER_H_
#include <vector>
#include "sherpa-onnx/csrc/offline-whisper-decoder.h"
#include "sherpa-onnx/csrc/offline-whisper-model.h"
namespace sherpa_onnx {
class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
public:
explicit OfflineWhisperGreedySearchDecoder(OfflineWhisperModel *model)
: model_(model) {}
std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
Ort::Value cross_v) override;
private:
OfflineWhisperModel *model_; // not owned
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_GREEDY_SEARCH_DECODER_H_
... ...
// sherpa-onnx/csrc/offline-whisper-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OfflineWhisperModelConfig::Register(ParseOptions *po) {
po->Register("whisper-encoder", &encoder,
"Path to onnx encoder of whisper, e.g., tiny-encoder.onnx, "
"medium.en-encoder.onnx.");
po->Register("whisper-decoder", &decoder,
"Path to onnx decoder of whisper, e.g., tiny-decoder.onnx, "
"medium.en-decoder.onnx.");
}
bool OfflineWhisperModelConfig::Validate() const {
if (!FileExists(encoder)) {
SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str());
return false;
}
if (!FileExists(decoder)) {
SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str());
return false;
}
return true;
}
std::string OfflineWhisperModelConfig::ToString() const {
std::ostringstream os;
os << "OfflineWhisperModelConfig(";
os << "encoder=\"" << encoder << "\", ";
os << "decoder=\"" << decoder << "\")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-whisper-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OfflineWhisperModelConfig {
std::string encoder;
std::string decoder;
OfflineWhisperModelConfig() = default;
OfflineWhisperModelConfig(const std::string &encoder,
const std::string &decoder)
: encoder(encoder), decoder(decoder) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
... ...
// sherpa-onnx/csrc/offline-whisper-model.cc
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-whisper-model.h"
#include <algorithm>
#include <string>
#include <tuple>
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
class OfflineWhisperModel::Impl {
public:
explicit Impl(const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(config.whisper.encoder);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(config.whisper.decoder);
InitDecoder(buf.data(), buf.size());
}
}
std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features) {
auto encoder_out = encoder_sess_->Run(
{}, encoder_input_names_ptr_.data(), &features, 1,
encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size());
return {std::move(encoder_out[0]), std::move(encoder_out[1])};
}
std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
Ort::Value>
ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache,
Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
Ort::Value n_layer_cross_v, Ort::Value offset) {
std::array<Ort::Value, 6> decoder_input = {std::move(tokens),
std::move(n_layer_self_k_cache),
std::move(n_layer_self_v_cache),
std::move(n_layer_cross_k),
std::move(n_layer_cross_v),
std::move(offset)};
auto decoder_out = decoder_sess_->Run(
{}, decoder_input_names_ptr_.data(), decoder_input.data(),
decoder_input.size(), decoder_output_names_ptr_.data(),
decoder_output_names_ptr_.size());
return {std::move(decoder_out[0]), std::move(decoder_out[1]),
std::move(decoder_out[2]), std::move(decoder_input[3]),
std::move(decoder_input[4]), std::move(decoder_input[5])};
}
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() {
std::array<int64_t, 4> shape{n_text_layer_, 1, n_text_ctx_, n_text_state_};
Ort::Value n_layer_self_k_cache = Ort::Value::CreateTensor<float>(
Allocator(), shape.data(), shape.size());
Ort::Value n_layer_self_v_cache = Ort::Value::CreateTensor<float>(
Allocator(), shape.data(), shape.size());
auto n = shape[0] * shape[1] * shape[2] * shape[3];
float *p_k = n_layer_self_k_cache.GetTensorMutableData<float>();
float *p_v = n_layer_self_v_cache.GetTensorMutableData<float>();
memset(p_k, 0, sizeof(float) * n);
memset(p_v, 0, sizeof(float) * n);
return {std::move(n_layer_self_k_cache), std::move(n_layer_self_v_cache)};
}
OrtAllocator *Allocator() const { return allocator_; }
const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; }
int32_t EOT() const { return eot_; }
int32_t TextCtx() const { return n_text_ctx_; }
private:
void InitEncoder(void *model_data, size_t model_data_length) {
encoder_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
&encoder_input_names_ptr_);
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
&encoder_output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
os << "---encoder---\n";
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer");
SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx");
SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state");
SHERPA_ONNX_READ_META_DATA(sot_, "sot");
SHERPA_ONNX_READ_META_DATA(eot_, "eot");
SHERPA_ONNX_READ_META_DATA(blank_, "blank_id");
SHERPA_ONNX_READ_META_DATA(translate_, "translate");
SHERPA_ONNX_READ_META_DATA(no_timestamps_, "no_timestamps");
SHERPA_ONNX_READ_META_DATA(no_speech_, "no_speech");
SHERPA_ONNX_READ_META_DATA_VEC(sot_sequence_, "sot_sequence");
}
void InitDecoder(void *model_data, size_t model_data_length) {
decoder_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
&decoder_input_names_ptr_);
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
&decoder_output_names_ptr_);
}
private:
OfflineModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> encoder_sess_;
std::unique_ptr<Ort::Session> decoder_sess_;
std::vector<std::string> encoder_input_names_;
std::vector<const char *> encoder_input_names_ptr_;
std::vector<std::string> encoder_output_names_;
std::vector<const char *> encoder_output_names_ptr_;
std::vector<std::string> decoder_input_names_;
std::vector<const char *> decoder_input_names_ptr_;
std::vector<std::string> decoder_output_names_;
std::vector<const char *> decoder_output_names_ptr_;
// model meta data
int32_t n_text_layer_;
int32_t n_text_ctx_;
int32_t n_text_state_;
int32_t sot_;
int32_t eot_;
int32_t blank_;
int32_t translate_;
int32_t no_timestamps_;
int32_t no_speech_;
std::vector<int64_t> sot_sequence_;
};
OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
OfflineWhisperModel::~OfflineWhisperModel() = default;
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::ForwardEncoder(
Ort::Value features) {
return impl_->ForwardEncoder(std::move(features));
}
std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
Ort::Value>
OfflineWhisperModel::ForwardDecoder(Ort::Value tokens,
Ort::Value n_layer_self_k_cache,
Ort::Value n_layer_self_v_cache,
Ort::Value n_layer_cross_k,
Ort::Value n_layer_cross_v,
Ort::Value offset) {
return impl_->ForwardDecoder(
std::move(tokens), std::move(n_layer_self_k_cache),
std::move(n_layer_self_v_cache), std::move(n_layer_cross_k),
std::move(n_layer_cross_v), std::move(offset));
}
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache() {
return impl_->GetInitialSelfKVCache();
}
OrtAllocator *OfflineWhisperModel::Allocator() const {
return impl_->Allocator();
}
const std::vector<int64_t> &OfflineWhisperModel::GetInitialTokens() const {
return impl_->GetInitialTokens();
}
int32_t OfflineWhisperModel::EOT() const { return impl_->EOT(); }
int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); }
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-whisper-model.h
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_
#include <memory>
#include <tuple>
#include <utility>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-model-config.h"
namespace sherpa_onnx {
class OfflineWhisperModel {
public:
explicit OfflineWhisperModel(const OfflineModelConfig &config);
~OfflineWhisperModel();
/** Run the encoder model.
*
* @param features A tensor of shape (N, C, T). It is changed in-place.
* C is 80 and T is 3000.
*
* @return Return a pair containing:
* - n_layer_cross_k: A 4-D tensor of shape
* (n_text_layer, N, n_audio_ctx, n_text_state)
* - n_layer_cross_v: A 4-D tensor of shape
* (n_text_layer, N, n_audio_ctx, n_text_state)
*/
std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features);
/** Run the decoder model.
*
* @param tokens A int64 tensor of shape (N, num_words)
* @param n_layer_self_k_cache A 4-D tensor of shape
* (n_text_layer, N, n_text_ctx, n_text_state).
* @param n_layer_self_v_cache A 4-D tensor of shape
* (n_text_layer, N, n_text_ctx, n_text_state).
* @param n_layer_cross_k A 4-D tensor of shape
* (n_text_layer, N, n_audio_ctx, n_text_state).
* @param n_layer_cross_v A 4-D tensor of shape
* (n_text_layer, N, n_audio_ctx, n_text_state).
* @param offset A int64 tensor of shape (N,)
*
* @return Return a tuple containing 6 tensors:
*
* - logits A 3-D tensor of shape (N, num_words, vocab_size)
* - out_n_layer_self_k_cache Same shape as n_layer_self_k_cache
* - out_n_layer_self_v_cache Same shape as n_layer_self_v_cache
* - out_n_layer_cross_k Same as n_layer_cross_k
* - out_n_layer_cross_v Same as n_layer_cross_v
* - out_offset Same as offset
*/
std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
Ort::Value>
ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache,
Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
Ort::Value n_layer_cross_v, Ort::Value offset);
/** Return the initial self kv cache in a pair
* - n_layer_self_k_cache A 4-D tensor of shape
* (n_text_layer, N, n_audio_ctx, n_text_state).
* - n_layer_self_v_cache A 4-D tensor of shape
* (n_text_layer, N, n_audio_ctx, n_text_state).
*/
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache();
const std::vector<int64_t> &GetInitialTokens() const;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const;
int32_t EOT() const;
int32_t TextCtx() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_
... ...
... ... @@ -98,11 +98,15 @@ Usage:
./bin/sherpa-onnx-microphone-offline \
--tokens=/path/to/tokens.txt \
--paraformer=/path/to/model.onnx \
--num-threads=2 \
--decoding-method=greedy_search
--num-threads=1
Default value for num_threads is 2.
Valid values for decoding_method: greedy_search.
(3) Whisper models
./bin/sherpa-onnx-microphone-offline \
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
--num-threads=1
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
... ...
... ... @@ -23,7 +23,7 @@ Usage:
--encoder=/path/to/encoder.onnx \
--decoder=/path/to/decoder.onnx \
--joiner=/path/to/joiner.onnx \
--num-threads=2 \
--num-threads=1 \
--decoding-method=greedy_search \
/path/to/foo.wav [bar.wav foobar.wav ...]
... ... @@ -33,14 +33,22 @@ Usage:
./bin/sherpa-onnx-offline \
--tokens=/path/to/tokens.txt \
--paraformer=/path/to/model.onnx \
--num-threads=2 \
--num-threads=1 \
--decoding-method=greedy_search \
/path/to/foo.wav [bar.wav foobar.wav ...]
(3) Whisper models
./bin/sherpa-onnx-offline \
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
--num-threads=1 \
/path/to/foo.wav [bar.wav foobar.wav ...]
Note: It supports decoding multiple files in batches
Default value for num_threads is 2.
Valid values for decoding_method: greedy_search.
foo.wav should be of single channel, 16-bit PCM encoded wave file; its
sampling rate can be arbitrary and does not need to be 16kHz.
... ... @@ -55,6 +63,7 @@ for a list of pre-trained models to download.
po.Read(argc, argv);
if (po.NumArgs() < 1) {
fprintf(stderr, "Error: Please provide at least 1 wave file.\n\n");
po.PrintUsage();
exit(EXIT_FAILURE);
}
... ...
... ... @@ -9,6 +9,7 @@
#include <sstream>
#include <strstream>
#include "sherpa-onnx/csrc/base64-decode.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#if __ANDROID_API__ >= 9
... ... @@ -82,4 +83,12 @@ std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table) {
return os << symbol_table.ToString();
}
void SymbolTable::ApplyBase64Decode() {
sym2id_.clear();
for (auto &p : id2sym_) {
p.second = Base64Decode(p.second);
sym2id_[p.second] = p.first;
}
}
} // namespace sherpa_onnx
... ...
... ... @@ -45,6 +45,9 @@ class SymbolTable {
/// Return true if there is a given symbol in the symbol table.
bool contains(const std::string &sym) const;
// for tokens.txt from Whisper
void ApplyBase64Decode();
private:
void Init(std::istream &is);
... ...
... ... @@ -11,6 +11,7 @@ pybind11_add_module(_sherpa_onnx
offline-recognizer.cc
offline-stream.cc
offline-transducer-model-config.cc
offline-whisper-model-config.cc
online-lm-config.cc
online-recognizer.cc
online-stream.cc
... ...
... ... @@ -11,6 +11,7 @@
#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
#include "sherpa-onnx/python/csrc/offline-whisper-model-config.h"
namespace sherpa_onnx {
... ... @@ -18,22 +19,25 @@ void PybindOfflineModelConfig(py::module *m) {
PybindOfflineTransducerModelConfig(m);
PybindOfflineParaformerModelConfig(m);
PybindOfflineNemoEncDecCtcModelConfig(m);
PybindOfflineWhisperModelConfig(m);
using PyClass = OfflineModelConfig;
py::class_<PyClass>(*m, "OfflineModelConfig")
.def(
py::init<const OfflineTransducerModelConfig &,
const OfflineParaformerModelConfig &,
const OfflineNemoEncDecCtcModelConfig &, const std::string &,
int32_t, bool, const std::string &, const std::string &>(),
py::arg("transducer") = OfflineTransducerModelConfig(),
py::arg("paraformer") = OfflineParaformerModelConfig(),
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "")
.def(py::init<const OfflineTransducerModelConfig &,
const OfflineParaformerModelConfig &,
const OfflineNemoEncDecCtcModelConfig &,
const OfflineWhisperModelConfig &, const std::string &,
int32_t, bool, const std::string &, const std::string &>(),
py::arg("transducer") = OfflineTransducerModelConfig(),
py::arg("paraformer") = OfflineParaformerModelConfig(),
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
py::arg("whisper") = OfflineWhisperModelConfig(), py::arg("tokens"),
py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "")
.def_readwrite("transducer", &PyClass::transducer)
.def_readwrite("paraformer", &PyClass::paraformer)
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
.def_readwrite("whisper", &PyClass::whisper)
.def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)
... ...
// sherpa-onnx/python/csrc/offline-whisper-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
#include <string>
#include <vector>
#include "sherpa-onnx/python/csrc/offline-whisper-model-config.h"
namespace sherpa_onnx {
void PybindOfflineWhisperModelConfig(py::module *m) {
using PyClass = OfflineWhisperModelConfig;
py::class_<PyClass>(*m, "OfflineWhisperModelConfig")
.def(py::init<const std::string &, const std::string &>(),
py::arg("encoder"), py::arg("decoder"))
.def_readwrite("encoder", &PyClass::encoder)
.def_readwrite("decoder", &PyClass::decoder)
.def("__str__", &PyClass::ToString);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/offline-whisper-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOfflineWhisperModelConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
... ...
# Copyright (c) 2023 by manyeyes
# Copyright (c) 2023 Xiaomi Corporation
from pathlib import Path
from typing import List, Optional
... ... @@ -7,6 +8,7 @@ from _sherpa_onnx import (
OfflineModelConfig,
OfflineNemoEncDecCtcModelConfig,
OfflineParaformerModelConfig,
OfflineWhisperModelConfig,
)
from _sherpa_onnx import OfflineRecognizer as _Recognizer
from _sherpa_onnx import (
... ... @@ -69,7 +71,7 @@ class OfflineRecognizer(object):
feature_dim:
Dimension of the feature used to train the model.
decoding_method:
Support only greedy_search for now.
Valid values: greedy_search, modified_beam_search.
debug:
True to show debug messages.
provider:
... ... @@ -137,7 +139,7 @@ class OfflineRecognizer(object):
feature_dim:
Dimension of the feature used to train the model.
decoding_method:
Valid values are greedy_search, modified_beam_search.
Valid values are greedy_search.
debug:
True to show debug messages.
provider:
... ... @@ -185,14 +187,14 @@ class OfflineRecognizer(object):
English, etc.
Args:
model:
Path to ``model.onnx``.
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
model:
Path to ``model.onnx``.
num_threads:
Number of threads for neural network computation.
sample_rate:
... ... @@ -200,7 +202,7 @@ class OfflineRecognizer(object):
feature_dim:
Dimension of the feature used to train the model.
decoding_method:
Valid values are greedy_search, modified_beam_search.
Valid values are greedy_search.
debug:
True to show debug messages.
provider:
... ... @@ -229,6 +231,68 @@ class OfflineRecognizer(object):
self.recognizer = _Recognizer(recognizer_config)
return self
@classmethod
def from_whisper(
cls,
encoder: str,
decoder: str,
tokens: str,
num_threads: int,
decoding_method: str = "greedy_search",
debug: bool = False,
provider: str = "cpu",
):
"""
Please refer to
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_
to download pre-trained models for different kinds of whisper models,
e.g., tiny, tiny.en, base, base.en, etc.
Args:
encoder_model:
Path to the encoder model, e.g., tiny-encoder.onnx,
tiny-encoder.int8.onnx, tiny-encoder.ort, etc.
decoder_model:
Path to the encoder model, e.g., tiny-encoder.onnx,
tiny-encoder.int8.onnx, tiny-encoder.ort, etc.
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
num_threads:
Number of threads for neural network computation.
decoding_method:
Valid values: greedy_search.
debug:
True to show debug messages.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
"""
self = cls.__new__(cls)
model_config = OfflineModelConfig(
whisper=OfflineWhisperModelConfig(encoder=encoder, decoder=decoder),
tokens=tokens,
num_threads=num_threads,
debug=debug,
provider=provider,
model_type="whisper",
)
feat_config = OfflineFeatureExtractorConfig(
sampling_rate=16000,
feature_dim=80,
)
recognizer_config = OfflineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
decoding_method=decoding_method,
)
self.recognizer = _Recognizer(recognizer_config)
return self
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
if contexts_list is None:
return self.recognizer.create_stream()
... ...