Fangjun Kuang
Committed by GitHub

Support VITS TTS models from coqui-ai/TTS (#416)

* Support VITS TTS models from coqui-ai/TTS

* release v1.8.9
1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR) 1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
2 project(sherpa-onnx) 2 project(sherpa-onnx)
3 3
4 -set(SHERPA_ONNX_VERSION "1.8.8") 4 +set(SHERPA_ONNX_VERSION "1.8.9")
5 5
6 # Disable warning about 6 # Disable warning about
7 # 7 #
@@ -196,20 +196,27 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese( @@ -196,20 +196,27 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese(
196 196
197 std::vector<int64_t> ans; 197 std::vector<int64_t> ans;
198 198
  199 + int32_t blank = -1;
  200 + if (token2id_.count(" ")) {
  201 + blank = token2id_.at(" ");
  202 + }
  203 +
199 int32_t sil = -1; 204 int32_t sil = -1;
200 int32_t eos = -1; 205 int32_t eos = -1;
201 if (token2id_.count("sil")) { 206 if (token2id_.count("sil")) {
202 sil = token2id_.at("sil"); 207 sil = token2id_.at("sil");
203 eos = token2id_.at("eos"); 208 eos = token2id_.at("eos");
204 - } else {  
205 - sil = 0;  
206 } 209 }
207 210
208 - ans.push_back(sil); 211 + if (sil != -1) {
  212 + ans.push_back(sil);
  213 + }
209 214
210 for (const auto &w : words) { 215 for (const auto &w : words) {
211 if (punctuations_.count(w)) { 216 if (punctuations_.count(w)) {
212 - ans.push_back(sil); 217 + if (sil != -1) {
  218 + ans.push_back(sil);
  219 + }
213 continue; 220 continue;
214 } 221 }
215 222
@@ -220,11 +227,19 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese( @@ -220,11 +227,19 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese(
220 227
221 const auto &token_ids = word2ids_.at(w); 228 const auto &token_ids = word2ids_.at(w);
222 ans.insert(ans.end(), token_ids.begin(), token_ids.end()); 229 ans.insert(ans.end(), token_ids.begin(), token_ids.end());
  230 + if (blank != -1) {
  231 + ans.push_back(blank);
  232 + }
  233 + }
  234 +
  235 + if (sil != -1) {
  236 + ans.push_back(sil);
223 } 237 }
224 - ans.push_back(sil); 238 +
225 if (eos != -1) { 239 if (eos != -1) {
226 ans.push_back(eos); 240 ans.push_back(eos);
227 } 241 }
  242 +
228 return ans; 243 return ans;
229 } 244 }
230 245
@@ -252,7 +267,7 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish( @@ -252,7 +267,7 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
252 int32_t blank = token2id_.at(" "); 267 int32_t blank = token2id_.at(" ");
253 268
254 std::vector<int64_t> ans; 269 std::vector<int64_t> ans;
255 - if (is_piper_) { 270 + if (is_piper_ && token2id_.count("^")) {
256 ans.push_back(token2id_.at("^")); // sos 271 ans.push_back(token2id_.at("^")); // sos
257 } 272 }
258 273
@@ -277,7 +292,7 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish( @@ -277,7 +292,7 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
277 ans.resize(ans.size() - 1); 292 ans.resize(ans.size() - 1);
278 } 293 }
279 294
280 - if (is_piper_) { 295 + if (is_piper_ && token2id_.count("$")) {
281 ans.push_back(token2id_.at("$")); // eos 296 ans.push_back(token2id_.at("$")); // eos
282 } 297 }
283 298
@@ -81,7 +81,8 @@ class OfflineTtsVitsModel::Impl { @@ -81,7 +81,8 @@ class OfflineTtsVitsModel::Impl {
81 81
82 std::string comment; 82 std::string comment;
83 SHERPA_ONNX_READ_META_DATA_STR(comment, "comment"); 83 SHERPA_ONNX_READ_META_DATA_STR(comment, "comment");
84 - if (comment.find("piper") != std::string::npos) { 84 + if (comment.find("piper") != std::string::npos ||
  85 + comment.find("coqui") != std::string::npos) {
85 is_piper_ = true; 86 is_piper_ = true;
86 } 87 }
87 } 88 }