Fangjun Kuang
Committed by GitHub

Support vits models from piper (#390)

@@ -83,8 +83,8 @@ static std::vector<int32_t> ConvertTokensToIds( @@ -83,8 +83,8 @@ static std::vector<int32_t> ConvertTokensToIds(
83 83
84 Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens, 84 Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
85 const std::string &punctuations, const std::string &language, 85 const std::string &punctuations, const std::string &language,
86 - bool debug /*= false*/)  
87 - : debug_(debug) { 86 + bool debug /*= false*/, bool is_piper /*= false*/)
  87 + : debug_(debug), is_piper_(is_piper) {
88 InitLanguage(language); 88 InitLanguage(language);
89 89
90 { 90 {
@@ -103,8 +103,9 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens, @@ -103,8 +103,9 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
103 #if __ANDROID_API__ >= 9 103 #if __ANDROID_API__ >= 9
104 Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon, 104 Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon,
105 const std::string &tokens, const std::string &punctuations, 105 const std::string &tokens, const std::string &punctuations,
106 - const std::string &language, bool debug /*= false*/)  
107 - : debug_(debug) { 106 + const std::string &language, bool debug /*= false*/,
  107 + bool is_piper /*= false*/)
  108 + : debug_(debug), is_piper_(is_piper) {
108 InitLanguage(language); 109 InitLanguage(language);
109 110
110 { 111 {
@@ -206,6 +207,10 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish( @@ -206,6 +207,10 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
206 int32_t blank = token2id_.at(" "); 207 int32_t blank = token2id_.at(" ");
207 208
208 std::vector<int64_t> ans; 209 std::vector<int64_t> ans;
  210 + if (is_piper_) {
  211 + ans.push_back(token2id_.at("^")); // sos
  212 + }
  213 +
209 for (const auto &w : words) { 214 for (const auto &w : words) {
210 if (punctuations_.count(w)) { 215 if (punctuations_.count(w)) {
211 ans.push_back(token2id_.at(w)); 216 ans.push_back(token2id_.at(w));
@@ -227,6 +232,10 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish( @@ -227,6 +232,10 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
227 ans.resize(ans.size() - 1); 232 ans.resize(ans.size() - 1);
228 } 233 }
229 234
  235 + if (is_piper_) {
  236 + ans.push_back(token2id_.at("$")); // eos
  237 + }
  238 +
230 return ans; 239 return ans;
231 } 240 }
232 241
@@ -24,12 +24,13 @@ class Lexicon { @@ -24,12 +24,13 @@ class Lexicon {
24 public: 24 public:
25 Lexicon(const std::string &lexicon, const std::string &tokens, 25 Lexicon(const std::string &lexicon, const std::string &tokens,
26 const std::string &punctuations, const std::string &language, 26 const std::string &punctuations, const std::string &language,
27 - bool debug = false); 27 + bool debug = false, bool is_piper = false);
28 28
29 #if __ANDROID_API__ >= 9 29 #if __ANDROID_API__ >= 9
30 Lexicon(AAssetManager *mgr, const std::string &lexicon, 30 Lexicon(AAssetManager *mgr, const std::string &lexicon,
31 const std::string &tokens, const std::string &punctuations, 31 const std::string &tokens, const std::string &punctuations,
32 - const std::string &language, bool debug = false); 32 + const std::string &language, bool debug = false,
  33 + bool is_piper = false);
33 #endif 34 #endif
34 35
35 std::vector<int64_t> ConvertTextToTokenIds(const std::string &text) const; 36 std::vector<int64_t> ConvertTextToTokenIds(const std::string &text) const;
@@ -59,7 +60,7 @@ class Lexicon { @@ -59,7 +60,7 @@ class Lexicon {
59 std::unordered_map<std::string, int32_t> token2id_; 60 std::unordered_map<std::string, int32_t> token2id_;
60 Language language_; 61 Language language_;
61 bool debug_; 62 bool debug_;
62 - // 63 + bool is_piper_;
63 }; 64 };
64 65
65 } // namespace sherpa_onnx 66 } // namespace sherpa_onnx
@@ -26,15 +26,15 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -26,15 +26,15 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
26 explicit OfflineTtsVitsImpl(const OfflineTtsConfig &config) 26 explicit OfflineTtsVitsImpl(const OfflineTtsConfig &config)
27 : model_(std::make_unique<OfflineTtsVitsModel>(config.model)), 27 : model_(std::make_unique<OfflineTtsVitsModel>(config.model)),
28 lexicon_(config.model.vits.lexicon, config.model.vits.tokens, 28 lexicon_(config.model.vits.lexicon, config.model.vits.tokens,
29 - model_->Punctuations(), model_->Language(),  
30 - config.model.debug) {} 29 + model_->Punctuations(), model_->Language(), config.model.debug,
  30 + model_->IsPiper()) {}
31 31
32 #if __ANDROID_API__ >= 9 32 #if __ANDROID_API__ >= 9
33 OfflineTtsVitsImpl(AAssetManager *mgr, const OfflineTtsConfig &config) 33 OfflineTtsVitsImpl(AAssetManager *mgr, const OfflineTtsConfig &config)
34 : model_(std::make_unique<OfflineTtsVitsModel>(mgr, config.model)), 34 : model_(std::make_unique<OfflineTtsVitsModel>(mgr, config.model)),
35 lexicon_(mgr, config.model.vits.lexicon, config.model.vits.tokens, 35 lexicon_(mgr, config.model.vits.lexicon, config.model.vits.tokens,
36 - model_->Punctuations(), model_->Language(),  
37 - config.model.debug) {} 36 + model_->Punctuations(), model_->Language(), config.model.debug,
  37 + model_->IsPiper()) {}
38 #endif 38 #endif
39 39
40 GeneratedAudio Generate(const std::string &text, int64_t sid = 0, 40 GeneratedAudio Generate(const std::string &text, int64_t sid = 0,
@@ -43,17 +43,16 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -43,17 +43,16 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
43 if (num_speakers == 0 && sid != 0) { 43 if (num_speakers == 0 && sid != 0) {
44 SHERPA_ONNX_LOGE( 44 SHERPA_ONNX_LOGE(
45 "This is a single-speaker model and supports only sid 0. Given sid: " 45 "This is a single-speaker model and supports only sid 0. Given sid: "
46 - "%d", 46 + "%d. sid is ignored",
47 static_cast<int32_t>(sid)); 47 static_cast<int32_t>(sid));
48 - return {};  
49 } 48 }
50 49
51 if (num_speakers != 0 && (sid >= num_speakers || sid < 0)) { 50 if (num_speakers != 0 && (sid >= num_speakers || sid < 0)) {
52 SHERPA_ONNX_LOGE( 51 SHERPA_ONNX_LOGE(
53 "This model contains only %d speakers. sid should be in the range " 52 "This model contains only %d speakers. sid should be in the range "
54 - "[%d, %d]. Given: %d", 53 + "[%d, %d]. Given: %d. Use sid=0",
55 num_speakers, 0, num_speakers - 1, static_cast<int32_t>(sid)); 54 num_speakers, 0, num_speakers - 1, static_cast<int32_t>(sid));
56 - return {}; 55 + sid = 0;
57 } 56 }
58 57
59 std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text); 58 std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text);
@@ -38,6 +38,107 @@ class OfflineTtsVitsModel::Impl { @@ -38,6 +38,107 @@ class OfflineTtsVitsModel::Impl {
38 #endif 38 #endif
39 39
40 Ort::Value Run(Ort::Value x, int64_t sid, float speed) { 40 Ort::Value Run(Ort::Value x, int64_t sid, float speed) {
  41 + if (is_piper_) {
  42 + return RunVitsPiper(std::move(x), sid, speed);
  43 + }
  44 +
  45 + return RunVits(std::move(x), sid, speed);
  46 + }
  47 +
  48 + int32_t SampleRate() const { return sample_rate_; }
  49 +
  50 + bool AddBlank() const { return add_blank_; }
  51 +
  52 + std::string Punctuations() const { return punctuations_; }
  53 + std::string Language() const { return language_; }
  54 + bool IsPiper() const { return is_piper_; }
  55 + int32_t NumSpeakers() const { return num_speakers_; }
  56 +
  57 + private:
  58 + void Init(void *model_data, size_t model_data_length) {
  59 + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
  60 + sess_opts_);
  61 +
  62 + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
  63 +
  64 + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
  65 +
  66 + // get meta data
  67 + Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
  68 + if (config_.debug) {
  69 + std::ostringstream os;
  70 + os << "---vits model---\n";
  71 + PrintModelMetadata(os, meta_data);
  72 + SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
  73 + }
  74 +
  75 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  76 + SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate");
  77 + SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank");
  78 + SHERPA_ONNX_READ_META_DATA(num_speakers_, "n_speakers");
  79 + SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation");
  80 + SHERPA_ONNX_READ_META_DATA_STR(language_, "language");
  81 +
  82 + std::string comment;
  83 + SHERPA_ONNX_READ_META_DATA_STR(comment, "comment");
  84 + if (comment.find("piper") != std::string::npos) {
  85 + is_piper_ = true;
  86 + }
  87 + }
  88 +
  89 + Ort::Value RunVitsPiper(Ort::Value x, int64_t sid, float speed) {
  90 + auto memory_info =
  91 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  92 +
  93 + std::vector<int64_t> x_shape = x.GetTensorTypeAndShapeInfo().GetShape();
  94 + if (x_shape[0] != 1) {
  95 + SHERPA_ONNX_LOGE("Support only batch_size == 1. Given: %d",
  96 + static_cast<int32_t>(x_shape[0]));
  97 + exit(-1);
  98 + }
  99 +
  100 + int64_t len = x_shape[1];
  101 + int64_t len_shape = 1;
  102 +
  103 + Ort::Value x_length =
  104 + Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1);
  105 +
  106 + float noise_scale = config_.vits.noise_scale;
  107 + float length_scale = config_.vits.length_scale;
  108 + float noise_scale_w = config_.vits.noise_scale_w;
  109 +
  110 + if (speed != 1 && speed > 0) {
  111 + length_scale = 1. / speed;
  112 + }
  113 + std::array<float, 3> scales = {noise_scale, length_scale, noise_scale_w};
  114 +
  115 + int64_t scale_shape = 3;
  116 +
  117 + Ort::Value scales_tensor = Ort::Value::CreateTensor(
  118 + memory_info, scales.data(), scales.size(), &scale_shape, 1);
  119 +
  120 + int64_t sid_shape = 1;
  121 + Ort::Value sid_tensor =
  122 + Ort::Value::CreateTensor(memory_info, &sid, 1, &sid_shape, 1);
  123 +
  124 + std::vector<Ort::Value> inputs;
  125 + inputs.reserve(4);
  126 + inputs.push_back(std::move(x));
  127 + inputs.push_back(std::move(x_length));
  128 + inputs.push_back(std::move(scales_tensor));
  129 +
  130 + if (input_names_.size() == 4 && input_names_.back() == "sid") {
  131 + inputs.push_back(std::move(sid_tensor));
  132 + }
  133 +
  134 + auto out =
  135 + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
  136 + output_names_ptr_.data(), output_names_ptr_.size());
  137 +
  138 + return std::move(out[0]);
  139 + }
  140 +
  141 + Ort::Value RunVits(Ort::Value x, int64_t sid, float speed) {
41 auto memory_info = 142 auto memory_info =
42 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); 143 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
43 144
@@ -94,40 +195,6 @@ class OfflineTtsVitsModel::Impl { @@ -94,40 +195,6 @@ class OfflineTtsVitsModel::Impl {
94 return std::move(out[0]); 195 return std::move(out[0]);
95 } 196 }
96 197
97 - int32_t SampleRate() const { return sample_rate_; }  
98 -  
99 - bool AddBlank() const { return add_blank_; }  
100 -  
101 - std::string Punctuations() const { return punctuations_; }  
102 - std::string Language() const { return language_; }  
103 - int32_t NumSpeakers() const { return num_speakers_; }  
104 -  
105 - private:  
106 - void Init(void *model_data, size_t model_data_length) {  
107 - sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,  
108 - sess_opts_);  
109 -  
110 - GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);  
111 -  
112 - GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);  
113 -  
114 - // get meta data  
115 - Ort::ModelMetadata meta_data = sess_->GetModelMetadata();  
116 - if (config_.debug) {  
117 - std::ostringstream os;  
118 - os << "---vits model---\n";  
119 - PrintModelMetadata(os, meta_data);  
120 - SHERPA_ONNX_LOGE("%s\n", os.str().c_str());  
121 - }  
122 -  
123 - Ort::AllocatorWithDefaultOptions allocator; // used in the macro below  
124 - SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate");  
125 - SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank");  
126 - SHERPA_ONNX_READ_META_DATA(num_speakers_, "n_speakers");  
127 - SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation");  
128 - SHERPA_ONNX_READ_META_DATA_STR(language_, "language");  
129 - }  
130 -  
131 private: 198 private:
132 OfflineTtsModelConfig config_; 199 OfflineTtsModelConfig config_;
133 Ort::Env env_; 200 Ort::Env env_;
@@ -147,6 +214,8 @@ class OfflineTtsVitsModel::Impl { @@ -147,6 +214,8 @@ class OfflineTtsVitsModel::Impl {
147 int32_t num_speakers_; 214 int32_t num_speakers_;
148 std::string punctuations_; 215 std::string punctuations_;
149 std::string language_; 216 std::string language_;
  217 +
  218 + bool is_piper_ = false;
150 }; 219 };
151 220
152 OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config) 221 OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config)
@@ -175,6 +244,8 @@ std::string OfflineTtsVitsModel::Punctuations() const { @@ -175,6 +244,8 @@ std::string OfflineTtsVitsModel::Punctuations() const {
175 244
176 std::string OfflineTtsVitsModel::Language() const { return impl_->Language(); } 245 std::string OfflineTtsVitsModel::Language() const { return impl_->Language(); }
177 246
  247 +bool OfflineTtsVitsModel::IsPiper() const { return impl_->IsPiper(); }
  248 +
178 int32_t OfflineTtsVitsModel::NumSpeakers() const { 249 int32_t OfflineTtsVitsModel::NumSpeakers() const {
179 return impl_->NumSpeakers(); 250 return impl_->NumSpeakers();
180 } 251 }
@@ -47,6 +47,7 @@ class OfflineTtsVitsModel { @@ -47,6 +47,7 @@ class OfflineTtsVitsModel {
47 47
48 std::string Punctuations() const; 48 std::string Punctuations() const;
49 std::string Language() const; 49 std::string Language() const;
  50 + bool IsPiper() const;
50 int32_t NumSpeakers() const; 51 int32_t NumSpeakers() const;
51 52
52 private: 53 private: