Fangjun Kuang
Committed by GitHub

Support VITS VCTK models (#367)

* Support VITS VCTK models

* Release v1.8.1
1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR) 1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
2 project(sherpa-onnx) 2 project(sherpa-onnx)
3 3
4 -set(SHERPA_ONNX_VERSION "1.8.0") 4 +set(SHERPA_ONNX_VERSION "1.8.1")
5 5
6 # Disable warning about 6 # Disable warning about
7 # 7 #
@@ -59,6 +59,16 @@ def get_args(): @@ -59,6 +59,16 @@ def get_args():
59 ) 59 )
60 60
61 parser.add_argument( 61 parser.add_argument(
  62 + "--sid",
  63 + type=int,
  64 + default=0,
  65 + help="""Speaker ID. Used only for multi-speaker models, e.g.
  66 + models trained using the VCTK dataset. Not used for single-speaker
  67 + models, e.g., models trained using the LJ speech dataset.
  68 + """,
  69 + )
  70 +
  71 + parser.add_argument(
62 "--debug", 72 "--debug",
63 type=bool, 73 type=bool,
64 default=False, 74 default=False,
@@ -105,7 +115,7 @@ def main(): @@ -105,7 +115,7 @@ def main():
105 ) 115 )
106 ) 116 )
107 tts = sherpa_onnx.OfflineTts(tts_config) 117 tts = sherpa_onnx.OfflineTts(tts_config)
108 - audio = tts.generate(args.text) 118 + audio = tts.generate(args.text, sid=args.sid)
109 sf.write( 119 sf.write(
110 args.output_filename, 120 args.output_filename,
111 audio.samples, 121 audio.samples,
1 tokens-ljs.txt 1 tokens-ljs.txt
  2 +tokens-vctk.txt
@@ -191,6 +191,7 @@ def main(): @@ -191,6 +191,7 @@ def main():
191 "comment": "ljspeech", 191 "comment": "ljspeech",
192 "language": "English", 192 "language": "English",
193 "add_blank": int(hps.data.add_blank), 193 "add_blank": int(hps.data.add_blank),
  194 + "n_speakers": int(hps.data.n_speakers),
194 "sample_rate": hps.data.sampling_rate, 195 "sample_rate": hps.data.sampling_rate,
195 "punctuation": " ".join(list(_punctuation)), 196 "punctuation": " ".join(list(_punctuation)),
196 } 197 }
  1 +#!/usr/bin/env python3
  2 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +"""
  5 +This script converts vits models trained using the VCTK dataset.
  6 +
  7 +Usage:
  8 +
  9 +(1) Download vits
  10 +
  11 +cd /Users/fangjun/open-source
  12 +git clone https://github.com/jaywalnut310/vits
  13 +
  14 +(2) Download pre-trained models from
  15 +https://huggingface.co/csukuangfj/vits-vctk/tree/main
  16 +
  17 +wget https://huggingface.co/csukuangfj/vits-vctk/resolve/main/pretrained_vctk.pth
  18 +
  19 +(3) Run this file
  20 +
  21 +./export-onnx-vctk.py \
  22 + --config ~/open-source//vits/configs/vctk_base.json \
  23 + --checkpoint ~/open-source/icefall-models/vits-vctk/pretrained_vctk.pth
  24 +
  25 +It will generate the following two files:
  26 +
  27 +$ ls -lh *.onnx
  28 +-rw-r--r-- 1 fangjun staff 37M Oct 16 10:57 vits-vctk.int8.onnx
  29 +-rw-r--r-- 1 fangjun staff 116M Oct 16 10:57 vits-vctk.onnx
  30 +"""
  31 +import sys
  32 +
  33 +# Please change this line to point to the vits directory.
  34 +# You can download vits from
  35 +# https://github.com/jaywalnut310/vits
  36 +sys.path.insert(0, "/Users/fangjun/open-source/vits") # noqa
  37 +
  38 +import argparse
  39 +from pathlib import Path
  40 +from typing import Dict, Any
  41 +
  42 +import commons
  43 +import onnx
  44 +import torch
  45 +import utils
  46 +from models import SynthesizerTrn
  47 +from onnxruntime.quantization import QuantType, quantize_dynamic
  48 +from text import text_to_sequence
  49 +from text.symbols import symbols
  50 +from text.symbols import _punctuation
  51 +
  52 +
  53 +def get_args():
  54 + parser = argparse.ArgumentParser()
  55 + parser.add_argument(
  56 + "--config",
  57 + type=str,
  58 + required=True,
  59 + help="""Path to vctk_base.json.
  60 + You can find it at
  61 + https://huggingface.co/csukuangfj/vits-vctk/resolve/main/vctk_base.json
  62 + """,
  63 + )
  64 +
  65 + parser.add_argument(
  66 + "--checkpoint",
  67 + type=str,
  68 + required=True,
  69 + help="""Path to the checkpoint file.
  70 + You can find it at
  71 + https://huggingface.co/csukuangfj/vits-vctk/resolve/main/pretrained_vctk.pth
  72 + """,
  73 + )
  74 +
  75 + return parser.parse_args()
  76 +
  77 +
  78 +class OnnxModel(torch.nn.Module):
  79 + def __init__(self, model: SynthesizerTrn):
  80 + super().__init__()
  81 + self.model = model
  82 +
  83 + def forward(
  84 + self,
  85 + x,
  86 + x_lengths,
  87 + noise_scale=1,
  88 + length_scale=1,
  89 + noise_scale_w=1.0,
  90 + sid=0,
  91 + max_len=None,
  92 + ):
  93 + return self.model.infer(
  94 + x=x,
  95 + x_lengths=x_lengths,
  96 + sid=sid,
  97 + noise_scale=noise_scale,
  98 + length_scale=length_scale,
  99 + noise_scale_w=noise_scale_w,
  100 + max_len=max_len,
  101 + )[0]
  102 +
  103 +
  104 +def get_text(text, hps):
  105 + text_norm = text_to_sequence(text, hps.data.text_cleaners)
  106 + if hps.data.add_blank:
  107 + text_norm = commons.intersperse(text_norm, 0)
  108 + text_norm = torch.LongTensor(text_norm)
  109 + return text_norm
  110 +
  111 +
  112 +def check_args(args):
  113 + assert Path(args.config).is_file(), args.config
  114 + assert Path(args.checkpoint).is_file(), args.checkpoint
  115 +
  116 +
  117 +def add_meta_data(filename: str, meta_data: Dict[str, Any]):
  118 + """Add meta data to an ONNX model. It is changed in-place.
  119 +
  120 + Args:
  121 + filename:
  122 + Filename of the ONNX model to be changed.
  123 + meta_data:
  124 + Key-value pairs.
  125 + """
  126 + model = onnx.load(filename)
  127 + for key, value in meta_data.items():
  128 + meta = model.metadata_props.add()
  129 + meta.key = key
  130 + meta.value = str(value)
  131 +
  132 + onnx.save(model, filename)
  133 +
  134 +
  135 +def generate_tokens():
  136 + with open("tokens-vctk.txt", "w", encoding="utf-8") as f:
  137 + for i, s in enumerate(symbols):
  138 + f.write(f"{s} {i}\n")
  139 + print("Generated tokens-vctk.txt")
  140 +
  141 +
  142 +@torch.no_grad()
  143 +def main():
  144 + args = get_args()
  145 + check_args(args)
  146 +
  147 + generate_tokens()
  148 +
  149 + hps = utils.get_hparams_from_file(args.config)
  150 +
  151 + net_g = SynthesizerTrn(
  152 + len(symbols),
  153 + hps.data.filter_length // 2 + 1,
  154 + hps.train.segment_size // hps.data.hop_length,
  155 + n_speakers=hps.data.n_speakers,
  156 + **hps.model,
  157 + )
  158 + _ = net_g.eval()
  159 +
  160 + _ = utils.load_checkpoint(args.checkpoint, net_g, None)
  161 +
  162 + x = get_text("Liliana is the most beautiful assistant", hps)
  163 + x = x.unsqueeze(0)
  164 +
  165 + x_length = torch.tensor([x.shape[1]], dtype=torch.int64)
  166 + noise_scale = torch.tensor([1], dtype=torch.float32)
  167 + length_scale = torch.tensor([1], dtype=torch.float32)
  168 + noise_scale_w = torch.tensor([1], dtype=torch.float32)
  169 + sid = torch.tensor([0], dtype=torch.int64)
  170 +
  171 + model = OnnxModel(net_g)
  172 +
  173 + opset_version = 13
  174 +
  175 + filename = "vits-vctk.onnx"
  176 +
  177 + torch.onnx.export(
  178 + model,
  179 + (x, x_length, noise_scale, length_scale, noise_scale_w, sid),
  180 + filename,
  181 + opset_version=opset_version,
  182 + input_names=[
  183 + "x",
  184 + "x_length",
  185 + "noise_scale",
  186 + "length_scale",
  187 + "noise_scale_w",
  188 + "sid",
  189 + ],
  190 + output_names=["y"],
  191 + dynamic_axes={
  192 + "x": {0: "N", 1: "L"}, # n_audio is also known as batch_size
  193 + "x_length": {0: "N"},
  194 + "y": {0: "N", 2: "L"},
  195 + },
  196 + )
  197 + meta_data = {
  198 + "model_type": "vits",
  199 + "comment": "vctk",
  200 + "language": "English",
  201 + "add_blank": int(hps.data.add_blank),
  202 + "n_speakers": int(hps.data.n_speakers),
  203 + "sample_rate": hps.data.sampling_rate,
  204 + "punctuation": " ".join(list(_punctuation)),
  205 + }
  206 + print("meta_data", meta_data)
  207 + add_meta_data(filename=filename, meta_data=meta_data)
  208 +
  209 + print("Generate int8 quantization models")
  210 +
  211 + filename_int8 = "vits-vctk.int8.onnx"
  212 + quantize_dynamic(
  213 + model_input=filename,
  214 + model_output=filename_int8,
  215 + weight_type=QuantType.QUInt8,
  216 + )
  217 +
  218 + print(f"Saved to {filename} and {filename_int8}")
  219 +
  220 +
  221 +if __name__ == "__main__":
  222 + main()
