正在显示
5 个修改的文件
包含
130 行增加
和
49 行删除
| @@ -83,8 +83,8 @@ static std::vector<int32_t> ConvertTokensToIds( | @@ -83,8 +83,8 @@ static std::vector<int32_t> ConvertTokensToIds( | ||
| 83 | 83 | ||
| 84 | Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens, | 84 | Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens, |
| 85 | const std::string &punctuations, const std::string &language, | 85 | const std::string &punctuations, const std::string &language, |
| 86 | - bool debug /*= false*/) | ||
| 87 | - : debug_(debug) { | 86 | + bool debug /*= false*/, bool is_piper /*= false*/) |
| 87 | + : debug_(debug), is_piper_(is_piper) { | ||
| 88 | InitLanguage(language); | 88 | InitLanguage(language); |
| 89 | 89 | ||
| 90 | { | 90 | { |
| @@ -103,8 +103,9 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens, | @@ -103,8 +103,9 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens, | ||
| 103 | #if __ANDROID_API__ >= 9 | 103 | #if __ANDROID_API__ >= 9 |
| 104 | Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon, | 104 | Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon, |
| 105 | const std::string &tokens, const std::string &punctuations, | 105 | const std::string &tokens, const std::string &punctuations, |
| 106 | - const std::string &language, bool debug /*= false*/) | ||
| 107 | - : debug_(debug) { | 106 | + const std::string &language, bool debug /*= false*/, |
| 107 | + bool is_piper /*= false*/) | ||
| 108 | + : debug_(debug), is_piper_(is_piper) { | ||
| 108 | InitLanguage(language); | 109 | InitLanguage(language); |
| 109 | 110 | ||
| 110 | { | 111 | { |
| @@ -206,6 +207,10 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish( | @@ -206,6 +207,10 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish( | ||
| 206 | int32_t blank = token2id_.at(" "); | 207 | int32_t blank = token2id_.at(" "); |
| 207 | 208 | ||
| 208 | std::vector<int64_t> ans; | 209 | std::vector<int64_t> ans; |
| 210 | + if (is_piper_) { | ||
| 211 | + ans.push_back(token2id_.at("^")); // sos | ||
| 212 | + } | ||
| 213 | + | ||
| 209 | for (const auto &w : words) { | 214 | for (const auto &w : words) { |
| 210 | if (punctuations_.count(w)) { | 215 | if (punctuations_.count(w)) { |
| 211 | ans.push_back(token2id_.at(w)); | 216 | ans.push_back(token2id_.at(w)); |
| @@ -227,6 +232,10 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish( | @@ -227,6 +232,10 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish( | ||
| 227 | ans.resize(ans.size() - 1); | 232 | ans.resize(ans.size() - 1); |
| 228 | } | 233 | } |
| 229 | 234 | ||
| 235 | + if (is_piper_) { | ||
| 236 | + ans.push_back(token2id_.at("$")); // eos | ||
| 237 | + } | ||
| 238 | + | ||
| 230 | return ans; | 239 | return ans; |
| 231 | } | 240 | } |
| 232 | 241 |
| @@ -24,12 +24,13 @@ class Lexicon { | @@ -24,12 +24,13 @@ class Lexicon { | ||
| 24 | public: | 24 | public: |
| 25 | Lexicon(const std::string &lexicon, const std::string &tokens, | 25 | Lexicon(const std::string &lexicon, const std::string &tokens, |
| 26 | const std::string &punctuations, const std::string &language, | 26 | const std::string &punctuations, const std::string &language, |
| 27 | - bool debug = false); | 27 | + bool debug = false, bool is_piper = false); |
| 28 | 28 | ||
| 29 | #if __ANDROID_API__ >= 9 | 29 | #if __ANDROID_API__ >= 9 |
| 30 | Lexicon(AAssetManager *mgr, const std::string &lexicon, | 30 | Lexicon(AAssetManager *mgr, const std::string &lexicon, |
| 31 | const std::string &tokens, const std::string &punctuations, | 31 | const std::string &tokens, const std::string &punctuations, |
| 32 | - const std::string &language, bool debug = false); | 32 | + const std::string &language, bool debug = false, |
| 33 | + bool is_piper = false); | ||
| 33 | #endif | 34 | #endif |
| 34 | 35 | ||
| 35 | std::vector<int64_t> ConvertTextToTokenIds(const std::string &text) const; | 36 | std::vector<int64_t> ConvertTextToTokenIds(const std::string &text) const; |
| @@ -59,7 +60,7 @@ class Lexicon { | @@ -59,7 +60,7 @@ class Lexicon { | ||
| 59 | std::unordered_map<std::string, int32_t> token2id_; | 60 | std::unordered_map<std::string, int32_t> token2id_; |
| 60 | Language language_; | 61 | Language language_; |
| 61 | bool debug_; | 62 | bool debug_; |
| 62 | - // | 63 | + bool is_piper_; |
| 63 | }; | 64 | }; |
| 64 | 65 | ||
| 65 | } // namespace sherpa_onnx | 66 | } // namespace sherpa_onnx |
| @@ -26,15 +26,15 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | @@ -26,15 +26,15 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | ||
| 26 | explicit OfflineTtsVitsImpl(const OfflineTtsConfig &config) | 26 | explicit OfflineTtsVitsImpl(const OfflineTtsConfig &config) |
| 27 | : model_(std::make_unique<OfflineTtsVitsModel>(config.model)), | 27 | : model_(std::make_unique<OfflineTtsVitsModel>(config.model)), |
| 28 | lexicon_(config.model.vits.lexicon, config.model.vits.tokens, | 28 | lexicon_(config.model.vits.lexicon, config.model.vits.tokens, |
| 29 | - model_->Punctuations(), model_->Language(), | ||
| 30 | - config.model.debug) {} | 29 | + model_->Punctuations(), model_->Language(), config.model.debug, |
| 30 | + model_->IsPiper()) {} | ||
| 31 | 31 | ||
| 32 | #if __ANDROID_API__ >= 9 | 32 | #if __ANDROID_API__ >= 9 |
| 33 | OfflineTtsVitsImpl(AAssetManager *mgr, const OfflineTtsConfig &config) | 33 | OfflineTtsVitsImpl(AAssetManager *mgr, const OfflineTtsConfig &config) |
| 34 | : model_(std::make_unique<OfflineTtsVitsModel>(mgr, config.model)), | 34 | : model_(std::make_unique<OfflineTtsVitsModel>(mgr, config.model)), |
| 35 | lexicon_(mgr, config.model.vits.lexicon, config.model.vits.tokens, | 35 | lexicon_(mgr, config.model.vits.lexicon, config.model.vits.tokens, |
| 36 | - model_->Punctuations(), model_->Language(), | ||
| 37 | - config.model.debug) {} | 36 | + model_->Punctuations(), model_->Language(), config.model.debug, |
| 37 | + model_->IsPiper()) {} | ||
| 38 | #endif | 38 | #endif |
| 39 | 39 | ||
| 40 | GeneratedAudio Generate(const std::string &text, int64_t sid = 0, | 40 | GeneratedAudio Generate(const std::string &text, int64_t sid = 0, |
| @@ -43,17 +43,16 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | @@ -43,17 +43,16 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | ||
| 43 | if (num_speakers == 0 && sid != 0) { | 43 | if (num_speakers == 0 && sid != 0) { |
| 44 | SHERPA_ONNX_LOGE( | 44 | SHERPA_ONNX_LOGE( |
| 45 | "This is a single-speaker model and supports only sid 0. Given sid: " | 45 | "This is a single-speaker model and supports only sid 0. Given sid: " |
| 46 | - "%d", | 46 | + "%d. sid is ignored", |
| 47 | static_cast<int32_t>(sid)); | 47 | static_cast<int32_t>(sid)); |
| 48 | - return {}; | ||
| 49 | } | 48 | } |
| 50 | 49 | ||
| 51 | if (num_speakers != 0 && (sid >= num_speakers || sid < 0)) { | 50 | if (num_speakers != 0 && (sid >= num_speakers || sid < 0)) { |
| 52 | SHERPA_ONNX_LOGE( | 51 | SHERPA_ONNX_LOGE( |
| 53 | "This model contains only %d speakers. sid should be in the range " | 52 | "This model contains only %d speakers. sid should be in the range " |
| 54 | - "[%d, %d]. Given: %d", | 53 | + "[%d, %d]. Given: %d. Use sid=0", |
| 55 | num_speakers, 0, num_speakers - 1, static_cast<int32_t>(sid)); | 54 | num_speakers, 0, num_speakers - 1, static_cast<int32_t>(sid)); |
| 56 | - return {}; | 55 | + sid = 0; |
| 57 | } | 56 | } |
| 58 | 57 | ||
| 59 | std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text); | 58 | std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text); |
| @@ -38,6 +38,107 @@ class OfflineTtsVitsModel::Impl { | @@ -38,6 +38,107 @@ class OfflineTtsVitsModel::Impl { | ||
| 38 | #endif | 38 | #endif |
| 39 | 39 | ||
| 40 | Ort::Value Run(Ort::Value x, int64_t sid, float speed) { | 40 | Ort::Value Run(Ort::Value x, int64_t sid, float speed) { |
| 41 | + if (is_piper_) { | ||
| 42 | + return RunVitsPiper(std::move(x), sid, speed); | ||
| 43 | + } | ||
| 44 | + | ||
| 45 | + return RunVits(std::move(x), sid, speed); | ||
| 46 | + } | ||
| 47 | + | ||
| 48 | + int32_t SampleRate() const { return sample_rate_; } | ||
| 49 | + | ||
| 50 | + bool AddBlank() const { return add_blank_; } | ||
| 51 | + | ||
| 52 | + std::string Punctuations() const { return punctuations_; } | ||
| 53 | + std::string Language() const { return language_; } | ||
| 54 | + bool IsPiper() const { return is_piper_; } | ||
| 55 | + int32_t NumSpeakers() const { return num_speakers_; } | ||
| 56 | + | ||
| 57 | + private: | ||
| 58 | + void Init(void *model_data, size_t model_data_length) { | ||
| 59 | + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length, | ||
| 60 | + sess_opts_); | ||
| 61 | + | ||
| 62 | + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | ||
| 63 | + | ||
| 64 | + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | ||
| 65 | + | ||
| 66 | + // get meta data | ||
| 67 | + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); | ||
| 68 | + if (config_.debug) { | ||
| 69 | + std::ostringstream os; | ||
| 70 | + os << "---vits model---\n"; | ||
| 71 | + PrintModelMetadata(os, meta_data); | ||
| 72 | + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
| 73 | + } | ||
| 74 | + | ||
| 75 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 76 | + SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate"); | ||
| 77 | + SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank"); | ||
| 78 | + SHERPA_ONNX_READ_META_DATA(num_speakers_, "n_speakers"); | ||
| 79 | + SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation"); | ||
| 80 | + SHERPA_ONNX_READ_META_DATA_STR(language_, "language"); | ||
| 81 | + | ||
| 82 | + std::string comment; | ||
| 83 | + SHERPA_ONNX_READ_META_DATA_STR(comment, "comment"); | ||
| 84 | + if (comment.find("piper") != std::string::npos) { | ||
| 85 | + is_piper_ = true; | ||
| 86 | + } | ||
| 87 | + } | ||
| 88 | + | ||
| 89 | + Ort::Value RunVitsPiper(Ort::Value x, int64_t sid, float speed) { | ||
| 90 | + auto memory_info = | ||
| 91 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 92 | + | ||
| 93 | + std::vector<int64_t> x_shape = x.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 94 | + if (x_shape[0] != 1) { | ||
| 95 | + SHERPA_ONNX_LOGE("Support only batch_size == 1. Given: %d", | ||
| 96 | + static_cast<int32_t>(x_shape[0])); | ||
| 97 | + exit(-1); | ||
| 98 | + } | ||
| 99 | + | ||
| 100 | + int64_t len = x_shape[1]; | ||
| 101 | + int64_t len_shape = 1; | ||
| 102 | + | ||
| 103 | + Ort::Value x_length = | ||
| 104 | + Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1); | ||
| 105 | + | ||
| 106 | + float noise_scale = config_.vits.noise_scale; | ||
| 107 | + float length_scale = config_.vits.length_scale; | ||
| 108 | + float noise_scale_w = config_.vits.noise_scale_w; | ||
| 109 | + | ||
| 110 | + if (speed != 1 && speed > 0) { | ||
| 111 | + length_scale = 1. / speed; | ||
| 112 | + } | ||
| 113 | + std::array<float, 3> scales = {noise_scale, length_scale, noise_scale_w}; | ||
| 114 | + | ||
| 115 | + int64_t scale_shape = 3; | ||
| 116 | + | ||
| 117 | + Ort::Value scales_tensor = Ort::Value::CreateTensor( | ||
| 118 | + memory_info, scales.data(), scales.size(), &scale_shape, 1); | ||
| 119 | + | ||
| 120 | + int64_t sid_shape = 1; | ||
| 121 | + Ort::Value sid_tensor = | ||
| 122 | + Ort::Value::CreateTensor(memory_info, &sid, 1, &sid_shape, 1); | ||
| 123 | + | ||
| 124 | + std::vector<Ort::Value> inputs; | ||
| 125 | + inputs.reserve(4); | ||
| 126 | + inputs.push_back(std::move(x)); | ||
| 127 | + inputs.push_back(std::move(x_length)); | ||
| 128 | + inputs.push_back(std::move(scales_tensor)); | ||
| 129 | + | ||
| 130 | + if (input_names_.size() == 4 && input_names_.back() == "sid") { | ||
| 131 | + inputs.push_back(std::move(sid_tensor)); | ||
| 132 | + } | ||
| 133 | + | ||
| 134 | + auto out = | ||
| 135 | + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 136 | + output_names_ptr_.data(), output_names_ptr_.size()); | ||
| 137 | + | ||
| 138 | + return std::move(out[0]); | ||
| 139 | + } | ||
| 140 | + | ||
| 141 | + Ort::Value RunVits(Ort::Value x, int64_t sid, float speed) { | ||
| 41 | auto memory_info = | 142 | auto memory_info = |
| 42 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | 143 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); |
| 43 | 144 | ||
| @@ -94,40 +195,6 @@ class OfflineTtsVitsModel::Impl { | @@ -94,40 +195,6 @@ class OfflineTtsVitsModel::Impl { | ||
| 94 | return std::move(out[0]); | 195 | return std::move(out[0]); |
| 95 | } | 196 | } |
| 96 | 197 | ||
| 97 | - int32_t SampleRate() const { return sample_rate_; } | ||
| 98 | - | ||
| 99 | - bool AddBlank() const { return add_blank_; } | ||
| 100 | - | ||
| 101 | - std::string Punctuations() const { return punctuations_; } | ||
| 102 | - std::string Language() const { return language_; } | ||
| 103 | - int32_t NumSpeakers() const { return num_speakers_; } | ||
| 104 | - | ||
| 105 | - private: | ||
| 106 | - void Init(void *model_data, size_t model_data_length) { | ||
| 107 | - sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length, | ||
| 108 | - sess_opts_); | ||
| 109 | - | ||
| 110 | - GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | ||
| 111 | - | ||
| 112 | - GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | ||
| 113 | - | ||
| 114 | - // get meta data | ||
| 115 | - Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); | ||
| 116 | - if (config_.debug) { | ||
| 117 | - std::ostringstream os; | ||
| 118 | - os << "---vits model---\n"; | ||
| 119 | - PrintModelMetadata(os, meta_data); | ||
| 120 | - SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
| 121 | - } | ||
| 122 | - | ||
| 123 | - Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 124 | - SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate"); | ||
| 125 | - SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank"); | ||
| 126 | - SHERPA_ONNX_READ_META_DATA(num_speakers_, "n_speakers"); | ||
| 127 | - SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation"); | ||
| 128 | - SHERPA_ONNX_READ_META_DATA_STR(language_, "language"); | ||
| 129 | - } | ||
| 130 | - | ||
| 131 | private: | 198 | private: |
| 132 | OfflineTtsModelConfig config_; | 199 | OfflineTtsModelConfig config_; |
| 133 | Ort::Env env_; | 200 | Ort::Env env_; |
| @@ -147,6 +214,8 @@ class OfflineTtsVitsModel::Impl { | @@ -147,6 +214,8 @@ class OfflineTtsVitsModel::Impl { | ||
| 147 | int32_t num_speakers_; | 214 | int32_t num_speakers_; |
| 148 | std::string punctuations_; | 215 | std::string punctuations_; |
| 149 | std::string language_; | 216 | std::string language_; |
| 217 | + | ||
| 218 | + bool is_piper_ = false; | ||
| 150 | }; | 219 | }; |
| 151 | 220 | ||
| 152 | OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config) | 221 | OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config) |
| @@ -175,6 +244,8 @@ std::string OfflineTtsVitsModel::Punctuations() const { | @@ -175,6 +244,8 @@ std::string OfflineTtsVitsModel::Punctuations() const { | ||
| 175 | 244 | ||
| 176 | std::string OfflineTtsVitsModel::Language() const { return impl_->Language(); } | 245 | std::string OfflineTtsVitsModel::Language() const { return impl_->Language(); } |
| 177 | 246 | ||
| 247 | +bool OfflineTtsVitsModel::IsPiper() const { return impl_->IsPiper(); } | ||
| 248 | + | ||
| 178 | int32_t OfflineTtsVitsModel::NumSpeakers() const { | 249 | int32_t OfflineTtsVitsModel::NumSpeakers() const { |
| 179 | return impl_->NumSpeakers(); | 250 | return impl_->NumSpeakers(); |
| 180 | } | 251 | } |
| @@ -47,6 +47,7 @@ class OfflineTtsVitsModel { | @@ -47,6 +47,7 @@ class OfflineTtsVitsModel { | ||
| 47 | 47 | ||
| 48 | std::string Punctuations() const; | 48 | std::string Punctuations() const; |
| 49 | std::string Language() const; | 49 | std::string Language() const; |
| 50 | + bool IsPiper() const; | ||
| 50 | int32_t NumSpeakers() const; | 51 | int32_t NumSpeakers() const; |
| 51 | 52 | ||
| 52 | private: | 53 | private: |
-
请 注册 或 登录 后发表评论