Fangjun Kuang
Committed by GitHub

Export the English TTS model from MeloTTS (#1509)

... ... @@ -40,7 +40,7 @@ jobs:
name: test.wav
path: scripts/melo-tts/test.wav
- name: Publish to huggingface
- name: Publish to huggingface (Chinese + English)
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v3
... ... @@ -61,14 +61,14 @@ jobs:
git fetch
git pull
echo "pwd: $PWD"
ls -lh ../scripts/melo-tts
ls -lh ../scripts/melo-tts/zh_en
rm -rf ./
cp -v ../scripts/melo-tts/*.onnx .
cp -v ../scripts/melo-tts/lexicon.txt .
cp -v ../scripts/melo-tts/tokens.txt .
cp -v ../scripts/melo-tts/README.md .
cp -v ../scripts/melo-tts/zh_en/*.onnx .
cp -v ../scripts/melo-tts/zh_en/lexicon.txt .
cp -v ../scripts/melo-tts/zh_en/tokens.txt .
cp -v ../scripts/melo-tts/zh_en/README.md .
curl -SL -O https://raw.githubusercontent.com/myshell-ai/MeloTTS/main/LICENSE
... ... @@ -102,6 +102,60 @@ jobs:
tar cjvf $dst.tar.bz2 $dst
rm -rf $dst
- name: Publish to huggingface (English)
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v3
with:
max_attempts: 20
timeout_seconds: 200
shell: bash
command: |
git config --global user.email "csukuangfj@gmail.com"
git config --global user.name "Fangjun Kuang"
rm -rf huggingface
export GIT_LFS_SKIP_SMUDGE=1
export GIT_CLONE_PROTECTION_ACTIVE=false
git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/vits-melo-tts-en huggingface
cd huggingface
git fetch
git pull
echo "pwd: $PWD"
ls -lh ../scripts/melo-tts/en
rm -rf ./
cp -v ../scripts/melo-tts/en/*.onnx .
cp -v ../scripts/melo-tts/en/lexicon.txt .
cp -v ../scripts/melo-tts/en/tokens.txt .
cp -v ../scripts/melo-tts/en/README.md .
curl -SL -O https://raw.githubusercontent.com/myshell-ai/MeloTTS/main/LICENSE
git lfs track "*.onnx"
git add .
ls -lh
git status
git diff
git commit -m "add models"
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/vits-melo-tts-en main || true
cd ..
rm -rf huggingface/.git*
dst=vits-melo-tts-en
mv huggingface $dst
tar cjvf $dst.tar.bz2 $dst
rm -rf $dst
- name: Release
uses: svenstaro/upload-release-action@v2
with:
... ...
... ... @@ -3,4 +3,5 @@
Models in this directory are converted from
https://github.com/myshell-ai/MeloTTS
Note there is only a single female speaker in the model.
Note there is only a single female speaker in the model for Chinese+English TTS.
TTS model, whereas there are 5 female speakers in the model For English TTS.
... ...
#!/usr/bin/env python3
# This model exports the English-only TTS model.
# It has 5 speakers.
# {'EN-US': 0, 'EN-BR': 1, 'EN_INDIA': 2, 'EN-AU': 3, 'EN-Default': 4}
from typing import Any, Dict
import onnx
import torch
from melo.api import TTS
from melo.text import language_id_map, language_tone_start_map
from melo.text.chinese import pinyin_to_symbol_map
from melo.text.english import eng_dict, refine_syllables
from pypinyin import Style, lazy_pinyin, phrases_dict, pinyin_dict
def generate_tokens(symbol_list):
with open("tokens.txt", "w", encoding="utf-8") as f:
for i, s in enumerate(symbol_list):
f.write(f"{s} {i}\n")
def add_new_english_words(lexicon):
"""
Args:
lexicon:
Please modify it in-place.
"""
# Please have a look at
# https://github.com/myshell-ai/MeloTTS/blob/main/melo/text/cmudict.rep
# We give several examples below about how to add new words
# Example 1. Add a new word kaldi
# It does not contain the word kaldi in cmudict.rep
# so if we add the following line to cmudict.rep
#
# KALDI K AH0 - L D IH0
#
# then we need to change the lexicon like below
lexicon["kaldi"] = [["K", "AH0"], ["L", "D", "IH0"]]
#
# K AH0 and L D IH0 are separated by a dash "-", so
# ["K", "AH0"] is a in list and ["L", "D", "IH0"] is in a separate list
# Note: Either kaldi or KALDI is fine. You can use either lowercase or
# uppercase or both
# Example 2. Add a new word SF
#
# If we add the following line to cmudict.rep
#
# SF EH1 S - EH1 F
#
# to cmudict.rep, then we need to change the lexicon like below:
lexicon["SF"] = [["EH1", "S"], ["EH1", "F"]]
# Please add your new words here
# No need to return lexicon since it is changed in-place
def generate_lexicon():
add_new_english_words(eng_dict)
with open("lexicon.txt", "w", encoding="utf-8") as f:
for word in eng_dict:
phones, tones = refine_syllables(eng_dict[word])
tones = [t + language_tone_start_map["EN"] for t in tones]
tones = [str(t) for t in tones]
phones = " ".join(phones)
tones = " ".join(tones)
f.write(f"{word.lower()} {phones} {tones}\n")
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
while len(model.metadata_props):
model.metadata_props.pop()
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
class ModelWrapper(torch.nn.Module):
def __init__(self, model: "SynthesizerTrn"):
super().__init__()
self.model = model
self.lang_id = language_id_map[model.language]
def forward(
self,
x,
x_lengths,
tones,
sid,
noise_scale,
length_scale,
noise_scale_w,
max_len=None,
):
"""
Args:
x: A 1-D array of dtype np.int64. Its shape is (token_numbers,)
tones: A 1-D array of dtype np.int64. Its shape is (token_numbers,)
lang_id: A 1-D array of dtype np.int64. Its shape is (token_numbers,)
sid: an integer
"""
bert = torch.zeros(x.shape[0], 1024, x.shape[1], dtype=torch.float32)
ja_bert = torch.zeros(x.shape[0], 768, x.shape[1], dtype=torch.float32)
lang_id = torch.zeros_like(x)
lang_id[:, 1::2] = self.lang_id
return self.model.model.infer(
x=x,
x_lengths=x_lengths,
sid=sid,
tone=tones,
language=lang_id,
bert=bert,
ja_bert=ja_bert,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
)[0]
def main():
generate_lexicon()
language = "EN"
model = TTS(language=language, device="cpu")
generate_tokens(model.hps["symbols"])
torch_model = ModelWrapper(model)
opset_version = 13
x = torch.randint(low=0, high=10, size=(60,), dtype=torch.int64)
print(x.shape)
x_lengths = torch.tensor([x.size(0)], dtype=torch.int64)
sid = torch.tensor([1], dtype=torch.int64)
tones = torch.zeros_like(x)
noise_scale = torch.tensor([1.0], dtype=torch.float32)
length_scale = torch.tensor([1.0], dtype=torch.float32)
noise_scale_w = torch.tensor([1.0], dtype=torch.float32)
x = x.unsqueeze(0)
tones = tones.unsqueeze(0)
filename = "model.onnx"
torch.onnx.export(
torch_model,
(
x,
x_lengths,
tones,
sid,
noise_scale,
length_scale,
noise_scale_w,
),
filename,
opset_version=opset_version,
input_names=[
"x",
"x_lengths",
"tones",
"sid",
"noise_scale",
"length_scale",
"noise_scale_w",
],
output_names=["y"],
dynamic_axes={
"x": {0: "N", 1: "L"},
"x_lengths": {0: "N"},
"tones": {0: "N", 1: "L"},
"y": {0: "N", 1: "S", 2: "T"},
},
)
meta_data = {
"model_type": "melo-vits",
"comment": "melo",
"version": 2,
"language": "English",
"add_blank": int(model.hps.data.add_blank),
"n_speakers": len(model.hps.data.spk2id), # 5
"jieba": 0,
"sample_rate": model.hps.data.sampling_rate,
"bert_dim": 1024,
"ja_bert_dim": 768,
"speaker_id": 0,
"lang_id": language_id_map[model.language],
"tone_start": language_tone_start_map[model.language],
"url": "https://github.com/myshell-ai/MeloTTS",
"license": "MIT license",
"description": "MeloTTS is a high-quality multi-lingual text-to-speech library by MyShell.ai",
}
add_meta_data(filename, meta_data)
if __name__ == "__main__":
main()
... ...
#!/usr/bin/env python3
# This script export ZH_EN TTS model, which supports both Chinese and English.
# This model has only 1 speaker.
from typing import Any, Dict
import onnx
... ...
... ... @@ -38,4 +38,24 @@ tail tokens.txt
./test.py
mkdir zh_en
mv -v *.onnx zh_en/
mv -v lexicon.txt zh_en
mv -v tokens.txt zh_en
cp -v README.md zh_en
ls -lh
echo "---"
ls -lh zh_en
./export-onnx-en.py
mkdir en
mv -v *.onnx en/
mv -v lexicon.txt en
mv -v tokens.txt en
cp -v README.md en
ls -lh en
ls -lh
... ...
... ... @@ -152,10 +152,6 @@
#define SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(dst, src_key) \
do { \
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
if (value.empty()) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
\
dst = std::move(value); \
} while (0)
... ...
... ... @@ -48,6 +48,20 @@ class MeloTtsLexicon::Impl {
}
}
Impl(const std::string &lexicon, const std::string &tokens,
const OfflineTtsVitsModelMetaData &meta_data, bool debug)
: meta_data_(meta_data), debug_(debug) {
{
std::ifstream is(tokens);
InitTokens(is);
}
{
std::ifstream is(lexicon);
InitLexicon(is);
}
}
std::vector<TokenIDs> ConvertTextToTokenIds(const std::string &_text) const {
std::string text = ToLowerCase(_text);
// see
... ... @@ -65,21 +79,39 @@ class MeloTtsLexicon::Impl {
s = std::regex_replace(s, punct_re4, "!");
std::vector<std::string> words;
bool is_hmm = true;
jieba_->Cut(text, words, is_hmm);
if (debug_) {
SHERPA_ONNX_LOGE("input text: %s", text.c_str());
SHERPA_ONNX_LOGE("after replacing punctuations: %s", s.c_str());
std::ostringstream os;
std::string sep = "";
for (const auto &w : words) {
os << sep << w;
sep = "_";
}
if (jieba_) {
bool is_hmm = true;
jieba_->Cut(text, words, is_hmm);
if (debug_) {
SHERPA_ONNX_LOGE("input text: %s", text.c_str());
SHERPA_ONNX_LOGE("after replacing punctuations: %s", s.c_str());
std::ostringstream os;
std::string sep = "";
for (const auto &w : words) {
os << sep << w;
sep = "_";
}
SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str());
SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str());
}
} else {
words = SplitUtf8(text);
if (debug_) {
fprintf(stderr, "Input text in string (lowercase): %s\n", text.c_str());
fprintf(stderr, "Input text in bytes (lowercase):");
for (uint8_t c : text) {
fprintf(stderr, " %02x", c);
}
fprintf(stderr, "\n");
fprintf(stderr, "After splitting to words:");
for (const auto &w : words) {
fprintf(stderr, " %s", w.c_str());
}
fprintf(stderr, "\n");
}
}
std::vector<TokenIDs> ans;
... ... @@ -241,6 +273,7 @@ class MeloTtsLexicon::Impl {
{std::move(word), TokenIDs{std::move(ids64), std::move(tone_list)}});
}
// For Chinese+English MeloTTS
word2ids_["呣"] = word2ids_["母"];
word2ids_["嗯"] = word2ids_["恩"];
}
... ... @@ -268,6 +301,12 @@ MeloTtsLexicon::MeloTtsLexicon(const std::string &lexicon,
: impl_(std::make_unique<Impl>(lexicon, tokens, dict_dir, meta_data,
debug)) {}
MeloTtsLexicon::MeloTtsLexicon(const std::string &lexicon,
const std::string &tokens,
const OfflineTtsVitsModelMetaData &meta_data,
bool debug)
: impl_(std::make_unique<Impl>(lexicon, tokens, meta_data, debug)) {}
std::vector<TokenIDs> MeloTtsLexicon::ConvertTextToTokenIds(
const std::string &text, const std::string & /*unused_voice = ""*/) const {
return impl_->ConvertTextToTokenIds(text);
... ...
... ... @@ -22,6 +22,9 @@ class MeloTtsLexicon : public OfflineTtsFrontend {
const std::string &dict_dir,
const OfflineTtsVitsModelMetaData &meta_data, bool debug);
MeloTtsLexicon(const std::string &lexicon, const std::string &tokens,
const OfflineTtsVitsModelMetaData &meta_data, bool debug);
std::vector<TokenIDs> ConvertTextToTokenIds(
const std::string &text,
const std::string &unused_voice = "") const override;
... ...
... ... @@ -349,6 +349,10 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
config_.model.vits.lexicon, config_.model.vits.tokens,
config_.model.vits.dict_dir, model_->GetMetaData(),
config_.model.debug);
} else if (meta_data.is_melo_tts && meta_data.language == "English") {
frontend_ = std::make_unique<MeloTtsLexicon>(
config_.model.vits.lexicon, config_.model.vits.tokens,
model_->GetMetaData(), config_.model.debug);
} else if (meta_data.jieba && !config_.model.vits.dict_dir.empty()) {
frontend_ = std::make_unique<JiebaLexicon>(
config_.model.vits.lexicon, config_.model.vits.tokens,
... ...
... ... @@ -46,8 +46,10 @@ class OfflineTtsVitsModel::Impl {
}
Ort::Value Run(Ort::Value x, Ort::Value tones, int64_t sid, float speed) {
// For MeloTTS, we hardcode sid to the one contained in the meta data
sid = meta_data_.speaker_id;
if (meta_data_.num_speakers == 1) {
// For MeloTTS, we hardcode sid to the one contained in the meta data
sid = meta_data_.speaker_id;
}
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
... ...
... ... @@ -408,10 +408,10 @@ std::string LookupCustomModelMetaData(const Ort::ModelMetadata &meta_data,
// For other versions, we may need to change it
#if ORT_API_VERSION >= 12
auto v = meta_data.LookupCustomMetadataMapAllocated(key, allocator);
return v.get();
return v ? v.get() : "";
#else
auto v = meta_data.LookupCustomMetadataMap(key, allocator);
std::string ans = v;
std::string ans = v ? v : "";
allocator->Free(allocator, v);
return ans;
#endif
... ...