Fangjun Kuang
Committed by GitHub
name: export-moonshine-to-onnx
on:
workflow_dispatch:
concurrency:
group: export-moonshine-to-onnx-${{ github.ref }}
cancel-in-progress: true
jobs:
export-moonshine-to-onnx:
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
name: export moonshine models to ONNX
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [macos-latest]
python-version: ["3.10"]
steps:
- uses: actions/checkout@v4
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install Python dependencies
shell: bash
run: |
pip install -q onnx onnxruntime librosa tokenizers soundfile
- name: Run
shell: bash
run: |
pushd scripts/moonshine
./run.sh
popd
mv -v scripts/moonshine/*.tar.bz2 .
mv -v scripts/moonshine/sherpa-onnx-* ./
- name: Release
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
file: ./*.tar.bz2
overwrite: true
repo_name: k2-fsa/sherpa-onnx
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
tag: asr-models
- name: Publish to huggingface (tiny)
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v3
with:
max_attempts: 20
timeout_seconds: 200
shell: bash
command: |
git config --global user.email "csukuangfj@gmail.com"
git config --global user.name "Fangjun Kuang"
d=sherpa-onnx-moonshine-tiny-en-int8
export GIT_LFS_SKIP_SMUDGE=1
export GIT_CLONE_PROTECTION_ACTIVE=false
git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface
mv -v $d/* ./huggingface
cd huggingface
git lfs track "*.onnx"
git lfs track "*.wav"
git status
git add .
git status
git commit -m "add models"
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d main
rm -rf huggingface
- name: Publish to huggingface (base)
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v3
with:
max_attempts: 20
timeout_seconds: 200
shell: bash
command: |
git config --global user.email "csukuangfj@gmail.com"
git config --global user.name "Fangjun Kuang"
d=sherpa-onnx-moonshine-base-en-int8
export GIT_LFS_SKIP_SMUDGE=1
export GIT_CLONE_PROTECTION_ACTIVE=false
git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface
mv -v $d/* ./huggingface
cd huggingface
git lfs track "*.onnx"
git lfs track "*.wav"
git status
git add .
git status
git commit -m "add models"
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d main
rm -rf huggingface
... ...
tokenizer.json
... ...
# Introduction
This directory contains models from
https://github.com/usefulsensors/moonshine
See its license at
https://github.com/usefulsensors/moonshine/blob/main/LICENSE
... ...
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
from pathlib import Path
import tokenizers
from onnxruntime.quantization import QuantType, quantize_dynamic
def generate_tokens():
if Path("./tokens.txt").is_file():
return
print("Generating tokens.txt")
tokenizer = tokenizers.Tokenizer.from_file("./tokenizer.json")
vocab_size = tokenizer.get_vocab_size()
with open("tokens.txt", "w", encoding="utf-8") as f:
for i in range(vocab_size):
s = tokenizer.id_to_token(i).strip()
f.write(f"{s}\t{i}\n")
def main():
generate_tokens()
# Note(fangjun): Don't use int8 for the preprocessor since it has
# a larger impact on the accuracy
for f in ["uncached_decode", "cached_decode", "encode"]:
if Path(f"{f}.int8.onnx").is_file():
continue
print("processing", f)
quantize_dynamic(
model_input=f"{f}.onnx",
model_output=f"{f}.int8.onnx",
weight_type=QuantType.QInt8,
)
if __name__ == "__main__":
main()
... ...
#!/usr/bin/env bash
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
set -ex
cat >LICENSE <<EOF
MIT License
Copyright (c) 2024 Useful Sensors
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
EOF
function download_files() {
for d in tiny base; do
mkdir $d
pushd $d
curl -SL -O https://huggingface.co/UsefulSensors/moonshine/resolve/main/onnx/$d/preprocess.onnx
curl -SL -O https://huggingface.co/UsefulSensors/moonshine/resolve/main/onnx/$d/encode.onnx
curl -SL -O https://huggingface.co/UsefulSensors/moonshine/resolve/main/onnx/$d/uncached_decode.onnx
curl -SL -O https://huggingface.co/UsefulSensors/moonshine/resolve/main/onnx/$d/cached_decode.onnx
popd
done
curl -SL -O https://huggingface.co/csukuangfj/sherpa-onnx-whisper-base/resolve/main/test_wavs/0.wav
curl -SL -O https://huggingface.co/csukuangfj/sherpa-onnx-whisper-base/resolve/main/test_wavs/1.wav
curl -SL -O https://huggingface.co/csukuangfj/sherpa-onnx-whisper-base/resolve/main/test_wavs/8k.wav
curl -SL -O https://huggingface.co/csukuangfj/sherpa-onnx-whisper-base/resolve/main/test_wavs/trans.txt
curl -SL -O https://raw.githubusercontent.com/usefulsensors/moonshine/refs/heads/main/moonshine/assets/tokenizer.json
}
function quantize() {
for d in tiny base; do
echo "==========$d=========="
ls -lh
mv $d/*.onnx .
./export-onnx.py
rm cached_decode.onnx
rm uncached_decode.onnx
rm encode.onnx
ls -lh
./test.py
mv *.onnx $d
mv tokens.txt $d
ls -lh $d
done
}
function zip() {
for d in tiny base; do
s=sherpa-onnx-moonshine-$d-en-int8
mv $d $s
mkdir $s/test_wavs
cp -v *.wav $s/test_wavs
cp trans.txt $s/test_wavs
cp LICENSE $s/
cp ./README.md $s
ls -lh $s
tar cjfv $s.tar.bz2 $s
done
}
download_files
quantize
zip
ls -lh
... ...
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import datetime as dt
import librosa
import numpy as np
import onnxruntime as ort
import soundfile as sf
def display(sess, name):
print(f"=========={name} Input==========")
for i in sess.get_inputs():
print(i)
print(f"=========={name} Output==========")
for i in sess.get_outputs():
print(i)
class OnnxModel:
def __init__(
self,
preprocess: str,
encode: str,
uncached_decode: str,
cached_decode: str,
):
self.init_preprocess(preprocess)
display(self.preprocess, "preprocess")
self.init_encode(encode)
display(self.encode, "encode")
self.init_uncached_decode(uncached_decode)
display(self.uncached_decode, "uncached_decode")
self.init_cached_decode(cached_decode)
display(self.cached_decode, "cached_decode")
def init_preprocess(self, preprocess):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.preprocess = ort.InferenceSession(
preprocess,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
def init_encode(self, encode):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.encode = ort.InferenceSession(
encode,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
def init_uncached_decode(self, uncached_decode):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.uncached_decode = ort.InferenceSession(
uncached_decode,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
def init_cached_decode(self, cached_decode):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.cached_decode = ort.InferenceSession(
cached_decode,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
def run_preprocess(self, audio):
"""
Args:
audio: (batch_size, num_samples), float32
Returns:
A tensor of shape (batch_size, T, dim), float32
"""
return self.preprocess.run(
[
self.preprocess.get_outputs()[0].name,
],
{
self.preprocess.get_inputs()[0].name: audio,
},
)[0]
def run_encode(self, features):
"""
Args:
features: (batch_size, T, dim)
Returns:
A tensor of shape (batch_size, T, dim)
"""
features_len = np.array([features.shape[1]], dtype=np.int32)
return self.encode.run(
[
self.encode.get_outputs()[0].name,
],
{
self.encode.get_inputs()[0].name: features,
self.encode.get_inputs()[1].name: features_len,
},
)[0]
def run_uncached_decode(self, token: int, token_len: int, encoder_out: np.ndarray):
"""
Args:
token: The current token
token_len: Number of predicted tokens so far
encoder_out: A tensor fo shape (batch_size, T, dim)
Returns:
A a tuple:
- a tensor of shape (batch_size, 1, dim)
- a list of states
"""
token_tensor = np.array([[token]], dtype=np.int32)
token_len_tensor = np.array([token_len], dtype=np.int32)
num_outs = len(self.uncached_decode.get_outputs())
out_names = [
self.uncached_decode.get_outputs()[i].name for i in range(num_outs)
]
out = self.uncached_decode.run(
out_names,
{
self.uncached_decode.get_inputs()[0].name: token_tensor,
self.uncached_decode.get_inputs()[1].name: encoder_out,
self.uncached_decode.get_inputs()[2].name: token_len_tensor,
},
)
logits = out[0]
states = out[1:]
return logits, states
def run_cached_decode(
self, token: int, token_len: int, encoder_out: np.ndarray, states
):
"""
Args:
token: The current token
token_len: Number of predicted tokens so far
encoder_out: A tensor of shape (batch_size, T, dim)
states: previous states
Returns:
A a tuple:
- a tensor of shape (batch_size, 1, dim)
- a list of states
"""
token_tensor = np.array([[token]], dtype=np.int32)
token_len_tensor = np.array([token_len], dtype=np.int32)
num_outs = len(self.cached_decode.get_outputs())
out_names = [self.cached_decode.get_outputs()[i].name for i in range(num_outs)]
states_inputs = {}
for i in range(3, len(self.cached_decode.get_inputs())):
name = self.cached_decode.get_inputs()[i].name
states_inputs[name] = states[i - 3]
out = self.cached_decode.run(
out_names,
{
self.cached_decode.get_inputs()[0].name: token_tensor,
self.cached_decode.get_inputs()[1].name: encoder_out,
self.cached_decode.get_inputs()[2].name: token_len_tensor,
**states_inputs,
},
)
logits = out[0]
states = out[1:]
return logits, states
def main():
wave = "./1.wav"
id2token = dict()
token2id = dict()
with open("./tokens.txt", encoding="utf-8") as f:
for k, line in enumerate(f):
t, idx = line.split("\t")
id2token[int(idx)] = t
token2id[t] = int(idx)
model = OnnxModel(
preprocess="./preprocess.onnx",
encode="./encode.int8.onnx",
uncached_decode="./uncached_decode.int8.onnx",
cached_decode="./cached_decode.int8.onnx",
)
audio, sample_rate = sf.read(wave, dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel
if sample_rate != 16000:
audio = librosa.resample(
audio,
orig_sr=sample_rate,
target_sr=16000,
)
sample_rate = 16000
audio = audio[None] # (1, num_samples)
print("audio.shape", audio.shape) # (1, 159414)
start_t = dt.datetime.now()
features = model.run_preprocess(audio) # (1, 413, 288)
print("features", features.shape)
sos = token2id["<s>"]
eos = token2id["</s>"]
tokens = [sos]
encoder_out = model.run_encode(features)
print("encoder_out.shape", encoder_out.shape) # (1, 413, 288)
logits, states = model.run_uncached_decode(
token=tokens[-1],
token_len=len(tokens),
encoder_out=encoder_out,
)
print("logits.shape", logits.shape) # (1, 1, 32768)
print("len(states)", len(states)) # 24
max_len = int((audio.shape[-1] / 16000) * 6)
for i in range(max_len):
token = logits.squeeze().argmax()
if token == eos:
break
tokens.append(token)
logits, states = model.run_cached_decode(
token=tokens[-1],
token_len=len(tokens),
encoder_out=encoder_out,
states=states,
)
tokens = tokens[1:] # remove sos
words = [id2token[i] for i in tokens]
underline = "▁"
# underline = b"\xe2\x96\x81".decode()
text = "".join(words).replace(underline, " ").strip()
end_t = dt.datetime.now()
t = (end_t - start_t).total_seconds()
rtf = t * 16000 / audio.shape[-1]
print(text)
print("RTF:", rtf)
if __name__ == "__main__":
main()
... ...