Fangjun Kuang
Committed by GitHub

Add script to convert vits models (#355)

  1 +name: export-vits-ljspeech-to-onnx
  2 +
  3 +on:
  4 + push:
  5 + branches:
  6 + - master
  7 + paths:
  8 + - 'scripts/vits/**'
  9 + - '.github/workflows/export-vits-ljspeech-to-onnx.yaml'
  10 + pull_request:
  11 + paths:
  12 + - 'scripts/vits/**'
  13 + - '.github/workflows/export-vits-ljspeech-to-onnx.yaml'
  14 +
  15 + workflow_dispatch:
  16 +
  17 +concurrency:
  18 + group: export-vits-ljspeech-${{ github.ref }}
  19 + cancel-in-progress: true
  20 +
  21 +jobs:
  22 + export-vits-ljspeech-onnx:
  23 + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
  24 + name: vits ljspeech
  25 + runs-on: ${{ matrix.os }}
  26 + strategy:
  27 + fail-fast: false
  28 + matrix:
  29 + os: [ubuntu-latest]
  30 + torch: ["1.13.0"]
  31 +
  32 + steps:
  33 + - uses: actions/checkout@v4
  34 +
  35 + - name: Install dependencies
  36 + shell: bash
  37 + run: |
  38 + python3 -m pip install -qq torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/torch_stable.html numpy
  39 + python3 -m pip install onnxruntime onnx soundfile
  40 + python3 -m pip install scipy cython unidecode phonemizer
  41 +
  42 + # required by phonemizer
  43 + # See https://bootphon.github.io/phonemizer/install.html
  44 + # To fix the following error: RuntimeError: espeak not installed on your system
  45 + #
  46 + sudo apt-get install festival espeak-ng mbrola
  47 +
  48 +
  49 + - name: export vits ljspeech
  50 + shell: bash
  51 + run: |
  52 + cd scripts/vits
  53 +
  54 + echo "Downloading vits"
  55 + git clone https://github.com/jaywalnut310/vits
  56 + pushd vits/monotonic_align
  57 + python3 setup.py build
  58 + ls -lh build/
  59 + ls -lh build/lib*/
  60 + ls -lh build/lib*/*/
  61 +
  62 + cp build/lib*/monotonic_align/core*.so .
  63 + sed -i.bak s/.monotonic_align.core/.core/g ./__init__.py
  64 + git diff
  65 + popd
  66 +
  67 + export PYTHONPATH=$PWD/vits:$PYTHONPATH
  68 +
  69 + echo "Download models"
  70 +
  71 + wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/pretrained_ljs.pth
  72 + wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/lexicon.txt
  73 + wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt
  74 + wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/test.py
  75 +
  76 + python3 ./export-onnx-ljs.py --config vits/configs/ljs_base.json --checkpoint ./pretrained_ljs.pth
  77 + python3 ./test.py
  78 + ls -lh *.wav
  79 +
  80 + - uses: actions/upload-artifact@v3
  81 + with:
  82 + name: test-0.wav
  83 + path: scripts/vits/test-0.wav
  84 +
  85 + - uses: actions/upload-artifact@v3
  86 + with:
  87 + name: test-1.wav
  88 + path: scripts/vits/test-1.wav
  89 +
  90 + - uses: actions/upload-artifact@v3
  91 + with:
  92 + name: test-2.wav
  93 + path: scripts/vits/test-2.wav
  1 +tokens-ljs.txt
  1 +#!/usr/bin/env python3
  2 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +"""
  5 +This script converts vits models trained using the LJ Speech dataset.
  6 +
  7 +Usage:
  8 +
  9 +(1) Download vits
  10 +
  11 +cd /Users/fangjun/open-source
  12 +git clone https://github.com/jaywalnut310/vits
  13 +
  14 +(2) Download pre-trained models from
  15 +https://huggingface.co/csukuangfj/vits-ljs/tree/main
  16 +
  17 +wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/pretrained_ljs.pth
  18 +
  19 +(3) Run this file
  20 +
  21 +./export-onnx-ljs.py \
  22 + --config ~/open-source//vits/configs/ljs_base.json \
  23 + --checkpoint ~/open-source/icefall-models/vits-ljs/pretrained_ljs.pth
  24 +
  25 +It will generate the following two files:
  26 +
  27 +$ ls -lh *.onnx
  28 +-rw-r--r-- 1 fangjun staff 36M Oct 10 20:48 vits-ljs.int8.onnx
  29 +-rw-r--r-- 1 fangjun staff 109M Oct 10 20:48 vits-ljs.onnx
  30 +"""
  31 +import sys
  32 +
  33 +# Please change this line to point to the vits directory.
  34 +# You can download vits from
  35 +# https://github.com/jaywalnut310/vits
  36 +sys.path.insert(0, "/Users/fangjun/open-source/vits") # noqa
  37 +
  38 +import argparse
  39 +from pathlib import Path
  40 +from typing import Dict, Any
  41 +
  42 +import commons
  43 +import onnx
  44 +import torch
  45 +import utils
  46 +from models import SynthesizerTrn
  47 +from onnxruntime.quantization import QuantType, quantize_dynamic
  48 +from text import text_to_sequence
  49 +from text.symbols import symbols
  50 +from text.symbols import _punctuation
  51 +
  52 +
  53 +def get_args():
  54 + parser = argparse.ArgumentParser()
  55 + parser.add_argument(
  56 + "--config",
  57 + type=str,
  58 + required=True,
  59 + help="""Path to ljs_base.json.
  60 + You can find it at
  61 + https://huggingface.co/csukuangfj/vits-ljs/resolve/main/ljs_base.json
  62 + """,
  63 + )
  64 +
  65 + parser.add_argument(
  66 + "--checkpoint",
  67 + type=str,
  68 + required=True,
  69 + help="""Path to the checkpoint file.
  70 + You can find it at
  71 + https://huggingface.co/csukuangfj/vits-ljs/resolve/main/pretrained_ljs.pth
  72 +
  73 + """,
  74 + )
  75 +
  76 + return parser.parse_args()
  77 +
  78 +
  79 +class OnnxModel(torch.nn.Module):
  80 + def __init__(self, model: SynthesizerTrn):
  81 + super().__init__()
  82 + self.model = model
  83 +
  84 + def forward(
  85 + self,
  86 + x,
  87 + x_lengths,
  88 + noise_scale=1,
  89 + length_scale=1,
  90 + noise_scale_w=1.0,
  91 + sid=None,
  92 + max_len=None,
  93 + ):
  94 + return self.model.infer(
  95 + x=x,
  96 + x_lengths=x_lengths,
  97 + sid=sid,
  98 + noise_scale=noise_scale,
  99 + length_scale=length_scale,
  100 + noise_scale_w=noise_scale_w,
  101 + max_len=max_len,
  102 + )[0]
  103 +
  104 +
  105 +def get_text(text, hps):
  106 + text_norm = text_to_sequence(text, hps.data.text_cleaners)
  107 + if hps.data.add_blank:
  108 + text_norm = commons.intersperse(text_norm, 0)
  109 + text_norm = torch.LongTensor(text_norm)
  110 + return text_norm
  111 +
  112 +
  113 +def check_args(args):
  114 + assert Path(args.config).is_file(), args.config
  115 + assert Path(args.checkpoint).is_file(), args.checkpoint
  116 +
  117 +
  118 +def add_meta_data(filename: str, meta_data: Dict[str, Any]):
  119 + """Add meta data to an ONNX model. It is changed in-place.
  120 +
  121 + Args:
  122 + filename:
  123 + Filename of the ONNX model to be changed.
  124 + meta_data:
  125 + Key-value pairs.
  126 + """
  127 + model = onnx.load(filename)
  128 + for key, value in meta_data.items():
  129 + meta = model.metadata_props.add()
  130 + meta.key = key
  131 + meta.value = str(value)
  132 +
  133 + onnx.save(model, filename)
  134 +
  135 +
  136 +def generate_tokens():
  137 + with open("tokens-ljs.txt", "w", encoding="utf-8") as f:
  138 + for i, s in enumerate(symbols):
  139 + f.write(f"{s} {i}\n")
  140 + print("Generated tokens-ljs.txt")
  141 +
  142 +
  143 +@torch.no_grad()
  144 +def main():
  145 + args = get_args()
  146 + check_args(args)
  147 +
  148 + generate_tokens()
  149 +
  150 + hps = utils.get_hparams_from_file(args.config)
  151 +
  152 + net_g = SynthesizerTrn(
  153 + len(symbols),
  154 + hps.data.filter_length // 2 + 1,
  155 + hps.train.segment_size // hps.data.hop_length,
  156 + **hps.model,
  157 + )
  158 + _ = net_g.eval()
  159 +
  160 + _ = utils.load_checkpoint(args.checkpoint, net_g, None)
  161 +
  162 + x = get_text("Liliana is the most beautiful assistant", hps)
  163 + x = x.unsqueeze(0)
  164 +
  165 + x_length = torch.tensor([x.shape[1]], dtype=torch.int64)
  166 + noise_scale = torch.tensor([1], dtype=torch.float32)
  167 + length_scale = torch.tensor([1], dtype=torch.float32)
  168 + noise_scale_w = torch.tensor([1], dtype=torch.float32)
  169 +
  170 + model = OnnxModel(net_g)
  171 +
  172 + opset_version = 13
  173 +
  174 + filename = "vits-ljs.onnx"
  175 +
  176 + torch.onnx.export(
  177 + model,
  178 + (x, x_length, noise_scale, length_scale, noise_scale_w),
  179 + filename,
  180 + opset_version=opset_version,
  181 + input_names=["x", "x_length", "noise_scale", "length_scale", "noise_scale_w"],
  182 + output_names=["y"],
  183 + dynamic_axes={
  184 + "x": {0: "N", 1: "L"}, # n_audio is also known as batch_size
  185 + "x_length": {0: "N"},
  186 + "y": {0: "N", 2: "L"},
  187 + },
  188 + )
  189 + meta_data = {
  190 + "model_type": "vits",
  191 + "comment": "ljspeech",
  192 + "language": "English",
  193 + "add_blank": int(hps.data.add_blank),
  194 + "sample_rate": hps.data.sampling_rate,
  195 + "punctuation": " ".join(list(_punctuation)),
  196 + }
  197 + print("meta_data", meta_data)
  198 + add_meta_data(filename=filename, meta_data=meta_data)
  199 +
  200 + print("Generate int8 quantization models")
  201 +
  202 + filename_int8 = "vits-ljs.int8.onnx"
  203 + quantize_dynamic(
  204 + model_input=filename,
  205 + model_output=filename_int8,
  206 + weight_type=QuantType.QUInt8,
  207 + )
  208 +
  209 + print(f"Saved to {filename} and {filename_int8}")
  210 +
  211 +
  212 +if __name__ == "__main__":
  213 + main()