Committed by
GitHub
Support VITS VCTK models (#367)
* Support VITS VCTK models * Release v1.8.1
正在显示
16 个修改的文件
包含
332 行增加
和
31 行删除
| @@ -59,6 +59,16 @@ def get_args(): | @@ -59,6 +59,16 @@ def get_args(): | ||
| 59 | ) | 59 | ) |
| 60 | 60 | ||
| 61 | parser.add_argument( | 61 | parser.add_argument( |
| 62 | + "--sid", | ||
| 63 | + type=int, | ||
| 64 | + default=0, | ||
| 65 | + help="""Speaker ID. Used only for multi-speaker models, e.g. | ||
| 66 | + models trained using the VCTK dataset. Not used for single-speaker | ||
| 67 | + models, e.g., models trained using the LJ speech dataset. | ||
| 68 | + """, | ||
| 69 | + ) | ||
| 70 | + | ||
| 71 | + parser.add_argument( | ||
| 62 | "--debug", | 72 | "--debug", |
| 63 | type=bool, | 73 | type=bool, |
| 64 | default=False, | 74 | default=False, |
| @@ -105,7 +115,7 @@ def main(): | @@ -105,7 +115,7 @@ def main(): | ||
| 105 | ) | 115 | ) |
| 106 | ) | 116 | ) |
| 107 | tts = sherpa_onnx.OfflineTts(tts_config) | 117 | tts = sherpa_onnx.OfflineTts(tts_config) |
| 108 | - audio = tts.generate(args.text) | 118 | + audio = tts.generate(args.text, sid=args.sid) |
| 109 | sf.write( | 119 | sf.write( |
| 110 | args.output_filename, | 120 | args.output_filename, |
| 111 | audio.samples, | 121 | audio.samples, |
| @@ -191,6 +191,7 @@ def main(): | @@ -191,6 +191,7 @@ def main(): | ||
| 191 | "comment": "ljspeech", | 191 | "comment": "ljspeech", |
| 192 | "language": "English", | 192 | "language": "English", |
| 193 | "add_blank": int(hps.data.add_blank), | 193 | "add_blank": int(hps.data.add_blank), |
| 194 | + "n_speakers": int(hps.data.n_speakers), | ||
| 194 | "sample_rate": hps.data.sampling_rate, | 195 | "sample_rate": hps.data.sampling_rate, |
| 195 | "punctuation": " ".join(list(_punctuation)), | 196 | "punctuation": " ".join(list(_punctuation)), |
| 196 | } | 197 | } |
scripts/vits/export-onnx-vctk.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +""" | ||
| 5 | +This script converts vits models trained using the VCTK dataset. | ||
| 6 | + | ||
| 7 | +Usage: | ||
| 8 | + | ||
| 9 | +(1) Download vits | ||
| 10 | + | ||
| 11 | +cd /Users/fangjun/open-source | ||
| 12 | +git clone https://github.com/jaywalnut310/vits | ||
| 13 | + | ||
| 14 | +(2) Download pre-trained models from | ||
| 15 | +https://huggingface.co/csukuangfj/vits-vctk/tree/main | ||
| 16 | + | ||
| 17 | +wget https://huggingface.co/csukuangfj/vits-vctk/resolve/main/pretrained_vctk.pth | ||
| 18 | + | ||
| 19 | +(3) Run this file | ||
| 20 | + | ||
| 21 | +./export-onnx-vctk.py \ | ||
| 22 | + --config ~/open-source//vits/configs/vctk_base.json \ | ||
| 23 | + --checkpoint ~/open-source/icefall-models/vits-vctk/pretrained_vctk.pth | ||
| 24 | + | ||
| 25 | +It will generate the following two files: | ||
| 26 | + | ||
| 27 | +$ ls -lh *.onnx | ||
| 28 | +-rw-r--r-- 1 fangjun staff 37M Oct 16 10:57 vits-vctk.int8.onnx | ||
| 29 | +-rw-r--r-- 1 fangjun staff 116M Oct 16 10:57 vits-vctk.onnx | ||
| 30 | +""" | ||
| 31 | +import sys | ||
| 32 | + | ||
| 33 | +# Please change this line to point to the vits directory. | ||
| 34 | +# You can download vits from | ||
| 35 | +# https://github.com/jaywalnut310/vits | ||
| 36 | +sys.path.insert(0, "/Users/fangjun/open-source/vits") # noqa | ||
| 37 | + | ||
| 38 | +import argparse | ||
| 39 | +from pathlib import Path | ||
| 40 | +from typing import Dict, Any | ||
| 41 | + | ||
| 42 | +import commons | ||
| 43 | +import onnx | ||
| 44 | +import torch | ||
| 45 | +import utils | ||
| 46 | +from models import SynthesizerTrn | ||
| 47 | +from onnxruntime.quantization import QuantType, quantize_dynamic | ||
| 48 | +from text import text_to_sequence | ||
| 49 | +from text.symbols import symbols | ||
| 50 | +from text.symbols import _punctuation | ||
| 51 | + | ||
| 52 | + | ||
| 53 | +def get_args(): | ||
| 54 | + parser = argparse.ArgumentParser() | ||
| 55 | + parser.add_argument( | ||
| 56 | + "--config", | ||
| 57 | + type=str, | ||
| 58 | + required=True, | ||
| 59 | + help="""Path to vctk_base.json. | ||
| 60 | + You can find it at | ||
| 61 | + https://huggingface.co/csukuangfj/vits-vctk/resolve/main/vctk_base.json | ||
| 62 | + """, | ||
| 63 | + ) | ||
| 64 | + | ||
| 65 | + parser.add_argument( | ||
| 66 | + "--checkpoint", | ||
| 67 | + type=str, | ||
| 68 | + required=True, | ||
| 69 | + help="""Path to the checkpoint file. | ||
| 70 | + You can find it at | ||
| 71 | + https://huggingface.co/csukuangfj/vits-vctk/resolve/main/pretrained_vctk.pth | ||
| 72 | + """, | ||
| 73 | + ) | ||
| 74 | + | ||
| 75 | + return parser.parse_args() | ||
| 76 | + | ||
| 77 | + | ||
| 78 | +class OnnxModel(torch.nn.Module): | ||
| 79 | + def __init__(self, model: SynthesizerTrn): | ||
| 80 | + super().__init__() | ||
| 81 | + self.model = model | ||
| 82 | + | ||
| 83 | + def forward( | ||
| 84 | + self, | ||
| 85 | + x, | ||
| 86 | + x_lengths, | ||
| 87 | + noise_scale=1, | ||
| 88 | + length_scale=1, | ||
| 89 | + noise_scale_w=1.0, | ||
| 90 | + sid=0, | ||
| 91 | + max_len=None, | ||
| 92 | + ): | ||
| 93 | + return self.model.infer( | ||
| 94 | + x=x, | ||
| 95 | + x_lengths=x_lengths, | ||
| 96 | + sid=sid, | ||
| 97 | + noise_scale=noise_scale, | ||
| 98 | + length_scale=length_scale, | ||
| 99 | + noise_scale_w=noise_scale_w, | ||
| 100 | + max_len=max_len, | ||
| 101 | + )[0] | ||
| 102 | + | ||
| 103 | + | ||
| 104 | +def get_text(text, hps): | ||
| 105 | + text_norm = text_to_sequence(text, hps.data.text_cleaners) | ||
| 106 | + if hps.data.add_blank: | ||
| 107 | + text_norm = commons.intersperse(text_norm, 0) | ||
| 108 | + text_norm = torch.LongTensor(text_norm) | ||
| 109 | + return text_norm | ||
| 110 | + | ||
| 111 | + | ||
| 112 | +def check_args(args): | ||
| 113 | + assert Path(args.config).is_file(), args.config | ||
| 114 | + assert Path(args.checkpoint).is_file(), args.checkpoint | ||
| 115 | + | ||
| 116 | + | ||
| 117 | +def add_meta_data(filename: str, meta_data: Dict[str, Any]): | ||
| 118 | + """Add meta data to an ONNX model. It is changed in-place. | ||
| 119 | + | ||
| 120 | + Args: | ||
| 121 | + filename: | ||
| 122 | + Filename of the ONNX model to be changed. | ||
| 123 | + meta_data: | ||
| 124 | + Key-value pairs. | ||
| 125 | + """ | ||
| 126 | + model = onnx.load(filename) | ||
| 127 | + for key, value in meta_data.items(): | ||
| 128 | + meta = model.metadata_props.add() | ||
| 129 | + meta.key = key | ||
| 130 | + meta.value = str(value) | ||
| 131 | + | ||
| 132 | + onnx.save(model, filename) | ||
| 133 | + | ||
| 134 | + | ||
| 135 | +def generate_tokens(): | ||
| 136 | + with open("tokens-vctk.txt", "w", encoding="utf-8") as f: | ||
| 137 | + for i, s in enumerate(symbols): | ||
| 138 | + f.write(f"{s} {i}\n") | ||
| 139 | + print("Generated tokens-vctk.txt") | ||
| 140 | + | ||
| 141 | + | ||
| 142 | +@torch.no_grad() | ||
| 143 | +def main(): | ||
| 144 | + args = get_args() | ||
| 145 | + check_args(args) | ||
| 146 | + | ||
| 147 | + generate_tokens() | ||
| 148 | + | ||
| 149 | + hps = utils.get_hparams_from_file(args.config) | ||
| 150 | + | ||
| 151 | + net_g = SynthesizerTrn( | ||
| 152 | + len(symbols), | ||
| 153 | + hps.data.filter_length // 2 + 1, | ||
| 154 | + hps.train.segment_size // hps.data.hop_length, | ||
| 155 | + n_speakers=hps.data.n_speakers, | ||
| 156 | + **hps.model, | ||
| 157 | + ) | ||
| 158 | + _ = net_g.eval() | ||
| 159 | + | ||
| 160 | + _ = utils.load_checkpoint(args.checkpoint, net_g, None) | ||
| 161 | + | ||
| 162 | + x = get_text("Liliana is the most beautiful assistant", hps) | ||
| 163 | + x = x.unsqueeze(0) | ||
| 164 | + | ||
| 165 | + x_length = torch.tensor([x.shape[1]], dtype=torch.int64) | ||
| 166 | + noise_scale = torch.tensor([1], dtype=torch.float32) | ||
| 167 | + length_scale = torch.tensor([1], dtype=torch.float32) | ||
| 168 | + noise_scale_w = torch.tensor([1], dtype=torch.float32) | ||
| 169 | + sid = torch.tensor([0], dtype=torch.int64) | ||
| 170 | + | ||
| 171 | + model = OnnxModel(net_g) | ||
| 172 | + | ||
| 173 | + opset_version = 13 | ||
| 174 | + | ||
| 175 | + filename = "vits-vctk.onnx" | ||
| 176 | + | ||
| 177 | + torch.onnx.export( | ||
| 178 | + model, | ||
| 179 | + (x, x_length, noise_scale, length_scale, noise_scale_w, sid), | ||
| 180 | + filename, | ||
| 181 | + opset_version=opset_version, | ||
| 182 | + input_names=[ | ||
| 183 | + "x", | ||
| 184 | + "x_length", | ||
| 185 | + "noise_scale", | ||
| 186 | + "length_scale", | ||
| 187 | + "noise_scale_w", | ||
| 188 | + "sid", | ||
| 189 | + ], | ||
| 190 | + output_names=["y"], | ||
| 191 | + dynamic_axes={ | ||
| 192 | + "x": {0: "N", 1: "L"}, # n_audio is also known as batch_size | ||
| 193 | + "x_length": {0: "N"}, | ||
| 194 | + "y": {0: "N", 2: "L"}, | ||
| 195 | + }, | ||
| 196 | + ) | ||
| 197 | + meta_data = { | ||
| 198 | + "model_type": "vits", | ||
| 199 | + "comment": "vctk", | ||
| 200 | + "language": "English", | ||
| 201 | + "add_blank": int(hps.data.add_blank), | ||
| 202 | + "n_speakers": int(hps.data.n_speakers), | ||
| 203 | + "sample_rate": hps.data.sampling_rate, | ||
| 204 | + "punctuation": " ".join(list(_punctuation)), | ||
| 205 | + } | ||
| 206 | + print("meta_data", meta_data) | ||
| 207 | + add_meta_data(filename=filename, meta_data=meta_data) | ||
| 208 | + | ||
| 209 | + print("Generate int8 quantization models") | ||
| 210 | + | ||
| 211 | + filename_int8 = "vits-vctk.int8.onnx" | ||
| 212 | + quantize_dynamic( | ||
| 213 | + model_input=filename, | ||
| 214 | + model_output=filename_int8, | ||
| 215 | + weight_type=QuantType.QUInt8, | ||
| 216 | + ) | ||
| 217 | + | ||
| 218 | + print(f"Saved to {filename} and {filename_int8}") | ||
| 219 | + | ||
| 220 | + | ||
| 221 | +if __name__ == "__main__": | ||
| 222 | + main() |
| @@ -18,7 +18,8 @@ class OfflineTtsImpl { | @@ -18,7 +18,8 @@ class OfflineTtsImpl { | ||
| 18 | 18 | ||
| 19 | static std::unique_ptr<OfflineTtsImpl> Create(const OfflineTtsConfig &config); | 19 | static std::unique_ptr<OfflineTtsImpl> Create(const OfflineTtsConfig &config); |
| 20 | 20 | ||
| 21 | - virtual GeneratedAudio Generate(const std::string &text) const = 0; | 21 | + virtual GeneratedAudio Generate(const std::string &text, |
| 22 | + int64_t sid = 0) const = 0; | ||
| 22 | }; | 23 | }; |
| 23 | 24 | ||
| 24 | } // namespace sherpa_onnx | 25 | } // namespace sherpa_onnx |
| @@ -23,7 +23,8 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | @@ -23,7 +23,8 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | ||
| 23 | lexicon_(config.model.vits.lexicon, config.model.vits.tokens, | 23 | lexicon_(config.model.vits.lexicon, config.model.vits.tokens, |
| 24 | model_->Punctuations()) {} | 24 | model_->Punctuations()) {} |
| 25 | 25 | ||
| 26 | - GeneratedAudio Generate(const std::string &text) const override { | 26 | + GeneratedAudio Generate(const std::string &text, |
| 27 | + int64_t sid = 0) const override { | ||
| 27 | std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text); | 28 | std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text); |
| 28 | if (x.empty()) { | 29 | if (x.empty()) { |
| 29 | SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str()); | 30 | SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str()); |
| @@ -47,7 +48,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | @@ -47,7 +48,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | ||
| 47 | Ort::Value x_tensor = Ort::Value::CreateTensor( | 48 | Ort::Value x_tensor = Ort::Value::CreateTensor( |
| 48 | memory_info, x.data(), x.size(), x_shape.data(), x_shape.size()); | 49 | memory_info, x.data(), x.size(), x_shape.data(), x_shape.size()); |
| 49 | 50 | ||
| 50 | - Ort::Value audio = model_->Run(std::move(x_tensor)); | 51 | + Ort::Value audio = model_->Run(std::move(x_tensor), sid); |
| 51 | 52 | ||
| 52 | std::vector<int64_t> audio_shape = | 53 | std::vector<int64_t> audio_shape = |
| 53 | audio.GetTensorTypeAndShapeInfo().GetShape(); | 54 | audio.GetTensorTypeAndShapeInfo().GetShape(); |
| @@ -13,6 +13,11 @@ void OfflineTtsVitsModelConfig::Register(ParseOptions *po) { | @@ -13,6 +13,11 @@ void OfflineTtsVitsModelConfig::Register(ParseOptions *po) { | ||
| 13 | po->Register("vits-model", &model, "Path to VITS model"); | 13 | po->Register("vits-model", &model, "Path to VITS model"); |
| 14 | po->Register("vits-lexicon", &lexicon, "Path to lexicon.txt for VITS models"); | 14 | po->Register("vits-lexicon", &lexicon, "Path to lexicon.txt for VITS models"); |
| 15 | po->Register("vits-tokens", &tokens, "Path to tokens.txt for VITS models"); | 15 | po->Register("vits-tokens", &tokens, "Path to tokens.txt for VITS models"); |
| 16 | + po->Register("vits-noise-scale", &noise_scale, "noise_scale for VITS models"); | ||
| 17 | + po->Register("vits-noise-scale-w", &noise_scale_w, | ||
| 18 | + "noise_scale_w for VITS models"); | ||
| 19 | + po->Register("vits-length-scale", &length_scale, | ||
| 20 | + "length_scale for VITS models"); | ||
| 16 | } | 21 | } |
| 17 | 22 | ||
| 18 | bool OfflineTtsVitsModelConfig::Validate() const { | 23 | bool OfflineTtsVitsModelConfig::Validate() const { |
| @@ -55,7 +60,10 @@ std::string OfflineTtsVitsModelConfig::ToString() const { | @@ -55,7 +60,10 @@ std::string OfflineTtsVitsModelConfig::ToString() const { | ||
| 55 | os << "OfflineTtsVitsModelConfig("; | 60 | os << "OfflineTtsVitsModelConfig("; |
| 56 | os << "model=\"" << model << "\", "; | 61 | os << "model=\"" << model << "\", "; |
| 57 | os << "lexicon=\"" << lexicon << "\", "; | 62 | os << "lexicon=\"" << lexicon << "\", "; |
| 58 | - os << "tokens=\"" << tokens << "\")"; | 63 | + os << "tokens=\"" << tokens << "\", "; |
| 64 | + os << "noise_scale=" << noise_scale << ", "; | ||
| 65 | + os << "noise_scale_w=" << noise_scale_w << ", "; | ||
| 66 | + os << "length_scale=" << length_scale << ")"; | ||
| 59 | 67 | ||
| 60 | return os.str(); | 68 | return os.str(); |
| 61 | } | 69 | } |
| @@ -16,12 +16,26 @@ struct OfflineTtsVitsModelConfig { | @@ -16,12 +16,26 @@ struct OfflineTtsVitsModelConfig { | ||
| 16 | std::string lexicon; | 16 | std::string lexicon; |
| 17 | std::string tokens; | 17 | std::string tokens; |
| 18 | 18 | ||
| 19 | + float noise_scale = 0.667; | ||
| 20 | + float noise_scale_w = 0.8; | ||
| 21 | + float length_scale = 1; | ||
| 22 | + | ||
| 23 | + // used only for multi-speaker models, e.g, vctk speech dataset. | ||
| 24 | + // Not applicable for single-speaker models, e.g., ljspeech dataset | ||
| 25 | + | ||
| 19 | OfflineTtsVitsModelConfig() = default; | 26 | OfflineTtsVitsModelConfig() = default; |
| 20 | 27 | ||
| 21 | OfflineTtsVitsModelConfig(const std::string &model, | 28 | OfflineTtsVitsModelConfig(const std::string &model, |
| 22 | const std::string &lexicon, | 29 | const std::string &lexicon, |
| 23 | - const std::string &tokens) | ||
| 24 | - : model(model), lexicon(lexicon), tokens(tokens) {} | 30 | + const std::string &tokens, |
| 31 | + float noise_scale = 0.667, | ||
| 32 | + float noise_scale_w = 0.8, float length_scale = 1) | ||
| 33 | + : model(model), | ||
| 34 | + lexicon(lexicon), | ||
| 35 | + tokens(tokens), | ||
| 36 | + noise_scale(noise_scale), | ||
| 37 | + noise_scale_w(noise_scale_w), | ||
| 38 | + length_scale(length_scale) {} | ||
| 25 | 39 | ||
| 26 | void Register(ParseOptions *po); | 40 | void Register(ParseOptions *po); |
| 27 | bool Validate() const; | 41 | bool Validate() const; |
| @@ -26,7 +26,7 @@ class OfflineTtsVitsModel::Impl { | @@ -26,7 +26,7 @@ class OfflineTtsVitsModel::Impl { | ||
| 26 | Init(buf.data(), buf.size()); | 26 | Init(buf.data(), buf.size()); |
| 27 | } | 27 | } |
| 28 | 28 | ||
| 29 | - Ort::Value Run(Ort::Value x) { | 29 | + Ort::Value Run(Ort::Value x, int64_t sid) { |
| 30 | auto memory_info = | 30 | auto memory_info = |
| 31 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | 31 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); |
| 32 | 32 | ||
| @@ -44,20 +44,33 @@ class OfflineTtsVitsModel::Impl { | @@ -44,20 +44,33 @@ class OfflineTtsVitsModel::Impl { | ||
| 44 | Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1); | 44 | Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1); |
| 45 | 45 | ||
| 46 | int64_t scale_shape = 1; | 46 | int64_t scale_shape = 1; |
| 47 | - float noise_scale = 1; | ||
| 48 | - float length_scale = 1; | ||
| 49 | - float noise_scale_w = 1; | 47 | + float noise_scale = config_.vits.noise_scale; |
| 48 | + float length_scale = config_.vits.length_scale; | ||
| 49 | + float noise_scale_w = config_.vits.noise_scale_w; | ||
| 50 | 50 | ||
| 51 | Ort::Value noise_scale_tensor = | 51 | Ort::Value noise_scale_tensor = |
| 52 | Ort::Value::CreateTensor(memory_info, &noise_scale, 1, &scale_shape, 1); | 52 | Ort::Value::CreateTensor(memory_info, &noise_scale, 1, &scale_shape, 1); |
| 53 | + | ||
| 53 | Ort::Value length_scale_tensor = Ort::Value::CreateTensor( | 54 | Ort::Value length_scale_tensor = Ort::Value::CreateTensor( |
| 54 | memory_info, &length_scale, 1, &scale_shape, 1); | 55 | memory_info, &length_scale, 1, &scale_shape, 1); |
| 56 | + | ||
| 55 | Ort::Value noise_scale_w_tensor = Ort::Value::CreateTensor( | 57 | Ort::Value noise_scale_w_tensor = Ort::Value::CreateTensor( |
| 56 | memory_info, &noise_scale_w, 1, &scale_shape, 1); | 58 | memory_info, &noise_scale_w, 1, &scale_shape, 1); |
| 57 | 59 | ||
| 58 | - std::array<Ort::Value, 5> inputs = { | ||
| 59 | - std::move(x), std::move(x_length), std::move(noise_scale_tensor), | ||
| 60 | - std::move(length_scale_tensor), std::move(noise_scale_w_tensor)}; | 60 | + Ort::Value sid_tensor = |
| 61 | + Ort::Value::CreateTensor(memory_info, &sid, 1, &scale_shape, 1); | ||
| 62 | + | ||
| 63 | + std::vector<Ort::Value> inputs; | ||
| 64 | + inputs.reserve(6); | ||
| 65 | + inputs.push_back(std::move(x)); | ||
| 66 | + inputs.push_back(std::move(x_length)); | ||
| 67 | + inputs.push_back(std::move(noise_scale_tensor)); | ||
| 68 | + inputs.push_back(std::move(length_scale_tensor)); | ||
| 69 | + inputs.push_back(std::move(noise_scale_w_tensor)); | ||
| 70 | + | ||
| 71 | + if (input_names_.size() == 6 && input_names_.back() == "sid") { | ||
| 72 | + inputs.push_back(std::move(sid_tensor)); | ||
| 73 | + } | ||
| 61 | 74 | ||
| 62 | auto out = | 75 | auto out = |
| 63 | sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | 76 | sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), |
| @@ -93,6 +106,7 @@ class OfflineTtsVitsModel::Impl { | @@ -93,6 +106,7 @@ class OfflineTtsVitsModel::Impl { | ||
| 93 | Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | 106 | Ort::AllocatorWithDefaultOptions allocator; // used in the macro below |
| 94 | SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate"); | 107 | SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate"); |
| 95 | SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank"); | 108 | SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank"); |
| 109 | + SHERPA_ONNX_READ_META_DATA(n_speakers_, "n_speakers"); | ||
| 96 | SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation"); | 110 | SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation"); |
| 97 | } | 111 | } |
| 98 | 112 | ||
| @@ -112,6 +126,7 @@ class OfflineTtsVitsModel::Impl { | @@ -112,6 +126,7 @@ class OfflineTtsVitsModel::Impl { | ||
| 112 | 126 | ||
| 113 | int32_t sample_rate_; | 127 | int32_t sample_rate_; |
| 114 | int32_t add_blank_; | 128 | int32_t add_blank_; |
| 129 | + int32_t n_speakers_; | ||
| 115 | std::string punctuations_; | 130 | std::string punctuations_; |
| 116 | }; | 131 | }; |
| 117 | 132 | ||
| @@ -120,8 +135,8 @@ OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config) | @@ -120,8 +135,8 @@ OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config) | ||
| 120 | 135 | ||
| 121 | OfflineTtsVitsModel::~OfflineTtsVitsModel() = default; | 136 | OfflineTtsVitsModel::~OfflineTtsVitsModel() = default; |
| 122 | 137 | ||
| 123 | -Ort::Value OfflineTtsVitsModel::Run(Ort::Value x) { | ||
| 124 | - return impl_->Run(std::move(x)); | 138 | +Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, int64_t sid /*=0*/) { |
| 139 | + return impl_->Run(std::move(x), sid); | ||
| 125 | } | 140 | } |
| 126 | 141 | ||
| 127 | int32_t OfflineTtsVitsModel::SampleRate() const { return impl_->SampleRate(); } | 142 | int32_t OfflineTtsVitsModel::SampleRate() const { return impl_->SampleRate(); } |
| @@ -22,10 +22,14 @@ class OfflineTtsVitsModel { | @@ -22,10 +22,14 @@ class OfflineTtsVitsModel { | ||
| 22 | /** Run the model. | 22 | /** Run the model. |
| 23 | * | 23 | * |
| 24 | * @param x A int64 tensor of shape (1, num_tokens) | 24 | * @param x A int64 tensor of shape (1, num_tokens) |
| 25 | + // @param sid Speaker ID. Used only for multi-speaker models, e.g., models | ||
| 26 | + // trained using the VCTK dataset. It is not used for | ||
| 27 | + // single-speaker models, e.g., models trained using the ljspeech | ||
| 28 | + // dataset. | ||
| 25 | * @return Return a float32 tensor containing audio samples. You can flatten | 29 | * @return Return a float32 tensor containing audio samples. You can flatten |
| 26 | * it to a 1-D tensor. | 30 | * it to a 1-D tensor. |
| 27 | */ | 31 | */ |
| 28 | - Ort::Value Run(Ort::Value x); | 32 | + Ort::Value Run(Ort::Value x, int64_t sid = 0); |
| 29 | 33 | ||
| 30 | // Sample rate of the generated audio | 34 | // Sample rate of the generated audio |
| 31 | int32_t SampleRate() const; | 35 | int32_t SampleRate() const; |
| @@ -28,8 +28,9 @@ OfflineTts::OfflineTts(const OfflineTtsConfig &config) | @@ -28,8 +28,9 @@ OfflineTts::OfflineTts(const OfflineTtsConfig &config) | ||
| 28 | 28 | ||
| 29 | OfflineTts::~OfflineTts() = default; | 29 | OfflineTts::~OfflineTts() = default; |
| 30 | 30 | ||
| 31 | -GeneratedAudio OfflineTts::Generate(const std::string &text) const { | ||
| 32 | - return impl_->Generate(text); | 31 | +GeneratedAudio OfflineTts::Generate(const std::string &text, |
| 32 | + int64_t sid /*=0*/) const { | ||
| 33 | + return impl_->Generate(text, sid); | ||
| 33 | } | 34 | } |
| 34 | 35 | ||
| 35 | } // namespace sherpa_onnx | 36 | } // namespace sherpa_onnx |
| @@ -39,7 +39,11 @@ class OfflineTts { | @@ -39,7 +39,11 @@ class OfflineTts { | ||
| 39 | ~OfflineTts(); | 39 | ~OfflineTts(); |
| 40 | explicit OfflineTts(const OfflineTtsConfig &config); | 40 | explicit OfflineTts(const OfflineTtsConfig &config); |
| 41 | // @param text A string containing words separated by spaces | 41 | // @param text A string containing words separated by spaces |
| 42 | - GeneratedAudio Generate(const std::string &text) const; | 42 | + // @param sid Speaker ID. Used only for multi-speaker models, e.g., models |
| 43 | + // trained using the VCTK dataset. It is not used for | ||
| 44 | + // single-speaker models, e.g., models trained using the ljspeech | ||
| 45 | + // dataset. | ||
| 46 | + GeneratedAudio Generate(const std::string &text, int64_t sid = 0) const; | ||
| 43 | 47 | ||
| 44 | private: | 48 | private: |
| 45 | std::unique_ptr<OfflineTtsImpl> impl_; | 49 | std::unique_ptr<OfflineTtsImpl> impl_; |
| @@ -13,11 +13,12 @@ int main(int32_t argc, char *argv[]) { | @@ -13,11 +13,12 @@ int main(int32_t argc, char *argv[]) { | ||
| 13 | Offline text-to-speech with sherpa-onnx | 13 | Offline text-to-speech with sherpa-onnx |
| 14 | 14 | ||
| 15 | ./bin/sherpa-onnx-offline-tts \ | 15 | ./bin/sherpa-onnx-offline-tts \ |
| 16 | - --vits-model /path/to/model.onnx \ | ||
| 17 | - --vits-lexicon /path/to/lexicon.txt \ | ||
| 18 | - --vits-tokens /path/to/tokens.txt | ||
| 19 | - --output-filename ./generated.wav \ | ||
| 20 | - 'some text within single quotes' | 16 | + --vits-model=/path/to/model.onnx \ |
| 17 | + --vits-lexicon=/path/to/lexicon.txt \ | ||
| 18 | + --vits-tokens=/path/to/tokens.txt \ | ||
| 19 | + --sid=0 \ | ||
| 20 | + --output-filename=./generated.wav \ | ||
| 21 | + 'some text within single quotes on linux/macos or use double quotes on windows' | ||
| 21 | 22 | ||
| 22 | It will generate a file ./generated.wav as specified by --output-filename. | 23 | It will generate a file ./generated.wav as specified by --output-filename. |
| 23 | 24 | ||
| @@ -33,15 +34,27 @@ wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt | @@ -33,15 +34,27 @@ wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt | ||
| 33 | --vits-model=./vits-ljs.onnx \ | 34 | --vits-model=./vits-ljs.onnx \ |
| 34 | --vits-lexicon=./lexicon.txt \ | 35 | --vits-lexicon=./lexicon.txt \ |
| 35 | --vits-tokens=./tokens.txt \ | 36 | --vits-tokens=./tokens.txt \ |
| 37 | + --sid=0 \ | ||
| 36 | --output-filename=./generated.wav \ | 38 | --output-filename=./generated.wav \ |
| 37 | 'liliana, the most beautiful and lovely assistant of our team!' | 39 | 'liliana, the most beautiful and lovely assistant of our team!' |
| 40 | + | ||
| 41 | +Please see | ||
| 42 | +https://k2-fsa.github.io/sherpa/onnx/tts/index.html | ||
| 43 | +or detailes. | ||
| 38 | )usage"; | 44 | )usage"; |
| 39 | 45 | ||
| 40 | sherpa_onnx::ParseOptions po(kUsageMessage); | 46 | sherpa_onnx::ParseOptions po(kUsageMessage); |
| 41 | std::string output_filename = "./generated.wav"; | 47 | std::string output_filename = "./generated.wav"; |
| 48 | + int32_t sid = 0; | ||
| 49 | + | ||
| 42 | po.Register("output-filename", &output_filename, | 50 | po.Register("output-filename", &output_filename, |
| 43 | "Path to save the generated audio"); | 51 | "Path to save the generated audio"); |
| 44 | 52 | ||
| 53 | + po.Register("sid", &sid, | ||
| 54 | + "Speaker ID. Used only for multi-speaker models, e.g., models " | ||
| 55 | + "trained using the VCTK dataset. Not used for single-speaker " | ||
| 56 | + "models, e.g., models trained using the LJSpeech dataset"); | ||
| 57 | + | ||
| 45 | sherpa_onnx::OfflineTtsConfig config; | 58 | sherpa_onnx::OfflineTtsConfig config; |
| 46 | 59 | ||
| 47 | config.Register(&po); | 60 | config.Register(&po); |
| @@ -67,7 +80,7 @@ wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt | @@ -67,7 +80,7 @@ wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt | ||
| 67 | } | 80 | } |
| 68 | 81 | ||
| 69 | sherpa_onnx::OfflineTts tts(config); | 82 | sherpa_onnx::OfflineTts tts(config); |
| 70 | - auto audio = tts.Generate(po.GetArg(1)); | 83 | + auto audio = tts.Generate(po.GetArg(1), sid); |
| 71 | 84 | ||
| 72 | bool ok = sherpa_onnx::WriteWave(output_filename, audio.sample_rate, | 85 | bool ok = sherpa_onnx::WriteWave(output_filename, audio.sample_rate, |
| 73 | audio.samples.data(), audio.samples.size()); | 86 | audio.samples.data(), audio.samples.size()); |
| @@ -76,7 +89,8 @@ wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt | @@ -76,7 +89,8 @@ wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt | ||
| 76 | exit(EXIT_FAILURE); | 89 | exit(EXIT_FAILURE); |
| 77 | } | 90 | } |
| 78 | 91 | ||
| 79 | - fprintf(stderr, "The text is: %s\n", po.GetArg(1).c_str()); | 92 | + fprintf(stderr, "The text is: %s. Speaker ID: %d\n", po.GetArg(1).c_str(), |
| 93 | + sid); | ||
| 80 | fprintf(stderr, "Saved to %s successfully!\n", output_filename.c_str()); | 94 | fprintf(stderr, "Saved to %s successfully!\n", output_filename.c_str()); |
| 81 | 95 | ||
| 82 | return 0; | 96 | return 0; |
| @@ -16,11 +16,16 @@ void PybindOfflineTtsVitsModelConfig(py::module *m) { | @@ -16,11 +16,16 @@ void PybindOfflineTtsVitsModelConfig(py::module *m) { | ||
| 16 | py::class_<PyClass>(*m, "OfflineTtsVitsModelConfig") | 16 | py::class_<PyClass>(*m, "OfflineTtsVitsModelConfig") |
| 17 | .def(py::init<>()) | 17 | .def(py::init<>()) |
| 18 | .def(py::init<const std::string &, const std::string &, | 18 | .def(py::init<const std::string &, const std::string &, |
| 19 | - const std::string &>(), | ||
| 20 | - py::arg("model"), py::arg("lexicon"), py::arg("tokens")) | 19 | + const std::string &, float, float, float>(), |
| 20 | + py::arg("model"), py::arg("lexicon"), py::arg("tokens"), | ||
| 21 | + py::arg("noise_scale") = 0.667, py::arg("noise_scale_w") = 0.8, | ||
| 22 | + py::arg("length_scale") = 1.0) | ||
| 21 | .def_readwrite("model", &PyClass::model) | 23 | .def_readwrite("model", &PyClass::model) |
| 22 | .def_readwrite("lexicon", &PyClass::lexicon) | 24 | .def_readwrite("lexicon", &PyClass::lexicon) |
| 23 | .def_readwrite("tokens", &PyClass::tokens) | 25 | .def_readwrite("tokens", &PyClass::tokens) |
| 26 | + .def_readwrite("noise_scale", &PyClass::noise_scale) | ||
| 27 | + .def_readwrite("noise_scale_w", &PyClass::noise_scale_w) | ||
| 28 | + .def_readwrite("length_scale", &PyClass::length_scale) | ||
| 24 | .def("__str__", &PyClass::ToString); | 29 | .def("__str__", &PyClass::ToString); |
| 25 | } | 30 | } |
| 26 | 31 |
| @@ -40,7 +40,7 @@ void PybindOfflineTts(py::module *m) { | @@ -40,7 +40,7 @@ void PybindOfflineTts(py::module *m) { | ||
| 40 | using PyClass = OfflineTts; | 40 | using PyClass = OfflineTts; |
| 41 | py::class_<PyClass>(*m, "OfflineTts") | 41 | py::class_<PyClass>(*m, "OfflineTts") |
| 42 | .def(py::init<const OfflineTtsConfig &>(), py::arg("config")) | 42 | .def(py::init<const OfflineTtsConfig &>(), py::arg("config")) |
| 43 | - .def("generate", &PyClass::Generate); | 43 | + .def("generate", &PyClass::Generate, py::arg("text"), py::arg("sid") = 0); |
| 44 | } | 44 | } |
| 45 | 45 | ||
| 46 | } // namespace sherpa_onnx | 46 | } // namespace sherpa_onnx |
-
请 注册 或 登录 后发表评论