@@ -18,7 +18,8 @@ class OfflineTtsImpl { @@ -18,7 +18,8 @@ class OfflineTtsImpl {
18 18
19 static std::unique_ptr<OfflineTtsImpl> Create(const OfflineTtsConfig &config); 19 static std::unique_ptr<OfflineTtsImpl> Create(const OfflineTtsConfig &config);
20 20
21 - virtual GeneratedAudio Generate(const std::string &text) const = 0; 21 + virtual GeneratedAudio Generate(const std::string &text,
  22 + int64_t sid = 0) const = 0;
22 }; 23 };
23 24
24 } // namespace sherpa_onnx 25 } // namespace sherpa_onnx
@@ -23,7 +23,8 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -23,7 +23,8 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
23 lexicon_(config.model.vits.lexicon, config.model.vits.tokens, 23 lexicon_(config.model.vits.lexicon, config.model.vits.tokens,
24 model_->Punctuations()) {} 24 model_->Punctuations()) {}
25 25
26 - GeneratedAudio Generate(const std::string &text) const override { 26 + GeneratedAudio Generate(const std::string &text,
  27 + int64_t sid = 0) const override {
27 std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text); 28 std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text);
28 if (x.empty()) { 29 if (x.empty()) {
29 SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str()); 30 SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
@@ -47,7 +48,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -47,7 +48,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
47 Ort::Value x_tensor = Ort::Value::CreateTensor( 48 Ort::Value x_tensor = Ort::Value::CreateTensor(
48 memory_info, x.data(), x.size(), x_shape.data(), x_shape.size()); 49 memory_info, x.data(), x.size(), x_shape.data(), x_shape.size());
49 50
50 - Ort::Value audio = model_->Run(std::move(x_tensor)); 51 + Ort::Value audio = model_->Run(std::move(x_tensor), sid);
51 52
52 std::vector<int64_t> audio_shape = 53 std::vector<int64_t> audio_shape =
53 audio.GetTensorTypeAndShapeInfo().GetShape(); 54 audio.GetTensorTypeAndShapeInfo().GetShape();
@@ -13,6 +13,11 @@ void OfflineTtsVitsModelConfig::Register(ParseOptions *po) { @@ -13,6 +13,11 @@ void OfflineTtsVitsModelConfig::Register(ParseOptions *po) {
13 po->Register("vits-model", &model, "Path to VITS model"); 13 po->Register("vits-model", &model, "Path to VITS model");
14 po->Register("vits-lexicon", &lexicon, "Path to lexicon.txt for VITS models"); 14 po->Register("vits-lexicon", &lexicon, "Path to lexicon.txt for VITS models");
15 po->Register("vits-tokens", &tokens, "Path to tokens.txt for VITS models"); 15 po->Register("vits-tokens", &tokens, "Path to tokens.txt for VITS models");
  16 + po->Register("vits-noise-scale", &noise_scale, "noise_scale for VITS models");
  17 + po->Register("vits-noise-scale-w", &noise_scale_w,
  18 + "noise_scale_w for VITS models");
  19 + po->Register("vits-length-scale", &length_scale,
  20 + "length_scale for VITS models");
16 } 21 }
17 22
18 bool OfflineTtsVitsModelConfig::Validate() const { 23 bool OfflineTtsVitsModelConfig::Validate() const {
@@ -55,7 +60,10 @@ std::string OfflineTtsVitsModelConfig::ToString() const { @@ -55,7 +60,10 @@ std::string OfflineTtsVitsModelConfig::ToString() const {
55 os << "OfflineTtsVitsModelConfig("; 60 os << "OfflineTtsVitsModelConfig(";
56 os << "model=\"" << model << "\", "; 61 os << "model=\"" << model << "\", ";
57 os << "lexicon=\"" << lexicon << "\", "; 62 os << "lexicon=\"" << lexicon << "\", ";
58 - os << "tokens=\"" << tokens << "\")"; 63 + os << "tokens=\"" << tokens << "\", ";
  64 + os << "noise_scale=" << noise_scale << ", ";
  65 + os << "noise_scale_w=" << noise_scale_w << ", ";
  66 + os << "length_scale=" << length_scale << ")";
59 67
60 return os.str(); 68 return os.str();
61 } 69 }
@@ -16,12 +16,26 @@ struct OfflineTtsVitsModelConfig { @@ -16,12 +16,26 @@ struct OfflineTtsVitsModelConfig {
16 std::string lexicon; 16 std::string lexicon;
17 std::string tokens; 17 std::string tokens;
18 18
  19 + float noise_scale = 0.667;
  20 + float noise_scale_w = 0.8;
  21 + float length_scale = 1;
  22 +
  23 + // used only for multi-speaker models, e.g, vctk speech dataset.
  24 + // Not applicable for single-speaker models, e.g., ljspeech dataset
  25 +
19 OfflineTtsVitsModelConfig() = default; 26 OfflineTtsVitsModelConfig() = default;
20 27
21 OfflineTtsVitsModelConfig(const std::string &model, 28 OfflineTtsVitsModelConfig(const std::string &model,
22 const std::string &lexicon, 29 const std::string &lexicon,
23 - const std::string &tokens)  
24 - : model(model), lexicon(lexicon), tokens(tokens) {} 30 + const std::string &tokens,
  31 + float noise_scale = 0.667,
  32 + float noise_scale_w = 0.8, float length_scale = 1)
  33 + : model(model),
  34 + lexicon(lexicon),
  35 + tokens(tokens),
  36 + noise_scale(noise_scale),
  37 + noise_scale_w(noise_scale_w),
  38 + length_scale(length_scale) {}
25 39
26 void Register(ParseOptions *po); 40 void Register(ParseOptions *po);
27 bool Validate() const; 41 bool Validate() const;
@@ -26,7 +26,7 @@ class OfflineTtsVitsModel::Impl { @@ -26,7 +26,7 @@ class OfflineTtsVitsModel::Impl {
26 Init(buf.data(), buf.size()); 26 Init(buf.data(), buf.size());
27 } 27 }
28 28
29 - Ort::Value Run(Ort::Value x) { 29 + Ort::Value Run(Ort::Value x, int64_t sid) {
30 auto memory_info = 30 auto memory_info =
31 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); 31 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
32 32
@@ -44,20 +44,33 @@ class OfflineTtsVitsModel::Impl { @@ -44,20 +44,33 @@ class OfflineTtsVitsModel::Impl {
44 Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1); 44 Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1);
45 45
46 int64_t scale_shape = 1; 46 int64_t scale_shape = 1;
47 - float noise_scale = 1;  
48 - float length_scale = 1;  
49 - float noise_scale_w = 1; 47 + float noise_scale = config_.vits.noise_scale;
  48 + float length_scale = config_.vits.length_scale;
  49 + float noise_scale_w = config_.vits.noise_scale_w;
50 50
51 Ort::Value noise_scale_tensor = 51 Ort::Value noise_scale_tensor =
52 Ort::Value::CreateTensor(memory_info, &noise_scale, 1, &scale_shape, 1); 52 Ort::Value::CreateTensor(memory_info, &noise_scale, 1, &scale_shape, 1);
  53 +
53 Ort::Value length_scale_tensor = Ort::Value::CreateTensor( 54 Ort::Value length_scale_tensor = Ort::Value::CreateTensor(
54 memory_info, &length_scale, 1, &scale_shape, 1); 55 memory_info, &length_scale, 1, &scale_shape, 1);
  56 +
55 Ort::Value noise_scale_w_tensor = Ort::Value::CreateTensor( 57 Ort::Value noise_scale_w_tensor = Ort::Value::CreateTensor(
56 memory_info, &noise_scale_w, 1, &scale_shape, 1); 58 memory_info, &noise_scale_w, 1, &scale_shape, 1);
57 59
58 - std::array<Ort::Value, 5> inputs = {  
59 - std::move(x), std::move(x_length), std::move(noise_scale_tensor),  
60 - std::move(length_scale_tensor), std::move(noise_scale_w_tensor)}; 60 + Ort::Value sid_tensor =
  61 + Ort::Value::CreateTensor(memory_info, &sid, 1, &scale_shape, 1);
  62 +
  63 + std::vector<Ort::Value> inputs;
  64 + inputs.reserve(6);
  65 + inputs.push_back(std::move(x));
  66 + inputs.push_back(std::move(x_length));
  67 + inputs.push_back(std::move(noise_scale_tensor));
  68 + inputs.push_back(std::move(length_scale_tensor));
  69 + inputs.push_back(std::move(noise_scale_w_tensor));
  70 +
  71 + if (input_names_.size() == 6 && input_names_.back() == "sid") {
  72 + inputs.push_back(std::move(sid_tensor));
  73 + }
61 74
62 auto out = 75 auto out =
63 sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), 76 sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
@@ -93,6 +106,7 @@ class OfflineTtsVitsModel::Impl { @@ -93,6 +106,7 @@ class OfflineTtsVitsModel::Impl {
93 Ort::AllocatorWithDefaultOptions allocator; // used in the macro below 106 Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
94 SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate"); 107 SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate");
95 SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank"); 108 SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank");
  109 + SHERPA_ONNX_READ_META_DATA(n_speakers_, "n_speakers");
96 SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation"); 110 SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation");
97 } 111 }
98 112
@@ -112,6 +126,7 @@ class OfflineTtsVitsModel::Impl { @@ -112,6 +126,7 @@ class OfflineTtsVitsModel::Impl {
112 126
113 int32_t sample_rate_; 127 int32_t sample_rate_;
114 int32_t add_blank_; 128 int32_t add_blank_;
  129 + int32_t n_speakers_;
115 std::string punctuations_; 130 std::string punctuations_;
116 }; 131 };
117 132
@@ -120,8 +135,8 @@ OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config) @@ -120,8 +135,8 @@ OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config)
120 135
121 OfflineTtsVitsModel::~OfflineTtsVitsModel() = default; 136 OfflineTtsVitsModel::~OfflineTtsVitsModel() = default;
122 137
123 -Ort::Value OfflineTtsVitsModel::Run(Ort::Value x) {  
124 - return impl_->Run(std::move(x)); 138 +Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, int64_t sid /*=0*/) {
  139 + return impl_->Run(std::move(x), sid);
125 } 140 }
126 141
127 int32_t OfflineTtsVitsModel::SampleRate() const { return impl_->SampleRate(); } 142 int32_t OfflineTtsVitsModel::SampleRate() const { return impl_->SampleRate(); }
@@ -22,10 +22,14 @@ class OfflineTtsVitsModel { @@ -22,10 +22,14 @@ class OfflineTtsVitsModel {
22 /** Run the model. 22 /** Run the model.
23 * 23 *
24 * @param x A int64 tensor of shape (1, num_tokens) 24 * @param x A int64 tensor of shape (1, num_tokens)
  25 + // @param sid Speaker ID. Used only for multi-speaker models, e.g., models
  26 + // trained using the VCTK dataset. It is not used for
  27 + // single-speaker models, e.g., models trained using the ljspeech
  28 + // dataset.
25 * @return Return a float32 tensor containing audio samples. You can flatten 29 * @return Return a float32 tensor containing audio samples. You can flatten
26 * it to a 1-D tensor. 30 * it to a 1-D tensor.
27 */ 31 */
28 - Ort::Value Run(Ort::Value x); 32 + Ort::Value Run(Ort::Value x, int64_t sid = 0);
29 33
30 // Sample rate of the generated audio 34 // Sample rate of the generated audio
31 int32_t SampleRate() const; 35 int32_t SampleRate() const;
@@ -28,8 +28,9 @@ OfflineTts::OfflineTts(const OfflineTtsConfig &config) @@ -28,8 +28,9 @@ OfflineTts::OfflineTts(const OfflineTtsConfig &config)
28 28
29 OfflineTts::~OfflineTts() = default; 29 OfflineTts::~OfflineTts() = default;
30 30
31 -GeneratedAudio OfflineTts::Generate(const std::string &text) const {  
32 - return impl_->Generate(text); 31 +GeneratedAudio OfflineTts::Generate(const std::string &text,
  32 + int64_t sid /*=0*/) const {
  33 + return impl_->Generate(text, sid);
33 } 34 }
34 35
35 } // namespace sherpa_onnx 36 } // namespace sherpa_onnx
@@ -39,7 +39,11 @@ class OfflineTts { @@ -39,7 +39,11 @@ class OfflineTts {
39 ~OfflineTts(); 39 ~OfflineTts();
40 explicit OfflineTts(const OfflineTtsConfig &config); 40 explicit OfflineTts(const OfflineTtsConfig &config);
41 // @param text A string containing words separated by spaces 41 // @param text A string containing words separated by spaces
42 - GeneratedAudio Generate(const std::string &text) const; 42 + // @param sid Speaker ID. Used only for multi-speaker models, e.g., models
  43 + // trained using the VCTK dataset. It is not used for
  44 + // single-speaker models, e.g., models trained using the ljspeech
  45 + // dataset.
  46 + GeneratedAudio Generate(const std::string &text, int64_t sid = 0) const;
43 47
44 private: 48 private:
45 std::unique_ptr<OfflineTtsImpl> impl_; 49 std::unique_ptr<OfflineTtsImpl> impl_;
@@ -13,11 +13,12 @@ int main(int32_t argc, char *argv[]) { @@ -13,11 +13,12 @@ int main(int32_t argc, char *argv[]) {
13 Offline text-to-speech with sherpa-onnx 13 Offline text-to-speech with sherpa-onnx
14 14
15 ./bin/sherpa-onnx-offline-tts \ 15 ./bin/sherpa-onnx-offline-tts \
16 - --vits-model /path/to/model.onnx \  
17 - --vits-lexicon /path/to/lexicon.txt \  
18 - --vits-tokens /path/to/tokens.txt  
19 - --output-filename ./generated.wav \  
20 - 'some text within single quotes' 16 + --vits-model=/path/to/model.onnx \
  17 + --vits-lexicon=/path/to/lexicon.txt \
  18 + --vits-tokens=/path/to/tokens.txt \
  19 + --sid=0 \
  20 + --output-filename=./generated.wav \
  21 + 'some text within single quotes on linux/macos or use double quotes on windows'
21 22
22 It will generate a file ./generated.wav as specified by --output-filename. 23 It will generate a file ./generated.wav as specified by --output-filename.
23 24
@@ -33,15 +34,27 @@ wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt @@ -33,15 +34,27 @@ wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt
33 --vits-model=./vits-ljs.onnx \ 34 --vits-model=./vits-ljs.onnx \
34 --vits-lexicon=./lexicon.txt \ 35 --vits-lexicon=./lexicon.txt \
35 --vits-tokens=./tokens.txt \ 36 --vits-tokens=./tokens.txt \
  37 + --sid=0 \
36 --output-filename=./generated.wav \ 38 --output-filename=./generated.wav \
37 'liliana, the most beautiful and lovely assistant of our team!' 39 'liliana, the most beautiful and lovely assistant of our team!'
  40 +
  41 +Please see
  42 +https://k2-fsa.github.io/sherpa/onnx/tts/index.html
  43 +or detailes.
38 )usage"; 44 )usage";
39 45
40 sherpa_onnx::ParseOptions po(kUsageMessage); 46 sherpa_onnx::ParseOptions po(kUsageMessage);
41 std::string output_filename = "./generated.wav"; 47 std::string output_filename = "./generated.wav";
  48 + int32_t sid = 0;
  49 +
42 po.Register("output-filename", &output_filename, 50 po.Register("output-filename", &output_filename,
43 "Path to save the generated audio"); 51 "Path to save the generated audio");
44 52
  53 + po.Register("sid", &sid,
  54 + "Speaker ID. Used only for multi-speaker models, e.g., models "
  55 + "trained using the VCTK dataset. Not used for single-speaker "
  56 + "models, e.g., models trained using the LJSpeech dataset");
  57 +
45 sherpa_onnx::OfflineTtsConfig config; 58 sherpa_onnx::OfflineTtsConfig config;
46 59
47 config.Register(&po); 60 config.Register(&po);
@@ -67,7 +80,7 @@ wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt @@ -67,7 +80,7 @@ wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt
67 } 80 }
68 81
69 sherpa_onnx::OfflineTts tts(config); 82 sherpa_onnx::OfflineTts tts(config);
70 - auto audio = tts.Generate(po.GetArg(1)); 83 + auto audio = tts.Generate(po.GetArg(1), sid);
71 84
72 bool ok = sherpa_onnx::WriteWave(output_filename, audio.sample_rate, 85 bool ok = sherpa_onnx::WriteWave(output_filename, audio.sample_rate,
73 audio.samples.data(), audio.samples.size()); 86 audio.samples.data(), audio.samples.size());
@@ -76,7 +89,8 @@ wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt @@ -76,7 +89,8 @@ wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt
76 exit(EXIT_FAILURE); 89 exit(EXIT_FAILURE);
77 } 90 }
78 91
79 - fprintf(stderr, "The text is: %s\n", po.GetArg(1).c_str()); 92 + fprintf(stderr, "The text is: %s. Speaker ID: %d\n", po.GetArg(1).c_str(),
  93 + sid);
