Wei Kang
Committed by GitHub

Add Zipvoice (#2487)

Co-authored-by: yaozengwei <yaozengwei@outlook.com>
@@ -372,6 +372,7 @@ endif() @@ -372,6 +372,7 @@ endif()
372 include(kaldi-native-fbank) 372 include(kaldi-native-fbank)
373 include(kaldi-decoder) 373 include(kaldi-decoder)
374 include(onnxruntime) 374 include(onnxruntime)
  375 +include(cppinyin)
375 include(simple-sentencepiece) 376 include(simple-sentencepiece)
376 set(ONNXRUNTIME_DIR ${onnxruntime_SOURCE_DIR}) 377 set(ONNXRUNTIME_DIR ${onnxruntime_SOURCE_DIR})
377 message(STATUS "ONNXRUNTIME_DIR: ${ONNXRUNTIME_DIR}") 378 message(STATUS "ONNXRUNTIME_DIR: ${ONNXRUNTIME_DIR}")
  1 +function(download_cppinyin)
  2 + include(FetchContent)
  3 +
  4 + set(cppinyin_URL "https://github.com/pkufool/cppinyin/archive/refs/tags/v0.10.tar.gz")
  5 + set(cppinyin_URL2 "https://gh-proxy.com/https://github.com/pkufool/cppinyin/archive/refs/tags/v0.10.tar.gz")
  6 + set(cppinyin_HASH "SHA256=abe6584d7ee56829e8f4b5fbda3b50ecdf49a13be8e413a78d1b0d5d5c019982")
  7 +
  8 + # If you don't have access to the Internet,
  9 + # please pre-download cppinyin
  10 + set(possible_file_locations
  11 + $ENV{HOME}/Downloads/cppinyin-0.10.tar.gz
  12 + ${CMAKE_SOURCE_DIR}/cppinyin-0.10.tar.gz
  13 + ${CMAKE_BINARY_DIR}/cppinyin-0.10.tar.gz
  14 + /tmp/cppinyin-0.10.tar.gz
  15 + /star-fj/fangjun/download/github/cppinyin-0.10.tar.gz
  16 + )
  17 +
  18 + foreach(f IN LISTS possible_file_locations)
  19 + if(EXISTS ${f})
  20 + set(cppinyin_URL "${f}")
  21 + file(TO_CMAKE_PATH "${cppinyin_URL}" cppinyin_URL)
  22 + message(STATUS "Found local downloaded cppinyin: ${cppinyin_URL}")
  23 + set(cppinyin_URL2)
  24 + break()
  25 + endif()
  26 + endforeach()
  27 +
  28 + set(CPPINYIN_ENABLE_TESTS OFF CACHE BOOL "" FORCE)
  29 + set(CPPINYIN_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
  30 +
  31 + FetchContent_Declare(cppinyin
  32 + URL
  33 + ${cppinyin_URL}
  34 + ${cppinyin_URL2}
  35 + URL_HASH
  36 + ${cppinyin_HASH}
  37 + )
  38 +
  39 + FetchContent_GetProperties(cppinyin)
  40 + if(NOT cppinyin_POPULATED)
  41 + message(STATUS "Downloading cppinyin ${cppinyin_URL}")
  42 + FetchContent_Populate(cppinyin)
  43 + endif()
  44 + message(STATUS "cppinyin is downloaded to ${cppinyin_SOURCE_DIR}")
  45 +
  46 + if(BUILD_SHARED_LIBS)
  47 + set(_build_shared_libs_bak ${BUILD_SHARED_LIBS})
  48 + set(BUILD_SHARED_LIBS OFF)
  49 + endif()
  50 +
  51 + add_subdirectory(${cppinyin_SOURCE_DIR} ${cppinyin_BINARY_DIR} EXCLUDE_FROM_ALL)
  52 +
  53 + if(_build_shared_libs_bak)
  54 + set_target_properties(cppinyin_core
  55 + PROPERTIES
  56 + POSITION_INDEPENDENT_CODE ON
  57 + C_VISIBILITY_PRESET hidden
  58 + CXX_VISIBILITY_PRESET hidden
  59 + )
  60 + set(BUILD_SHARED_LIBS ON)
  61 + endif()
  62 +
  63 + target_include_directories(cppinyin_core
  64 + PUBLIC
  65 + ${cppinyin_SOURCE_DIR}/
  66 + )
  67 +
  68 + if(NOT BUILD_SHARED_LIBS)
  69 + install(TARGETS cppinyin_core DESTINATION lib)
  70 + endif()
  71 +
  72 +endfunction()
  73 +
  74 +download_cppinyin()
1 #!/usr/bin/env python3 1 #!/usr/bin/env python3
2 # 2 #
3 -# Copyright (c) 2023 Xiaomi Corporation 3 +# Copyright (c) 2023-2025 Xiaomi Corporation
4 4
5 """ 5 """
6 This file demonstrates how to use sherpa-onnx Python API to generate audio 6 This file demonstrates how to use sherpa-onnx Python API to generate audio
@@ -453,7 +453,9 @@ def main(): @@ -453,7 +453,9 @@ def main():
453 end = time.time() 453 end = time.time()
454 454
455 if len(audio.samples) == 0: 455 if len(audio.samples) == 0:
456 - print("Error in generating audios. Please read previous error messages.") 456 + print(
  457 + "Error in generating audios. Please read previous error messages."
  458 + )
457 return 459 return
458 460
459 elapsed_seconds = end - start 461 elapsed_seconds = end - start
@@ -470,7 +472,9 @@ def main(): @@ -470,7 +472,9 @@ def main():
470 print(f"The text is '{args.text}'") 472 print(f"The text is '{args.text}'")
471 print(f"Elapsed seconds: {elapsed_seconds:.3f}") 473 print(f"Elapsed seconds: {elapsed_seconds:.3f}")
472 print(f"Audio duration in seconds: {audio_duration:.3f}") 474 print(f"Audio duration in seconds: {audio_duration:.3f}")
473 - print(f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}") 475 + print(
  476 + f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}"
  477 + )
