Fangjun Kuang
Committed by GitHub

Validate input sid (#369)

@@ -20,7 +20,7 @@ option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF) @@ -20,7 +20,7 @@ option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF)
20 option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON) 20 option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON)
21 option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON) 21 option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON)
22 option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF) 22 option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF)
23 -option(SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY "True to link libstdc++ statically. Used only when BUILD_SHARED_LIBS is ON on Linux" ON) 23 +option(SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY "True to link libstdc++ statically. Used only when BUILD_SHARED_LIBS is OFF on Linux" ON)
24 24
25 set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") 25 set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
26 set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") 26 set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
@@ -124,6 +124,11 @@ def main(): @@ -124,6 +124,11 @@ def main():
124 start = time.time() 124 start = time.time()
125 audio = tts.generate(args.text, sid=args.sid) 125 audio = tts.generate(args.text, sid=args.sid)
126 end = time.time() 126 end = time.time()
  127 +
  128 + if len(audio.samples) == 0:
  129 + print("Error in generating audios. Please read previous error messages.")
  130 + return
  131 +
127 elapsed_seconds = end - start 132 elapsed_seconds = end - start
128 audio_duration = len(audio.samples) / audio.sample_rate 133 audio_duration = len(audio.samples) / audio.sample_rate
129 real_time_factor = elapsed_seconds / audio_duration 134 real_time_factor = elapsed_seconds / audio_duration
@@ -104,9 +104,17 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese( @@ -104,9 +104,17 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese(
104 104
105 std::vector<int64_t> ans; 105 std::vector<int64_t> ans;
106 106
107 - ans.push_back(token2id_.at("sil")); 107 + auto sil = token2id_.at("sil");
  108 + auto eos = token2id_.at("eos");
  109 +
  110 + ans.push_back(sil);
108 111
109 for (const auto &w : words) { 112 for (const auto &w : words) {
  113 + if (punctuations_.count(w)) {
  114 + ans.push_back(sil);
  115 + continue;
  116 + }
  117 +
110 if (!word2ids_.count(w)) { 118 if (!word2ids_.count(w)) {
111 SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str()); 119 SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
112 continue; 120 continue;
@@ -115,8 +123,8 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese( @@ -115,8 +123,8 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese(
115 const auto &token_ids = word2ids_.at(w); 123 const auto &token_ids = word2ids_.at(w);
116 ans.insert(ans.end(), token_ids.begin(), token_ids.end()); 124 ans.insert(ans.end(), token_ids.begin(), token_ids.end());
117 } 125 }
118 - ans.push_back(token2id_.at("sil"));  
119 - ans.push_back(token2id_.at("eos")); 126 + ans.push_back(sil);
  127 + ans.push_back(eos);
120 return ans; 128 return ans;
121 } 129 }
122 130
@@ -126,6 +134,7 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish( @@ -126,6 +134,7 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
126 ToLowerCase(&text); 134 ToLowerCase(&text);
127 135
128 std::vector<std::string> words = SplitUtf8(text); 136 std::vector<std::string> words = SplitUtf8(text);
  137 + int32_t blank = token2id_.at(" ");
129 138
130 std::vector<int64_t> ans; 139 std::vector<int64_t> ans;
131 for (const auto &w : words) { 140 for (const auto &w : words) {
@@ -141,12 +150,10 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish( @@ -141,12 +150,10 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
141 150
142 const auto &token_ids = word2ids_.at(w); 151 const auto &token_ids = word2ids_.at(w);
143 ans.insert(ans.end(), token_ids.begin(), token_ids.end()); 152 ans.insert(ans.end(), token_ids.begin(), token_ids.end());
144 - if (blank_ != -1) {  
145 - ans.push_back(blank_);  
146 - } 153 + ans.push_back(blank);
147 } 154 }
148 155
149 - if (blank_ != -1 && !ans.empty()) { 156 + if (!ans.empty()) {
150 // remove the last blank 157 // remove the last blank
151 ans.resize(ans.size() - 1); 158 ans.resize(ans.size() - 1);
152 } 159 }
@@ -156,9 +163,6 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish( @@ -156,9 +163,6 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
156 163
157 void Lexicon::InitTokens(const std::string &tokens) { 164 void Lexicon::InitTokens(const std::string &tokens) {
158 token2id_ = ReadTokens(tokens); 165 token2id_ = ReadTokens(tokens);
159 - if (token2id_.count(" ")) {  
160 - blank_ = token2id_.at(" ");  
161 - }  
162 } 166 }
163 167
164 void Lexicon::InitLanguage(const std::string &_lang) { 168 void Lexicon::InitLanguage(const std::string &_lang) {
@@ -44,7 +44,6 @@ class Lexicon { @@ -44,7 +44,6 @@ class Lexicon {
44 std::unordered_map<std::string, std::vector<int32_t>> word2ids_; 44 std::unordered_map<std::string, std::vector<int32_t>> word2ids_;
45 std::unordered_set<std::string> punctuations_; 45 std::unordered_set<std::string> punctuations_;
46 std::unordered_map<std::string, int32_t> token2id_; 46 std::unordered_map<std::string, int32_t> token2id_;
47 - int32_t blank_ = -1; // ID for the blank token  
48 Language language_; 47 Language language_;
49 // 48 //
50 }; 49 };
@@ -25,6 +25,23 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -25,6 +25,23 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
25 25
26 GeneratedAudio Generate(const std::string &text, 26 GeneratedAudio Generate(const std::string &text,
27 int64_t sid = 0) const override { 27 int64_t sid = 0) const override {
  28 + int32_t num_speakers = model_->NumSpeakers();
  29 + if (num_speakers == 0 && sid != 0) {
  30 + SHERPA_ONNX_LOGE(
  31 + "This is a single-speaker model and supports only sid 0. Given sid: "
  32 + "%d",
  33 + sid);
  34 + return {};
  35 + }
  36 +
  37 + if (num_speakers != 0 && (sid >= num_speakers || sid < 0)) {
  38 + SHERPA_ONNX_LOGE(
  39 + "This model contains only %d speakers. sid should be in the range "
  40 + "[%d, %d]. Given: %d",
  41 + num_speakers, 0, num_speakers - 1, sid);
  42 + return {};
  43 + }
  44 +
28 std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text); 45 std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text);
29 if (x.empty()) { 46 if (x.empty()) {
30 SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str()); 47 SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
@@ -85,6 +85,7 @@ class OfflineTtsVitsModel::Impl { @@ -85,6 +85,7 @@ class OfflineTtsVitsModel::Impl {
85 85
86 std::string Punctuations() const { return punctuations_; } 86 std::string Punctuations() const { return punctuations_; }
87 std::string Language() const { return language_; } 87 std::string Language() const { return language_; }
  88 + int32_t NumSpeakers() const { return num_speakers_; }
88 89
89 private: 90 private:
90 void Init(void *model_data, size_t model_data_length) { 91 void Init(void *model_data, size_t model_data_length) {
@@ -107,7 +108,7 @@ class OfflineTtsVitsModel::Impl { @@ -107,7 +108,7 @@ class OfflineTtsVitsModel::Impl {
107 Ort::AllocatorWithDefaultOptions allocator; // used in the macro below 108 Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
108 SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate"); 109 SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate");
109 SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank"); 110 SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank");
110 - SHERPA_ONNX_READ_META_DATA(n_speakers_, "n_speakers"); 111 + SHERPA_ONNX_READ_META_DATA(num_speakers_, "n_speakers");
111 SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation"); 112 SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation");
112 SHERPA_ONNX_READ_META_DATA_STR(language_, "language"); 113 SHERPA_ONNX_READ_META_DATA_STR(language_, "language");
113 } 114 }
@@ -128,7 +129,7 @@ class OfflineTtsVitsModel::Impl { @@ -128,7 +129,7 @@ class OfflineTtsVitsModel::Impl {
128 129
129 int32_t sample_rate_; 130 int32_t sample_rate_;
130 int32_t add_blank_; 131 int32_t add_blank_;
131 - int32_t n_speakers_; 132 + int32_t num_speakers_;
132 std::string punctuations_; 133 std::string punctuations_;
133 std::string language_; 134 std::string language_;
134 }; 135 };
@@ -152,4 +153,8 @@ std::string OfflineTtsVitsModel::Punctuations() const { @@ -152,4 +153,8 @@ std::string OfflineTtsVitsModel::Punctuations() const {
152 153
153 std::string OfflineTtsVitsModel::Language() const { return impl_->Language(); } 154 std::string OfflineTtsVitsModel::Language() const { return impl_->Language(); }
154 155
  156 +int32_t OfflineTtsVitsModel::NumSpeakers() const {
  157 + return impl_->NumSpeakers();
  158 +}
  159 +
155 } // namespace sherpa_onnx 160 } // namespace sherpa_onnx
@@ -39,6 +39,7 @@ class OfflineTtsVitsModel { @@ -39,6 +39,7 @@ class OfflineTtsVitsModel {
39 39
40 std::string Punctuations() const; 40 std::string Punctuations() const;
41 std::string Language() const; 41 std::string Language() const;
  42 + int32_t NumSpeakers() const;
42 43
43 private: 44 private:
44 class Impl; 45 class Impl;
@@ -81,6 +81,12 @@ or detailes. @@ -81,6 +81,12 @@ or detailes.
81 81
82 sherpa_onnx::OfflineTts tts(config); 82 sherpa_onnx::OfflineTts tts(config);
83 auto audio = tts.Generate(po.GetArg(1), sid); 83 auto audio = tts.Generate(po.GetArg(1), sid);
  84 + if (audio.samples.empty()) {
  85 + fprintf(
  86 + stderr,
  87 + "Error in generating audios. Please read previous error messages.\n");
  88 + exit(EXIT_FAILURE);
  89 + }
84 90
85 bool ok = sherpa_onnx::WriteWave(output_filename, audio.sample_rate, 91 bool ok = sherpa_onnx::WriteWave(output_filename, audio.sample_rate,
86 audio.samples.data(), audio.samples.size()); 92 audio.samples.data(), audio.samples.size());