Fangjun Kuang
Committed by GitHub

Export https://github.com/KittenML/KittenTTS to sherpa-onnx (#2456)

name: export-kitten-to-onnx
on:
push:
branches:
- kitten-tts
workflow_dispatch:
concurrency:
group: export-kitten-to-onnx-${{ github.ref }}
cancel-in-progress: true
jobs:
export-kitten-to-onnx:
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
name: export kitten ${{ matrix.version }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python-version: ["3.10"]
steps:
- uses: actions/checkout@v4
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install Python dependencies
shell: bash
run: |
pip install "numpy<=1.26.4" onnx==1.16.0 onnxruntime==1.17.1 librosa soundfile piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html
- name: Run
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
shell: bash
run: |
cd scripts/kitten-tts/nano_v0_1
./run.sh
- name: Collect results
shell: bash
run: |
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
tar xf espeak-ng-data.tar.bz2
rm espeak-ng-data.tar.bz2
src=scripts/kitten-tts/nano_v0_1
d=kitten-nano-en-v0_1-fp16
mkdir $d
cp -a LICENSE $d/LICENSE
cp -a espeak-ng-data $d/
cp -v $src/model.fp16.onnx $d/model.fp16.onnx
cp -v $src/voices.bin $d/
cp -v $src/tokens.txt $d/
cp -v $src/../README.md $d/README.md
ls -lh $d/
tar cjfv $d.tar.bz2 $d
ls -lh $d.tar.bz2
- name: Release
if: github.repository_owner == 'csukuangfj'
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
file: ./*.tar.bz2
overwrite: true
repo_name: k2-fsa/sherpa-onnx
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
tag: tts-models
- name: Release
if: github.repository_owner == 'k2-fsa'
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
file: ./*.tar.bz2
overwrite: true
tag: tts-models
- name: Publish to huggingface
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v3
with:
max_attempts: 20
timeout_seconds: 200
shell: bash
command: |
git config --global user.email "csukuangfj@gmail.com"
git config --global user.name "Fangjun Kuang"
dirs=(
kitten-nano-en-v0_1-fp16
)
export GIT_LFS_SKIP_SMUDGE=1
export GIT_CLONE_PROTECTION_ACTIVE=false
for d in ${dirs[@]}; do
rm -rf huggingface
git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface
cd huggingface
rm -rf ./*
git lfs track "*.onnx"
git lfs track af_dict
git lfs track ar_dict
git lfs track cmn_dict
git lfs track da_dict en_dict fa_dict hu_dict ia_dict it_dict lb_dict phondata ru_dict ta_dict
git lfs track ur_dict yue_dict
cp -a ../$d/* ./
git add .
ls -lh
git status
git commit -m "add models"
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d main || true
done
... ...
... ... @@ -142,3 +142,5 @@ README-DEV.txt
.idea
sherpa-onnx-dolphin-base-ctc-multi-lang-int8-2025-04-02
dict
*.npz
voices.bin
... ...
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
import argparse
import numpy as np
import onnx
from generate_voices_bin import speaker2id
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", type=str, required=True, help="input and output onnx model"
)
return parser.parse_args()
def main():
args = get_args()
print(args.model)
model = onnx.load(args.model)
style = np.load("./voices.npz")
style_shape = style[list(style.keys())[0]].shape
speaker2id_str = ""
id2speaker_str = ""
sep = ""
for s, i in speaker2id.items():
speaker2id_str += f"{sep}{s}->{i}"
id2speaker_str += f"{sep}{i}->{s}"
sep = ","
meta_data = {
"model_type": "kitten-tts",
"language": "English",
"has_espeak": 1,
"sample_rate": 24000,
"version": 1,
"voice": "en-us",
"style_dim": ",".join(map(str, style_shape)),
"n_speakers": len(speaker2id),
"speaker2id": speaker2id_str,
"id2speaker": id2speaker_str,
"speaker_names": ",".join(map(str, speaker2id.keys())),
"model_url": "https://huggingface.co/KittenML/kitten-tts-nano-0.1",
"see_also": "https://github.com/KittenML/KittenTTS",
"maintainer": "k2-fsa",
"comment": "This is kitten-tts-nano-0.1 and supports only English",
}
print(model.metadata_props)
while len(model.metadata_props):
model.metadata_props.pop()
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
print("--------------------")
print(model.metadata_props)
onnx.save(model, args.model)
print(f"Please see {args.model}")
if __name__ == "__main__":
main()
... ...
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
"""
Change the model so that it can be run in onnxruntime 1.17.1
"""
import onnx
def main():
model = onnx.load("kitten_tts_nano_v0_1.onnx")
# Print current opsets
for opset in model.opset_import:
print(f"Domain: '{opset.domain}', Version: {opset.version}")
# Modify the opset versions (be careful!)
for opset in model.opset_import:
if opset.domain == "": # ai.onnx domain
opset.version = 19 # change from 20 to 19
elif opset.domain == "ai.onnx.ml":
opset.version = 4 # change from 5 to 4
# Save the modified model
onnx.save(model, "model.fp16.onnx")
if __name__ == "__main__":
main()
... ...
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
def get_vocab():
# https://github.com/KittenML/KittenTTS/blob/main/kittentts/onnx_model.py#L17
_pad = "$"
_punctuation = ';:,.!?¡¿—…"«»"" '
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
dicts = {}
for i in range(len((symbols))):
dicts[symbols[i]] = i
return dicts
def main():
token2id = get_vocab()
with open("tokens.txt", "w", encoding="utf-8") as f:
for s, i in token2id.items():
f.write(f"{s} {i}\n")
if __name__ == "__main__":
main()
... ...
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
from pathlib import Path
import numpy as np
speakers = [
"expr-voice-2-m",
"expr-voice-2-f",
"expr-voice-3-m",
"expr-voice-3-f",
"expr-voice-4-m",
"expr-voice-4-f",
"expr-voice-5-m",
"expr-voice-5-f",
]
id2speaker = {idx: speaker for idx, speaker in enumerate(speakers)}
speaker2id = {speaker: idx for idx, speaker in id2speaker.items()}
def main():
if Path("./voices.bin").is_file():
print("./voices.bin exists - skip")
return
voices = np.load("./voices.npz")
with open("voices.bin", "wb") as f:
for speaker in speakers:
v = voices[speaker]
# v.shape (1, 256)
f.write(v.tobytes())
if __name__ == "__main__":
main()
... ...
#!/usr/bin/env bash
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
set -ex
if [ ! -f kitten_tts_nano_v0_1.onnx ]; then
curl -SL -O https://huggingface.co/KittenML/kitten-tts-nano-0.1/resolve/main/kitten_tts_nano_v0_1.onnx
fi
if [ ! -f voices.npz ]; then
curl -SL -O https://huggingface.co/KittenML/kitten-tts-nano-0.1/resolve/main/voices.npz
fi
./generate_voices_bin.py
./generate_tokens.py
./convert_opset.py
./show.py
./add_meta_data.py --model ./model.fp16.onnx
# ./test.py --model ./model.fp16.onnx --tokens ./tokens.txt --voice ./voices.bin
ls -lh
... ...
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
import onnxruntime
import onnx
"""
[key: "onnx.infer"
value: "onnxruntime.quant"
, key: "onnx.quant.pre_process"
value: "onnxruntime.quant"
]
NodeArg(name='input_ids', type='tensor(int64)', shape=[1, 'sequence_length'])
NodeArg(name='style', type='tensor(float)', shape=[1, 256])
NodeArg(name='speed', type='tensor(float)', shape=[1])
-----
NodeArg(name='waveform', type='tensor(float)', shape=['num_samples'])
NodeArg(name='duration', type='tensor(int64)', shape=['Castduration_dim_0'])
"""
def show(filename):
model = onnx.load(filename)
print(model.metadata_props)
session_opts = onnxruntime.SessionOptions()
session_opts.log_severity_level = 3
sess = onnxruntime.InferenceSession(
filename, session_opts, providers=["CPUExecutionProvider"]
)
for i in sess.get_inputs():
print(i)
print("-----")
for i in sess.get_outputs():
print(i)
def main():
show("./model.fp16.onnx")
if __name__ == "__main__":
main()
... ...
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
import argparse
import time
from pathlib import Path
from typing import Dict, List
import numpy as np
try:
from piper_phonemize import phonemize_espeak
except Exception as ex:
raise RuntimeError(
f"{ex}\nPlease run\n"
"pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html"
)
import onnxruntime as ort
import soundfile as sf
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
required=True,
help="Path to the model",
)
parser.add_argument(
"--voices-bin",
type=str,
required=True,
help="Path to the voices.bin",
)
parser.add_argument(
"--tokens",
type=str,
required=True,
help="Path to tokens.txt",
)
return parser.parse_args()
def show(filename):
session_opts = ort.SessionOptions()
session_opts.log_severity_level = 3
sess = ort.InferenceSession(filename, session_opts)
for i in sess.get_inputs():
print(i)
print("-----")
for i in sess.get_outputs():
print(i)
def load_tokens(filename: str) -> Dict[str, int]:
ans = dict()
with open(filename, encoding="utf-8") as f:
for line in f:
fields = line.strip().split()
if len(fields) == 2:
token, idx = fields
ans[token] = int(idx)
else:
assert len(fields) == 1, (len(fields), line)
ans[" "] = int(fields[0])
return ans
def load_voices(speaker_names: List[str], dim: List[int], voices_bin: str):
embedding = (
np.fromfile(voices_bin, dtype="uint8")
.view(np.float32)
.reshape(len(speaker_names), *dim)
)
ans = dict()
for i in range(len(speaker_names)):
ans[speaker_names[i]] = embedding[i]
return ans
class OnnxModel:
def __init__(self, model_filename: str, voices_bin: str, tokens: str):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.model = ort.InferenceSession(
model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
self.token2id = load_tokens(tokens)
meta = self.model.get_modelmeta().custom_metadata_map
print(meta)
dim = list(map(int, meta["style_dim"].split(",")))
speaker_names = meta["speaker_names"].split(",")
self.voices = load_voices(
speaker_names=speaker_names, dim=dim, voices_bin=voices_bin
)
self.sample_rate = int(meta["sample_rate"])
def __call__(self, text: str, voice):
tokens = phonemize_espeak(text, "en-us")
# tokens is List[List[str]]
# Each sentence is a List[str]
# len(tokens) == number of sentences
flatten = []
for t in tokens:
flatten.extend(t)
# we append a space at the end of a sentence so that there is
# a pause in the generated audio
flatten.append(" ")
tokens = "".join(flatten)
tokens = list(tokens)
token_ids = [self.token2id[i] for i in tokens]
style = self.voices[voice]
token_ids = [0, *token_ids, 0]
token_ids = np.array([token_ids], dtype=np.int64)
speed = np.array([1.0], dtype=np.float32)
audio = self.model.run(
[
self.model.get_outputs()[0].name,
],
{
self.model.get_inputs()[0].name: token_ids,
self.model.get_inputs()[1].name: style,
self.model.get_inputs()[2].name: speed,
},
)[0]
return audio
def main():
args = get_args()
print(vars(args))
show(args.model)
# tokens = phonemize_espeak("how are you doing?", "en-us")
# [['h', 'ˌ', 'a', 'ʊ', ' ', 'ɑ', 'ː', 'ɹ', ' ', 'j', 'u', 'ː', ' ', 'd', 'ˈ', 'u', 'ː', 'ɪ', 'ŋ', '?']]
m = OnnxModel(
model_filename=args.model, voices_bin=args.voices_bin, tokens=args.tokens
)
text = (
"Today as always, men fall into two groups: slaves and free men. "
+ " Whoever does not have two-thirds of his day for himself, "
+ "is a slave, whatever he may be: a statesman, a businessman, "
+ "an official, or a scholar."
)
for i, voice in enumerate(m.voices.keys(), 1):
print(f"Testing {i}/{len(m.voices)} - {voice}/{args.model}")
start = time.time()
audio = m(text, voice=voice)
end = time.time()
elapsed_seconds = end - start
audio_duration = len(audio) / m.sample_rate
real_time_factor = elapsed_seconds / audio_duration
filename = f"{Path(args.model).stem}-{voice}.wav"
sf.write(
filename,
audio,
samplerate=m.sample_rate,
subtype="PCM_16",
)
print(f" Saved to {filename}")
print(f" Elapsed seconds: {elapsed_seconds:.3f}")
print(f" Audio duration in seconds: {audio_duration:.3f}")
print(
f" RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}"
)
if __name__ == "__main__":
main()
... ...