Wei Kang
Committed by GitHub

Add Zipvoice (#2487)

Co-authored-by: yaozengwei <yaozengwei@outlook.com>
... ... @@ -372,6 +372,7 @@ endif()
include(kaldi-native-fbank)
include(kaldi-decoder)
include(onnxruntime)
include(cppinyin)
include(simple-sentencepiece)
set(ONNXRUNTIME_DIR ${onnxruntime_SOURCE_DIR})
message(STATUS "ONNXRUNTIME_DIR: ${ONNXRUNTIME_DIR}")
... ...
function(download_cppinyin)
include(FetchContent)
set(cppinyin_URL "https://github.com/pkufool/cppinyin/archive/refs/tags/v0.10.tar.gz")
set(cppinyin_URL2 "https://gh-proxy.com/https://github.com/pkufool/cppinyin/archive/refs/tags/v0.10.tar.gz")
set(cppinyin_HASH "SHA256=abe6584d7ee56829e8f4b5fbda3b50ecdf49a13be8e413a78d1b0d5d5c019982")
# If you don't have access to the Internet,
# please pre-download cppinyin
set(possible_file_locations
$ENV{HOME}/Downloads/cppinyin-0.10.tar.gz
${CMAKE_SOURCE_DIR}/cppinyin-0.10.tar.gz
${CMAKE_BINARY_DIR}/cppinyin-0.10.tar.gz
/tmp/cppinyin-0.10.tar.gz
/star-fj/fangjun/download/github/cppinyin-0.10.tar.gz
)
foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(cppinyin_URL "${f}")
file(TO_CMAKE_PATH "${cppinyin_URL}" cppinyin_URL)
message(STATUS "Found local downloaded cppinyin: ${cppinyin_URL}")
set(cppinyin_URL2)
break()
endif()
endforeach()
set(CPPINYIN_ENABLE_TESTS OFF CACHE BOOL "" FORCE)
set(CPPINYIN_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
FetchContent_Declare(cppinyin
URL
${cppinyin_URL}
${cppinyin_URL2}
URL_HASH
${cppinyin_HASH}
)
FetchContent_GetProperties(cppinyin)
if(NOT cppinyin_POPULATED)
message(STATUS "Downloading cppinyin ${cppinyin_URL}")
FetchContent_Populate(cppinyin)
endif()
message(STATUS "cppinyin is downloaded to ${cppinyin_SOURCE_DIR}")
if(BUILD_SHARED_LIBS)
set(_build_shared_libs_bak ${BUILD_SHARED_LIBS})
set(BUILD_SHARED_LIBS OFF)
endif()
add_subdirectory(${cppinyin_SOURCE_DIR} ${cppinyin_BINARY_DIR} EXCLUDE_FROM_ALL)
if(_build_shared_libs_bak)
set_target_properties(cppinyin_core
PROPERTIES
POSITION_INDEPENDENT_CODE ON
C_VISIBILITY_PRESET hidden
CXX_VISIBILITY_PRESET hidden
)
set(BUILD_SHARED_LIBS ON)
endif()
target_include_directories(cppinyin_core
PUBLIC
${cppinyin_SOURCE_DIR}/
)
if(NOT BUILD_SHARED_LIBS)
install(TARGETS cppinyin_core DESTINATION lib)
endif()
endfunction()
download_cppinyin()
... ...
#!/usr/bin/env python3
#
# Copyright (c) 2023 Xiaomi Corporation
# Copyright (c) 2023-2025 Xiaomi Corporation
"""
This file demonstrates how to use sherpa-onnx Python API to generate audio
... ... @@ -453,7 +453,9 @@ def main():
end = time.time()
if len(audio.samples) == 0:
print("Error in generating audios. Please read previous error messages.")
print(
"Error in generating audios. Please read previous error messages."
)
return
elapsed_seconds = end - start
... ... @@ -470,7 +472,9 @@ def main():
print(f"The text is '{args.text}'")
print(f"Elapsed seconds: {elapsed_seconds:.3f}")
print(f"Audio duration in seconds: {audio_duration:.3f}")
print(f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}")
print(
f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}"
)
if __name__ == "__main__":
... ...
#!/usr/bin/env python3
#
# Copyright (c) 2025 Xiaomi Corporation
"""
This file demonstrates how to use sherpa-onnx Python API to generate audio
from text with prompt, i.e., zero shot text-to-speech.
Usage:
Example (zipvoice)
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/sherpa-onnx-zipvoice-distill-zh-en-emilia.tar.bz2
tar xf sherpa-onnx-zipvoice-distill-zh-en-emilia.tar.bz2
python3 ./python-api-examples/offline-zeroshot-tts.py \
--zipvoice-flow-matching-model sherpa-onnx-zipvoice-distill-zh-en-emilia/fm_decoder.onnx \
--zipvoice-text-model sherpa-onnx-zipvoice-distill-zh-en-emilia/text_encoder.onnx \
--zipvoice-data-dir sherpa-onnx-zipvoice-distill-zh-en-emilia/espeak-ng-data \
--zipvoice-pinyin-dict sherpa-onnx-zipvoice-distill-zh-en-emilia/pinyin.raw \
--zipvoice-tokens sherpa-onnx-zipvoice-distill-zh-en-emilia/tokens.txt \
--zipvoice-vocoder sherpa-onnx-zipvoice-distill-zh-en-emilia/vocos_24khz.onnx \
--prompt-audio sherpa-onnx-zipvoice-distill-zh-en-emilia/prompt.wav \
--zipvoice-num-steps 4 \
--num-threads 4 \
--prompt-text "周日被我射熄火了,所以今天是周一。" \
"我是中国人民的儿子,我爱我的祖国。我得祖国是一个伟大的国家,拥有五千年的文明史。"
"""
import argparse
import time
import wave
import numpy as np
from typing import Tuple
import sherpa_onnx
import soundfile as sf
def add_zipvoice_args(parser):
parser.add_argument(
"--zipvoice-tokens",
type=str,
default="",
help="Path to tokens.txt for Zipvoice models.",
)
parser.add_argument(
"--zipvoice-text-model",
type=str,
default="",
help="Path to zipvoice text model.",
)
parser.add_argument(
"--zipvoice-flow-matching-model",
type=str,
default="",
help="Path to zipvoice flow matching model.",
)
parser.add_argument(
"--zipvoice-data-dir",
type=str,
default="",
help="Path to the dict directory of espeak-ng.",
)
parser.add_argument(
"--zipvoice-pinyin-dict",
type=str,
default="",
help="Path to the pinyin dictionary.",
)
parser.add_argument(
"--zipvoice-vocoder",
type=str,
default="",
help="Path to the vocos vocoder.",
)
parser.add_argument(
"--zipvoice-num-steps",
type=int,
default=4,
help="Number of steps for Zipvoice.",
)
parser.add_argument(
"--zipvoice-feat-scale",
type=float,
default=0.1,
help="Scale factor for Zipvoice features.",
)
parser.add_argument(
"--zipvoice-t-shift",
type=float,
default=0.5,
help="Shift t to smaller ones if t-shift < 1.0.",
)
parser.add_argument(
"--zipvoice-target-rms",
type=float,
default=0.1,
help="Target speech normalization RMS value for Zipvoice.",
)
parser.add_argument(
"--zipvoice-guidance-scale",
type=float,
default=1.0,
help="The scale of classifier-free guidance during inference for for Zipvoice.",
)
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
"""
Args:
wave_filename:
Path to a wave file. It should be single channel and each sample should
be 16-bit. Its sample rate does not need to be 16kHz.
Returns:
Return a tuple containing:
- A 1-D array of dtype np.float32 containing the samples, which are
normalized to the range [-1, 1].
- sample rate of the wave file
"""
with wave.open(wave_filename) as f:
assert f.getnchannels() == 1, f.getnchannels()
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
num_samples = f.getnframes()
samples = f.readframes(num_samples)
samples_int16 = np.frombuffer(samples, dtype=np.int16)
samples_float32 = samples_int16.astype(np.float32)
samples_float32 = samples_float32 / 32768
return samples_float32, f.getframerate()
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
add_zipvoice_args(parser)
parser.add_argument(
"--tts-rule-fsts",
type=str,
default="",
help="Path to rule.fst",
)
parser.add_argument(
"--max-num-sentences",
type=int,
default=1,
help="""Max number of sentences in a batch to avoid OOM if the input
text is very long. Set it to -1 to process all the sentences in a
single batch. A smaller value does not mean it is slower compared
to a larger one on CPU.
""",
)
parser.add_argument(
"--output-filename",
type=str,
default="./generated.wav",
help="Path to save generated wave",
)
parser.add_argument(
"--debug",
type=bool,
default=False,
help="True to show debug messages",
)
parser.add_argument(
"--provider",
type=str,
default="cpu",
help="valid values: cpu, cuda, coreml",
)
parser.add_argument(
"--num-threads",
type=int,
default=1,
help="Number of threads for neural network computation",
)
parser.add_argument(
"--speed",
type=float,
default=1.0,
help="Speech speed. Larger->faster; smaller->slower",
)
parser.add_argument(
"--prompt-text",
type=str,
required=True,
help="The transcription of prompt audio (Zipvoice)",
)
parser.add_argument(
"--prompt-audio",
type=str,
required=True,
help="The path to prompt audio (Zipvoice).",
)
parser.add_argument(
"text",
type=str,
help="The input text to generate audio for",
)
return parser.parse_args()
def main():
args = get_args()
print(args)
tts_config = sherpa_onnx.OfflineTtsConfig(
model=sherpa_onnx.OfflineTtsModelConfig(
zipvoice=sherpa_onnx.OfflineTtsZipvoiceModelConfig(
tokens=args.zipvoice_tokens,
text_model=args.zipvoice_text_model,
flow_matching_model=args.zipvoice_flow_matching_model,
data_dir=args.zipvoice_data_dir,
pinyin_dict=args.zipvoice_pinyin_dict,
vocoder=args.zipvoice_vocoder,
feat_scale=args.zipvoice_feat_scale,
t_shift=args.zipvoice_t_shift,
target_rms=args.zipvoice_target_rms,
guidance_scale=args.zipvoice_guidance_scale,
),
provider=args.provider,
debug=args.debug,
num_threads=args.num_threads,
),
rule_fsts=args.tts_rule_fsts,
max_num_sentences=args.max_num_sentences,
)
if not tts_config.validate():
raise ValueError("Please check your config")
tts = sherpa_onnx.OfflineTts(tts_config)
start = time.time()
prompt_samples, sample_rate = read_wave(args.prompt_audio)
audio = tts.generate(
args.text,
args.prompt_text,
prompt_samples,
sample_rate,
speed=args.speed,
num_steps=args.zipvoice_num_steps,
)
end = time.time()
if len(audio.samples) == 0:
print(
"Error in generating audios. Please read previous error messages."
)
return
elapsed_seconds = end - start
audio_duration = len(audio.samples) / audio.sample_rate
real_time_factor = elapsed_seconds / audio_duration
sf.write(
args.output_filename,
audio.samples,
samplerate=audio.sample_rate,
subtype="PCM_16",
)
print(f"Saved to {args.output_filename}")
print(f"The text is '{args.text}'")
print(f"Elapsed seconds: {elapsed_seconds:.3f}")
print(f"Audio duration in seconds: {audio_duration:.3f}")
print(
f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}"
)
if __name__ == "__main__":
main()
... ...
... ... @@ -201,6 +201,9 @@ if(SHERPA_ONNX_ENABLE_TTS)
offline-tts-model-config.cc
offline-tts-vits-model-config.cc
offline-tts-vits-model.cc
offline-tts-zipvoice-frontend.cc
offline-tts-zipvoice-model.cc
offline-tts-zipvoice-model-config.cc
offline-tts.cc
piper-phonemize-lexicon.cc
vocoder.cc
... ... @@ -265,6 +268,7 @@ if(ANDROID_NDK)
endif()
target_link_libraries(sherpa-onnx-core
cppinyin_core
kaldi-native-fbank-core
kaldi-decoder-core
ssentencepiece_core
... ... @@ -348,6 +352,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
if(SHERPA_ONNX_ENABLE_TTS)
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
add_executable(sherpa-onnx-offline-zeroshot-tts sherpa-onnx-offline-zeroshot-tts.cc)
endif()
if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
... ... @@ -370,6 +375,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
if(SHERPA_ONNX_ENABLE_TTS)
list(APPEND main_exes
sherpa-onnx-offline-tts
sherpa-onnx-offline-zeroshot-tts
)
endif()
... ... @@ -667,6 +673,7 @@ if(SHERPA_ONNX_ENABLE_TESTS)
if(SHERPA_ONNX_ENABLE_TTS)
list(APPEND sherpa_onnx_test_srcs
cppjieba-test.cc
offline-tts-zipvoice-frontend-test.cc
piper-phonemize-test.cc
)
endif()
... ...
... ... @@ -20,6 +20,7 @@
#include "sherpa-onnx/csrc/offline-tts-kokoro-impl.h"
#include "sherpa-onnx/csrc/offline-tts-matcha-impl.h"
#include "sherpa-onnx/csrc/offline-tts-vits-impl.h"
#include "sherpa-onnx/csrc/offline-tts-zipvoice-impl.h"
namespace sherpa_onnx {
... ... @@ -41,6 +42,9 @@ std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create(
return std::make_unique<OfflineTtsVitsImpl>(config);
} else if (!config.model.matcha.acoustic_model.empty()) {
return std::make_unique<OfflineTtsMatchaImpl>(config);
} else if (!config.model.zipvoice.text_model.empty() &&
!config.model.zipvoice.flow_matching_model.empty()) {
return std::make_unique<OfflineTtsZipvoiceImpl>(config);
} else if (!config.model.kokoro.model.empty()) {
return std::make_unique<OfflineTtsKokoroImpl>(config);
} else if (!config.model.kitten.model.empty()) {
... ... @@ -59,6 +63,9 @@ std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create(
return std::make_unique<OfflineTtsVitsImpl>(mgr, config);
} else if (!config.model.matcha.acoustic_model.empty()) {
return std::make_unique<OfflineTtsMatchaImpl>(mgr, config);
} else if (!config.model.zipvoice.text_model.empty() &&
!config.model.zipvoice.flow_matching_model.empty()) {
return std::make_unique<OfflineTtsZipvoiceImpl>(mgr, config);
} else if (!config.model.kokoro.model.empty()) {
return std::make_unique<OfflineTtsKokoroImpl>(mgr, config);
} else if (!config.model.kitten.model.empty()) {
... ...
... ... @@ -6,9 +6,11 @@
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_
#include <memory>
#include <stdexcept>
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-tts.h"
namespace sherpa_onnx {
... ... @@ -25,14 +27,29 @@ class OfflineTtsImpl {
virtual GeneratedAudio Generate(
const std::string &text, int64_t sid = 0, float speed = 1.0,
GeneratedAudioCallback callback = nullptr) const = 0;
GeneratedAudioCallback callback = nullptr) const {
throw std::runtime_error(
"OfflineTtsImpl backend does not support non zero-shot Generate()");
}
virtual GeneratedAudio Generate(
const std::string &text, const std::string &prompt_text,
const std::vector<float> &prompt_samples, int32_t sample_rate,
float speed = 1.0, int32_t num_step = 4,
GeneratedAudioCallback callback = nullptr) const {
throw std::runtime_error(
"OfflineTtsImpl backend does not support zero-shot Generate()");
}
// Return the sample rate of the generated audio
virtual int32_t SampleRate() const = 0;
// Number of supported speakers.
// If it supports only a single speaker, then it return 0 or 1.
virtual int32_t NumSpeakers() const = 0;
virtual int32_t NumSpeakers() const {
throw std::runtime_error(
"Zero-shot OfflineTts does not support NumSpeakers()");
}
std::vector<int64_t> AddBlank(const std::vector<int64_t> &x,
int32_t blank_id = 0) const;
... ...
... ... @@ -12,6 +12,7 @@ void OfflineTtsModelConfig::Register(ParseOptions *po) {
vits.Register(po);
matcha.Register(po);
kokoro.Register(po);
zipvoice.Register(po);
kitten.Register(po);
po->Register("num-threads", &num_threads,
... ... @@ -38,6 +39,10 @@ bool OfflineTtsModelConfig::Validate() const {
return matcha.Validate();
}
if (!zipvoice.flow_matching_model.empty()) {
return zipvoice.Validate();
}
if (!kokoro.model.empty()) {
return kokoro.Validate();
}
... ... @@ -58,6 +63,7 @@ std::string OfflineTtsModelConfig::ToString() const {
os << "vits=" << vits.ToString() << ", ";
os << "matcha=" << matcha.ToString() << ", ";
os << "kokoro=" << kokoro.ToString() << ", ";
os << "zipvoice=" << zipvoice.ToString() << ", ";
os << "kitten=" << kitten.ToString() << ", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
... ...
... ... @@ -11,6 +11,7 @@
#include "sherpa-onnx/csrc/offline-tts-kokoro-model-config.h"
#include "sherpa-onnx/csrc/offline-tts-matcha-model-config.h"
#include "sherpa-onnx/csrc/offline-tts-vits-model-config.h"
#include "sherpa-onnx/csrc/offline-tts-zipvoice-model-config.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
... ... @@ -19,6 +20,7 @@ struct OfflineTtsModelConfig {
OfflineTtsVitsModelConfig vits;
OfflineTtsMatchaModelConfig matcha;
OfflineTtsKokoroModelConfig kokoro;
OfflineTtsZipvoiceModelConfig zipvoice;
OfflineTtsKittenModelConfig kitten;
int32_t num_threads = 1;
... ... @@ -30,12 +32,14 @@ struct OfflineTtsModelConfig {
OfflineTtsModelConfig(const OfflineTtsVitsModelConfig &vits,
const OfflineTtsMatchaModelConfig &matcha,
const OfflineTtsKokoroModelConfig &kokoro,
const OfflineTtsZipvoiceModelConfig &zipvoice,
const OfflineTtsKittenModelConfig &kitten,
int32_t num_threads, bool debug,
const std::string &provider)
: vits(vits),
matcha(matcha),
kokoro(kokoro),
zipvoice(zipvoice),
kitten(kitten),
num_threads(num_threads),
debug(debug),
... ...
// sherpa-onnx/csrc/offline-tts-zipvoice-frontend-test.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-tts-zipvoice-frontend.h"
#include "espeak-ng/speak_lib.h"
#include "gtest/gtest.h"
#include "phoneme_ids.hpp"
#include "phonemize.hpp"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
TEST(ZipVoiceFrontend, Case1) {
std::string data_dir = "../zipvoice/espeak-ng-data";
if (!FileExists(data_dir + "/en_dict")) {
SHERPA_ONNX_LOGE("%s/en_dict does not exist. Skipping test",
data_dir.c_str());
return;
}
if (!FileExists(data_dir + "/phontab")) {
SHERPA_ONNX_LOGE("%s/phontab does not exist. Skipping test",
data_dir.c_str());
return;
}
if (!FileExists(data_dir + "/phonindex")) {
SHERPA_ONNX_LOGE("%s/phonindex does not exist. Skipping test",
data_dir.c_str());
return;
}
if (!FileExists(data_dir + "/phondata")) {
SHERPA_ONNX_LOGE("%s/phondata does not exist. Skipping test",
data_dir.c_str());
return;
}
if (!FileExists(data_dir + "/intonations")) {
SHERPA_ONNX_LOGE("%s/intonations does not exist. Skipping test",
data_dir.c_str());
return;
}
std::string pinyin_dict = data_dir + "/../pinyin.dict";
if (!FileExists(pinyin_dict)) {
SHERPA_ONNX_LOGE("%s does not exist. Skipping test", pinyin_dict.c_str());
return;
}
std::string tokens_file = data_dir + "/../tokens.txt";
if (!FileExists(tokens_file)) {
SHERPA_ONNX_LOGE("%s does not exist. Skipping test", tokens_file.c_str());
return;
}
auto frontend = OfflineTtsZipvoiceFrontend(
tokens_file, data_dir, pinyin_dict,
OfflineTtsZipvoiceModelMetaData{.use_espeak = true, .use_pinyin = true},
true);
std::string text = "how are you doing?";
std::vector<sherpa_onnx::TokenIDs> ans =
frontend.ConvertTextToTokenIds(text, "en-us");
text = "这是第一句。这是第二句。";
ans = frontend.ConvertTextToTokenIds(text, "en-us");
text =
"这是第一句。这是第二句。<pin1><yin2>测试 [S1]and hello "
"world[S2]这是第三句。";
ans = frontend.ConvertTextToTokenIds(text, "en-us");
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-tts-zipvoice-frontend.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include <algorithm>
#include <cctype>
#include <codecvt>
#include <fstream>
#include <locale>
#include <regex> // NOLINT
#include <sstream>
#include <strstream>
#include <utility>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif
#include "cppinyin/csrc/cppinyin.h"
#include "espeak-ng/speak_lib.h"
#include "phoneme_ids.hpp"
#include "phonemize.hpp"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-tts-zipvoice-frontend.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
void CallPhonemizeEspeak(const std::string &text,
piper::eSpeakPhonemeConfig &config, // NOLINT
std::vector<std::vector<piper::Phoneme>> *phonemes);
static std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is) {
std::unordered_map<std::string, int32_t> token2id;
std::string line;
std::string sym;
int32_t id = 0;
while (std::getline(is, line)) {
std::istringstream iss(line);
iss >> sym;
if (iss.eof()) {
id = atoi(sym.c_str());
sym = " ";
} else {
iss >> id;
}
// eat the trailing \r\n on windows
iss >> std::ws;
if (!iss.eof()) {
SHERPA_ONNX_LOGE("Error when reading tokens: %s", line.c_str());
exit(-1);
}
if (token2id.count(sym)) {
SHERPA_ONNX_LOGE("Duplicated token %s. Line %s. Existing ID: %d",
sym.c_str(), line.c_str(), token2id.at(sym));
exit(-1);
}
token2id.insert({sym, id});
}
return token2id;
}
static std::string MapPunctuations(
const std::string &text,
const std::unordered_map<std::string, std::string> &punct_map) {
std::string result = text;
for (const auto &kv : punct_map) {
// Replace all occurrences of kv.first with kv.second
size_t pos = 0;
while ((pos = result.find(kv.first, pos)) != std::string::npos) {
result.replace(pos, kv.first.length(), kv.second);
pos += kv.second.length();
}
}
return result;
}
static void ProcessPinyin(
const std::string &pinyin, const cppinyin::PinyinEncoder *pinyin_encoder,
const std::unordered_map<std::string, int32_t> &token2id,
std::vector<int64_t> *tokens_ids, std::vector<std::string> *tokens) {
auto initial = pinyin_encoder->ToInitial(pinyin);
if (!initial.empty()) {
// append '0' to fix the conflict with espeak token
initial = initial + "0";
if (token2id.count(initial)) {
tokens_ids->push_back(token2id.at(initial));
tokens->push_back(initial);
} else {
SHERPA_ONNX_LOGE("Skip unknown initial %s", initial.c_str());
}
}
auto final_t = pinyin_encoder->ToFinal(pinyin);
if (!final_t.empty()) {
if (!std::isdigit(final_t.back())) {
final_t = final_t + "5"; // use 5 for neutral tone
}
if (token2id.count(final_t)) {
tokens_ids->push_back(token2id.at(final_t));
tokens->push_back(final_t);
} else {
SHERPA_ONNX_LOGE("Skip unknown final %s", final_t.c_str());
}
}
}
static void TokenizeZh(const std::string &words,
const cppinyin::PinyinEncoder *pinyin_encoder,
const std::unordered_map<std::string, int32_t> &token2id,
std::vector<int64_t> *token_ids,
std::vector<std::string> *tokens) {
std::vector<std::string> pinyins;
pinyin_encoder->Encode(words, &pinyins, "number" /*tone*/, false /*partial*/);
for (const auto &pinyin : pinyins) {
if (pinyin_encoder->ValidPinyin(pinyin, "number" /*tone*/)) {
ProcessPinyin(pinyin, pinyin_encoder, token2id, token_ids, tokens);
} else {
auto wstext = ToWideString(pinyin);
for (auto &wc : wstext) {
auto c = ToString(std::wstring(1, wc));
if (token2id.count(c)) {
token_ids->push_back(token2id.at(c));
tokens->push_back(c);
} else {
SHERPA_ONNX_LOGE("Skip unknown character %s", c.c_str());
}
}
}
}
}
static void TokenizeEn(const std::string &words,
const std::unordered_map<std::string, int32_t> &token2id,
const std::string &voice,
std::vector<int64_t> *token_ids,
std::vector<std::string> *tokens) {
piper::eSpeakPhonemeConfig config;
// ./bin/espeak-ng-bin --path ./install/share/espeak-ng-data/ --voices
// to list available voices
config.voice = voice; // e.g., voice is en-us
std::vector<std::vector<piper::Phoneme>> phonemes;
CallPhonemizeEspeak(words, config, &phonemes);
for (const auto &p : phonemes) {
for (const auto &ph : p) {
auto token = Utf32ToUtf8(std::u32string(1, ph));
if (token2id.count(token)) {
token_ids->push_back(token2id.at(token));
tokens->push_back(token);
} else {
SHERPA_ONNX_LOGE("Skip unknown phoneme %s", token.c_str());
}
}
}
}
static void TokenizeTag(
const std::string &words,
const std::unordered_map<std::string, int32_t> &token2id,
std::vector<int64_t> *tokens_ids, std::vector<std::string> *tokens) {
// in zipvoice tags are all in upper case
std::string tag = ToUpperAscii(words);
if (token2id.count(tag)) {
tokens_ids->push_back(token2id.at(tag));
tokens->push_back(tag);
} else {
SHERPA_ONNX_LOGE("Skip unknown tag %s", tag.c_str());
}
}
static void TokenizePinyin(
const std::string &words, const cppinyin::PinyinEncoder *pinyin_encoder,
const std::unordered_map<std::string, int32_t> &token2id,
std::vector<int64_t> *tokens_ids, std::vector<std::string> *tokens) {
// words are in the form of <ha3>, <ha4>
std::string pinyin = words.substr(1, words.size() - 2);
if (!pinyin.empty()) {
if (pinyin[pinyin.size() - 1] == '5') {
pinyin = pinyin.substr(0, pinyin.size() - 1); // remove the tone
}
if (pinyin_encoder->ValidPinyin(pinyin, "number" /*tone*/)) {
ProcessPinyin(pinyin, pinyin_encoder, token2id, tokens_ids, tokens);
} else {
SHERPA_ONNX_LOGE("Invalid pinyin %s", pinyin.c_str());
}
}
}
OfflineTtsZipvoiceFrontend::OfflineTtsZipvoiceFrontend(
const std::string &tokens, const std::string &data_dir,
const std::string &pinyin_dict,
const OfflineTtsZipvoiceModelMetaData &meta_data, bool debug)
: debug_(debug), meta_data_(meta_data) {
std::ifstream is(tokens);
token2id_ = ReadTokens(is);
if (meta_data_.use_pinyin) {
pinyin_encoder_ = std::make_unique<cppinyin::PinyinEncoder>(pinyin_dict);
} else {
pinyin_encoder_ = nullptr;
}
if (meta_data_.use_espeak) {
// We should copy the directory of espeak-ng-data from the asset to
// some internal or external storage and then pass the directory to
// data_dir.
InitEspeak(data_dir);
}
}
template <typename Manager>
OfflineTtsZipvoiceFrontend::OfflineTtsZipvoiceFrontend(
Manager *mgr, const std::string &tokens, const std::string &data_dir,
const std::string &pinyin_dict,
const OfflineTtsZipvoiceModelMetaData &meta_data, bool debug)
: debug_(debug), meta_data_(meta_data) {
auto buf = ReadFile(mgr, tokens);
std::istrstream is(buf.data(), buf.size());
token2id_ = ReadTokens(is);
if (meta_data_.use_pinyin) {
auto buf = ReadFile(mgr, pinyin_dict);
std::istringstream iss(std::string(buf.begin(), buf.end()));
pinyin_encoder_ = std::make_unique<cppinyin::PinyinEncoder>(iss);
} else {
pinyin_encoder_ = nullptr;
}
if (meta_data_.use_espeak) {
// We should copy the directory of espeak-ng-data from the asset to
// some internal or external storage and then pass the directory to
// data_dir.
InitEspeak(data_dir);
}
}
std::vector<TokenIDs> OfflineTtsZipvoiceFrontend::ConvertTextToTokenIds(
const std::string &_text, const std::string &voice) const {
std::string text = _text;
if (meta_data_.use_espeak) {
text = ToLowerAscii(_text);
}
text = MapPunctuations(text, punct_map_);
auto wstext = ToWideString(text);
std::vector<std::string> parts;
// Match <...>, [...], or single character
std::wregex part_pattern(LR"([<\[].*?[>\]]|.)");
auto words_begin =
std::wsregex_iterator(wstext.begin(), wstext.end(), part_pattern);
auto words_end = std::wsregex_iterator();
for (std::wsregex_iterator i = words_begin; i != words_end; ++i) {
parts.push_back(ToString(i->str()));
}
// types are en, zh, tag, pinyin, other
// tag is [...]
// pinyin is <...>
// other is any other text that does not match the above, normally numbers and
// punctuations
std::vector<std::string> types;
for (auto &word : parts) {
if (word.size() == 1 && std::isalpha(word[0])) {
// single character, e.g., 'a', 'b', 'c'
types.push_back("en");
} else if (word.size() > 1 && word[0] == '<' && word.back() == '>') {
// e.g., <ha3>, <ha4>
types.push_back("pinyin");
} else if (word.size() > 1 && word[0] == '[' && word.back() == ']') {
types.push_back("tag");
} else if (ContainsCJK(word)) { // word contains one CJK characters
types.push_back("zh");
} else {
types.push_back("other");
}
}
std::vector<std::pair<std::string, std::string>> parts_with_types;
std::ostringstream oss;
std::string t_lang;
oss.str("");
std::ostringstream debug_oss;
if (debug_) {
debug_oss << "Text : " << _text << ", Parts with types: \n";
}
for (int32_t i = 0; i < types.size(); ++i) {
if (i == 0) {
oss << parts[i];
t_lang = types[i];
} else {
if (t_lang == "other" && (types[i] != "tag" && types[i] != "pinyin")) {
// combine into current type if the previous part is "other"
// do not combine with "tag" or "pinyin"
oss << parts[i];
t_lang = types[i];
} else {
if ((t_lang == types[i] || types[i] == "other") && t_lang != "pinyin" &&
t_lang != "tag") {
// same language or other, continue
// do not combine other into "pinyin" or "tag"
oss << parts[i];
} else {
// different language, start a new sentence
std::string part = oss.str();
oss.str("");
parts_with_types.emplace_back(part, t_lang);
if (debug_) {
debug_oss << "(" << part << ", " << t_lang << "),";
}
oss << parts[i];
t_lang = types[i];
}
}
}
}
std::string part = oss.str();
oss.str("");
parts_with_types.emplace_back(part, t_lang);
if (debug_) {
debug_oss << "(" << part << ", " << t_lang << ")\n";
SHERPA_ONNX_LOGE("%s", debug_oss.str().c_str());
debug_oss.str("");
}
std::vector<int64_t> token_ids;
std::vector<std::string> tokens; // for debugging
for (const auto &pt : parts_with_types) {
if (pt.second == "zh") {
TokenizeZh(pt.first, pinyin_encoder_.get(), token2id_, &token_ids,
&tokens);
} else if (pt.second == "en") {
TokenizeEn(pt.first, token2id_, voice, &token_ids, &tokens);
} else if (pt.second == "pinyin") {
TokenizePinyin(pt.first, pinyin_encoder_.get(), token2id_, &token_ids,
&tokens);
} else if (pt.second == "tag") {
TokenizeTag(pt.first, token2id_, &token_ids, &tokens);
} else {
SHERPA_ONNX_LOGE("Unexpected type: %s", pt.second.c_str());
exit(-1);
}
}
if (debug_) {
debug_oss << "Tokens and IDs: \n";
for (int32_t i = 0; i < tokens.size(); i++) {
debug_oss << "(" << tokens[i] << ", " << token_ids[i] << "),";
}
debug_oss << "\n";
SHERPA_ONNX_LOGE("%s", debug_oss.str().c_str());
}
std::vector<TokenIDs> ans;
ans.push_back(TokenIDs(std::move(token_ids)));
return ans;
}
#if __ANDROID_API__ >= 9
template OfflineTtsZipvoiceFrontend::OfflineTtsZipvoiceFrontend(
AAssetManager *mgr, const std::string &tokens, const std::string &data_dir,
const std::string &pinyin_dict,
const OfflineTtsZipvoiceModelMetaData &meta_data);
#endif
#if __OHOS__
template OfflineTtsZipvoiceFrontend::OfflineTtsZipvoiceFrontend(
NativeResourceManager *mgr, const std::string &tokens,
const std::string &data_dir, const std::string &pinyin_dict,
const OfflineTtsZipvoiceModelMetaData &meta_data);
#endif
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-tts-zipvoice-frontend.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_FRONTEND_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_FRONTEND_H_
#include <cstdint>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "cppinyin/csrc/cppinyin.h"
#include "sherpa-onnx/csrc/offline-tts-frontend.h"
#include "sherpa-onnx/csrc/offline-tts-zipvoice-model-meta-data.h"
namespace sherpa_onnx {
class OfflineTtsZipvoiceFrontend : public OfflineTtsFrontend {
public:
OfflineTtsZipvoiceFrontend(const std::string &tokens,
const std::string &data_dir,
const std::string &pinyin_dict,
const OfflineTtsZipvoiceModelMetaData &meta_data,
bool debug = false);
template <typename Manager>
OfflineTtsZipvoiceFrontend(Manager *mgr, const std::string &tokens,
const std::string &data_dir,
const std::string &pinyin_dict,
const OfflineTtsZipvoiceModelMetaData &meta_data,
bool debug = false);
/** Convert a string to token IDs.
*
* @param text The input text.
* Example 1: "This is the first sample sentence; this is the
* second one." Example 2: "这是第一句。这是第二句。"
* @param voice Optional. It is for espeak-ng.
*
* @return Return a vector-of-vector of token IDs. Each subvector contains
* a sentence that can be processed independently.
* If a frontend does not support splitting the text into
* sentences, the resulting vector contains only one subvector.
*/
std::vector<TokenIDs> ConvertTextToTokenIds(
const std::string &text, const std::string &voice = "") const override;
private:
bool debug_ = false;
std::unordered_map<std::string, int32_t> token2id_;
const std::unordered_map<std::string, std::string> punct_map_ = {
{",", ","}, {"。", "."}, {"!", "!"}, {"?", "?"}, {";", ";"},
{":", ":"}, {"、", ","}, {"‘", "'"}, {"“", "\""}, {"”", "\""},
{"’", "'"}, {"⋯", "…"}, {"···", "…"}, {"・・・", "…"}, {"...", "…"}};
OfflineTtsZipvoiceModelMetaData meta_data_;
std::unique_ptr<cppinyin::PinyinEncoder> pinyin_encoder_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_FRONTEND_H_
... ...
// sherpa-onnx/csrc/offline-tts-zipvoice-impl.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_IMPL_H_
#include <cmath>
#include <memory>
#include <string>
#include <strstream>
#include <utility>
#include <vector>
#include "kaldi-native-fbank/csrc/mel-computations.h"
#include "kaldi-native-fbank/csrc/stft.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-tts-frontend.h"
#include "sherpa-onnx/csrc/offline-tts-impl.h"
#include "sherpa-onnx/csrc/offline-tts-zipvoice-frontend.h"
#include "sherpa-onnx/csrc/offline-tts-zipvoice-model-config.h"
#include "sherpa-onnx/csrc/offline-tts-zipvoice-model.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/resample.h"
#include "sherpa-onnx/csrc/vocoder.h"
namespace sherpa_onnx {
class OfflineTtsZipvoiceImpl : public OfflineTtsImpl {
public:
explicit OfflineTtsZipvoiceImpl(const OfflineTtsConfig &config)
: config_(config),
model_(std::make_unique<OfflineTtsZipvoiceModel>(config.model)),
vocoder_(Vocoder::Create(config.model)) {
InitFrontend();
}
template <typename Manager>
OfflineTtsZipvoiceImpl(Manager *mgr, const OfflineTtsConfig &config)
: config_(config),
model_(std::make_unique<OfflineTtsZipvoiceModel>(mgr, config.model)),
vocoder_(Vocoder::Create(mgr, config.model)) {
InitFrontend(mgr);
}
int32_t SampleRate() const override {
return model_->GetMetaData().sample_rate;
}
GeneratedAudio Generate(
const std::string &text, const std::string &prompt_text,
const std::vector<float> &prompt_samples, int32_t sample_rate,
float speed, int32_t num_steps,
GeneratedAudioCallback callback = nullptr) const override {
std::vector<TokenIDs> text_token_ids =
frontend_->ConvertTextToTokenIds(text);
std::vector<TokenIDs> prompt_token_ids =
frontend_->ConvertTextToTokenIds(prompt_text);
if (text_token_ids.empty() ||
(text_token_ids.size() == 1 && text_token_ids[0].tokens.empty())) {
#if __OHOS__
SHERPA_ONNX_LOGE("Failed to convert '%{public}s' to token IDs",
text.c_str());
#else
SHERPA_ONNX_LOGE("Failed to convert '%s' to token IDs", text.c_str());
#endif
return {};
}
if (prompt_token_ids.empty() ||
(prompt_token_ids.size() == 1 && prompt_token_ids[0].tokens.empty())) {
#if __OHOS__
SHERPA_ONNX_LOGE(
"Failed to convert prompt text '%{public}s' to token IDs",
prompt_text.c_str());
#else
SHERPA_ONNX_LOGE("Failed to convert prompt text '%s' to token IDs",
prompt_text.c_str());
#endif
return {};
}
// we assume batch size is 1
std::vector<int64_t> tokens = text_token_ids[0].tokens;
std::vector<int64_t> prompt_tokens = prompt_token_ids[0].tokens;
return Process(tokens, prompt_tokens, prompt_samples, sample_rate, speed,
num_steps);
}
private:
template <typename Manager>
void InitFrontend(Manager *mgr) {
const auto &meta_data = model_->GetMetaData();
frontend_ = std::make_unique<OfflineTtsZipvoiceFrontend>(
mgr, config_.model.zipvoice.tokens, config_.model.zipvoice.data_dir,
config_.model.zipvoice.pinyin_dict, meta_data, config_.model.debug);
}
void InitFrontend() {
const auto &meta_data = model_->GetMetaData();
if (meta_data.use_pinyin && config_.model.zipvoice.pinyin_dict.empty()) {
SHERPA_ONNX_LOGE(
"Please provide --zipvoice-pinyin-dict for converting Chinese into "
"pinyin.");
exit(-1);
}
if (meta_data.use_espeak && config_.model.zipvoice.data_dir.empty()) {
SHERPA_ONNX_LOGE("Please provide --zipvoice-data-dir for espeak-ng.");
exit(-1);
}
frontend_ = std::make_unique<OfflineTtsZipvoiceFrontend>(
config_.model.zipvoice.tokens, config_.model.zipvoice.data_dir,
config_.model.zipvoice.pinyin_dict, meta_data, config_.model.debug);
}
std::vector<int32_t> ComputeMelSpectrogram(
const std::vector<float> &_samples, int32_t sample_rate,
std::vector<float> *prompt_features) const {
const auto &meta = model_->GetMetaData();
if (sample_rate != meta.sample_rate) {
SHERPA_ONNX_LOGE(
"Creating a resampler:\n"
" in_sample_rate: %d\n"
" output_sample_rate: %d\n",
sample_rate, static_cast<int32_t>(meta.sample_rate));
float min_freq = std::min<int32_t>(sample_rate, meta.sample_rate);
float lowpass_cutoff = 0.99 * 0.5 * min_freq;
int32_t lowpass_filter_width = 6;
auto resampler = std::make_unique<LinearResample>(
sample_rate, meta.sample_rate, lowpass_cutoff, lowpass_filter_width);
std::vector<float> samples;
resampler->Resample(_samples.data(), _samples.size(), true, &samples);
return ComputeMelSpectrogram(samples, prompt_features);
} else {
// Use the original samples if the sample rate matches
return ComputeMelSpectrogram(_samples, prompt_features);
}
}
std::vector<int32_t> ComputeMelSpectrogram(
const std::vector<float> &samples,
std::vector<float> *prompt_features) const {
const auto &meta = model_->GetMetaData();
int32_t sample_rate = meta.sample_rate;
int32_t n_fft = meta.n_fft;
int32_t hop_length = meta.hop_length;
int32_t win_length = meta.window_length;
int32_t num_mels = meta.num_mels;
knf::StftConfig stft_config;
stft_config.n_fft = n_fft;
stft_config.hop_length = hop_length;
stft_config.win_length = win_length;
stft_config.window_type = "hann";
stft_config.center = true;
knf::Stft stft(stft_config);
auto stft_result = stft.Compute(samples.data(), samples.size());
int32_t num_frames = stft_result.num_frames;
int32_t fft_bins = n_fft / 2 + 1;
knf::FrameExtractionOptions frame_opts;
frame_opts.samp_freq = sample_rate;
frame_opts.frame_length_ms = win_length * 1000 / sample_rate;
frame_opts.frame_shift_ms = hop_length * 1000 / sample_rate;
frame_opts.window_type = "hanning";
knf::MelBanksOptions mel_opts;
mel_opts.num_bins = num_mels;
mel_opts.low_freq = 0;
mel_opts.high_freq = sample_rate / 2;
mel_opts.is_librosa = true;
mel_opts.use_slaney_mel_scale = false;
mel_opts.norm = "";
knf::MelBanks mel_banks(mel_opts, frame_opts, 1.0f);
prompt_features->clear();
prompt_features->reserve(num_frames * num_mels);
for (int32_t i = 0; i < num_frames; ++i) {
std::vector<float> magnitude_spectrum(fft_bins);
for (int32_t k = 0; k < fft_bins; ++k) {
float real = stft_result.real[i * fft_bins + k];
float imag = stft_result.imag[i * fft_bins + k];
magnitude_spectrum[k] = std::sqrt(real * real + imag * imag);
}
std::vector<float> mel_features(num_mels, 0.0f);
mel_banks.Compute(magnitude_spectrum.data(), mel_features.data());
for (auto &v : mel_features) {
v = std::log(v + 1e-10f);
}
// Instead of push_back a vector, push elements individually
prompt_features->insert(prompt_features->end(), mel_features.begin(),
mel_features.end());
}
if (num_frames == 0) {
SHERPA_ONNX_LOGE("No frames extracted from the prompt audio");
return {0, 0};
} else {
return {num_frames, num_mels};
}
}
GeneratedAudio Process(const std::vector<int64_t> &tokens,
const std::vector<int64_t> &prompt_tokens,
const std::vector<float> &prompt_samples,
int32_t sample_rate, float speed,
int num_steps) const {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 2> tokens_shape = {1,
static_cast<int64_t>(tokens.size())};
Ort::Value tokens_tensor = Ort::Value::CreateTensor(
memory_info, const_cast<int64_t *>(tokens.data()), tokens.size(),
tokens_shape.data(), tokens_shape.size());
std::array<int64_t, 2> prompt_tokens_shape = {
1, static_cast<int64_t>(prompt_tokens.size())};
Ort::Value prompt_tokens_tensor = Ort::Value::CreateTensor(
memory_info, const_cast<int64_t *>(prompt_tokens.data()),
prompt_tokens.size(), prompt_tokens_shape.data(),
prompt_tokens_shape.size());
float target_rms = config_.model.zipvoice.target_rms;
float feat_scale = config_.model.zipvoice.feat_scale;
// Scale prompt_samples
std::vector<float> prompt_samples_scaled = prompt_samples;
float prompt_rms = 0.0f;
double sum_sq = 0.0;
// Compute RMS of prompt_samples
for (float s : prompt_samples_scaled) {
sum_sq += s * s;
}
prompt_rms = std::sqrt(sum_sq / prompt_samples_scaled.size());
if (prompt_rms < target_rms && prompt_rms > 0.0f) {
float scale = target_rms / static_cast<float>(prompt_rms);
for (auto &s : prompt_samples_scaled) {
s *= scale;
}
}
std::vector<float> prompt_features;
auto res_shape = ComputeMelSpectrogram(prompt_samples_scaled, sample_rate,
&prompt_features);
int32_t num_frames = res_shape[0];
int32_t mel_dim = res_shape[1];
if (feat_scale != 1.0f) {
for (auto &item : prompt_features) {
item *= feat_scale;
}
}
std::array<int64_t, 3> shape = {1, num_frames, mel_dim};
auto prompt_features_tensor = Ort::Value::CreateTensor(
memory_info, prompt_features.data(), prompt_features.size(),
shape.data(), shape.size());
Ort::Value mel =
model_->Run(std::move(tokens_tensor), std::move(prompt_tokens_tensor),
std::move(prompt_features_tensor), speed, num_steps);
// Assume mel_shape = {1, T, C}
std::vector<int64_t> mel_shape = mel.GetTensorTypeAndShapeInfo().GetShape();
int64_t T = mel_shape[1], C = mel_shape[2];
float *mel_data = mel.GetTensorMutableData<float>();
std::vector<float> mel_permuted(C * T);
for (int64_t c = 0; c < C; ++c) {
for (int64_t t = 0; t < T; ++t) {
int64_t src_idx = t * C + c; // src: [T, C] (row major)
int64_t dst_idx = c * T + t; // dst: [C, T] (row major)
mel_permuted[dst_idx] = mel_data[src_idx] / feat_scale;
}
}
std::array<int64_t, 3> new_shape = {1, C, T};
Ort::Value mel_new = Ort::Value::CreateTensor<float>(
memory_info, mel_permuted.data(), mel_permuted.size(), new_shape.data(),
new_shape.size());
GeneratedAudio ans;
ans.samples = vocoder_->Run(std::move(mel_new));
ans.sample_rate = model_->GetMetaData().sample_rate;
if (prompt_rms < target_rms && target_rms > 0.0f) {
float scale = prompt_rms / target_rms;
for (auto &s : ans.samples) {
s *= scale;
}
}
return ans;
}
private:
OfflineTtsConfig config_;
std::unique_ptr<OfflineTtsZipvoiceModel> model_;
std::unique_ptr<Vocoder> vocoder_;
std::unique_ptr<OfflineTtsFrontend> frontend_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_IMPL_H_
... ...
// sherpa-onnx/csrc/offline-tts-zipvoice-model-config.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-tts-zipvoice-model-config.h"
#include <vector>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OfflineTtsZipvoiceModelConfig::Register(ParseOptions *po) {
po->Register("zipvoice-tokens", &tokens,
"Path to tokens.txt for ZipVoice models");
po->Register("zipvoice-data-dir", &data_dir,
"Path to the directory containing dict for espeak-ng.");
po->Register("zipvoice-pinyin-dict", &pinyin_dict,
"Path to the pinyin dictionary for cppinyin (i.e converting "
"Chinese into phones).");
po->Register("zipvoice-text-model", &text_model,
"Path to zipvoice text model");
po->Register("zipvoice-flow-matching-model", &flow_matching_model,
"Path to zipvoice flow-matching model");
po->Register("zipvoice-vocoder", &vocoder, "Path to zipvoice vocoder");
po->Register("zipvoice-feat-scale", &feat_scale,
"Feature scale for ZipVoice (default: 0.1)");
po->Register("zipvoice-t-shift", &t_shift,
"Shift t to smaller ones if t_shift < 1.0 (default: 0.5)");
po->Register(
"zipvoice-target-rms", &target_rms,
"Target speech normalization rms value for ZipVoice (default: 0.1)");
po->Register(
"zipvoice-guidance-scale", &guidance_scale,
"The scale of classifier-free guidance during inference for ZipVoice "
"(default: 1.0)");
}
bool OfflineTtsZipvoiceModelConfig::Validate() const {
if (tokens.empty()) {
SHERPA_ONNX_LOGE("Please provide --zipvoice-tokens");
return false;
}
if (!FileExists(tokens)) {
SHERPA_ONNX_LOGE("--zipvoice-tokens: '%s' does not exist", tokens.c_str());
return false;
}
if (text_model.empty()) {
SHERPA_ONNX_LOGE("Please provide --zipvoice-text-model");
return false;
}
if (!FileExists(text_model)) {
SHERPA_ONNX_LOGE("--zipvoice-text-model: '%s' does not exist",
text_model.c_str());
return false;
}
if (flow_matching_model.empty()) {
SHERPA_ONNX_LOGE("Please provide --zipvoice-flow-matching-model");
return false;
}
if (!FileExists(flow_matching_model)) {
SHERPA_ONNX_LOGE("--zipvoice-flow-matching-model: '%s' does not exist",
flow_matching_model.c_str());
return false;
}
if (vocoder.empty()) {
SHERPA_ONNX_LOGE("Please provide --zipvoice-vocoder");
return false;
}
if (!FileExists(vocoder)) {
SHERPA_ONNX_LOGE("--zipvoice-vocoder: '%s' does not exist",
vocoder.c_str());
return false;
}
if (!data_dir.empty()) {
std::vector<std::string> required_files = {
"phontab",
"phonindex",
"phondata",
"intonations",
};
for (const auto &f : required_files) {
if (!FileExists(data_dir + "/" + f)) {
SHERPA_ONNX_LOGE(
"'%s/%s' does not exist. Please check zipvoice-data-dir",
data_dir.c_str(), f.c_str());
return false;
}
}
}
if (!pinyin_dict.empty() && !FileExists(pinyin_dict)) {
SHERPA_ONNX_LOGE("--zipvoice-pinyin-dict: '%s' does not exist",
pinyin_dict.c_str());
return false;
}
if (feat_scale <= 0) {
SHERPA_ONNX_LOGE("--zipvoice-feat-scale must be positive. Given: %f",
feat_scale);
return false;
}
if (t_shift < 0) {
SHERPA_ONNX_LOGE("--zipvoice-t-shift must be non-negative. Given: %f",
t_shift);
return false;
}
if (target_rms <= 0) {
SHERPA_ONNX_LOGE("--zipvoice-target-rms must be positive. Given: %f",
target_rms);
return false;
}
if (guidance_scale <= 0) {
SHERPA_ONNX_LOGE("--zipvoice-guidance-scale must be positive. Given: %f",
guidance_scale);
return false;
}
return true;
}
std::string OfflineTtsZipvoiceModelConfig::ToString() const {
std::ostringstream os;
os << "OfflineTtsZipvoiceModelConfig(";
os << "tokens=\"" << tokens << "\", ";
os << "text_model=\"" << text_model << "\", ";
os << "flow_matching_model=\"" << flow_matching_model << "\", ";
os << "vocoder=\"" << vocoder << "\", ";
os << "data_dir=\"" << data_dir << "\", ";
os << "pinyin_dict=\"" << pinyin_dict << "\", ";
os << "feat_scale=" << feat_scale << ", ";
os << "t_shift=" << t_shift << ", ";
os << "target_rms=" << target_rms << ", ";
os << "guidance_scale=" << guidance_scale << ")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-tts-zipvoice-model-config.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_CONFIG_H_
#include <cstdint>
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OfflineTtsZipvoiceModelConfig {
std::string tokens;
std::string text_model;
std::string flow_matching_model;
std::string vocoder;
// If data_dir is given, lexicon is ignored
// data_dir is for piper-phonemize, which uses espeak-ng
std::string data_dir;
// Used for converting Chinese characters to pinyin
std::string pinyin_dict;
float feat_scale = 0.1;
float t_shift = 0.5;
float target_rms = 0.1;
float guidance_scale = 1.0;
OfflineTtsZipvoiceModelConfig() = default;
OfflineTtsZipvoiceModelConfig(
const std::string &tokens, const std::string &text_model,
const std::string &flow_matching_model, const std::string &vocoder,
const std::string &data_dir, const std::string &pinyin_dict,
float feat_scale = 0.1, float t_shift = 0.5, float target_rms = 0.1,
float guidance_scale = 1.0)
: tokens(tokens),
text_model(text_model),
flow_matching_model(flow_matching_model),
vocoder(vocoder),
data_dir(data_dir),
pinyin_dict(pinyin_dict),
feat_scale(feat_scale),
t_shift(t_shift),
target_rms(target_rms),
guidance_scale(guidance_scale) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_CONFIG_H_
... ...
// sherpa-onnx/csrc/offline-tts-zipvoice-model-meta-data.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_META_DATA_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_META_DATA_H_
#include <cstdint>
#include <string>
namespace sherpa_onnx {
// If you are not sure what each field means, please
// have a look of the Python file in the model directory that
// you have downloaded.
struct OfflineTtsZipvoiceModelMetaData {
int32_t version = 1;
int32_t feat_dim = 100;
int32_t sample_rate = 24000;
int32_t n_fft = 1024;
int32_t hop_length = 256;
int32_t window_length = 1024;
int32_t num_mels = 100;
int32_t use_espeak = 1;
int32_t use_pinyin = 1;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_META_DATA_H_
... ...
// sherpa-onnx/csrc/offline-tts-zipvoice-model.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-tts-zipvoice-model.h"
#include <algorithm>
#include <iostream>
#include <random>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
class OfflineTtsZipvoiceModel::Impl {
public:
explicit Impl(const OfflineTtsModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto text_buf = ReadFile(config.zipvoice.text_model);
auto fm_buf = ReadFile(config.zipvoice.flow_matching_model);
Init(text_buf.data(), text_buf.size(), fm_buf.data(), fm_buf.size());
}
template <typename Manager>
Impl(Manager *mgr, const OfflineTtsModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto text_buf = ReadFile(mgr, config.zipvoice.text_model);
auto fm_buf = ReadFile(mgr, config.zipvoice.flow_matching_model);
Init(text_buf.data(), text_buf.size(), fm_buf.data(), fm_buf.size());
}
const OfflineTtsZipvoiceModelMetaData &GetMetaData() const {
return meta_data_;
}
Ort::Value Run(Ort::Value tokens, Ort::Value prompt_tokens,
Ort::Value prompt_features, float speed, int32_t num_steps) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::vector<int64_t> tokens_shape =
tokens.GetTensorTypeAndShapeInfo().GetShape();
int64_t batch_size = tokens_shape[0];
if (batch_size != 1) {
SHERPA_ONNX_LOGE("Support only batch_size == 1. Given: %d",
static_cast<int32_t>(batch_size));
exit(-1);
}
std::vector<int64_t> prompt_feat_shape =
prompt_features.GetTensorTypeAndShapeInfo().GetShape();
int64_t prompt_feat_len = prompt_feat_shape[1];
int64_t prompt_feat_len_shape = 1;
Ort::Value prompt_feat_len_tensor = Ort::Value::CreateTensor<int64_t>(
memory_info, &prompt_feat_len, 1, &prompt_feat_len_shape, 1);
int64_t speed_shape = 1;
Ort::Value speed_tensor = Ort::Value::CreateTensor<float>(
memory_info, &speed, 1, &speed_shape, 1);
std::vector<Ort::Value> text_inputs;
text_inputs.reserve(4);
text_inputs.push_back(std::move(tokens));
text_inputs.push_back(std::move(prompt_tokens));
text_inputs.push_back(std::move(prompt_feat_len_tensor));
text_inputs.push_back(std::move(speed_tensor));
// forward text-encoder
auto text_out =
text_sess_->Run({}, text_input_names_ptr_.data(), text_inputs.data(),
text_inputs.size(), text_output_names_ptr_.data(),
text_output_names_ptr_.size());
Ort::Value &text_condition = text_out[0];
std::vector<int64_t> text_cond_shape =
text_condition.GetTensorTypeAndShapeInfo().GetShape();
int64_t num_frames = text_cond_shape[1];
int64_t feat_dim = meta_data_.feat_dim;
std::vector<float> x_data(batch_size * num_frames * feat_dim);
std::default_random_engine rng(std::random_device{}());
std::normal_distribution<float> norm(0, 1);
for (auto &v : x_data) v = norm(rng);
std::vector<int64_t> x_shape = {batch_size, num_frames, feat_dim};
Ort::Value x = Ort::Value::CreateTensor<float>(
memory_info, x_data.data(), x_data.size(), x_shape.data(),
x_shape.size());
std::vector<float> speech_cond_data(batch_size * num_frames * feat_dim,
0.0f);
const float *src = prompt_features.GetTensorData<float>();
float *dst = speech_cond_data.data();
std::memcpy(dst, src,
batch_size * prompt_feat_len * feat_dim * sizeof(float));
std::vector<int64_t> speech_cond_shape = {batch_size, num_frames, feat_dim};
Ort::Value speech_condition = Ort::Value::CreateTensor<float>(
memory_info, speech_cond_data.data(), speech_cond_data.size(),
speech_cond_shape.data(), speech_cond_shape.size());
float t_shift = config_.zipvoice.t_shift;
float guidance_scale = config_.zipvoice.guidance_scale;
std::vector<float> timesteps(num_steps + 1);
for (int32_t i = 0; i <= num_steps; ++i) {
float t = static_cast<float>(i) / num_steps;
timesteps[i] = t_shift * t / (1.0f + (t_shift - 1.0f) * t);
}
int64_t guidance_scale_shape = 1;
Ort::Value guidance_scale_tensor = Ort::Value::CreateTensor<float>(
memory_info, &guidance_scale, 1, &guidance_scale_shape, 1);
std::vector<Ort::Value> fm_inputs;
fm_inputs.reserve(5);
// fm_inputs[0] is t tensor, will set in for loop
fm_inputs.emplace_back(nullptr);
fm_inputs.push_back(std::move(x));
fm_inputs.push_back(std::move(text_condition));
fm_inputs.push_back(std::move(speech_condition));
fm_inputs.push_back(std::move(guidance_scale_tensor));
for (int32_t step = 0; step < num_steps; ++step) {
float t_val = timesteps[step];
int64_t t_shape = 1;
Ort::Value t_tensor =
Ort::Value::CreateTensor<float>(memory_info, &t_val, 1, &t_shape, 1);
fm_inputs[0] = std::move(t_tensor);
auto fm_out = fm_sess_->Run(
{}, fm_input_names_ptr_.data(), fm_inputs.data(), fm_inputs.size(),
fm_output_names_ptr_.data(), fm_output_names_ptr_.size());
Ort::Value &v = fm_out[0];
float delta_t = timesteps[step + 1] - timesteps[step];
float *x_ptr = fm_inputs[1].GetTensorMutableData<float>();
const float *v_ptr = v.GetTensorData<float>();
int64_t N = batch_size * num_frames * feat_dim;
for (int64_t i = 0; i < N; ++i) {
x_ptr[i] += v_ptr[i] * delta_t;
}
}
int64_t keep_frames = num_frames - prompt_feat_len;
std::vector<float> out_data(batch_size * keep_frames * feat_dim);
x = std::move(fm_inputs[1]);
const float *x_ptr = x.GetTensorData<float>();
for (int64_t b = 0; b < batch_size; ++b) {
std::memcpy(out_data.data() + b * keep_frames * feat_dim,
x_ptr + (b * num_frames + prompt_feat_len) * feat_dim,
keep_frames * feat_dim * sizeof(float));
}
std::vector<int64_t> out_shape = {batch_size, keep_frames, feat_dim};
return Ort::Value::CreateTensor<float>(memory_info, out_data.data(),
out_data.size(), out_shape.data(),
out_shape.size());
}
private:
void Init(void *text_model_data, size_t text_model_data_length,
void *fm_model_data, size_t fm_model_data_length) {
// Init text-encoder model
text_sess_ = std::make_unique<Ort::Session>(
env_, text_model_data, text_model_data_length, sess_opts_);
GetInputNames(text_sess_.get(), &text_input_names_, &text_input_names_ptr_);
GetOutputNames(text_sess_.get(), &text_output_names_,
&text_output_names_ptr_);
// Init flow-matching model
fm_sess_ = std::make_unique<Ort::Session>(env_, fm_model_data,
fm_model_data_length, sess_opts_);
GetInputNames(fm_sess_.get(), &fm_input_names_, &fm_input_names_ptr_);
GetOutputNames(fm_sess_.get(), &fm_output_names_, &fm_output_names_ptr_);
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
Ort::ModelMetadata meta_data = text_sess_->GetModelMetadata();
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.use_espeak, "use_espeak",
1);
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.use_pinyin, "use_pinyin",
1);
meta_data = fm_sess_->GetModelMetadata();
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.version, "version", 1);
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.feat_dim, "feat_dim",
100);
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.sample_rate,
"sample_rate", 24000);
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.n_fft, "n_fft", 1024);
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.hop_length, "hop_length",
256);
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.window_length,
"window_length", 1024);
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.num_mels, "num_mels",
100);
if (config_.debug) {
std::ostringstream os;
os << "---zipvoice text-encoder model---\n";
Ort::ModelMetadata text_meta_data = text_sess_->GetModelMetadata();
PrintModelMetadata(os, text_meta_data);
os << "----------input names----------\n";
int32_t i = 0;
for (const auto &s : text_input_names_) {
os << i << " " << s << "\n";
++i;
}
os << "----------output names----------\n";
i = 0;
for (const auto &s : text_output_names_) {
os << i << " " << s << "\n";
++i;
}
os << "---zipvoice flow-matching model---\n";
PrintModelMetadata(os, meta_data);
os << "----------input names----------\n";
i = 0;
for (const auto &s : fm_input_names_) {
os << i << " " << s << "\n";
++i;
}
os << "----------output names----------\n";
i = 0;
for (const auto &s : fm_output_names_) {
os << i << " " << s << "\n";
++i;
}
#if __OHOS__
SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str());
#else
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
#endif
}
}
private:
OfflineTtsModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> text_sess_;
std::unique_ptr<Ort::Session> fm_sess_;
std::vector<std::string> text_input_names_;
std::vector<const char *> text_input_names_ptr_;
std::vector<std::string> text_output_names_;
std::vector<const char *> text_output_names_ptr_;
std::vector<std::string> fm_input_names_;
std::vector<const char *> fm_input_names_ptr_;
std::vector<std::string> fm_output_names_;
std::vector<const char *> fm_output_names_ptr_;
OfflineTtsZipvoiceModelMetaData meta_data_;
};
OfflineTtsZipvoiceModel::OfflineTtsZipvoiceModel(
const OfflineTtsModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
template <typename Manager>
OfflineTtsZipvoiceModel::OfflineTtsZipvoiceModel(
Manager *mgr, const OfflineTtsModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
OfflineTtsZipvoiceModel::~OfflineTtsZipvoiceModel() = default;
const OfflineTtsZipvoiceModelMetaData &OfflineTtsZipvoiceModel::GetMetaData()
const {
return impl_->GetMetaData();
}
Ort::Value OfflineTtsZipvoiceModel::Run(Ort::Value tokens,
Ort::Value prompt_tokens,
Ort::Value prompt_features,
float speed /*= 1.0*/,
int32_t num_steps /*= 16*/) const {
return impl_->Run(std::move(tokens), std::move(prompt_tokens),
std::move(prompt_features), speed, num_steps);
}
#if __ANDROID_API__ >= 9
template OfflineTtsZipvoiceModel::OfflineTtsZipvoiceModel(
AAssetManager *mgr, const OfflineTtsModelConfig &config);
#endif
#if __OHOS__
template OfflineTtsZipvoiceModel::OfflineTtsZipvoiceModel(
NativeResourceManager *mgr, const OfflineTtsModelConfig &config);
#endif
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-tts-zipvoice-model.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_H_
#include <memory>
#include <string>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
#include "sherpa-onnx/csrc/offline-tts-zipvoice-model-meta-data.h"
namespace sherpa_onnx {
class OfflineTtsZipvoiceModel {
public:
~OfflineTtsZipvoiceModel();
explicit OfflineTtsZipvoiceModel(const OfflineTtsModelConfig &config);
template <typename Manager>
OfflineTtsZipvoiceModel(Manager *mgr, const OfflineTtsModelConfig &config);
// Return a float32 tensor containing the mel
// of shape (batch_size, mel_dim, num_frames)
Ort::Value Run(Ort::Value tokens, Ort::Value prompt_tokens,
Ort::Value prompt_features, float speed,
int32_t num_steps) const;
const OfflineTtsZipvoiceModelMetaData &GetMetaData() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_H_
... ...
... ... @@ -196,6 +196,46 @@ GeneratedAudio OfflineTts::Generate(
#endif
}
GeneratedAudio OfflineTts::Generate(
const std::string &text, const std::string &prompt_text,
const std::vector<float> &prompt_samples, int32_t sample_rate,
float speed /*=1.0*/, int32_t num_steps /*=4*/,
GeneratedAudioCallback callback /*=nullptr*/) const {
#if !defined(_WIN32)
return impl_->Generate(text, prompt_text, prompt_samples, sample_rate, speed,
num_steps, std::move(callback));
#else
static bool printed = false;
auto utf8_text = text;
if (IsGB2312(text)) {
utf8_text = Gb2312ToUtf8(text);
if (!printed) {
SHERPA_ONNX_LOGE("Detected GB2312 encoded text! Converting it to UTF8.");
printed = true;
}
}
auto utf8_prompt_text = prompt_text;
if (IsGB2312(prompt_text)) {
utf8_prompt_text = Gb2312ToUtf8(prompt_text);
if (!printed) {
SHERPA_ONNX_LOGE(
"Detected GB2312 encoded prompt text! Converting it to UTF8.");
printed = true;
}
}
if (IsUtf8(utf8_text) && IsUtf8(utf8_prompt_text)) {
return impl_->Generate(utf8_text, utf8_prompt_text, prompt_samples,
sample_rate, speed, num_steps, std::move(callback));
} else {
SHERPA_ONNX_LOGE(
"Non UTF8 encoded string is received. You would not get expected "
"results!");
return impl_->Generate(utf8_text, utf8_prompt_text, prompt_samples,
sample_rate, speed, num_steps, std::move(callback));
}
#endif
}
int32_t OfflineTts::SampleRate() const { return impl_->SampleRate(); }
int32_t OfflineTts::NumSpeakers() const { return impl_->NumSpeakers(); }
... ...
... ... @@ -95,6 +95,26 @@ class OfflineTts {
float speed = 1.0,
GeneratedAudioCallback callback = nullptr) const;
// @param text The string to be synthesized.
// @param prompt_text The transcribe of `prompt_sampes`.
// @param prompt_samples The prompt audio samples (mono PCM floats in [-1,1]).
// @param sample_rate The sample rate of `prompt_audio` in Hz.
// @param speed The speed for the generated speech. E.g., 2 means 2x faster.
// @param num_steps The number of flow steps to generate the audio.
// @param callback If not NULL, it is called whenever config.max_num_sentences
// sentences have been processed. Note that the passed
// pointer `samples` for the callback might be invalidated
// after the callback is returned, so the caller should not
// keep a reference to it. The caller can copy the data if
// he/she wants to access the samples after the callback
// returns. The callback is called in the current thread.
GeneratedAudio Generate(const std::string &text,
const std::string &prompt_text,
const std::vector<float> &prompt_samples,
int32_t sample_rate, float speed = 1.0,
int32_t num_steps = 4,
GeneratedAudioCallback callback = nullptr) const;
// Return the sample rate of the generated audio
int32_t SampleRate() const;
... ...
// sherpa-onnx/csrc/sherpa-onnx-offline-zeroshot-tts.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include <chrono> // NOLINT
#include <fstream>
#include "sherpa-onnx/csrc/offline-tts.h"
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/wave-reader.h"
#include "sherpa-onnx/csrc/wave-writer.h"
static int32_t AudioCallback(const float * /*samples*/, int32_t n,
float progress) {
printf("sample=%d, progress=%f\n", n, progress);
return 1;
}
int main(int32_t argc, char *argv[]) {
const char *kUsageMessage = R"usage(
Offline/Non-streaming zero-shot text-to-speech with sherpa-onnx
Usage example:
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/sherpa-onnx-zipvoice-distill-zh-en-emilia.tar.bz2
tar xf sherpa-onnx-zipvoice-distill-zh-en-emilia.tar.bz2
./bin/sherpa-onnx-offline-zeroshot-tts \
--zipvoice-flow-matching-model=sherpa-onnx-zipvoice-distill-zh-en-emilia/fm_decoder.onnx \
--zipvoice-text-model=sherpa-onnx-zipvoice-distill-zh-en-emilia/text_encoder.onnx \
--zipvoice-data-dir=sherpa-onnx-zipvoice-distill-zh-en-emilia/espeak-ng-data \
--zipvoice-pinyin-dict=sherpa-onnx-zipvoice-distill-zh-en-emilia/pinyin.raw \
--zipvoice-tokens=sherpa-onnx-zipvoice-distill-zh-en-emilia/tokens.txt \
--zipvoice-vocoder=sherpa-onnx-zipvoice-distill-zh-en-emilia/vocos_24khz.onnx \
--prompt-audio=sherpa-onnx-zipvoice-distill-zh-en-emilia/prompt.wav \
--num-steps=4 \
--num-threads=4 \
--prompt-text="周日被我射熄火了,所以今天是周一。" \
"我是中国人民的儿子,我爱我的祖国。我得祖国是一个伟大的国家,拥有五千年的文明史。"
It will generate a file ./generated.wav as specified by --output-filename.
)usage";
sherpa_onnx::ParseOptions po(kUsageMessage);
std::string output_filename = "./generated.wav";
int32_t num_steps = 4;
float speed = 1.0;
std::string prompt_text;
std::string prompt_audio;
po.Register("output-filename", &output_filename,
"Path to save the generated audio");
po.Register("num-steps", &num_steps,
"Number of inference steps for ZipVoice (default: 4)");
po.Register("speed", &speed,
"Speech speed for ZipVoice (default: 1.0, larger=faster, "
"smaller=slower)");
po.Register("prompt-text", &prompt_text, "The transcribe of prompt_samples.");
po.Register("prompt-audio", &prompt_audio,
"The prompt audio file, single channel pcm. ");
sherpa_onnx::OfflineTtsConfig config;
config.Register(&po);
po.Read(argc, argv);
if (po.NumArgs() == 0) {
fprintf(stderr, "Error: Please provide the text to generate audio.\n\n");
po.PrintUsage();
exit(EXIT_FAILURE);
}
if (po.NumArgs() > 1) {
fprintf(stderr,
"Error: Accept only one positional argument. Please use single "
"quotes to wrap your text\n");
po.PrintUsage();
exit(EXIT_FAILURE);
}
if (config.model.debug) {
fprintf(stderr, "%s\n", config.model.ToString().c_str());
}
if (!config.Validate()) {
fprintf(stderr, "Errors in config!\n");
exit(EXIT_FAILURE);
}
if (prompt_text.empty() || prompt_audio.empty()) {
fprintf(stderr, "Please provide both --prompt-text and --prompt-audio\n");
exit(EXIT_FAILURE);
}
sherpa_onnx::OfflineTts tts(config);
int32_t sample_rate = -1;
bool is_ok = false;
const std::vector<float> prompt_samples =
sherpa_onnx::ReadWave(prompt_audio, &sample_rate, &is_ok);
if (!is_ok) {
fprintf(stderr, "Failed to read '%s'\n", prompt_audio.c_str());
return -1;
}
const auto begin = std::chrono::steady_clock::now();
auto audio = tts.Generate(po.GetArg(1), prompt_text, prompt_samples,
sample_rate, speed, num_steps, AudioCallback);
const auto end = std::chrono::steady_clock::now();
if (audio.samples.empty()) {
fprintf(
stderr,
"Error in generating audio. Please read previous error messages.\n");
exit(EXIT_FAILURE);
}
float elapsed_seconds =
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
.count() /
1000.;
float duration = audio.samples.size() / static_cast<float>(audio.sample_rate);
float rtf = elapsed_seconds / duration;
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
fprintf(stderr, "Audio duration: %.3f s\n", duration);
fprintf(stderr, "Real-time factor (RTF): %.3f/%.3f = %.3f\n", elapsed_seconds,
duration, rtf);
bool ok = sherpa_onnx::WriteWave(output_filename, audio.sample_rate,
audio.samples.data(), audio.samples.size());
if (!ok) {
fprintf(stderr, "Failed to write wave to %s\n", output_filename.c_str());
exit(EXIT_FAILURE);
}
fprintf(stderr, "The text is: %s.\n", po.GetArg(1).c_str());
fprintf(stderr, "Saved to %s successfully!\n", output_filename.c_str());
return 0;
}
... ...
... ... @@ -4,6 +4,9 @@
#include "sherpa-onnx/csrc/text-utils.h"
#include <regex>
#include <sstream>
#include "gtest/gtest.h"
namespace sherpa_onnx {
... ... @@ -55,7 +58,6 @@ TEST(RemoveInvalidUtf8Sequences, Case1) {
EXPECT_EQ(s.size() + 4, v.size());
}
// Tests for sanitizeUtf8
TEST(RemoveInvalidUtf8Sequences, ValidUtf8StringPassesUnchanged) {
std::string input = "Valid UTF-8 🌍";
... ...
... ... @@ -724,4 +724,62 @@ std::vector<std::string> SplitString(const std::string &s, int32_t chunk_size) {
return ans;
}
std::u32string Utf8ToUtf32(const std::string &str) {
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> conv;
return conv.from_bytes(str);
}
std::string Utf32ToUtf8(const std::u32string &str) {
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> conv;
return conv.to_bytes(str);
}
// Helper: Convert ASCII chars in a std::string to uppercase (leaves non-ASCII
// unchanged)
std::string ToUpperAscii(const std::string &str) {
std::string out = str;
for (char &c : out) {
unsigned char uc = static_cast<unsigned char>(c);
if (uc >= 'a' && uc <= 'z') {
c = static_cast<char>(uc - 'a' + 'A');
}
}
return out;
}
// Helper: Convert ASCII chars in a std::string to lowercase (leaves non-ASCII
// unchanged)
std::string ToLowerAscii(const std::string &str) {
std::string out = str;
for (char &c : out) {
unsigned char uc = static_cast<unsigned char>(c);
if (uc >= 'A' && uc <= 'Z') {
c = static_cast<char>(uc - 'A' + 'a');
}
}
return out;
}
// Detect if a codepoint is a CJK character
bool IsCJK(char32_t cp) {
return (cp >= 0x1100 && cp <= 0x11FF) || (cp >= 0x2E80 && cp <= 0xA4CF) ||
(cp >= 0xA840 && cp <= 0xD7AF) || (cp >= 0xF900 && cp <= 0xFAFF) ||
(cp >= 0xFE30 && cp <= 0xFE4F) || (cp >= 0xFF65 && cp <= 0xFFDC) ||
(cp >= 0x20000 && cp <= 0x2FFFF);
}
bool ContainsCJK(const std::string &text) {
std::u32string utf32_text = Utf8ToUtf32(text);
return ContainsCJK(utf32_text);
}
bool ContainsCJK(const std::u32string &text) {
for (char32_t cp : text) {
if (IsCJK(cp)) {
return true;
}
}
return false;
}
} // namespace sherpa_onnx
... ...
... ... @@ -149,6 +149,29 @@ bool EndsWith(const std::string &haystack, const std::string &needle);
std::vector<std::string> SplitString(const std::string &s, int32_t chunk_size);
// Converts a UTF-8 std::string to a UTF-32 std::u32string
std::u32string Utf8ToUtf32(const std::string &str);
// Converts a UTF-32 std::u32string to a UTF-8 std::string
std::string Utf32ToUtf8(const std::u32string &str);
// Helper: Convert ASCII chars in a std::string to uppercase (leaves non-ASCII
// unchanged)
std::string ToUpperAscii(const std::string &str);
// Helper: Convert ASCII chars in a std::string to lowercase (leaves non-ASCII
// unchanged)
std::string ToLowerAscii(const std::string &str);
// Detect if a codepoint is a CJK character
bool IsCJK(char32_t cp);
bool ContainsCJK(const std::string &text);
bool ContainsCJK(const std::u32string &text);
bool StringToBool(const std::string &s);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_
... ...
... ... @@ -74,7 +74,18 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
}
std::unique_ptr<Vocoder> Vocoder::Create(const OfflineTtsModelConfig &config) {
auto buffer = ReadFile(config.matcha.vocoder);
std::vector<char> buffer;
if (!config.matcha.vocoder.empty()) {
SHERPA_ONNX_LOGE("Using matcha vocoder: %s", config.matcha.vocoder.c_str());
buffer = ReadFile(config.matcha.vocoder);
} else if (!config.zipvoice.vocoder.empty()) {
SHERPA_ONNX_LOGE("Using zipvoice vocoder: %s",
config.zipvoice.vocoder.c_str());
buffer = ReadFile(config.zipvoice.vocoder);
} else {
SHERPA_ONNX_LOGE("No vocoder model provided in the config!");
exit(-1);
}
auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
switch (model_type) {
... ... @@ -94,7 +105,19 @@ std::unique_ptr<Vocoder> Vocoder::Create(const OfflineTtsModelConfig &config) {
template <typename Manager>
std::unique_ptr<Vocoder> Vocoder::Create(Manager *mgr,
const OfflineTtsModelConfig &config) {
auto buffer = ReadFile(mgr, config.matcha.vocoder);
std::vector<char> buffer;
if (!config.matcha.vocoder.empty()) {
SHERPA_ONNX_LOGE("Using matcha vocoder: %s", config.matcha.vocoder.c_str());
buffer = ReadFile(mgr, config.matcha.vocoder);
} else if (!config.zipvoice.vocoder.empty()) {
SHERPA_ONNX_LOGE("Using zipvoice vocoder: %s",
config.zipvoice.vocoder.c_str());
buffer = ReadFile(mgr, config.zipvoice.vocoder);
} else {
SHERPA_ONNX_LOGE("No vocoder model provided in the config!");
return nullptr;
}
auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
switch (model_type) {
... ...
... ... @@ -42,8 +42,16 @@ class VocosVocoder::Impl {
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config.num_threads, config.provider)),
allocator_{} {
auto buf = ReadFile(config.matcha.vocoder);
Init(buf.data(), buf.size());
std::vector<char> buffer;
if (!config.matcha.vocoder.empty()) {
buffer = ReadFile(config.matcha.vocoder);
} else if (!config.zipvoice.vocoder.empty()) {
buffer = ReadFile(config.zipvoice.vocoder);
} else {
SHERPA_ONNX_LOGE("No vocoder model provided in the config!");
exit(-1);
}
Init(buffer.data(), buffer.size());
}
template <typename Manager>
... ... @@ -52,8 +60,16 @@ class VocosVocoder::Impl {
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config.num_threads, config.provider)),
allocator_{} {
auto buf = ReadFile(mgr, config.matcha.vocoder);
Init(buf.data(), buf.size());
std::vector<char> buffer;
if (!config.matcha.vocoder.empty()) {
buffer = ReadFile(mgr, config.matcha.vocoder);
} else if (!config.zipvoice.vocoder.empty()) {
buffer = ReadFile(mgr, config.zipvoice.vocoder);
} else {
SHERPA_ONNX_LOGE("No vocoder model provided in the config!");
exit(-1);
}
Init(buffer.data(), buffer.size());
}
std::vector<float> Run(Ort::Value mel) const {
... ...
... ... @@ -72,6 +72,7 @@ if(SHERPA_ONNX_ENABLE_TTS)
offline-tts-matcha-model-config.cc
offline-tts-model-config.cc
offline-tts-vits-model-config.cc
offline-tts-zipvoice-model-config.cc
offline-tts.cc
)
endif()
... ...
... ... @@ -11,6 +11,7 @@
#include "sherpa-onnx/python/csrc/offline-tts-kokoro-model-config.h"
#include "sherpa-onnx/python/csrc/offline-tts-matcha-model-config.h"
#include "sherpa-onnx/python/csrc/offline-tts-vits-model-config.h"
#include "sherpa-onnx/python/csrc/offline-tts-zipvoice-model-config.h"
namespace sherpa_onnx {
... ... @@ -18,6 +19,7 @@ void PybindOfflineTtsModelConfig(py::module *m) {
PybindOfflineTtsVitsModelConfig(m);
PybindOfflineTtsMatchaModelConfig(m);
PybindOfflineTtsKokoroModelConfig(m);
PybindOfflineTtsZipvoiceModelConfig(m);
PybindOfflineTtsKittenModelConfig(m);
using PyClass = OfflineTtsModelConfig;
... ... @@ -27,17 +29,20 @@ void PybindOfflineTtsModelConfig(py::module *m) {
.def(py::init<const OfflineTtsVitsModelConfig &,
const OfflineTtsMatchaModelConfig &,
const OfflineTtsKokoroModelConfig &,
const OfflineTtsZipvoiceModelConfig &,
const OfflineTtsKittenModelConfig &, int32_t, bool,
const std::string &>(),
py::arg("vits") = OfflineTtsVitsModelConfig{},
py::arg("matcha") = OfflineTtsMatchaModelConfig{},
py::arg("kokoro") = OfflineTtsKokoroModelConfig{},
py::arg("zipvoice") = OfflineTtsZipvoiceModelConfig{},
py::arg("kitten") = OfflineTtsKittenModelConfig{},
py::arg("num_threads") = 1, py::arg("debug") = false,
py::arg("provider") = "cpu")
.def_readwrite("vits", &PyClass::vits)
.def_readwrite("matcha", &PyClass::matcha)
.def_readwrite("kokoro", &PyClass::kokoro)
.def_readwrite("zipvoice", &PyClass::zipvoice)
.def_readwrite("kitten", &PyClass::kitten)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)
... ...
// sherpa-onnx/python/csrc/offline-tts-zipvoice-model-config.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/offline-tts-zipvoice-model-config.h"
#include <string>
#include "sherpa-onnx/csrc/offline-tts-zipvoice-model-config.h"
namespace sherpa_onnx {
void PybindOfflineTtsZipvoiceModelConfig(py::module *m) {
using PyClass = OfflineTtsZipvoiceModelConfig;
py::class_<PyClass>(*m, "OfflineTtsZipvoiceModelConfig")
.def(py::init<>())
.def(py::init<const std::string &, const std::string &,
const std::string &, const std::string &,
const std::string &, const std::string &, float, float,
float, float>(),
py::arg("tokens"), py::arg("text_model"),
py::arg("flow_matching_model"), py::arg("vocoder"),
py::arg("data_dir") = "", py::arg("pinyin_dict") = "",
py::arg("feat_scale") = 0.1, py::arg("t_shift") = 0.5,
py::arg("target_rms") = 0.1, py::arg("guidance_scale") = 1.0)
.def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("text_model", &PyClass::text_model)
.def_readwrite("flow_matching_model", &PyClass::flow_matching_model)
.def_readwrite("vocoder", &PyClass::vocoder)
.def_readwrite("data_dir", &PyClass::data_dir)
.def_readwrite("pinyin_dict", &PyClass::pinyin_dict)
.def_readwrite("feat_scale", &PyClass::feat_scale)
.def_readwrite("t_shift", &PyClass::t_shift)
.def_readwrite("target_rms", &PyClass::target_rms)
.def_readwrite("guidance_scale", &PyClass::guidance_scale)
.def("__str__", &PyClass::ToString)
.def("validate", &PyClass::Validate);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/offline-tts-zipvoice-model-config.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOfflineTtsZipvoiceModelConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_CONFIG_H_
... ...
... ... @@ -84,6 +84,41 @@ void PybindOfflineTts(py::module *m) {
},
py::arg("text"), py::arg("sid") = 0, py::arg("speed") = 1.0,
py::arg("callback") = py::none(),
py::call_guard<py::gil_scoped_release>())
.def(
"generate",
[](const PyClass &self, const std::string &text,
const std::string &prompt_text,
const std::vector<float> &prompt_samples, int32_t sample_rate,
float speed, int32_t num_steps,
std::function<int32_t(py::array_t<float>, float)> callback)
-> GeneratedAudio {
if (!callback) {
return self.Generate(text, prompt_text, prompt_samples,
sample_rate, speed, num_steps);
}
std::function<int32_t(const float *, int32_t, float)>
callback_wrapper = [callback](const float *samples, int32_t n,
float progress) {
// CAUTION(fangjun): we have to copy samples since it is
// freed once the call back returns.
pybind11::gil_scoped_acquire acquire;
pybind11::array_t<float> array(n);
py::buffer_info buf = array.request();
auto p = static_cast<float *>(buf.ptr);
std::copy(samples, samples + n, p);
return callback(array, progress);
};
return self.Generate(text, prompt_text, prompt_samples, sample_rate,
speed, num_steps, callback_wrapper);
},
py::arg("text"), py::arg("prompt_text"), py::arg("prompt_samples"),
py::arg("sample_rate"), py::arg("speed") = 1.0,
py::arg("num_steps") = 4, py::arg("callback") = py::none(),
py::call_guard<py::gil_scoped_release>());
}
... ...
... ... @@ -49,6 +49,7 @@ from sherpa_onnx.lib._sherpa_onnx import (
OfflineTtsMatchaModelConfig,
OfflineTtsModelConfig,
OfflineTtsVitsModelConfig,
OfflineTtsZipvoiceModelConfig,
OfflineWenetCtcModelConfig,
OfflineWhisperModelConfig,
OfflineZipformerAudioTaggingModelConfig,
... ...