Fangjun Kuang
Committed by GitHub

Export MeloTTS to ONNX (#1129)

name: export-melo-tts-to-onnx
on:
push:
branches:
- export-melo-tts-onnx
workflow_dispatch:
concurrency:
group: export-melo-tts-to-onnx-${{ github.ref }}
cancel-in-progress: true
jobs:
export-melo-tts-to-onnx:
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
name: export melo-tts
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python-version: ["3.10"]
steps:
- uses: actions/checkout@v4
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Run
shell: bash
run: |
cd scripts/melo-tts
./run.sh
- uses: actions/upload-artifact@v4
with:
name: test.wav
path: scripts/melo-tts/test.wav
- name: Publish to huggingface (aishell)
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v3
with:
max_attempts: 20
timeout_seconds: 200
shell: bash
command: |
git config --global user.email "csukuangfj@gmail.com"
git config --global user.name "Fangjun Kuang"
rm -rf huggingface
export GIT_LFS_SKIP_SMUDGE=1
export GIT_CLONE_PROTECTION_ACTIVE=false
git clone https://huggingface.co/csukuangfj/vits-melo-tts-zh_en huggingface
cd huggingface
git fetch
git pull
echo "pwd: $PWD"
ls -lh ../scripts/melo-tts
cp -v ../scripts/melo-tts/*.onnx .
cp -v ../scripts/melo-tts/lexicon.txt .
cp -v ../scripts/melo-tts/tokens.txt .
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/date.fst
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/number.fst
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/phone.fst
curl -SL -O https://github.com/csukuangfj/cppjieba/releases/download/sherpa-onnx-2024-04-19/dict.tar.bz2
tar xvf dict.tar.bz2
rm dict.tar.bz2
git lfs track "*.onnx"
git add .
git commit -m "add models"
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/vits-melo-tts-zh_en main || true
cd ..
rm -rf huggingface/.git*
dst=vits-melo-tts-zh_en
mv huggingface $dst
tar cjvf $dst.tar.bz2 $dst
rm -rf $dst
- name: Release
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
file: ./*.tar.bz2
overwrite: true
repo_name: k2-fsa/sherpa-onnx
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
tag: tts-models
... ...
#!/usr/bin/env python3
from typing import Any, Dict
import onnx
import torch
from melo.api import TTS
from melo.text import language_id_map, language_tone_start_map
from melo.text.chinese import pinyin_to_symbol_map
from pypinyin import Style, lazy_pinyin, phrases_dict, pinyin_dict
for k, v in pinyin_to_symbol_map.items():
pinyin_to_symbol_map[k] = v.split()
def get_initial_final_tone(word: str):
initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
finals = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
ans_phone = []
ans_tone = []
for c, v in zip(initials, finals):
raw_pinyin = c + v
v_without_tone = v[:-1]
try:
tone = v[-1]
except:
print("skip", word, initials, finals)
return [], []
pinyin = c + v_without_tone
assert tone in "12345"
if c:
v_rep_map = {
"uei": "ui",
"iou": "iu",
"uen": "un",
}
if v_without_tone in v_rep_map.keys():
pinyin = c + v_rep_map[v_without_tone]
else:
pinyin_rep_map = {
"ing": "ying",
"i": "yi",
"in": "yin",
"u": "wu",
}
if pinyin in pinyin_rep_map.keys():
pinyin = pinyin_rep_map[pinyin]
else:
single_rep_map = {
"v": "yu",
"e": "e",
"i": "y",
"u": "w",
}
if pinyin[0] in single_rep_map.keys():
pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
# print(word, initials, finals, pinyin)
if pinyin not in pinyin_to_symbol_map:
print("skip", pinyin, word, c, v, raw_pinyin)
continue
phone = pinyin_to_symbol_map[pinyin]
ans_phone += phone
ans_tone += [tone] * len(phone)
return ans_phone, ans_tone
def generate_tokens(symbol_list):
with open("tokens.txt", "w", encoding="utf-8") as f:
for i, s in enumerate(symbol_list):
f.write(f"{s} {i}\n")
def generate_lexicon():
word_dict = pinyin_dict.pinyin_dict
phrases = phrases_dict.phrases_dict
with open("lexicon.txt", "w", encoding="utf-8") as f:
for key in word_dict:
if not (0x4E00 <= key <= 0x9FA5):
continue
w = chr(key)
phone, tone = get_initial_final_tone(w)
if not phone:
continue
phone = " ".join(phone)
tone = " ".join(tone)
f.write(f"{w} {phone} {tone}\n")
for w in phrases:
phone, tone = get_initial_final_tone(w)
if not phone:
continue
assert len(phone) == len(tone), (len(phone), len(tone), phone, tone)
phone = " ".join(phone)
tone = " ".join(tone)
f.write(f"{w} {phone} {tone}\n")
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
while len(model.metadata_props):
model.metadata_props.pop()
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
class ModelWrapper(torch.nn.Module):
def __init__(self, model: "SynthesizerTrn"):
super().__init__()
self.model = model
def forward(
self,
x,
x_lengths,
tones,
lang_id,
bert,
ja_bert,
sid,
noise_scale,
length_scale,
noise_scale_w,
max_len=None,
):
"""
Args:
x: A 1-D array of dtype np.int64. Its shape is (token_numbers,)
tones: A 1-D array of dtype np.int64. Its shape is (token_numbers,)
lang_id: A 1-D array of dtype np.int64. Its shape is (token_numbers,)
sid: an integer
"""
return self.model.infer(
x=x,
x_lengths=x_lengths,
sid=sid,
tone=tones,
language=lang_id,
bert=bert,
ja_bert=ja_bert,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
)[0]
def main():
generate_lexicon()
language = "ZH"
model = TTS(language=language, device="cpu")
generate_tokens(model.hps["symbols"])
torch_model = ModelWrapper(model.model)
opset_version = 13
x = torch.randint(low=0, high=10, size=(60,), dtype=torch.int64)
print(x.shape)
x_lengths = torch.tensor([x.size(0)], dtype=torch.int64)
sid = torch.tensor([1], dtype=torch.int64)
tones = torch.zeros_like(x)
lang_id = torch.ones_like(x)
noise_scale = torch.tensor([1.0], dtype=torch.float32)
length_scale = torch.tensor([1.0], dtype=torch.float32)
noise_scale_w = torch.tensor([1.0], dtype=torch.float32)
bert = torch.zeros(1024, x.shape[0], dtype=torch.float32)
ja_bert = torch.zeros(768, x.shape[0], dtype=torch.float32)
x = x.unsqueeze(0)
tones = tones.unsqueeze(0)
lang_id = lang_id.unsqueeze(0)
bert = bert.unsqueeze(0)
ja_bert = ja_bert.unsqueeze(0)
filename = "model.onnx"
torch.onnx.export(
torch_model,
(
x,
x_lengths,
tones,
lang_id,
bert,
ja_bert,
sid,
noise_scale,
length_scale,
noise_scale_w,
),
filename,
opset_version=opset_version,
input_names=[
"x",
"x_lengths",
"tones",
"lang_id",
"bert",
"ja_bert",
"sid",
"noise_scale",
"length_scale",
"noise_scale_w",
],
output_names=["y"],
dynamic_axes={
"x": {0: "N", 1: "L"},
"x_lengths": {0: "N"},
"tones": {0: "N", 1: "L"},
"lang_id": {0: "N", 1: "L"},
"bert": {0: "N", 2: "L"},
"ja_bert": {0: "N", 2: "L"},
"y": {0: "N", 1: "S", 2: "T"},
},
)
meta_data = {
"model_type": "melo-vits",
"comment": "melo",
"language": "Chinese + English",
"add_blank": int(model.hps.data.add_blank),
"n_speakers": 1,
"sample_rate": model.hps.data.sampling_rate,
"bert_dim": 1024,
"ja_bert_dim": 768,
"speaker_id": list(model.hps.data.spk2id.values())[0],
"lang_id": language_id_map[model.language],
"tone_start": language_tone_start_map[model.language],
"url": "https://github.com/myshell-ai/MeloTTS",
"license": "MIT license",
"description": "MeloTTS is a high-quality multi-lingual text-to-speech library by MyShell.ai",
}
add_meta_data(filename, meta_data)
if __name__ == "__main__":
main()
... ...
#!/usr/bin/env bash
set -ex
function install() {
pip install torch==2.3.1+cpu torchaudio==2.3.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
pushd /tmp
git clone https://github.com/myshell-ai/MeloTTS
cd MeloTTS
pip install -r ./requirements.txt
pip install soundfile onnx onnxruntime
python3 -m unidic download
popd
}
install
export PYTHONPATH=/tmp/MeloTTS:$PYTHONPATH
echo "pwd: $PWD"
./export-onnx.py
ls -lh
head lexicon.txt
echo "---"
tail lexicon.txt
echo "---"
head tokens.txt
echo "---"
tail tokens.txt
./test.py
ls -lh
... ...
#!/usr/bin/env python3
from typing import Iterable, List, Tuple
import jieba
import onnxruntime as ort
import soundfile as sf
import torch
class Lexicon:
def __init__(self, lexion_filename: str, tokens_filename: str):
tokens = dict()
with open(tokens_filename, encoding="utf-8") as f:
for line in f:
s, i = line.split()
tokens[s] = int(i)
lexicon = dict()
with open(lexion_filename, encoding="utf-8") as f:
for line in f:
splits = line.split()
word_or_phrase = splits[0]
phone_tone_list = splits[1:]
assert len(phone_tone_list) & 1 == 0, len(phone_tone_list)
phones = phone_tone_list[: len(phone_tone_list) // 2]
phones = [tokens[p] for p in phones]
tones = phone_tone_list[len(phone_tone_list) // 2 :]
tones = [int(t) for t in tones]
lexicon[word_or_phrase] = (phones, tones)
self.lexicon = lexicon
punctuation = ["!", "?", "…", ",", ".", "'", "-"]
for p in punctuation:
i = tokens[p]
tone = 0
self.lexicon[p] = ([i], [tone])
self.lexicon[" "] = ([tokens["_"]], [0])
def _convert(self, text: str) -> Tuple[List[int], List[int]]:
phones = []
tones = []
if text == ",":
text = ","
elif text == "。":
text = "."
elif text == "!":
text = "!"
elif text == "?":
text = "?"
if text not in self.lexicon:
print("t", text)
if len(text) > 1:
for w in text:
print("w", w)
p, t = self.convert(w)
if p:
phones += p
tones += t
return phones, tones
phones, tones = self.lexicon[text]
return phones, tones
def convert(self, text_list: Iterable[str]) -> Tuple[List[int], List[int]]:
phones = []
tones = []
for text in text_list:
print(text)
p, t = self._convert(text)
phones += p
tones += t
return phones, tones
class OnnxModel:
def __init__(self, filename):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 4
self.session_opts = session_opts
self.model = ort.InferenceSession(
filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
meta = self.model.get_modelmeta().custom_metadata_map
self.bert_dim = int(meta["bert_dim"])
self.ja_bert_dim = int(meta["ja_bert_dim"])
self.add_blank = int(meta["add_blank"])
self.sample_rate = int(meta["sample_rate"])
self.speaker_id = int(meta["speaker_id"])
self.lang_id = int(meta["lang_id"])
self.sample_rate = int(meta["sample_rate"])
def __call__(self, x, tones, lang):
"""
Args:
x: 1-D int64 torch tensor
tones: 1-D int64 torch tensor
lang: 1-D int64 torch tensor
"""
x = x.unsqueeze(0)
tones = tones.unsqueeze(0)
lang = lang.unsqueeze(0)
print(x.shape, tones.shape, lang.shape)
bert = torch.zeros(1, self.bert_dim, x.shape[-1])
ja_bert = torch.zeros(1, self.ja_bert_dim, x.shape[-1])
sid = torch.tensor([self.speaker_id], dtype=torch.int64)
noise_scale = torch.tensor([0.6], dtype=torch.float32)
length_scale = torch.tensor([1.0], dtype=torch.float32)
noise_scale_w = torch.tensor([0.8], dtype=torch.float32)
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.int64)
y = self.model.run(
["y"],
{
"x": x.numpy(),
"x_lengths": x_lengths.numpy(),
"tones": tones.numpy(),
"lang_id": lang.numpy(),
"bert": bert.numpy(),
"ja_bert": ja_bert.numpy(),
"sid": sid.numpy(),
"noise_scale": noise_scale.numpy(),
"noise_scale_w": noise_scale_w.numpy(),
"length_scale": length_scale.numpy(),
},
)[0][0][0]
return y
def main():
lexicon = Lexicon(lexion_filename="./lexicon.txt", tokens_filename="./tokens.txt")
text = "永远相信,美好的事情即将发生。多音字测试, 银行,行不行?长沙长大"
s = jieba.cut(text, HMM=True)
phones, tones = lexicon.convert(s)
model = OnnxModel("./model.onnx")
langs = [model.lang_id] * len(phones)
if model.add_blank:
new_phones = [0] * (2 * len(phones) + 1)
new_tones = [0] * (2 * len(tones) + 1)
new_langs = [0] * (2 * len(langs) + 1)
new_phones[1::2] = phones
new_tones[1::2] = tones
new_langs[1::2] = langs
phones = new_phones
tones = new_tones
langs = new_langs
phones = torch.tensor(phones, dtype=torch.int64)
tones = torch.tensor(tones, dtype=torch.int64)
langs = torch.tensor(langs, dtype=torch.int64)
print(phones.shape, tones.shape, langs.shape)
y = model(x=phones, tones=tones, lang=langs)
sf.write("./test.wav", y, model.sample_rate)
if __name__ == "__main__":
main()
... ...