Fangjun Kuang
Committed by GitHub

Export MeloTTS to ONNX (#1129)

  1 +name: export-melo-tts-to-onnx
  2 +
  3 +on:
  4 + push:
  5 + branches:
  6 + - export-melo-tts-onnx
  7 + workflow_dispatch:
  8 +
  9 +concurrency:
  10 + group: export-melo-tts-to-onnx-${{ github.ref }}
  11 + cancel-in-progress: true
  12 +
  13 +jobs:
  14 + export-melo-tts-to-onnx:
  15 + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
  16 + name: export melo-tts
  17 + runs-on: ${{ matrix.os }}
  18 + strategy:
  19 + fail-fast: false
  20 + matrix:
  21 + os: [ubuntu-latest]
  22 + python-version: ["3.10"]
  23 +
  24 + steps:
  25 + - uses: actions/checkout@v4
  26 +
  27 + - name: Setup Python ${{ matrix.python-version }}
  28 + uses: actions/setup-python@v5
  29 + with:
  30 + python-version: ${{ matrix.python-version }}
  31 +
  32 + - name: Run
  33 + shell: bash
  34 + run: |
  35 + cd scripts/melo-tts
  36 + ./run.sh
  37 +
  38 + - uses: actions/upload-artifact@v4
  39 + with:
  40 + name: test.wav
  41 + path: scripts/melo-tts/test.wav
  42 +
  43 + - name: Publish to huggingface (aishell)
  44 + env:
  45 + HF_TOKEN: ${{ secrets.HF_TOKEN }}
  46 + uses: nick-fields/retry@v3
  47 + with:
  48 + max_attempts: 20
  49 + timeout_seconds: 200
  50 + shell: bash
  51 + command: |
  52 + git config --global user.email "csukuangfj@gmail.com"
  53 + git config --global user.name "Fangjun Kuang"
  54 +
  55 + rm -rf huggingface
  56 + export GIT_LFS_SKIP_SMUDGE=1
  57 + export GIT_CLONE_PROTECTION_ACTIVE=false
  58 +
  59 + git clone https://huggingface.co/csukuangfj/vits-melo-tts-zh_en huggingface
  60 + cd huggingface
  61 + git fetch
  62 + git pull
  63 + echo "pwd: $PWD"
  64 + ls -lh ../scripts/melo-tts
  65 +
  66 + cp -v ../scripts/melo-tts/*.onnx .
  67 + cp -v ../scripts/melo-tts/lexicon.txt .
  68 + cp -v ../scripts/melo-tts/tokens.txt .
  69 +
  70 + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/date.fst
  71 + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/number.fst
  72 + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/phone.fst
  73 + curl -SL -O https://github.com/csukuangfj/cppjieba/releases/download/sherpa-onnx-2024-04-19/dict.tar.bz2
  74 + tar xvf dict.tar.bz2
  75 + rm dict.tar.bz2
  76 +
  77 + git lfs track "*.onnx"
  78 + git add .
  79 +
  80 + git commit -m "add models"
  81 + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/vits-melo-tts-zh_en main || true
  82 +
  83 + cd ..
  84 +
  85 + rm -rf huggingface/.git*
  86 + dst=vits-melo-tts-zh_en
  87 +
  88 + mv huggingface $dst
  89 +
  90 + tar cjvf $dst.tar.bz2 $dst
  91 + rm -rf $dst
  92 +
  93 + - name: Release
  94 + uses: svenstaro/upload-release-action@v2
  95 + with:
  96 + file_glob: true
  97 + file: ./*.tar.bz2
  98 + overwrite: true
  99 + repo_name: k2-fsa/sherpa-onnx
  100 + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
  101 + tag: tts-models
  1 +#!/usr/bin/env python3
  2 +from typing import Any, Dict
  3 +
  4 +import onnx
  5 +import torch
  6 +from melo.api import TTS
  7 +from melo.text import language_id_map, language_tone_start_map
  8 +from melo.text.chinese import pinyin_to_symbol_map
  9 +from pypinyin import Style, lazy_pinyin, phrases_dict, pinyin_dict
  10 +
  11 +for k, v in pinyin_to_symbol_map.items():
  12 + pinyin_to_symbol_map[k] = v.split()
  13 +
  14 +
  15 +def get_initial_final_tone(word: str):
  16 + initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
  17 + finals = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
  18 +
  19 + ans_phone = []
  20 + ans_tone = []
  21 +
  22 + for c, v in zip(initials, finals):
  23 + raw_pinyin = c + v
  24 + v_without_tone = v[:-1]
  25 + try:
  26 + tone = v[-1]
  27 + except:
  28 + print("skip", word, initials, finals)
  29 + return [], []
  30 +
  31 + pinyin = c + v_without_tone
  32 + assert tone in "12345"
  33 +
  34 + if c:
  35 + v_rep_map = {
  36 + "uei": "ui",
  37 + "iou": "iu",
  38 + "uen": "un",
  39 + }
  40 + if v_without_tone in v_rep_map.keys():
  41 + pinyin = c + v_rep_map[v_without_tone]
  42 + else:
  43 + pinyin_rep_map = {
  44 + "ing": "ying",
  45 + "i": "yi",
  46 + "in": "yin",
  47 + "u": "wu",
  48 + }
  49 + if pinyin in pinyin_rep_map.keys():
  50 + pinyin = pinyin_rep_map[pinyin]
  51 + else:
  52 + single_rep_map = {
  53 + "v": "yu",
  54 + "e": "e",
  55 + "i": "y",
  56 + "u": "w",
  57 + }
  58 + if pinyin[0] in single_rep_map.keys():
  59 + pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
  60 + # print(word, initials, finals, pinyin)
  61 +
  62 + if pinyin not in pinyin_to_symbol_map:
  63 + print("skip", pinyin, word, c, v, raw_pinyin)
  64 + continue
  65 + phone = pinyin_to_symbol_map[pinyin]
  66 + ans_phone += phone
  67 + ans_tone += [tone] * len(phone)
  68 +
  69 + return ans_phone, ans_tone
  70 +
  71 +
  72 +def generate_tokens(symbol_list):
  73 + with open("tokens.txt", "w", encoding="utf-8") as f:
  74 + for i, s in enumerate(symbol_list):
  75 + f.write(f"{s} {i}\n")
  76 +
  77 +
  78 +def generate_lexicon():
  79 + word_dict = pinyin_dict.pinyin_dict
  80 + phrases = phrases_dict.phrases_dict
  81 + with open("lexicon.txt", "w", encoding="utf-8") as f:
  82 + for key in word_dict:
  83 + if not (0x4E00 <= key <= 0x9FA5):
  84 + continue
  85 + w = chr(key)
  86 + phone, tone = get_initial_final_tone(w)
  87 + if not phone:
  88 + continue
  89 + phone = " ".join(phone)
  90 + tone = " ".join(tone)
  91 + f.write(f"{w} {phone} {tone}\n")
  92 +
  93 + for w in phrases:
  94 + phone, tone = get_initial_final_tone(w)
  95 + if not phone:
  96 + continue
  97 + assert len(phone) == len(tone), (len(phone), len(tone), phone, tone)
  98 + phone = " ".join(phone)
  99 + tone = " ".join(tone)
  100 + f.write(f"{w} {phone} {tone}\n")
  101 +
  102 +
  103 +def add_meta_data(filename: str, meta_data: Dict[str, Any]):
  104 + """Add meta data to an ONNX model. It is changed in-place.
  105 +
  106 + Args:
  107 + filename:
  108 + Filename of the ONNX model to be changed.
  109 + meta_data:
  110 + Key-value pairs.
  111 + """
  112 + model = onnx.load(filename)
  113 + while len(model.metadata_props):
  114 + model.metadata_props.pop()
  115 +
  116 + for key, value in meta_data.items():
  117 + meta = model.metadata_props.add()
  118 + meta.key = key
  119 + meta.value = str(value)
  120 +
  121 + onnx.save(model, filename)
  122 +
  123 +
  124 +class ModelWrapper(torch.nn.Module):
  125 + def __init__(self, model: "SynthesizerTrn"):
  126 + super().__init__()
  127 + self.model = model
  128 +
  129 + def forward(
  130 + self,
  131 + x,
  132 + x_lengths,
  133 + tones,
  134 + lang_id,
  135 + bert,
  136 + ja_bert,
  137 + sid,
  138 + noise_scale,
  139 + length_scale,
  140 + noise_scale_w,
  141 + max_len=None,
  142 + ):
  143 + """
  144 + Args:
  145 + x: A 1-D array of dtype np.int64. Its shape is (token_numbers,)
  146 + tones: A 1-D array of dtype np.int64. Its shape is (token_numbers,)
  147 + lang_id: A 1-D array of dtype np.int64. Its shape is (token_numbers,)
  148 + sid: an integer
  149 + """
  150 + return self.model.infer(
  151 + x=x,
  152 + x_lengths=x_lengths,
  153 + sid=sid,
  154 + tone=tones,
  155 + language=lang_id,
  156 + bert=bert,
  157 + ja_bert=ja_bert,
  158 + noise_scale=noise_scale,
  159 + noise_scale_w=noise_scale_w,
  160 + length_scale=length_scale,
  161 + )[0]
  162 +
  163 +
  164 +def main():
  165 + generate_lexicon()
  166 +
  167 + language = "ZH"
  168 + model = TTS(language=language, device="cpu")
  169 +
  170 + generate_tokens(model.hps["symbols"])
  171 +
  172 + torch_model = ModelWrapper(model.model)
  173 +
  174 + opset_version = 13
  175 + x = torch.randint(low=0, high=10, size=(60,), dtype=torch.int64)
  176 + print(x.shape)
  177 + x_lengths = torch.tensor([x.size(0)], dtype=torch.int64)
  178 + sid = torch.tensor([1], dtype=torch.int64)
  179 + tones = torch.zeros_like(x)
  180 + lang_id = torch.ones_like(x)
  181 + noise_scale = torch.tensor([1.0], dtype=torch.float32)
  182 + length_scale = torch.tensor([1.0], dtype=torch.float32)
  183 + noise_scale_w = torch.tensor([1.0], dtype=torch.float32)
  184 +
  185 + bert = torch.zeros(1024, x.shape[0], dtype=torch.float32)
  186 + ja_bert = torch.zeros(768, x.shape[0], dtype=torch.float32)
  187 +
  188 + x = x.unsqueeze(0)
  189 + tones = tones.unsqueeze(0)
  190 + lang_id = lang_id.unsqueeze(0)
  191 + bert = bert.unsqueeze(0)
  192 + ja_bert = ja_bert.unsqueeze(0)
  193 +
  194 + filename = "model.onnx"
  195 +
  196 + torch.onnx.export(
  197 + torch_model,
  198 + (
  199 + x,
  200 + x_lengths,
  201 + tones,
  202 + lang_id,
  203 + bert,
  204 + ja_bert,
  205 + sid,
  206 + noise_scale,
  207 + length_scale,
  208 + noise_scale_w,
  209 + ),
  210 + filename,
  211 + opset_version=opset_version,
  212 + input_names=[
  213 + "x",
  214 + "x_lengths",
  215 + "tones",
  216 + "lang_id",
  217 + "bert",
  218 + "ja_bert",
  219 + "sid",
  220 + "noise_scale",
  221 + "length_scale",
  222 + "noise_scale_w",
  223 + ],
  224 + output_names=["y"],
  225 + dynamic_axes={
  226 + "x": {0: "N", 1: "L"},
  227 + "x_lengths": {0: "N"},
  228 + "tones": {0: "N", 1: "L"},
  229 + "lang_id": {0: "N", 1: "L"},
  230 + "bert": {0: "N", 2: "L"},
  231 + "ja_bert": {0: "N", 2: "L"},
  232 + "y": {0: "N", 1: "S", 2: "T"},
  233 + },
  234 + )
  235 +
  236 + meta_data = {
  237 + "model_type": "melo-vits",
  238 + "comment": "melo",
  239 + "language": "Chinese + English",
  240 + "add_blank": int(model.hps.data.add_blank),
  241 + "n_speakers": 1,
  242 + "sample_rate": model.hps.data.sampling_rate,
  243 + "bert_dim": 1024,
  244 + "ja_bert_dim": 768,
  245 + "speaker_id": list(model.hps.data.spk2id.values())[0],
  246 + "lang_id": language_id_map[model.language],
  247 + "tone_start": language_tone_start_map[model.language],
  248 + "url": "https://github.com/myshell-ai/MeloTTS",
  249 + "license": "MIT license",
  250 + "description": "MeloTTS is a high-quality multi-lingual text-to-speech library by MyShell.ai",
  251 + }
  252 + add_meta_data(filename, meta_data)
  253 +
  254 +
  255 +if __name__ == "__main__":
  256 + main()
  1 +#!/usr/bin/env bash
  2 +
  3 +set -ex
  4 +
  5 +
  6 +
  7 +function install() {
  8 + pip install torch==2.3.1+cpu torchaudio==2.3.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
  9 +
  10 + pushd /tmp
  11 + git clone https://github.com/myshell-ai/MeloTTS
  12 + cd MeloTTS
  13 + pip install -r ./requirements.txt
  14 +
  15 + pip install soundfile onnx onnxruntime
  16 +
  17 + python3 -m unidic download
  18 + popd
  19 +}
  20 +
  21 +install
  22 +
  23 +export PYTHONPATH=/tmp/MeloTTS:$PYTHONPATH
  24 +
  25 +echo "pwd: $PWD"
  26 +
  27 +./export-onnx.py
  28 +
  29 +ls -lh
  30 +
  31 +head lexicon.txt
  32 +echo "---"
  33 +tail lexicon.txt
  34 +echo "---"
  35 +head tokens.txt
  36 +echo "---"
  37 +tail tokens.txt
  38 +
  39 +./test.py
  40 +
  41 +ls -lh
  1 +#!/usr/bin/env python3
  2 +
  3 +from typing import Iterable, List, Tuple
  4 +
  5 +import jieba
  6 +import onnxruntime as ort
  7 +import soundfile as sf
  8 +import torch
  9 +
  10 +
  11 +class Lexicon:
  12 + def __init__(self, lexion_filename: str, tokens_filename: str):
  13 + tokens = dict()
  14 + with open(tokens_filename, encoding="utf-8") as f:
  15 + for line in f:
  16 + s, i = line.split()
  17 + tokens[s] = int(i)
  18 +
  19 + lexicon = dict()
  20 + with open(lexion_filename, encoding="utf-8") as f:
  21 + for line in f:
  22 + splits = line.split()
  23 + word_or_phrase = splits[0]
  24 + phone_tone_list = splits[1:]
  25 + assert len(phone_tone_list) & 1 == 0, len(phone_tone_list)
  26 + phones = phone_tone_list[: len(phone_tone_list) // 2]
  27 + phones = [tokens[p] for p in phones]
  28 +
  29 + tones = phone_tone_list[len(phone_tone_list) // 2 :]
  30 + tones = [int(t) for t in tones]
  31 +
  32 + lexicon[word_or_phrase] = (phones, tones)
  33 + self.lexicon = lexicon
  34 +
  35 + punctuation = ["!", "?", "…", ",", ".", "'", "-"]
  36 + for p in punctuation:
  37 + i = tokens[p]
  38 + tone = 0
  39 + self.lexicon[p] = ([i], [tone])
  40 + self.lexicon[" "] = ([tokens["_"]], [0])
  41 +
  42 + def _convert(self, text: str) -> Tuple[List[int], List[int]]:
  43 + phones = []
  44 + tones = []
  45 +
  46 + if text == ",":
  47 + text = ","
  48 + elif text == "。":
  49 + text = "."
  50 + elif text == "!":
  51 + text = "!"
  52 + elif text == "?":
  53 + text = "?"
  54 +
  55 + if text not in self.lexicon:
  56 + print("t", text)
  57 + if len(text) > 1:
  58 + for w in text:
  59 + print("w", w)
  60 + p, t = self.convert(w)
  61 + if p:
  62 + phones += p
  63 + tones += t
  64 + return phones, tones
  65 +
  66 + phones, tones = self.lexicon[text]
  67 + return phones, tones
  68 +
  69 + def convert(self, text_list: Iterable[str]) -> Tuple[List[int], List[int]]:
  70 + phones = []
  71 + tones = []
  72 + for text in text_list:
  73 + print(text)
  74 + p, t = self._convert(text)
  75 + phones += p
  76 + tones += t
  77 + return phones, tones
  78 +
  79 +
  80 +class OnnxModel:
  81 + def __init__(self, filename):
  82 + session_opts = ort.SessionOptions()
  83 + session_opts.inter_op_num_threads = 1
  84 + session_opts.intra_op_num_threads = 4
  85 +
  86 + self.session_opts = session_opts
  87 + self.model = ort.InferenceSession(
  88 + filename,
  89 + sess_options=self.session_opts,
  90 + providers=["CPUExecutionProvider"],
  91 + )
  92 + meta = self.model.get_modelmeta().custom_metadata_map
  93 + self.bert_dim = int(meta["bert_dim"])
  94 + self.ja_bert_dim = int(meta["ja_bert_dim"])
  95 + self.add_blank = int(meta["add_blank"])
  96 + self.sample_rate = int(meta["sample_rate"])
  97 + self.speaker_id = int(meta["speaker_id"])
  98 + self.lang_id = int(meta["lang_id"])
  99 + self.sample_rate = int(meta["sample_rate"])
  100 +
  101 + def __call__(self, x, tones, lang):
  102 + """
  103 + Args:
  104 + x: 1-D int64 torch tensor
  105 + tones: 1-D int64 torch tensor
  106 + lang: 1-D int64 torch tensor
  107 + """
  108 + x = x.unsqueeze(0)
  109 + tones = tones.unsqueeze(0)
  110 + lang = lang.unsqueeze(0)
  111 +
  112 + print(x.shape, tones.shape, lang.shape)
  113 + bert = torch.zeros(1, self.bert_dim, x.shape[-1])
  114 + ja_bert = torch.zeros(1, self.ja_bert_dim, x.shape[-1])
  115 + sid = torch.tensor([self.speaker_id], dtype=torch.int64)
  116 + noise_scale = torch.tensor([0.6], dtype=torch.float32)
  117 + length_scale = torch.tensor([1.0], dtype=torch.float32)
  118 + noise_scale_w = torch.tensor([0.8], dtype=torch.float32)
  119 +
  120 + x_lengths = torch.tensor([x.shape[-1]], dtype=torch.int64)
  121 +
  122 + y = self.model.run(
  123 + ["y"],
  124 + {
  125 + "x": x.numpy(),
  126 + "x_lengths": x_lengths.numpy(),
  127 + "tones": tones.numpy(),
  128 + "lang_id": lang.numpy(),
  129 + "bert": bert.numpy(),
  130 + "ja_bert": ja_bert.numpy(),
  131 + "sid": sid.numpy(),
  132 + "noise_scale": noise_scale.numpy(),
  133 + "noise_scale_w": noise_scale_w.numpy(),
  134 + "length_scale": length_scale.numpy(),
  135 + },
  136 + )[0][0][0]
  137 + return y
  138 +
  139 +
  140 +def main():
  141 + lexicon = Lexicon(lexion_filename="./lexicon.txt", tokens_filename="./tokens.txt")
  142 +
  143 + text = "永远相信,美好的事情即将发生。多音字测试, 银行,行不行?长沙长大"
  144 + s = jieba.cut(text, HMM=True)
  145 +
  146 + phones, tones = lexicon.convert(s)
  147 +
  148 + model = OnnxModel("./model.onnx")
  149 + langs = [model.lang_id] * len(phones)
  150 +
  151 + if model.add_blank:
  152 + new_phones = [0] * (2 * len(phones) + 1)
  153 + new_tones = [0] * (2 * len(tones) + 1)
  154 + new_langs = [0] * (2 * len(langs) + 1)
  155 +
  156 + new_phones[1::2] = phones
  157 + new_tones[1::2] = tones
  158 + new_langs[1::2] = langs
  159 +
  160 + phones = new_phones
  161 + tones = new_tones
  162 + langs = new_langs
  163 +
  164 + phones = torch.tensor(phones, dtype=torch.int64)
  165 + tones = torch.tensor(tones, dtype=torch.int64)
  166 + langs = torch.tensor(langs, dtype=torch.int64)
  167 +
  168 + print(phones.shape, tones.shape, langs.shape)
  169 +
  170 + y = model(x=phones, tones=tones, lang=langs)
  171 + sf.write("./test.wav", y, model.sample_rate)
  172 +
  173 +
  174 +if __name__ == "__main__":
  175 + main()