Fangjun Kuang
Committed by GitHub

Add C++ runtime for kitten-tts (#2460)

... ... @@ -192,6 +192,8 @@ if(SHERPA_ONNX_ENABLE_TTS)
offline-tts-character-frontend.cc
offline-tts-frontend.cc
offline-tts-impl.cc
offline-tts-kitten-model-config.cc
offline-tts-kitten-model.cc
offline-tts-kokoro-model-config.cc
offline-tts-kokoro-model.cc
offline-tts-matcha-model-config.cc
... ...
... ... @@ -260,7 +260,7 @@ class KokoroMultiLangLexicon::Impl {
std::vector<std::vector<int32_t>> ConvertTextToTokenIDsWithEspeak(
const std::string &text, const std::string &voice) const {
auto temp = ConvertTextToTokenIdsKokoro(
auto temp = ConvertTextToTokenIdsKokoroOrKitten(
phoneme2id_, meta_data_.max_token_len, text, voice);
std::vector<std::vector<int32_t>> ans;
ans.reserve(temp.size());
... ...
... ... @@ -59,7 +59,7 @@ class OfflineTtsFrontend {
void InitEspeak(const std::string &data_dir);
// implementation in ./piper-phonemize-lexicon.cc
std::vector<TokenIDs> ConvertTextToTokenIdsKokoro(
std::vector<TokenIDs> ConvertTextToTokenIdsKokoroOrKitten(
const std::unordered_map<char32_t, int32_t> &token2id,
int32_t max_token_len, const std::string &text,
const std::string &voice = "");
... ...
... ... @@ -16,6 +16,7 @@
#include "rawfile/raw_file_manager.h"
#endif
#include "sherpa-onnx/csrc/offline-tts-kitten-impl.h"
#include "sherpa-onnx/csrc/offline-tts-kokoro-impl.h"
#include "sherpa-onnx/csrc/offline-tts-matcha-impl.h"
#include "sherpa-onnx/csrc/offline-tts-vits-impl.h"
... ... @@ -40,9 +41,15 @@ std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create(
return std::make_unique<OfflineTtsVitsImpl>(config);
} else if (!config.model.matcha.acoustic_model.empty()) {
return std::make_unique<OfflineTtsMatchaImpl>(config);
} else if (!config.model.kokoro.model.empty()) {
return std::make_unique<OfflineTtsKokoroImpl>(config);
} else if (!config.model.kitten.model.empty()) {
return std::make_unique<OfflineTtsKittenImpl>(config);
}
return std::make_unique<OfflineTtsKokoroImpl>(config);
SHERPA_ONNX_LOGE("Please provide a tts model.");
return {};
}
template <typename Manager>
... ... @@ -52,9 +59,14 @@ std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create(
return std::make_unique<OfflineTtsVitsImpl>(mgr, config);
} else if (!config.model.matcha.acoustic_model.empty()) {
return std::make_unique<OfflineTtsMatchaImpl>(mgr, config);
} else if (!config.model.kokoro.model.empty()) {
return std::make_unique<OfflineTtsKokoroImpl>(mgr, config);
} else if (!config.model.kitten.model.empty()) {
return std::make_unique<OfflineTtsKittenImpl>(mgr, config);
}
return std::make_unique<OfflineTtsKokoroImpl>(mgr, config);
SHERPA_ONNX_LOGE("Please provide a tts model.");
return {};
}
#if __ANDROID_API__ >= 9
... ...
// sherpa-onnx/csrc/offline-tts-kitten-impl.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_KITTEN_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_KITTEN_IMPL_H_
#include <iomanip>
#include <ios>
#include <memory>
#include <string>
#include <strstream>
#include <utility>
#include <vector>
#include "fst/extensions/far/far.h"
#include "kaldifst/csrc/kaldi-fst-io.h"
#include "kaldifst/csrc/text-normalizer.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/lexicon.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-tts-frontend.h"
#include "sherpa-onnx/csrc/offline-tts-impl.h"
#include "sherpa-onnx/csrc/offline-tts-kitten-model.h"
#include "sherpa-onnx/csrc/piper-phonemize-lexicon.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
class OfflineTtsKittenImpl : public OfflineTtsImpl {
public:
explicit OfflineTtsKittenImpl(const OfflineTtsConfig &config)
: config_(config),
model_(std::make_unique<OfflineTtsKittenModel>(config.model)) {
InitFrontend();
if (!config.rule_fsts.empty()) {
std::vector<std::string> files;
SplitStringToVector(config.rule_fsts, ",", false, &files);
tn_list_.reserve(files.size());
for (const auto &f : files) {
if (config.model.debug) {
#if __OHOS__
SHERPA_ONNX_LOGE("rule fst: %{public}s", f.c_str());
#else
SHERPA_ONNX_LOGE("rule fst: %s", f.c_str());
#endif
}
tn_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(f));
}
}
if (!config.rule_fars.empty()) {
if (config.model.debug) {
SHERPA_ONNX_LOGE("Loading FST archives");
}
std::vector<std::string> files;
SplitStringToVector(config.rule_fars, ",", false, &files);
tn_list_.reserve(files.size() + tn_list_.size());
for (const auto &f : files) {
if (config.model.debug) {
#if __OHOS__
SHERPA_ONNX_LOGE("rule far: %{public}s", f.c_str());
#else
SHERPA_ONNX_LOGE("rule far: %s", f.c_str());
#endif
}
std::unique_ptr<fst::FarReader<fst::StdArc>> reader(
fst::FarReader<fst::StdArc>::Open(f));
for (; !reader->Done(); reader->Next()) {
std::unique_ptr<fst::StdConstFst> r(
fst::CastOrConvertToConstFst(reader->GetFst()->Copy()));
tn_list_.push_back(
std::make_unique<kaldifst::TextNormalizer>(std::move(r)));
}
}
if (config.model.debug) {
SHERPA_ONNX_LOGE("FST archives loaded!");
}
}
}
template <typename Manager>
OfflineTtsKittenImpl(Manager *mgr, const OfflineTtsConfig &config)
: config_(config),
model_(std::make_unique<OfflineTtsKittenModel>(mgr, config.model)) {
InitFrontend(mgr);
if (!config.rule_fsts.empty()) {
std::vector<std::string> files;
SplitStringToVector(config.rule_fsts, ",", false, &files);
tn_list_.reserve(files.size());
for (const auto &f : files) {
if (config.model.debug) {
#if __OHOS__
SHERPA_ONNX_LOGE("rule fst: %{public}s", f.c_str());
#else
SHERPA_ONNX_LOGE("rule fst: %s", f.c_str());
#endif
}
auto buf = ReadFile(mgr, f);
std::istrstream is(buf.data(), buf.size());
tn_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(is));
}
}
if (!config.rule_fars.empty()) {
std::vector<std::string> files;
SplitStringToVector(config.rule_fars, ",", false, &files);
tn_list_.reserve(files.size() + tn_list_.size());
for (const auto &f : files) {
if (config.model.debug) {
#if __OHOS__
SHERPA_ONNX_LOGE("rule far: %{public}s", f.c_str());
#else
SHERPA_ONNX_LOGE("rule far: %s", f.c_str());
#endif
}
auto buf = ReadFile(mgr, f);
std::unique_ptr<std::istream> s(
new std::istrstream(buf.data(), buf.size()));
std::unique_ptr<fst::FarReader<fst::StdArc>> reader(
fst::FarReader<fst::StdArc>::Open(std::move(s)));
for (; !reader->Done(); reader->Next()) {
std::unique_ptr<fst::StdConstFst> r(
fst::CastOrConvertToConstFst(reader->GetFst()->Copy()));
tn_list_.push_back(
std::make_unique<kaldifst::TextNormalizer>(std::move(r)));
} // for (; !reader->Done(); reader->Next())
} // for (const auto &f : files)
} // if (!config.rule_fars.empty())
}
int32_t SampleRate() const override {
return model_->GetMetaData().sample_rate;
}
int32_t NumSpeakers() const override {
return model_->GetMetaData().num_speakers;
}
GeneratedAudio Generate(
const std::string &_text, int64_t sid = 0, float speed = 1.0,
GeneratedAudioCallback callback = nullptr) const override {
const auto &meta_data = model_->GetMetaData();
int32_t num_speakers = meta_data.num_speakers;
if (num_speakers == 0 && sid != 0) {
#if __OHOS__
SHERPA_ONNX_LOGE(
"This is a single-speaker model and supports only sid 0. Given sid: "
"%{public}d. sid is ignored",
static_cast<int32_t>(sid));
#else
SHERPA_ONNX_LOGE(
"This is a single-speaker model and supports only sid 0. Given sid: "
"%d. sid is ignored",
static_cast<int32_t>(sid));
#endif
}
if (num_speakers != 0 && (sid >= num_speakers || sid < 0)) {
#if __OHOS__
SHERPA_ONNX_LOGE(
"This model contains only %{public}d speakers. sid should be in the "
"range [%{public}d, %{public}d]. Given: %{public}d. Use sid=0",
num_speakers, 0, num_speakers - 1, static_cast<int32_t>(sid));
#else
SHERPA_ONNX_LOGE(
"This model contains only %d speakers. sid should be in the range "
"[%d, %d]. Given: %d. Use sid=0",
num_speakers, 0, num_speakers - 1, static_cast<int32_t>(sid));
#endif
sid = 0;
}
std::string text = _text;
if (config_.model.debug) {
#if __OHOS__
SHERPA_ONNX_LOGE("Raw text: %{public}s", text.c_str());
#else
SHERPA_ONNX_LOGE("Raw text: %s", text.c_str());
#endif
std::ostringstream os;
os << "In bytes (hex):\n";
const auto p = reinterpret_cast<const uint8_t *>(text.c_str());
for (int32_t i = 0; i != text.size(); ++i) {
os << std::setw(2) << std::setfill('0') << std::hex
<< static_cast<uint32_t>(p[i]) << " ";
}
os << "\n";
#if __OHOS__
SHERPA_ONNX_LOGE("%{public}s", os.str().c_str());
#else
SHERPA_ONNX_LOGE("%s", os.str().c_str());
#endif
}
if (!tn_list_.empty()) {
for (const auto &tn : tn_list_) {
text = tn->Normalize(text);
if (config_.model.debug) {
#if __OHOS__
SHERPA_ONNX_LOGE("After normalizing: %{public}s", text.c_str());
#else
SHERPA_ONNX_LOGE("After normalizing: %s", text.c_str());
#endif
}
}
}
std::vector<TokenIDs> token_ids =
frontend_->ConvertTextToTokenIds(text, meta_data.voice);
if (token_ids.empty() ||
(token_ids.size() == 1 && token_ids[0].tokens.empty())) {
#if __OHOS__
SHERPA_ONNX_LOGE("Failed to convert '%{public}s' to token IDs",
text.c_str());
#else
SHERPA_ONNX_LOGE("Failed to convert '%s' to token IDs", text.c_str());
#endif
return {};
}
std::vector<std::vector<int64_t>> x;
x.reserve(token_ids.size());
for (auto &i : token_ids) {
x.push_back(std::move(i.tokens));
}
int32_t x_size = static_cast<int32_t>(x.size());
if (config_.max_num_sentences != 1) {
#if __OHOS__
SHERPA_ONNX_LOGE(
"max_num_sentences (%{public}d) != 1 is ignored for Kitten TTS "
"models",
config_.max_num_sentences);
#else
SHERPA_ONNX_LOGE(
"max_num_sentences (%d) != 1 is ignored for Kitten TTS models",
config_.max_num_sentences);
#endif
}
// the input text is too long, we process sentences within it in batches
// to avoid OOM. Batch size is config_.max_num_sentences
std::vector<std::vector<int64_t>> batch_x;
int32_t batch_size = 1;
batch_x.reserve(batch_size);
int32_t num_batches = x_size / batch_size;
if (config_.model.debug) {
#if __OHOS__
SHERPA_ONNX_LOGE(
"Split it into %{public}d batches. batch size: "
"%{public}d. Number of sentences: %{public}d",
num_batches, batch_size, x_size);
#else
SHERPA_ONNX_LOGE(
"Split it into %d batches. batch size: %d. Number "
"of sentences: %d",
num_batches, batch_size, x_size);
#endif
}
GeneratedAudio ans;
int32_t should_continue = 1;
int32_t k = 0;
for (int32_t b = 0; b != num_batches && should_continue; ++b) {
batch_x.clear();
for (int32_t i = 0; i != batch_size; ++i, ++k) {
batch_x.push_back(std::move(x[k]));
}
auto audio = Process(batch_x, sid, speed);
ans.sample_rate = audio.sample_rate;
ans.samples.insert(ans.samples.end(), audio.samples.begin(),
audio.samples.end());
if (callback) {
should_continue = callback(audio.samples.data(), audio.samples.size(),
(b + 1) * 1.0 / num_batches);
// Caution(fangjun): audio is freed when the callback returns, so users
// should copy the data if they want to access the data after
// the callback returns to avoid segmentation fault.
}
}
batch_x.clear();
while (k < static_cast<int32_t>(x.size()) && should_continue) {
batch_x.push_back(std::move(x[k]));
++k;
}
if (!batch_x.empty()) {
auto audio = Process(batch_x, sid, speed);
ans.sample_rate = audio.sample_rate;
ans.samples.insert(ans.samples.end(), audio.samples.begin(),
audio.samples.end());
if (callback) {
callback(audio.samples.data(), audio.samples.size(), 1.0);
// Caution(fangjun): audio is freed when the callback returns, so users
// should copy the data if they want to access the data after
// the callback returns to avoid segmentation fault.
}
}
return ans;
}
private:
template <typename Manager>
void InitFrontend(Manager *mgr) {
const auto &meta_data = model_->GetMetaData();
frontend_ = std::make_unique<PiperPhonemizeLexicon>(
mgr, config_.model.kitten.tokens, config_.model.kitten.data_dir,
meta_data);
}
void InitFrontend() {
const auto &meta_data = model_->GetMetaData();
frontend_ = std::make_unique<PiperPhonemizeLexicon>(
config_.model.kitten.tokens, config_.model.kitten.data_dir, meta_data);
}
GeneratedAudio Process(const std::vector<std::vector<int64_t>> &tokens,
int32_t sid, float speed) const {
int32_t num_tokens = 0;
for (const auto &k : tokens) {
num_tokens += k.size();
}
std::vector<int64_t> x;
x.reserve(num_tokens);
for (const auto &k : tokens) {
x.insert(x.end(), k.begin(), k.end());
}
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 2> x_shape = {1, static_cast<int32_t>(x.size())};
Ort::Value x_tensor = Ort::Value::CreateTensor(
memory_info, x.data(), x.size(), x_shape.data(), x_shape.size());
Ort::Value audio = model_->Run(std::move(x_tensor), sid, speed);
std::vector<int64_t> audio_shape =
audio.GetTensorTypeAndShapeInfo().GetShape();
int64_t total = 1;
// The output shape may be (1, 1, total) or (1, total) or (total,)
for (auto i : audio_shape) {
total *= i;
}
const float *p = audio.GetTensorData<float>();
GeneratedAudio ans;
ans.sample_rate = model_->GetMetaData().sample_rate;
ans.samples = std::vector<float>(p, p + total);
float silence_scale = config_.silence_scale;
if (silence_scale != 1) {
ans = ans.ScaleSilence(silence_scale);
}
return ans;
}
private:
OfflineTtsConfig config_;
std::unique_ptr<OfflineTtsKittenModel> model_;
std::vector<std::unique_ptr<kaldifst::TextNormalizer>> tn_list_;
std::unique_ptr<OfflineTtsFrontend> frontend_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_KITTEN_IMPL_H_
... ...
// sherpa-onnx/csrc/offline-tts-kitten-model-config.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-tts-kitten-model-config.h"
#include <vector>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
void OfflineTtsKittenModelConfig::Register(ParseOptions *po) {
po->Register("kitten-model", &model, "Path to kitten model");
po->Register("kitten-voices", &voices,
"Path to voices.bin for kitten models");
po->Register("kitten-tokens", &tokens,
"Path to tokens.txt for kitten models");
po->Register("kitten-data-dir", &data_dir,
"Path to the directory containing dict for espeak-ng.");
po->Register("kitten-length-scale", &length_scale,
"Inverse of speech speed. Larger->Slower; Smaller->faster.");
}
bool OfflineTtsKittenModelConfig::Validate() const {
if (model.empty()) {
SHERPA_ONNX_LOGE("Please provide --kitten-model");
return false;
}
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("--kitten-model: '%s' does not exist", model.c_str());
return false;
}
if (voices.empty()) {
SHERPA_ONNX_LOGE("Please provide --kitten-voices");
return false;
}
if (!FileExists(voices)) {
SHERPA_ONNX_LOGE("--kitten-voices: '%s' does not exist", voices.c_str());
return false;
}
if (tokens.empty()) {
SHERPA_ONNX_LOGE("Please provide --kitten-tokens");
return false;
}
if (!FileExists(tokens)) {
SHERPA_ONNX_LOGE("--kitten-tokens: '%s' does not exist", tokens.c_str());
return false;
}
if (data_dir.empty()) {
SHERPA_ONNX_LOGE("Please provide --kitten-data-dir");
return false;
}
if (!FileExists(data_dir + "/phontab")) {
SHERPA_ONNX_LOGE(
"'%s/phontab' does not exist. Please check --kitten-data-dir",
data_dir.c_str());
return false;
}
if (!FileExists(data_dir + "/phonindex")) {
SHERPA_ONNX_LOGE(
"'%s/phonindex' does not exist. Please check --kitten-data-dir",
data_dir.c_str());
return false;
}
if (!FileExists(data_dir + "/phondata")) {
SHERPA_ONNX_LOGE(
"'%s/phondata' does not exist. Please check --kitten-data-dir",
data_dir.c_str());
return false;
}
if (!FileExists(data_dir + "/intonations")) {
SHERPA_ONNX_LOGE(
"'%s/intonations' does not exist. Please check --kitten-data-dir",
data_dir.c_str());
return false;
}
if (length_scale <= 0) {
SHERPA_ONNX_LOGE(
"Please provide a positive length_scale for --kitten-length-scale. "
"Given: %.3f",
length_scale);
return false;
}
return true;
}
std::string OfflineTtsKittenModelConfig::ToString() const {
std::ostringstream os;
os << "OfflineTtsKittenModelConfig(";
os << "model=\"" << model << "\", ";
os << "voices=\"" << voices << "\", ";
os << "tokens=\"" << tokens << "\", ";
os << "data_dir=\"" << data_dir << "\", ";
os << "length_scale=" << length_scale << ")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-tts-kitten-model-config.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_KITTEN_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_KITTEN_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OfflineTtsKittenModelConfig {
std::string model;
std::string voices;
std::string tokens;
std::string data_dir;
// speed = 1 / length_scale
float length_scale = 1.0;
OfflineTtsKittenModelConfig() = default;
OfflineTtsKittenModelConfig(const std::string &model,
const std::string &voices,
const std::string &tokens,
const std::string &data_dir, float length_scale)
: model(model),
voices(voices),
tokens(tokens),
data_dir(data_dir),
length_scale(length_scale) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_KITTEN_MODEL_CONFIG_H_
... ...
// sherpa-onnx/csrc/offline-tts-kitten-model-meta-data.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_KITTEN_MODEL_META_DATA_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_KITTEN_MODEL_META_DATA_H_
#include <cstdint>
#include <string>
namespace sherpa_onnx {
// please refer to
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/kitten-tts/nano_v0_1/add_meta_data.py
struct OfflineTtsKittenModelMetaData {
int32_t sample_rate = 0;
int32_t num_speakers = 0;
int32_t version = 1;
int32_t has_espeak = 1;
int32_t max_token_len = 256;
std::string voice;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_KITTEN_MODEL_META_DATA_H_
... ...
// sherpa-onnx/csrc/offline-tts-kitten-model.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-tts-kitten-model.h"
#include <algorithm>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
class OfflineTtsKittenModel::Impl {
public:
explicit Impl(const OfflineTtsModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto model_buf = ReadFile(config.kitten.model);
auto voices_buf = ReadFile(config.kitten.voices);
Init(model_buf.data(), model_buf.size(), voices_buf.data(),
voices_buf.size());
}
template <typename Manager>
Impl(Manager *mgr, const OfflineTtsModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto model_buf = ReadFile(mgr, config.kitten.model);
auto voices_buf = ReadFile(mgr, config.kitten.voices);
Init(model_buf.data(), model_buf.size(), voices_buf.data(),
voices_buf.size());
}
const OfflineTtsKittenModelMetaData &GetMetaData() const {
return meta_data_;
}
Ort::Value Run(Ort::Value x, int32_t sid, float speed) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::vector<int64_t> x_shape = x.GetTensorTypeAndShapeInfo().GetShape();
if (x_shape[0] != 1) {
SHERPA_ONNX_LOGE("Support only batch_size == 1. Given: %d",
static_cast<int32_t>(x_shape[0]));
SHERPA_ONNX_EXIT(-1);
}
int32_t num_speakers = meta_data_.num_speakers;
int32_t dim1 = style_dim_[1];
/*const*/ float *p = styles_.data() + sid * dim1;
std::array<int64_t, 2> style_embedding_shape = {1, dim1};
Ort::Value style_embedding = Ort::Value::CreateTensor(
memory_info, p, dim1, style_embedding_shape.data(),
style_embedding_shape.size());
int64_t speed_shape = 1;
if (config_.kitten.length_scale != 1 && speed == 1) {
speed = 1. / config_.kitten.length_scale;
}
Ort::Value speed_tensor =
Ort::Value::CreateTensor(memory_info, &speed, 1, &speed_shape, 1);
std::array<Ort::Value, 3> inputs = {
std::move(x), std::move(style_embedding), std::move(speed_tensor)};
auto out =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
return std::move(out[0]);
}
private:
void Init(void *model_data, size_t model_data_length, const char *voices_data,
size_t voices_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
os << "---kitten model---\n";
PrintModelMetadata(os, meta_data);
os << "----------input names----------\n";
int32_t i = 0;
for (const auto &s : input_names_) {
os << i << " " << s << "\n";
++i;
}
os << "----------output names----------\n";
i = 0;
for (const auto &s : output_names_) {
os << i << " " << s << "\n";
++i;
}
#if __OHOS__
SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str());
#else
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
#endif
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
std::string model_type;
SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type");
if (model_type != "kitten-tts") {
SHERPA_ONNX_LOGE(
"Please download the kitten tts model from us containing meta data");
SHERPA_ONNX_EXIT(-1);
}
SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate");
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.version, "version", 1);
SHERPA_ONNX_READ_META_DATA(meta_data_.num_speakers, "n_speakers");
SHERPA_ONNX_READ_META_DATA(meta_data_.has_espeak, "has_espeak");
SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.voice, "voice",
"en-us");
if (meta_data_.has_espeak != 1) {
SHERPA_ONNX_LOGE("It should require espeak-ng");
SHERPA_ONNX_EXIT(-1);
}
if (config_.debug) {
std::vector<std::string> speaker_names;
SHERPA_ONNX_READ_META_DATA_VEC_STRING(speaker_names, "speaker_names");
std::ostringstream os;
os << "\n";
for (int32_t i = 0; i != speaker_names.size(); ++i) {
os << i << "->" << speaker_names[i] << ", ";
}
os << "\n";
#if __OHOS__
SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str());
#else
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
#endif
}
SHERPA_ONNX_READ_META_DATA_VEC(style_dim_, "style_dim");
if (style_dim_.size() != 2) {
SHERPA_ONNX_LOGE("style_dim should be 2-d, given: %d",
static_cast<int32_t>(style_dim_.size()));
SHERPA_ONNX_EXIT(-1);
}
if (style_dim_[0] != 1) {
SHERPA_ONNX_LOGE("style_dim[0] should be 1, given: %d", style_dim_[0]);
SHERPA_ONNX_EXIT(-1);
}
int32_t actual_num_floats = voices_data_length / sizeof(float);
int32_t expected_num_floats =
style_dim_[0] * style_dim_[1] * meta_data_.num_speakers;
if (actual_num_floats != expected_num_floats) {
#if __OHOS__
SHERPA_ONNX_LOGE(
"Corrupted --kitten-voices '%{public}s'. Expected #floats: "
"%{public}d, actual: %{public}d",
config_.kitten.voices.c_str(), expected_num_floats,
actual_num_floats);
#else
SHERPA_ONNX_LOGE(
"Corrupted --kitten-voices '%s'. Expected #floats: %d, actual: %d",
config_.kitten.voices.c_str(), expected_num_floats,
actual_num_floats);
#endif
SHERPA_ONNX_EXIT(-1);
}
styles_ = std::vector<float>(
reinterpret_cast<const float *>(voices_data),
reinterpret_cast<const float *>(voices_data) + expected_num_floats);
}
private:
OfflineTtsModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> sess_;
std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
OfflineTtsKittenModelMetaData meta_data_;
std::vector<int32_t> style_dim_;
// (num_speakers, style_dim_[1])
std::vector<float> styles_;
};
OfflineTtsKittenModel::OfflineTtsKittenModel(
const OfflineTtsModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
template <typename Manager>
OfflineTtsKittenModel::OfflineTtsKittenModel(
Manager *mgr, const OfflineTtsModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
OfflineTtsKittenModel::~OfflineTtsKittenModel() = default;
const OfflineTtsKittenModelMetaData &OfflineTtsKittenModel::GetMetaData()
const {
return impl_->GetMetaData();
}
Ort::Value OfflineTtsKittenModel::Run(Ort::Value x, int64_t sid /*= 0*/,
float speed /*= 1.0*/) const {
return impl_->Run(std::move(x), sid, speed);
}
#if __ANDROID_API__ >= 9
template OfflineTtsKittenModel::OfflineTtsKittenModel(
AAssetManager *mgr, const OfflineTtsModelConfig &config);
#endif
#if __OHOS__
template OfflineTtsKittenModel::OfflineTtsKittenModel(
NativeResourceManager *mgr, const OfflineTtsModelConfig &config);
#endif
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-tts-kitten-model.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_KITTEN_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_KITTEN_MODEL_H_
#include <memory>
#include <string>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-tts-kitten-model-meta-data.h"
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
namespace sherpa_onnx {
class OfflineTtsKittenModel {
public:
~OfflineTtsKittenModel();
explicit OfflineTtsKittenModel(const OfflineTtsModelConfig &config);
template <typename Manager>
OfflineTtsKittenModel(Manager *mgr, const OfflineTtsModelConfig &config);
// @params x An int64 tensor of shape (1, num_tokens)
// @return Return a float32 tensor containing the
// samples of shape (num_samples,)
Ort::Value Run(Ort::Value x, int64_t sid = 0, float speed = 1.0) const;
const OfflineTtsKittenModelMetaData &GetMetaData() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_KITTEN_MODEL_H_
... ...
... ... @@ -11,7 +11,9 @@
namespace sherpa_onnx {
// please refer to
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/kokoro/add-meta-data.py
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/kokoro/v0.19/add_meta_data.py
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/kokoro/v1.0/add_meta_data.py
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/kokoro/v1.1-zh/add_meta_data.py
struct OfflineTtsKokoroModelMetaData {
int32_t sample_rate = 0;
int32_t num_speakers = 0;
... ...
... ... @@ -170,7 +170,7 @@ class OfflineTtsKokoroModel::Impl {
}
if (style_dim_[1] != 1) {
SHERPA_ONNX_LOGE("style_dim[0] should be 1, given: %d", style_dim_[1]);
SHERPA_ONNX_LOGE("style_dim[1] should be 1, given: %d", style_dim_[1]);
SHERPA_ONNX_EXIT(-1);
}
... ...
... ... @@ -23,8 +23,8 @@ class OfflineTtsKokoroModel {
template <typename Manager>
OfflineTtsKokoroModel(Manager *mgr, const OfflineTtsModelConfig &config);
// Return a float32 tensor containing the mel
// of shape (batch_size, mel_dim, num_frames)
// Return a float32 tensor containing the samples
// of shape (batch_size, num_samples)
Ort::Value Run(Ort::Value x, int64_t sid = 0, float speed = 1.0) const;
const OfflineTtsKokoroModelMetaData &GetMetaData() const;
... ...
... ... @@ -12,6 +12,7 @@ void OfflineTtsModelConfig::Register(ParseOptions *po) {
vits.Register(po);
matcha.Register(po);
kokoro.Register(po);
kitten.Register(po);
po->Register("num-threads", &num_threads,
"Number of threads to run the neural network");
... ... @@ -37,7 +38,17 @@ bool OfflineTtsModelConfig::Validate() const {
return matcha.Validate();
}
if (!kokoro.model.empty()) {
return kokoro.Validate();
}
if (!kitten.model.empty()) {
return kitten.Validate();
}
SHERPA_ONNX_LOGE("Please provide at exactly one tts model.");
return false;
}
std::string OfflineTtsModelConfig::ToString() const {
... ... @@ -47,6 +58,7 @@ std::string OfflineTtsModelConfig::ToString() const {
os << "vits=" << vits.ToString() << ", ";
os << "matcha=" << matcha.ToString() << ", ";
os << "kokoro=" << kokoro.ToString() << ", ";
os << "kitten=" << kitten.ToString() << ", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
os << "provider=\"" << provider << "\")";
... ...
... ... @@ -7,6 +7,7 @@
#include <string>
#include "sherpa-onnx/csrc/offline-tts-kitten-model-config.h"
#include "sherpa-onnx/csrc/offline-tts-kokoro-model-config.h"
#include "sherpa-onnx/csrc/offline-tts-matcha-model-config.h"
#include "sherpa-onnx/csrc/offline-tts-vits-model-config.h"
... ... @@ -18,6 +19,7 @@ struct OfflineTtsModelConfig {
OfflineTtsVitsModelConfig vits;
OfflineTtsMatchaModelConfig matcha;
OfflineTtsKokoroModelConfig kokoro;
OfflineTtsKittenModelConfig kitten;
int32_t num_threads = 1;
bool debug = false;
... ... @@ -28,11 +30,13 @@ struct OfflineTtsModelConfig {
OfflineTtsModelConfig(const OfflineTtsVitsModelConfig &vits,
const OfflineTtsMatchaModelConfig &matcha,
const OfflineTtsKokoroModelConfig &kokoro,
const OfflineTtsKittenModelConfig &kitten,
int32_t num_threads, bool debug,
const std::string &provider)
: vits(vits),
matcha(matcha),
kokoro(kokoro),
kitten(kitten),
num_threads(num_threads),
debug(debug),
provider(provider) {}
... ...
... ... @@ -180,7 +180,7 @@ static std::vector<int64_t> PiperPhonemesToIdsMatcha(
return ans;
}
static std::vector<std::vector<int64_t>> PiperPhonemesToIdsKokoro(
static std::vector<std::vector<int64_t>> PiperPhonemesToIdsKokoroOrKitten(
const std::unordered_map<char32_t, int32_t> &token2id,
const std::vector<piper::Phoneme> &phonemes, int32_t max_len) {
std::vector<std::vector<int64_t>> ans;
... ... @@ -277,7 +277,6 @@ static std::vector<int64_t> CoquiPhonemesToIds(
void InitEspeak(const std::string &data_dir) {
static std::once_flag init_flag;
std::call_once(init_flag, [data_dir]() {
#if __ANDROID_API__ >= 9 || defined(__OHOS__)
if (data_dir[0] != '/') {
SHERPA_ONNX_LOGE(
... ... @@ -358,6 +357,18 @@ PiperPhonemizeLexicon::PiperPhonemizeLexicon(
InitEspeak(data_dir);
}
PiperPhonemizeLexicon::PiperPhonemizeLexicon(
const std::string &tokens, const std::string &data_dir,
const OfflineTtsKittenModelMetaData &kitten_meta_data)
: kitten_meta_data_(kitten_meta_data), is_kitten_(true) {
{
std::ifstream is(tokens);
token2id_ = ReadTokens(is);
}
InitEspeak(data_dir);
}
template <typename Manager>
PiperPhonemizeLexicon::PiperPhonemizeLexicon(
Manager *mgr, const std::string &tokens, const std::string &data_dir,
... ... @@ -392,13 +403,33 @@ PiperPhonemizeLexicon::PiperPhonemizeLexicon(
InitEspeak(data_dir);
}
template <typename Manager>
PiperPhonemizeLexicon::PiperPhonemizeLexicon(
Manager *mgr, const std::string &tokens, const std::string &data_dir,
const OfflineTtsKittenModelMetaData &kitten_meta_data)
: kitten_meta_data_(kitten_meta_data), is_kitten_(true) {
{
auto buf = ReadFile(mgr, tokens);
std::istrstream is(buf.data(), buf.size());
token2id_ = ReadTokens(is);
}
// We should copy the directory of espeak-ng-data from the asset to
// some internal or external storage and then pass the directory to
// data_dir.
InitEspeak(data_dir);
}
std::vector<TokenIDs> PiperPhonemizeLexicon::ConvertTextToTokenIds(
const std::string &text, const std::string &voice /*= ""*/) const {
if (is_matcha_) {
return ConvertTextToTokenIdsMatcha(text, voice);
} else if (is_kokoro_) {
return ConvertTextToTokenIdsKokoro(
return ConvertTextToTokenIdsKokoroOrKitten(
token2id_, kokoro_meta_data_.max_token_len, text, voice);
} else if (is_kitten_) {
return ConvertTextToTokenIdsKokoroOrKitten(
token2id_, kitten_meta_data_.max_token_len, text, voice);
} else {
return ConvertTextToTokenIdsVits(text, voice);
}
... ... @@ -429,7 +460,7 @@ std::vector<TokenIDs> PiperPhonemizeLexicon::ConvertTextToTokenIdsMatcha(
return ans;
}
std::vector<TokenIDs> ConvertTextToTokenIdsKokoro(
std::vector<TokenIDs> ConvertTextToTokenIdsKokoroOrKitten(
const std::unordered_map<char32_t, int32_t> &token2id,
int32_t max_token_len, const std::string &text,
const std::string &voice /*= ""*/) {
... ... @@ -446,7 +477,8 @@ std::vector<TokenIDs> ConvertTextToTokenIdsKokoro(
std::vector<TokenIDs> ans;
for (const auto &p : phonemes) {
auto phoneme_ids = PiperPhonemesToIdsKokoro(token2id, p, max_token_len);
auto phoneme_ids =
PiperPhonemesToIdsKokoroOrKitten(token2id, p, max_token_len);
for (auto &ids : phoneme_ids) {
ans.emplace_back(std::move(ids));
... ...
... ... @@ -10,6 +10,7 @@
#include <vector>
#include "sherpa-onnx/csrc/offline-tts-frontend.h"
#include "sherpa-onnx/csrc/offline-tts-kitten-model-meta-data.h"
#include "sherpa-onnx/csrc/offline-tts-kokoro-model-meta-data.h"
#include "sherpa-onnx/csrc/offline-tts-matcha-model-meta-data.h"
#include "sherpa-onnx/csrc/offline-tts-vits-model-meta-data.h"
... ... @@ -27,6 +28,9 @@ class PiperPhonemizeLexicon : public OfflineTtsFrontend {
PiperPhonemizeLexicon(const std::string &tokens, const std::string &data_dir,
const OfflineTtsKokoroModelMetaData &kokoro_meta_data);
PiperPhonemizeLexicon(const std::string &tokens, const std::string &data_dir,
const OfflineTtsKittenModelMetaData &kitten_meta_data);
template <typename Manager>
PiperPhonemizeLexicon(Manager *mgr, const std::string &tokens,
const std::string &data_dir,
... ... @@ -42,6 +46,11 @@ class PiperPhonemizeLexicon : public OfflineTtsFrontend {
const std::string &data_dir,
const OfflineTtsKokoroModelMetaData &kokoro_meta_data);
template <typename Manager>
PiperPhonemizeLexicon(Manager *mgr, const std::string &tokens,
const std::string &data_dir,
const OfflineTtsKittenModelMetaData &kitten_meta_data);
std::vector<TokenIDs> ConvertTextToTokenIds(
const std::string &text, const std::string &voice = "") const override;
... ... @@ -58,8 +67,10 @@ class PiperPhonemizeLexicon : public OfflineTtsFrontend {
OfflineTtsVitsModelMetaData vits_meta_data_;
OfflineTtsMatchaModelMetaData matcha_meta_data_;
OfflineTtsKokoroModelMetaData kokoro_meta_data_;
OfflineTtsKittenModelMetaData kitten_meta_data_;
bool is_matcha_ = false;
bool is_kokoro_ = false;
bool is_kitten_ = false;
};
} // namespace sherpa_onnx
... ...
... ... @@ -101,6 +101,7 @@ or details.
float duration = audio.samples.size() / static_cast<float>(audio.sample_rate);
float rtf = elapsed_seconds / duration;
fprintf(stderr, "Number of threads: %d\n", config.model.num_threads);
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
fprintf(stderr, "Audio duration: %.3f s\n", duration);
fprintf(stderr, "Real-time factor (RTF): %.3f/%.3f = %.3f\n", elapsed_seconds,
... ...
... ... @@ -67,6 +67,7 @@ endif()
if(SHERPA_ONNX_ENABLE_TTS)
list(APPEND srcs
offline-tts-kitten-model-config.cc
offline-tts-kokoro-model-config.cc
offline-tts-matcha-model-config.cc
offline-tts-model-config.cc
... ...
// sherpa-onnx/python/csrc/offline-tts-kitten-model-config.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/offline-tts-kitten-model-config.h"
#include <string>
#include "sherpa-onnx/csrc/offline-tts-kitten-model-config.h"
namespace sherpa_onnx {
void PybindOfflineTtsKittenModelConfig(py::module *m) {
using PyClass = OfflineTtsKittenModelConfig;
py::class_<PyClass>(*m, "OfflineTtsKittenModelConfig")
.def(py::init<>())
.def(py::init<const std::string &, const std::string &,
const std::string &, const std::string &, float>(),
py::arg("model"), py::arg("voices"), py::arg("tokens"),
py::arg("data_dir"), py::arg("length_scale") = 1.0)
.def_readwrite("model", &PyClass::model)
.def_readwrite("voices", &PyClass::voices)
.def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("data_dir", &PyClass::data_dir)
.def_readwrite("length_scale", &PyClass::length_scale)
.def("__str__", &PyClass::ToString)
.def("validate", &PyClass::Validate);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/offline-tts-kitten-model-config.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_KITTEN_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_KITTEN_MODEL_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOfflineTtsKittenModelConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_KITTEN_MODEL_CONFIG_H_
... ...
... ... @@ -7,6 +7,7 @@
#include <string>
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
#include "sherpa-onnx/python/csrc/offline-tts-kitten-model-config.h"
#include "sherpa-onnx/python/csrc/offline-tts-kokoro-model-config.h"
#include "sherpa-onnx/python/csrc/offline-tts-matcha-model-config.h"
#include "sherpa-onnx/python/csrc/offline-tts-vits-model-config.h"
... ... @@ -17,6 +18,7 @@ void PybindOfflineTtsModelConfig(py::module *m) {
PybindOfflineTtsVitsModelConfig(m);
PybindOfflineTtsMatchaModelConfig(m);
PybindOfflineTtsKokoroModelConfig(m);
PybindOfflineTtsKittenModelConfig(m);
using PyClass = OfflineTtsModelConfig;
... ... @@ -24,16 +26,19 @@ void PybindOfflineTtsModelConfig(py::module *m) {
.def(py::init<>())
.def(py::init<const OfflineTtsVitsModelConfig &,
const OfflineTtsMatchaModelConfig &,
const OfflineTtsKokoroModelConfig &, int32_t, bool,
const OfflineTtsKokoroModelConfig &,
const OfflineTtsKittenModelConfig &, int32_t, bool,
const std::string &>(),
py::arg("vits") = OfflineTtsVitsModelConfig{},
py::arg("matcha") = OfflineTtsMatchaModelConfig{},
py::arg("kokoro") = OfflineTtsKokoroModelConfig{},
py::arg("kitten") = OfflineTtsKittenModelConfig{},
py::arg("num_threads") = 1, py::arg("debug") = false,
py::arg("provider") = "cpu")
.def_readwrite("vits", &PyClass::vits)
.def_readwrite("matcha", &PyClass::matcha)
.def_readwrite("kokoro", &PyClass::kokoro)
.def_readwrite("kitten", &PyClass::kitten)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)
.def_readwrite("provider", &PyClass::provider)
... ...