Fangjun Kuang
Committed by GitHub

Add TTS with VITS (#360)

@@ -86,6 +86,15 @@ set(sources @@ -86,6 +86,15 @@ set(sources
86 wave-reader.cc 86 wave-reader.cc
87 ) 87 )
88 88
  89 +list(APPEND sources
  90 + lexicon.cc
  91 + offline-tts-impl.cc
  92 + offline-tts-model-config.cc
  93 + offline-tts-vits-model-config.cc
  94 + offline-tts-vits-model.cc
  95 + offline-tts.cc
  96 +)
  97 +
89 if(SHERPA_ONNX_ENABLE_CHECK) 98 if(SHERPA_ONNX_ENABLE_CHECK)
90 list(APPEND sources log.cc) 99 list(APPEND sources log.cc)
91 endif() 100 endif()
@@ -135,23 +144,31 @@ endif() @@ -135,23 +144,31 @@ endif()
135 add_executable(sherpa-onnx sherpa-onnx.cc) 144 add_executable(sherpa-onnx sherpa-onnx.cc)
136 add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc) 145 add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
137 add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) 146 add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
  147 +add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
138 148
139 149
140 target_link_libraries(sherpa-onnx sherpa-onnx-core) 150 target_link_libraries(sherpa-onnx sherpa-onnx-core)
141 target_link_libraries(sherpa-onnx-offline sherpa-onnx-core) 151 target_link_libraries(sherpa-onnx-offline sherpa-onnx-core)
142 target_link_libraries(sherpa-onnx-offline-parallel sherpa-onnx-core) 152 target_link_libraries(sherpa-onnx-offline-parallel sherpa-onnx-core)
  153 +target_link_libraries(sherpa-onnx-offline-tts sherpa-onnx-core)
143 if(NOT WIN32) 154 if(NOT WIN32)
144 target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib") 155 target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
145 target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib") 156 target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib")
146 157
147 target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib") 158 target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
148 target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib") 159 target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib")
  160 +
  161 + target_link_libraries(sherpa-onnx-offline-parallel "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
149 target_link_libraries(sherpa-onnx-offline-parallel "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib") 162 target_link_libraries(sherpa-onnx-offline-parallel "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib")
150 163
  164 + target_link_libraries(sherpa-onnx-offline-tts "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
  165 + target_link_libraries(sherpa-onnx-offline-tts "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib")
  166 +
151 if(SHERPA_ONNX_ENABLE_PYTHON) 167 if(SHERPA_ONNX_ENABLE_PYTHON)
152 target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib") 168 target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib")
153 target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib") 169 target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib")
154 target_link_libraries(sherpa-onnx-offline-parallel "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib") 170 target_link_libraries(sherpa-onnx-offline-parallel "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib")
  171 + target_link_libraries(sherpa-onnx-offline-tts "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib")