474 478
475 479
476 if __name__ == "__main__": 480 if __name__ == "__main__":
  1 +#!/usr/bin/env python3
  2 +#
  3 +# Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +"""
  6 +This file demonstrates how to use sherpa-onnx Python API to generate audio
  7 +from text with prompt, i.e., zero shot text-to-speech.
  8 +
  9 +Usage:
  10 +
  11 +Example (zipvoice)
  12 +
  13 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/sherpa-onnx-zipvoice-distill-zh-en-emilia.tar.bz2
  14 +tar xf sherpa-onnx-zipvoice-distill-zh-en-emilia.tar.bz2
  15 +
  16 +python3 ./python-api-examples/offline-zeroshot-tts.py \
  17 + --zipvoice-flow-matching-model sherpa-onnx-zipvoice-distill-zh-en-emilia/fm_decoder.onnx \
  18 + --zipvoice-text-model sherpa-onnx-zipvoice-distill-zh-en-emilia/text_encoder.onnx \
  19 + --zipvoice-data-dir sherpa-onnx-zipvoice-distill-zh-en-emilia/espeak-ng-data \
  20 + --zipvoice-pinyin-dict sherpa-onnx-zipvoice-distill-zh-en-emilia/pinyin.raw \
  21 + --zipvoice-tokens sherpa-onnx-zipvoice-distill-zh-en-emilia/tokens.txt \
  22 + --zipvoice-vocoder sherpa-onnx-zipvoice-distill-zh-en-emilia/vocos_24khz.onnx \
  23 + --prompt-audio sherpa-onnx-zipvoice-distill-zh-en-emilia/prompt.wav \
  24 + --zipvoice-num-steps 4 \
  25 + --num-threads 4 \
  26 + --prompt-text "周日被我射熄火了,所以今天是周一。" \
  27 + "我是中国人民的儿子,我爱我的祖国。我得祖国是一个伟大的国家,拥有五千年的文明史。"
  28 +"""
  29 +
  30 +import argparse
  31 +import time
  32 +import wave
  33 +import numpy as np
  34 +
  35 +from typing import Tuple
  36 +
  37 +import sherpa_onnx
  38 +import soundfile as sf
  39 +
  40 +
  41 +def add_zipvoice_args(parser):
  42 + parser.add_argument(
  43 + "--zipvoice-tokens",
  44 + type=str,
  45 + default="",
  46 + help="Path to tokens.txt for Zipvoice models.",
  47 + )
  48 +
  49 + parser.add_argument(
  50 + "--zipvoice-text-model",
  51 + type=str,
  52 + default="",
  53 + help="Path to zipvoice text model.",
  54 + )
  55 +
  56 + parser.add_argument(
  57 + "--zipvoice-flow-matching-model",
  58 + type=str,
  59 + default="",
  60 + help="Path to zipvoice flow matching model.",
  61 + )
  62 +
  63 + parser.add_argument(
  64 + "--zipvoice-data-dir",
  65 + type=str,
  66 + default="",
  67 + help="Path to the dict directory of espeak-ng.",
  68 + )
  69 +
  70 + parser.add_argument(
  71 + "--zipvoice-pinyin-dict",
  72 + type=str,
  73 + default="",
  74 + help="Path to the pinyin dictionary.",
  75 + )
  76 +
  77 + parser.add_argument(
  78 + "--zipvoice-vocoder",
  79 + type=str,
  80 + default="",
  81 + help="Path to the vocos vocoder.",
  82 + )
  83 +
  84 + parser.add_argument(
  85 + "--zipvoice-num-steps",
  86 + type=int,
  87 + default=4,
  88 + help="Number of steps for Zipvoice.",
  89 + )
  90 +
  91 + parser.add_argument(
  92 + "--zipvoice-feat-scale",
  93 + type=float,
  94 + default=0.1,
  95 + help="Scale factor for Zipvoice features.",
  96 + )
  97 +
  98 + parser.add_argument(
  99 + "--zipvoice-t-shift",
  100 + type=float,
  101 + default=0.5,
  102 + help="Shift t to smaller ones if t-shift < 1.0.",
  103 + )
  104 +
  105 + parser.add_argument(
  106 + "--zipvoice-target-rms",
  107 + type=float,
  108 + default=0.1,
  109 + help="Target speech normalization RMS value for Zipvoice.",
  110 + )
  111 +
  112 + parser.add_argument(
  113 + "--zipvoice-guidance-scale",
  114 + type=float,
  115 + default=1.0,
  116 + help="The scale of classifier-free guidance during inference for for Zipvoice.",
  117 + )
  118 +
  119 +
  120 +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
  121 + """
  122 + Args:
  123 + wave_filename:
  124 + Path to a wave file. It should be single channel and each sample should
  125 + be 16-bit. Its sample rate does not need to be 16kHz.
  126 + Returns:
  127 + Return a tuple containing:
  128 + - A 1-D array of dtype np.float32 containing the samples, which are
  129 + normalized to the range [-1, 1].
  130 + - sample rate of the wave file
  131 + """
  132 +
  133 + with wave.open(wave_filename) as f:
  134 + assert f.getnchannels() == 1, f.getnchannels()
  135 + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
  136 + num_samples = f.getnframes()
  137 + samples = f.readframes(num_samples)
  138 + samples_int16 = np.frombuffer(samples, dtype=np.int16)
  139 + samples_float32 = samples_int16.astype(np.float32)
  140 +
  141 + samples_float32 = samples_float32 / 32768
  142 + return samples_float32, f.getframerate()
  143 +
  144 +
  145 +def get_args():
  146 + parser = argparse.ArgumentParser(
  147 + formatter_class=argparse.ArgumentDefaultsHelpFormatter
  148 + )
  149 +
  150 + add_zipvoice_args(parser)
  151 +
  152 + parser.add_argument(
  153 + "--tts-rule-fsts",
  154 + type=str,
  155 + default="",
  156 + help="Path to rule.fst",
  157 + )
  158 +
  159 + parser.add_argument(
  160 + "--max-num-sentences",
  161 + type=int,
  162 + default=1,
  163 + help="""Max number of sentences in a batch to avoid OOM if the input
  164 + text is very long. Set it to -1 to process all the sentences in a
  165 + single batch. A smaller value does not mean it is slower compared
  166 + to a larger one on CPU.
  167 + """,
  168 + )
  169 +
  170 + parser.add_argument(
  171 + "--output-filename",
  172 + type=str,
  173 + default="./generated.wav",
  174 + help="Path to save generated wave",
  175 + )
  176 +
  177 + parser.add_argument(
  178 + "--debug",
  179 + type=bool,
  180 + default=False,
  181 + help="True to show debug messages",
  182 + )
  183 +
  184 + parser.add_argument(
  185 + "--provider",
  186 + type=str,
  187 + default="cpu",
  188 + help="valid values: cpu, cuda, coreml",
  189 + )
  190 +
  191 + parser.add_argument(
  192 + "--num-threads",
  193 + type=int,
  194 + default=1,
  195 + help="Number of threads for neural network computation",
  196 + )
  197 +
  198 + parser.add_argument(
  199 + "--speed",
  200 + type=float,
  201 + default=1.0,
  202 + help="Speech speed. Larger->faster; smaller->slower",
  203 + )
  204 +
  205 + parser.add_argument(
  206 + "--prompt-text",
  207 + type=str,
  208 + required=True,
  209 + help="The transcription of prompt audio (Zipvoice)",
  210 + )
  211 +
  212 + parser.add_argument(
  213 + "--prompt-audio",
  214 + type=str,
  215 + required=True,
  216 + help="The path to prompt audio (Zipvoice).",
  217 + )
  218 +
  219 + parser.add_argument(
  220 + "text",
  221 + type=str,
  222 + help="The input text to generate audio for",
  223 + )
  224 +
  225 + return parser.parse_args()
  226 +
  227 +
  228 +def main():
  229 + args = get_args()
  230 + print(args)
  231 +
  232 + tts_config = sherpa_onnx.OfflineTtsConfig(
  233 + model=sherpa_onnx.OfflineTtsModelConfig(
  234 + zipvoice=sherpa_onnx.OfflineTtsZipvoiceModelConfig(
  235 + tokens=args.zipvoice_tokens,
  236 + text_model=args.zipvoice_text_model,
  237 + flow_matching_model=args.zipvoice_flow_matching_model,
  238 + data_dir=args.zipvoice_data_dir,
  239 + pinyin_dict=args.zipvoice_pinyin_dict,
  240 + vocoder=args.zipvoice_vocoder,
  241 + feat_scale=args.zipvoice_feat_scale,
  242 + t_shift=args.zipvoice_t_shift,
  243 + target_rms=args.zipvoice_target_rms,
  244 + guidance_scale=args.zipvoice_guidance_scale,
  245 + ),
  246 + provider=args.provider,
  247 + debug=args.debug,
  248 + num_threads=args.num_threads,
  249 + ),
  250 + rule_fsts=args.tts_rule_fsts,
  251 + max_num_sentences=args.max_num_sentences,
  252 + )
  253 + if not tts_config.validate():
  254 + raise ValueError("Please check your config")
  255 +
  256 + tts = sherpa_onnx.OfflineTts(tts_config)
  257 +
  258 + start = time.time()
  259 + prompt_samples, sample_rate = read_wave(args.prompt_audio)
  260 + audio = tts.generate(
  261 + args.text,
  262 + args.prompt_text,
  263 + prompt_samples,
  264 + sample_rate,
  265 + speed=args.speed,
  266 + num_steps=args.zipvoice_num_steps,
  267 + )
  268 + end = time.time()
  269 +
  270 + if len(audio.samples) == 0:
  271 + print(
  272 + "Error in generating audios. Please read previous error messages."
  273 + )
  274 + return
  275 +
  276 + elapsed_seconds = end - start
  277 + audio_duration = len(audio.samples) / audio.sample_rate
  278 + real_time_factor = elapsed_seconds / audio_duration
  279 +
  280 + sf.write(
  281 + args.output_filename,
  282 + audio.samples,
  283 + samplerate=audio.sample_rate,
  284 + subtype="PCM_16",
  285 + )
  286 + print(f"Saved to {args.output_filename}")
  287 + print(f"The text is '{args.text}'")
  288 + print(f"Elapsed seconds: {elapsed_seconds:.3f}")
  289 + print(f"Audio duration in seconds: {audio_duration:.3f}")
  290 + print(
  291 + f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}"
  292 + )
  293 +
  294 +
  295 +if __name__ == "__main__":
  296 + main()
@@ -201,6 +201,9 @@ if(SHERPA_ONNX_ENABLE_TTS) @@ -201,6 +201,9 @@ if(SHERPA_ONNX_ENABLE_TTS)
201 offline-tts-model-config.cc 201 offline-tts-model-config.cc
202 offline-tts-vits-model-config.cc 202 offline-tts-vits-model-config.cc
203 offline-tts-vits-model.cc 203 offline-tts-vits-model.cc
  204 + offline-tts-zipvoice-frontend.cc
  205 + offline-tts-zipvoice-model.cc
  206 + offline-tts-zipvoice-model-config.cc
204 offline-tts.cc 207 offline-tts.cc
205 piper-phonemize-lexicon.cc 208 piper-phonemize-lexicon.cc
206 vocoder.cc 209 vocoder.cc
@@ -265,6 +268,7 @@ if(ANDROID_NDK) @@ -265,6 +268,7 @@ if(ANDROID_NDK)
265 endif() 268 endif()
266 269
267 target_link_libraries(sherpa-onnx-core 270 target_link_libraries(sherpa-onnx-core
  271 + cppinyin_core
268 kaldi-native-fbank-core 272 kaldi-native-fbank-core
269 kaldi-decoder-core 273 kaldi-decoder-core
270 ssentencepiece_core 274 ssentencepiece_core
@@ -348,6 +352,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) @@ -348,6 +352,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
348 352
349 if(SHERPA_ONNX_ENABLE_TTS) 353 if(SHERPA_ONNX_ENABLE_TTS)
350 add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) 354 add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
  355 + add_executable(sherpa-onnx-offline-zeroshot-tts sherpa-onnx-offline-zeroshot-tts.cc)
351 endif() 356 endif()
352 357
353 if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) 358 if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
@@ -370,6 +375,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) @@ -370,6 +375,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
370 if(SHERPA_ONNX_ENABLE_TTS) 375 if(SHERPA_ONNX_ENABLE_TTS)
371 list(APPEND main_exes 376 list(APPEND main_exes
372 sherpa-onnx-offline-tts 377 sherpa-onnx-offline-tts
  378 + sherpa-onnx-offline-zeroshot-tts
373 ) 379 )
374 endif() 380 endif()
375 381
@@ -667,6 +673,7 @@ if(SHERPA_ONNX_ENABLE_TESTS) @@ -667,6 +673,7 @@ if(SHERPA_ONNX_ENABLE_TESTS)
667 if(SHERPA_ONNX_ENABLE_TTS) 673 if(SHERPA_ONNX_ENABLE_TTS)
668 list(APPEND sherpa_onnx_test_srcs 674 list(APPEND sherpa_onnx_test_srcs
669 cppjieba-test.cc 675 cppjieba-test.cc
  676 + offline-tts-zipvoice-frontend-test.cc
670 piper-phonemize-test.cc 677 piper-phonemize-test.cc
671 ) 678 )
672 endif() 679 endif()
@@ -20,6 +20,7 @@ @@ -20,6 +20,7 @@
20 #include "sherpa-onnx/csrc/offline-tts-kokoro-impl.h" 20 #include "sherpa-onnx/csrc/offline-tts-kokoro-impl.h"
21 #include "sherpa-onnx/csrc/offline-tts-matcha-impl.h" 21 #include "sherpa-onnx/csrc/offline-tts-matcha-impl.h"
22 #include "sherpa-onnx/csrc/offline-tts-vits-impl.h" 22 #include "sherpa-onnx/csrc/offline-tts-vits-impl.h"
  23 +#include "sherpa-onnx/csrc/offline-tts-zipvoice-impl.h"
23 24
24 namespace sherpa_onnx { 25 namespace sherpa_onnx {
25 26
@@ -41,6 +42,9 @@ std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create( @@ -41,6 +42,9 @@ std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create(
41 return std::make_unique<OfflineTtsVitsImpl>(config); 42 return std::make_unique<OfflineTtsVitsImpl>(config);
42 } else if (!config.model.matcha.acoustic_model.empty()) { 43 } else if (!config.model.matcha.acoustic_model.empty()) {
43 return std::make_unique<OfflineTtsMatchaImpl>(config); 44 return std::make_unique<OfflineTtsMatchaImpl>(config);
  45 + } else if (!config.model.zipvoice.text_model.empty() &&
  46 + !config.model.zipvoice.flow_matching_model.empty()) {
  47 + return std::make_unique<OfflineTtsZipvoiceImpl>(config);
44 } else if (!config.model.kokoro.model.empty()) { 48 } else if (!config.model.kokoro.model.empty()) {
45 return std::make_unique<OfflineTtsKokoroImpl>(config); 49 return std::make_unique<OfflineTtsKokoroImpl>(config);
46 } else if (!config.model.kitten.model.empty()) { 50 } else if (!config.model.kitten.model.empty()) {
@@ -59,6 +63,9 @@ std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create( @@ -59,6 +63,9 @@ std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create(
59 return std::make_unique<OfflineTtsVitsImpl>(mgr, config); 63 return std::make_unique<OfflineTtsVitsImpl>(mgr, config);
60 } else if (!config.model.matcha.acoustic_model.empty()) { 64 } else if (!config.model.matcha.acoustic_model.empty()) {
61 return std::make_unique<OfflineTtsMatchaImpl>(mgr, config); 65 return std::make_unique<OfflineTtsMatchaImpl>(mgr, config);
  66 + } else if (!config.model.zipvoice.text_model.empty() &&
  67 + !config.model.zipvoice.flow_matching_model.empty()) {
  68 + return std::make_unique<OfflineTtsZipvoiceImpl>(mgr, config);
62 } else if (!config.model.kokoro.model.empty()) { 69 } else if (!config.model.kokoro.model.empty()) {
63 return std::make_unique<OfflineTtsKokoroImpl>(mgr, config); 70 return std::make_unique<OfflineTtsKokoroImpl>(mgr, config);
64 } else if (!config.model.kitten.model.empty()) { 71 } else if (!config.model.kitten.model.empty()) {
@@ -6,9 +6,11 @@ @@ -6,9 +6,11 @@
6 #define SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_ 6 #define SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_
7 7
8 #include <memory> 8 #include <memory>
  9 +#include <stdexcept>
9 #include <string> 10 #include <string>
10 #include <vector> 11 #include <vector>
11 12
  13 +#include "sherpa-onnx/csrc/macros.h"
12 #include "sherpa-onnx/csrc/offline-tts.h" 14 #include "sherpa-onnx/csrc/offline-tts.h"
13 15
14 namespace sherpa_onnx { 16 namespace sherpa_onnx {
@@ -25,14 +27,29 @@ class OfflineTtsImpl { @@ -25,14 +27,29 @@ class OfflineTtsImpl {
25 27
26 virtual GeneratedAudio Generate( 28 virtual GeneratedAudio Generate(
27 const std::string &text, int64_t sid = 0, float speed = 1.0, 29 const std::string &text, int64_t sid = 0, float speed = 1.0,
28 - GeneratedAudioCallback callback = nullptr) const = 0; 30 + GeneratedAudioCallback callback = nullptr) const {
  31 + throw std::runtime_error(
  32 + "OfflineTtsImpl backend does not support non zero-shot Generate()");
  33 + }
  34 +
  35 + virtual GeneratedAudio Generate(
  36 + const std::string &text, const std::string &prompt_text,
  37 + const std::vector<float> &prompt_samples, int32_t sample_rate,
  38 + float speed = 1.0, int32_t num_step = 4,
  39 + GeneratedAudioCallback callback = nullptr) const {
  40 + throw std::runtime_error(
  41 + "OfflineTtsImpl backend does not support zero-shot Generate()");
  42 + }
29 43
30 // Return the sample rate of the generated audio 44 // Return the sample rate of the generated audio
31 virtual int32_t SampleRate() const = 0; 45 virtual int32_t SampleRate() const = 0;
32 46
33 // Number of supported speakers. 47 // Number of supported speakers.
34 // If it supports only a single speaker, then it return 0 or 1. 48 // If it supports only a single speaker, then it return 0 or 1.
35 - virtual int32_t NumSpeakers() const = 0; 49 + virtual int32_t NumSpeakers() const {
  50 + throw std::runtime_error(
  51 + "Zero-shot OfflineTts does not support NumSpeakers()");
  52 + }
36 53
37 std::vector<int64_t> AddBlank(const std::vector<int64_t> &x, 54 std::vector<int64_t> AddBlank(const std::vector<int64_t> &x,
38 int32_t blank_id = 0) const; 55 int32_t blank_id = 0) const;
@@ -12,6 +12,7 @@ void OfflineTtsModelConfig::Register(ParseOptions *po) { @@ -12,6 +12,7 @@ void OfflineTtsModelConfig::Register(ParseOptions *po) {
12 vits.Register(po); 12 vits.Register(po);
13 matcha.Register(po); 13 matcha.Register(po);
14 kokoro.Register(po); 14 kokoro.Register(po);
  15 + zipvoice.Register(po);
15 kitten.Register(po); 16 kitten.Register(po);
16 17
17 po->Register("num-threads", &num_threads, 18 po->Register("num-threads", &num_threads,
@@ -38,6 +39,10 @@ bool OfflineTtsModelConfig::Validate() const { @@ -38,6 +39,10 @@ bool OfflineTtsModelConfig::Validate() const {
38 return matcha.Validate(); 39 return matcha.Validate();
39 } 40 }
40 41
  42 + if (!zipvoice.flow_matching_model.empty()) {
  43 + return zipvoice.Validate();
  44 + }
  45 +
41 if (!kokoro.model.empty()) { 46 if (!kokoro.model.empty()) {
42 return kokoro.Validate(); 47 return kokoro.Validate();
43 } 48 }
@@ -58,6 +63,7 @@ std::string OfflineTtsModelConfig::ToString() const { @@ -58,6 +63,7 @@ std::string OfflineTtsModelConfig::ToString() const {
58 os << "vits=" << vits.ToString() << ", "; 63 os << "vits=" << vits.ToString() << ", ";
59 os << "matcha=" << matcha.ToString() << ", "; 64 os << "matcha=" << matcha.ToString() << ", ";
60 os << "kokoro=" << kokoro.ToString() << ", "; 65 os << "kokoro=" << kokoro.ToString() << ", ";
  66 + os << "zipvoice=" << zipvoice.ToString() << ", ";
61 os << "kitten=" << kitten.ToString() << ", "; 67 os << "kitten=" << kitten.ToString() << ", ";
62 os << "num_threads=" << num_threads << ", "; 68 os << "num_threads=" << num_threads << ", ";
63 os << "debug=" << (debug ? "True" : "False") << ", "; 69 os << "debug=" << (debug ? "True" : "False") << ", ";
@@ -11,6 +11,7 @@ @@ -11,6 +11,7 @@
11 #include "sherpa-onnx/csrc/offline-tts-kokoro-model-config.h" 11 #include "sherpa-onnx/csrc/offline-tts-kokoro-model-config.h"
12 #include "sherpa-onnx/csrc/offline-tts-matcha-model-config.h" 12 #include "sherpa-onnx/csrc/offline-tts-matcha-model-config.h"
13 #include "sherpa-onnx/csrc/offline-tts-vits-model-config.h" 13 #include "sherpa-onnx/csrc/offline-tts-vits-model-config.h"
  14 +#include "sherpa-onnx/csrc/offline-tts-zipvoice-model-config.h"
14 #include "sherpa-onnx/csrc/parse-options.h" 15 #include "sherpa-onnx/csrc/parse-options.h"
15 16
16 namespace sherpa_onnx { 17 namespace sherpa_onnx {
@@ -19,6 +20,7 @@ struct OfflineTtsModelConfig { @@ -19,6 +20,7 @@ struct OfflineTtsModelConfig {
19 OfflineTtsVitsModelConfig vits; 20 OfflineTtsVitsModelConfig vits;
20 OfflineTtsMatchaModelConfig matcha; 21 OfflineTtsMatchaModelConfig matcha;
21 OfflineTtsKokoroModelConfig kokoro; 22 OfflineTtsKokoroModelConfig kokoro;
  23 + OfflineTtsZipvoiceModelConfig zipvoice;
22 OfflineTtsKittenModelConfig kitten; 24 OfflineTtsKittenModelConfig kitten;
23 25
24 int32_t num_threads = 1; 26 int32_t num_threads = 1;
@@ -30,12 +32,14 @@ struct OfflineTtsModelConfig { @@ -30,12 +32,14 @@ struct OfflineTtsModelConfig {
30 OfflineTtsModelConfig(const OfflineTtsVitsModelConfig &vits, 32 OfflineTtsModelConfig(const OfflineTtsVitsModelConfig &vits,
31 const OfflineTtsMatchaModelConfig &matcha, 33 const OfflineTtsMatchaModelConfig &matcha,
32 const OfflineTtsKokoroModelConfig &kokoro, 34 const OfflineTtsKokoroModelConfig &kokoro,
  35 + const OfflineTtsZipvoiceModelConfig &zipvoice,
33 const OfflineTtsKittenModelConfig &kitten, 36 const OfflineTtsKittenModelConfig &kitten,
34 int32_t num_threads, bool debug, 37 int32_t num_threads, bool debug,
35 const std::string &provider) 38 const std::string &provider)
36 : vits(vits), 39 : vits(vits),
37 matcha(matcha), 40 matcha(matcha),
38 kokoro(kokoro), 41 kokoro(kokoro),
  42 + zipvoice(zipvoice),
39 kitten(kitten), 43 kitten(kitten),
40 num_threads(num_threads), 44 num_threads(num_threads),
41 debug(debug), 45 debug(debug),
  1 +// sherpa-onnx/csrc/offline-tts-zipvoice-frontend-test.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-tts-zipvoice-frontend.h"
  6 +
  7 +#include "espeak-ng/speak_lib.h"
  8 +#include "gtest/gtest.h"
  9 +#include "phoneme_ids.hpp"
  10 +#include "phonemize.hpp"
  11 +#include "sherpa-onnx/csrc/file-utils.h"
  12 +#include "sherpa-onnx/csrc/macros.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +TEST(ZipVoiceFrontend, Case1) {
  17 + std::string data_dir = "../zipvoice/espeak-ng-data";
  18 + if (!FileExists(data_dir + "/en_dict")) {
  19 + SHERPA_ONNX_LOGE("%s/en_dict does not exist. Skipping test",
  20 + data_dir.c_str());
  21 + return;
  22 + }
  23 +
  24 + if (!FileExists(data_dir + "/phontab")) {
  25 + SHERPA_ONNX_LOGE("%s/phontab does not exist. Skipping test",
  26 + data_dir.c_str());
  27 + return;
  28 + }
  29 +
  30 + if (!FileExists(data_dir + "/phonindex")) {
  31 + SHERPA_ONNX_LOGE("%s/phonindex does not exist. Skipping test",
  32 + data_dir.c_str());
  33 + return;
  34 + }
  35 +
  36 + if (!FileExists(data_dir + "/phondata")) {
  37 + SHERPA_ONNX_LOGE("%s/phondata does not exist. Skipping test",
  38 + data_dir.c_str());
  39 + return;
  40 + }
  41 +
  42 + if (!FileExists(data_dir + "/intonations")) {
  43 + SHERPA_ONNX_LOGE("%s/intonations does not exist. Skipping test",
  44 + data_dir.c_str());
  45 + return;
  46 + }
  47 +
  48 + std::string pinyin_dict = data_dir + "/../pinyin.dict";
  49 + if (!FileExists(pinyin_dict)) {
  50 + SHERPA_ONNX_LOGE("%s does not exist. Skipping test", pinyin_dict.c_str());
  51 + return;
  52 + }
  53 +
  54 + std::string tokens_file = data_dir + "/../tokens.txt";
  55 + if (!FileExists(tokens_file)) {
  56 + SHERPA_ONNX_LOGE("%s does not exist. Skipping test", tokens_file.c_str());
  57 + return;
  58 + }
  59 +
  60 + auto frontend = OfflineTtsZipvoiceFrontend(
  61 + tokens_file, data_dir, pinyin_dict,
  62 + OfflineTtsZipvoiceModelMetaData{.use_espeak = true, .use_pinyin = true},
  63 + true);
  64 +
  65 + std::string text = "how are you doing?";
  66 + std::vector<sherpa_onnx::TokenIDs> ans =
  67 + frontend.ConvertTextToTokenIds(text, "en-us");
  68 +
  69 + text = "这是第一句。这是第二句。";
  70 + ans = frontend.ConvertTextToTokenIds(text, "en-us");
  71 +
  72 + text =
  73 + "这是第一句。这是第二句。<pin1><yin2>测试 [S1]and hello "
  74 + "world[S2]这是第三句。";
  75 + ans = frontend.ConvertTextToTokenIds(text, "en-us");
  76 +}
  77 +
  78 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-tts-zipvoice-frontend.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include <algorithm>
  6 +#include <cctype>
  7 +#include <codecvt>
  8 +#include <fstream>
  9 +#include <locale>
  10 +#include <regex> // NOLINT
  11 +#include <sstream>
  12 +#include <strstream>
  13 +#include <utility>
  14 +
  15 +#if __ANDROID_API__ >= 9
  16 +#include "android/asset_manager.h"
  17 +#include "android/asset_manager_jni.h"
  18 +#endif
  19 +
  20 +#if __OHOS__
  21 +#include "rawfile/raw_file_manager.h"
  22 +#endif
  23 +
  24 +#include "cppinyin/csrc/cppinyin.h"
  25 +#include "espeak-ng/speak_lib.h"
  26 +#include "phoneme_ids.hpp"
  27 +#include "phonemize.hpp"
  28 +#include "sherpa-onnx/csrc/file-utils.h"
  29 +#include "sherpa-onnx/csrc/macros.h"
  30 +#include "sherpa-onnx/csrc/offline-tts-zipvoice-frontend.h"
  31 +#include "sherpa-onnx/csrc/text-utils.h"
  32 +
  33 +namespace sherpa_onnx {
  34 +
  35 +void CallPhonemizeEspeak(const std::string &text,
  36 + piper::eSpeakPhonemeConfig &config, // NOLINT
  37 + std::vector<std::vector<piper::Phoneme>> *phonemes);
  38 +
  39 +static std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is) {
  40 + std::unordered_map<std::string, int32_t> token2id;
  41 +
  42 + std::string line;
  43 + std::string sym;
  44 + int32_t id = 0;
  45 + while (std::getline(is, line)) {
  46 + std::istringstream iss(line);
  47 + iss >> sym;
  48 + if (iss.eof()) {
  49 + id = atoi(sym.c_str());
  50 + sym = " ";
  51 + } else {
  52 + iss >> id;
  53 + }
  54 + // eat the trailing \r\n on windows
  55 + iss >> std::ws;
  56 + if (!iss.eof()) {
  57 + SHERPA_ONNX_LOGE("Error when reading tokens: %s", line.c_str());
  58 + exit(-1);
  59 + }
  60 +
  61 + if (token2id.count(sym)) {
  62 + SHERPA_ONNX_LOGE("Duplicated token %s. Line %s. Existing ID: %d",
  63 + sym.c_str(), line.c_str(), token2id.at(sym));
  64 + exit(-1);
  65 + }
  66 + token2id.insert({sym, id});
  67 + }
  68 + return token2id;
  69 +}
  70 +
  71 +static std::string MapPunctuations(
  72 + const std::string &text,
  73 + const std::unordered_map<std::string, std::string> &punct_map) {
  74 + std::string result = text;
  75 + for (const auto &kv : punct_map) {
  76 + // Replace all occurrences of kv.first with kv.second
  77 + size_t pos = 0;
  78 + while ((pos = result.find(kv.first, pos)) != std::string::npos) {
  79 + result.replace(pos, kv.first.length(), kv.second);
  80 + pos += kv.second.length();
  81 + }
  82 + }
  83 + return result;
  84 +}
  85 +
  86 +static void ProcessPinyin(
  87 + const std::string &pinyin, const cppinyin::PinyinEncoder *pinyin_encoder,
  88 + const std::unordered_map<std::string, int32_t> &token2id,
  89 + std::vector<int64_t> *tokens_ids, std::vector<std::string> *tokens) {
  90 + auto initial = pinyin_encoder->ToInitial(pinyin);
  91 + if (!initial.empty()) {
  92 + // append '0' to fix the conflict with espeak token
  93 + initial = initial + "0";
  94 + if (token2id.count(initial)) {
  95 + tokens_ids->push_back(token2id.at(initial));
  96 + tokens->push_back(initial);
  97 + } else {
  98 + SHERPA_ONNX_LOGE("Skip unknown initial %s", initial.c_str());
  99 + }
  100 + }
  101 + auto final_t = pinyin_encoder->ToFinal(pinyin);
  102 + if (!final_t.empty()) {
  103 + if (!std::isdigit(final_t.back())) {
  104 + final_t = final_t + "5"; // use 5 for neutral tone
  105 + }
  106 + if (token2id.count(final_t)) {
  107 + tokens_ids->push_back(token2id.at(final_t));
  108 + tokens->push_back(final_t);
  109 + } else {
  110 + SHERPA_ONNX_LOGE("Skip unknown final %s", final_t.c_str());
  111 + }
  112 + }
  113 +}
  114 +
  115 +static void TokenizeZh(const std::string &words,
  116 + const cppinyin::PinyinEncoder *pinyin_encoder,
  117 + const std::unordered_map<std::string, int32_t> &token2id,
  118 + std::vector<int64_t> *token_ids,
  119 + std::vector<std::string> *tokens) {
  120 + std::vector<std::string> pinyins;
  121 + pinyin_encoder->Encode(words, &pinyins, "number" /*tone*/, false /*partial*/);
  122 + for (const auto &pinyin : pinyins) {
  123 + if (pinyin_encoder->ValidPinyin(pinyin, "number" /*tone*/)) {
  124 + ProcessPinyin(pinyin, pinyin_encoder, token2id, token_ids, tokens);
  125 + } else {
  126 + auto wstext = ToWideString(pinyin);
  127 + for (auto &wc : wstext) {
  128 + auto c = ToString(std::wstring(1, wc));
  129 + if (token2id.count(c)) {
  130 + token_ids->push_back(token2id.at(c));
  131 + tokens->push_back(c);
  132 + } else {
  133 + SHERPA_ONNX_LOGE("Skip unknown character %s", c.c_str());
  134 + }
  135 + }
  136 + }
  137 + }
  138 +}
  139 +
  140 +static void TokenizeEn(const std::string &words,
  141 + const std::unordered_map<std::string, int32_t> &token2id,
  142 + const std::string &voice,
  143 + std::vector<int64_t> *token_ids,
  144 + std::vector<std::string> *tokens) {
  145 + piper::eSpeakPhonemeConfig config;
  146 + // ./bin/espeak-ng-bin --path ./install/share/espeak-ng-data/ --voices
  147 + // to list available voices
  148 + config.voice = voice; // e.g., voice is en-us
  149 +
  150 + std::vector<std::vector<piper::Phoneme>> phonemes;
  151 +
  152 + CallPhonemizeEspeak(words, config, &phonemes);
  153 +
  154 + for (const auto &p : phonemes) {
  155 + for (const auto &ph : p) {
  156 + auto token = Utf32ToUtf8(std::u32string(1, ph));
  157 + if (token2id.count(token)) {
  158 + token_ids->push_back(token2id.at(token));
  159 + tokens->push_back(token);
  160 + } else {
  161 + SHERPA_ONNX_LOGE("Skip unknown phoneme %s", token.c_str());
  162 + }
  163 + }
  164 + }
  165 +}
  166 +
  167 +static void TokenizeTag(
  168 + const std::string &words,
  169 + const std::unordered_map<std::string, int32_t> &token2id,
  170 + std::vector<int64_t> *tokens_ids, std::vector<std::string> *tokens) {
  171 + // in zipvoice tags are all in upper case
  172 + std::string tag = ToUpperAscii(words);
  173 + if (token2id.count(tag)) {
  174 + tokens_ids->push_back(token2id.at(tag));
  175 + tokens->push_back(tag);
  176 + } else {
  177 + SHERPA_ONNX_LOGE("Skip unknown tag %s", tag.c_str());
  178 + }
  179 +}
  180 +
  181 +static void TokenizePinyin(
  182 + const std::string &words, const cppinyin::PinyinEncoder *pinyin_encoder,
  183 + const std::unordered_map<std::string, int32_t> &token2id,
  184 + std::vector<int64_t> *tokens_ids, std::vector<std::string> *tokens) {
  185 + // words are in the form of <ha3>, <ha4>
  186 + std::string pinyin = words.substr(1, words.size() - 2);
  187 + if (!pinyin.empty()) {
  188 + if (pinyin[pinyin.size() - 1] == '5') {
  189 + pinyin = pinyin.substr(0, pinyin.size() - 1); // remove the tone
  190 + }
  191 + if (pinyin_encoder->ValidPinyin(pinyin, "number" /*tone*/)) {
  192 + ProcessPinyin(pinyin, pinyin_encoder, token2id, tokens_ids, tokens);
  193 + } else {
  194 + SHERPA_ONNX_LOGE("Invalid pinyin %s", pinyin.c_str());
  195 + }
  196 + }
  197 +}
  198 +
  199 +OfflineTtsZipvoiceFrontend::OfflineTtsZipvoiceFrontend(
  200 + const std::string &tokens, const std::string &data_dir,
  201 + const std::string &pinyin_dict,
  202 + const OfflineTtsZipvoiceModelMetaData &meta_data, bool debug)
  203 + : debug_(debug), meta_data_(meta_data) {
  204 + std::ifstream is(tokens);
  205 + token2id_ = ReadTokens(is);
  206 + if (meta_data_.use_pinyin) {
  207 + pinyin_encoder_ = std::make_unique<cppinyin::PinyinEncoder>(pinyin_dict);
  208 + } else {
  209 + pinyin_encoder_ = nullptr;
  210 + }
  211 + if (meta_data_.use_espeak) {
  212 + // We should copy the directory of espeak-ng-data from the asset to
  213 + // some internal or external storage and then pass the directory to
  214 + // data_dir.
  215 + InitEspeak(data_dir);
  216 + }
  217 +}
  218 +
  219 +template <typename Manager>
  220 +OfflineTtsZipvoiceFrontend::OfflineTtsZipvoiceFrontend(
  221 + Manager *mgr, const std::string &tokens, const std::string &data_dir,
  222 + const std::string &pinyin_dict,
  223 + const OfflineTtsZipvoiceModelMetaData &meta_data, bool debug)
  224 + : debug_(debug), meta_data_(meta_data) {
  225 + auto buf = ReadFile(mgr, tokens);
  226 + std::istrstream is(buf.data(), buf.size());
  227 + token2id_ = ReadTokens(is);
  228 + if (meta_data_.use_pinyin) {
  229 + auto buf = ReadFile(mgr, pinyin_dict);
  230 + std::istringstream iss(std::string(buf.begin(), buf.end()));
  231 + pinyin_encoder_ = std::make_unique<cppinyin::PinyinEncoder>(iss);
  232 + } else {
  233 + pinyin_encoder_ = nullptr;
  234 + }
  235 + if (meta_data_.use_espeak) {
  236 + // We should copy the directory of espeak-ng-data from the asset to
  237 + // some internal or external storage and then pass the directory to
  238 + // data_dir.
  239 + InitEspeak(data_dir);
  240 + }
  241 +}
  242 +
  243 +std::vector<TokenIDs> OfflineTtsZipvoiceFrontend::ConvertTextToTokenIds(
  244 + const std::string &_text, const std::string &voice) const {
  245 + std::string text = _text;
  246 + if (meta_data_.use_espeak) {
  247 + text = ToLowerAscii(_text);
  248 + }
  249 +
  250 + text = MapPunctuations(text, punct_map_);
  251 +
  252 + auto wstext = ToWideString(text);
  253 +
  254 + std::vector<std::string> parts;
  255 + // Match <...>, [...], or single character
  256 + std::wregex part_pattern(LR"([<\[].*?[>\]]|.)");
  257 + auto words_begin =
  258 + std::wsregex_iterator(wstext.begin(), wstext.end(), part_pattern);
  259 + auto words_end = std::wsregex_iterator();
  260 + for (std::wsregex_iterator i = words_begin; i != words_end; ++i) {
  261 + parts.push_back(ToString(i->str()));
  262 + }
  263 +
  264 + // types are en, zh, tag, pinyin, other
  265 + // tag is [...]
  266 + // pinyin is <...>
  267 + // other is any other text that does not match the above, normally numbers and
  268 + // punctuations
  269 + std::vector<std::string> types;
  270 + for (auto &word : parts) {
  271 + if (word.size() == 1 && std::isalpha(word[0])) {
  272 + // single character, e.g., 'a', 'b', 'c'
  273 + types.push_back("en");
  274 + } else if (word.size() > 1 && word[0] == '<' && word.back() == '>') {
  275 + // e.g., <ha3>, <ha4>
  276 + types.push_back("pinyin");
  277 + } else if (word.size() > 1 && word[0] == '[' && word.back() == ']') {
  278 + types.push_back("tag");
  279 + } else if (ContainsCJK(word)) { // word contains one CJK characters
  280 + types.push_back("zh");
  281 + } else {
  282 + types.push_back("other");
  283 + }
  284 + }
  285 +
  286 + std::vector<std::pair<std::string, std::string>> parts_with_types;
  287 + std::ostringstream oss;
  288 + std::string t_lang;
  289 + oss.str("");
  290 + std::ostringstream debug_oss;
  291 + if (debug_) {
  292 + debug_oss << "Text : " << _text << ", Parts with types: \n";
  293 + }
  294 + for (int32_t i = 0; i < types.size(); ++i) {
  295 + if (i == 0) {
  296 + oss << parts[i];
  297 + t_lang = types[i];
  298 + } else {
  299 + if (t_lang == "other" && (types[i] != "tag" && types[i] != "pinyin")) {
  300 + // combine into current type if the previous part is "other"
  301 + // do not combine with "tag" or "pinyin"
  302 + oss << parts[i];
  303 + t_lang = types[i];
  304 + } else {
  305 + if ((t_lang == types[i] || types[i] == "other") && t_lang != "pinyin" &&
  306 + t_lang != "tag") {
  307 + // same language or other, continue
  308 + // do not combine other into "pinyin" or "tag"
  309 + oss << parts[i];
  310 + } else {
  311 + // different language, start a new sentence
  312 + std::string part = oss.str();
  313 + oss.str("");
  314 + parts_with_types.emplace_back(part, t_lang);
  315 + if (debug_) {
  316 + debug_oss << "(" << part << ", " << t_lang << "),";
  317 + }
  318 + oss << parts[i];
  319 + t_lang = types[i];
  320 + }
  321 + }
  322 + }
  323 + }
  324 +
  325 + std::string part = oss.str();
  326 + oss.str("");
  327 + parts_with_types.emplace_back(part, t_lang);
  328 + if (debug_) {
  329 + debug_oss << "(" << part << ", " << t_lang << ")\n";
  330 + SHERPA_ONNX_LOGE("%s", debug_oss.str().c_str());
  331 + debug_oss.str("");
  332 + }
  333 +
  334 + std::vector<int64_t> token_ids;
  335 + std::vector<std::string> tokens; // for debugging
  336 + for (const auto &pt : parts_with_types) {
  337 + if (pt.second == "zh") {
  338 + TokenizeZh(pt.first, pinyin_encoder_.get(), token2id_, &token_ids,
  339 + &tokens);
  340 + } else if (pt.second == "en") {
  341 + TokenizeEn(pt.first, token2id_, voice, &token_ids, &tokens);
  342 + } else if (pt.second == "pinyin") {
  343 + TokenizePinyin(pt.first, pinyin_encoder_.get(), token2id_, &token_ids,
  344 + &tokens);
  345 + } else if (pt.second == "tag") {
  346 + TokenizeTag(pt.first, token2id_, &token_ids, &tokens);
  347 + } else {
  348 + SHERPA_ONNX_LOGE("Unexpected type: %s", pt.second.c_str());
  349 + exit(-1);
  350 + }
  351 + }
  352 + if (debug_) {
  353 + debug_oss << "Tokens and IDs: \n";
  354 + for (int32_t i = 0; i < tokens.size(); i++) {
  355 + debug_oss << "(" << tokens[i] << ", " << token_ids[i] << "),";
  356 + }
  357 + debug_oss << "\n";
  358 + SHERPA_ONNX_LOGE("%s", debug_oss.str().c_str());
  359 + }
  360 +
  361 + std::vector<TokenIDs> ans;
  362 + ans.push_back(TokenIDs(std::move(token_ids)));
  363 + return ans;
  364 +}
  365 +
  366 +#if __ANDROID_API__ >= 9
  367 +template OfflineTtsZipvoiceFrontend::OfflineTtsZipvoiceFrontend(
  368 + AAssetManager *mgr, const std::string &tokens, const std::string &data_dir,
  369 + const std::string &pinyin_dict,
  370 + const OfflineTtsZipvoiceModelMetaData &meta_data);
  371 +
  372 +#endif
  373 +
  374 +#if __OHOS__
  375 +template OfflineTtsZipvoiceFrontend::OfflineTtsZipvoiceFrontend(
  376 + NativeResourceManager *mgr, const std::string &tokens,
  377 + const std::string &data_dir, const std::string &pinyin_dict,
  378 + const OfflineTtsZipvoiceModelMetaData &meta_data);
  379 +
  380 +#endif
  381 +
  382 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-tts-zipvoice-frontend.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_FRONTEND_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_FRONTEND_H_
  7 +#include <cstdint>
  8 +#include <memory>
  9 +#include <string>
  10 +#include <unordered_map>
  11 +#include <vector>
  12 +
  13 +#include "cppinyin/csrc/cppinyin.h"
  14 +#include "sherpa-onnx/csrc/offline-tts-frontend.h"
  15 +#include "sherpa-onnx/csrc/offline-tts-zipvoice-model-meta-data.h"
  16 +
  17 +namespace sherpa_onnx {
  18 +
  19 +class OfflineTtsZipvoiceFrontend : public OfflineTtsFrontend {
  20 + public:
  21 + OfflineTtsZipvoiceFrontend(const std::string &tokens,
  22 + const std::string &data_dir,
  23 + const std::string &pinyin_dict,
  24 + const OfflineTtsZipvoiceModelMetaData &meta_data,
  25 + bool debug = false);
  26 +
  27 + template <typename Manager>
  28 + OfflineTtsZipvoiceFrontend(Manager *mgr, const std::string &tokens,
  29 + const std::string &data_dir,
  30 + const std::string &pinyin_dict,
  31 + const OfflineTtsZipvoiceModelMetaData &meta_data,
  32 + bool debug = false);
  33 +
  34 + /** Convert a string to token IDs.
  35 + *
  36 + * @param text The input text.
  37 + * Example 1: "This is the first sample sentence; this is the
  38 + * second one." Example 2: "这是第一句。这是第二句。"
  39 + * @param voice Optional. It is for espeak-ng.
  40 + *
  41 + * @return Return a vector-of-vector of token IDs. Each subvector contains
  42 + * a sentence that can be processed independently.
  43 + * If a frontend does not support splitting the text into
  44 + * sentences, the resulting vector contains only one subvector.
  45 + */
  46 + std::vector<TokenIDs> ConvertTextToTokenIds(
  47 + const std::string &text, const std::string &voice = "") const override;
  48 +
  49 + private:
  50 + bool debug_ = false;
  51 + std::unordered_map<std::string, int32_t> token2id_;
  52 + const std::unordered_map<std::string, std::string> punct_map_ = {
  53 + {",", ","}, {"。", "."}, {"!", "!"}, {"?", "?"}, {";", ";"},
  54 + {":", ":"}, {"、", ","}, {"‘", "'"}, {"“", "\""}, {"”", "\""},
  55 + {"’", "'"}, {"⋯", "…"}, {"···", "…"}, {"・・・", "…"}, {"...", "…"}};
  56 + OfflineTtsZipvoiceModelMetaData meta_data_;
  57 + std::unique_ptr<cppinyin::PinyinEncoder> pinyin_encoder_;
  58 +};
  59 +
  60 +} // namespace sherpa_onnx
  61 +
  62 +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_FRONTEND_H_
  1 +// sherpa-onnx/csrc/offline-tts-zipvoice-impl.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_IMPL_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_IMPL_H_
  6 +
  7 +#include <cmath>
  8 +#include <memory>
  9 +#include <string>
  10 +#include <strstream>
  11 +#include <utility>
  12 +#include <vector>
  13 +
  14 +#include "kaldi-native-fbank/csrc/mel-computations.h"
  15 +#include "kaldi-native-fbank/csrc/stft.h"
  16 +#include "sherpa-onnx/csrc/macros.h"
  17 +#include "sherpa-onnx/csrc/offline-tts-frontend.h"
  18 +#include "sherpa-onnx/csrc/offline-tts-impl.h"
  19 +#include "sherpa-onnx/csrc/offline-tts-zipvoice-frontend.h"
  20 +#include "sherpa-onnx/csrc/offline-tts-zipvoice-model-config.h"
  21 +#include "sherpa-onnx/csrc/offline-tts-zipvoice-model.h"
  22 +#include "sherpa-onnx/csrc/onnx-utils.h"
  23 +#include "sherpa-onnx/csrc/resample.h"
  24 +#include "sherpa-onnx/csrc/vocoder.h"
  25 +
  26 +namespace sherpa_onnx {
  27 +
  28 +class OfflineTtsZipvoiceImpl : public OfflineTtsImpl {
  29 + public:
  30 + explicit OfflineTtsZipvoiceImpl(const OfflineTtsConfig &config)
  31 + : config_(config),
  32 + model_(std::make_unique<OfflineTtsZipvoiceModel>(config.model)),
  33 + vocoder_(Vocoder::Create(config.model)) {
  34 + InitFrontend();
  35 + }
  36 +
  37 + template <typename Manager>
  38 + OfflineTtsZipvoiceImpl(Manager *mgr, const OfflineTtsConfig &config)
  39 + : config_(config),
  40 + model_(std::make_unique<OfflineTtsZipvoiceModel>(mgr, config.model)),
  41 + vocoder_(Vocoder::Create(mgr, config.model)) {
  42 + InitFrontend(mgr);
  43 + }
  44 +
  45 + int32_t SampleRate() const override {
  46 + return model_->GetMetaData().sample_rate;
  47 + }
  48 +
  49 + GeneratedAudio Generate(
  50 + const std::string &text, const std::string &prompt_text,
  51 + const std::vector<float> &prompt_samples, int32_t sample_rate,
  52 + float speed, int32_t num_steps,
  53 + GeneratedAudioCallback callback = nullptr) const override {
  54 + std::vector<TokenIDs> text_token_ids =
  55 + frontend_->ConvertTextToTokenIds(text);
  56 +
  57 + std::vector<TokenIDs> prompt_token_ids =
  58 + frontend_->ConvertTextToTokenIds(prompt_text);
  59 +
  60 + if (text_token_ids.empty() ||
  61 + (text_token_ids.size() == 1 && text_token_ids[0].tokens.empty())) {
  62 +#if __OHOS__
  63 + SHERPA_ONNX_LOGE("Failed to convert '%{public}s' to token IDs",
  64 + text.c_str());
  65 +#else
  66 + SHERPA_ONNX_LOGE("Failed to convert '%s' to token IDs", text.c_str());
  67 +#endif
  68 + return {};
  69 + }
  70 +
  71 + if (prompt_token_ids.empty() ||
  72 + (prompt_token_ids.size() == 1 && prompt_token_ids[0].tokens.empty())) {
  73 +#if __OHOS__
  74 + SHERPA_ONNX_LOGE(
  75 + "Failed to convert prompt text '%{public}s' to token IDs",
  76 + prompt_text.c_str());
  77 +#else
  78 + SHERPA_ONNX_LOGE("Failed to convert prompt text '%s' to token IDs",
  79 + prompt_text.c_str());
  80 +#endif
  81 + return {};
  82 + }
  83 +
  84 + // we assume batch size is 1
  85 + std::vector<int64_t> tokens = text_token_ids[0].tokens;
  86 + std::vector<int64_t> prompt_tokens = prompt_token_ids[0].tokens;
  87 +
  88 + return Process(tokens, prompt_tokens, prompt_samples, sample_rate, speed,
  89 + num_steps);
  90 + }
  91 +
  92 + private:
  93 + template <typename Manager>
  94 + void InitFrontend(Manager *mgr) {
  95 + const auto &meta_data = model_->GetMetaData();
  96 + frontend_ = std::make_unique<OfflineTtsZipvoiceFrontend>(
  97 + mgr, config_.model.zipvoice.tokens, config_.model.zipvoice.data_dir,
  98 + config_.model.zipvoice.pinyin_dict, meta_data, config_.model.debug);
  99 + }
  100 +
  101 + void InitFrontend() {
  102 + const auto &meta_data = model_->GetMetaData();
  103 +
  104 + if (meta_data.use_pinyin && config_.model.zipvoice.pinyin_dict.empty()) {
  105 + SHERPA_ONNX_LOGE(
  106 + "Please provide --zipvoice-pinyin-dict for converting Chinese into "
  107 + "pinyin.");
  108 + exit(-1);
  109 + }
  110 + if (meta_data.use_espeak && config_.model.zipvoice.data_dir.empty()) {
  111 + SHERPA_ONNX_LOGE("Please provide --zipvoice-data-dir for espeak-ng.");
  112 + exit(-1);
  113 + }
  114 + frontend_ = std::make_unique<OfflineTtsZipvoiceFrontend>(
  115 + config_.model.zipvoice.tokens, config_.model.zipvoice.data_dir,
  116 + config_.model.zipvoice.pinyin_dict, meta_data, config_.model.debug);
  117 + }
  118 +
  119 + std::vector<int32_t> ComputeMelSpectrogram(
  120 + const std::vector<float> &_samples, int32_t sample_rate,
  121 + std::vector<float> *prompt_features) const {
  122 + const auto &meta = model_->GetMetaData();
  123 + if (sample_rate != meta.sample_rate) {
  124 + SHERPA_ONNX_LOGE(
  125 + "Creating a resampler:\n"
  126 + " in_sample_rate: %d\n"
  127 + " output_sample_rate: %d\n",
  128 + sample_rate, static_cast<int32_t>(meta.sample_rate));
  129 +
  130 + float min_freq = std::min<int32_t>(sample_rate, meta.sample_rate);
  131 + float lowpass_cutoff = 0.99 * 0.5 * min_freq;
  132 +
  133 + int32_t lowpass_filter_width = 6;
  134 + auto resampler = std::make_unique<LinearResample>(
  135 + sample_rate, meta.sample_rate, lowpass_cutoff, lowpass_filter_width);
  136 + std::vector<float> samples;
  137 + resampler->Resample(_samples.data(), _samples.size(), true, &samples);
  138 + return ComputeMelSpectrogram(samples, prompt_features);
  139 + } else {
  140 + // Use the original samples if the sample rate matches
  141 + return ComputeMelSpectrogram(_samples, prompt_features);
  142 + }
  143 + }
  144 +
  145 + std::vector<int32_t> ComputeMelSpectrogram(
  146 + const std::vector<float> &samples,
  147 + std::vector<float> *prompt_features) const {
  148 + const auto &meta = model_->GetMetaData();
  149 +
  150 + int32_t sample_rate = meta.sample_rate;
  151 + int32_t n_fft = meta.n_fft;
  152 + int32_t hop_length = meta.hop_length;
  153 + int32_t win_length = meta.window_length;
  154 + int32_t num_mels = meta.num_mels;
  155 +
  156 + knf::StftConfig stft_config;
  157 + stft_config.n_fft = n_fft;
  158 + stft_config.hop_length = hop_length;
  159 + stft_config.win_length = win_length;
  160 + stft_config.window_type = "hann";
  161 + stft_config.center = true;
  162 +
  163 + knf::Stft stft(stft_config);
  164 + auto stft_result = stft.Compute(samples.data(), samples.size());
  165 + int32_t num_frames = stft_result.num_frames;
  166 + int32_t fft_bins = n_fft / 2 + 1;
  167 +
  168 + knf::FrameExtractionOptions frame_opts;
  169 + frame_opts.samp_freq = sample_rate;
  170 + frame_opts.frame_length_ms = win_length * 1000 / sample_rate;
  171 + frame_opts.frame_shift_ms = hop_length * 1000 / sample_rate;
  172 + frame_opts.window_type = "hanning";
  173 +
  174 + knf::MelBanksOptions mel_opts;
  175 + mel_opts.num_bins = num_mels;
  176 + mel_opts.low_freq = 0;
  177 + mel_opts.high_freq = sample_rate / 2;
  178 + mel_opts.is_librosa = true;
  179 + mel_opts.use_slaney_mel_scale = false;
  180 + mel_opts.norm = "";
  181 +
  182 + knf::MelBanks mel_banks(mel_opts, frame_opts, 1.0f);
  183 +
  184 + prompt_features->clear();
  185 + prompt_features->reserve(num_frames * num_mels);
  186 +
  187 + for (int32_t i = 0; i < num_frames; ++i) {
  188 + std::vector<float> magnitude_spectrum(fft_bins);
  189 + for (int32_t k = 0; k < fft_bins; ++k) {
  190 + float real = stft_result.real[i * fft_bins + k];
  191 + float imag = stft_result.imag[i * fft_bins + k];
  192 + magnitude_spectrum[k] = std::sqrt(real * real + imag * imag);
  193 + }
  194 + std::vector<float> mel_features(num_mels, 0.0f);
  195 + mel_banks.Compute(magnitude_spectrum.data(), mel_features.data());
  196 + for (auto &v : mel_features) {
  197 + v = std::log(v + 1e-10f);
  198 + }
  199 + // Instead of push_back a vector, push elements individually
  200 + prompt_features->insert(prompt_features->end(), mel_features.begin(),
  201 + mel_features.end());
  202 + }
  203 + if (num_frames == 0) {
  204 + SHERPA_ONNX_LOGE("No frames extracted from the prompt audio");
  205 + return {0, 0};
  206 + } else {
  207 + return {num_frames, num_mels};
  208 + }
  209 + }
  210 +
  211 + GeneratedAudio Process(const std::vector<int64_t> &tokens,
  212 + const std::vector<int64_t> &prompt_tokens,
  213 + const std::vector<float> &prompt_samples,
  214 + int32_t sample_rate, float speed,
  215 + int num_steps) const {
  216 + auto memory_info =
  217 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  218 +
  219 + std::array<int64_t, 2> tokens_shape = {1,
  220 + static_cast<int64_t>(tokens.size())};
  221 + Ort::Value tokens_tensor = Ort::Value::CreateTensor(
  222 + memory_info, const_cast<int64_t *>(tokens.data()), tokens.size(),
  223 + tokens_shape.data(), tokens_shape.size());
  224 +
  225 + std::array<int64_t, 2> prompt_tokens_shape = {
  226 + 1, static_cast<int64_t>(prompt_tokens.size())};
  227 + Ort::Value prompt_tokens_tensor = Ort::Value::CreateTensor(
  228 + memory_info, const_cast<int64_t *>(prompt_tokens.data()),
  229 + prompt_tokens.size(), prompt_tokens_shape.data(),
  230 + prompt_tokens_shape.size());
  231 +
  232 + float target_rms = config_.model.zipvoice.target_rms;
  233 + float feat_scale = config_.model.zipvoice.feat_scale;
  234 +
  235 + // Scale prompt_samples
  236 + std::vector<float> prompt_samples_scaled = prompt_samples;
  237 + float prompt_rms = 0.0f;
  238 + double sum_sq = 0.0;
  239 + // Compute RMS of prompt_samples
  240 + for (float s : prompt_samples_scaled) {
  241 + sum_sq += s * s;
  242 + }
  243 + prompt_rms = std::sqrt(sum_sq / prompt_samples_scaled.size());
  244 + if (prompt_rms < target_rms && prompt_rms > 0.0f) {
  245 + float scale = target_rms / static_cast<float>(prompt_rms);
  246 + for (auto &s : prompt_samples_scaled) {
  247 + s *= scale;
  248 + }
  249 + }
  250 +
  251 + std::vector<float> prompt_features;
  252 + auto res_shape = ComputeMelSpectrogram(prompt_samples_scaled, sample_rate,
  253 + &prompt_features);
  254 +
  255 + int32_t num_frames = res_shape[0];
  256 + int32_t mel_dim = res_shape[1];
  257 +
  258 + if (feat_scale != 1.0f) {
  259 + for (auto &item : prompt_features) {
  260 + item *= feat_scale;
  261 + }
  262 + }
  263 +
  264 + std::array<int64_t, 3> shape = {1, num_frames, mel_dim};
  265 + auto prompt_features_tensor = Ort::Value::CreateTensor(
  266 + memory_info, prompt_features.data(), prompt_features.size(),
  267 + shape.data(), shape.size());
  268 +
  269 + Ort::Value mel =
  270 + model_->Run(std::move(tokens_tensor), std::move(prompt_tokens_tensor),
  271 + std::move(prompt_features_tensor), speed, num_steps);
  272 +
  273 + // Assume mel_shape = {1, T, C}
  274 + std::vector<int64_t> mel_shape = mel.GetTensorTypeAndShapeInfo().GetShape();
  275 + int64_t T = mel_shape[1], C = mel_shape[2];
  276 +
  277 + float *mel_data = mel.GetTensorMutableData<float>();
  278 + std::vector<float> mel_permuted(C * T);
  279 +
  280 + for (int64_t c = 0; c < C; ++c) {
  281 + for (int64_t t = 0; t < T; ++t) {
  282 + int64_t src_idx = t * C + c; // src: [T, C] (row major)
  283 + int64_t dst_idx = c * T + t; // dst: [C, T] (row major)
  284 + mel_permuted[dst_idx] = mel_data[src_idx] / feat_scale;
  285 + }
  286 + }
  287 +
  288 + std::array<int64_t, 3> new_shape = {1, C, T};
  289 + Ort::Value mel_new = Ort::Value::CreateTensor<float>(
  290 + memory_info, mel_permuted.data(), mel_permuted.size(), new_shape.data(),
  291 + new_shape.size());
  292 +
  293 + GeneratedAudio ans;
  294 + ans.samples = vocoder_->Run(std::move(mel_new));
  295 + ans.sample_rate = model_->GetMetaData().sample_rate;
  296 +
  297 + if (prompt_rms < target_rms && target_rms > 0.0f) {
  298 + float scale = prompt_rms / target_rms;
  299 + for (auto &s : ans.samples) {
  300 + s *= scale;
  301 + }
  302 + }
  303 + return ans;
  304 + }
  305 +
  306 + private:
  307 + OfflineTtsConfig config_;
  308 + std::unique_ptr<OfflineTtsZipvoiceModel> model_;
  309 + std::unique_ptr<Vocoder> vocoder_;
  310 + std::unique_ptr<OfflineTtsFrontend> frontend_;
  311 +};
  312 +
  313 +} // namespace sherpa_onnx
  314 +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_IMPL_H_
  1 +// sherpa-onnx/csrc/offline-tts-zipvoice-model-config.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-tts-zipvoice-model-config.h"
  6 +
  7 +#include <vector>
  8 +
  9 +#include "sherpa-onnx/csrc/file-utils.h"
  10 +#include "sherpa-onnx/csrc/macros.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +void OfflineTtsZipvoiceModelConfig::Register(ParseOptions *po) {
  15 + po->Register("zipvoice-tokens", &tokens,
  16 + "Path to tokens.txt for ZipVoice models");
  17 + po->Register("zipvoice-data-dir", &data_dir,
  18 + "Path to the directory containing dict for espeak-ng.");
  19 + po->Register("zipvoice-pinyin-dict", &pinyin_dict,
  20 + "Path to the pinyin dictionary for cppinyin (i.e converting "
  21 + "Chinese into phones).");
  22 + po->Register("zipvoice-text-model", &text_model,
  23 + "Path to zipvoice text model");
  24 + po->Register("zipvoice-flow-matching-model", &flow_matching_model,
  25 + "Path to zipvoice flow-matching model");
  26 + po->Register("zipvoice-vocoder", &vocoder, "Path to zipvoice vocoder");
  27 + po->Register("zipvoice-feat-scale", &feat_scale,
  28 + "Feature scale for ZipVoice (default: 0.1)");
  29 + po->Register("zipvoice-t-shift", &t_shift,
  30 + "Shift t to smaller ones if t_shift < 1.0 (default: 0.5)");
  31 + po->Register(
  32 + "zipvoice-target-rms", &target_rms,
  33 + "Target speech normalization rms value for ZipVoice (default: 0.1)");
  34 + po->Register(
  35 + "zipvoice-guidance-scale", &guidance_scale,
  36 + "The scale of classifier-free guidance during inference for ZipVoice "
  37 + "(default: 1.0)");
  38 +}
  39 +
  40 +bool OfflineTtsZipvoiceModelConfig::Validate() const {
  41 + if (tokens.empty()) {
  42 + SHERPA_ONNX_LOGE("Please provide --zipvoice-tokens");
  43 + return false;
  44 + }
  45 + if (!FileExists(tokens)) {
  46 + SHERPA_ONNX_LOGE("--zipvoice-tokens: '%s' does not exist", tokens.c_str());
  47 + return false;
  48 + }
  49 +
  50 + if (text_model.empty()) {
  51 + SHERPA_ONNX_LOGE("Please provide --zipvoice-text-model");
  52 + return false;
  53 + }
  54 + if (!FileExists(text_model)) {
  55 + SHERPA_ONNX_LOGE("--zipvoice-text-model: '%s' does not exist",
  56 + text_model.c_str());
  57 + return false;
  58 + }
  59 +
  60 + if (flow_matching_model.empty()) {
  61 + SHERPA_ONNX_LOGE("Please provide --zipvoice-flow-matching-model");
  62 + return false;
  63 + }
  64 + if (!FileExists(flow_matching_model)) {
  65 + SHERPA_ONNX_LOGE("--zipvoice-flow-matching-model: '%s' does not exist",
  66 + flow_matching_model.c_str());
  67 + return false;
  68 + }
  69 +
  70 + if (vocoder.empty()) {
  71 + SHERPA_ONNX_LOGE("Please provide --zipvoice-vocoder");
  72 + return false;
  73 + }
  74 +
  75 + if (!FileExists(vocoder)) {
  76 + SHERPA_ONNX_LOGE("--zipvoice-vocoder: '%s' does not exist",
  77 + vocoder.c_str());
  78 + return false;
  79 + }
  80 +
  81 + if (!data_dir.empty()) {
  82 + std::vector<std::string> required_files = {
  83 + "phontab",
  84 + "phonindex",
  85 + "phondata",
  86 + "intonations",
  87 + };
  88 + for (const auto &f : required_files) {
  89 + if (!FileExists(data_dir + "/" + f)) {
  90 + SHERPA_ONNX_LOGE(
  91 + "'%s/%s' does not exist. Please check zipvoice-data-dir",
  92 + data_dir.c_str(), f.c_str());
  93 + return false;
  94 + }
  95 + }
  96 + }
  97 +
  98 + if (!pinyin_dict.empty() && !FileExists(pinyin_dict)) {
  99 + SHERPA_ONNX_LOGE("--zipvoice-pinyin-dict: '%s' does not exist",
  100 + pinyin_dict.c_str());
  101 + return false;
  102 + }
  103 +
  104 + if (feat_scale <= 0) {
  105 + SHERPA_ONNX_LOGE("--zipvoice-feat-scale must be positive. Given: %f",
  106 + feat_scale);
  107 + return false;
  108 + }
  109 +
  110 + if (t_shift < 0) {
  111 + SHERPA_ONNX_LOGE("--zipvoice-t-shift must be non-negative. Given: %f",
  112 + t_shift);
  113 + return false;
  114 + }
  115 +
  116 + if (target_rms <= 0) {
  117 + SHERPA_ONNX_LOGE("--zipvoice-target-rms must be positive. Given: %f",
  118 + target_rms);
  119 + return false;
  120 + }
  121 +
  122 + if (guidance_scale <= 0) {
  123 + SHERPA_ONNX_LOGE("--zipvoice-guidance-scale must be positive. Given: %f",
  124 + guidance_scale);
  125 + return false;
  126 + }
  127 +
  128 + return true;
  129 +}
  130 +
  131 +std::string OfflineTtsZipvoiceModelConfig::ToString() const {
  132 + std::ostringstream os;
  133 +
  134 + os << "OfflineTtsZipvoiceModelConfig(";
  135 + os << "tokens=\"" << tokens << "\", ";
  136 + os << "text_model=\"" << text_model << "\", ";
  137 + os << "flow_matching_model=\"" << flow_matching_model << "\", ";
  138 + os << "vocoder=\"" << vocoder << "\", ";
  139 + os << "data_dir=\"" << data_dir << "\", ";
  140 + os << "pinyin_dict=\"" << pinyin_dict << "\", ";
  141 + os << "feat_scale=" << feat_scale << ", ";
  142 + os << "t_shift=" << t_shift << ", ";
  143 + os << "target_rms=" << target_rms << ", ";
  144 + os << "guidance_scale=" << guidance_scale << ")";
  145 +
  146 + return os.str();
  147 +}
  148 +
  149 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-tts-zipvoice-model-config.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_CONFIG_H_
  7 +
  8 +#include <cstdint>
  9 +#include <string>
  10 +
  11 +#include "sherpa-onnx/csrc/parse-options.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +struct OfflineTtsZipvoiceModelConfig {
  16 + std::string tokens;
  17 + std::string text_model;
  18 + std::string flow_matching_model;
  19 + std::string vocoder;
  20 +
  21 + // If data_dir is given, lexicon is ignored
  22 + // data_dir is for piper-phonemize, which uses espeak-ng
  23 + std::string data_dir;
  24 +
  25 + // Used for converting Chinese characters to pinyin
  26 + std::string pinyin_dict;
  27 +
  28 + float feat_scale = 0.1;
  29 + float t_shift = 0.5;
  30 + float target_rms = 0.1;
  31 + float guidance_scale = 1.0;
  32 +
  33 + OfflineTtsZipvoiceModelConfig() = default;
  34 +
  35 + OfflineTtsZipvoiceModelConfig(
  36 + const std::string &tokens, const std::string &text_model,
  37 + const std::string &flow_matching_model, const std::string &vocoder,
  38 + const std::string &data_dir, const std::string &pinyin_dict,
  39 + float feat_scale = 0.1, float t_shift = 0.5, float target_rms = 0.1,
  40 + float guidance_scale = 1.0)
  41 + : tokens(tokens),
  42 + text_model(text_model),
  43 + flow_matching_model(flow_matching_model),
  44 + vocoder(vocoder),
  45 + data_dir(data_dir),
  46 + pinyin_dict(pinyin_dict),
  47 + feat_scale(feat_scale),
  48 + t_shift(t_shift),
  49 + target_rms(target_rms),
  50 + guidance_scale(guidance_scale) {}
  51 +
  52 + void Register(ParseOptions *po);
  53 + bool Validate() const;
  54 +
  55 + std::string ToString() const;
  56 +};
  57 +
  58 +} // namespace sherpa_onnx
  59 +
  60 +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/offline-tts-zipvoice-model-meta-data.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_META_DATA_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_META_DATA_H_
  7 +
  8 +#include <cstdint>
  9 +#include <string>
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +// If you are not sure what each field means, please
  14 +// have a look of the Python file in the model directory that
  15 +// you have downloaded.
  16 +struct OfflineTtsZipvoiceModelMetaData {
  17 + int32_t version = 1;
  18 + int32_t feat_dim = 100;
  19 + int32_t sample_rate = 24000;
  20 + int32_t n_fft = 1024;
  21 + int32_t hop_length = 256;
  22 + int32_t window_length = 1024;
  23 + int32_t num_mels = 100;
  24 + int32_t use_espeak = 1;
  25 + int32_t use_pinyin = 1;
  26 +};
  27 +
  28 +} // namespace sherpa_onnx
  29 +
  30 +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_META_DATA_H_
  1 +// sherpa-onnx/csrc/offline-tts-zipvoice-model.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-tts-zipvoice-model.h"
  6 +
  7 +#include <algorithm>
  8 +#include <iostream>
  9 +#include <random>
  10 +#include <string>
  11 +#include <utility>
  12 +#include <vector>
  13 +
  14 +#if __ANDROID_API__ >= 9
  15 +#include "android/asset_manager.h"
  16 +#include "android/asset_manager_jni.h"
  17 +#endif
  18 +
  19 +#if __OHOS__
  20 +#include "rawfile/raw_file_manager.h"
  21 +#endif
  22 +
  23 +#include "sherpa-onnx/csrc/file-utils.h"
  24 +#include "sherpa-onnx/csrc/macros.h"
  25 +#include "sherpa-onnx/csrc/onnx-utils.h"
  26 +#include "sherpa-onnx/csrc/session.h"
  27 +#include "sherpa-onnx/csrc/text-utils.h"
  28 +
  29 +namespace sherpa_onnx {
  30 +
  31 +class OfflineTtsZipvoiceModel::Impl {
  32 + public:
  33 + explicit Impl(const OfflineTtsModelConfig &config)
  34 + : config_(config),
  35 + env_(ORT_LOGGING_LEVEL_ERROR),
  36 + sess_opts_(GetSessionOptions(config)),
  37 + allocator_{} {
  38 + auto text_buf = ReadFile(config.zipvoice.text_model);
  39 + auto fm_buf = ReadFile(config.zipvoice.flow_matching_model);
  40 + Init(text_buf.data(), text_buf.size(), fm_buf.data(), fm_buf.size());
  41 + }
  42 +
  43 + template <typename Manager>
  44 + Impl(Manager *mgr, const OfflineTtsModelConfig &config)
  45 + : config_(config),
  46 + env_(ORT_LOGGING_LEVEL_ERROR),
  47 + sess_opts_(GetSessionOptions(config)),
  48 + allocator_{} {
  49 + auto text_buf = ReadFile(mgr, config.zipvoice.text_model);
  50 + auto fm_buf = ReadFile(mgr, config.zipvoice.flow_matching_model);
  51 + Init(text_buf.data(), text_buf.size(), fm_buf.data(), fm_buf.size());
  52 + }
  53 +
  54 + const OfflineTtsZipvoiceModelMetaData &GetMetaData() const {
  55 + return meta_data_;
  56 + }
  57 +
  58 + Ort::Value Run(Ort::Value tokens, Ort::Value prompt_tokens,
  59 + Ort::Value prompt_features, float speed, int32_t num_steps) {
  60 + auto memory_info =
  61 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  62 +
  63 + std::vector<int64_t> tokens_shape =
  64 + tokens.GetTensorTypeAndShapeInfo().GetShape();
  65 + int64_t batch_size = tokens_shape[0];
  66 + if (batch_size != 1) {
  67 + SHERPA_ONNX_LOGE("Support only batch_size == 1. Given: %d",
  68 + static_cast<int32_t>(batch_size));
  69 + exit(-1);
  70 + }
  71 +
  72 + std::vector<int64_t> prompt_feat_shape =
  73 + prompt_features.GetTensorTypeAndShapeInfo().GetShape();
  74 +
  75 + int64_t prompt_feat_len = prompt_feat_shape[1];
  76 + int64_t prompt_feat_len_shape = 1;
  77 + Ort::Value prompt_feat_len_tensor = Ort::Value::CreateTensor<int64_t>(
  78 + memory_info, &prompt_feat_len, 1, &prompt_feat_len_shape, 1);
  79 +
  80 + int64_t speed_shape = 1;
  81 + Ort::Value speed_tensor = Ort::Value::CreateTensor<float>(
  82 + memory_info, &speed, 1, &speed_shape, 1);
  83 +
  84 + std::vector<Ort::Value> text_inputs;
  85 + text_inputs.reserve(4);
  86 + text_inputs.push_back(std::move(tokens));
  87 + text_inputs.push_back(std::move(prompt_tokens));
  88 + text_inputs.push_back(std::move(prompt_feat_len_tensor));
  89 + text_inputs.push_back(std::move(speed_tensor));
  90 +
  91 + // forward text-encoder
  92 + auto text_out =
  93 + text_sess_->Run({}, text_input_names_ptr_.data(), text_inputs.data(),
  94 + text_inputs.size(), text_output_names_ptr_.data(),
  95 + text_output_names_ptr_.size());
  96 +
  97 + Ort::Value &text_condition = text_out[0];
  98 +
  99 + std::vector<int64_t> text_cond_shape =
  100 + text_condition.GetTensorTypeAndShapeInfo().GetShape();
  101 + int64_t num_frames = text_cond_shape[1];
  102 +
  103 + int64_t feat_dim = meta_data_.feat_dim;
  104 +
  105 + std::vector<float> x_data(batch_size * num_frames * feat_dim);
  106 + std::default_random_engine rng(std::random_device{}());
  107 + std::normal_distribution<float> norm(0, 1);
  108 + for (auto &v : x_data) v = norm(rng);
  109 + std::vector<int64_t> x_shape = {batch_size, num_frames, feat_dim};
  110 + Ort::Value x = Ort::Value::CreateTensor<float>(
  111 + memory_info, x_data.data(), x_data.size(), x_shape.data(),
  112 + x_shape.size());
  113 +
  114 + std::vector<float> speech_cond_data(batch_size * num_frames * feat_dim,
  115 + 0.0f);
  116 + const float *src = prompt_features.GetTensorData<float>();
  117 + float *dst = speech_cond_data.data();
  118 + std::memcpy(dst, src,
  119 + batch_size * prompt_feat_len * feat_dim * sizeof(float));
  120 + std::vector<int64_t> speech_cond_shape = {batch_size, num_frames, feat_dim};
  121 + Ort::Value speech_condition = Ort::Value::CreateTensor<float>(
  122 + memory_info, speech_cond_data.data(), speech_cond_data.size(),
  123 + speech_cond_shape.data(), speech_cond_shape.size());
  124 +
  125 + float t_shift = config_.zipvoice.t_shift;
  126 + float guidance_scale = config_.zipvoice.guidance_scale;
  127 +
  128 + std::vector<float> timesteps(num_steps + 1);
  129 + for (int32_t i = 0; i <= num_steps; ++i) {
  130 + float t = static_cast<float>(i) / num_steps;
  131 + timesteps[i] = t_shift * t / (1.0f + (t_shift - 1.0f) * t);
  132 + }
  133 +
  134 + int64_t guidance_scale_shape = 1;
  135 + Ort::Value guidance_scale_tensor = Ort::Value::CreateTensor<float>(
  136 + memory_info, &guidance_scale, 1, &guidance_scale_shape, 1);
  137 +
  138 + std::vector<Ort::Value> fm_inputs;
  139 + fm_inputs.reserve(5);
  140 + // fm_inputs[0] is t tensor, will set in for loop
  141 + fm_inputs.emplace_back(nullptr);
  142 + fm_inputs.push_back(std::move(x));
  143 + fm_inputs.push_back(std::move(text_condition));
  144 + fm_inputs.push_back(std::move(speech_condition));
  145 + fm_inputs.push_back(std::move(guidance_scale_tensor));
  146 +
  147 + for (int32_t step = 0; step < num_steps; ++step) {
  148 + float t_val = timesteps[step];
  149 + int64_t t_shape = 1;
  150 + Ort::Value t_tensor =
  151 + Ort::Value::CreateTensor<float>(memory_info, &t_val, 1, &t_shape, 1);
  152 + fm_inputs[0] = std::move(t_tensor);
  153 + auto fm_out = fm_sess_->Run(
  154 + {}, fm_input_names_ptr_.data(), fm_inputs.data(), fm_inputs.size(),
  155 + fm_output_names_ptr_.data(), fm_output_names_ptr_.size());
  156 + Ort::Value &v = fm_out[0];
  157 +
  158 + float delta_t = timesteps[step + 1] - timesteps[step];
  159 + float *x_ptr = fm_inputs[1].GetTensorMutableData<float>();
  160 + const float *v_ptr = v.GetTensorData<float>();
  161 + int64_t N = batch_size * num_frames * feat_dim;
  162 + for (int64_t i = 0; i < N; ++i) {
  163 + x_ptr[i] += v_ptr[i] * delta_t;
  164 + }
  165 + }
  166 +
  167 + int64_t keep_frames = num_frames - prompt_feat_len;
  168 + std::vector<float> out_data(batch_size * keep_frames * feat_dim);
  169 + x = std::move(fm_inputs[1]);
  170 + const float *x_ptr = x.GetTensorData<float>();
  171 + for (int64_t b = 0; b < batch_size; ++b) {
  172 + std::memcpy(out_data.data() + b * keep_frames * feat_dim,
  173 + x_ptr + (b * num_frames + prompt_feat_len) * feat_dim,
  174 + keep_frames * feat_dim * sizeof(float));
  175 + }
  176 + std::vector<int64_t> out_shape = {batch_size, keep_frames, feat_dim};
  177 + return Ort::Value::CreateTensor<float>(memory_info, out_data.data(),
  178 + out_data.size(), out_shape.data(),
  179 + out_shape.size());
  180 + }
  181 +
  182 + private:
  183 + void Init(void *text_model_data, size_t text_model_data_length,
  184 + void *fm_model_data, size_t fm_model_data_length) {
  185 + // Init text-encoder model
  186 + text_sess_ = std::make_unique<Ort::Session>(
  187 + env_, text_model_data, text_model_data_length, sess_opts_);
  188 + GetInputNames(text_sess_.get(), &text_input_names_, &text_input_names_ptr_);
  189 + GetOutputNames(text_sess_.get(), &text_output_names_,
  190 + &text_output_names_ptr_);
  191 +
  192 + // Init flow-matching model
  193 + fm_sess_ = std::make_unique<Ort::Session>(env_, fm_model_data,
  194 + fm_model_data_length, sess_opts_);
  195 + GetInputNames(fm_sess_.get(), &fm_input_names_, &fm_input_names_ptr_);
  196 + GetOutputNames(fm_sess_.get(), &fm_output_names_, &fm_output_names_ptr_);
  197 +
  198 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  199 +
  200 + Ort::ModelMetadata meta_data = text_sess_->GetModelMetadata();
  201 + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.use_espeak, "use_espeak",
  202 + 1);
  203 + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.use_pinyin, "use_pinyin",
  204 + 1);
  205 +
  206 + meta_data = fm_sess_->GetModelMetadata();
  207 + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.version, "version", 1);
  208 + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.feat_dim, "feat_dim",
  209 + 100);
  210 + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.sample_rate,
  211 + "sample_rate", 24000);
  212 + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.n_fft, "n_fft", 1024);
  213 + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.hop_length, "hop_length",
  214 + 256);
  215 + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.window_length,
  216 + "window_length", 1024);
  217 + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.num_mels, "num_mels",
  218 + 100);
  219 +
  220 + if (config_.debug) {
  221 + std::ostringstream os;
  222 +
  223 + os << "---zipvoice text-encoder model---\n";
  224 + Ort::ModelMetadata text_meta_data = text_sess_->GetModelMetadata();
  225 + PrintModelMetadata(os, text_meta_data);
  226 +
  227 + os << "----------input names----------\n";
  228 + int32_t i = 0;
  229 + for (const auto &s : text_input_names_) {
  230 + os << i << " " << s << "\n";
  231 + ++i;
  232 + }
  233 + os << "----------output names----------\n";
  234 + i = 0;
  235 + for (const auto &s : text_output_names_) {
  236 + os << i << " " << s << "\n";
  237 + ++i;
  238 + }
  239 +
  240 + os << "---zipvoice flow-matching model---\n";
  241 + PrintModelMetadata(os, meta_data);
  242 +
  243 + os << "----------input names----------\n";
  244 + i = 0;
  245 + for (const auto &s : fm_input_names_) {
  246 + os << i << " " << s << "\n";
  247 + ++i;
  248 + }
  249 + os << "----------output names----------\n";
  250 + i = 0;
  251 + for (const auto &s : fm_output_names_) {
  252 + os << i << " " << s << "\n";
  253 + ++i;
  254 + }
  255 +
  256 +#if __OHOS__
  257 + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str());
  258 +#else
  259 + SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
  260 +#endif
  261 + }
  262 + }
  263 +
  264 + private:
  265 + OfflineTtsModelConfig config_;
  266 + Ort::Env env_;
  267 + Ort::SessionOptions sess_opts_;
  268 + Ort::AllocatorWithDefaultOptions allocator_;
  269 +
  270 + std::unique_ptr<Ort::Session> text_sess_;
  271 + std::unique_ptr<Ort::Session> fm_sess_;
  272 +
  273 + std::vector<std::string> text_input_names_;
  274 + std::vector<const char *> text_input_names_ptr_;
  275 +
  276 + std::vector<std::string> text_output_names_;
  277 + std::vector<const char *> text_output_names_ptr_;
  278 +
  279 + std::vector<std::string> fm_input_names_;
  280 + std::vector<const char *> fm_input_names_ptr_;
  281 +
  282 + std::vector<std::string> fm_output_names_;
  283 + std::vector<const char *> fm_output_names_ptr_;
  284 +
  285 + OfflineTtsZipvoiceModelMetaData meta_data_;
  286 +};
  287 +
  288 +OfflineTtsZipvoiceModel::OfflineTtsZipvoiceModel(
  289 + const OfflineTtsModelConfig &config)
  290 + : impl_(std::make_unique<Impl>(config)) {}
  291 +
  292 +template <typename Manager>
  293 +OfflineTtsZipvoiceModel::OfflineTtsZipvoiceModel(
  294 + Manager *mgr, const OfflineTtsModelConfig &config)
  295 + : impl_(std::make_unique<Impl>(mgr, config)) {}
  296 +
  297 +OfflineTtsZipvoiceModel::~OfflineTtsZipvoiceModel() = default;
  298 +
  299 +const OfflineTtsZipvoiceModelMetaData &OfflineTtsZipvoiceModel::GetMetaData()
  300 + const {
  301 + return impl_->GetMetaData();
  302 +}
  303 +
  304 +Ort::Value OfflineTtsZipvoiceModel::Run(Ort::Value tokens,
  305 + Ort::Value prompt_tokens,
  306 + Ort::Value prompt_features,
  307 + float speed /*= 1.0*/,
  308 + int32_t num_steps /*= 16*/) const {
  309 + return impl_->Run(std::move(tokens), std::move(prompt_tokens),
  310 + std::move(prompt_features), speed, num_steps);
  311 +}
  312 +
  313 +#if __ANDROID_API__ >= 9
  314 +template OfflineTtsZipvoiceModel::OfflineTtsZipvoiceModel(
  315 + AAssetManager *mgr, const OfflineTtsModelConfig &config);
  316 +#endif
  317 +
  318 +#if __OHOS__
  319 +template OfflineTtsZipvoiceModel::OfflineTtsZipvoiceModel(
  320 + NativeResourceManager *mgr, const OfflineTtsModelConfig &config);
  321 +#endif
  322 +
  323 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-tts-zipvoice-model.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_H_
  7 +
  8 +#include <memory>
  9 +#include <string>
  10 +
  11 +#include "onnxruntime_cxx_api.h" // NOLINT
  12 +#include "sherpa-onnx/csrc/offline-tts-model-config.h"
  13 +#include "sherpa-onnx/csrc/offline-tts-zipvoice-model-meta-data.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +class OfflineTtsZipvoiceModel {
  18 + public:
  19 + ~OfflineTtsZipvoiceModel();
  20 +
  21 + explicit OfflineTtsZipvoiceModel(const OfflineTtsModelConfig &config);
  22 +
  23 + template <typename Manager>
  24 + OfflineTtsZipvoiceModel(Manager *mgr, const OfflineTtsModelConfig &config);
  25 +
  26 + // Return a float32 tensor containing the mel
  27 + // of shape (batch_size, mel_dim, num_frames)
  28 + Ort::Value Run(Ort::Value tokens, Ort::Value prompt_tokens,
  29 + Ort::Value prompt_features, float speed,
  30 + int32_t num_steps) const;
  31 +
  32 + const OfflineTtsZipvoiceModelMetaData &GetMetaData() const;
  33 +
  34 + private:
  35 + class Impl;
  36 + std::unique_ptr<Impl> impl_;
  37 +};
  38 +
  39 +} // namespace sherpa_onnx
  40 +
  41 +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_H_
@@ -196,6 +196,46 @@ GeneratedAudio OfflineTts::Generate( @@ -196,6 +196,46 @@ GeneratedAudio OfflineTts::Generate(
196 #endif 196 #endif
197 } 197 }
198 198
  199 +GeneratedAudio OfflineTts::Generate(
  200 + const std::string &text, const std::string &prompt_text,
  201 + const std::vector<float> &prompt_samples, int32_t sample_rate,
  202 + float speed /*=1.0*/, int32_t num_steps /*=4*/,
  203 + GeneratedAudioCallback callback /*=nullptr*/) const {
  204 +#if !defined(_WIN32)
  205 + return impl_->Generate(text, prompt_text, prompt_samples, sample_rate, speed,
  206 + num_steps, std::move(callback));
  207 +#else
  208 + static bool printed = false;
  209 + auto utf8_text = text;
  210 + if (IsGB2312(text)) {
  211 + utf8_text = Gb2312ToUtf8(text);
  212 + if (!printed) {
  213 + SHERPA_ONNX_LOGE("Detected GB2312 encoded text! Converting it to UTF8.");
  214 + printed = true;
  215 + }
  216 + }
  217 + auto utf8_prompt_text = prompt_text;
  218 + if (IsGB2312(prompt_text)) {
  219 + utf8_prompt_text = Gb2312ToUtf8(prompt_text);
  220 + if (!printed) {
  221 + SHERPA_ONNX_LOGE(
  222 + "Detected GB2312 encoded prompt text! Converting it to UTF8.");
  223 + printed = true;
  224 + }
  225 + }
  226 + if (IsUtf8(utf8_text) && IsUtf8(utf8_prompt_text)) {
  227 + return impl_->Generate(utf8_text, utf8_prompt_text, prompt_samples,
  228 + sample_rate, speed, num_steps, std::move(callback));
  229 + } else {
  230 + SHERPA_ONNX_LOGE(
  231 + "Non UTF8 encoded string is received. You would not get expected "
  232 + "results!");
  233 + return impl_->Generate(utf8_text, utf8_prompt_text, prompt_samples,
  234 + sample_rate, speed, num_steps, std::move(callback));
  235 + }
  236 +#endif
  237 +}
  238 +
199 int32_t OfflineTts::SampleRate() const { return impl_->SampleRate(); } 239 int32_t OfflineTts::SampleRate() const { return impl_->SampleRate(); }
200 240
201 int32_t OfflineTts::NumSpeakers() const { return impl_->NumSpeakers(); } 241 int32_t OfflineTts::NumSpeakers() const { return impl_->NumSpeakers(); }
@@ -95,6 +95,26 @@ class OfflineTts { @@ -95,6 +95,26 @@ class OfflineTts {
95 float speed = 1.0, 95 float speed = 1.0,
96 GeneratedAudioCallback callback = nullptr) const; 96 GeneratedAudioCallback callback = nullptr) const;
97 97
  98 + // @param text The string to be synthesized.
  99 + // @param prompt_text The transcribe of `prompt_sampes`.
  100 + // @param prompt_samples The prompt audio samples (mono PCM floats in [-1,1]).
  101 + // @param sample_rate The sample rate of `prompt_audio` in Hz.
  102 + // @param speed The speed for the generated speech. E.g., 2 means 2x faster.
  103 + // @param num_steps The number of flow steps to generate the audio.
  104 + // @param callback If not NULL, it is called whenever config.max_num_sentences
  105 + // sentences have been processed. Note that the passed
  106 + // pointer `samples` for the callback might be invalidated
  107 + // after the callback is returned, so the caller should not
  108 + // keep a reference to it. The caller can copy the data if
  109 + // he/she wants to access the samples after the callback
  110 + // returns. The callback is called in the current thread.
  111 + GeneratedAudio Generate(const std::string &text,
  112 + const std::string &prompt_text,
  113 + const std::vector<float> &prompt_samples,
  114 + int32_t sample_rate, float speed = 1.0,
  115 + int32_t num_steps = 4,
  116 + GeneratedAudioCallback callback = nullptr) const;
  117 +
98 // Return the sample rate of the generated audio 118 // Return the sample rate of the generated audio
99 int32_t SampleRate() const; 119 int32_t SampleRate() const;
100 120
  1 +// sherpa-onnx/csrc/sherpa-onnx-offline-zeroshot-tts.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include <chrono> // NOLINT
  6 +#include <fstream>
  7 +
  8 +#include "sherpa-onnx/csrc/offline-tts.h"
  9 +#include "sherpa-onnx/csrc/parse-options.h"
  10 +#include "sherpa-onnx/csrc/wave-reader.h"
  11 +#include "sherpa-onnx/csrc/wave-writer.h"
  12 +
  13 +static int32_t AudioCallback(const float * /*samples*/, int32_t n,
  14 + float progress) {
  15 + printf("sample=%d, progress=%f\n", n, progress);
  16 + return 1;
  17 +}
  18 +
  19 +int main(int32_t argc, char *argv[]) {
  20 + const char *kUsageMessage = R"usage(
  21 +Offline/Non-streaming zero-shot text-to-speech with sherpa-onnx
  22 +
  23 +Usage example:
  24 +
  25 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/sherpa-onnx-zipvoice-distill-zh-en-emilia.tar.bz2
  26 +tar xf sherpa-onnx-zipvoice-distill-zh-en-emilia.tar.bz2
  27 +
  28 +./bin/sherpa-onnx-offline-zeroshot-tts \
  29 + --zipvoice-flow-matching-model=sherpa-onnx-zipvoice-distill-zh-en-emilia/fm_decoder.onnx \
  30 + --zipvoice-text-model=sherpa-onnx-zipvoice-distill-zh-en-emilia/text_encoder.onnx \
  31 + --zipvoice-data-dir=sherpa-onnx-zipvoice-distill-zh-en-emilia/espeak-ng-data \
  32 + --zipvoice-pinyin-dict=sherpa-onnx-zipvoice-distill-zh-en-emilia/pinyin.raw \
  33 + --zipvoice-tokens=sherpa-onnx-zipvoice-distill-zh-en-emilia/tokens.txt \
  34 + --zipvoice-vocoder=sherpa-onnx-zipvoice-distill-zh-en-emilia/vocos_24khz.onnx \
  35 + --prompt-audio=sherpa-onnx-zipvoice-distill-zh-en-emilia/prompt.wav \
  36 + --num-steps=4 \
  37 + --num-threads=4 \
  38 + --prompt-text="周日被我射熄火了,所以今天是周一。" \
  39 + "我是中国人民的儿子,我爱我的祖国。我得祖国是一个伟大的国家,拥有五千年的文明史。"
  40 +
  41 +It will generate a file ./generated.wav as specified by --output-filename.
  42 +)usage";
  43 +
  44 + sherpa_onnx::ParseOptions po(kUsageMessage);
  45 + std::string output_filename = "./generated.wav";
  46 +
  47 + int32_t num_steps = 4;
  48 + float speed = 1.0;
  49 + std::string prompt_text;
  50 + std::string prompt_audio;
  51 +
  52 + po.Register("output-filename", &output_filename,
  53 + "Path to save the generated audio");
  54 +
  55 + po.Register("num-steps", &num_steps,
  56 + "Number of inference steps for ZipVoice (default: 4)");
  57 +
  58 + po.Register("speed", &speed,
  59 + "Speech speed for ZipVoice (default: 1.0, larger=faster, "
  60 + "smaller=slower)");
  61 +
  62 + po.Register("prompt-text", &prompt_text, "The transcribe of prompt_samples.");
  63 +
  64 + po.Register("prompt-audio", &prompt_audio,
  65 + "The prompt audio file, single channel pcm. ");
  66 +
  67 + sherpa_onnx::OfflineTtsConfig config;
  68 +
  69 + config.Register(&po);
  70 + po.Read(argc, argv);
  71 +
  72 + if (po.NumArgs() == 0) {
  73 + fprintf(stderr, "Error: Please provide the text to generate audio.\n\n");
  74 + po.PrintUsage();
  75 + exit(EXIT_FAILURE);
  76 + }
  77 +
  78 + if (po.NumArgs() > 1) {
  79 + fprintf(stderr,
  80 + "Error: Accept only one positional argument. Please use single "
  81 + "quotes to wrap your text\n");
  82 + po.PrintUsage();
  83 + exit(EXIT_FAILURE);
  84 + }
  85 +
  86 + if (config.model.debug) {
  87 + fprintf(stderr, "%s\n", config.model.ToString().c_str());
  88 + }
  89 +
  90 + if (!config.Validate()) {
  91 + fprintf(stderr, "Errors in config!\n");
  92 + exit(EXIT_FAILURE);
  93 + }
  94 +
  95 + if (prompt_text.empty() || prompt_audio.empty()) {
  96 + fprintf(stderr, "Please provide both --prompt-text and --prompt-audio\n");
  97 + exit(EXIT_FAILURE);
  98 + }
  99 +
  100 + sherpa_onnx::OfflineTts tts(config);
  101 +
  102 + int32_t sample_rate = -1;
  103 + bool is_ok = false;
  104 + const std::vector<float> prompt_samples =
  105 + sherpa_onnx::ReadWave(prompt_audio, &sample_rate, &is_ok);
  106 +
  107 + if (!is_ok) {
  108 + fprintf(stderr, "Failed to read '%s'\n", prompt_audio.c_str());
  109 + return -1;
  110 + }
  111 +
  112 + const auto begin = std::chrono::steady_clock::now();
  113 + auto audio = tts.Generate(po.GetArg(1), prompt_text, prompt_samples,
  114 + sample_rate, speed, num_steps, AudioCallback);
  115 + const auto end = std::chrono::steady_clock::now();
  116 +
  117 + if (audio.samples.empty()) {
  118 + fprintf(
  119 + stderr,
  120 + "Error in generating audio. Please read previous error messages.\n");
  121 + exit(EXIT_FAILURE);
  122 + }
  123 +
  124 + float elapsed_seconds =
  125 + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
  126 + .count() /
  127 + 1000.;
  128 + float duration = audio.samples.size() / static_cast<float>(audio.sample_rate);
  129 +
  130 + float rtf = elapsed_seconds / duration;
  131 + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
  132 + fprintf(stderr, "Audio duration: %.3f s\n", duration);
  133 + fprintf(stderr, "Real-time factor (RTF): %.3f/%.3f = %.3f\n", elapsed_seconds,
  134 + duration, rtf);
  135 +
  136 + bool ok = sherpa_onnx::WriteWave(output_filename, audio.sample_rate,
  137 + audio.samples.data(), audio.samples.size());
  138 + if (!ok) {
  139 + fprintf(stderr, "Failed to write wave to %s\n", output_filename.c_str());
  140 + exit(EXIT_FAILURE);
  141 + }
  142 +
  143 + fprintf(stderr, "The text is: %s.\n", po.GetArg(1).c_str());
  144 + fprintf(stderr, "Saved to %s successfully!\n", output_filename.c_str());
  145 +
  146 + return 0;
  147 +}
@@ -4,6 +4,9 @@ @@ -4,6 +4,9 @@
4 4
5 #include "sherpa-onnx/csrc/text-utils.h" 5 #include "sherpa-onnx/csrc/text-utils.h"
6 6
  7 +#include <regex>
  8 +#include <sstream>
  9 +
7 #include "gtest/gtest.h" 10 #include "gtest/gtest.h"
8 11
9 namespace sherpa_onnx { 12 namespace sherpa_onnx {
@@ -55,7 +58,6 @@ TEST(RemoveInvalidUtf8Sequences, Case1) { @@ -55,7 +58,6 @@ TEST(RemoveInvalidUtf8Sequences, Case1) {
55 EXPECT_EQ(s.size() + 4, v.size()); 58 EXPECT_EQ(s.size() + 4, v.size());
56 } 59 }
57 60
58 -  
59 // Tests for sanitizeUtf8 61 // Tests for sanitizeUtf8
60 TEST(RemoveInvalidUtf8Sequences, ValidUtf8StringPassesUnchanged) { 62 TEST(RemoveInvalidUtf8Sequences, ValidUtf8StringPassesUnchanged) {
61 std::string input = "Valid UTF-8 🌍"; 63 std::string input = "Valid UTF-8 🌍";
@@ -724,4 +724,62 @@ std::vector<std::string> SplitString(const std::string &s, int32_t chunk_size) { @@ -724,4 +724,62 @@ std::vector<std::string> SplitString(const std::string &s, int32_t chunk_size) {
724 return ans; 724 return ans;
725 } 725 }
726 726
  727 +std::u32string Utf8ToUtf32(const std::string &str) {
  728 + std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> conv;
  729 + return conv.from_bytes(str);
  730 +}
  731 +
  732 +std::string Utf32ToUtf8(const std::u32string &str) {
  733 + std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> conv;
  734 + return conv.to_bytes(str);
  735 +}
  736 +
  737 +// Helper: Convert ASCII chars in a std::string to uppercase (leaves non-ASCII
  738 +// unchanged)
  739 +std::string ToUpperAscii(const std::string &str) {
  740 + std::string out = str;
  741 + for (char &c : out) {
  742 + unsigned char uc = static_cast<unsigned char>(c);
  743 + if (uc >= 'a' && uc <= 'z') {
  744 + c = static_cast<char>(uc - 'a' + 'A');
  745 + }
  746 + }
  747 + return out;
  748 +}
  749 +
  750 +// Helper: Convert ASCII chars in a std::string to lowercase (leaves non-ASCII
  751 +// unchanged)
  752 +std::string ToLowerAscii(const std::string &str) {
  753 + std::string out = str;
  754 + for (char &c : out) {
  755 + unsigned char uc = static_cast<unsigned char>(c);
  756 + if (uc >= 'A' && uc <= 'Z') {
  757 + c = static_cast<char>(uc - 'A' + 'a');
  758 + }
  759 + }
  760 + return out;
  761 +}
  762 +
  763 +// Detect if a codepoint is a CJK character
  764 +bool IsCJK(char32_t cp) {
  765 + return (cp >= 0x1100 && cp <= 0x11FF) || (cp >= 0x2E80 && cp <= 0xA4CF) ||
  766 + (cp >= 0xA840 && cp <= 0xD7AF) || (cp >= 0xF900 && cp <= 0xFAFF) ||
  767 + (cp >= 0xFE30 && cp <= 0xFE4F) || (cp >= 0xFF65 && cp <= 0xFFDC) ||
  768 + (cp >= 0x20000 && cp <= 0x2FFFF);
  769 +}
  770 +
  771 +bool ContainsCJK(const std::string &text) {
  772 + std::u32string utf32_text = Utf8ToUtf32(text);
  773 + return ContainsCJK(utf32_text);
  774 +}
  775 +
  776 +bool ContainsCJK(const std::u32string &text) {
  777 + for (char32_t cp : text) {
  778 + if (IsCJK(cp)) {
  779 + return true;
  780 + }
  781 + }
  782 + return false;
  783 +}
  784 +
727 } // namespace sherpa_onnx 785 } // namespace sherpa_onnx
@@ -149,6 +149,29 @@ bool EndsWith(const std::string &haystack, const std::string &needle); @@ -149,6 +149,29 @@ bool EndsWith(const std::string &haystack, const std::string &needle);
149 149
150 std::vector<std::string> SplitString(const std::string &s, int32_t chunk_size); 150 std::vector<std::string> SplitString(const std::string &s, int32_t chunk_size);
151 151
  152 +// Converts a UTF-8 std::string to a UTF-32 std::u32string
  153 +std::u32string Utf8ToUtf32(const std::string &str);
  154 +
  155 +// Converts a UTF-32 std::u32string to a UTF-8 std::string
  156 +std::string Utf32ToUtf8(const std::u32string &str);
  157 +
  158 +// Helper: Convert ASCII chars in a std::string to uppercase (leaves non-ASCII
  159 +// unchanged)
  160 +std::string ToUpperAscii(const std::string &str);
  161 +
  162 +// Helper: Convert ASCII chars in a std::string to lowercase (leaves non-ASCII
  163 +// unchanged)
  164 +std::string ToLowerAscii(const std::string &str);
  165 +
  166 +// Detect if a codepoint is a CJK character
  167 +bool IsCJK(char32_t cp);
  168 +
  169 +bool ContainsCJK(const std::string &text);
  170 +
  171 +bool ContainsCJK(const std::u32string &text);
  172 +
  173 +bool StringToBool(const std::string &s);
  174 +
152 } // namespace sherpa_onnx 175 } // namespace sherpa_onnx
153 176
154 #endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_ 177 #endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_
@@ -74,7 +74,18 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -74,7 +74,18 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
74 } 74 }
75 75
76 std::unique_ptr<Vocoder> Vocoder::Create(const OfflineTtsModelConfig &config) { 76 std::unique_ptr<Vocoder> Vocoder::Create(const OfflineTtsModelConfig &config) {
77 - auto buffer = ReadFile(config.matcha.vocoder); 77 + std::vector<char> buffer;
  78 + if (!config.matcha.vocoder.empty()) {
  79 + SHERPA_ONNX_LOGE("Using matcha vocoder: %s", config.matcha.vocoder.c_str());
  80 + buffer = ReadFile(config.matcha.vocoder);
  81 + } else if (!config.zipvoice.vocoder.empty()) {
  82 + SHERPA_ONNX_LOGE("Using zipvoice vocoder: %s",
  83 + config.zipvoice.vocoder.c_str());
  84 + buffer = ReadFile(config.zipvoice.vocoder);
  85 + } else {
  86 + SHERPA_ONNX_LOGE("No vocoder model provided in the config!");
  87 + exit(-1);
  88 + }
78 auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug); 89 auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
79 90
80 switch (model_type) { 91 switch (model_type) {
@@ -94,7 +105,19 @@ std::unique_ptr<Vocoder> Vocoder::Create(const OfflineTtsModelConfig &config) { @@ -94,7 +105,19 @@ std::unique_ptr<Vocoder> Vocoder::Create(const OfflineTtsModelConfig &config) {
94 template <typename Manager> 105 template <typename Manager>
95 std::unique_ptr<Vocoder> Vocoder::Create(Manager *mgr, 106 std::unique_ptr<Vocoder> Vocoder::Create(Manager *mgr,
96 const OfflineTtsModelConfig &config) { 107 const OfflineTtsModelConfig &config) {
97 - auto buffer = ReadFile(mgr, config.matcha.vocoder); 108 + std::vector<char> buffer;
  109 + if (!config.matcha.vocoder.empty()) {
  110 + SHERPA_ONNX_LOGE("Using matcha vocoder: %s", config.matcha.vocoder.c_str());
  111 + buffer = ReadFile(mgr, config.matcha.vocoder);
  112 + } else if (!config.zipvoice.vocoder.empty()) {
  113 + SHERPA_ONNX_LOGE("Using zipvoice vocoder: %s",
  114 + config.zipvoice.vocoder.c_str());
  115 + buffer = ReadFile(mgr, config.zipvoice.vocoder);
  116 + } else {
  117 + SHERPA_ONNX_LOGE("No vocoder model provided in the config!");
  118 + return nullptr;
  119 + }
  120 +
98 auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug); 121 auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
99 122
100 switch (model_type) { 123 switch (model_type) {
@@ -42,8 +42,16 @@ class VocosVocoder::Impl { @@ -42,8 +42,16 @@ class VocosVocoder::Impl {
42 env_(ORT_LOGGING_LEVEL_ERROR), 42 env_(ORT_LOGGING_LEVEL_ERROR),
43 sess_opts_(GetSessionOptions(config.num_threads, config.provider)), 43 sess_opts_(GetSessionOptions(config.num_threads, config.provider)),
44 allocator_{} { 44 allocator_{} {
45 - auto buf = ReadFile(config.matcha.vocoder);  
46 - Init(buf.data(), buf.size()); 45 + std::vector<char> buffer;
  46 + if (!config.matcha.vocoder.empty()) {
  47 + buffer = ReadFile(config.matcha.vocoder);
  48 + } else if (!config.zipvoice.vocoder.empty()) {
  49 + buffer = ReadFile(config.zipvoice.vocoder);
  50 + } else {
  51 + SHERPA_ONNX_LOGE("No vocoder model provided in the config!");
  52 + exit(-1);
  53 + }
  54 + Init(buffer.data(), buffer.size());
47 } 55 }
48 56
49 template <typename Manager> 57 template <typename Manager>
@@ -52,8 +60,16 @@ class VocosVocoder::Impl { @@ -52,8 +60,16 @@ class VocosVocoder::Impl {
52 env_(ORT_LOGGING_LEVEL_ERROR), 60 env_(ORT_LOGGING_LEVEL_ERROR),
53 sess_opts_(GetSessionOptions(config.num_threads, config.provider)), 61 sess_opts_(GetSessionOptions(config.num_threads, config.provider)),
54 allocator_{} { 62 allocator_{} {
55 - auto buf = ReadFile(mgr, config.matcha.vocoder);  
56 - Init(buf.data(), buf.size()); 63 + std::vector<char> buffer;
  64 + if (!config.matcha.vocoder.empty()) {
  65 + buffer = ReadFile(mgr, config.matcha.vocoder);
  66 + } else if (!config.zipvoice.vocoder.empty()) {
  67 + buffer = ReadFile(mgr, config.zipvoice.vocoder);
  68 + } else {
  69 + SHERPA_ONNX_LOGE("No vocoder model provided in the config!");
  70 + exit(-1);
  71 + }
  72 + Init(buffer.data(), buffer.size());
57 } 73 }
58 74
59 std::vector<float> Run(Ort::Value mel) const { 75 std::vector<float> Run(Ort::Value mel) const {
@@ -72,6 +72,7 @@ if(SHERPA_ONNX_ENABLE_TTS) @@ -72,6 +72,7 @@ if(SHERPA_ONNX_ENABLE_TTS)
72 offline-tts-matcha-model-config.cc 72 offline-tts-matcha-model-config.cc
73 offline-tts-model-config.cc 73 offline-tts-model-config.cc
74 offline-tts-vits-model-config.cc 74 offline-tts-vits-model-config.cc
  75 + offline-tts-zipvoice-model-config.cc
75 offline-tts.cc 76 offline-tts.cc
76 ) 77 )
77 endif() 78 endif()
@@ -11,6 +11,7 @@ @@ -11,6 +11,7 @@
11 #include "sherpa-onnx/python/csrc/offline-tts-kokoro-model-config.h" 11 #include "sherpa-onnx/python/csrc/offline-tts-kokoro-model-config.h"
12 #include "sherpa-onnx/python/csrc/offline-tts-matcha-model-config.h" 12 #include "sherpa-onnx/python/csrc/offline-tts-matcha-model-config.h"
13 #include "sherpa-onnx/python/csrc/offline-tts-vits-model-config.h" 13 #include "sherpa-onnx/python/csrc/offline-tts-vits-model-config.h"
  14 +#include "sherpa-onnx/python/csrc/offline-tts-zipvoice-model-config.h"
14 15
15 namespace sherpa_onnx { 16 namespace sherpa_onnx {
16 17
@@ -18,6 +19,7 @@ void PybindOfflineTtsModelConfig(py::module *m) { @@ -18,6 +19,7 @@ void PybindOfflineTtsModelConfig(py::module *m) {
18 PybindOfflineTtsVitsModelConfig(m); 19 PybindOfflineTtsVitsModelConfig(m);
19 PybindOfflineTtsMatchaModelConfig(m); 20 PybindOfflineTtsMatchaModelConfig(m);
20 PybindOfflineTtsKokoroModelConfig(m); 21 PybindOfflineTtsKokoroModelConfig(m);
  22 + PybindOfflineTtsZipvoiceModelConfig(m);
21 PybindOfflineTtsKittenModelConfig(m); 23 PybindOfflineTtsKittenModelConfig(m);
22 24
23 using PyClass = OfflineTtsModelConfig; 25 using PyClass = OfflineTtsModelConfig;
@@ -27,17 +29,20 @@ void PybindOfflineTtsModelConfig(py::module *m) { @@ -27,17 +29,20 @@ void PybindOfflineTtsModelConfig(py::module *m) {
27 .def(py::init<const OfflineTtsVitsModelConfig &, 29 .def(py::init<const OfflineTtsVitsModelConfig &,
28 const OfflineTtsMatchaModelConfig &, 30 const OfflineTtsMatchaModelConfig &,
29 const OfflineTtsKokoroModelConfig &, 31 const OfflineTtsKokoroModelConfig &,
  32 + const OfflineTtsZipvoiceModelConfig &,
30 const OfflineTtsKittenModelConfig &, int32_t, bool, 33 const OfflineTtsKittenModelConfig &, int32_t, bool,
31 const std::string &>(), 34 const std::string &>(),
32 py::arg("vits") = OfflineTtsVitsModelConfig{}, 35 py::arg("vits") = OfflineTtsVitsModelConfig{},
33 py::arg("matcha") = OfflineTtsMatchaModelConfig{}, 36 py::arg("matcha") = OfflineTtsMatchaModelConfig{},
34 py::arg("kokoro") = OfflineTtsKokoroModelConfig{}, 37 py::arg("kokoro") = OfflineTtsKokoroModelConfig{},
  38 + py::arg("zipvoice") = OfflineTtsZipvoiceModelConfig{},
35 py::arg("kitten") = OfflineTtsKittenModelConfig{}, 39 py::arg("kitten") = OfflineTtsKittenModelConfig{},
36 py::arg("num_threads") = 1, py::arg("debug") = false, 40 py::arg("num_threads") = 1, py::arg("debug") = false,
37 py::arg("provider") = "cpu") 41 py::arg("provider") = "cpu")
38 .def_readwrite("vits", &PyClass::vits) 42 .def_readwrite("vits", &PyClass::vits)
39 .def_readwrite("matcha", &PyClass::matcha) 43 .def_readwrite("matcha", &PyClass::matcha)
40 .def_readwrite("kokoro", &PyClass::kokoro) 44 .def_readwrite("kokoro", &PyClass::kokoro)
  45 + .def_readwrite("zipvoice", &PyClass::zipvoice)
41 .def_readwrite("kitten", &PyClass::kitten) 46 .def_readwrite("kitten", &PyClass::kitten)
42 .def_readwrite("num_threads", &PyClass::num_threads) 47 .def_readwrite("num_threads", &PyClass::num_threads)
43 .def_readwrite("debug", &PyClass::debug) 48 .def_readwrite("debug", &PyClass::debug)
  1 +// sherpa-onnx/python/csrc/offline-tts-zipvoice-model-config.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/offline-tts-zipvoice-model-config.h"
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/offline-tts-zipvoice-model-config.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +void PybindOfflineTtsZipvoiceModelConfig(py::module *m) {
  14 + using PyClass = OfflineTtsZipvoiceModelConfig;
  15 +
  16 + py::class_<PyClass>(*m, "OfflineTtsZipvoiceModelConfig")
  17 + .def(py::init<>())
  18 + .def(py::init<const std::string &, const std::string &,
  19 + const std::string &, const std::string &,
  20 + const std::string &, const std::string &, float, float,
  21 + float, float>(),
  22 + py::arg("tokens"), py::arg("text_model"),
  23 + py::arg("flow_matching_model"), py::arg("vocoder"),
  24 + py::arg("data_dir") = "", py::arg("pinyin_dict") = "",
  25 + py::arg("feat_scale") = 0.1, py::arg("t_shift") = 0.5,
  26 + py::arg("target_rms") = 0.1, py::arg("guidance_scale") = 1.0)
  27 + .def_readwrite("tokens", &PyClass::tokens)
  28 + .def_readwrite("text_model", &PyClass::text_model)
  29 + .def_readwrite("flow_matching_model", &PyClass::flow_matching_model)
  30 + .def_readwrite("vocoder", &PyClass::vocoder)
  31 + .def_readwrite("data_dir", &PyClass::data_dir)
  32 + .def_readwrite("pinyin_dict", &PyClass::pinyin_dict)
  33 + .def_readwrite("feat_scale", &PyClass::feat_scale)
  34 + .def_readwrite("t_shift", &PyClass::t_shift)
  35 + .def_readwrite("target_rms", &PyClass::target_rms)
  36 + .def_readwrite("guidance_scale", &PyClass::guidance_scale)
  37 + .def("__str__", &PyClass::ToString)
  38 + .def("validate", &PyClass::Validate);
  39 +}
  40 +
  41 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-tts-zipvoice-model-config.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineTtsZipvoiceModelConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_CONFIG_H_
@@ -84,6 +84,41 @@ void PybindOfflineTts(py::module *m) { @@ -84,6 +84,41 @@ void PybindOfflineTts(py::module *m) {
84 }, 84 },
85 py::arg("text"), py::arg("sid") = 0, py::arg("speed") = 1.0, 85 py::arg("text"), py::arg("sid") = 0, py::arg("speed") = 1.0,
86 py::arg("callback") = py::none(), 86 py::arg("callback") = py::none(),
  87 + py::call_guard<py::gil_scoped_release>())
  88 + .def(
  89 + "generate",
  90 + [](const PyClass &self, const std::string &text,
  91 + const std::string &prompt_text,
  92 + const std::vector<float> &prompt_samples, int32_t sample_rate,
  93 + float speed, int32_t num_steps,
  94 + std::function<int32_t(py::array_t<float>, float)> callback)
  95 + -> GeneratedAudio {
  96 + if (!callback) {
  97 + return self.Generate(text, prompt_text, prompt_samples,
  98 + sample_rate, speed, num_steps);
  99 + }
  100 +
  101 + std::function<int32_t(const float *, int32_t, float)>
  102 + callback_wrapper = [callback](const float *samples, int32_t n,
  103 + float progress) {
  104 + // CAUTION(fangjun): we have to copy samples since it is
  105 + // freed once the call back returns.
  106 +
  107 + pybind11::gil_scoped_acquire acquire;
  108 +
  109 + pybind11::array_t<float> array(n);
  110 + py::buffer_info buf = array.request();
  111 + auto p = static_cast<float *>(buf.ptr);
  112 + std::copy(samples, samples + n, p);
  113 + return callback(array, progress);
  114 + };
  115 +
  116 + return self.Generate(text, prompt_text, prompt_samples, sample_rate,
  117 + speed, num_steps, callback_wrapper);
  118 + },
  119 + py::arg("text"), py::arg("prompt_text"), py::arg("prompt_samples"),
  120 + py::arg("sample_rate"), py::arg("speed") = 1.0,
  121 + py::arg("num_steps") = 4, py::arg("callback") = py::none(),
87 py::call_guard<py::gil_scoped_release>()); 122 py::call_guard<py::gil_scoped_release>());
88 } 123 }
89 124
@@ -49,6 +49,7 @@ from sherpa_onnx.lib._sherpa_onnx import ( @@ -49,6 +49,7 @@ from sherpa_onnx.lib._sherpa_onnx import (
49 OfflineTtsMatchaModelConfig, 49 OfflineTtsMatchaModelConfig,
50 OfflineTtsModelConfig, 50 OfflineTtsModelConfig,
51 OfflineTtsVitsModelConfig, 51 OfflineTtsVitsModelConfig,
  52 + OfflineTtsZipvoiceModelConfig,
52 OfflineWenetCtcModelConfig, 53 OfflineWenetCtcModelConfig,
53 OfflineWhisperModelConfig, 54 OfflineWhisperModelConfig,
54 OfflineZipformerAudioTaggingModelConfig, 55 OfflineZipformerAudioTaggingModelConfig,