Fangjun Kuang
Committed by GitHub

Support VITS VCTK models (#367)

* Support VITS VCTK models

* Release v1.8.1
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(sherpa-onnx)
set(SHERPA_ONNX_VERSION "1.8.0")
set(SHERPA_ONNX_VERSION "1.8.1")
# Disable warning about
#
... ...
... ... @@ -59,6 +59,16 @@ def get_args():
)
parser.add_argument(
"--sid",
type=int,
default=0,
help="""Speaker ID. Used only for multi-speaker models, e.g.
models trained using the VCTK dataset. Not used for single-speaker
models, e.g., models trained using the LJ speech dataset.
""",
)
parser.add_argument(
"--debug",
type=bool,
default=False,
... ... @@ -105,7 +115,7 @@ def main():
)
)
tts = sherpa_onnx.OfflineTts(tts_config)
audio = tts.generate(args.text)
audio = tts.generate(args.text, sid=args.sid)
sf.write(
args.output_filename,
audio.samples,
... ...
tokens-ljs.txt
tokens-vctk.txt
... ...
... ... @@ -191,6 +191,7 @@ def main():
"comment": "ljspeech",
"language": "English",
"add_blank": int(hps.data.add_blank),
"n_speakers": int(hps.data.n_speakers),
"sample_rate": hps.data.sampling_rate,
"punctuation": " ".join(list(_punctuation)),
}
... ...
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
"""
This script converts vits models trained using the VCTK dataset.
Usage:
(1) Download vits
cd /Users/fangjun/open-source
git clone https://github.com/jaywalnut310/vits
(2) Download pre-trained models from
https://huggingface.co/csukuangfj/vits-vctk/tree/main
wget https://huggingface.co/csukuangfj/vits-vctk/resolve/main/pretrained_vctk.pth
(3) Run this file
./export-onnx-vctk.py \
--config ~/open-source//vits/configs/vctk_base.json \
--checkpoint ~/open-source/icefall-models/vits-vctk/pretrained_vctk.pth
It will generate the following two files:
$ ls -lh *.onnx
-rw-r--r-- 1 fangjun staff 37M Oct 16 10:57 vits-vctk.int8.onnx
-rw-r--r-- 1 fangjun staff 116M Oct 16 10:57 vits-vctk.onnx
"""
import sys
# Please change this line to point to the vits directory.
# You can download vits from
# https://github.com/jaywalnut310/vits
sys.path.insert(0, "/Users/fangjun/open-source/vits") # noqa
import argparse
from pathlib import Path
from typing import Dict, Any
import commons
import onnx
import torch
import utils
from models import SynthesizerTrn
from onnxruntime.quantization import QuantType, quantize_dynamic
from text import text_to_sequence
from text.symbols import symbols
from text.symbols import _punctuation
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
type=str,
required=True,
help="""Path to vctk_base.json.
You can find it at
https://huggingface.co/csukuangfj/vits-vctk/resolve/main/vctk_base.json
""",
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="""Path to the checkpoint file.
You can find it at
https://huggingface.co/csukuangfj/vits-vctk/resolve/main/pretrained_vctk.pth
""",
)
return parser.parse_args()
class OnnxModel(torch.nn.Module):
def __init__(self, model: SynthesizerTrn):
super().__init__()
self.model = model
def forward(
self,
x,
x_lengths,
noise_scale=1,
length_scale=1,
noise_scale_w=1.0,
sid=0,
max_len=None,
):
return self.model.infer(
x=x,
x_lengths=x_lengths,
sid=sid,
noise_scale=noise_scale,
length_scale=length_scale,
noise_scale_w=noise_scale_w,
max_len=max_len,
)[0]
def get_text(text, hps):
text_norm = text_to_sequence(text, hps.data.text_cleaners)
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = torch.LongTensor(text_norm)
return text_norm
def check_args(args):
assert Path(args.config).is_file(), args.config
assert Path(args.checkpoint).is_file(), args.checkpoint
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
def generate_tokens():
with open("tokens-vctk.txt", "w", encoding="utf-8") as f:
for i, s in enumerate(symbols):
f.write(f"{s} {i}\n")
print("Generated tokens-vctk.txt")
@torch.no_grad()
def main():
args = get_args()
check_args(args)
generate_tokens()
hps = utils.get_hparams_from_file(args.config)
net_g = SynthesizerTrn(
len(symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
)
_ = net_g.eval()
_ = utils.load_checkpoint(args.checkpoint, net_g, None)
x = get_text("Liliana is the most beautiful assistant", hps)
x = x.unsqueeze(0)
x_length = torch.tensor([x.shape[1]], dtype=torch.int64)
noise_scale = torch.tensor([1], dtype=torch.float32)
length_scale = torch.tensor([1], dtype=torch.float32)
noise_scale_w = torch.tensor([1], dtype=torch.float32)
sid = torch.tensor([0], dtype=torch.int64)
model = OnnxModel(net_g)
opset_version = 13
filename = "vits-vctk.onnx"
torch.onnx.export(
model,
(x, x_length, noise_scale, length_scale, noise_scale_w, sid),
filename,
opset_version=opset_version,
input_names=[
"x",
"x_length",
"noise_scale",
"length_scale",
"noise_scale_w",
"sid",
],
output_names=["y"],
dynamic_axes={
"x": {0: "N", 1: "L"}, # n_audio is also known as batch_size
"x_length": {0: "N"},
"y": {0: "N", 2: "L"},
},
)
meta_data = {
"model_type": "vits",
"comment": "vctk",
"language": "English",
"add_blank": int(hps.data.add_blank),
"n_speakers": int(hps.data.n_speakers),
"sample_rate": hps.data.sampling_rate,
"punctuation": " ".join(list(_punctuation)),
}
print("meta_data", meta_data)
add_meta_data(filename=filename, meta_data=meta_data)
print("Generate int8 quantization models")
filename_int8 = "vits-vctk.int8.onnx"
quantize_dynamic(
model_input=filename,
model_output=filename_int8,
weight_type=QuantType.QUInt8,
)
print(f"Saved to {filename} and {filename_int8}")
if __name__ == "__main__":
main()
... ...
... ... @@ -18,7 +18,8 @@ class OfflineTtsImpl {
static std::unique_ptr<OfflineTtsImpl> Create(const OfflineTtsConfig &config);
virtual GeneratedAudio Generate(const std::string &text) const = 0;
virtual GeneratedAudio Generate(const std::string &text,
int64_t sid = 0) const = 0;
};
} // namespace sherpa_onnx
... ...
... ... @@ -23,7 +23,8 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
lexicon_(config.model.vits.lexicon, config.model.vits.tokens,
model_->Punctuations()) {}
GeneratedAudio Generate(const std::string &text) const override {
GeneratedAudio Generate(const std::string &text,
int64_t sid = 0) const override {
std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text);
if (x.empty()) {
SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
... ... @@ -47,7 +48,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
Ort::Value x_tensor = Ort::Value::CreateTensor(
memory_info, x.data(), x.size(), x_shape.data(), x_shape.size());
Ort::Value audio = model_->Run(std::move(x_tensor));
Ort::Value audio = model_->Run(std::move(x_tensor), sid);
std::vector<int64_t> audio_shape =
audio.GetTensorTypeAndShapeInfo().GetShape();
... ...
... ... @@ -13,6 +13,11 @@ void OfflineTtsVitsModelConfig::Register(ParseOptions *po) {
po->Register("vits-model", &model, "Path to VITS model");
po->Register("vits-lexicon", &lexicon, "Path to lexicon.txt for VITS models");
po->Register("vits-tokens", &tokens, "Path to tokens.txt for VITS models");
po->Register("vits-noise-scale", &noise_scale, "noise_scale for VITS models");
po->Register("vits-noise-scale-w", &noise_scale_w,
"noise_scale_w for VITS models");
po->Register("vits-length-scale", &length_scale,
"length_scale for VITS models");
}
bool OfflineTtsVitsModelConfig::Validate() const {
... ... @@ -55,7 +60,10 @@ std::string OfflineTtsVitsModelConfig::ToString() const {
os << "OfflineTtsVitsModelConfig(";
os << "model=\"" << model << "\", ";
os << "lexicon=\"" << lexicon << "\", ";
os << "tokens=\"" << tokens << "\")";
os << "tokens=\"" << tokens << "\", ";
os << "noise_scale=" << noise_scale << ", ";
os << "noise_scale_w=" << noise_scale_w << ", ";
os << "length_scale=" << length_scale << ")";
return os.str();
}
... ...
... ... @@ -16,12 +16,26 @@ struct OfflineTtsVitsModelConfig {
std::string lexicon;
std::string tokens;
float noise_scale = 0.667;
float noise_scale_w = 0.8;
float length_scale = 1;
// used only for multi-speaker models, e.g, vctk speech dataset.
// Not applicable for single-speaker models, e.g., ljspeech dataset
OfflineTtsVitsModelConfig() = default;
OfflineTtsVitsModelConfig(const std::string &model,
const std::string &lexicon,
const std::string &tokens)
: model(model), lexicon(lexicon), tokens(tokens) {}
const std::string &tokens,
float noise_scale = 0.667,
float noise_scale_w = 0.8, float length_scale = 1)
: model(model),
lexicon(lexicon),
tokens(tokens),
noise_scale(noise_scale),
noise_scale_w(noise_scale_w),
length_scale(length_scale) {}
void Register(ParseOptions *po);
bool Validate() const;
... ...
... ... @@ -26,7 +26,7 @@ class OfflineTtsVitsModel::Impl {
Init(buf.data(), buf.size());
}
Ort::Value Run(Ort::Value x) {
Ort::Value Run(Ort::Value x, int64_t sid) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
... ... @@ -44,20 +44,33 @@ class OfflineTtsVitsModel::Impl {
Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1);
int64_t scale_shape = 1;
float noise_scale = 1;
float length_scale = 1;
float noise_scale_w = 1;
float noise_scale = config_.vits.noise_scale;
float length_scale = config_.vits.length_scale;
float noise_scale_w = config_.vits.noise_scale_w;
Ort::Value noise_scale_tensor =
Ort::Value::CreateTensor(memory_info, &noise_scale, 1, &scale_shape, 1);
Ort::Value length_scale_tensor = Ort::Value::CreateTensor(
memory_info, &length_scale, 1, &scale_shape, 1);
Ort::Value noise_scale_w_tensor = Ort::Value::CreateTensor(
memory_info, &noise_scale_w, 1, &scale_shape, 1);
std::array<Ort::Value, 5> inputs = {
std::move(x), std::move(x_length), std::move(noise_scale_tensor),
std::move(length_scale_tensor), std::move(noise_scale_w_tensor)};
Ort::Value sid_tensor =
Ort::Value::CreateTensor(memory_info, &sid, 1, &scale_shape, 1);
std::vector<Ort::Value> inputs;
inputs.reserve(6);
inputs.push_back(std::move(x));
inputs.push_back(std::move(x_length));
inputs.push_back(std::move(noise_scale_tensor));
inputs.push_back(std::move(length_scale_tensor));
inputs.push_back(std::move(noise_scale_w_tensor));
if (input_names_.size() == 6 && input_names_.back() == "sid") {
inputs.push_back(std::move(sid_tensor));
}
auto out =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
... ... @@ -93,6 +106,7 @@ class OfflineTtsVitsModel::Impl {
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate");
SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank");
SHERPA_ONNX_READ_META_DATA(n_speakers_, "n_speakers");
SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation");
}
... ... @@ -112,6 +126,7 @@ class OfflineTtsVitsModel::Impl {
int32_t sample_rate_;
int32_t add_blank_;
int32_t n_speakers_;
std::string punctuations_;
};
... ... @@ -120,8 +135,8 @@ OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config)
OfflineTtsVitsModel::~OfflineTtsVitsModel() = default;
Ort::Value OfflineTtsVitsModel::Run(Ort::Value x) {
return impl_->Run(std::move(x));
Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, int64_t sid /*=0*/) {
return impl_->Run(std::move(x), sid);
}
int32_t OfflineTtsVitsModel::SampleRate() const { return impl_->SampleRate(); }
... ...
... ... @@ -22,10 +22,14 @@ class OfflineTtsVitsModel {
/** Run the model.
*
* @param x A int64 tensor of shape (1, num_tokens)
// @param sid Speaker ID. Used only for multi-speaker models, e.g., models
// trained using the VCTK dataset. It is not used for
// single-speaker models, e.g., models trained using the ljspeech
// dataset.
* @return Return a float32 tensor containing audio samples. You can flatten
* it to a 1-D tensor.
*/
Ort::Value Run(Ort::Value x);
Ort::Value Run(Ort::Value x, int64_t sid = 0);
// Sample rate of the generated audio
int32_t SampleRate() const;
... ...
... ... @@ -28,8 +28,9 @@ OfflineTts::OfflineTts(const OfflineTtsConfig &config)
OfflineTts::~OfflineTts() = default;
GeneratedAudio OfflineTts::Generate(const std::string &text) const {
return impl_->Generate(text);
GeneratedAudio OfflineTts::Generate(const std::string &text,
int64_t sid /*=0*/) const {
return impl_->Generate(text, sid);
}
} // namespace sherpa_onnx
... ...
... ... @@ -39,7 +39,11 @@ class OfflineTts {
~OfflineTts();
explicit OfflineTts(const OfflineTtsConfig &config);
// @param text A string containing words separated by spaces
GeneratedAudio Generate(const std::string &text) const;
// @param sid Speaker ID. Used only for multi-speaker models, e.g., models
// trained using the VCTK dataset. It is not used for
// single-speaker models, e.g., models trained using the ljspeech
// dataset.
GeneratedAudio Generate(const std::string &text, int64_t sid = 0) const;
private:
std::unique_ptr<OfflineTtsImpl> impl_;
... ...
... ... @@ -13,11 +13,12 @@ int main(int32_t argc, char *argv[]) {
Offline text-to-speech with sherpa-onnx
./bin/sherpa-onnx-offline-tts \
--vits-model /path/to/model.onnx \
--vits-lexicon /path/to/lexicon.txt \
--vits-tokens /path/to/tokens.txt
--output-filename ./generated.wav \
'some text within single quotes'
--vits-model=/path/to/model.onnx \
--vits-lexicon=/path/to/lexicon.txt \
--vits-tokens=/path/to/tokens.txt \
--sid=0 \
--output-filename=./generated.wav \
'some text within single quotes on linux/macos or use double quotes on windows'
It will generate a file ./generated.wav as specified by --output-filename.
... ... @@ -33,15 +34,27 @@ wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt
--vits-model=./vits-ljs.onnx \
--vits-lexicon=./lexicon.txt \
--vits-tokens=./tokens.txt \
--sid=0 \
--output-filename=./generated.wav \
'liliana, the most beautiful and lovely assistant of our team!'
Please see
https://k2-fsa.github.io/sherpa/onnx/tts/index.html
or detailes.
)usage";
sherpa_onnx::ParseOptions po(kUsageMessage);
std::string output_filename = "./generated.wav";
int32_t sid = 0;
po.Register("output-filename", &output_filename,
"Path to save the generated audio");
po.Register("sid", &sid,
"Speaker ID. Used only for multi-speaker models, e.g., models "
"trained using the VCTK dataset. Not used for single-speaker "
"models, e.g., models trained using the LJSpeech dataset");
sherpa_onnx::OfflineTtsConfig config;
config.Register(&po);
... ... @@ -67,7 +80,7 @@ wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt
}
sherpa_onnx::OfflineTts tts(config);
auto audio = tts.Generate(po.GetArg(1));
auto audio = tts.Generate(po.GetArg(1), sid);
bool ok = sherpa_onnx::WriteWave(output_filename, audio.sample_rate,
audio.samples.data(), audio.samples.size());
... ... @@ -76,7 +89,8 @@ wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt
exit(EXIT_FAILURE);
}
fprintf(stderr, "The text is: %s\n", po.GetArg(1).c_str());
fprintf(stderr, "The text is: %s. Speaker ID: %d\n", po.GetArg(1).c_str(),
sid);
fprintf(stderr, "Saved to %s successfully!\n", output_filename.c_str());
return 0;
... ...
... ... @@ -16,11 +16,16 @@ void PybindOfflineTtsVitsModelConfig(py::module *m) {
py::class_<PyClass>(*m, "OfflineTtsVitsModelConfig")
.def(py::init<>())
.def(py::init<const std::string &, const std::string &,
const std::string &>(),
py::arg("model"), py::arg("lexicon"), py::arg("tokens"))
const std::string &, float, float, float>(),
py::arg("model"), py::arg("lexicon"), py::arg("tokens"),
py::arg("noise_scale") = 0.667, py::arg("noise_scale_w") = 0.8,
py::arg("length_scale") = 1.0)
.def_readwrite("model", &PyClass::model)
.def_readwrite("lexicon", &PyClass::lexicon)
.def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("noise_scale", &PyClass::noise_scale)
.def_readwrite("noise_scale_w", &PyClass::noise_scale_w)
.def_readwrite("length_scale", &PyClass::length_scale)
.def("__str__", &PyClass::ToString);
}
... ...
... ... @@ -40,7 +40,7 @@ void PybindOfflineTts(py::module *m) {
using PyClass = OfflineTts;
py::class_<PyClass>(*m, "OfflineTts")
.def(py::init<const OfflineTtsConfig &>(), py::arg("config"))
.def("generate", &PyClass::Generate);
.def("generate", &PyClass::Generate, py::arg("text"), py::arg("sid") = 0);
}
} // namespace sherpa_onnx
... ...