155 endif() 172 endif()
156 endif() 173 endif()
157 174
@@ -170,6 +187,7 @@ install( @@ -170,6 +187,7 @@ install(
170 sherpa-onnx 187 sherpa-onnx
171 sherpa-onnx-offline 188 sherpa-onnx-offline
172 sherpa-onnx-offline-parallel 189 sherpa-onnx-offline-parallel
  190 + sherpa-onnx-offline-tts
173 DESTINATION 191 DESTINATION
174 bin 192 bin
175 ) 193 )
  1 +// sherpa-onnx/csrc/lexicon.cc
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/lexicon.h"
  6 +
  7 +#include <algorithm>
  8 +#include <cctype>
  9 +#include <fstream>
  10 +#include <sstream>
  11 +#include <utility>
  12 +
  13 +#include "sherpa-onnx/csrc/macros.h"
  14 +#include "sherpa-onnx/csrc/text-utils.h"
  15 +
  16 +namespace sherpa_onnx {
  17 +
  18 +static void ToLowerCase(std::string *in_out) {
  19 + std::transform(in_out->begin(), in_out->end(), in_out->begin(),
  20 + [](unsigned char c) { return std::tolower(c); });
  21 +}
  22 +
  23 +// Note: We don't use SymbolTable here since tokens may contain a blank
  24 +// in the first column
  25 +static std::unordered_map<std::string, int32_t> ReadTokens(
  26 + const std::string &tokens) {
  27 + std::unordered_map<std::string, int32_t> token2id;
  28 +
  29 + std::ifstream is(tokens);
  30 + std::string line;
  31 +
  32 + std::string sym;
  33 + int32_t id;
  34 + while (std::getline(is, line)) {
  35 + std::istringstream iss(line);
  36 + iss >> sym;
  37 + if (iss.eof()) {
  38 + id = atoi(sym.c_str());
  39 + sym = " ";
  40 + } else {
  41 + iss >> id;
  42 + }
  43 +
  44 + if (!iss.eof()) {
  45 + SHERPA_ONNX_LOGE("Error: %s", line.c_str());
  46 + exit(-1);
  47 + }
  48 +
  49 +#if 0
  50 + if (token2id.count(sym)) {
  51 + SHERPA_ONNX_LOGE("Duplicated token %s. Line %s. Existing ID: %d",
  52 + sym.c_str(), line.c_str(), token2id.at(sym));
  53 + exit(-1);
  54 + }
  55 +#endif
  56 + token2id.insert({sym, id});
  57 + }
  58 +
  59 + return token2id;
  60 +}
  61 +
  62 +static std::vector<int32_t> ConvertTokensToIds(
  63 + const std::unordered_map<std::string, int32_t> &token2id,
  64 + const std::vector<std::string> &tokens) {
  65 + std::vector<int32_t> ids;
  66 + ids.reserve(tokens.size());
  67 + for (const auto &s : tokens) {
  68 + if (!token2id.count(s)) {
  69 + return {};
  70 + }
  71 + int32_t id = token2id.at(s);
  72 + ids.push_back(id);
  73 + }
  74 +
  75 + return ids;
  76 +}
  77 +
  78 +Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
  79 + const std::string &punctuations) {
  80 + token2id_ = ReadTokens(tokens);
  81 + std::ifstream is(lexicon);
  82 +
  83 + std::string word;
  84 + std::vector<std::string> token_list;
  85 + std::string line;
  86 + std::string phone;
  87 +
  88 + while (std::getline(is, line)) {
  89 + std::istringstream iss(line);
  90 +
  91 + token_list.clear();
  92 +
  93 + iss >> word;
  94 + ToLowerCase(&word);
  95 +
  96 + if (word2ids_.count(word)) {
  97 + SHERPA_ONNX_LOGE("Duplicated word: %s", word.c_str());
  98 + return;
  99 + }
  100 +
  101 + while (iss >> phone) {
  102 + token_list.push_back(std::move(phone));
  103 + }
  104 +
  105 + std::vector<int32_t> ids = ConvertTokensToIds(token2id_, token_list);
  106 + if (ids.empty()) {
  107 + continue;
  108 + }
  109 + word2ids_.insert({std::move(word), std::move(ids)});
  110 + }
  111 +
  112 + // process punctuations
  113 + std::vector<std::string> punctuation_list;
  114 + SplitStringToVector(punctuations, " ", false, &punctuation_list);
  115 + for (auto &s : punctuation_list) {
  116 + punctuations_.insert(std::move(s));
  117 + }
  118 +}
  119 +
  120 +std::vector<int64_t> Lexicon::ConvertTextToTokenIds(
  121 + const std::string &_text) const {
  122 + std::string text(_text);
  123 + ToLowerCase(&text);
  124 +
  125 + std::vector<std::string> words;
  126 + SplitStringToVector(text, " ", false, &words);
  127 +
  128 + std::vector<int64_t> ans;
  129 + for (auto w : words) {
  130 + std::vector<int64_t> prefix;
  131 + while (!w.empty() && punctuations_.count(std::string(1, w[0]))) {
  132 + // if w begins with a punctuation
  133 + prefix.push_back(token2id_.at(std::string(1, w[0])));
  134 + w = std::string(w.begin() + 1, w.end());
  135 + }
  136 +
  137 + std::vector<int64_t> suffix;
  138 + while (!w.empty() && punctuations_.count(std::string(1, w.back()))) {
  139 + suffix.push_back(token2id_.at(std::string(1, w.back())));
  140 + w = std::string(w.begin(), w.end() - 1);
  141 + }
  142 +
  143 + if (!word2ids_.count(w)) {
  144 + SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
  145 + continue;
  146 + }
  147 +
  148 + const auto &token_ids = word2ids_.at(w);
  149 + ans.insert(ans.end(), prefix.begin(), prefix.end());
  150 + ans.insert(ans.end(), token_ids.begin(), token_ids.end());
  151 + ans.insert(ans.end(), suffix.rbegin(), suffix.rend());
  152 + }
  153 +
  154 + return ans;
  155 +}
  156 +
  157 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/lexicon.h
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_LEXICON_H_
  6 +#define SHERPA_ONNX_CSRC_LEXICON_H_
  7 +
  8 +#include <cstdint>
  9 +#include <string>
  10 +#include <unordered_map>
  11 +#include <unordered_set>
  12 +#include <vector>
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +class Lexicon {
  17 + public:
  18 + Lexicon(const std::string &lexicon, const std::string &tokens,
  19 + const std::string &punctuations);
  20 +
  21 + std::vector<int64_t> ConvertTextToTokenIds(const std::string &text) const;
  22 +
  23 + private:
  24 + std::unordered_map<std::string, std::vector<int32_t>> word2ids_;
  25 + std::unordered_set<std::string> punctuations_;
  26 + std::unordered_map<std::string, int32_t> token2id_;
  27 +};
  28 +
  29 +} // namespace sherpa_onnx
  30 +
  31 +#endif // SHERPA_ONNX_CSRC_LEXICON_H_
  1 +// sherpa-onnx/csrc/offline-tts-impl.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-tts-impl.h"
  6 +
  7 +#include <memory>
  8 +
  9 +#include "sherpa-onnx/csrc/offline-tts-vits-impl.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create(
  14 + const OfflineTtsConfig &config) {
  15 + // TODO(fangjun): Support other types
  16 + return std::make_unique<OfflineTtsVitsImpl>(config);
  17 +}
  18 +
  19 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-tts-impl.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_
  7 +
  8 +#include <memory>
  9 +#include <string>
  10 +
  11 +#include "sherpa-onnx/csrc/offline-tts.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +class OfflineTtsImpl {
  16 + public:
  17 + virtual ~OfflineTtsImpl() = default;
  18 +
  19 + static std::unique_ptr<OfflineTtsImpl> Create(const OfflineTtsConfig &config);
  20 +
  21 + virtual GeneratedAudio Generate(const std::string &text) const = 0;
  22 +};
  23 +
  24 +} // namespace sherpa_onnx
  25 +
  26 +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_
  1 +// sherpa-onnx/csrc/offline-tts-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-tts-model-config.h"
  6 +
  7 +#include "sherpa-onnx/csrc/macros.h"
  8 +
  9 +namespace sherpa_onnx {
  10 +
  11 +void OfflineTtsModelConfig::Register(ParseOptions *po) {
  12 + vits.Register(po);
  13 +
  14 + po->Register("num-threads", &num_threads,
  15 + "Number of threads to run the neural network");
  16 +
  17 + po->Register("debug", &debug,
  18 + "true to print model information while loading it.");
  19 +
  20 + po->Register("provider", &provider,
  21 + "Specify a provider to use: cpu, cuda, coreml");
  22 +}
  23 +
  24 +bool OfflineTtsModelConfig::Validate() const {
  25 + if (num_threads < 1) {
  26 + SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
  27 + return false;
  28 + }
  29 +
  30 + return vits.Validate();
  31 +}
  32 +
  33 +std::string OfflineTtsModelConfig::ToString() const {
  34 + std::ostringstream os;
  35 +
  36 + os << "OfflineTtsModelConfig(";
  37 + os << "vits=" << vits.ToString() << ", ";
  38 + os << "num_threads=" << num_threads << ", ";
  39 + os << "debug=" << (debug ? "True" : "False") << ", ";
  40 + os << "provider=\"" << provider << "\")";
  41 +
  42 + return os.str();
  43 +}
  44 +
  45 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-tts-model-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_
  7 +
  8 +#include <string>
  9 +
  10 +#include "sherpa-onnx/csrc/offline-tts-vits-model-config.h"
  11 +#include "sherpa-onnx/csrc/parse-options.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +struct OfflineTtsModelConfig {
  16 + OfflineTtsVitsModelConfig vits;
  17 +
  18 + int32_t num_threads = 1;
  19 + bool debug = false;
  20 + std::string provider = "cpu";
  21 +
  22 + OfflineTtsModelConfig() = default;
  23 +
  24 + OfflineTtsModelConfig(const OfflineTtsVitsModelConfig &vits,
  25 + int32_t num_threads, bool debug,
  26 + const std::string &provider)
  27 + : vits(vits),
  28 + num_threads(num_threads),
  29 + debug(debug),
  30 + provider(provider) {}
  31 +
  32 + void Register(ParseOptions *po);
  33 + bool Validate() const;
  34 +
  35 + std::string ToString() const;
  36 +};
  37 +
  38 +} // namespace sherpa_onnx
  39 +
  40 +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/offline-tts-vits-impl.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_IMPL_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_IMPL_H_
  6 +
  7 +#include <memory>
  8 +#include <string>
  9 +#include <utility>
  10 +#include <vector>
  11 +
  12 +#include "sherpa-onnx/csrc/lexicon.h"
  13 +#include "sherpa-onnx/csrc/macros.h"
  14 +#include "sherpa-onnx/csrc/offline-tts-impl.h"
  15 +#include "sherpa-onnx/csrc/offline-tts-vits-model.h"
  16 +
  17 +namespace sherpa_onnx {
  18 +
  19 +class OfflineTtsVitsImpl : public OfflineTtsImpl {
  20 + public:
  21 + explicit OfflineTtsVitsImpl(const OfflineTtsConfig &config)
  22 + : model_(std::make_unique<OfflineTtsVitsModel>(config.model)),
  23 + lexicon_(config.model.vits.lexicon, config.model.vits.tokens,
  24 + model_->Punctuations()) {
  25 + SHERPA_ONNX_LOGE("config: %s\n", config.ToString().c_str());
  26 + }
  27 +
  28 + GeneratedAudio Generate(const std::string &text) const override {
  29 + std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text);
  30 + if (x.empty()) {
  31 + SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
  32 + return {};
  33 + }
  34 +
  35 + if (model_->AddBlank()) {
  36 + std::vector<int64_t> buffer(x.size() * 2 + 1);
  37 + int32_t i = 1;
  38 + for (auto k : x) {
  39 + buffer[i] = k;
  40 + i += 2;
  41 + }
  42 + x = std::move(buffer);
  43 + }
  44 +
  45 + auto memory_info =
  46 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  47 +
  48 + std::array<int64_t, 2> x_shape = {1, static_cast<int32_t>(x.size())};
  49 + Ort::Value x_tensor = Ort::Value::CreateTensor(
  50 + memory_info, x.data(), x.size(), x_shape.data(), x_shape.size());
  51 +
  52 + Ort::Value audio = model_->Run(std::move(x_tensor));
  53 +
  54 + std::vector<int64_t> audio_shape =
  55 + audio.GetTensorTypeAndShapeInfo().GetShape();
  56 +
  57 + int64_t total = 1;
  58 + // The output shape may be (1, 1, total) or (1, total) or (total,)
  59 + for (auto i : audio_shape) {
  60 + total *= i;
  61 + }
  62 +
  63 + const float *p = audio.GetTensorData<float>();
  64 +
  65 + GeneratedAudio ans;
  66 + ans.sample_rate = model_->SampleRate();
  67 + ans.samples = std::vector<float>(p, p + total);
  68 + return ans;
  69 + }
  70 +
  71 + private:
  72 + std::unique_ptr<OfflineTtsVitsModel> model_;
  73 + Lexicon lexicon_;
  74 +};
  75 +
  76 +} // namespace sherpa_onnx
  77 +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_IMPL_H_
  1 +// sherpa-onnx/csrc/offline-tts-vits-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-tts-vits-model-config.h"
  6 +
  7 +#include "sherpa-onnx/csrc/file-utils.h"
  8 +#include "sherpa-onnx/csrc/macros.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void OfflineTtsVitsModelConfig::Register(ParseOptions *po) {
  13 + po->Register("vits-model", &model, "Path to VITS model");
  14 + po->Register("vits-lexicon", &lexicon, "Path to lexicon.txt for VITS models");
  15 + po->Register("vits-tokens", &tokens, "Path to tokens.txt for VITS models");
  16 +}
  17 +
  18 +bool OfflineTtsVitsModelConfig::Validate() const {
  19 + if (model.empty()) {
  20 + SHERPA_ONNX_LOGE("Please provide --vits-model");
  21 + return false;
  22 + }
  23 +
  24 + if (!FileExists(model)) {
  25 + SHERPA_ONNX_LOGE("--vits-model: %s does not exist", model.c_str());
  26 + return false;
  27 + }
  28 +
  29 + if (lexicon.empty()) {
  30 + SHERPA_ONNX_LOGE("Please provide --vits-lexicon");
  31 + return false;
  32 + }
  33 +
  34 + if (!FileExists(lexicon)) {
  35 + SHERPA_ONNX_LOGE("--vits-lexicon: %s does not exist", lexicon.c_str());
  36 + return false;
  37 + }
  38 +
  39 + if (tokens.empty()) {
  40 + SHERPA_ONNX_LOGE("Please provide --vits-tokens");
  41 + return false;
  42 + }
  43 +
  44 + if (!FileExists(tokens)) {
  45 + SHERPA_ONNX_LOGE("--vits-tokens: %s does not exist", tokens.c_str());
  46 + return false;
  47 + }
  48 +
  49 + return true;
  50 +}
  51 +
  52 +std::string OfflineTtsVitsModelConfig::ToString() const {
  53 + std::ostringstream os;
  54 +
  55 + os << "OfflineTtsVitsModelConfig(";
  56 + os << "model=\"" << model << "\", ";
  57 + os << "lexicon=\"" << lexicon << "\", ";
  58 + os << "tokens=\"" << tokens << "\")";
  59 +
  60 + return os.str();
  61 +}
  62 +
  63 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-tts-vits-model-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_CONFIG_H_
  7 +
  8 +#include <string>
  9 +
  10 +#include "sherpa-onnx/csrc/parse-options.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +struct OfflineTtsVitsModelConfig {
  15 + std::string model;
  16 + std::string lexicon;
  17 + std::string tokens;
  18 +
  19 + OfflineTtsVitsModelConfig() = default;
  20 +
  21 + OfflineTtsVitsModelConfig(const std::string &model,
  22 + const std::string &lexicon,
  23 + const std::string &tokens)
  24 + : model(model), lexicon(lexicon), tokens(tokens) {}
  25 +
  26 + void Register(ParseOptions *po);
  27 + bool Validate() const;
  28 +
  29 + std::string ToString() const;
  30 +};
  31 +
  32 +} // namespace sherpa_onnx
  33 +
  34 +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/offline-tts-vits-model.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-tts-vits-model.h"
  6 +
  7 +#include <algorithm>
  8 +#include <string>
  9 +#include <utility>
  10 +#include <vector>
  11 +
  12 +#include "sherpa-onnx/csrc/macros.h"
  13 +#include "sherpa-onnx/csrc/onnx-utils.h"
  14 +#include "sherpa-onnx/csrc/session.h"
  15 +
  16 +namespace sherpa_onnx {
  17 +
  18 +class OfflineTtsVitsModel::Impl {
  19 + public:
  20 + explicit Impl(const OfflineTtsModelConfig &config)
  21 + : config_(config),
  22 + env_(ORT_LOGGING_LEVEL_WARNING),
  23 + sess_opts_(GetSessionOptions(config)),
  24 + allocator_{} {
  25 + auto buf = ReadFile(config.vits.model);
  26 + Init(buf.data(), buf.size());
  27 + }
  28 +
  29 + Ort::Value Run(Ort::Value x) {
  30 + auto memory_info =
  31 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  32 +
  33 + std::vector<int64_t> x_shape = x.GetTensorTypeAndShapeInfo().GetShape();
  34 + if (x_shape[0] != 1) {
  35 + SHERPA_ONNX_LOGE("Support only batch_size == 1. Given: %d",
  36 + static_cast<int32_t>(x_shape[0]));
  37 + exit(-1);
  38 + }
  39 +
  40 + int64_t len = x_shape[1];
  41 + int64_t len_shape = 1;
  42 +
  43 + Ort::Value x_length =
  44 + Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1);
  45 +
  46 + int64_t scale_shape = 1;
  47 + float noise_scale = 1;
  48 + float length_scale = 1;
  49 + float noise_scale_w = 1;
  50 +
  51 + Ort::Value noise_scale_tensor =
  52 + Ort::Value::CreateTensor(memory_info, &noise_scale, 1, &scale_shape, 1);
  53 + Ort::Value length_scale_tensor = Ort::Value::CreateTensor(
  54 + memory_info, &length_scale, 1, &scale_shape, 1);
  55 + Ort::Value noise_scale_w_tensor = Ort::Value::CreateTensor(
  56 + memory_info, &noise_scale_w, 1, &scale_shape, 1);
  57 +
  58 + std::array<Ort::Value, 5> inputs = {
  59 + std::move(x), std::move(x_length), std::move(noise_scale_tensor),
  60 + std::move(length_scale_tensor), std::move(noise_scale_w_tensor)};
  61 +
  62 + auto out =
  63 + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
  64 + output_names_ptr_.data(), output_names_ptr_.size());
  65 +
  66 + return std::move(out[0]);
  67 + }
  68 +
  69 + int32_t SampleRate() const { return sample_rate_; }
  70 +
  71 + bool AddBlank() const { return add_blank_; }
  72 +
  73 + std::string Punctuations() const { return punctuations_; }
  74 +
  75 + private:
  76 + void Init(void *model_data, size_t model_data_length) {
  77 + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
  78 + sess_opts_);
  79 +
  80 + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
  81 +
  82 + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
  83 +
  84 + // get meta data
  85 + Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
  86 + if (config_.debug) {
  87 + std::ostringstream os;
  88 + os << "---vits model---\n";
  89 + PrintModelMetadata(os, meta_data);
  90 + SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
  91 + }
  92 +
  93 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  94 + SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate");
  95 + SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank");
  96 + SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation");
  97 + }
  98 +
  99 + private:
  100 + OfflineTtsModelConfig config_;
  101 + Ort::Env env_;
  102 + Ort::SessionOptions sess_opts_;
  103 + Ort::AllocatorWithDefaultOptions allocator_;
  104 +
  105 + std::unique_ptr<Ort::Session> sess_;
  106 +
  107 + std::vector<std::string> input_names_;
  108 + std::vector<const char *> input_names_ptr_;
  109 +
  110 + std::vector<std::string> output_names_;
  111 + std::vector<const char *> output_names_ptr_;
  112 +
  113 + int32_t sample_rate_;
  114 + int32_t add_blank_;
  115 + std::string punctuations_;
  116 +};
  117 +
  118 +OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config)
  119 + : impl_(std::make_unique<Impl>(config)) {}
  120 +
  121 +OfflineTtsVitsModel::~OfflineTtsVitsModel() = default;
  122 +
  123 +Ort::Value OfflineTtsVitsModel::Run(Ort::Value x) {
  124 + return impl_->Run(std::move(x));
  125 +}
  126 +
  127 +int32_t OfflineTtsVitsModel::SampleRate() const { return impl_->SampleRate(); }
  128 +
  129 +bool OfflineTtsVitsModel::AddBlank() const { return impl_->AddBlank(); }
  130 +
  131 +std::string OfflineTtsVitsModel::Punctuations() const {
  132 + return impl_->Punctuations();
  133 +}
  134 +
  135 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-tts-vits-model.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_H_
  7 +
  8 +#include <memory>
  9 +#include <string>
  10 +
  11 +#include "onnxruntime_cxx_api.h" // NOLINT
  12 +#include "sherpa-onnx/csrc/offline-tts-model-config.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +class OfflineTtsVitsModel {
  17 + public:
  18 + ~OfflineTtsVitsModel();
  19 +
  20 + explicit OfflineTtsVitsModel(const OfflineTtsModelConfig &config);
  21 +
  22 + /** Run the model.
  23 + *
  24 + * @param x A int64 tensor of shape (1, num_tokens)
  25 + * @return Return a float32 tensor containing audio samples. You can flatten
  26 + * it to a 1-D tensor.
  27 + */
  28 + Ort::Value Run(Ort::Value x);
  29 +
  30 + // Sample rate of the generated audio
  31 + int32_t SampleRate() const;
  32 +
  33 + // true to insert a blank between each token
  34 + bool AddBlank() const;
  35 +
  36 + std::string Punctuations() const;
  37 +
  38 + private:
  39 + class Impl;
  40 + std::unique_ptr<Impl> impl_;
  41 +};
  42 +
  43 +} // namespace sherpa_onnx
  44 +
  45 +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_H_
  1 +// sherpa-onnx/csrc/offline-tts.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-tts.h"
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/offline-tts-impl.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +void OfflineTtsConfig::Register(ParseOptions *po) { model.Register(po); }
  14 +
  15 +bool OfflineTtsConfig::Validate() const { return model.Validate(); }
  16 +
  17 +std::string OfflineTtsConfig::ToString() const {
  18 + std::ostringstream os;
  19 +
  20 + os << "OfflineTtsConfig(";
  21 + os << "model=" << model.ToString() << ")";
  22 +
  23 + return os.str();
  24 +}
  25 +
  26 +OfflineTts::OfflineTts(const OfflineTtsConfig &config)
  27 + : impl_(OfflineTtsImpl::Create(config)) {}
  28 +
  29 +OfflineTts::~OfflineTts() = default;
  30 +
  31 +GeneratedAudio OfflineTts::Generate(const std::string &text) const {
  32 + return impl_->Generate(text);
  33 +}
  34 +
  35 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-tts.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_H_
  6 +
  7 +#include <cstdint>
  8 +#include <memory>
  9 +#include <string>
  10 +#include <vector>
  11 +
  12 +#include "sherpa-onnx/csrc/offline-tts-model-config.h"
  13 +#include "sherpa-onnx/csrc/parse-options.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +struct OfflineTtsConfig {
  18 + OfflineTtsModelConfig model;
  19 +
  20 + OfflineTtsConfig() = default;
  21 + explicit OfflineTtsConfig(const OfflineTtsModelConfig &model)
  22 + : model(model) {}
  23 +
  24 + void Register(ParseOptions *po);
  25 + bool Validate() const;
  26 +
  27 + std::string ToString() const;
  28 +};
  29 +
  30 +struct GeneratedAudio {
  31 + std::vector<float> samples;
  32 + int32_t sample_rate;
  33 +};
  34 +
  35 +class OfflineTtsImpl;
  36 +
  37 +class OfflineTts {
  38 + public:
  39 + ~OfflineTts();
  40 + explicit OfflineTts(const OfflineTtsConfig &config);
  41 + // @param text A string containing words separated by spaces
  42 + GeneratedAudio Generate(const std::string &text) const;
  43 +
  44 + private:
  45 + std::unique_ptr<OfflineTtsImpl> impl_;
  46 +};
  47 +
  48 +} // namespace sherpa_onnx
  49 +
  50 +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_H_
