正在显示
8 个修改的文件
包含
51 行增加
和
14 行删除
| @@ -20,7 +20,7 @@ option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF) | @@ -20,7 +20,7 @@ option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF) | ||
| 20 | option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON) | 20 | option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON) |
| 21 | option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON) | 21 | option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON) |
| 22 | option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF) | 22 | option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF) |
| 23 | -option(SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY "True to link libstdc++ statically. Used only when BUILD_SHARED_LIBS is ON on Linux" ON) | 23 | +option(SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY "True to link libstdc++ statically. Used only when BUILD_SHARED_LIBS is OFF on Linux" ON) |
| 24 | 24 | ||
| 25 | set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") | 25 | set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") |
| 26 | set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") | 26 | set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") |
| @@ -124,6 +124,11 @@ def main(): | @@ -124,6 +124,11 @@ def main(): | ||
| 124 | start = time.time() | 124 | start = time.time() |
| 125 | audio = tts.generate(args.text, sid=args.sid) | 125 | audio = tts.generate(args.text, sid=args.sid) |
| 126 | end = time.time() | 126 | end = time.time() |
| 127 | + | ||
| 128 | + if len(audio.samples) == 0: | ||
| 129 | + print("Error in generating audios. Please read previous error messages.") | ||
| 130 | + return | ||
| 131 | + | ||
| 127 | elapsed_seconds = end - start | 132 | elapsed_seconds = end - start |
| 128 | audio_duration = len(audio.samples) / audio.sample_rate | 133 | audio_duration = len(audio.samples) / audio.sample_rate |
| 129 | real_time_factor = elapsed_seconds / audio_duration | 134 | real_time_factor = elapsed_seconds / audio_duration |
| @@ -104,9 +104,17 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese( | @@ -104,9 +104,17 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese( | ||
| 104 | 104 | ||
| 105 | std::vector<int64_t> ans; | 105 | std::vector<int64_t> ans; |
| 106 | 106 | ||
| 107 | - ans.push_back(token2id_.at("sil")); | 107 | + auto sil = token2id_.at("sil"); |
| 108 | + auto eos = token2id_.at("eos"); | ||
| 109 | + | ||
| 110 | + ans.push_back(sil); | ||
| 108 | 111 | ||
| 109 | for (const auto &w : words) { | 112 | for (const auto &w : words) { |
| 113 | + if (punctuations_.count(w)) { | ||
| 114 | + ans.push_back(sil); | ||
| 115 | + continue; | ||
| 116 | + } | ||
| 117 | + | ||
| 110 | if (!word2ids_.count(w)) { | 118 | if (!word2ids_.count(w)) { |
| 111 | SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str()); | 119 | SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str()); |
| 112 | continue; | 120 | continue; |
| @@ -115,8 +123,8 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese( | @@ -115,8 +123,8 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese( | ||
| 115 | const auto &token_ids = word2ids_.at(w); | 123 | const auto &token_ids = word2ids_.at(w); |
| 116 | ans.insert(ans.end(), token_ids.begin(), token_ids.end()); | 124 | ans.insert(ans.end(), token_ids.begin(), token_ids.end()); |
| 117 | } | 125 | } |
| 118 | - ans.push_back(token2id_.at("sil")); | ||
| 119 | - ans.push_back(token2id_.at("eos")); | 126 | + ans.push_back(sil); |
| 127 | + ans.push_back(eos); | ||
| 120 | return ans; | 128 | return ans; |
| 121 | } | 129 | } |
| 122 | 130 | ||
| @@ -126,6 +134,7 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish( | @@ -126,6 +134,7 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish( | ||
| 126 | ToLowerCase(&text); | 134 | ToLowerCase(&text); |
| 127 | 135 | ||
| 128 | std::vector<std::string> words = SplitUtf8(text); | 136 | std::vector<std::string> words = SplitUtf8(text); |
| 137 | + int32_t blank = token2id_.at(" "); | ||
| 129 | 138 | ||
| 130 | std::vector<int64_t> ans; | 139 | std::vector<int64_t> ans; |
| 131 | for (const auto &w : words) { | 140 | for (const auto &w : words) { |
| @@ -141,12 +150,10 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish( | @@ -141,12 +150,10 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish( | ||
| 141 | 150 | ||
| 142 | const auto &token_ids = word2ids_.at(w); | 151 | const auto &token_ids = word2ids_.at(w); |
| 143 | ans.insert(ans.end(), token_ids.begin(), token_ids.end()); | 152 | ans.insert(ans.end(), token_ids.begin(), token_ids.end()); |
| 144 | - if (blank_ != -1) { | ||
| 145 | - ans.push_back(blank_); | ||
| 146 | - } | 153 | + ans.push_back(blank); |
| 147 | } | 154 | } |
| 148 | 155 | ||
| 149 | - if (blank_ != -1 && !ans.empty()) { | 156 | + if (!ans.empty()) { |
| 150 | // remove the last blank | 157 | // remove the last blank |
| 151 | ans.resize(ans.size() - 1); | 158 | ans.resize(ans.size() - 1); |
| 152 | } | 159 | } |
| @@ -156,9 +163,6 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish( | @@ -156,9 +163,6 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish( | ||
| 156 | 163 | ||
| 157 | void Lexicon::InitTokens(const std::string &tokens) { | 164 | void Lexicon::InitTokens(const std::string &tokens) { |
| 158 | token2id_ = ReadTokens(tokens); | 165 | token2id_ = ReadTokens(tokens); |
| 159 | - if (token2id_.count(" ")) { | ||
| 160 | - blank_ = token2id_.at(" "); | ||
| 161 | - } | ||
| 162 | } | 166 | } |
| 163 | 167 | ||
| 164 | void Lexicon::InitLanguage(const std::string &_lang) { | 168 | void Lexicon::InitLanguage(const std::string &_lang) { |
| @@ -44,7 +44,6 @@ class Lexicon { | @@ -44,7 +44,6 @@ class Lexicon { | ||
| 44 | std::unordered_map<std::string, std::vector<int32_t>> word2ids_; | 44 | std::unordered_map<std::string, std::vector<int32_t>> word2ids_; |
| 45 | std::unordered_set<std::string> punctuations_; | 45 | std::unordered_set<std::string> punctuations_; |
| 46 | std::unordered_map<std::string, int32_t> token2id_; | 46 | std::unordered_map<std::string, int32_t> token2id_; |
| 47 | - int32_t blank_ = -1; // ID for the blank token | ||
| 48 | Language language_; | 47 | Language language_; |
| 49 | // | 48 | // |
| 50 | }; | 49 | }; |
| @@ -25,6 +25,23 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | @@ -25,6 +25,23 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | ||
| 25 | 25 | ||
| 26 | GeneratedAudio Generate(const std::string &text, | 26 | GeneratedAudio Generate(const std::string &text, |
| 27 | int64_t sid = 0) const override { | 27 | int64_t sid = 0) const override { |
| 28 | + int32_t num_speakers = model_->NumSpeakers(); | ||
| 29 | + if (num_speakers == 0 && sid != 0) { | ||
| 30 | + SHERPA_ONNX_LOGE( | ||
| 31 | + "This is a single-speaker model and supports only sid 0. Given sid: " | ||
| 32 | + "%d", | ||
| 33 | + sid); | ||
| 34 | + return {}; | ||
| 35 | + } | ||
| 36 | + | ||
| 37 | + if (num_speakers != 0 && (sid >= num_speakers || sid < 0)) { | ||
| 38 | + SHERPA_ONNX_LOGE( | ||
| 39 | + "This model contains only %d speakers. sid should be in the range " | ||
| 40 | + "[%d, %d]. Given: %d", | ||
| 41 | + num_speakers, 0, num_speakers - 1, sid); | ||
| 42 | + return {}; | ||
| 43 | + } | ||
| 44 | + | ||
| 28 | std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text); | 45 | std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text); |
| 29 | if (x.empty()) { | 46 | if (x.empty()) { |
| 30 | SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str()); | 47 | SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str()); |
| @@ -85,6 +85,7 @@ class OfflineTtsVitsModel::Impl { | @@ -85,6 +85,7 @@ class OfflineTtsVitsModel::Impl { | ||
| 85 | 85 | ||
| 86 | std::string Punctuations() const { return punctuations_; } | 86 | std::string Punctuations() const { return punctuations_; } |
| 87 | std::string Language() const { return language_; } | 87 | std::string Language() const { return language_; } |
| 88 | + int32_t NumSpeakers() const { return num_speakers_; } | ||
| 88 | 89 | ||
| 89 | private: | 90 | private: |
| 90 | void Init(void *model_data, size_t model_data_length) { | 91 | void Init(void *model_data, size_t model_data_length) { |
| @@ -107,7 +108,7 @@ class OfflineTtsVitsModel::Impl { | @@ -107,7 +108,7 @@ class OfflineTtsVitsModel::Impl { | ||
| 107 | Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | 108 | Ort::AllocatorWithDefaultOptions allocator; // used in the macro below |
| 108 | SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate"); | 109 | SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate"); |
| 109 | SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank"); | 110 | SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank"); |
| 110 | - SHERPA_ONNX_READ_META_DATA(n_speakers_, "n_speakers"); | 111 | + SHERPA_ONNX_READ_META_DATA(num_speakers_, "n_speakers"); |
| 111 | SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation"); | 112 | SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation"); |
| 112 | SHERPA_ONNX_READ_META_DATA_STR(language_, "language"); | 113 | SHERPA_ONNX_READ_META_DATA_STR(language_, "language"); |
| 113 | } | 114 | } |
| @@ -128,7 +129,7 @@ class OfflineTtsVitsModel::Impl { | @@ -128,7 +129,7 @@ class OfflineTtsVitsModel::Impl { | ||
| 128 | 129 | ||
| 129 | int32_t sample_rate_; | 130 | int32_t sample_rate_; |
| 130 | int32_t add_blank_; | 131 | int32_t add_blank_; |
| 131 | - int32_t n_speakers_; | 132 | + int32_t num_speakers_; |
| 132 | std::string punctuations_; | 133 | std::string punctuations_; |
| 133 | std::string language_; | 134 | std::string language_; |
| 134 | }; | 135 | }; |
| @@ -152,4 +153,8 @@ std::string OfflineTtsVitsModel::Punctuations() const { | @@ -152,4 +153,8 @@ std::string OfflineTtsVitsModel::Punctuations() const { | ||
| 152 | 153 | ||
| 153 | std::string OfflineTtsVitsModel::Language() const { return impl_->Language(); } | 154 | std::string OfflineTtsVitsModel::Language() const { return impl_->Language(); } |
| 154 | 155 | ||
| 156 | +int32_t OfflineTtsVitsModel::NumSpeakers() const { | ||
| 157 | + return impl_->NumSpeakers(); | ||
| 158 | +} | ||
| 159 | + | ||
| 155 | } // namespace sherpa_onnx | 160 | } // namespace sherpa_onnx |
| @@ -39,6 +39,7 @@ class OfflineTtsVitsModel { | @@ -39,6 +39,7 @@ class OfflineTtsVitsModel { | ||
| 39 | 39 | ||
| 40 | std::string Punctuations() const; | 40 | std::string Punctuations() const; |
| 41 | std::string Language() const; | 41 | std::string Language() const; |
| 42 | + int32_t NumSpeakers() const; | ||
| 42 | 43 | ||
| 43 | private: | 44 | private: |
| 44 | class Impl; | 45 | class Impl; |
| @@ -81,6 +81,12 @@ or detailes. | @@ -81,6 +81,12 @@ or detailes. | ||
| 81 | 81 | ||
| 82 | sherpa_onnx::OfflineTts tts(config); | 82 | sherpa_onnx::OfflineTts tts(config); |
| 83 | auto audio = tts.Generate(po.GetArg(1), sid); | 83 | auto audio = tts.Generate(po.GetArg(1), sid); |
| 84 | + if (audio.samples.empty()) { | ||
| 85 | + fprintf( | ||
| 86 | + stderr, | ||
| 87 | + "Error in generating audios. Please read previous error messages.\n"); | ||
| 88 | + exit(EXIT_FAILURE); | ||
| 89 | + } | ||
| 84 | 90 | ||
| 85 | bool ok = sherpa_onnx::WriteWave(output_filename, audio.sample_rate, | 91 | bool ok = sherpa_onnx::WriteWave(output_filename, audio.sample_rate, |
| 86 | audio.samples.data(), audio.samples.size()); | 92 | audio.samples.data(), audio.samples.size()); |
-
请 注册 或 登录 后发表评论