Fangjun Kuang
Committed by GitHub

Add TTS with VITS (#360)

... ... @@ -86,6 +86,15 @@ set(sources
wave-reader.cc
)
list(APPEND sources
lexicon.cc
offline-tts-impl.cc
offline-tts-model-config.cc
offline-tts-vits-model-config.cc
offline-tts-vits-model.cc
offline-tts.cc
)
if(SHERPA_ONNX_ENABLE_CHECK)
list(APPEND sources log.cc)
endif()
... ... @@ -135,23 +144,31 @@ endif()
add_executable(sherpa-onnx sherpa-onnx.cc)
add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
target_link_libraries(sherpa-onnx sherpa-onnx-core)
target_link_libraries(sherpa-onnx-offline sherpa-onnx-core)
target_link_libraries(sherpa-onnx-offline-parallel sherpa-onnx-core)
target_link_libraries(sherpa-onnx-offline-tts sherpa-onnx-core)
if(NOT WIN32)
target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib")
target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib")
target_link_libraries(sherpa-onnx-offline-parallel "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
target_link_libraries(sherpa-onnx-offline-parallel "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib")
target_link_libraries(sherpa-onnx-offline-tts "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
target_link_libraries(sherpa-onnx-offline-tts "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib")
if(SHERPA_ONNX_ENABLE_PYTHON)
target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib")
target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib")
target_link_libraries(sherpa-onnx-offline-parallel "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib")
target_link_libraries(sherpa-onnx-offline-tts "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib")
endif()
endif()
... ... @@ -170,6 +187,7 @@ install(
sherpa-onnx
sherpa-onnx-offline
sherpa-onnx-offline-parallel
sherpa-onnx-offline-tts
DESTINATION
bin
)
... ...
// sherpa-onnx/csrc/lexicon.cc
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/lexicon.h"
#include <algorithm>
#include <cctype>
#include <fstream>
#include <sstream>
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
static void ToLowerCase(std::string *in_out) {
std::transform(in_out->begin(), in_out->end(), in_out->begin(),
[](unsigned char c) { return std::tolower(c); });
}
// Note: We don't use SymbolTable here since tokens may contain a blank
// in the first column
static std::unordered_map<std::string, int32_t> ReadTokens(
const std::string &tokens) {
std::unordered_map<std::string, int32_t> token2id;
std::ifstream is(tokens);
std::string line;
std::string sym;
int32_t id;
while (std::getline(is, line)) {
std::istringstream iss(line);
iss >> sym;
if (iss.eof()) {
id = atoi(sym.c_str());
sym = " ";
} else {
iss >> id;
}
if (!iss.eof()) {
SHERPA_ONNX_LOGE("Error: %s", line.c_str());
exit(-1);
}
#if 0
if (token2id.count(sym)) {
SHERPA_ONNX_LOGE("Duplicated token %s. Line %s. Existing ID: %d",
sym.c_str(), line.c_str(), token2id.at(sym));
exit(-1);
}
#endif
token2id.insert({sym, id});
}
return token2id;
}
static std::vector<int32_t> ConvertTokensToIds(
const std::unordered_map<std::string, int32_t> &token2id,
const std::vector<std::string> &tokens) {
std::vector<int32_t> ids;
ids.reserve(tokens.size());
for (const auto &s : tokens) {
if (!token2id.count(s)) {
return {};
}
int32_t id = token2id.at(s);
ids.push_back(id);
}
return ids;
}
Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
const std::string &punctuations) {
token2id_ = ReadTokens(tokens);
std::ifstream is(lexicon);
std::string word;
std::vector<std::string> token_list;
std::string line;
std::string phone;
while (std::getline(is, line)) {
std::istringstream iss(line);
token_list.clear();
iss >> word;
ToLowerCase(&word);
if (word2ids_.count(word)) {
SHERPA_ONNX_LOGE("Duplicated word: %s", word.c_str());
return;
}
while (iss >> phone) {
token_list.push_back(std::move(phone));
}
std::vector<int32_t> ids = ConvertTokensToIds(token2id_, token_list);
if (ids.empty()) {
continue;
}
word2ids_.insert({std::move(word), std::move(ids)});
}
// process punctuations
std::vector<std::string> punctuation_list;
SplitStringToVector(punctuations, " ", false, &punctuation_list);
for (auto &s : punctuation_list) {
punctuations_.insert(std::move(s));
}
}
std::vector<int64_t> Lexicon::ConvertTextToTokenIds(
const std::string &_text) const {
std::string text(_text);
ToLowerCase(&text);
std::vector<std::string> words;
SplitStringToVector(text, " ", false, &words);
std::vector<int64_t> ans;
for (auto w : words) {
std::vector<int64_t> prefix;
while (!w.empty() && punctuations_.count(std::string(1, w[0]))) {
// if w begins with a punctuation
prefix.push_back(token2id_.at(std::string(1, w[0])));
w = std::string(w.begin() + 1, w.end());
}
std::vector<int64_t> suffix;
while (!w.empty() && punctuations_.count(std::string(1, w.back()))) {
suffix.push_back(token2id_.at(std::string(1, w.back())));
w = std::string(w.begin(), w.end() - 1);
}
if (!word2ids_.count(w)) {
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
continue;
}
const auto &token_ids = word2ids_.at(w);
ans.insert(ans.end(), prefix.begin(), prefix.end());
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
ans.insert(ans.end(), suffix.rbegin(), suffix.rend());
}
return ans;
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/lexicon.h
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_LEXICON_H_
#define SHERPA_ONNX_CSRC_LEXICON_H_
#include <cstdint>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace sherpa_onnx {
class Lexicon {
public:
Lexicon(const std::string &lexicon, const std::string &tokens,
const std::string &punctuations);
std::vector<int64_t> ConvertTextToTokenIds(const std::string &text) const;
private:
std::unordered_map<std::string, std::vector<int32_t>> word2ids_;
std::unordered_set<std::string> punctuations_;
std::unordered_map<std::string, int32_t> token2id_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_LEXICON_H_
... ...
// sherpa-onnx/csrc/offline-tts-impl.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-tts-impl.h"
#include <memory>
#include "sherpa-onnx/csrc/offline-tts-vits-impl.h"
namespace sherpa_onnx {
std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create(
const OfflineTtsConfig &config) {
// TODO(fangjun): Support other types
return std::make_unique<OfflineTtsVitsImpl>(config);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-tts-impl.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_
#include <memory>
#include <string>
#include "sherpa-onnx/csrc/offline-tts.h"
namespace sherpa_onnx {
class OfflineTtsImpl {
public:
virtual ~OfflineTtsImpl() = default;
static std::unique_ptr<OfflineTtsImpl> Create(const OfflineTtsConfig &config);
virtual GeneratedAudio Generate(const std::string &text) const = 0;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_
... ...
// sherpa-onnx/csrc/offline-tts-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OfflineTtsModelConfig::Register(ParseOptions *po) {
vits.Register(po);
po->Register("num-threads", &num_threads,
"Number of threads to run the neural network");
po->Register("debug", &debug,
"true to print model information while loading it.");
po->Register("provider", &provider,
"Specify a provider to use: cpu, cuda, coreml");
}
bool OfflineTtsModelConfig::Validate() const {
if (num_threads < 1) {
SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
return false;
}
return vits.Validate();
}
std::string OfflineTtsModelConfig::ToString() const {
std::ostringstream os;
os << "OfflineTtsModelConfig(";
os << "vits=" << vits.ToString() << ", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
os << "provider=\"" << provider << "\")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-tts-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/offline-tts-vits-model-config.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OfflineTtsModelConfig {
OfflineTtsVitsModelConfig vits;
int32_t num_threads = 1;
bool debug = false;
std::string provider = "cpu";
OfflineTtsModelConfig() = default;
OfflineTtsModelConfig(const OfflineTtsVitsModelConfig &vits,
int32_t num_threads, bool debug,
const std::string &provider)
: vits(vits),
num_threads(num_threads),
debug(debug),
provider(provider) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_
... ...
// sherpa-onnx/csrc/offline-tts-vits-impl.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_IMPL_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/lexicon.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-tts-impl.h"
#include "sherpa-onnx/csrc/offline-tts-vits-model.h"
namespace sherpa_onnx {
class OfflineTtsVitsImpl : public OfflineTtsImpl {
public:
explicit OfflineTtsVitsImpl(const OfflineTtsConfig &config)
: model_(std::make_unique<OfflineTtsVitsModel>(config.model)),
lexicon_(config.model.vits.lexicon, config.model.vits.tokens,
model_->Punctuations()) {
SHERPA_ONNX_LOGE("config: %s\n", config.ToString().c_str());
}
GeneratedAudio Generate(const std::string &text) const override {
std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text);
if (x.empty()) {
SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
return {};
}
if (model_->AddBlank()) {
std::vector<int64_t> buffer(x.size() * 2 + 1);
int32_t i = 1;
for (auto k : x) {
buffer[i] = k;
i += 2;
}
x = std::move(buffer);
}
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 2> x_shape = {1, static_cast<int32_t>(x.size())};
Ort::Value x_tensor = Ort::Value::CreateTensor(
memory_info, x.data(), x.size(), x_shape.data(), x_shape.size());
Ort::Value audio = model_->Run(std::move(x_tensor));
std::vector<int64_t> audio_shape =
audio.GetTensorTypeAndShapeInfo().GetShape();
int64_t total = 1;
// The output shape may be (1, 1, total) or (1, total) or (total,)
for (auto i : audio_shape) {
total *= i;
}
const float *p = audio.GetTensorData<float>();
GeneratedAudio ans;
ans.sample_rate = model_->SampleRate();
ans.samples = std::vector<float>(p, p + total);
return ans;
}
private:
std::unique_ptr<OfflineTtsVitsModel> model_;
Lexicon lexicon_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_IMPL_H_
... ...
// sherpa-onnx/csrc/offline-tts-vits-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-tts-vits-model-config.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OfflineTtsVitsModelConfig::Register(ParseOptions *po) {
po->Register("vits-model", &model, "Path to VITS model");
po->Register("vits-lexicon", &lexicon, "Path to lexicon.txt for VITS models");
po->Register("vits-tokens", &tokens, "Path to tokens.txt for VITS models");
}
bool OfflineTtsVitsModelConfig::Validate() const {
if (model.empty()) {
SHERPA_ONNX_LOGE("Please provide --vits-model");
return false;
}
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("--vits-model: %s does not exist", model.c_str());
return false;
}
if (lexicon.empty()) {
SHERPA_ONNX_LOGE("Please provide --vits-lexicon");
return false;
}
if (!FileExists(lexicon)) {
SHERPA_ONNX_LOGE("--vits-lexicon: %s does not exist", lexicon.c_str());
return false;
}
if (tokens.empty()) {
SHERPA_ONNX_LOGE("Please provide --vits-tokens");
return false;
}
if (!FileExists(tokens)) {
SHERPA_ONNX_LOGE("--vits-tokens: %s does not exist", tokens.c_str());
return false;
}
return true;
}
std::string OfflineTtsVitsModelConfig::ToString() const {
std::ostringstream os;
os << "OfflineTtsVitsModelConfig(";
os << "model=\"" << model << "\", ";
os << "lexicon=\"" << lexicon << "\", ";
os << "tokens=\"" << tokens << "\")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-tts-vits-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OfflineTtsVitsModelConfig {
std::string model;
std::string lexicon;
std::string tokens;
OfflineTtsVitsModelConfig() = default;
OfflineTtsVitsModelConfig(const std::string &model,
const std::string &lexicon,
const std::string &tokens)
: model(model), lexicon(lexicon), tokens(tokens) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_CONFIG_H_
... ...
// sherpa-onnx/csrc/offline-tts-vits-model.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-tts-vits-model.h"
#include <algorithm>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
namespace sherpa_onnx {
class OfflineTtsVitsModel::Impl {
public:
explicit Impl(const OfflineTtsModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_WARNING),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(config.vits.model);
Init(buf.data(), buf.size());
}
Ort::Value Run(Ort::Value x) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::vector<int64_t> x_shape = x.GetTensorTypeAndShapeInfo().GetShape();
if (x_shape[0] != 1) {
SHERPA_ONNX_LOGE("Support only batch_size == 1. Given: %d",
static_cast<int32_t>(x_shape[0]));
exit(-1);
}
int64_t len = x_shape[1];
int64_t len_shape = 1;
Ort::Value x_length =
Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1);
int64_t scale_shape = 1;
float noise_scale = 1;
float length_scale = 1;
float noise_scale_w = 1;
Ort::Value noise_scale_tensor =
Ort::Value::CreateTensor(memory_info, &noise_scale, 1, &scale_shape, 1);
Ort::Value length_scale_tensor = Ort::Value::CreateTensor(
memory_info, &length_scale, 1, &scale_shape, 1);
Ort::Value noise_scale_w_tensor = Ort::Value::CreateTensor(
memory_info, &noise_scale_w, 1, &scale_shape, 1);
std::array<Ort::Value, 5> inputs = {
std::move(x), std::move(x_length), std::move(noise_scale_tensor),
std::move(length_scale_tensor), std::move(noise_scale_w_tensor)};
auto out =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
return std::move(out[0]);
}
int32_t SampleRate() const { return sample_rate_; }
bool AddBlank() const { return add_blank_; }
std::string Punctuations() const { return punctuations_; }
private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
os << "---vits model---\n";
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate");
SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank");
SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation");
}
private:
OfflineTtsModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> sess_;
std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
int32_t sample_rate_;
int32_t add_blank_;
std::string punctuations_;
};
OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
OfflineTtsVitsModel::~OfflineTtsVitsModel() = default;
Ort::Value OfflineTtsVitsModel::Run(Ort::Value x) {
return impl_->Run(std::move(x));
}
int32_t OfflineTtsVitsModel::SampleRate() const { return impl_->SampleRate(); }
bool OfflineTtsVitsModel::AddBlank() const { return impl_->AddBlank(); }
std::string OfflineTtsVitsModel::Punctuations() const {
return impl_->Punctuations();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-tts-vits-model.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_H_
#include <memory>
#include <string>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
namespace sherpa_onnx {
class OfflineTtsVitsModel {
public:
~OfflineTtsVitsModel();
explicit OfflineTtsVitsModel(const OfflineTtsModelConfig &config);
/** Run the model.
*
* @param x A int64 tensor of shape (1, num_tokens)
* @return Return a float32 tensor containing audio samples. You can flatten
* it to a 1-D tensor.
*/
Ort::Value Run(Ort::Value x);
// Sample rate of the generated audio
int32_t SampleRate() const;
// true to insert a blank between each token
bool AddBlank() const;
std::string Punctuations() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_H_
... ...
// sherpa-onnx/csrc/offline-tts.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-tts.h"
#include <string>
#include "sherpa-onnx/csrc/offline-tts-impl.h"
namespace sherpa_onnx {
void OfflineTtsConfig::Register(ParseOptions *po) { model.Register(po); }
bool OfflineTtsConfig::Validate() const { return model.Validate(); }
std::string OfflineTtsConfig::ToString() const {
std::ostringstream os;
os << "OfflineTtsConfig(";
os << "model=" << model.ToString() << ")";
return os.str();
}
OfflineTts::OfflineTts(const OfflineTtsConfig &config)
: impl_(OfflineTtsImpl::Create(config)) {}
OfflineTts::~OfflineTts() = default;
GeneratedAudio OfflineTts::Generate(const std::string &text) const {
return impl_->Generate(text);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-tts.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_H_
#include <cstdint>
#include <memory>
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OfflineTtsConfig {
OfflineTtsModelConfig model;
OfflineTtsConfig() = default;
explicit OfflineTtsConfig(const OfflineTtsModelConfig &model)
: model(model) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
struct GeneratedAudio {
std::vector<float> samples;
int32_t sample_rate;
};
class OfflineTtsImpl;
class OfflineTts {
public:
~OfflineTts();
explicit OfflineTts(const OfflineTtsConfig &config);
// @param text A string containing words separated by spaces
GeneratedAudio Generate(const std::string &text) const;
private:
std::unique_ptr<OfflineTtsImpl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_H_
... ...
... ... @@ -87,4 +87,8 @@ Ort::SessionOptions GetSessionOptions(const VadModelConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}
Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}
} // namespace sherpa_onnx
... ...
... ... @@ -8,6 +8,7 @@
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-lm-config.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
#include "sherpa-onnx/csrc/online-lm-config.h"
#include "sherpa-onnx/csrc/online-model-config.h"
#include "sherpa-onnx/csrc/vad-model-config.h"
... ... @@ -23,6 +24,8 @@ Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config);
Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config);
Ort::SessionOptions GetSessionOptions(const VadModelConfig &config);
Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SESSION_H_
... ...
// sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include <fstream>
#include "sherpa-onnx/csrc/offline-tts.h"
#include "sherpa-onnx/csrc/parse-options.h"
int main(int32_t argc, char *argv[]) {
const char *kUsageMessage = R"usage(
Offline text-to-speech with sherpa-onnx
./bin/sherpa-onnx-offline-tts \
--vits-model /path/to/model.onnx \
--vits-lexicon /path/to/lexicon.txt \
--vits-tokens /path/to/tokens.txt
'some text within single quotes'
It will generate a file test.wav.
)usage";
sherpa_onnx::ParseOptions po(kUsageMessage);
sherpa_onnx::OfflineTtsConfig config;
config.Register(&po);
po.Read(argc, argv);
if (po.NumArgs() == 0) {
fprintf(stderr, "Error: Please provide the text to generate audio.\n\n");
po.PrintUsage();
exit(EXIT_FAILURE);
}
if (po.NumArgs() > 1) {
fprintf(stderr,
"Error: Accept only one positional argument. Please use single "
"quotes to wrap your text\n");
po.PrintUsage();
exit(EXIT_FAILURE);
}
if (!config.Validate()) {
fprintf(stderr, "Errors in config!\n");
exit(EXIT_FAILURE);
}
sherpa_onnx::OfflineTts tts(config);
auto audio = tts.Generate(po.GetArg(1));
std::ofstream os("t.pcm", std::ios::binary);
os.write(reinterpret_cast<const char *>(audio.samples.data()),
sizeof(float) * audio.samples.size());
// sox -t raw -r 22050 -b 32 -e floating-point -c 1 ./t.pcm ./t.wav
return 0;
}
... ...