Fangjun Kuang
Committed by GitHub

Break text into sentences for tts. (#460)

This is for models that are not using piper-phonemize as their front-end.
@@ -88,8 +88,8 @@ static std::vector<int32_t> ConvertTokensToIds( @@ -88,8 +88,8 @@ static std::vector<int32_t> ConvertTokensToIds(
88 88
89 Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens, 89 Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
90 const std::string &punctuations, const std::string &language, 90 const std::string &punctuations, const std::string &language,
91 - bool debug /*= false*/, bool is_piper /*= false*/)  
92 - : debug_(debug), is_piper_(is_piper) { 91 + bool debug /*= false*/)
  92 + : debug_(debug) {
93 InitLanguage(language); 93 InitLanguage(language);
94 94
95 { 95 {
@@ -108,9 +108,9 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens, @@ -108,9 +108,9 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
108 #if __ANDROID_API__ >= 9 108 #if __ANDROID_API__ >= 9
109 Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon, 109 Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon,
110 const std::string &tokens, const std::string &punctuations, 110 const std::string &tokens, const std::string &punctuations,
111 - const std::string &language, bool debug /*= false*/,  
112 - bool is_piper /*= false*/)  
113 - : debug_(debug), is_piper_(is_piper) { 111 + const std::string &language, bool debug /*= false*/
  112 + )
  113 + : debug_(debug) {
114 InitLanguage(language); 114 InitLanguage(language);
115 115
116 { 116 {
@@ -132,16 +132,10 @@ Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon, @@ -132,16 +132,10 @@ Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon,
132 std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIds( 132 std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIds(
133 const std::string &text, const std::string & /*voice*/ /*= ""*/) const { 133 const std::string &text, const std::string & /*voice*/ /*= ""*/) const {
134 switch (language_) { 134 switch (language_) {
135 - case Language::kEnglish:  
136 - return ConvertTextToTokenIdsEnglish(text);  
137 - case Language::kGerman:  
138 - return ConvertTextToTokenIdsGerman(text);  
139 - case Language::kSpanish:  
140 - return ConvertTextToTokenIdsSpanish(text);  
141 - case Language::kFrench:  
142 - return ConvertTextToTokenIdsFrench(text);  
143 case Language::kChinese: 135 case Language::kChinese:
144 return ConvertTextToTokenIdsChinese(text); 136 return ConvertTextToTokenIdsChinese(text);
  137 + case Language::kNotChinese:
  138 + return ConvertTextToTokenIdsNotChinese(text);
145 default: 139 default:
146 SHERPA_ONNX_LOGE("Unknown language: %d", static_cast<int32_t>(language_)); 140 SHERPA_ONNX_LOGE("Unknown language: %d", static_cast<int32_t>(language_));
147 exit(-1); 141 exit(-1);
@@ -197,7 +191,8 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese( @@ -197,7 +191,8 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
197 fprintf(stderr, "\n"); 191 fprintf(stderr, "\n");
198 } 192 }
199 193
200 - std::vector<int64_t> ans; 194 + std::vector<std::vector<int64_t>> ans;
  195 + std::vector<int64_t> this_sentence;
201 196
202 int32_t blank = -1; 197 int32_t blank = -1;
203 if (token2id_.count(" ")) { 198 if (token2id_.count(" ")) {
@@ -212,15 +207,32 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese( @@ -212,15 +207,32 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
212 } 207 }
213 208
214 if (sil != -1) { 209 if (sil != -1) {
215 - ans.push_back(sil); 210 + this_sentence.push_back(sil);
216 } 211 }
217 212
218 for (const auto &w : words) { 213 for (const auto &w : words) {
  214 + if (w == "." || w == ";" || w == "!" || w == "?" || w == "-" || w == ":" ||
  215 + w == "。" || w == ";" || w == "!" || w == "?" || w == ":" ||
  216 + w == "”" ||
  217 + // not sentence break
  218 + w == "," || w == "“" || w == "," || w == "、") {
219 if (punctuations_.count(w)) { 219 if (punctuations_.count(w)) {
220 if (token2id_.count(w)) { 220 if (token2id_.count(w)) {
221 - ans.push_back(token2id_.at(w)); 221 + this_sentence.push_back(token2id_.at(w));
222 } else if (sil != -1) { 222 } else if (sil != -1) {
223 - ans.push_back(sil); 223 + this_sentence.push_back(sil);
  224 + }
  225 + }
  226 +
  227 + if (w != "," && w != "“" && w != "," && w != "、") {
  228 + if (eos != -1) {
  229 + this_sentence.push_back(eos);
  230 + }
  231 + ans.push_back(std::move(this_sentence));
  232 +
  233 + if (sil != -1) {
  234 + this_sentence.push_back(sil);
  235 + }
224 } 236 }
225 continue; 237 continue;
226 } 238 }
@@ -231,24 +243,26 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese( @@ -231,24 +243,26 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
231 } 243 }
232 244
233 const auto &token_ids = word2ids_.at(w); 245 const auto &token_ids = word2ids_.at(w);
234 - ans.insert(ans.end(), token_ids.begin(), token_ids.end()); 246 + this_sentence.insert(this_sentence.end(), token_ids.begin(),
  247 + token_ids.end());
235 if (blank != -1) { 248 if (blank != -1) {
236 - ans.push_back(blank); 249 + this_sentence.push_back(blank);
237 } 250 }
238 } 251 }
239 252
240 if (sil != -1) { 253 if (sil != -1) {
241 - ans.push_back(sil); 254 + this_sentence.push_back(sil);
242 } 255 }
243 256
244 if (eos != -1) { 257 if (eos != -1) {
245 - ans.push_back(eos); 258 + this_sentence.push_back(eos);
246 } 259 }
  260 + ans.push_back(std::move(this_sentence));
247 261
248 - return {ans}; 262 + return ans;
249 } 263 }
250 264
251 -std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsEnglish( 265 +std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsNotChinese(
252 const std::string &_text) const { 266 const std::string &_text) const {
253 std::string text(_text); 267 std::string text(_text);
254 ToLowerCase(&text); 268 ToLowerCase(&text);
@@ -271,14 +285,22 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsEnglish( @@ -271,14 +285,22 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsEnglish(
271 285
272 int32_t blank = token2id_.at(" "); 286 int32_t blank = token2id_.at(" ");
273 287
274 - std::vector<int64_t> ans;  
275 - if (is_piper_ && token2id_.count("^")) {  
276 - ans.push_back(token2id_.at("^")); // sos  
277 - } 288 + std::vector<std::vector<int64_t>> ans;
  289 + std::vector<int64_t> this_sentence;
278 290
279 for (const auto &w : words) { 291 for (const auto &w : words) {
  292 + if (w == "." || w == ";" || w == "!" || w == "?" || w == "-" || w == ":" ||
  293 + // not sentence break
  294 + w == ",") {
280 if (punctuations_.count(w)) { 295 if (punctuations_.count(w)) {
281 - ans.push_back(token2id_.at(w)); 296 + this_sentence.push_back(token2id_.at(w));
  297 + }
  298 +
  299 + if (w != ",") {
  300 + this_sentence.push_back(blank);
  301 + ans.push_back(std::move(this_sentence));
  302 + }
  303 +
282 continue; 304 continue;
283 } 305 }
284 306
@@ -288,20 +310,21 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsEnglish( @@ -288,20 +310,21 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsEnglish(
288 } 310 }
289 311
290 const auto &token_ids = word2ids_.at(w); 312 const auto &token_ids = word2ids_.at(w);
291 - ans.insert(ans.end(), token_ids.begin(), token_ids.end());  
292 - ans.push_back(blank); 313 + this_sentence.insert(this_sentence.end(), token_ids.begin(),
  314 + token_ids.end());
  315 + this_sentence.push_back(blank);
293 } 316 }
294 317
295 - if (!ans.empty()) { 318 + if (!this_sentence.empty()) {
296 // remove the last blank 319 // remove the last blank
297 - ans.resize(ans.size() - 1); 320 + this_sentence.resize(this_sentence.size() - 1);
298 } 321 }
299 322
300 - if (is_piper_ && token2id_.count("$")) {  
301 - ans.push_back(token2id_.at("$")); // eos 323 + if (!this_sentence.empty()) {
  324 + ans.push_back(std::move(this_sentence));
302 } 325 }
303 326
304 - return {ans}; 327 + return ans;
305 } 328 }
306 329
307 void Lexicon::InitTokens(std::istream &is) { token2id_ = ReadTokens(is); } 330 void Lexicon::InitTokens(std::istream &is) { token2id_ = ReadTokens(is); }
@@ -309,16 +332,10 @@ void Lexicon::InitTokens(std::istream &is) { token2id_ = ReadTokens(is); } @@ -309,16 +332,10 @@ void Lexicon::InitTokens(std::istream &is) { token2id_ = ReadTokens(is); }
309 void Lexicon::InitLanguage(const std::string &_lang) { 332 void Lexicon::InitLanguage(const std::string &_lang) {
310 std::string lang(_lang); 333 std::string lang(_lang);
311 ToLowerCase(&lang); 334 ToLowerCase(&lang);
312 - if (lang == "english") {  
313 - language_ = Language::kEnglish;  
314 - } else if (lang == "german") {  
315 - language_ = Language::kGerman;  
316 - } else if (lang == "spanish") {  
317 - language_ = Language::kSpanish;  
318 - } else if (lang == "french") {  
319 - language_ = Language::kFrench;  
320 - } else if (lang == "chinese") { 335 + if (lang == "chinese") {
321 language_ = Language::kChinese; 336 language_ = Language::kChinese;
  337 + } else if (!lang.empty()) {
  338 + language_ = Language::kNotChinese;
322 } else { 339 } else {
323 SHERPA_ONNX_LOGE("Unknown language: %s", _lang.c_str()); 340 SHERPA_ONNX_LOGE("Unknown language: %s", _lang.c_str());
324 exit(-1); 341 exit(-1);
@@ -29,35 +29,19 @@ class Lexicon : public OfflineTtsFrontend { @@ -29,35 +29,19 @@ class Lexicon : public OfflineTtsFrontend {
29 // Note: for models from piper, we won't use this class. 29 // Note: for models from piper, we won't use this class.
30 Lexicon(const std::string &lexicon, const std::string &tokens, 30 Lexicon(const std::string &lexicon, const std::string &tokens,
31 const std::string &punctuations, const std::string &language, 31 const std::string &punctuations, const std::string &language,
32 - bool debug = false, bool is_piper = false); 32 + bool debug = false);
33 33
34 #if __ANDROID_API__ >= 9 34 #if __ANDROID_API__ >= 9
35 Lexicon(AAssetManager *mgr, const std::string &lexicon, 35 Lexicon(AAssetManager *mgr, const std::string &lexicon,
36 const std::string &tokens, const std::string &punctuations, 36 const std::string &tokens, const std::string &punctuations,
37 - const std::string &language, bool debug = false,  
38 - bool is_piper = false); 37 + const std::string &language, bool debug = false);
39 #endif 38 #endif
40 39
41 std::vector<std::vector<int64_t>> ConvertTextToTokenIds( 40 std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
42 const std::string &text, const std::string &voice = "") const override; 41 const std::string &text, const std::string &voice = "") const override;
43 42
44 private: 43 private:
45 - std::vector<std::vector<int64_t>> ConvertTextToTokenIdsGerman(  
46 - const std::string &text) const {  
47 - return ConvertTextToTokenIdsEnglish(text);  
48 - }  
49 -  
50 - std::vector<std::vector<int64_t>> ConvertTextToTokenIdsSpanish(  
51 - const std::string &text) const {  
52 - return ConvertTextToTokenIdsEnglish(text);  
53 - }  
54 -  
55 - std::vector<std::vector<int64_t>> ConvertTextToTokenIdsFrench(  
56 - const std::string &text) const {  
57 - return ConvertTextToTokenIdsEnglish(text);  
58 - }  
59 -  
60 - std::vector<std::vector<int64_t>> ConvertTextToTokenIdsEnglish( 44 + std::vector<std::vector<int64_t>> ConvertTextToTokenIdsNotChinese(
61 const std::string &text) const; 45 const std::string &text) const;
62 46
63 std::vector<std::vector<int64_t>> ConvertTextToTokenIdsChinese( 47 std::vector<std::vector<int64_t>> ConvertTextToTokenIdsChinese(
@@ -70,10 +54,7 @@ class Lexicon : public OfflineTtsFrontend { @@ -70,10 +54,7 @@ class Lexicon : public OfflineTtsFrontend {
70 54
71 private: 55 private:
72 enum class Language { 56 enum class Language {
73 - kEnglish,  
74 - kGerman,  
75 - kSpanish,  
76 - kFrench, 57 + kNotChinese,
77 kChinese, 58 kChinese,
78 kUnknown, 59 kUnknown,
79 }; 60 };
@@ -84,7 +65,6 @@ class Lexicon : public OfflineTtsFrontend { @@ -84,7 +65,6 @@ class Lexicon : public OfflineTtsFrontend {
84 std::unordered_map<std::string, int32_t> token2id_; 65 std::unordered_map<std::string, int32_t> token2id_;
85 Language language_; 66 Language language_;
86 bool debug_; 67 bool debug_;
87 - bool is_piper_;  
88 68
89 // for Chinese polyphones 69 // for Chinese polyphones
90 std::unique_ptr<std::regex> pattern_; 70 std::unique_ptr<std::regex> pattern_;
@@ -195,8 +195,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -195,8 +195,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
195 } else { 195 } else {
196 frontend_ = std::make_unique<Lexicon>( 196 frontend_ = std::make_unique<Lexicon>(
197 mgr, config_.model.vits.lexicon, config_.model.vits.tokens, 197 mgr, config_.model.vits.lexicon, config_.model.vits.tokens,
198 - model_->Punctuations(), model_->Language(), config_.model.debug,  
199 - model_->IsPiper()); 198 + model_->Punctuations(), model_->Language(), config_.model.debug);
200 } 199 }
201 } 200 }
202 #endif 201 #endif
@@ -208,8 +207,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -208,8 +207,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
208 } else { 207 } else {
209 frontend_ = std::make_unique<Lexicon>( 208 frontend_ = std::make_unique<Lexicon>(
210 config_.model.vits.lexicon, config_.model.vits.tokens, 209 config_.model.vits.lexicon, config_.model.vits.tokens,
211 - model_->Punctuations(), model_->Language(), config_.model.debug,  
212 - model_->IsPiper()); 210 + model_->Punctuations(), model_->Language(), config_.model.debug);
213 } 211 }
214 } 212 }
215 213