Fangjun Kuang
Committed by GitHub

export telespeech ctc models to sherpa-onnx (#968)

name: export-telespeech-ctc-to-onnx
on:
workflow_dispatch:
concurrency:
group: export-telespeech-ctc-to-onnx-${{ github.ref }}
cancel-in-progress: true
jobs:
export-telespeech-ctc-to-onnx:
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
name: telespeech
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python-version: ["3.10"]
steps:
- uses: actions/checkout@v4
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install Python dependencies
shell: bash
run: |
pip install onnx onnxruntime soundfile librosa numpy kaldi-native-fbank
- name: Run
shell: bash
run: |
cd scripts/tele-speech
./run.sh
./test.py
- name: Release
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
file: ./*.tar.bz2
overwrite: true
repo_name: k2-fsa/sherpa-onnx
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
tag: asr-models
... ...
# Introduction
This folder contains scripts about adding metadata to
onnx models from
https://hf-mirror.com/lovemefan/telespeech/tree/main
Please see
- https://github.com/Tele-AI/TeleSpeech-ASR
- https://github.com/lovemefan/telespeech-asr-python
- [TeleSpeech模型社区许可协议.pdf](https://github.com/Tele-AI/TeleSpeech-ASR/blob/master/TeleSpeech%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
for more details.
... ...
#!/usr/bin/env python3
import json
from typing import Dict
import onnx
from onnxruntime.quantization import QuantType, quantize_dynamic
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
while len(model.metadata_props):
model.metadata_props.pop()
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
onnx.save(model, filename)
def main():
with open("./vocab.json", "r", encoding="utf-8") as f:
tokens = json.load(f)
vocab_size = len(tokens)
with open("tokens.txt", "w", encoding="utf-8") as f:
for token, idx in tokens.items():
if idx == 0:
f.write("<blk> 0\n")
else:
f.write(f"{token} {idx}\n")
filename = "model.onnx"
meta_data = {
"model_type": "telespeech_ctc",
"version": "1",
"model_author": "Tele-AI",
"comment": "See also https://github.com/lovemefan/telespeech-asr-python",
"license": "https://github.com/Tele-AI/TeleSpeech-ASR/blob/master/TeleSpeech%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf",
"url": "https://github.com/Tele-AI/TeleSpeech-ASR",
}
add_meta_data(filename, meta_data)
filename_int8 = f"model.int8.onnx"
quantize_dynamic(
model_input=filename,
model_output=filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
# filename_uint8 = f"model.uint8.onnx"
# quantize_dynamic(
# model_input=filename,
# model_output=filename_uint8,
# op_types_to_quantize=["MatMul"],
# weight_type=QuantType.QUInt8,
# )
if __name__ == "__main__":
main()
... ...
#!/usr/bin/env bash
curl -SL -O https://hf-mirror.com/lovemefan/telespeech/resolve/main/model_export.onnx
mv model_export.onnx model.onnx
curl -SL -O https://hf-mirror.com/lovemefan/telespeech/resolve/main/vocab.json
curl -SL -O https://github.com/csukuangfj/models/releases/download/a/TeleSpeech.pdf
curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-paraformer-zh-small-2024-03-09/resolve/main/test_wavs/3-sichuan.wav
curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-paraformer-zh-small-2024-03-09/resolve/main/test_wavs/4-tianjin.wav
curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-paraformer-zh-small-2024-03-09/resolve/main/test_wavs/5-henan.wav
ls -lh
./add-metadata.py
dst=sherpa-onnx-telespeech-ctc-zh-2024-06-04
mkdir $dst
mkdir $dst/test_wavs
cp -v model.onnx $dst/
cp -v tokens.txt $dst
cp -v *.wav $dst/test_wavs
cp -v *.pdf $dst
cp -v README.md $dst
cp -v *.py $dst
ls -lh $dst
tar cvjfv ${dst}.tar.bz2 $dst
dst=sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04
mkdir $dst
mkdir $dst/test_wavs
cp -v model.int8.onnx $dst/
cp -v tokens.txt $dst
cp -v *.wav $dst/test_wavs
cp -v *.pdf $dst
cp -v README.md $dst
cp -v *.py $dst
ls -lh $dst
tar cvjfv ${dst}.tar.bz2 $dst
cp -v *.tar.bz2 ../..
ls -lh ../../
... ...
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
from typing import Tuple
import kaldi_native_fbank as knf
import numpy as np
import onnxruntime as ort
import soundfile as sf
"""
NodeArg(name='feats', type='tensor(float)', shape=[1, 'T', 40])
-----
NodeArg(name='logits', type='tensor(float)', shape=['Addlogits_dim_0', 1, 7535])
"""
class OnnxModel:
def __init__(
self,
filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.model = ort.InferenceSession(
filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
self.show()
def show(self):
for i in self.model.get_inputs():
print(i)
print("-----")
for i in self.model.get_outputs():
print(i)
def __call__(self, x):
"""
Args:
x: a float32 tensor of shape (N, T, C)
"""
logits = self.model.run(
[
self.model.get_outputs()[0].name,
],
{
self.model.get_inputs()[0].name: x,
},
)[0]
return logits
def load_audio(filename: str) -> Tuple[np.ndarray, int]:
data, sample_rate = sf.read(
filename,
always_2d=True,
dtype="float32",
)
data = data[:, 0] # use only the first channel
samples = np.ascontiguousarray(data)
return samples, sample_rate
def get_features(test_wav_filename):
samples, sample_rate = load_audio(test_wav_filename)
if sample_rate != 16000:
import librosa
samples = librosa.resample(samples, orig_sr=sample_rate, target_sr=16000)
sample_rate = 16000
samples *= 372768
opts = knf.MfccOptions()
# See https://github.com/Tele-AI/TeleSpeech-ASR/blob/master/mfcc_hires.conf
opts.frame_opts.dither = 0
opts.num_ceps = 40
opts.use_energy = False
opts.mel_opts.num_bins = 40
opts.mel_opts.low_freq = 40
opts.mel_opts.high_freq = -200
mfcc = knf.OnlineMfcc(opts)
mfcc.accept_waveform(16000, samples)
frames = []
for i in range(mfcc.num_frames_ready):
frames.append(mfcc.get_frame(i))
frames = np.stack(frames, axis=0)
return frames
def cmvn(features):
# See https://github.com/Tele-AI/TeleSpeech-ASR/blob/master/wenet_representation/conf/train_d2v2_ark_conformer.yaml#L70
# https://github.com/Tele-AI/TeleSpeech-ASR/blob/master/wenet_representation/wenet/dataset/dataset.py#L184
# https://github.com/Tele-AI/TeleSpeech-ASR/blob/master/wenet_representation/wenet/dataset/processor.py#L278
mean = features.mean(axis=0, keepdims=True)
std = features.std(axis=0, keepdims=True)
return (features - mean) / (std + 1e-5)
def main():
# Please download the test data from
# https://hf-mirror.com/csukuangfj/sherpa-onnx-paraformer-zh-small-2024-03-09/tree/main/test_wavs
test_wav_filename = "./3-sichuan.wav"
test_wav_filename = "./4-tianjin.wav"
test_wav_filename = "./5-henan.wav"
features = get_features(test_wav_filename)
features = cmvn(features)
features = np.expand_dims(features, axis=0) # (T, C) -> (N, T, C)
model_filename = "./model.int8.onnx"
model = OnnxModel(model_filename)
logits = model(features)
logits = logits.squeeze(axis=1) # remove batch axis
ids = logits.argmax(axis=-1)
id2token = dict()
with open("./tokens.txt", encoding="utf-8") as f:
for line in f:
t, idx = line.split()
id2token[int(idx)] = t
tokens = []
blank = 0
prev = -1
for k in ids:
if k != blank and k != prev:
tokens.append(k)
prev = k
tokens = [id2token[i] for i in tokens]
text = "".join(tokens)
print(text)
if __name__ == "__main__":
main()
... ...