@@ -87,4 +87,8 @@ Ort::SessionOptions GetSessionOptions(const VadModelConfig &config) { @@ -87,4 +87,8 @@ Ort::SessionOptions GetSessionOptions(const VadModelConfig &config) {
87 return GetSessionOptionsImpl(config.num_threads, config.provider); 87 return GetSessionOptionsImpl(config.num_threads, config.provider);
88 } 88 }
89 89
  90 +Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) {
  91 + return GetSessionOptionsImpl(config.num_threads, config.provider);
  92 +}
  93 +
90 } // namespace sherpa_onnx 94 } // namespace sherpa_onnx
@@ -8,6 +8,7 @@ @@ -8,6 +8,7 @@
8 #include "onnxruntime_cxx_api.h" // NOLINT 8 #include "onnxruntime_cxx_api.h" // NOLINT
9 #include "sherpa-onnx/csrc/offline-lm-config.h" 9 #include "sherpa-onnx/csrc/offline-lm-config.h"
10 #include "sherpa-onnx/csrc/offline-model-config.h" 10 #include "sherpa-onnx/csrc/offline-model-config.h"
  11 +#include "sherpa-onnx/csrc/offline-tts-model-config.h"
11 #include "sherpa-onnx/csrc/online-lm-config.h" 12 #include "sherpa-onnx/csrc/online-lm-config.h"
12 #include "sherpa-onnx/csrc/online-model-config.h" 13 #include "sherpa-onnx/csrc/online-model-config.h"
13 #include "sherpa-onnx/csrc/vad-model-config.h" 14 #include "sherpa-onnx/csrc/vad-model-config.h"
@@ -23,6 +24,8 @@ Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config); @@ -23,6 +24,8 @@ Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config);
23 Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config); 24 Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config);
24 25
25 Ort::SessionOptions GetSessionOptions(const VadModelConfig &config); 26 Ort::SessionOptions GetSessionOptions(const VadModelConfig &config);
  27 +
  28 +Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config);
