Fangjun Kuang
Committed by GitHub

Support whisper large/large-v1/large-v2/large-v3 and distil-large-v2 (#1114)

... ... @@ -15,9 +15,9 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
# model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2", "distil-large-v2"]
model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "medium-aishell"]
os: [macos-latest]
model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "medium-aishell", "large", "large-v1", "large-v2", "distil-large-v2"]
# model: ["large", "large-v1", "large-v2", "large-v3", "distil-large-v2"]
python-version: ["3.8"]
steps:
... ... @@ -32,7 +32,7 @@ jobs:
shell: bash
run: |
python3 -m pip install torch==1.13.0 torchaudio==0.13.0 -f https://download.pytorch.org/whl/cpu/torch_stable.html
python3 -m pip install openai-whisper==20230314 onnxruntime onnx
python3 -m pip install openai-whisper==20231117 onnxruntime onnx soundfile librosa
- name: export ${{ matrix.model }}
shell: bash
... ... @@ -62,7 +62,6 @@ jobs:
rm -fv medium-aishell-decoder.onnx
fi
ls -lh
ls -lh ~/.cache/whisper || true
... ... @@ -74,7 +73,8 @@ jobs:
src=sherpa-onnx-whisper-${{ matrix.model }}
cd ..
mv whisper $src
mkdir $src
mv -v whisper/$model* $src/
echo "------------------------------"
... ... @@ -97,19 +97,16 @@ jobs:
ls -lh $src
echo "--------------------"
if [[ $model == large || $model == large-v1 || $model == large-v2 || $model == distil-large-v2 ]]; then
#tar cvjf - $src | split --bytes=1024MB - $src.tar.bz2.
tar cvjf $src.tar.bz2 $src
split -b 1G $src.tar.bz2 $src.tar.bz2.
rm $src.tar.bz2
# cat $src.tar.gz.* | tar xjf -
if [[ $model == large || $model == large-v1 || $model == large-v2 || $model == distil-large-v2 ]]; then
echo "Don't release model to github for large models. $model"
else
tar cvjf $src.tar.bz2 $src
fi
ls -lh
ls -lh
- name: Release
if: matrix.model != 'large' && matrix.model != 'large-v1' && matrix.model != 'large-v2' && matrix.model != 'large-v3' && matrix.model != 'distil-large-v2'
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
... ... @@ -119,19 +116,6 @@ jobs:
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
tag: asr-models
- name: Test ${{ matrix.model }}
shell: bash
run: |
python3 -m pip install kaldi-native-fbank
git checkout .
model=${{ matrix.model }}
src=sherpa-onnx-whisper-$model
python3 scripts/whisper/test.py \
--encoder $src/$model-encoder.int8.onnx \
--decoder $src/$model-decoder.int8.onnx \
--tokens $src/$model-tokens.txt \
$src/test_wavs/0.wav
- name: Publish ${{ matrix.model }} to huggingface
shell: bash
env:
... ... @@ -144,27 +128,36 @@ jobs:
export GIT_CLONE_PROTECTION_ACTIVE=false
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface
export GIT_LFS_SKIP_SMUDGE=1
git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface
if [[ $model != medium-aishell ]]; then
rm -rf huggingface/*
fi
if [[ $model == large || $model == large-v1 || $model == large-v2 || $model == distil-large-v2 ]]; then
mv $src.tar* ./huggingface
else
cp -v $src/*.onnx ./huggingface
cp -v $src/*tokens* ./huggingface
cp -av $src/test_wavs ./huggingface
fi
cp -av $src/* ./huggingface/
cd huggingface
git status
ls -lh
git lfs track "*gz*"
git lfs track "*onnx*"
git lfs track "*weights*"
git add .
git commit -m "upload ${{ matrix.model }}"
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} main
- name: Test ${{ matrix.model }}
shell: bash
run: |
python3 -m pip install kaldi-native-fbank
git checkout .
model=${{ matrix.model }}
src=sherpa-onnx-whisper-$model
time python3 scripts/whisper/test.py \
--encoder $src/$model-encoder.onnx \
--decoder $src/$model-decoder.onnx \
--tokens $src/$model-tokens.txt \
$src/test_wavs/0.wav
... ...
## 1.10.14 (to-be-released)
## 1.10.14
* Support whisper large v3
* Update onnxruntime from v1.18.0 to v1.18.1
* Fix invalid utf8 sequence from Whisper for Dart API.
... ...
... ... @@ -11,7 +11,7 @@ project(sherpa-onnx)
# ./nodejs-addon-examples
# ./dart-api-examples/
# ./CHANGELOG.md
set(SHERPA_ONNX_VERSION "1.10.13")
set(SHERPA_ONNX_VERSION "1.10.14")
# Disable warning about
#
... ...
function(download_kaldi_native_fbank)
include(FetchContent)
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.3.tar.gz")
set(kaldi_native_fbank_URL2 "https://hub.nuaa.cf/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.3.tar.gz")
set(kaldi_native_fbank_HASH "SHA256=335fe1daf1b9bfb2a7b6bf03b64c4c4686c39077c57fb8058c02611981676638")
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.20.0.tar.gz")
set(kaldi_native_fbank_URL2 "https://hub.nuaa.cf/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.20.0.tar.gz")
set(kaldi_native_fbank_HASH "SHA256=c6195b3cf374eef824644061d3c04f6b2a9267ae554169cbaa9865c89c1fe4f9")
set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
... ... @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank)
# If you don't have access to the Internet,
# please pre-download kaldi-native-fbank
set(possible_file_locations
$ENV{HOME}/Downloads/kaldi-native-fbank-1.19.3.tar.gz
${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.19.3.tar.gz
${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.19.3.tar.gz
/tmp/kaldi-native-fbank-1.19.3.tar.gz
/star-fj/fangjun/download/github/kaldi-native-fbank-1.19.3.tar.gz
$ENV{HOME}/Downloads/kaldi-native-fbank-1.20.0.tar.gz
${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.20.0.tar.gz
${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.20.0.tar.gz
/tmp/kaldi-native-fbank-1.20.0.tar.gz
/star-fj/fangjun/download/github/kaldi-native-fbank-1.20.0.tar.gz
)
foreach(f IN LISTS possible_file_locations)
... ...
... ... @@ -10,7 +10,7 @@ environment:
# Add regular dependencies here.
dependencies:
sherpa_onnx: ^1.10.13
sherpa_onnx: ^1.10.14
path: ^1.9.0
args: ^2.5.0
... ...
... ... @@ -11,7 +11,7 @@ environment:
# Add regular dependencies here.
dependencies:
sherpa_onnx: ^1.10.13
sherpa_onnx: ^1.10.14
path: ^1.9.0
args: ^2.5.0
... ...
... ... @@ -8,7 +8,7 @@ environment:
# Add regular dependencies here.
dependencies:
sherpa_onnx: ^1.10.13
sherpa_onnx: ^1.10.14
path: ^1.9.0
args: ^2.5.0
... ...
... ... @@ -9,7 +9,7 @@ environment:
sdk: ^3.4.0
dependencies:
sherpa_onnx: ^1.10.13
sherpa_onnx: ^1.10.14
path: ^1.9.0
args: ^2.5.0
... ...
... ... @@ -5,7 +5,7 @@ description: >
publish_to: 'none'
version: 1.10.13
version: 1.10.14
topics:
- speech-recognition
... ... @@ -30,7 +30,7 @@ dependencies:
record: ^5.1.0
url_launcher: ^6.2.6
sherpa_onnx: ^1.10.13
sherpa_onnx: ^1.10.14
# sherpa_onnx:
# path: ../../flutter/sherpa_onnx
... ...
... ... @@ -17,7 +17,7 @@ dependencies:
cupertino_icons: ^1.0.6
path_provider: ^2.1.3
path: ^1.9.0
sherpa_onnx: ^1.10.13
sherpa_onnx: ^1.10.14
url_launcher: ^6.2.6
audioplayers: ^5.0.0
... ...
... ... @@ -17,7 +17,7 @@ topics:
- voice-activity-detection
# remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx_macos.podspec
version: 1.10.13
version: 1.10.14
homepage: https://github.com/k2-fsa/sherpa-onnx
... ... @@ -30,19 +30,19 @@ dependencies:
flutter:
sdk: flutter
sherpa_onnx_android: ^1.10.13
sherpa_onnx_android: ^1.10.14
# path: ../sherpa_onnx_android
sherpa_onnx_macos: ^1.10.13
sherpa_onnx_macos: ^1.10.14
# path: ../sherpa_onnx_macos
sherpa_onnx_linux: ^1.10.13
sherpa_onnx_linux: ^1.10.14
# path: ../sherpa_onnx_linux
#
sherpa_onnx_windows: ^1.10.13
sherpa_onnx_windows: ^1.10.14
# path: ../sherpa_onnx_windows
sherpa_onnx_ios: ^1.10.13
sherpa_onnx_ios: ^1.10.14
# sherpa_onnx_ios:
# path: ../sherpa_onnx_ios
... ...
... ... @@ -7,7 +7,7 @@
# https://groups.google.com/g/dart-ffi/c/nUATMBy7r0c
Pod::Spec.new do |s|
s.name = 'sherpa_onnx_ios'
s.version = '1.10.13'
s.version = '1.10.14'
s.summary = 'A new Flutter FFI plugin project.'
s.description = <<-DESC
A new Flutter FFI plugin project.
... ...
... ... @@ -4,7 +4,7 @@
#
Pod::Spec.new do |s|
s.name = 'sherpa_onnx_macos'
s.version = '1.10.13'
s.version = '1.10.14'
s.summary = 'sherpa-onnx Flutter FFI plugin project.'
s.description = <<-DESC
sherpa-onnx Flutter FFI plugin project.
... ...
{
"dependencies": {
"sherpa-onnx-node": "^1.10.13"
"sherpa-onnx-node": "^1.10.14"
}
}
... ...
... ... @@ -17,7 +17,7 @@ topics:
- voice-activity-detection
# remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx.podspec
version: 1.10.13
version: 1.10.14
homepage: https://github.com/k2-fsa/sherpa-onnx
... ...
... ... @@ -2,3 +2,9 @@
*.config
*.ort
*-tokens.txt
*.bias
*.weights
*.weight
*.*embedding
_Const*
onnx__*
... ...
... ... @@ -32,6 +32,9 @@ from whisper.model import (
TextDecoder,
)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_args():
parser = argparse.ArgumentParser()
... ... @@ -43,8 +46,9 @@ def get_args():
choices=[
"tiny", "tiny.en", "base", "base.en",
"small", "small.en", "medium", "medium.en",
"large", "large-v1", "large-v2",
"large", "large-v1", "large-v2", "large-v3",
"distil-medium.en", "distil-small.en", "distil-large-v2",
# "distil-large-v3", # distil-large-v3 is not supported!
# for fine-tuned models from icefall
"medium-aishell",
],
... ... @@ -63,12 +67,26 @@ def add_meta_data(filename: str, meta_data: Dict[str, Any]):
Key-value pairs.
"""
model = onnx.load(filename)
while len(model.metadata_props):
model.metadata_props.pop()
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
if "large" in filename:
external_filename = filename.split(".onnx")[0]
onnx.save(
model,
filename,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=external_filename + ".weights",
)
else:
onnx.save(model, filename)
def modified_audio_encoder_forward(self: AudioEncoder, x: torch.Tensor):
... ... @@ -376,7 +394,9 @@ def main():
# write tokens
tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual)
tokenizer = whisper.tokenizer.get_tokenizer(
model.is_multilingual, num_languages=model.num_languages
)
model.eval()
print(model.dims)
... ... @@ -384,10 +404,15 @@ def main():
audio = whisper.pad_or_trim(audio)
assert audio.shape == (16000 * 30,), audio.shape
# make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio).to(model.device).unsqueeze(0)
if args.model in ("large", "large-v3"):
n_mels = 128
else:
n_mels = 80
mel = (
whisper.log_mel_spectrogram(audio, n_mels=n_mels).to(model.device).unsqueeze(0)
)
batch_size = 1
assert mel.shape == (batch_size, 80, 30 * 100)
assert mel.shape == (batch_size, n_mels, 30 * 100), mel.shape
encoder = AudioEncoderTensorCache(model.encoder, model.decoder)
... ... @@ -547,6 +572,17 @@ def main():
)
if "large" in args.model:
decoder_external_filename = decoder_filename.split(".onnx")[0]
decoder_model = onnx.load(decoder_filename)
onnx.save(
decoder_model,
decoder_filename,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=decoder_external_filename + ".weights",
)
if "large" in args.model:
# it causes errors for large models, so skip it.
return
# Generate int8 quantization models
... ...
... ... @@ -9,9 +9,10 @@ import base64
from typing import Tuple
import kaldi_native_fbank as knf
import numpy as np
import onnxruntime as ort
import soundfile as sf
import torch
import torchaudio
def get_args():
... ... @@ -98,7 +99,6 @@ class OnnxModel:
self.blank = int(meta["blank_id"])
self.sot_sequence = list(map(int, meta["sot_sequence"].split(",")))
self.sot_sequence.append(self.no_timestamps)
self.all_language_tokens = list(
... ... @@ -226,7 +226,18 @@ def load_tokens(filename):
return tokens
def compute_features(filename: str) -> torch.Tensor:
def load_audio(filename: str) -> Tuple[np.ndarray, int]:
data, sample_rate = sf.read(
filename,
always_2d=True,
dtype="float32",
)
data = data[:, 0] # use only the first channel
samples = np.ascontiguousarray(data)
return samples, sample_rate
def compute_features(filename: str, dim: int = 80) -> torch.Tensor:
"""
Args:
filename:
... ... @@ -234,16 +245,18 @@ def compute_features(filename: str) -> torch.Tensor:
Returns:
Return a 1-D float32 tensor of shape (1, 80, 3000) containing the features.
"""
wave, sample_rate = torchaudio.load(filename)
audio = wave[0].contiguous() # only use the first channel
wave, sample_rate = load_audio(filename)
if sample_rate != 16000:
audio = torchaudio.functional.resample(
audio, orig_freq=sample_rate, new_freq=16000
)
import librosa
wave = librosa.resample(wave, orig_sr=sample_rate, target_sr=16000)
sample_rate = 16000
features = []
online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions())
online_whisper_fbank.accept_waveform(16000, audio.numpy())
opts = knf.WhisperFeatureOptions()
opts.dim = dim
online_whisper_fbank = knf.OnlineWhisperFbank(opts)
online_whisper_fbank.accept_waveform(16000, wave)
online_whisper_fbank.input_finished()
for i in range(online_whisper_fbank.num_frames_ready):
f = online_whisper_fbank.get_frame(i)
... ... @@ -280,8 +293,9 @@ def compute_features(filename: str) -> torch.Tensor:
def main():
args = get_args()
mel = compute_features(args.sound_file)
model = OnnxModel(args.encoder, args.decoder)
dim = 80 if "large-v3" not in args.encoder else 128
mel = compute_features(args.sound_file, dim=dim)
n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)
... ... @@ -313,6 +327,7 @@ def main():
n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache()
print(model.sot_sequence)
tokens = torch.tensor([model.sot_sequence], dtype=torch.int64)
offset = torch.zeros(1, dtype=torch.int64)
logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
... ...
... ... @@ -88,7 +88,9 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
}
std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(WhisperTag{});
WhisperTag tag;
tag.dim = model_->FeatureDim();
return std::make_unique<OfflineStream>(tag);
}
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
... ...
... ... @@ -97,12 +97,16 @@ class OfflineStream::Impl {
}
}
explicit Impl(WhisperTag /*tag*/) {
explicit Impl(WhisperTag tag) {
config_.normalize_samples = true;
opts_.frame_opts.samp_freq = 16000;
opts_.mel_opts.num_bins = 80; // not used
whisper_fbank_ =
std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts);
opts_.mel_opts.num_bins = tag.dim;
knf::WhisperFeatureOptions whisper_opts;
whisper_opts.frame_opts = opts_.frame_opts;
whisper_opts.dim = tag.dim;
whisper_fbank_ = std::make_unique<knf::OnlineWhisperFbank>(whisper_opts);
config_.sampling_rate = opts_.frame_opts.samp_freq;
}
... ...
... ... @@ -35,7 +35,10 @@ struct OfflineRecognitionResult {
std::string AsJsonString() const;
};
struct WhisperTag {};
struct WhisperTag {
int32_t dim = 80;
};
struct CEDTag {};
class OfflineStream {
... ...
... ... @@ -217,6 +217,8 @@ class OfflineWhisperModel::Impl {
int32_t VocabSize() const { return n_vocab_; }
int32_t FeatureDim() const { return n_mels_; }
int32_t Translate() const { return translate_; }
bool IsMultiLingual() const { return is_multilingual_; }
... ... @@ -242,6 +244,7 @@ class OfflineWhisperModel::Impl {
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(n_mels_, "n_mels");
SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer");
SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx");
SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state");
... ... @@ -316,6 +319,7 @@ class OfflineWhisperModel::Impl {
std::unordered_map<int32_t, std::string> id2lang_;
// model meta data
int32_t n_mels_ = 80;
int32_t n_text_layer_ = 0;
int32_t n_text_ctx_ = 0;
int32_t n_text_state_ = 0;
... ... @@ -414,6 +418,8 @@ int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); }
int32_t OfflineWhisperModel::VocabSize() const { return impl_->VocabSize(); }
int32_t OfflineWhisperModel::FeatureDim() const { return impl_->FeatureDim(); }
int32_t OfflineWhisperModel::Translate() const { return impl_->Translate(); }
bool OfflineWhisperModel::IsMultiLingual() const {
... ...
... ... @@ -102,6 +102,7 @@ class OfflineWhisperModel {
int32_t SOT() const;
int32_t TextCtx() const;
int32_t VocabSize() const;
int32_t FeatureDim() const;
int32_t Translate() const;
bool IsMultiLingual() const;
... ...