Fangjun Kuang
Committed by GitHub

Support English for MeloTTS models. (#1134)

@@ -20,7 +20,7 @@ jobs: @@ -20,7 +20,7 @@ jobs:
20 strategy: 20 strategy:
21 fail-fast: false 21 fail-fast: false
22 matrix: 22 matrix:
23 - os: [windows-latest] 23 + os: [windows-2019]
24 24
25 steps: 25 steps:
26 - uses: actions/checkout@v4 26 - uses: actions/checkout@v4
@@ -6,9 +6,13 @@ import torch @@ -6,9 +6,13 @@ import torch
6 from melo.api import TTS 6 from melo.api import TTS
7 from melo.text import language_id_map, language_tone_start_map 7 from melo.text import language_id_map, language_tone_start_map
8 from melo.text.chinese import pinyin_to_symbol_map 8 from melo.text.chinese import pinyin_to_symbol_map
  9 +from melo.text.english import eng_dict, refine_syllables
9 from pypinyin import Style, lazy_pinyin, phrases_dict, pinyin_dict 10 from pypinyin import Style, lazy_pinyin, phrases_dict, pinyin_dict
  11 +from melo.text.symbols import language_tone_start_map
10 12
11 for k, v in pinyin_to_symbol_map.items(): 13 for k, v in pinyin_to_symbol_map.items():
  14 + if isinstance(v, list):
  15 + break
12 pinyin_to_symbol_map[k] = v.split() 16 pinyin_to_symbol_map[k] = v.split()
13 17
14 18
@@ -79,6 +83,16 @@ def generate_lexicon(): @@ -79,6 +83,16 @@ def generate_lexicon():
79 word_dict = pinyin_dict.pinyin_dict 83 word_dict = pinyin_dict.pinyin_dict
80 phrases = phrases_dict.phrases_dict 84 phrases = phrases_dict.phrases_dict
81 with open("lexicon.txt", "w", encoding="utf-8") as f: 85 with open("lexicon.txt", "w", encoding="utf-8") as f:
  86 + for word in eng_dict:
  87 + phones, tones = refine_syllables(eng_dict[word])
  88 + tones = [t + language_tone_start_map["EN"] for t in tones]
  89 + tones = [str(t) for t in tones]
  90 +
  91 + phones = " ".join(phones)
  92 + tones = " ".join(tones)
  93 +
  94 + f.write(f"{word.lower()} {phones} {tones}\n")
  95 +
82 for key in word_dict: 96 for key in word_dict:
83 if not (0x4E00 <= key <= 0x9FA5): 97 if not (0x4E00 <= key <= 0x9FA5):
84 continue 98 continue
@@ -125,15 +139,13 @@ class ModelWrapper(torch.nn.Module): @@ -125,15 +139,13 @@ class ModelWrapper(torch.nn.Module):
125 def __init__(self, model: "SynthesizerTrn"): 139 def __init__(self, model: "SynthesizerTrn"):
126 super().__init__() 140 super().__init__()
127 self.model = model 141 self.model = model
  142 + self.lang_id = language_id_map[model.language]
128 143
129 def forward( 144 def forward(
130 self, 145 self,
131 x, 146 x,
132 x_lengths, 147 x_lengths,
133 tones, 148 tones,
134 - lang_id,  
135 - bert,  
136 - ja_bert,  
137 sid, 149 sid,
138 noise_scale, 150 noise_scale,
139 length_scale, 151 length_scale,
@@ -147,7 +159,11 @@ class ModelWrapper(torch.nn.Module): @@ -147,7 +159,11 @@ class ModelWrapper(torch.nn.Module):
147 lang_id: A 1-D array of dtype np.int64. Its shape is (token_numbers,) 159 lang_id: A 1-D array of dtype np.int64. Its shape is (token_numbers,)
148 sid: an integer 160 sid: an integer
149 """ 161 """
150 - return self.model.infer( 162 + bert = torch.zeros(x.shape[0], 1024, x.shape[1], dtype=torch.float32)
  163 + ja_bert = torch.zeros(x.shape[0], 768, x.shape[1], dtype=torch.float32)
  164 + lang_id = torch.zeros_like(x)
  165 + lang_id[:, 1::2] = self.lang_id
  166 + return self.model.model.infer(
151 x=x, 167 x=x,
152 x_lengths=x_lengths, 168 x_lengths=x_lengths,
153 sid=sid, 169 sid=sid,
@@ -169,7 +185,7 @@ def main(): @@ -169,7 +185,7 @@ def main():
169 185
170 generate_tokens(model.hps["symbols"]) 186 generate_tokens(model.hps["symbols"])
171 187
172 - torch_model = ModelWrapper(model.model) 188 + torch_model = ModelWrapper(model)
173 189
174 opset_version = 13 190 opset_version = 13
175 x = torch.randint(low=0, high=10, size=(60,), dtype=torch.int64) 191 x = torch.randint(low=0, high=10, size=(60,), dtype=torch.int64)
@@ -177,19 +193,13 @@ def main(): @@ -177,19 +193,13 @@ def main():
177 x_lengths = torch.tensor([x.size(0)], dtype=torch.int64) 193 x_lengths = torch.tensor([x.size(0)], dtype=torch.int64)
178 sid = torch.tensor([1], dtype=torch.int64) 194 sid = torch.tensor([1], dtype=torch.int64)
179 tones = torch.zeros_like(x) 195 tones = torch.zeros_like(x)
180 - lang_id = torch.ones_like(x) 196 +
181 noise_scale = torch.tensor([1.0], dtype=torch.float32) 197 noise_scale = torch.tensor([1.0], dtype=torch.float32)
182 length_scale = torch.tensor([1.0], dtype=torch.float32) 198 length_scale = torch.tensor([1.0], dtype=torch.float32)
183 noise_scale_w = torch.tensor([1.0], dtype=torch.float32) 199 noise_scale_w = torch.tensor([1.0], dtype=torch.float32)
184 200
185 - bert = torch.zeros(1024, x.shape[0], dtype=torch.float32)  
186 - ja_bert = torch.zeros(768, x.shape[0], dtype=torch.float32)  
187 -  
188 x = x.unsqueeze(0) 201 x = x.unsqueeze(0)
189 tones = tones.unsqueeze(0) 202 tones = tones.unsqueeze(0)
190 - lang_id = lang_id.unsqueeze(0)  
191 - bert = bert.unsqueeze(0)  
192 - ja_bert = ja_bert.unsqueeze(0)  
193 203
194 filename = "model.onnx" 204 filename = "model.onnx"
195 205
@@ -199,9 +209,6 @@ def main(): @@ -199,9 +209,6 @@ def main():
199 x, 209 x,
200 x_lengths, 210 x_lengths,
201 tones, 211 tones,
202 - lang_id,  
203 - bert,  
204 - ja_bert,  
205 sid, 212 sid,
206 noise_scale, 213 noise_scale,
207 length_scale, 214 length_scale,
@@ -213,9 +220,6 @@ def main(): @@ -213,9 +220,6 @@ def main():
213 "x", 220 "x",
214 "x_lengths", 221 "x_lengths",
215 "tones", 222 "tones",
216 - "lang_id",  
217 - "bert",  
218 - "ja_bert",  
219 "sid", 223 "sid",
220 "noise_scale", 224 "noise_scale",
221 "length_scale", 225 "length_scale",
@@ -226,9 +230,6 @@ def main(): @@ -226,9 +230,6 @@ def main():
226 "x": {0: "N", 1: "L"}, 230 "x": {0: "N", 1: "L"},
227 "x_lengths": {0: "N"}, 231 "x_lengths": {0: "N"},
228 "tones": {0: "N", 1: "L"}, 232 "tones": {0: "N", 1: "L"},
229 - "lang_id": {0: "N", 1: "L"},  
230 - "bert": {0: "N", 2: "L"},  
231 - "ja_bert": {0: "N", 2: "L"},  
232 "y": {0: "N", 1: "S", 2: "T"}, 233 "y": {0: "N", 1: "S", 2: "T"},
233 }, 234 },
234 ) 235 )
@@ -28,6 +28,8 @@ echo "pwd: $PWD" @@ -28,6 +28,8 @@ echo "pwd: $PWD"
28 28
29 ls -lh 29 ls -lh
30 30
  31 +./show-info.py
  32 +
31 head lexicon.txt 33 head lexicon.txt
32 echo "---" 34 echo "---"
33 tail lexicon.txt 35 tail lexicon.txt
  1 +#!/usr/bin/env python3
  2 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +import onnxruntime
  5 +
  6 +
  7 +def show(filename):
  8 + session_opts = onnxruntime.SessionOptions()
  9 + session_opts.log_severity_level = 3
  10 + sess = onnxruntime.InferenceSession(filename, session_opts)
  11 + for i in sess.get_inputs():
  12 + print(i)
  13 +
  14 + print("-----")
  15 +
  16 + for i in sess.get_outputs():
  17 + print(i)
  18 +
  19 + meta = sess.get_modelmeta().custom_metadata_map
  20 + print("*****************************************")
  21 + print("meta\n", meta)
  22 +
  23 +
  24 +def main():
  25 + print("=========model==========")
  26 + show("./model.onnx")
  27 +
  28 +
  29 +if __name__ == "__main__":
  30 + main()
  31 +
  32 +"""
  33 +=========model==========
  34 +NodeArg(name='x', type='tensor(int64)', shape=['N', 'L'])
  35 +NodeArg(name='x_lengths', type='tensor(int64)', shape=['N'])
  36 +NodeArg(name='tones', type='tensor(int64)', shape=['N', 'L'])
  37 +NodeArg(name='sid', type='tensor(int64)', shape=[1])
  38 +NodeArg(name='noise_scale', type='tensor(float)', shape=[1])
  39 +NodeArg(name='length_scale', type='tensor(float)', shape=[1])
  40 +NodeArg(name='noise_scale_w', type='tensor(float)', shape=[1])
  41 +-----
  42 +NodeArg(name='y', type='tensor(float)', shape=['N', 'S', 'T'])
  43 +*****************************************
  44 +meta
  45 + {'description': 'MeloTTS is a high-quality multi-lingual text-to-speech library by MyShell.ai',
  46 + 'model_type': 'melo-vits', 'license': 'MIT license', 'sample_rate': '44100', 'add_blank': '1',
  47 + 'n_speakers': '1', 'bert_dim': '1024', 'language': 'Chinese + English',
  48 + 'ja_bert_dim': '768', 'speaker_id': '1', 'comment': 'melo', 'lang_id': '3',
  49 + 'tone_start': '0', 'url': 'https://github.com/myshell-ai/MeloTTS'}
  50 +"""
@@ -30,6 +30,8 @@ class Lexicon: @@ -30,6 +30,8 @@ class Lexicon:
30 tones = [int(t) for t in tones] 30 tones = [int(t) for t in tones]
31 31
32 lexicon[word_or_phrase] = (phones, tones) 32 lexicon[word_or_phrase] = (phones, tones)
  33 + lexicon["呣"] = lexicon["母"]
  34 + lexicon["嗯"] = lexicon["恩"]
33 self.lexicon = lexicon 35 self.lexicon = lexicon
34 36
35 punctuation = ["!", "?", "…", ",", ".", "'", "-"] 37 punctuation = ["!", "?", "…", ",", ".", "'", "-"]
@@ -98,20 +100,16 @@ class OnnxModel: @@ -98,20 +100,16 @@ class OnnxModel:
98 self.lang_id = int(meta["lang_id"]) 100 self.lang_id = int(meta["lang_id"])
99 self.sample_rate = int(meta["sample_rate"]) 101 self.sample_rate = int(meta["sample_rate"])
100 102
101 - def __call__(self, x, tones, lang): 103 + def __call__(self, x, tones):
102 """ 104 """
103 Args: 105 Args:
104 x: 1-D int64 torch tensor 106 x: 1-D int64 torch tensor
105 tones: 1-D int64 torch tensor 107 tones: 1-D int64 torch tensor
106 - lang: 1-D int64 torch tensor  
107 """ 108 """
108 x = x.unsqueeze(0) 109 x = x.unsqueeze(0)
109 tones = tones.unsqueeze(0) 110 tones = tones.unsqueeze(0)
110 - lang = lang.unsqueeze(0)  
111 111
112 - print(x.shape, tones.shape, lang.shape)  
113 - bert = torch.zeros(1, self.bert_dim, x.shape[-1])  
114 - ja_bert = torch.zeros(1, self.ja_bert_dim, x.shape[-1]) 112 + print(x.shape, tones.shape)
115 sid = torch.tensor([self.speaker_id], dtype=torch.int64) 113 sid = torch.tensor([self.speaker_id], dtype=torch.int64)
116 noise_scale = torch.tensor([0.6], dtype=torch.float32) 114 noise_scale = torch.tensor([0.6], dtype=torch.float32)
117 length_scale = torch.tensor([1.0], dtype=torch.float32) 115 length_scale = torch.tensor([1.0], dtype=torch.float32)
@@ -125,9 +123,6 @@ class OnnxModel: @@ -125,9 +123,6 @@ class OnnxModel:
125 "x": x.numpy(), 123 "x": x.numpy(),
126 "x_lengths": x_lengths.numpy(), 124 "x_lengths": x_lengths.numpy(),
127 "tones": tones.numpy(), 125 "tones": tones.numpy(),
128 - "lang_id": lang.numpy(),  
129 - "bert": bert.numpy(),  
130 - "ja_bert": ja_bert.numpy(),  
131 "sid": sid.numpy(), 126 "sid": sid.numpy(),
132 "noise_scale": noise_scale.numpy(), 127 "noise_scale": noise_scale.numpy(),
133 "noise_scale_w": noise_scale_w.numpy(), 128 "noise_scale_w": noise_scale_w.numpy(),
@@ -140,34 +135,46 @@ class OnnxModel: @@ -140,34 +135,46 @@ class OnnxModel:
140 def main(): 135 def main():
141 lexicon = Lexicon(lexion_filename="./lexicon.txt", tokens_filename="./tokens.txt") 136 lexicon = Lexicon(lexion_filename="./lexicon.txt", tokens_filename="./tokens.txt")
142 137
143 - text = "永远相信,美好的事情即将发生。多音字测试, 银行,行不行?长沙长大" 138 + text = "永远相信,美好的事情即将发生。"
144 s = jieba.cut(text, HMM=True) 139 s = jieba.cut(text, HMM=True)
145 140
146 phones, tones = lexicon.convert(s) 141 phones, tones = lexicon.convert(s)
147 142
  143 + en_text = "how are you ?".split()
  144 +
  145 + phones_en, tones_en = lexicon.convert(en_text)
  146 + phones += [0]
  147 + tones += [0]
  148 +
  149 + phones += phones_en
  150 + tones += tones_en
  151 +
  152 + text = "多音字测试, 银行,行不行?长沙长大"
  153 + s = jieba.cut(text, HMM=True)
  154 +
  155 + phones2, tones2 = lexicon.convert(s)
  156 +
  157 + phones += phones2
  158 + tones += tones2
  159 +
148 model = OnnxModel("./model.onnx") 160 model = OnnxModel("./model.onnx")
149 - langs = [model.lang_id] * len(phones)  
150 161
151 if model.add_blank: 162 if model.add_blank:
152 new_phones = [0] * (2 * len(phones) + 1) 163 new_phones = [0] * (2 * len(phones) + 1)
153 new_tones = [0] * (2 * len(tones) + 1) 164 new_tones = [0] * (2 * len(tones) + 1)
154 - new_langs = [0] * (2 * len(langs) + 1)  
155 165
156 new_phones[1::2] = phones 166 new_phones[1::2] = phones
157 new_tones[1::2] = tones 167 new_tones[1::2] = tones
158 - new_langs[1::2] = langs  
159 168
160 phones = new_phones 169 phones = new_phones
161 tones = new_tones 170 tones = new_tones
162 - langs = new_langs  
163 171
164 phones = torch.tensor(phones, dtype=torch.int64) 172 phones = torch.tensor(phones, dtype=torch.int64)
165 tones = torch.tensor(tones, dtype=torch.int64) 173 tones = torch.tensor(tones, dtype=torch.int64)
166 - langs = torch.tensor(langs, dtype=torch.int64)  
167 174
168 - print(phones.shape, tones.shape, langs.shape) 175 + print(phones.shape, tones.shape)
169 176
170 - y = model(x=phones, tones=tones, lang=langs) 177 + y = model(x=phones, tones=tones)
171 sf.write("./test.wav", y, model.sample_rate) 178 sf.write("./test.wav", y, model.sample_rate)
172 179
173 180