26 } // namespace sherpa_onnx 29 } // namespace sherpa_onnx
27 30
28 #endif // SHERPA_ONNX_CSRC_SESSION_H_ 31 #endif // SHERPA_ONNX_CSRC_SESSION_H_
  1 +// sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include <fstream>
  6 +
  7 +#include "sherpa-onnx/csrc/offline-tts.h"
  8 +#include "sherpa-onnx/csrc/parse-options.h"
  9 +
  10 +int main(int32_t argc, char *argv[]) {
  11 + const char *kUsageMessage = R"usage(
  12 +Offline text-to-speech with sherpa-onnx
  13 +
  14 +./bin/sherpa-onnx-offline-tts \
  15 + --vits-model /path/to/model.onnx \
  16 + --vits-lexicon /path/to/lexicon.txt \
  17 + --vits-tokens /path/to/tokens.txt
  18 + 'some text within single quotes'
  19 +
  20 +It will generate a file test.wav.
  21 +)usage";
  22 +
  23 + sherpa_onnx::ParseOptions po(kUsageMessage);
  24 + sherpa_onnx::OfflineTtsConfig config;
  25 + config.Register(&po);
  26 + po.Read(argc, argv);
  27 +
  28 + if (po.NumArgs() == 0) {
  29 + fprintf(stderr, "Error: Please provide the text to generate audio.\n\n");
  30 + po.PrintUsage();
  31 + exit(EXIT_FAILURE);
  32 + }
  33 +
  34 + if (po.NumArgs() > 1) {
  35 + fprintf(stderr,
  36 + "Error: Accept only one positional argument. Please use single "
  37 + "quotes to wrap your text\n");
  38 + po.PrintUsage();
  39 + exit(EXIT_FAILURE);
  40 + }
  41 +
  42 + if (!config.Validate()) {
  43 + fprintf(stderr, "Errors in config!\n");
  44 + exit(EXIT_FAILURE);
  45 + }
  46 +
  47 + sherpa_onnx::OfflineTts tts(config);
  48 + auto audio = tts.Generate(po.GetArg(1));
  49 +
  50 + std::ofstream os("t.pcm", std::ios::binary);
  51 + os.write(reinterpret_cast<const char *>(audio.samples.data()),
  52 + sizeof(float) * audio.samples.size());
  53 +
  54 + // sox -t raw -r 22050 -b 32 -e floating-point -c 1 ./t.pcm ./t.wav
  55 +
  56 + return 0;
  57 +}