Fangjun Kuang
Committed by GitHub

Use espeak-ng for coqui-ai/TTS VITS English models. (#466)

@@ -23,7 +23,7 @@ data class OfflineTtsModelConfig( @@ -23,7 +23,7 @@ data class OfflineTtsModelConfig(
23 data class OfflineTtsConfig( 23 data class OfflineTtsConfig(
24 var model: OfflineTtsModelConfig, 24 var model: OfflineTtsModelConfig,
25 var ruleFsts: String = "", 25 var ruleFsts: String = "",
26 - var maxNumSentences: Int = 2, 26 + var maxNumSentences: Int = 1,
27 ) 27 )
28 28
29 class GeneratedAudio( 29 class GeneratedAudio(
@@ -311,6 +311,9 @@ def main(): @@ -311,6 +311,9 @@ def main():
311 311
312 if len(audio.samples) == 0: 312 if len(audio.samples) == 0:
313 print("Error in generating audios. Please read previous error messages.") 313 print("Error in generating audios. Please read previous error messages.")
  314 + global killed
  315 + killed = True
  316 + play_back_thread.join()
314 return 317 return
315 318
316 elapsed_seconds = end - start 319 elapsed_seconds = end - start
@@ -33,6 +33,23 @@ class TtsModel: @@ -33,6 +33,23 @@ class TtsModel:
33 data_dir: Optional[str] = None 33 data_dir: Optional[str] = None
34 34
35 35
  36 +def get_coqui_models() -> List[TtsModel]:
  37 + # English (coqui-ai/TTS)
  38 + models = [
  39 + TtsModel(model_dir="vits-coqui-en-ljspeech"),
  40 + TtsModel(model_dir="vits-coqui-en-ljspeech-neon"),
  41 + TtsModel(model_dir="vits-coqui-en-vctk"),
  42 + # TtsModel(model_dir="vits-coqui-en-jenny"),
  43 + ]
  44 +
  45 + for m in models:
  46 + m.data_dir = m.model_dir + "/" + "espeak-ng-data"
  47 + m.model_name = "model.onnx"
  48 + m.lang = "en"
  49 +
  50 + return models
  51 +
  52 +
36 def get_piper_models() -> List[TtsModel]: 53 def get_piper_models() -> List[TtsModel]:
37 models = [ 54 models = [
38 TtsModel(model_dir="vits-piper-ar_JO-kareem-low"), 55 TtsModel(model_dir="vits-piper-ar_JO-kareem-low"),
@@ -137,6 +154,7 @@ def get_piper_models() -> List[TtsModel]: @@ -137,6 +154,7 @@ def get_piper_models() -> List[TtsModel]:
137 TtsModel(model_dir="vits-piper-vi_VN-vivos-x_low"), 154 TtsModel(model_dir="vits-piper-vi_VN-vivos-x_low"),
138 TtsModel(model_dir="vits-piper-zh_CN-huayan-medium"), 155 TtsModel(model_dir="vits-piper-zh_CN-huayan-medium"),
139 ] 156 ]
  157 +
140 for m in models: 158 for m in models:
141 m.data_dir = m.model_dir + "/" + "espeak-ng-data" 159 m.data_dir = m.model_dir + "/" + "espeak-ng-data"
142 m.model_name = m.model_dir[len("vits-piper-") :] + ".onnx" 160 m.model_name = m.model_dir[len("vits-piper-") :] + ".onnx"
@@ -145,7 +163,7 @@ def get_piper_models() -> List[TtsModel]: @@ -145,7 +163,7 @@ def get_piper_models() -> List[TtsModel]:
145 return models 163 return models
146 164
147 165
148 -def get_all_models() -> List[TtsModel]: 166 +def get_vits_models() -> List[TtsModel]:
149 return [ 167 return [
150 # Chinese 168 # Chinese
151 TtsModel( 169 TtsModel(
@@ -202,12 +220,6 @@ def get_all_models() -> List[TtsModel]: @@ -202,12 +220,6 @@ def get_all_models() -> List[TtsModel]:
202 lang="zh", 220 lang="zh",
203 rule_fsts="vits-zh-hf-theresa/rule.fst", 221 rule_fsts="vits-zh-hf-theresa/rule.fst",
204 ), 222 ),
205 - # English (coqui-ai/TTS)  
206 - # fmt: off  
207 - TtsModel(model_dir="vits-coqui-en-ljspeech", model_name="model.onnx", lang="en"),  
208 - TtsModel(model_dir="vits-coqui-en-ljspeech-neon", model_name="model.onnx", lang="en"),  
209 - TtsModel(model_dir="vits-coqui-en-vctk", model_name="model.onnx", lang="en"),  
210 - # TtsModel(model_dir="vits-coqui-en-jenny", model_name="model.onnx", lang="en"),  
211 # English (US) 223 # English (US)
212 TtsModel(model_dir="vits-vctk", model_name="vits-vctk.onnx", lang="en"), 224 TtsModel(model_dir="vits-vctk", model_name="vits-vctk.onnx", lang="en"),
213 TtsModel(model_dir="vits-ljs", model_name="vits-ljs.onnx", lang="en"), 225 TtsModel(model_dir="vits-ljs", model_name="vits-ljs.onnx", lang="en"),
@@ -225,8 +237,11 @@ def main(): @@ -225,8 +237,11 @@ def main():
225 s = f.read() 237 s = f.read()
226 template = environment.from_string(s) 238 template = environment.from_string(s)
227 d = dict() 239 d = dict()
228 - # all_model_list = get_all_models() 240 +
  241 + # all_model_list = get_vits_models()
229 all_model_list = get_piper_models() 242 all_model_list = get_piper_models()
  243 + all_model_list += get_coqui_models()
  244 +
230 num_models = len(all_model_list) 245 num_models = len(all_model_list)
231 246
232 num_per_runner = num_models // total 247 num_per_runner = num_models // total
@@ -69,12 +69,16 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -69,12 +69,16 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
69 } 69 }
70 #endif 70 #endif
71 71
72 - int32_t SampleRate() const override { return model_->SampleRate(); } 72 + int32_t SampleRate() const override {
  73 + return model_->GetMetaData().sample_rate;
  74 + }
73 75
74 GeneratedAudio Generate( 76 GeneratedAudio Generate(
75 const std::string &_text, int64_t sid = 0, float speed = 1.0, 77 const std::string &_text, int64_t sid = 0, float speed = 1.0,
76 GeneratedAudioCallback callback = nullptr) const override { 78 GeneratedAudioCallback callback = nullptr) const override {
77 - int32_t num_speakers = model_->NumSpeakers(); 79 + const auto &meta_data = model_->GetMetaData();
  80 + int32_t num_speakers = meta_data.num_speakers;
  81 +
78 if (num_speakers == 0 && sid != 0) { 82 if (num_speakers == 0 && sid != 0) {
79 SHERPA_ONNX_LOGE( 83 SHERPA_ONNX_LOGE(
80 "This is a single-speaker model and supports only sid 0. Given sid: " 84 "This is a single-speaker model and supports only sid 0. Given sid: "
@@ -105,14 +109,14 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -105,14 +109,14 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
105 } 109 }
106 110
107 std::vector<std::vector<int64_t>> x = 111 std::vector<std::vector<int64_t>> x =
108 - frontend_->ConvertTextToTokenIds(text, model_->Voice()); 112 + frontend_->ConvertTextToTokenIds(text, meta_data.voice);
109 113
110 if (x.empty() || (x.size() == 1 && x[0].empty())) { 114 if (x.empty() || (x.size() == 1 && x[0].empty())) {
111 SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str()); 115 SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
112 return {}; 116 return {};
113 } 117 }
114 118
115 - if (model_->AddBlank() && config_.model.vits.data_dir.empty()) { 119 + if (meta_data.add_blank && config_.model.vits.data_dir.empty()) {
116 for (auto &k : x) { 120 for (auto &k : x) {
117 k = AddBlank(k); 121 k = AddBlank(k);
118 } 122 }
@@ -189,25 +193,33 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -189,25 +193,33 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
189 private: 193 private:
190 #if __ANDROID_API__ >= 9 194 #if __ANDROID_API__ >= 9
191 void InitFrontend(AAssetManager *mgr) { 195 void InitFrontend(AAssetManager *mgr) {
192 - if (model_->IsPiper() && !config_.model.vits.data_dir.empty()) { 196 + const auto &meta_data = model_->GetMetaData();
  197 +
  198 + if ((meta_data.is_piper || meta_data.is_coqui) &&
  199 + !config_.model.vits.data_dir.empty()) {
193 frontend_ = std::make_unique<PiperPhonemizeLexicon>( 200 frontend_ = std::make_unique<PiperPhonemizeLexicon>(
194 - mgr, config_.model.vits.tokens, config_.model.vits.data_dir); 201 + mgr, config_.model.vits.tokens, config_.model.vits.data_dir,
  202 + meta_data);
195 } else { 203 } else {
196 frontend_ = std::make_unique<Lexicon>( 204 frontend_ = std::make_unique<Lexicon>(
197 mgr, config_.model.vits.lexicon, config_.model.vits.tokens, 205 mgr, config_.model.vits.lexicon, config_.model.vits.tokens,
198 - model_->Punctuations(), model_->Language(), config_.model.debug); 206 + meta_data.punctuations, meta_data.language, config_.model.debug);
199 } 207 }
200 } 208 }
201 #endif 209 #endif
202 210
203 void InitFrontend() { 211 void InitFrontend() {
204 - if (model_->IsPiper() && !config_.model.vits.data_dir.empty()) { 212 + const auto &meta_data = model_->GetMetaData();
  213 +
  214 + if ((meta_data.is_piper || meta_data.is_coqui) &&
  215 + !config_.model.vits.data_dir.empty()) {
205 frontend_ = std::make_unique<PiperPhonemizeLexicon>( 216 frontend_ = std::make_unique<PiperPhonemizeLexicon>(
206 - config_.model.vits.tokens, config_.model.vits.data_dir); 217 + config_.model.vits.tokens, config_.model.vits.data_dir,
  218 + model_->GetMetaData());
207 } else { 219 } else {
208 frontend_ = std::make_unique<Lexicon>( 220 frontend_ = std::make_unique<Lexicon>(
209 config_.model.vits.lexicon, config_.model.vits.tokens, 221 config_.model.vits.lexicon, config_.model.vits.tokens,
210 - model_->Punctuations(), model_->Language(), config_.model.debug); 222 + meta_data.punctuations, meta_data.language, config_.model.debug);
211 } 223 }
212 } 224 }
213 225
@@ -256,7 +268,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -256,7 +268,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
256 const float *p = audio.GetTensorData<float>(); 268 const float *p = audio.GetTensorData<float>();
257 269
258 GeneratedAudio ans; 270 GeneratedAudio ans;
259 - ans.sample_rate = model_->SampleRate(); 271 + ans.sample_rate = model_->GetMetaData().sample_rate;
260 ans.samples = std::vector<float>(p, p + total); 272 ans.samples = std::vector<float>(p, p + total);
261 return ans; 273 return ans;
262 } 274 }
@@ -46,7 +46,8 @@ bool OfflineTtsVitsModelConfig::Validate() const { @@ -46,7 +46,8 @@ bool OfflineTtsVitsModelConfig::Validate() const {
46 46
47 if (data_dir.empty()) { 47 if (data_dir.empty()) {
48 if (lexicon.empty()) { 48 if (lexicon.empty()) {
49 - SHERPA_ONNX_LOGE("Please provide --vits-lexicon"); 49 + SHERPA_ONNX_LOGE(
  50 + "Please provide --vits-lexicon if you leave --vits-data-dir empty");
50 return false; 51 return false;
51 } 52 }
52 53
  1 +// sherpa-onnx/csrc/offline-tts-vits-model-metadata.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_METADATA_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_METADATA_H_
  7 +
  8 +#include <cstdint>
  9 +#include <string>
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +struct OfflineTtsVitsModelMetaData {
  14 + int32_t sample_rate;
  15 + int32_t add_blank = 0;
  16 + int32_t num_speakers = 0;
  17 +
  18 + std::string punctuations;
  19 + std::string language;
  20 + std::string voice;
  21 +
  22 + bool is_piper = false;
  23 + bool is_coqui = false;
  24 +
  25 + // the following options are for models from coqui-ai/TTS
  26 + int32_t blank_id = 0;
  27 + int32_t bos_id = 0;
  28 + int32_t eos_id = 0;
  29 + int32_t use_eos_bos = 0;
  30 +};
  31 +
  32 +} // namespace sherpa_onnx
  33 +
  34 +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_METADATA_H_
@@ -38,22 +38,14 @@ class OfflineTtsVitsModel::Impl { @@ -38,22 +38,14 @@ class OfflineTtsVitsModel::Impl {
38 #endif 38 #endif
39 39
40 Ort::Value Run(Ort::Value x, int64_t sid, float speed) { 40 Ort::Value Run(Ort::Value x, int64_t sid, float speed) {
41 - if (is_piper_) {  
42 - return RunVitsPiper(std::move(x), sid, speed); 41 + if (meta_data_.is_piper || meta_data_.is_coqui) {
  42 + return RunVitsPiperOrCoqui(std::move(x), sid, speed);
43 } 43 }
44 44
45 return RunVits(std::move(x), sid, speed); 45 return RunVits(std::move(x), sid, speed);
46 } 46 }
47 47
48 - int32_t SampleRate() const { return sample_rate_; }  
49 -  
50 - bool AddBlank() const { return add_blank_; }  
51 -  
52 - std::string Punctuations() const { return punctuations_; }  
53 - std::string Language() const { return language_; }  
54 - std::string Voice() const { return voice_; }  
55 - bool IsPiper() const { return is_piper_; }  
56 - int32_t NumSpeakers() const { return num_speakers_; } 48 + const OfflineTtsVitsModelMetaData &GetMetaData() const { return meta_data_; }
57 49
58 private: 50 private:
59 void Init(void *model_data, size_t model_data_length) { 51 void Init(void *model_data, size_t model_data_length) {
@@ -70,27 +62,52 @@ class OfflineTtsVitsModel::Impl { @@ -70,27 +62,52 @@ class OfflineTtsVitsModel::Impl {
70 std::ostringstream os; 62 std::ostringstream os;
71 os << "---vits model---\n"; 63 os << "---vits model---\n";
72 PrintModelMetadata(os, meta_data); 64 PrintModelMetadata(os, meta_data);
  65 +
  66 + os << "----------input names----------\n";
  67 + int32_t i = 0;
  68 + for (const auto &s : input_names_) {
  69 + os << i << " " << s << "\n";
  70 + ++i;
  71 + }
  72 + os << "----------output names----------\n";
  73 + i = 0;
  74 + for (const auto &s : output_names_) {
  75 + os << i << " " << s << "\n";
  76 + ++i;
  77 + }
  78 +
73 SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); 79 SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
74 } 80 }
75 81
76 Ort::AllocatorWithDefaultOptions allocator; // used in the macro below 82 Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
77 - SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate");  
78 - SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(add_blank_, "add_blank", 0);  
79 - SHERPA_ONNX_READ_META_DATA(num_speakers_, "n_speakers");  
80 - SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(punctuations_, "punctuation",  
81 - "");  
82 - SHERPA_ONNX_READ_META_DATA_STR(language_, "language");  
83 - SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(voice_, "voice", ""); 83 + SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate");
  84 + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.add_blank, "add_blank",
  85 + 0);
  86 + SHERPA_ONNX_READ_META_DATA(meta_data_.num_speakers, "n_speakers");
  87 + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.punctuations,
  88 + "punctuation", "");
  89 + SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language");
  90 + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.voice, "voice", "");
  91 +
  92 + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.blank_id, "blank_id", 0);
  93 + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.bos_id, "bos_id", 0);
  94 + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.eos_id, "eos_id", 0);
  95 + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.use_eos_bos,
  96 + "use_eos_bos", 0);
84 97
85 std::string comment; 98 std::string comment;
86 SHERPA_ONNX_READ_META_DATA_STR(comment, "comment"); 99 SHERPA_ONNX_READ_META_DATA_STR(comment, "comment");
87 - if (comment.find("piper") != std::string::npos ||  
88 - comment.find("coqui") != std::string::npos) {  
89 - is_piper_ = true; 100 +
  101 + if (comment.find("piper") != std::string::npos) {
  102 + meta_data_.is_piper = true;
  103 + }
  104 +
  105 + if (comment.find("coqui") != std::string::npos) {
  106 + meta_data_.is_coqui = true;
90 } 107 }
91 } 108 }
92 109
93 - Ort::Value RunVitsPiper(Ort::Value x, int64_t sid, float speed) { 110 + Ort::Value RunVitsPiperOrCoqui(Ort::Value x, int64_t sid, float speed) {
94 auto memory_info = 111 auto memory_info =
95 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); 112 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
96 113
@@ -213,14 +230,7 @@ class OfflineTtsVitsModel::Impl { @@ -213,14 +230,7 @@ class OfflineTtsVitsModel::Impl {
213 std::vector<std::string> output_names_; 230 std::vector<std::string> output_names_;
214 std::vector<const char *> output_names_ptr_; 231 std::vector<const char *> output_names_ptr_;
215 232
216 - int32_t sample_rate_;  
217 - int32_t add_blank_;  
218 - int32_t num_speakers_;  
219 - std::string punctuations_;  
220 - std::string language_;  
221 - std::string voice_;  
222 -  
223 - bool is_piper_ = false; 233 + OfflineTtsVitsModelMetaData meta_data_;
224 }; 234 };
225 235
226 OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config) 236 OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config)
@@ -239,21 +249,8 @@ Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, int64_t sid /*=0*/, @@ -239,21 +249,8 @@ Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, int64_t sid /*=0*/,
239 return impl_->Run(std::move(x), sid, speed); 249 return impl_->Run(std::move(x), sid, speed);
240 } 250 }
241 251
242 -int32_t OfflineTtsVitsModel::SampleRate() const { return impl_->SampleRate(); }  
243 -  
244 -bool OfflineTtsVitsModel::AddBlank() const { return impl_->AddBlank(); }  
245 -  
246 -std::string OfflineTtsVitsModel::Punctuations() const {  
247 - return impl_->Punctuations();  
248 -}  
249 -  
250 -std::string OfflineTtsVitsModel::Language() const { return impl_->Language(); }  
251 -std::string OfflineTtsVitsModel::Voice() const { return impl_->Voice(); }  
252 -  
253 -bool OfflineTtsVitsModel::IsPiper() const { return impl_->IsPiper(); }  
254 -  
255 -int32_t OfflineTtsVitsModel::NumSpeakers() const {  
256 - return impl_->NumSpeakers(); 252 +const OfflineTtsVitsModelMetaData &OfflineTtsVitsModel::GetMetaData() const {
  253 + return impl_->GetMetaData();
257 } 254 }
258 255
259 } // namespace sherpa_onnx 256 } // namespace sherpa_onnx
@@ -15,6 +15,7 @@ @@ -15,6 +15,7 @@
15 15
16 #include "onnxruntime_cxx_api.h" // NOLINT 16 #include "onnxruntime_cxx_api.h" // NOLINT
17 #include "sherpa-onnx/csrc/offline-tts-model-config.h" 17 #include "sherpa-onnx/csrc/offline-tts-model-config.h"
  18 +#include "sherpa-onnx/csrc/offline-tts-vits-model-metadata.h"
18 19
19 namespace sherpa_onnx { 20 namespace sherpa_onnx {
20 21
@@ -39,17 +40,7 @@ class OfflineTtsVitsModel { @@ -39,17 +40,7 @@ class OfflineTtsVitsModel {
39 */ 40 */
40 Ort::Value Run(Ort::Value x, int64_t sid = 0, float speed = 1.0); 41 Ort::Value Run(Ort::Value x, int64_t sid = 0, float speed = 1.0);
41 42
42 - // Sample rate of the generated audio  
43 - int32_t SampleRate() const;  
44 -  
45 - // true to insert a blank between each token  
46 - bool AddBlank() const;  
47 -  
48 - std::string Punctuations() const;  
49 - std::string Language() const; // e.g., Chinese, English, German, etc.  
50 - std::string Voice() const; // e.g., en-us, for espeak-ng  
51 - bool IsPiper() const;  
52 - int32_t NumSpeakers() const; 43 + const OfflineTtsVitsModelMetaData &GetMetaData() const;
53 44
54 private: 45 private:
55 class Impl; 46 class Impl;
@@ -57,10 +57,17 @@ static std::unordered_map<char32_t, int32_t> ReadTokens(std::istream &is) { @@ -57,10 +57,17 @@ static std::unordered_map<char32_t, int32_t> ReadTokens(std::istream &is) {
57 57
58 s = conv.from_bytes(sym); 58 s = conv.from_bytes(sym);
59 if (s.size() != 1) { 59 if (s.size() != 1) {
  60 + // for tokens.txt from coqui-ai/TTS, the last token is <BLNK>
  61 + if (s.size() == 6 && s[0] == '<' && s[1] == 'B' && s[2] == 'L' &&
  62 + s[3] == 'N' && s[4] == 'K' && s[5] == '>') {
  63 + continue;
  64 + }
  65 +
60 SHERPA_ONNX_LOGE("Error when reading tokens at Line %s. size: %d", 66 SHERPA_ONNX_LOGE("Error when reading tokens at Line %s. size: %d",
61 line.c_str(), static_cast<int32_t>(s.size())); 67 line.c_str(), static_cast<int32_t>(s.size()));
62 exit(-1); 68 exit(-1);
63 } 69 }
  70 +
64 char32_t c = s[0]; 71 char32_t c = s[0];
65 72
66 if (token2id.count(c)) { 73 if (token2id.count(c)) {
@@ -77,7 +84,7 @@ static std::unordered_map<char32_t, int32_t> ReadTokens(std::istream &is) { @@ -77,7 +84,7 @@ static std::unordered_map<char32_t, int32_t> ReadTokens(std::istream &is) {
77 84
78 // see the function "phonemes_to_ids" from 85 // see the function "phonemes_to_ids" from
79 // https://github.com/rhasspy/piper/blob/master/notebooks/piper_inference_(ONNX).ipynb 86 // https://github.com/rhasspy/piper/blob/master/notebooks/piper_inference_(ONNX).ipynb
80 -static std::vector<int64_t> PhonemesToIds( 87 +static std::vector<int64_t> PiperPhonemesToIds(
81 const std::unordered_map<char32_t, int32_t> &token2id, 88 const std::unordered_map<char32_t, int32_t> &token2id,
82 const std::vector<piper::Phoneme> &phonemes) { 89 const std::vector<piper::Phoneme> &phonemes) {
83 // see 90 // see
@@ -104,6 +111,65 @@ static std::vector<int64_t> PhonemesToIds( @@ -104,6 +111,65 @@ static std::vector<int64_t> PhonemesToIds(
104 return ans; 111 return ans;
105 } 112 }
106 113
  114 +static std::vector<int64_t> CoquiPhonemesToIds(
  115 + const std::unordered_map<char32_t, int32_t> &token2id,
  116 + const std::vector<piper::Phoneme> &phonemes,
  117 + const OfflineTtsVitsModelMetaData &meta_data) {
  118 + // see
  119 + // https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/utils/text/tokenizer.py#L87
  120 + int32_t use_eos_bos = meta_data.use_eos_bos;
  121 + int32_t bos_id = meta_data.bos_id;
  122 + int32_t eos_id = meta_data.eos_id;
  123 + int32_t blank_id = meta_data.blank_id;
  124 + int32_t add_blank = meta_data.add_blank;
  125 + int32_t comma_id = token2id.at(',');
  126 + SHERPA_ONNX_LOGE("comma id: %d", comma_id);
  127 +
  128 + std::vector<int64_t> ans;
  129 + if (add_blank) {
  130 + ans.reserve(phonemes.size() * 2 + 3);
  131 + } else {
  132 + ans.reserve(phonemes.size() + 2);
  133 + }
  134 +
  135 + if (use_eos_bos) {
  136 + ans.push_back(bos_id);
  137 + }
  138 +
  139 + if (add_blank) {
  140 + ans.push_back(blank_id);
  141 +
  142 + for (auto p : phonemes) {
  143 + if (token2id.count(p)) {
  144 + ans.push_back(token2id.at(p));
  145 + ans.push_back(blank_id);
  146 + } else {
  147 + SHERPA_ONNX_LOGE("Skip unknown phonemes. Unicode codepoint: \\U+%04x.",
  148 + static_cast<uint32_t>(p));
  149 + }
  150 + }
  151 + } else {
  152 + // not adding blank
  153 + for (auto p : phonemes) {
  154 + if (token2id.count(p)) {
  155 + ans.push_back(token2id.at(p));
  156 + } else {
  157 + SHERPA_ONNX_LOGE("Skip unknown phonemes. Unicode codepoint: \\U+%04x.",
  158 + static_cast<uint32_t>(p));
  159 + }
  160 + }
  161 + }
  162 +
  163 + // add a comma at the end of a sentence so that we can have a longer pause.
  164 + ans.push_back(comma_id);
  165 +
  166 + if (use_eos_bos) {
  167 + ans.push_back(eos_id);
  168 + }
  169 +
  170 + return ans;
  171 +}
  172 +
107 void InitEspeak(const std::string &data_dir) { 173 void InitEspeak(const std::string &data_dir) {
108 static std::once_flag init_flag; 174 static std::once_flag init_flag;
109 std::call_once(init_flag, [data_dir]() { 175 std::call_once(init_flag, [data_dir]() {
@@ -119,21 +185,23 @@ void InitEspeak(const std::string &data_dir) { @@ -119,21 +185,23 @@ void InitEspeak(const std::string &data_dir) {
119 }); 185 });
120 } 186 }
121 187
122 -PiperPhonemizeLexicon::PiperPhonemizeLexicon(const std::string &tokens,  
123 - const std::string &data_dir)  
124 - : data_dir_(data_dir) { 188 +PiperPhonemizeLexicon::PiperPhonemizeLexicon(
  189 + const std::string &tokens, const std::string &data_dir,
  190 + const OfflineTtsVitsModelMetaData &meta_data)
  191 + : meta_data_(meta_data) {
125 { 192 {
126 std::ifstream is(tokens); 193 std::ifstream is(tokens);
127 token2id_ = ReadTokens(is); 194 token2id_ = ReadTokens(is);
128 } 195 }
129 196
130 - InitEspeak(data_dir_); 197 + InitEspeak(data_dir);
131 } 198 }
132 199
133 #if __ANDROID_API__ >= 9 200 #if __ANDROID_API__ >= 9
134 -PiperPhonemizeLexicon::PiperPhonemizeLexicon(AAssetManager *mgr,  
135 - const std::string &tokens,  
136 - const std::string &data_dir) { 201 +PiperPhonemizeLexicon::PiperPhonemizeLexicon(
  202 + AAssetManager *mgr, const std::string &tokens, const std::string &data_dir,
  203 + const OfflineTtsVitsModelMetaData &meta_data)
  204 + : meta_data_(meta_data) {
137 { 205 {
138 auto buf = ReadFile(mgr, tokens); 206 auto buf = ReadFile(mgr, tokens);
139 std::istrstream is(buf.data(), buf.size()); 207 std::istrstream is(buf.data(), buf.size());
@@ -141,8 +209,9 @@ PiperPhonemizeLexicon::PiperPhonemizeLexicon(AAssetManager *mgr, @@ -141,8 +209,9 @@ PiperPhonemizeLexicon::PiperPhonemizeLexicon(AAssetManager *mgr,
141 } 209 }
142 210
143 // We should copy the directory of espeak-ng-data from the asset to 211 // We should copy the directory of espeak-ng-data from the asset to
144 - // some internal or external storage and then pass the directory to data_dir.  
145 - InitEspeak(data_dir_); 212 + // some internal or external storage and then pass the directory to
  213 + // data_dir.
  214 + InitEspeak(data_dir);
146 } 215 }
147 #endif 216 #endif
148 217
@@ -160,9 +229,21 @@ std::vector<std::vector<int64_t>> PiperPhonemizeLexicon::ConvertTextToTokenIds( @@ -160,9 +229,21 @@ std::vector<std::vector<int64_t>> PiperPhonemizeLexicon::ConvertTextToTokenIds(
160 std::vector<std::vector<int64_t>> ans; 229 std::vector<std::vector<int64_t>> ans;
161 230
162 std::vector<int64_t> phoneme_ids; 231 std::vector<int64_t> phoneme_ids;
163 - for (const auto &p : phonemes) {  
164 - phoneme_ids = PhonemesToIds(token2id_, p);  
165 - ans.push_back(std::move(phoneme_ids)); 232 +
  233 + if (meta_data_.is_piper) {
  234 + for (const auto &p : phonemes) {
  235 + phoneme_ids = PiperPhonemesToIds(token2id_, p);
  236 + ans.push_back(std::move(phoneme_ids));
  237 + }
  238 + } else if (meta_data_.is_coqui) {
  239 + for (const auto &p : phonemes) {
  240 + phoneme_ids = CoquiPhonemesToIds(token2id_, p, meta_data_);
  241 + ans.push_back(std::move(phoneme_ids));
  242 + }
  243 +
  244 + } else {
  245 + SHERPA_ONNX_LOGE("Unsupported model");
  246 + exit(-1);
166 } 247 }
167 248
168 return ans; 249 return ans;
@@ -15,25 +15,28 @@ @@ -15,25 +15,28 @@
15 #endif 15 #endif
16 16
17 #include "sherpa-onnx/csrc/offline-tts-frontend.h" 17 #include "sherpa-onnx/csrc/offline-tts-frontend.h"
  18 +#include "sherpa-onnx/csrc/offline-tts-vits-model-metadata.h"
18 19
19 namespace sherpa_onnx { 20 namespace sherpa_onnx {
20 21
21 class PiperPhonemizeLexicon : public OfflineTtsFrontend { 22 class PiperPhonemizeLexicon : public OfflineTtsFrontend {
22 public: 23 public:
23 - PiperPhonemizeLexicon(const std::string &tokens, const std::string &data_dir); 24 + PiperPhonemizeLexicon(const std::string &tokens, const std::string &data_dir,
  25 + const OfflineTtsVitsModelMetaData &meta_data);
24 26
25 #if __ANDROID_API__ >= 9 27 #if __ANDROID_API__ >= 9
26 PiperPhonemizeLexicon(AAssetManager *mgr, const std::string &tokens, 28 PiperPhonemizeLexicon(AAssetManager *mgr, const std::string &tokens,
27 - const std::string &data_dir); 29 + const std::string &data_dir,
  30 + const OfflineTtsVitsModelMetaData &meta_data);
28 #endif 31 #endif
29 32
30 std::vector<std::vector<int64_t>> ConvertTextToTokenIds( 33 std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
31 const std::string &text, const std::string &voice = "") const override; 34 const std::string &text, const std::string &voice = "") const override;
32 35
33 private: 36 private:
34 - std::string data_dir_;  
35 // map unicode codepoint to an integer ID 37 // map unicode codepoint to an integer ID
36 std::unordered_map<char32_t, int32_t> token2id_; 38 std::unordered_map<char32_t, int32_t> token2id_;
  39 + OfflineTtsVitsModelMetaData meta_data_;
37 }; 40 };
38 41
39 } // namespace sherpa_onnx 42 } // namespace sherpa_onnx