Fangjun Kuang
Committed by GitHub

Export the English TTS model from MeloTTS (#1509)

@@ -40,7 +40,7 @@ jobs: @@ -40,7 +40,7 @@ jobs:
40 name: test.wav 40 name: test.wav
41 path: scripts/melo-tts/test.wav 41 path: scripts/melo-tts/test.wav
42 42
43 - - name: Publish to huggingface 43 + - name: Publish to huggingface (Chinese + English)
44 env: 44 env:
45 HF_TOKEN: ${{ secrets.HF_TOKEN }} 45 HF_TOKEN: ${{ secrets.HF_TOKEN }}
46 uses: nick-fields/retry@v3 46 uses: nick-fields/retry@v3
@@ -61,14 +61,14 @@ jobs: @@ -61,14 +61,14 @@ jobs:
61 git fetch 61 git fetch
62 git pull 62 git pull
63 echo "pwd: $PWD" 63 echo "pwd: $PWD"
64 - ls -lh ../scripts/melo-tts 64 + ls -lh ../scripts/melo-tts/zh_en
65 65
66 rm -rf ./ 66 rm -rf ./
67 67
68 - cp -v ../scripts/melo-tts/*.onnx .  
69 - cp -v ../scripts/melo-tts/lexicon.txt .  
70 - cp -v ../scripts/melo-tts/tokens.txt .  
71 - cp -v ../scripts/melo-tts/README.md . 68 + cp -v ../scripts/melo-tts/zh_en/*.onnx .
  69 + cp -v ../scripts/melo-tts/zh_en/lexicon.txt .
  70 + cp -v ../scripts/melo-tts/zh_en/tokens.txt .
  71 + cp -v ../scripts/melo-tts/zh_en/README.md .
72 72
73 curl -SL -O https://raw.githubusercontent.com/myshell-ai/MeloTTS/main/LICENSE 73 curl -SL -O https://raw.githubusercontent.com/myshell-ai/MeloTTS/main/LICENSE
74 74
@@ -102,6 +102,60 @@ jobs: @@ -102,6 +102,60 @@ jobs:
102 tar cjvf $dst.tar.bz2 $dst 102 tar cjvf $dst.tar.bz2 $dst
103 rm -rf $dst 103 rm -rf $dst
104 104
  105 + - name: Publish to huggingface (English)
  106 + env:
  107 + HF_TOKEN: ${{ secrets.HF_TOKEN }}
  108 + uses: nick-fields/retry@v3
  109 + with:
  110 + max_attempts: 20
  111 + timeout_seconds: 200
  112 + shell: bash
  113 + command: |
  114 + git config --global user.email "csukuangfj@gmail.com"
  115 + git config --global user.name "Fangjun Kuang"
  116 +
  117 + rm -rf huggingface
  118 + export GIT_LFS_SKIP_SMUDGE=1
  119 + export GIT_CLONE_PROTECTION_ACTIVE=false
  120 +
  121 + git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/vits-melo-tts-en huggingface
  122 + cd huggingface
  123 + git fetch
  124 + git pull
  125 + echo "pwd: $PWD"
  126 + ls -lh ../scripts/melo-tts/en
  127 +
  128 + rm -rf ./
  129 +
  130 + cp -v ../scripts/melo-tts/en/*.onnx .
  131 + cp -v ../scripts/melo-tts/en/lexicon.txt .
  132 + cp -v ../scripts/melo-tts/en/tokens.txt .
  133 + cp -v ../scripts/melo-tts/en/README.md .
  134 +
  135 + curl -SL -O https://raw.githubusercontent.com/myshell-ai/MeloTTS/main/LICENSE
  136 +
  137 + git lfs track "*.onnx"
  138 + git add .
  139 +
  140 + ls -lh
  141 +
  142 + git status
  143 +
  144 + git diff
  145 +
  146 + git commit -m "add models"
  147 + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/vits-melo-tts-en main || true
  148 +
  149 + cd ..
  150 +
  151 + rm -rf huggingface/.git*
  152 + dst=vits-melo-tts-en
  153 +
  154 + mv huggingface $dst
  155 +
  156 + tar cjvf $dst.tar.bz2 $dst
  157 + rm -rf $dst
  158 +
105 - name: Release 159 - name: Release
106 uses: svenstaro/upload-release-action@v2 160 uses: svenstaro/upload-release-action@v2
107 with: 161 with:
@@ -3,4 +3,5 @@ @@ -3,4 +3,5 @@
3 Models in this directory are converted from 3 Models in this directory are converted from
4 https://github.com/myshell-ai/MeloTTS 4 https://github.com/myshell-ai/MeloTTS
5 5
6 -Note there is only a single female speaker in the model. 6 +Note there is only a single female speaker in the model for Chinese+English TTS.
  7 +TTS model, whereas there are 5 female speakers in the model For English TTS.
  1 +#!/usr/bin/env python3
  2 +# This model exports the English-only TTS model.
  3 +# It has 5 speakers.
  4 +# {'EN-US': 0, 'EN-BR': 1, 'EN_INDIA': 2, 'EN-AU': 3, 'EN-Default': 4}
  5 +
  6 +from typing import Any, Dict
  7 +
  8 +import onnx
  9 +import torch
  10 +from melo.api import TTS
  11 +from melo.text import language_id_map, language_tone_start_map
  12 +from melo.text.chinese import pinyin_to_symbol_map
  13 +from melo.text.english import eng_dict, refine_syllables
  14 +from pypinyin import Style, lazy_pinyin, phrases_dict, pinyin_dict
  15 +
  16 +
  17 +def generate_tokens(symbol_list):
  18 + with open("tokens.txt", "w", encoding="utf-8") as f:
  19 + for i, s in enumerate(symbol_list):
  20 + f.write(f"{s} {i}\n")
  21 +
  22 +
  23 +def add_new_english_words(lexicon):
  24 + """
  25 + Args:
  26 + lexicon:
  27 + Please modify it in-place.
  28 + """
  29 +
  30 + # Please have a look at
  31 + # https://github.com/myshell-ai/MeloTTS/blob/main/melo/text/cmudict.rep
  32 +
  33 + # We give several examples below about how to add new words
  34 +
  35 + # Example 1. Add a new word kaldi
  36 +
  37 + # It does not contain the word kaldi in cmudict.rep
  38 + # so if we add the following line to cmudict.rep
  39 + #
  40 + # KALDI K AH0 - L D IH0
  41 + #
  42 + # then we need to change the lexicon like below
  43 + lexicon["kaldi"] = [["K", "AH0"], ["L", "D", "IH0"]]
  44 + #
  45 + # K AH0 and L D IH0 are separated by a dash "-", so
  46 + # ["K", "AH0"] is a in list and ["L", "D", "IH0"] is in a separate list
  47 +
  48 + # Note: Either kaldi or KALDI is fine. You can use either lowercase or
  49 + # uppercase or both
  50 +
  51 + # Example 2. Add a new word SF
  52 + #
  53 + # If we add the following line to cmudict.rep
  54 + #
  55 + # SF EH1 S - EH1 F
  56 + #
  57 + # to cmudict.rep, then we need to change the lexicon like below:
  58 + lexicon["SF"] = [["EH1", "S"], ["EH1", "F"]]
  59 +
  60 + # Please add your new words here
  61 +
  62 + # No need to return lexicon since it is changed in-place
  63 +
  64 +
  65 +def generate_lexicon():
  66 + add_new_english_words(eng_dict)
  67 + with open("lexicon.txt", "w", encoding="utf-8") as f:
  68 + for word in eng_dict:
  69 + phones, tones = refine_syllables(eng_dict[word])
  70 + tones = [t + language_tone_start_map["EN"] for t in tones]
  71 + tones = [str(t) for t in tones]
  72 +
  73 + phones = " ".join(phones)
  74 + tones = " ".join(tones)
  75 +
  76 + f.write(f"{word.lower()} {phones} {tones}\n")
  77 +
  78 +
  79 +def add_meta_data(filename: str, meta_data: Dict[str, Any]):
  80 + """Add meta data to an ONNX model. It is changed in-place.
  81 +
  82 + Args:
  83 + filename:
  84 + Filename of the ONNX model to be changed.
  85 + meta_data:
  86 + Key-value pairs.
  87 + """
  88 + model = onnx.load(filename)
  89 + while len(model.metadata_props):
  90 + model.metadata_props.pop()
  91 +
  92 + for key, value in meta_data.items():
  93 + meta = model.metadata_props.add()
  94 + meta.key = key
  95 + meta.value = str(value)
  96 +
  97 + onnx.save(model, filename)
  98 +
  99 +
  100 +class ModelWrapper(torch.nn.Module):
  101 + def __init__(self, model: "SynthesizerTrn"):
  102 + super().__init__()
  103 + self.model = model
  104 + self.lang_id = language_id_map[model.language]
  105 +
  106 + def forward(
  107 + self,
  108 + x,
  109 + x_lengths,
  110 + tones,
  111 + sid,
  112 + noise_scale,
  113 + length_scale,
  114 + noise_scale_w,
  115 + max_len=None,
  116 + ):
  117 + """
  118 + Args:
  119 + x: A 1-D array of dtype np.int64. Its shape is (token_numbers,)
  120 + tones: A 1-D array of dtype np.int64. Its shape is (token_numbers,)
  121 + lang_id: A 1-D array of dtype np.int64. Its shape is (token_numbers,)
  122 + sid: an integer
  123 + """
  124 + bert = torch.zeros(x.shape[0], 1024, x.shape[1], dtype=torch.float32)
  125 + ja_bert = torch.zeros(x.shape[0], 768, x.shape[1], dtype=torch.float32)
  126 + lang_id = torch.zeros_like(x)
  127 + lang_id[:, 1::2] = self.lang_id
  128 + return self.model.model.infer(
  129 + x=x,
  130 + x_lengths=x_lengths,
  131 + sid=sid,
  132 + tone=tones,
  133 + language=lang_id,
  134 + bert=bert,
  135 + ja_bert=ja_bert,
  136 + noise_scale=noise_scale,
  137 + noise_scale_w=noise_scale_w,
  138 + length_scale=length_scale,
  139 + )[0]
  140 +
  141 +
  142 +def main():
  143 + generate_lexicon()
  144 +
  145 + language = "EN"
  146 + model = TTS(language=language, device="cpu")
  147 +
  148 + generate_tokens(model.hps["symbols"])
  149 +
  150 + torch_model = ModelWrapper(model)
  151 +
  152 + opset_version = 13
  153 + x = torch.randint(low=0, high=10, size=(60,), dtype=torch.int64)
  154 + print(x.shape)
  155 + x_lengths = torch.tensor([x.size(0)], dtype=torch.int64)
  156 + sid = torch.tensor([1], dtype=torch.int64)
  157 + tones = torch.zeros_like(x)
  158 +
  159 + noise_scale = torch.tensor([1.0], dtype=torch.float32)
  160 + length_scale = torch.tensor([1.0], dtype=torch.float32)
  161 + noise_scale_w = torch.tensor([1.0], dtype=torch.float32)
  162 +
  163 + x = x.unsqueeze(0)
  164 + tones = tones.unsqueeze(0)
  165 +
  166 + filename = "model.onnx"
  167 +
  168 + torch.onnx.export(
  169 + torch_model,
  170 + (
  171 + x,
  172 + x_lengths,
  173 + tones,
  174 + sid,
  175 + noise_scale,
  176 + length_scale,
  177 + noise_scale_w,
  178 + ),
  179 + filename,
  180 + opset_version=opset_version,
  181 + input_names=[
  182 + "x",
  183 + "x_lengths",
  184 + "tones",
  185 + "sid",
  186 + "noise_scale",
  187 + "length_scale",
  188 + "noise_scale_w",
  189 + ],
  190 + output_names=["y"],
  191 + dynamic_axes={
  192 + "x": {0: "N", 1: "L"},
  193 + "x_lengths": {0: "N"},
  194 + "tones": {0: "N", 1: "L"},
  195 + "y": {0: "N", 1: "S", 2: "T"},
  196 + },
  197 + )
  198 +
  199 + meta_data = {
  200 + "model_type": "melo-vits",
  201 + "comment": "melo",
  202 + "version": 2,
  203 + "language": "English",
  204 + "add_blank": int(model.hps.data.add_blank),
  205 + "n_speakers": len(model.hps.data.spk2id), # 5
  206 + "jieba": 0,
  207 + "sample_rate": model.hps.data.sampling_rate,
  208 + "bert_dim": 1024,
  209 + "ja_bert_dim": 768,
  210 + "speaker_id": 0,
  211 + "lang_id": language_id_map[model.language],
  212 + "tone_start": language_tone_start_map[model.language],
  213 + "url": "https://github.com/myshell-ai/MeloTTS",
  214 + "license": "MIT license",
  215 + "description": "MeloTTS is a high-quality multi-lingual text-to-speech library by MyShell.ai",
  216 + }
  217 + add_meta_data(filename, meta_data)
  218 +
  219 +
  220 +if __name__ == "__main__":
  221 + main()
1 #!/usr/bin/env python3 1 #!/usr/bin/env python3
  2 +# This script export ZH_EN TTS model, which supports both Chinese and English.
  3 +# This model has only 1 speaker.
  4 +
2 from typing import Any, Dict 5 from typing import Any, Dict
3 6
4 import onnx 7 import onnx
@@ -38,4 +38,24 @@ tail tokens.txt @@ -38,4 +38,24 @@ tail tokens.txt
38 38
39 ./test.py 39 ./test.py
40 40
  41 +mkdir zh_en
  42 +mv -v *.onnx zh_en/
  43 +mv -v lexicon.txt zh_en
  44 +mv -v tokens.txt zh_en
  45 +cp -v README.md zh_en
  46 +
  47 +ls -lh
  48 +echo "---"
  49 +ls -lh zh_en
  50 +
  51 +./export-onnx-en.py
  52 +
  53 +mkdir en
  54 +mv -v *.onnx en/
  55 +mv -v lexicon.txt en
  56 +mv -v tokens.txt en
  57 +cp -v README.md en
  58 +
  59 +ls -lh en
  60 +
41 ls -lh 61 ls -lh
@@ -152,10 +152,6 @@ @@ -152,10 +152,6 @@
152 #define SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(dst, src_key) \ 152 #define SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(dst, src_key) \
153 do { \ 153 do { \
154 auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ 154 auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
155 - if (value.empty()) { \  
156 - SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \  
157 - exit(-1); \  
158 - } \  
159 \ 155 \
160 dst = std::move(value); \ 156 dst = std::move(value); \
161 } while (0) 157 } while (0)
@@ -48,6 +48,20 @@ class MeloTtsLexicon::Impl { @@ -48,6 +48,20 @@ class MeloTtsLexicon::Impl {
48 } 48 }
49 } 49 }
50 50
  51 + Impl(const std::string &lexicon, const std::string &tokens,
  52 + const OfflineTtsVitsModelMetaData &meta_data, bool debug)
  53 + : meta_data_(meta_data), debug_(debug) {
  54 + {
  55 + std::ifstream is(tokens);
  56 + InitTokens(is);
  57 + }
  58 +
  59 + {
  60 + std::ifstream is(lexicon);
  61 + InitLexicon(is);
  62 + }
  63 + }
  64 +
51 std::vector<TokenIDs> ConvertTextToTokenIds(const std::string &_text) const { 65 std::vector<TokenIDs> ConvertTextToTokenIds(const std::string &_text) const {
52 std::string text = ToLowerCase(_text); 66 std::string text = ToLowerCase(_text);
53 // see 67 // see
@@ -65,21 +79,39 @@ class MeloTtsLexicon::Impl { @@ -65,21 +79,39 @@ class MeloTtsLexicon::Impl {
65 s = std::regex_replace(s, punct_re4, "!"); 79 s = std::regex_replace(s, punct_re4, "!");
66 80
67 std::vector<std::string> words; 81 std::vector<std::string> words;
68 - bool is_hmm = true;  
69 - jieba_->Cut(text, words, is_hmm);  
70 -  
71 - if (debug_) {  
72 - SHERPA_ONNX_LOGE("input text: %s", text.c_str());  
73 - SHERPA_ONNX_LOGE("after replacing punctuations: %s", s.c_str());  
74 -  
75 - std::ostringstream os;  
76 - std::string sep = "";  
77 - for (const auto &w : words) {  
78 - os << sep << w;  
79 - sep = "_";  
80 - } 82 + if (jieba_) {
  83 + bool is_hmm = true;
  84 + jieba_->Cut(text, words, is_hmm);
  85 +
  86 + if (debug_) {
  87 + SHERPA_ONNX_LOGE("input text: %s", text.c_str());
  88 + SHERPA_ONNX_LOGE("after replacing punctuations: %s", s.c_str());
  89 +
  90 + std::ostringstream os;
  91 + std::string sep = "";
  92 + for (const auto &w : words) {
  93 + os << sep << w;
  94 + sep = "_";
  95 + }
81 96
82 - SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str()); 97 + SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str());
  98 + }
  99 + } else {
  100 + words = SplitUtf8(text);
  101 +
  102 + if (debug_) {
  103 + fprintf(stderr, "Input text in string (lowercase): %s\n", text.c_str());
  104 + fprintf(stderr, "Input text in bytes (lowercase):");
  105 + for (uint8_t c : text) {
  106 + fprintf(stderr, " %02x", c);
  107 + }
  108 + fprintf(stderr, "\n");
  109 + fprintf(stderr, "After splitting to words:");
  110 + for (const auto &w : words) {
  111 + fprintf(stderr, " %s", w.c_str());
  112 + }
  113 + fprintf(stderr, "\n");
  114 + }
83 } 115 }
84 116
85 std::vector<TokenIDs> ans; 117 std::vector<TokenIDs> ans;
@@ -241,6 +273,7 @@ class MeloTtsLexicon::Impl { @@ -241,6 +273,7 @@ class MeloTtsLexicon::Impl {
241 {std::move(word), TokenIDs{std::move(ids64), std::move(tone_list)}}); 273 {std::move(word), TokenIDs{std::move(ids64), std::move(tone_list)}});
242 } 274 }
243 275
  276 + // For Chinese+English MeloTTS
244 word2ids_["呣"] = word2ids_["母"]; 277 word2ids_["呣"] = word2ids_["母"];
245 word2ids_["嗯"] = word2ids_["恩"]; 278 word2ids_["嗯"] = word2ids_["恩"];
246 } 279 }
@@ -268,6 +301,12 @@ MeloTtsLexicon::MeloTtsLexicon(const std::string &lexicon, @@ -268,6 +301,12 @@ MeloTtsLexicon::MeloTtsLexicon(const std::string &lexicon,
268 : impl_(std::make_unique<Impl>(lexicon, tokens, dict_dir, meta_data, 301 : impl_(std::make_unique<Impl>(lexicon, tokens, dict_dir, meta_data,
269 debug)) {} 302 debug)) {}
270 303
  304 +MeloTtsLexicon::MeloTtsLexicon(const std::string &lexicon,
  305 + const std::string &tokens,
  306 + const OfflineTtsVitsModelMetaData &meta_data,
  307 + bool debug)
  308 + : impl_(std::make_unique<Impl>(lexicon, tokens, meta_data, debug)) {}
  309 +
271 std::vector<TokenIDs> MeloTtsLexicon::ConvertTextToTokenIds( 310 std::vector<TokenIDs> MeloTtsLexicon::ConvertTextToTokenIds(
272 const std::string &text, const std::string & /*unused_voice = ""*/) const { 311 const std::string &text, const std::string & /*unused_voice = ""*/) const {
273 return impl_->ConvertTextToTokenIds(text); 312 return impl_->ConvertTextToTokenIds(text);
@@ -22,6 +22,9 @@ class MeloTtsLexicon : public OfflineTtsFrontend { @@ -22,6 +22,9 @@ class MeloTtsLexicon : public OfflineTtsFrontend {
22 const std::string &dict_dir, 22 const std::string &dict_dir,
23 const OfflineTtsVitsModelMetaData &meta_data, bool debug); 23 const OfflineTtsVitsModelMetaData &meta_data, bool debug);
24 24
  25 + MeloTtsLexicon(const std::string &lexicon, const std::string &tokens,
  26 + const OfflineTtsVitsModelMetaData &meta_data, bool debug);
  27 +
25 std::vector<TokenIDs> ConvertTextToTokenIds( 28 std::vector<TokenIDs> ConvertTextToTokenIds(
26 const std::string &text, 29 const std::string &text,
27 const std::string &unused_voice = "") const override; 30 const std::string &unused_voice = "") const override;
@@ -349,6 +349,10 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -349,6 +349,10 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
349 config_.model.vits.lexicon, config_.model.vits.tokens, 349 config_.model.vits.lexicon, config_.model.vits.tokens,
350 config_.model.vits.dict_dir, model_->GetMetaData(), 350 config_.model.vits.dict_dir, model_->GetMetaData(),
351 config_.model.debug); 351 config_.model.debug);
  352 + } else if (meta_data.is_melo_tts && meta_data.language == "English") {
  353 + frontend_ = std::make_unique<MeloTtsLexicon>(
  354 + config_.model.vits.lexicon, config_.model.vits.tokens,
  355 + model_->GetMetaData(), config_.model.debug);
352 } else if (meta_data.jieba && !config_.model.vits.dict_dir.empty()) { 356 } else if (meta_data.jieba && !config_.model.vits.dict_dir.empty()) {
353 frontend_ = std::make_unique<JiebaLexicon>( 357 frontend_ = std::make_unique<JiebaLexicon>(
354 config_.model.vits.lexicon, config_.model.vits.tokens, 358 config_.model.vits.lexicon, config_.model.vits.tokens,
@@ -46,8 +46,10 @@ class OfflineTtsVitsModel::Impl { @@ -46,8 +46,10 @@ class OfflineTtsVitsModel::Impl {
46 } 46 }
47 47
48 Ort::Value Run(Ort::Value x, Ort::Value tones, int64_t sid, float speed) { 48 Ort::Value Run(Ort::Value x, Ort::Value tones, int64_t sid, float speed) {
49 - // For MeloTTS, we hardcode sid to the one contained in the meta data  
50 - sid = meta_data_.speaker_id; 49 + if (meta_data_.num_speakers == 1) {
  50 + // For MeloTTS, we hardcode sid to the one contained in the meta data
  51 + sid = meta_data_.speaker_id;
  52 + }
51 53
52 auto memory_info = 54 auto memory_info =
53 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); 55 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
@@ -408,10 +408,10 @@ std::string LookupCustomModelMetaData(const Ort::ModelMetadata &meta_data, @@ -408,10 +408,10 @@ std::string LookupCustomModelMetaData(const Ort::ModelMetadata &meta_data,
408 // For other versions, we may need to change it 408 // For other versions, we may need to change it
409 #if ORT_API_VERSION >= 12 409 #if ORT_API_VERSION >= 12
410 auto v = meta_data.LookupCustomMetadataMapAllocated(key, allocator); 410 auto v = meta_data.LookupCustomMetadataMapAllocated(key, allocator);
411 - return v.get(); 411 + return v ? v.get() : "";
412 #else 412 #else
413 auto v = meta_data.LookupCustomMetadataMap(key, allocator); 413 auto v = meta_data.LookupCustomMetadataMap(key, allocator);
414 - std::string ans = v; 414 + std::string ans = v ? v : "";
415 allocator->Free(allocator, v); 415 allocator->Free(allocator, v);
416 return ans; 416 return ans;
417 #endif 417 #endif