80 fprintf(stderr, "Saved to %s successfully!\n", output_filename.c_str()); 94 fprintf(stderr, "Saved to %s successfully!\n", output_filename.c_str());
81 95
82 return 0; 96 return 0;
@@ -16,11 +16,16 @@ void PybindOfflineTtsVitsModelConfig(py::module *m) { @@ -16,11 +16,16 @@ void PybindOfflineTtsVitsModelConfig(py::module *m) {
16 py::class_<PyClass>(*m, "OfflineTtsVitsModelConfig") 16 py::class_<PyClass>(*m, "OfflineTtsVitsModelConfig")
17 .def(py::init<>()) 17 .def(py::init<>())
18 .def(py::init<const std::string &, const std::string &, 18 .def(py::init<const std::string &, const std::string &,
19 - const std::string &>(),  
20 - py::arg("model"), py::arg("lexicon"), py::arg("tokens")) 19 + const std::string &, float, float, float>(),
  20 + py::arg("model"), py::arg("lexicon"), py::arg("tokens"),
  21 + py::arg("noise_scale") = 0.667, py::arg("noise_scale_w") = 0.8,
  22 + py::arg("length_scale") = 1.0)
21 .def_readwrite("model", &PyClass::model) 23 .def_readwrite("model", &PyClass::model)
22 .def_readwrite("lexicon", &PyClass::lexicon) 24 .def_readwrite("lexicon", &PyClass::lexicon)
23 .def_readwrite("tokens", &PyClass::tokens) 25 .def_readwrite("tokens", &PyClass::tokens)
  26 + .def_readwrite("noise_scale", &PyClass::noise_scale)
  27 + .def_readwrite("noise_scale_w", &PyClass::noise_scale_w)
  28 + .def_readwrite("length_scale", &PyClass::length_scale)
24 .def("__str__", &PyClass::ToString); 29 .def("__str__", &PyClass::ToString);
25 } 30 }
26 31
@@ -40,7 +40,7 @@ void PybindOfflineTts(py::module *m) { @@ -40,7 +40,7 @@ void PybindOfflineTts(py::module *m) {
40 using PyClass = OfflineTts; 40 using PyClass = OfflineTts;
41 py::class_<PyClass>(*m, "OfflineTts") 41 py::class_<PyClass>(*m, "OfflineTts")
42 .def(py::init<const OfflineTtsConfig &>(), py::arg("config")) 42 .def(py::init<const OfflineTtsConfig &>(), py::arg("config"))
43 - .def("generate", &PyClass::Generate); 43 + .def("generate", &PyClass::Generate, py::arg("text"), py::arg("sid") = 0);
44 } 44 }
45 45
46 } // namespace sherpa_onnx 46 } // namespace sherpa_onnx