Committed by
GitHub
Add Zipvoice (#2487)
Co-authored-by: yaozengwei <yaozengwei@outlook.com>
正在显示
33 个修改的文件
包含
2301 行增加
和
18 行删除
| @@ -372,6 +372,7 @@ endif() | @@ -372,6 +372,7 @@ endif() | ||
| 372 | include(kaldi-native-fbank) | 372 | include(kaldi-native-fbank) |
| 373 | include(kaldi-decoder) | 373 | include(kaldi-decoder) |
| 374 | include(onnxruntime) | 374 | include(onnxruntime) |
| 375 | +include(cppinyin) | ||
| 375 | include(simple-sentencepiece) | 376 | include(simple-sentencepiece) |
| 376 | set(ONNXRUNTIME_DIR ${onnxruntime_SOURCE_DIR}) | 377 | set(ONNXRUNTIME_DIR ${onnxruntime_SOURCE_DIR}) |
| 377 | message(STATUS "ONNXRUNTIME_DIR: ${ONNXRUNTIME_DIR}") | 378 | message(STATUS "ONNXRUNTIME_DIR: ${ONNXRUNTIME_DIR}") |
cmake/cppinyin.cmake
0 → 100644
| 1 | +function(download_cppinyin) | ||
| 2 | + include(FetchContent) | ||
| 3 | + | ||
| 4 | + set(cppinyin_URL "https://github.com/pkufool/cppinyin/archive/refs/tags/v0.10.tar.gz") | ||
| 5 | + set(cppinyin_URL2 "https://gh-proxy.com/https://github.com/pkufool/cppinyin/archive/refs/tags/v0.10.tar.gz") | ||
| 6 | + set(cppinyin_HASH "SHA256=abe6584d7ee56829e8f4b5fbda3b50ecdf49a13be8e413a78d1b0d5d5c019982") | ||
| 7 | + | ||
| 8 | + # If you don't have access to the Internet, | ||
| 9 | + # please pre-download cppinyin | ||
| 10 | + set(possible_file_locations | ||
| 11 | + $ENV{HOME}/Downloads/cppinyin-0.10.tar.gz | ||
| 12 | + ${CMAKE_SOURCE_DIR}/cppinyin-0.10.tar.gz | ||
| 13 | + ${CMAKE_BINARY_DIR}/cppinyin-0.10.tar.gz | ||
| 14 | + /tmp/cppinyin-0.10.tar.gz | ||
| 15 | + /star-fj/fangjun/download/github/cppinyin-0.10.tar.gz | ||
| 16 | + ) | ||
| 17 | + | ||
| 18 | + foreach(f IN LISTS possible_file_locations) | ||
| 19 | + if(EXISTS ${f}) | ||
| 20 | + set(cppinyin_URL "${f}") | ||
| 21 | + file(TO_CMAKE_PATH "${cppinyin_URL}" cppinyin_URL) | ||
| 22 | + message(STATUS "Found local downloaded cppinyin: ${cppinyin_URL}") | ||
| 23 | + set(cppinyin_URL2) | ||
| 24 | + break() | ||
| 25 | + endif() | ||
| 26 | + endforeach() | ||
| 27 | + | ||
| 28 | + set(CPPINYIN_ENABLE_TESTS OFF CACHE BOOL "" FORCE) | ||
| 29 | + set(CPPINYIN_BUILD_PYTHON OFF CACHE BOOL "" FORCE) | ||
| 30 | + | ||
| 31 | + FetchContent_Declare(cppinyin | ||
| 32 | + URL | ||
| 33 | + ${cppinyin_URL} | ||
| 34 | + ${cppinyin_URL2} | ||
| 35 | + URL_HASH | ||
| 36 | + ${cppinyin_HASH} | ||
| 37 | + ) | ||
| 38 | + | ||
| 39 | + FetchContent_GetProperties(cppinyin) | ||
| 40 | + if(NOT cppinyin_POPULATED) | ||
| 41 | + message(STATUS "Downloading cppinyin ${cppinyin_URL}") | ||
| 42 | + FetchContent_Populate(cppinyin) | ||
| 43 | + endif() | ||
| 44 | + message(STATUS "cppinyin is downloaded to ${cppinyin_SOURCE_DIR}") | ||
| 45 | + | ||
| 46 | + if(BUILD_SHARED_LIBS) | ||
| 47 | + set(_build_shared_libs_bak ${BUILD_SHARED_LIBS}) | ||
| 48 | + set(BUILD_SHARED_LIBS OFF) | ||
| 49 | + endif() | ||
| 50 | + | ||
| 51 | + add_subdirectory(${cppinyin_SOURCE_DIR} ${cppinyin_BINARY_DIR} EXCLUDE_FROM_ALL) | ||
| 52 | + | ||
| 53 | + if(_build_shared_libs_bak) | ||
| 54 | + set_target_properties(cppinyin_core | ||
| 55 | + PROPERTIES | ||
| 56 | + POSITION_INDEPENDENT_CODE ON | ||
| 57 | + C_VISIBILITY_PRESET hidden | ||
| 58 | + CXX_VISIBILITY_PRESET hidden | ||
| 59 | + ) | ||
| 60 | + set(BUILD_SHARED_LIBS ON) | ||
| 61 | + endif() | ||
| 62 | + | ||
| 63 | + target_include_directories(cppinyin_core | ||
| 64 | + PUBLIC | ||
| 65 | + ${cppinyin_SOURCE_DIR}/ | ||
| 66 | + ) | ||
| 67 | + | ||
| 68 | + if(NOT BUILD_SHARED_LIBS) | ||
| 69 | + install(TARGETS cppinyin_core DESTINATION lib) | ||
| 70 | + endif() | ||
| 71 | + | ||
| 72 | +endfunction() | ||
| 73 | + | ||
| 74 | +download_cppinyin() |
| 1 | #!/usr/bin/env python3 | 1 | #!/usr/bin/env python3 |
| 2 | # | 2 | # |
| 3 | -# Copyright (c) 2023 Xiaomi Corporation | 3 | +# Copyright (c) 2023-2025 Xiaomi Corporation |
| 4 | 4 | ||
| 5 | """ | 5 | """ |
| 6 | This file demonstrates how to use sherpa-onnx Python API to generate audio | 6 | This file demonstrates how to use sherpa-onnx Python API to generate audio |
| @@ -453,7 +453,9 @@ def main(): | @@ -453,7 +453,9 @@ def main(): | ||
| 453 | end = time.time() | 453 | end = time.time() |
| 454 | 454 | ||
| 455 | if len(audio.samples) == 0: | 455 | if len(audio.samples) == 0: |
| 456 | - print("Error in generating audios. Please read previous error messages.") | 456 | + print( |
| 457 | + "Error in generating audios. Please read previous error messages." | ||
| 458 | + ) | ||
| 457 | return | 459 | return |
| 458 | 460 | ||
| 459 | elapsed_seconds = end - start | 461 | elapsed_seconds = end - start |
| @@ -470,7 +472,9 @@ def main(): | @@ -470,7 +472,9 @@ def main(): | ||
| 470 | print(f"The text is '{args.text}'") | 472 | print(f"The text is '{args.text}'") |
| 471 | print(f"Elapsed seconds: {elapsed_seconds:.3f}") | 473 | print(f"Elapsed seconds: {elapsed_seconds:.3f}") |
| 472 | print(f"Audio duration in seconds: {audio_duration:.3f}") | 474 | print(f"Audio duration in seconds: {audio_duration:.3f}") |
| 473 | - print(f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}") | 475 | + print( |
| 476 | + f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}" | ||
| 477 | + ) | ||
| 474 | 478 | ||
| 475 | 479 | ||
| 476 | if __name__ == "__main__": | 480 | if __name__ == "__main__": |
python-api-examples/offline-zeroshot-tts.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# | ||
| 3 | +# Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +""" | ||
| 6 | +This file demonstrates how to use sherpa-onnx Python API to generate audio | ||
| 7 | +from text with prompt, i.e., zero shot text-to-speech. | ||
| 8 | + | ||
| 9 | +Usage: | ||
| 10 | + | ||
| 11 | +Example (zipvoice) | ||
| 12 | + | ||
| 13 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/sherpa-onnx-zipvoice-distill-zh-en-emilia.tar.bz2 | ||
| 14 | +tar xf sherpa-onnx-zipvoice-distill-zh-en-emilia.tar.bz2 | ||
| 15 | + | ||
| 16 | +python3 ./python-api-examples/offline-zeroshot-tts.py \ | ||
| 17 | + --zipvoice-flow-matching-model sherpa-onnx-zipvoice-distill-zh-en-emilia/fm_decoder.onnx \ | ||
| 18 | + --zipvoice-text-model sherpa-onnx-zipvoice-distill-zh-en-emilia/text_encoder.onnx \ | ||
| 19 | + --zipvoice-data-dir sherpa-onnx-zipvoice-distill-zh-en-emilia/espeak-ng-data \ | ||
| 20 | + --zipvoice-pinyin-dict sherpa-onnx-zipvoice-distill-zh-en-emilia/pinyin.raw \ | ||
| 21 | + --zipvoice-tokens sherpa-onnx-zipvoice-distill-zh-en-emilia/tokens.txt \ | ||
| 22 | + --zipvoice-vocoder sherpa-onnx-zipvoice-distill-zh-en-emilia/vocos_24khz.onnx \ | ||
| 23 | + --prompt-audio sherpa-onnx-zipvoice-distill-zh-en-emilia/prompt.wav \ | ||
| 24 | + --zipvoice-num-steps 4 \ | ||
| 25 | + --num-threads 4 \ | ||
| 26 | + --prompt-text "周日被我射熄火了,所以今天是周一。" \ | ||
| 27 | + "我是中国人民的儿子,我爱我的祖国。我得祖国是一个伟大的国家,拥有五千年的文明史。" | ||
| 28 | +""" | ||
| 29 | + | ||
| 30 | +import argparse | ||
| 31 | +import time | ||
| 32 | +import wave | ||
| 33 | +import numpy as np | ||
| 34 | + | ||
| 35 | +from typing import Tuple | ||
| 36 | + | ||
| 37 | +import sherpa_onnx | ||
| 38 | +import soundfile as sf | ||
| 39 | + | ||
| 40 | + | ||
| 41 | +def add_zipvoice_args(parser): | ||
| 42 | + parser.add_argument( | ||
| 43 | + "--zipvoice-tokens", | ||
| 44 | + type=str, | ||
| 45 | + default="", | ||
| 46 | + help="Path to tokens.txt for Zipvoice models.", | ||
| 47 | + ) | ||
| 48 | + | ||
| 49 | + parser.add_argument( | ||
| 50 | + "--zipvoice-text-model", | ||
| 51 | + type=str, | ||
| 52 | + default="", | ||
| 53 | + help="Path to zipvoice text model.", | ||
| 54 | + ) | ||
| 55 | + | ||
| 56 | + parser.add_argument( | ||
| 57 | + "--zipvoice-flow-matching-model", | ||
| 58 | + type=str, | ||
| 59 | + default="", | ||
| 60 | + help="Path to zipvoice flow matching model.", | ||
| 61 | + ) | ||
| 62 | + | ||
| 63 | + parser.add_argument( | ||
| 64 | + "--zipvoice-data-dir", | ||
| 65 | + type=str, | ||
| 66 | + default="", | ||
| 67 | + help="Path to the dict directory of espeak-ng.", | ||
| 68 | + ) | ||
| 69 | + | ||
| 70 | + parser.add_argument( | ||
| 71 | + "--zipvoice-pinyin-dict", | ||
| 72 | + type=str, | ||
| 73 | + default="", | ||
| 74 | + help="Path to the pinyin dictionary.", | ||
| 75 | + ) | ||
| 76 | + | ||
| 77 | + parser.add_argument( | ||
| 78 | + "--zipvoice-vocoder", | ||
| 79 | + type=str, | ||
| 80 | + default="", | ||
| 81 | + help="Path to the vocos vocoder.", | ||
| 82 | + ) | ||
| 83 | + | ||
| 84 | + parser.add_argument( | ||
| 85 | + "--zipvoice-num-steps", | ||
| 86 | + type=int, | ||
| 87 | + default=4, | ||
| 88 | + help="Number of steps for Zipvoice.", | ||
| 89 | + ) | ||
| 90 | + | ||
| 91 | + parser.add_argument( | ||
| 92 | + "--zipvoice-feat-scale", | ||
| 93 | + type=float, | ||
| 94 | + default=0.1, | ||
| 95 | + help="Scale factor for Zipvoice features.", | ||
| 96 | + ) | ||
| 97 | + | ||
| 98 | + parser.add_argument( | ||
| 99 | + "--zipvoice-t-shift", | ||
| 100 | + type=float, | ||
| 101 | + default=0.5, | ||
| 102 | + help="Shift t to smaller ones if t-shift < 1.0.", | ||
| 103 | + ) | ||
| 104 | + | ||
| 105 | + parser.add_argument( | ||
| 106 | + "--zipvoice-target-rms", | ||
| 107 | + type=float, | ||
| 108 | + default=0.1, | ||
| 109 | + help="Target speech normalization RMS value for Zipvoice.", | ||
| 110 | + ) | ||
| 111 | + | ||
| 112 | + parser.add_argument( | ||
| 113 | + "--zipvoice-guidance-scale", | ||
| 114 | + type=float, | ||
| 115 | + default=1.0, | ||
| 116 | + help="The scale of classifier-free guidance during inference for for Zipvoice.", | ||
| 117 | + ) | ||
| 118 | + | ||
| 119 | + | ||
| 120 | +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: | ||
| 121 | + """ | ||
| 122 | + Args: | ||
| 123 | + wave_filename: | ||
| 124 | + Path to a wave file. It should be single channel and each sample should | ||
| 125 | + be 16-bit. Its sample rate does not need to be 16kHz. | ||
| 126 | + Returns: | ||
| 127 | + Return a tuple containing: | ||
| 128 | + - A 1-D array of dtype np.float32 containing the samples, which are | ||
| 129 | + normalized to the range [-1, 1]. | ||
| 130 | + - sample rate of the wave file | ||
| 131 | + """ | ||
| 132 | + | ||
| 133 | + with wave.open(wave_filename) as f: | ||
| 134 | + assert f.getnchannels() == 1, f.getnchannels() | ||
| 135 | + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes | ||
| 136 | + num_samples = f.getnframes() | ||
| 137 | + samples = f.readframes(num_samples) | ||
| 138 | + samples_int16 = np.frombuffer(samples, dtype=np.int16) | ||
| 139 | + samples_float32 = samples_int16.astype(np.float32) | ||
| 140 | + | ||
| 141 | + samples_float32 = samples_float32 / 32768 | ||
| 142 | + return samples_float32, f.getframerate() | ||
| 143 | + | ||
| 144 | + | ||
| 145 | +def get_args(): | ||
| 146 | + parser = argparse.ArgumentParser( | ||
| 147 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
| 148 | + ) | ||
| 149 | + | ||
| 150 | + add_zipvoice_args(parser) | ||
| 151 | + | ||
| 152 | + parser.add_argument( | ||
| 153 | + "--tts-rule-fsts", | ||
| 154 | + type=str, | ||
| 155 | + default="", | ||
| 156 | + help="Path to rule.fst", | ||
| 157 | + ) | ||
| 158 | + | ||
| 159 | + parser.add_argument( | ||
| 160 | + "--max-num-sentences", | ||
| 161 | + type=int, | ||
| 162 | + default=1, | ||
| 163 | + help="""Max number of sentences in a batch to avoid OOM if the input | ||
| 164 | + text is very long. Set it to -1 to process all the sentences in a | ||
| 165 | + single batch. A smaller value does not mean it is slower compared | ||
| 166 | + to a larger one on CPU. | ||
| 167 | + """, | ||
| 168 | + ) | ||
| 169 | + | ||
| 170 | + parser.add_argument( | ||
| 171 | + "--output-filename", | ||
| 172 | + type=str, | ||
| 173 | + default="./generated.wav", | ||
| 174 | + help="Path to save generated wave", | ||
| 175 | + ) | ||
| 176 | + | ||
| 177 | + parser.add_argument( | ||
| 178 | + "--debug", | ||
| 179 | + type=bool, | ||
| 180 | + default=False, | ||
| 181 | + help="True to show debug messages", | ||
| 182 | + ) | ||
| 183 | + | ||
| 184 | + parser.add_argument( | ||
| 185 | + "--provider", | ||
| 186 | + type=str, | ||
| 187 | + default="cpu", | ||
| 188 | + help="valid values: cpu, cuda, coreml", | ||
| 189 | + ) | ||
| 190 | + | ||
| 191 | + parser.add_argument( | ||
| 192 | + "--num-threads", | ||
| 193 | + type=int, | ||
| 194 | + default=1, | ||
| 195 | + help="Number of threads for neural network computation", | ||
| 196 | + ) | ||
| 197 | + | ||
| 198 | + parser.add_argument( | ||
| 199 | + "--speed", | ||
| 200 | + type=float, | ||
| 201 | + default=1.0, | ||
| 202 | + help="Speech speed. Larger->faster; smaller->slower", | ||
| 203 | + ) | ||
| 204 | + | ||
| 205 | + parser.add_argument( | ||
| 206 | + "--prompt-text", | ||
| 207 | + type=str, | ||
| 208 | + required=True, | ||
| 209 | + help="The transcription of prompt audio (Zipvoice)", | ||
| 210 | + ) | ||
| 211 | + | ||
| 212 | + parser.add_argument( | ||
| 213 | + "--prompt-audio", | ||
| 214 | + type=str, | ||
| 215 | + required=True, | ||
| 216 | + help="The path to prompt audio (Zipvoice).", | ||
| 217 | + ) | ||
| 218 | + | ||
| 219 | + parser.add_argument( | ||
| 220 | + "text", | ||
| 221 | + type=str, | ||
| 222 | + help="The input text to generate audio for", | ||
| 223 | + ) | ||
| 224 | + | ||
| 225 | + return parser.parse_args() | ||
| 226 | + | ||
| 227 | + | ||
| 228 | +def main(): | ||
| 229 | + args = get_args() | ||
| 230 | + print(args) | ||
| 231 | + | ||
| 232 | + tts_config = sherpa_onnx.OfflineTtsConfig( | ||
| 233 | + model=sherpa_onnx.OfflineTtsModelConfig( | ||
| 234 | + zipvoice=sherpa_onnx.OfflineTtsZipvoiceModelConfig( | ||
| 235 | + tokens=args.zipvoice_tokens, | ||
| 236 | + text_model=args.zipvoice_text_model, | ||
| 237 | + flow_matching_model=args.zipvoice_flow_matching_model, | ||
| 238 | + data_dir=args.zipvoice_data_dir, | ||
| 239 | + pinyin_dict=args.zipvoice_pinyin_dict, | ||
| 240 | + vocoder=args.zipvoice_vocoder, | ||
| 241 | + feat_scale=args.zipvoice_feat_scale, | ||
| 242 | + t_shift=args.zipvoice_t_shift, | ||
| 243 | + target_rms=args.zipvoice_target_rms, | ||
| 244 | + guidance_scale=args.zipvoice_guidance_scale, | ||
| 245 | + ), | ||
| 246 | + provider=args.provider, | ||
| 247 | + debug=args.debug, | ||
| 248 | + num_threads=args.num_threads, | ||
| 249 | + ), | ||
| 250 | + rule_fsts=args.tts_rule_fsts, | ||
| 251 | + max_num_sentences=args.max_num_sentences, | ||
| 252 | + ) | ||
| 253 | + if not tts_config.validate(): | ||
| 254 | + raise ValueError("Please check your config") | ||
| 255 | + | ||
| 256 | + tts = sherpa_onnx.OfflineTts(tts_config) | ||
| 257 | + | ||
| 258 | + start = time.time() | ||
| 259 | + prompt_samples, sample_rate = read_wave(args.prompt_audio) | ||
| 260 | + audio = tts.generate( | ||
| 261 | + args.text, | ||
| 262 | + args.prompt_text, | ||
| 263 | + prompt_samples, | ||
| 264 | + sample_rate, | ||
| 265 | + speed=args.speed, | ||
| 266 | + num_steps=args.zipvoice_num_steps, | ||
| 267 | + ) | ||
| 268 | + end = time.time() | ||
| 269 | + | ||
| 270 | + if len(audio.samples) == 0: | ||
| 271 | + print( | ||
| 272 | + "Error in generating audios. Please read previous error messages." | ||
| 273 | + ) | ||
| 274 | + return | ||
| 275 | + | ||
| 276 | + elapsed_seconds = end - start | ||
| 277 | + audio_duration = len(audio.samples) / audio.sample_rate | ||
| 278 | + real_time_factor = elapsed_seconds / audio_duration | ||
| 279 | + | ||
| 280 | + sf.write( | ||
| 281 | + args.output_filename, | ||
| 282 | + audio.samples, | ||
| 283 | + samplerate=audio.sample_rate, | ||
| 284 | + subtype="PCM_16", | ||
| 285 | + ) | ||
| 286 | + print(f"Saved to {args.output_filename}") | ||
| 287 | + print(f"The text is '{args.text}'") | ||
| 288 | + print(f"Elapsed seconds: {elapsed_seconds:.3f}") | ||
| 289 | + print(f"Audio duration in seconds: {audio_duration:.3f}") | ||
| 290 | + print( | ||
| 291 | + f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}" | ||
| 292 | + ) | ||
| 293 | + | ||
| 294 | + | ||
| 295 | +if __name__ == "__main__": | ||
| 296 | + main() |
| @@ -201,6 +201,9 @@ if(SHERPA_ONNX_ENABLE_TTS) | @@ -201,6 +201,9 @@ if(SHERPA_ONNX_ENABLE_TTS) | ||
| 201 | offline-tts-model-config.cc | 201 | offline-tts-model-config.cc |
| 202 | offline-tts-vits-model-config.cc | 202 | offline-tts-vits-model-config.cc |
| 203 | offline-tts-vits-model.cc | 203 | offline-tts-vits-model.cc |
| 204 | + offline-tts-zipvoice-frontend.cc | ||
| 205 | + offline-tts-zipvoice-model.cc | ||
| 206 | + offline-tts-zipvoice-model-config.cc | ||
| 204 | offline-tts.cc | 207 | offline-tts.cc |
| 205 | piper-phonemize-lexicon.cc | 208 | piper-phonemize-lexicon.cc |
| 206 | vocoder.cc | 209 | vocoder.cc |
| @@ -265,6 +268,7 @@ if(ANDROID_NDK) | @@ -265,6 +268,7 @@ if(ANDROID_NDK) | ||
| 265 | endif() | 268 | endif() |
| 266 | 269 | ||
| 267 | target_link_libraries(sherpa-onnx-core | 270 | target_link_libraries(sherpa-onnx-core |
| 271 | + cppinyin_core | ||
| 268 | kaldi-native-fbank-core | 272 | kaldi-native-fbank-core |
| 269 | kaldi-decoder-core | 273 | kaldi-decoder-core |
| 270 | ssentencepiece_core | 274 | ssentencepiece_core |
| @@ -348,6 +352,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) | @@ -348,6 +352,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) | ||
| 348 | 352 | ||
| 349 | if(SHERPA_ONNX_ENABLE_TTS) | 353 | if(SHERPA_ONNX_ENABLE_TTS) |
| 350 | add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) | 354 | add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) |
| 355 | + add_executable(sherpa-onnx-offline-zeroshot-tts sherpa-onnx-offline-zeroshot-tts.cc) | ||
| 351 | endif() | 356 | endif() |
| 352 | 357 | ||
| 353 | if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) | 358 | if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) |
| @@ -370,6 +375,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) | @@ -370,6 +375,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) | ||
| 370 | if(SHERPA_ONNX_ENABLE_TTS) | 375 | if(SHERPA_ONNX_ENABLE_TTS) |
| 371 | list(APPEND main_exes | 376 | list(APPEND main_exes |
| 372 | sherpa-onnx-offline-tts | 377 | sherpa-onnx-offline-tts |
| 378 | + sherpa-onnx-offline-zeroshot-tts | ||
| 373 | ) | 379 | ) |
| 374 | endif() | 380 | endif() |
| 375 | 381 | ||
| @@ -667,6 +673,7 @@ if(SHERPA_ONNX_ENABLE_TESTS) | @@ -667,6 +673,7 @@ if(SHERPA_ONNX_ENABLE_TESTS) | ||
| 667 | if(SHERPA_ONNX_ENABLE_TTS) | 673 | if(SHERPA_ONNX_ENABLE_TTS) |
| 668 | list(APPEND sherpa_onnx_test_srcs | 674 | list(APPEND sherpa_onnx_test_srcs |
| 669 | cppjieba-test.cc | 675 | cppjieba-test.cc |
| 676 | + offline-tts-zipvoice-frontend-test.cc | ||
| 670 | piper-phonemize-test.cc | 677 | piper-phonemize-test.cc |
| 671 | ) | 678 | ) |
| 672 | endif() | 679 | endif() |
| @@ -20,6 +20,7 @@ | @@ -20,6 +20,7 @@ | ||
| 20 | #include "sherpa-onnx/csrc/offline-tts-kokoro-impl.h" | 20 | #include "sherpa-onnx/csrc/offline-tts-kokoro-impl.h" |
| 21 | #include "sherpa-onnx/csrc/offline-tts-matcha-impl.h" | 21 | #include "sherpa-onnx/csrc/offline-tts-matcha-impl.h" |
| 22 | #include "sherpa-onnx/csrc/offline-tts-vits-impl.h" | 22 | #include "sherpa-onnx/csrc/offline-tts-vits-impl.h" |
| 23 | +#include "sherpa-onnx/csrc/offline-tts-zipvoice-impl.h" | ||
| 23 | 24 | ||
| 24 | namespace sherpa_onnx { | 25 | namespace sherpa_onnx { |
| 25 | 26 | ||
| @@ -41,6 +42,9 @@ std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create( | @@ -41,6 +42,9 @@ std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create( | ||
| 41 | return std::make_unique<OfflineTtsVitsImpl>(config); | 42 | return std::make_unique<OfflineTtsVitsImpl>(config); |
| 42 | } else if (!config.model.matcha.acoustic_model.empty()) { | 43 | } else if (!config.model.matcha.acoustic_model.empty()) { |
| 43 | return std::make_unique<OfflineTtsMatchaImpl>(config); | 44 | return std::make_unique<OfflineTtsMatchaImpl>(config); |
| 45 | + } else if (!config.model.zipvoice.text_model.empty() && | ||
| 46 | + !config.model.zipvoice.flow_matching_model.empty()) { | ||
| 47 | + return std::make_unique<OfflineTtsZipvoiceImpl>(config); | ||
| 44 | } else if (!config.model.kokoro.model.empty()) { | 48 | } else if (!config.model.kokoro.model.empty()) { |
| 45 | return std::make_unique<OfflineTtsKokoroImpl>(config); | 49 | return std::make_unique<OfflineTtsKokoroImpl>(config); |
| 46 | } else if (!config.model.kitten.model.empty()) { | 50 | } else if (!config.model.kitten.model.empty()) { |
| @@ -59,6 +63,9 @@ std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create( | @@ -59,6 +63,9 @@ std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create( | ||
| 59 | return std::make_unique<OfflineTtsVitsImpl>(mgr, config); | 63 | return std::make_unique<OfflineTtsVitsImpl>(mgr, config); |
| 60 | } else if (!config.model.matcha.acoustic_model.empty()) { | 64 | } else if (!config.model.matcha.acoustic_model.empty()) { |
| 61 | return std::make_unique<OfflineTtsMatchaImpl>(mgr, config); | 65 | return std::make_unique<OfflineTtsMatchaImpl>(mgr, config); |
| 66 | + } else if (!config.model.zipvoice.text_model.empty() && | ||
| 67 | + !config.model.zipvoice.flow_matching_model.empty()) { | ||
| 68 | + return std::make_unique<OfflineTtsZipvoiceImpl>(mgr, config); | ||
| 62 | } else if (!config.model.kokoro.model.empty()) { | 69 | } else if (!config.model.kokoro.model.empty()) { |
| 63 | return std::make_unique<OfflineTtsKokoroImpl>(mgr, config); | 70 | return std::make_unique<OfflineTtsKokoroImpl>(mgr, config); |
| 64 | } else if (!config.model.kitten.model.empty()) { | 71 | } else if (!config.model.kitten.model.empty()) { |
| @@ -6,9 +6,11 @@ | @@ -6,9 +6,11 @@ | ||
| 6 | #define SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_ | 6 | #define SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_ |
| 7 | 7 | ||
| 8 | #include <memory> | 8 | #include <memory> |
| 9 | +#include <stdexcept> | ||
| 9 | #include <string> | 10 | #include <string> |
| 10 | #include <vector> | 11 | #include <vector> |
| 11 | 12 | ||
| 13 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 12 | #include "sherpa-onnx/csrc/offline-tts.h" | 14 | #include "sherpa-onnx/csrc/offline-tts.h" |
| 13 | 15 | ||
| 14 | namespace sherpa_onnx { | 16 | namespace sherpa_onnx { |
| @@ -25,14 +27,29 @@ class OfflineTtsImpl { | @@ -25,14 +27,29 @@ class OfflineTtsImpl { | ||
| 25 | 27 | ||
| 26 | virtual GeneratedAudio Generate( | 28 | virtual GeneratedAudio Generate( |
| 27 | const std::string &text, int64_t sid = 0, float speed = 1.0, | 29 | const std::string &text, int64_t sid = 0, float speed = 1.0, |
| 28 | - GeneratedAudioCallback callback = nullptr) const = 0; | 30 | + GeneratedAudioCallback callback = nullptr) const { |
| 31 | + throw std::runtime_error( | ||
| 32 | + "OfflineTtsImpl backend does not support non zero-shot Generate()"); | ||
| 33 | + } | ||
| 34 | + | ||
| 35 | + virtual GeneratedAudio Generate( | ||
| 36 | + const std::string &text, const std::string &prompt_text, | ||
| 37 | + const std::vector<float> &prompt_samples, int32_t sample_rate, | ||
| 38 | + float speed = 1.0, int32_t num_step = 4, | ||
| 39 | + GeneratedAudioCallback callback = nullptr) const { | ||
| 40 | + throw std::runtime_error( | ||
| 41 | + "OfflineTtsImpl backend does not support zero-shot Generate()"); | ||
| 42 | + } | ||
| 29 | 43 | ||
| 30 | // Return the sample rate of the generated audio | 44 | // Return the sample rate of the generated audio |
| 31 | virtual int32_t SampleRate() const = 0; | 45 | virtual int32_t SampleRate() const = 0; |
| 32 | 46 | ||
| 33 | // Number of supported speakers. | 47 | // Number of supported speakers. |
| 34 | // If it supports only a single speaker, then it return 0 or 1. | 48 | // If it supports only a single speaker, then it return 0 or 1. |
| 35 | - virtual int32_t NumSpeakers() const = 0; | 49 | + virtual int32_t NumSpeakers() const { |
| 50 | + throw std::runtime_error( | ||
| 51 | + "Zero-shot OfflineTts does not support NumSpeakers()"); | ||
| 52 | + } | ||
| 36 | 53 | ||
| 37 | std::vector<int64_t> AddBlank(const std::vector<int64_t> &x, | 54 | std::vector<int64_t> AddBlank(const std::vector<int64_t> &x, |
| 38 | int32_t blank_id = 0) const; | 55 | int32_t blank_id = 0) const; |
| @@ -140,8 +140,8 @@ class OfflineTtsMatchaImpl : public OfflineTtsImpl { | @@ -140,8 +140,8 @@ class OfflineTtsMatchaImpl : public OfflineTtsImpl { | ||
| 140 | tn_list_.push_back( | 140 | tn_list_.push_back( |
| 141 | std::make_unique<kaldifst::TextNormalizer>(std::move(r))); | 141 | std::make_unique<kaldifst::TextNormalizer>(std::move(r))); |
| 142 | } // for (; !reader->Done(); reader->Next()) | 142 | } // for (; !reader->Done(); reader->Next()) |
| 143 | - } // for (const auto &f : files) | ||
| 144 | - } // if (!config.rule_fars.empty()) | 143 | + } // for (const auto &f : files) |
| 144 | + } // if (!config.rule_fars.empty()) | ||
| 145 | } | 145 | } |
| 146 | 146 | ||
| 147 | int32_t SampleRate() const override { | 147 | int32_t SampleRate() const override { |
| @@ -12,6 +12,7 @@ void OfflineTtsModelConfig::Register(ParseOptions *po) { | @@ -12,6 +12,7 @@ void OfflineTtsModelConfig::Register(ParseOptions *po) { | ||
| 12 | vits.Register(po); | 12 | vits.Register(po); |
| 13 | matcha.Register(po); | 13 | matcha.Register(po); |
| 14 | kokoro.Register(po); | 14 | kokoro.Register(po); |
| 15 | + zipvoice.Register(po); | ||
| 15 | kitten.Register(po); | 16 | kitten.Register(po); |
| 16 | 17 | ||
| 17 | po->Register("num-threads", &num_threads, | 18 | po->Register("num-threads", &num_threads, |
| @@ -38,6 +39,10 @@ bool OfflineTtsModelConfig::Validate() const { | @@ -38,6 +39,10 @@ bool OfflineTtsModelConfig::Validate() const { | ||
| 38 | return matcha.Validate(); | 39 | return matcha.Validate(); |
| 39 | } | 40 | } |
| 40 | 41 | ||
| 42 | + if (!zipvoice.flow_matching_model.empty()) { | ||
| 43 | + return zipvoice.Validate(); | ||
| 44 | + } | ||
| 45 | + | ||
| 41 | if (!kokoro.model.empty()) { | 46 | if (!kokoro.model.empty()) { |
| 42 | return kokoro.Validate(); | 47 | return kokoro.Validate(); |
| 43 | } | 48 | } |
| @@ -58,6 +63,7 @@ std::string OfflineTtsModelConfig::ToString() const { | @@ -58,6 +63,7 @@ std::string OfflineTtsModelConfig::ToString() const { | ||
| 58 | os << "vits=" << vits.ToString() << ", "; | 63 | os << "vits=" << vits.ToString() << ", "; |
| 59 | os << "matcha=" << matcha.ToString() << ", "; | 64 | os << "matcha=" << matcha.ToString() << ", "; |
| 60 | os << "kokoro=" << kokoro.ToString() << ", "; | 65 | os << "kokoro=" << kokoro.ToString() << ", "; |
| 66 | + os << "zipvoice=" << zipvoice.ToString() << ", "; | ||
| 61 | os << "kitten=" << kitten.ToString() << ", "; | 67 | os << "kitten=" << kitten.ToString() << ", "; |
| 62 | os << "num_threads=" << num_threads << ", "; | 68 | os << "num_threads=" << num_threads << ", "; |
| 63 | os << "debug=" << (debug ? "True" : "False") << ", "; | 69 | os << "debug=" << (debug ? "True" : "False") << ", "; |
| @@ -11,6 +11,7 @@ | @@ -11,6 +11,7 @@ | ||
| 11 | #include "sherpa-onnx/csrc/offline-tts-kokoro-model-config.h" | 11 | #include "sherpa-onnx/csrc/offline-tts-kokoro-model-config.h" |
| 12 | #include "sherpa-onnx/csrc/offline-tts-matcha-model-config.h" | 12 | #include "sherpa-onnx/csrc/offline-tts-matcha-model-config.h" |
| 13 | #include "sherpa-onnx/csrc/offline-tts-vits-model-config.h" | 13 | #include "sherpa-onnx/csrc/offline-tts-vits-model-config.h" |
| 14 | +#include "sherpa-onnx/csrc/offline-tts-zipvoice-model-config.h" | ||
| 14 | #include "sherpa-onnx/csrc/parse-options.h" | 15 | #include "sherpa-onnx/csrc/parse-options.h" |
| 15 | 16 | ||
| 16 | namespace sherpa_onnx { | 17 | namespace sherpa_onnx { |
| @@ -19,6 +20,7 @@ struct OfflineTtsModelConfig { | @@ -19,6 +20,7 @@ struct OfflineTtsModelConfig { | ||
| 19 | OfflineTtsVitsModelConfig vits; | 20 | OfflineTtsVitsModelConfig vits; |
| 20 | OfflineTtsMatchaModelConfig matcha; | 21 | OfflineTtsMatchaModelConfig matcha; |
| 21 | OfflineTtsKokoroModelConfig kokoro; | 22 | OfflineTtsKokoroModelConfig kokoro; |
| 23 | + OfflineTtsZipvoiceModelConfig zipvoice; | ||
| 22 | OfflineTtsKittenModelConfig kitten; | 24 | OfflineTtsKittenModelConfig kitten; |
| 23 | 25 | ||
| 24 | int32_t num_threads = 1; | 26 | int32_t num_threads = 1; |
| @@ -30,12 +32,14 @@ struct OfflineTtsModelConfig { | @@ -30,12 +32,14 @@ struct OfflineTtsModelConfig { | ||
| 30 | OfflineTtsModelConfig(const OfflineTtsVitsModelConfig &vits, | 32 | OfflineTtsModelConfig(const OfflineTtsVitsModelConfig &vits, |
| 31 | const OfflineTtsMatchaModelConfig &matcha, | 33 | const OfflineTtsMatchaModelConfig &matcha, |
| 32 | const OfflineTtsKokoroModelConfig &kokoro, | 34 | const OfflineTtsKokoroModelConfig &kokoro, |
| 35 | + const OfflineTtsZipvoiceModelConfig &zipvoice, | ||
| 33 | const OfflineTtsKittenModelConfig &kitten, | 36 | const OfflineTtsKittenModelConfig &kitten, |
| 34 | int32_t num_threads, bool debug, | 37 | int32_t num_threads, bool debug, |
| 35 | const std::string &provider) | 38 | const std::string &provider) |
| 36 | : vits(vits), | 39 | : vits(vits), |
| 37 | matcha(matcha), | 40 | matcha(matcha), |
| 38 | kokoro(kokoro), | 41 | kokoro(kokoro), |
| 42 | + zipvoice(zipvoice), | ||
| 39 | kitten(kitten), | 43 | kitten(kitten), |
| 40 | num_threads(num_threads), | 44 | num_threads(num_threads), |
| 41 | debug(debug), | 45 | debug(debug), |
| 1 | +// sherpa-onnx/csrc/offline-tts-zipvoice-frontend-test.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-tts-zipvoice-frontend.h" | ||
| 6 | + | ||
| 7 | +#include "espeak-ng/speak_lib.h" | ||
| 8 | +#include "gtest/gtest.h" | ||
| 9 | +#include "phoneme_ids.hpp" | ||
| 10 | +#include "phonemize.hpp" | ||
| 11 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 12 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +TEST(ZipVoiceFrontend, Case1) { | ||
| 17 | + std::string data_dir = "../zipvoice/espeak-ng-data"; | ||
| 18 | + if (!FileExists(data_dir + "/en_dict")) { | ||
| 19 | + SHERPA_ONNX_LOGE("%s/en_dict does not exist. Skipping test", | ||
| 20 | + data_dir.c_str()); | ||
| 21 | + return; | ||
| 22 | + } | ||
| 23 | + | ||
| 24 | + if (!FileExists(data_dir + "/phontab")) { | ||
| 25 | + SHERPA_ONNX_LOGE("%s/phontab does not exist. Skipping test", | ||
| 26 | + data_dir.c_str()); | ||
| 27 | + return; | ||
| 28 | + } | ||
| 29 | + | ||
| 30 | + if (!FileExists(data_dir + "/phonindex")) { | ||
| 31 | + SHERPA_ONNX_LOGE("%s/phonindex does not exist. Skipping test", | ||
| 32 | + data_dir.c_str()); | ||
| 33 | + return; | ||
| 34 | + } | ||
| 35 | + | ||
| 36 | + if (!FileExists(data_dir + "/phondata")) { | ||
| 37 | + SHERPA_ONNX_LOGE("%s/phondata does not exist. Skipping test", | ||
| 38 | + data_dir.c_str()); | ||
| 39 | + return; | ||
| 40 | + } | ||
| 41 | + | ||
| 42 | + if (!FileExists(data_dir + "/intonations")) { | ||
| 43 | + SHERPA_ONNX_LOGE("%s/intonations does not exist. Skipping test", | ||
| 44 | + data_dir.c_str()); | ||
| 45 | + return; | ||
| 46 | + } | ||
| 47 | + | ||
| 48 | + std::string pinyin_dict = data_dir + "/../pinyin.dict"; | ||
| 49 | + if (!FileExists(pinyin_dict)) { | ||
| 50 | + SHERPA_ONNX_LOGE("%s does not exist. Skipping test", pinyin_dict.c_str()); | ||
| 51 | + return; | ||
| 52 | + } | ||
| 53 | + | ||
| 54 | + std::string tokens_file = data_dir + "/../tokens.txt"; | ||
| 55 | + if (!FileExists(tokens_file)) { | ||
| 56 | + SHERPA_ONNX_LOGE("%s does not exist. Skipping test", tokens_file.c_str()); | ||
| 57 | + return; | ||
| 58 | + } | ||
| 59 | + | ||
| 60 | + auto frontend = OfflineTtsZipvoiceFrontend( | ||
| 61 | + tokens_file, data_dir, pinyin_dict, | ||
| 62 | + OfflineTtsZipvoiceModelMetaData{.use_espeak = true, .use_pinyin = true}, | ||
| 63 | + true); | ||
| 64 | + | ||
| 65 | + std::string text = "how are you doing?"; | ||
| 66 | + std::vector<sherpa_onnx::TokenIDs> ans = | ||
| 67 | + frontend.ConvertTextToTokenIds(text, "en-us"); | ||
| 68 | + | ||
| 69 | + text = "这是第一句。这是第二句。"; | ||
| 70 | + ans = frontend.ConvertTextToTokenIds(text, "en-us"); | ||
| 71 | + | ||
| 72 | + text = | ||
| 73 | + "这是第一句。这是第二句。<pin1><yin2>测试 [S1]and hello " | ||
| 74 | + "world[S2]这是第三句。"; | ||
| 75 | + ans = frontend.ConvertTextToTokenIds(text, "en-us"); | ||
| 76 | +} | ||
| 77 | + | ||
| 78 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-tts-zipvoice-frontend.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include <algorithm> | ||
| 6 | +#include <cctype> | ||
| 7 | +#include <codecvt> | ||
| 8 | +#include <fstream> | ||
| 9 | +#include <locale> | ||
| 10 | +#include <regex> // NOLINT | ||
| 11 | +#include <sstream> | ||
| 12 | +#include <strstream> | ||
| 13 | +#include <utility> | ||
| 14 | + | ||
| 15 | +#if __ANDROID_API__ >= 9 | ||
| 16 | +#include "android/asset_manager.h" | ||
| 17 | +#include "android/asset_manager_jni.h" | ||
| 18 | +#endif | ||
| 19 | + | ||
| 20 | +#if __OHOS__ | ||
| 21 | +#include "rawfile/raw_file_manager.h" | ||
| 22 | +#endif | ||
| 23 | + | ||
| 24 | +#include "cppinyin/csrc/cppinyin.h" | ||
| 25 | +#include "espeak-ng/speak_lib.h" | ||
| 26 | +#include "phoneme_ids.hpp" | ||
| 27 | +#include "phonemize.hpp" | ||
| 28 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 29 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 30 | +#include "sherpa-onnx/csrc/offline-tts-zipvoice-frontend.h" | ||
| 31 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 32 | + | ||
| 33 | +namespace sherpa_onnx { | ||
| 34 | + | ||
| 35 | +void CallPhonemizeEspeak(const std::string &text, | ||
| 36 | + piper::eSpeakPhonemeConfig &config, // NOLINT | ||
| 37 | + std::vector<std::vector<piper::Phoneme>> *phonemes); | ||
| 38 | + | ||
| 39 | +static std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is) { | ||
| 40 | + std::unordered_map<std::string, int32_t> token2id; | ||
| 41 | + | ||
| 42 | + std::string line; | ||
| 43 | + std::string sym; | ||
| 44 | + int32_t id = 0; | ||
| 45 | + while (std::getline(is, line)) { | ||
| 46 | + std::istringstream iss(line); | ||
| 47 | + iss >> sym; | ||
| 48 | + if (iss.eof()) { | ||
| 49 | + id = atoi(sym.c_str()); | ||
| 50 | + sym = " "; | ||
| 51 | + } else { | ||
| 52 | + iss >> id; | ||
| 53 | + } | ||
| 54 | + // eat the trailing \r\n on windows | ||
| 55 | + iss >> std::ws; | ||
| 56 | + if (!iss.eof()) { | ||
| 57 | + SHERPA_ONNX_LOGE("Error when reading tokens: %s", line.c_str()); | ||
| 58 | + exit(-1); | ||
| 59 | + } | ||
| 60 | + | ||
| 61 | + if (token2id.count(sym)) { | ||
| 62 | + SHERPA_ONNX_LOGE("Duplicated token %s. Line %s. Existing ID: %d", | ||
| 63 | + sym.c_str(), line.c_str(), token2id.at(sym)); | ||
| 64 | + exit(-1); | ||
| 65 | + } | ||
| 66 | + token2id.insert({sym, id}); | ||
| 67 | + } | ||
| 68 | + return token2id; | ||
| 69 | +} | ||
| 70 | + | ||
| 71 | +static std::string MapPunctuations( | ||
| 72 | + const std::string &text, | ||
| 73 | + const std::unordered_map<std::string, std::string> &punct_map) { | ||
| 74 | + std::string result = text; | ||
| 75 | + for (const auto &kv : punct_map) { | ||
| 76 | + // Replace all occurrences of kv.first with kv.second | ||
| 77 | + size_t pos = 0; | ||
| 78 | + while ((pos = result.find(kv.first, pos)) != std::string::npos) { | ||
| 79 | + result.replace(pos, kv.first.length(), kv.second); | ||
| 80 | + pos += kv.second.length(); | ||
| 81 | + } | ||
| 82 | + } | ||
| 83 | + return result; | ||
| 84 | +} | ||
| 85 | + | ||
| 86 | +static void ProcessPinyin( | ||
| 87 | + const std::string &pinyin, const cppinyin::PinyinEncoder *pinyin_encoder, | ||
| 88 | + const std::unordered_map<std::string, int32_t> &token2id, | ||
| 89 | + std::vector<int64_t> *tokens_ids, std::vector<std::string> *tokens) { | ||
| 90 | + auto initial = pinyin_encoder->ToInitial(pinyin); | ||
| 91 | + if (!initial.empty()) { | ||
| 92 | + // append '0' to fix the conflict with espeak token | ||
| 93 | + initial = initial + "0"; | ||
| 94 | + if (token2id.count(initial)) { | ||
| 95 | + tokens_ids->push_back(token2id.at(initial)); | ||
| 96 | + tokens->push_back(initial); | ||
| 97 | + } else { | ||
| 98 | + SHERPA_ONNX_LOGE("Skip unknown initial %s", initial.c_str()); | ||
| 99 | + } | ||
| 100 | + } | ||
| 101 | + auto final_t = pinyin_encoder->ToFinal(pinyin); | ||
| 102 | + if (!final_t.empty()) { | ||
| 103 | + if (!std::isdigit(final_t.back())) { | ||
| 104 | + final_t = final_t + "5"; // use 5 for neutral tone | ||
| 105 | + } | ||
| 106 | + if (token2id.count(final_t)) { | ||
| 107 | + tokens_ids->push_back(token2id.at(final_t)); | ||
| 108 | + tokens->push_back(final_t); | ||
| 109 | + } else { | ||
| 110 | + SHERPA_ONNX_LOGE("Skip unknown final %s", final_t.c_str()); | ||
| 111 | + } | ||
| 112 | + } | ||
| 113 | +} | ||
| 114 | + | ||
| 115 | +static void TokenizeZh(const std::string &words, | ||
| 116 | + const cppinyin::PinyinEncoder *pinyin_encoder, | ||
| 117 | + const std::unordered_map<std::string, int32_t> &token2id, | ||
| 118 | + std::vector<int64_t> *token_ids, | ||
| 119 | + std::vector<std::string> *tokens) { | ||
| 120 | + std::vector<std::string> pinyins; | ||
| 121 | + pinyin_encoder->Encode(words, &pinyins, "number" /*tone*/, false /*partial*/); | ||
| 122 | + for (const auto &pinyin : pinyins) { | ||
| 123 | + if (pinyin_encoder->ValidPinyin(pinyin, "number" /*tone*/)) { | ||
| 124 | + ProcessPinyin(pinyin, pinyin_encoder, token2id, token_ids, tokens); | ||
| 125 | + } else { | ||
| 126 | + auto wstext = ToWideString(pinyin); | ||
| 127 | + for (auto &wc : wstext) { | ||
| 128 | + auto c = ToString(std::wstring(1, wc)); | ||
| 129 | + if (token2id.count(c)) { | ||
| 130 | + token_ids->push_back(token2id.at(c)); | ||
| 131 | + tokens->push_back(c); | ||
| 132 | + } else { | ||
| 133 | + SHERPA_ONNX_LOGE("Skip unknown character %s", c.c_str()); | ||
| 134 | + } | ||
| 135 | + } | ||
| 136 | + } | ||
| 137 | + } | ||
| 138 | +} | ||
| 139 | + | ||
| 140 | +static void TokenizeEn(const std::string &words, | ||
| 141 | + const std::unordered_map<std::string, int32_t> &token2id, | ||
| 142 | + const std::string &voice, | ||
| 143 | + std::vector<int64_t> *token_ids, | ||
| 144 | + std::vector<std::string> *tokens) { | ||
| 145 | + piper::eSpeakPhonemeConfig config; | ||
| 146 | + // ./bin/espeak-ng-bin --path ./install/share/espeak-ng-data/ --voices | ||
| 147 | + // to list available voices | ||
| 148 | + config.voice = voice; // e.g., voice is en-us | ||
| 149 | + | ||
| 150 | + std::vector<std::vector<piper::Phoneme>> phonemes; | ||
| 151 | + | ||
| 152 | + CallPhonemizeEspeak(words, config, &phonemes); | ||
| 153 | + | ||
| 154 | + for (const auto &p : phonemes) { | ||
| 155 | + for (const auto &ph : p) { | ||
| 156 | + auto token = Utf32ToUtf8(std::u32string(1, ph)); | ||
| 157 | + if (token2id.count(token)) { | ||
| 158 | + token_ids->push_back(token2id.at(token)); | ||
| 159 | + tokens->push_back(token); | ||
| 160 | + } else { | ||
| 161 | + SHERPA_ONNX_LOGE("Skip unknown phoneme %s", token.c_str()); | ||
| 162 | + } | ||
| 163 | + } | ||
| 164 | + } | ||
| 165 | +} | ||
| 166 | + | ||
| 167 | +static void TokenizeTag( | ||
| 168 | + const std::string &words, | ||
| 169 | + const std::unordered_map<std::string, int32_t> &token2id, | ||
| 170 | + std::vector<int64_t> *tokens_ids, std::vector<std::string> *tokens) { | ||
| 171 | + // in zipvoice tags are all in upper case | ||
| 172 | + std::string tag = ToUpperAscii(words); | ||
| 173 | + if (token2id.count(tag)) { | ||
| 174 | + tokens_ids->push_back(token2id.at(tag)); | ||
| 175 | + tokens->push_back(tag); | ||
| 176 | + } else { | ||
| 177 | + SHERPA_ONNX_LOGE("Skip unknown tag %s", tag.c_str()); | ||
| 178 | + } | ||
| 179 | +} | ||
| 180 | + | ||
| 181 | +static void TokenizePinyin( | ||
| 182 | + const std::string &words, const cppinyin::PinyinEncoder *pinyin_encoder, | ||
| 183 | + const std::unordered_map<std::string, int32_t> &token2id, | ||
| 184 | + std::vector<int64_t> *tokens_ids, std::vector<std::string> *tokens) { | ||
| 185 | + // words are in the form of <ha3>, <ha4> | ||
| 186 | + std::string pinyin = words.substr(1, words.size() - 2); | ||
| 187 | + if (!pinyin.empty()) { | ||
| 188 | + if (pinyin[pinyin.size() - 1] == '5') { | ||
| 189 | + pinyin = pinyin.substr(0, pinyin.size() - 1); // remove the tone | ||
| 190 | + } | ||
| 191 | + if (pinyin_encoder->ValidPinyin(pinyin, "number" /*tone*/)) { | ||
| 192 | + ProcessPinyin(pinyin, pinyin_encoder, token2id, tokens_ids, tokens); | ||
| 193 | + } else { | ||
| 194 | + SHERPA_ONNX_LOGE("Invalid pinyin %s", pinyin.c_str()); | ||
| 195 | + } | ||
| 196 | + } | ||
| 197 | +} | ||
| 198 | + | ||
| 199 | +OfflineTtsZipvoiceFrontend::OfflineTtsZipvoiceFrontend( | ||
| 200 | + const std::string &tokens, const std::string &data_dir, | ||
| 201 | + const std::string &pinyin_dict, | ||
| 202 | + const OfflineTtsZipvoiceModelMetaData &meta_data, bool debug) | ||
| 203 | + : debug_(debug), meta_data_(meta_data) { | ||
| 204 | + std::ifstream is(tokens); | ||
| 205 | + token2id_ = ReadTokens(is); | ||
| 206 | + if (meta_data_.use_pinyin) { | ||
| 207 | + pinyin_encoder_ = std::make_unique<cppinyin::PinyinEncoder>(pinyin_dict); | ||
| 208 | + } else { | ||
| 209 | + pinyin_encoder_ = nullptr; | ||
| 210 | + } | ||
| 211 | + if (meta_data_.use_espeak) { | ||
| 212 | + // We should copy the directory of espeak-ng-data from the asset to | ||
| 213 | + // some internal or external storage and then pass the directory to | ||
| 214 | + // data_dir. | ||
| 215 | + InitEspeak(data_dir); | ||
| 216 | + } | ||
| 217 | +} | ||
| 218 | + | ||
| 219 | +template <typename Manager> | ||
| 220 | +OfflineTtsZipvoiceFrontend::OfflineTtsZipvoiceFrontend( | ||
| 221 | + Manager *mgr, const std::string &tokens, const std::string &data_dir, | ||
| 222 | + const std::string &pinyin_dict, | ||
| 223 | + const OfflineTtsZipvoiceModelMetaData &meta_data, bool debug) | ||
| 224 | + : debug_(debug), meta_data_(meta_data) { | ||
| 225 | + auto buf = ReadFile(mgr, tokens); | ||
| 226 | + std::istrstream is(buf.data(), buf.size()); | ||
| 227 | + token2id_ = ReadTokens(is); | ||
| 228 | + if (meta_data_.use_pinyin) { | ||
| 229 | + auto buf = ReadFile(mgr, pinyin_dict); | ||
| 230 | + std::istringstream iss(std::string(buf.begin(), buf.end())); | ||
| 231 | + pinyin_encoder_ = std::make_unique<cppinyin::PinyinEncoder>(iss); | ||
| 232 | + } else { | ||
| 233 | + pinyin_encoder_ = nullptr; | ||
| 234 | + } | ||
| 235 | + if (meta_data_.use_espeak) { | ||
| 236 | + // We should copy the directory of espeak-ng-data from the asset to | ||
| 237 | + // some internal or external storage and then pass the directory to | ||
| 238 | + // data_dir. | ||
| 239 | + InitEspeak(data_dir); | ||
| 240 | + } | ||
| 241 | +} | ||
| 242 | + | ||
| 243 | +std::vector<TokenIDs> OfflineTtsZipvoiceFrontend::ConvertTextToTokenIds( | ||
| 244 | + const std::string &_text, const std::string &voice) const { | ||
| 245 | + std::string text = _text; | ||
| 246 | + if (meta_data_.use_espeak) { | ||
| 247 | + text = ToLowerAscii(_text); | ||
| 248 | + } | ||
| 249 | + | ||
| 250 | + text = MapPunctuations(text, punct_map_); | ||
| 251 | + | ||
| 252 | + auto wstext = ToWideString(text); | ||
| 253 | + | ||
| 254 | + std::vector<std::string> parts; | ||
| 255 | + // Match <...>, [...], or single character | ||
| 256 | + std::wregex part_pattern(LR"([<\[].*?[>\]]|.)"); | ||
| 257 | + auto words_begin = | ||
| 258 | + std::wsregex_iterator(wstext.begin(), wstext.end(), part_pattern); | ||
| 259 | + auto words_end = std::wsregex_iterator(); | ||
| 260 | + for (std::wsregex_iterator i = words_begin; i != words_end; ++i) { | ||
| 261 | + parts.push_back(ToString(i->str())); | ||
| 262 | + } | ||
| 263 | + | ||
| 264 | + // types are en, zh, tag, pinyin, other | ||
| 265 | + // tag is [...] | ||
| 266 | + // pinyin is <...> | ||
| 267 | + // other is any other text that does not match the above, normally numbers and | ||
| 268 | + // punctuations | ||
| 269 | + std::vector<std::string> types; | ||
| 270 | + for (auto &word : parts) { | ||
| 271 | + if (word.size() == 1 && std::isalpha(word[0])) { | ||
| 272 | + // single character, e.g., 'a', 'b', 'c' | ||
| 273 | + types.push_back("en"); | ||
| 274 | + } else if (word.size() > 1 && word[0] == '<' && word.back() == '>') { | ||
| 275 | + // e.g., <ha3>, <ha4> | ||
| 276 | + types.push_back("pinyin"); | ||
| 277 | + } else if (word.size() > 1 && word[0] == '[' && word.back() == ']') { | ||
| 278 | + types.push_back("tag"); | ||
| 279 | + } else if (ContainsCJK(word)) { // word contains one CJK characters | ||
| 280 | + types.push_back("zh"); | ||
| 281 | + } else { | ||
| 282 | + types.push_back("other"); | ||
| 283 | + } | ||
| 284 | + } | ||
| 285 | + | ||
| 286 | + std::vector<std::pair<std::string, std::string>> parts_with_types; | ||
| 287 | + std::ostringstream oss; | ||
| 288 | + std::string t_lang; | ||
| 289 | + oss.str(""); | ||
| 290 | + std::ostringstream debug_oss; | ||
| 291 | + if (debug_) { | ||
| 292 | + debug_oss << "Text : " << _text << ", Parts with types: \n"; | ||
| 293 | + } | ||
| 294 | + for (int32_t i = 0; i < types.size(); ++i) { | ||
| 295 | + if (i == 0) { | ||
| 296 | + oss << parts[i]; | ||
| 297 | + t_lang = types[i]; | ||
| 298 | + } else { | ||
| 299 | + if (t_lang == "other" && (types[i] != "tag" && types[i] != "pinyin")) { | ||
| 300 | + // combine into current type if the previous part is "other" | ||
| 301 | + // do not combine with "tag" or "pinyin" | ||
| 302 | + oss << parts[i]; | ||
| 303 | + t_lang = types[i]; | ||
| 304 | + } else { | ||
| 305 | + if ((t_lang == types[i] || types[i] == "other") && t_lang != "pinyin" && | ||
| 306 | + t_lang != "tag") { | ||
| 307 | + // same language or other, continue | ||
| 308 | + // do not combine other into "pinyin" or "tag" | ||
| 309 | + oss << parts[i]; | ||
| 310 | + } else { | ||
| 311 | + // different language, start a new sentence | ||
| 312 | + std::string part = oss.str(); | ||
| 313 | + oss.str(""); | ||
| 314 | + parts_with_types.emplace_back(part, t_lang); | ||
| 315 | + if (debug_) { | ||
| 316 | + debug_oss << "(" << part << ", " << t_lang << "),"; | ||
| 317 | + } | ||
| 318 | + oss << parts[i]; | ||
| 319 | + t_lang = types[i]; | ||
| 320 | + } | ||
| 321 | + } | ||
| 322 | + } | ||
| 323 | + } | ||
| 324 | + | ||
| 325 | + std::string part = oss.str(); | ||
| 326 | + oss.str(""); | ||
| 327 | + parts_with_types.emplace_back(part, t_lang); | ||
| 328 | + if (debug_) { | ||
| 329 | + debug_oss << "(" << part << ", " << t_lang << ")\n"; | ||
| 330 | + SHERPA_ONNX_LOGE("%s", debug_oss.str().c_str()); | ||
| 331 | + debug_oss.str(""); | ||
| 332 | + } | ||
| 333 | + | ||
| 334 | + std::vector<int64_t> token_ids; | ||
| 335 | + std::vector<std::string> tokens; // for debugging | ||
| 336 | + for (const auto &pt : parts_with_types) { | ||
| 337 | + if (pt.second == "zh") { | ||
| 338 | + TokenizeZh(pt.first, pinyin_encoder_.get(), token2id_, &token_ids, | ||
| 339 | + &tokens); | ||
| 340 | + } else if (pt.second == "en") { | ||
| 341 | + TokenizeEn(pt.first, token2id_, voice, &token_ids, &tokens); | ||
| 342 | + } else if (pt.second == "pinyin") { | ||
| 343 | + TokenizePinyin(pt.first, pinyin_encoder_.get(), token2id_, &token_ids, | ||
| 344 | + &tokens); | ||
| 345 | + } else if (pt.second == "tag") { | ||
| 346 | + TokenizeTag(pt.first, token2id_, &token_ids, &tokens); | ||
| 347 | + } else { | ||
| 348 | + SHERPA_ONNX_LOGE("Unexpected type: %s", pt.second.c_str()); | ||
| 349 | + exit(-1); | ||
| 350 | + } | ||
| 351 | + } | ||
| 352 | + if (debug_) { | ||
| 353 | + debug_oss << "Tokens and IDs: \n"; | ||
| 354 | + for (int32_t i = 0; i < tokens.size(); i++) { | ||
| 355 | + debug_oss << "(" << tokens[i] << ", " << token_ids[i] << "),"; | ||
| 356 | + } | ||
| 357 | + debug_oss << "\n"; | ||
| 358 | + SHERPA_ONNX_LOGE("%s", debug_oss.str().c_str()); | ||
| 359 | + } | ||
| 360 | + | ||
| 361 | + std::vector<TokenIDs> ans; | ||
| 362 | + ans.push_back(TokenIDs(std::move(token_ids))); | ||
| 363 | + return ans; | ||
| 364 | +} | ||
| 365 | + | ||
| 366 | +#if __ANDROID_API__ >= 9 | ||
| 367 | +template OfflineTtsZipvoiceFrontend::OfflineTtsZipvoiceFrontend( | ||
| 368 | + AAssetManager *mgr, const std::string &tokens, const std::string &data_dir, | ||
| 369 | + const std::string &pinyin_dict, | ||
| 370 | + const OfflineTtsZipvoiceModelMetaData &meta_data); | ||
| 371 | + | ||
| 372 | +#endif | ||
| 373 | + | ||
| 374 | +#if __OHOS__ | ||
| 375 | +template OfflineTtsZipvoiceFrontend::OfflineTtsZipvoiceFrontend( | ||
| 376 | + NativeResourceManager *mgr, const std::string &tokens, | ||
| 377 | + const std::string &data_dir, const std::string &pinyin_dict, | ||
| 378 | + const OfflineTtsZipvoiceModelMetaData &meta_data); | ||
| 379 | + | ||
| 380 | +#endif | ||
| 381 | + | ||
| 382 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-tts-zipvoice-frontend.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_FRONTEND_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_FRONTEND_H_ | ||
| 7 | +#include <cstdint> | ||
| 8 | +#include <memory> | ||
| 9 | +#include <string> | ||
| 10 | +#include <unordered_map> | ||
| 11 | +#include <vector> | ||
| 12 | + | ||
| 13 | +#include "cppinyin/csrc/cppinyin.h" | ||
| 14 | +#include "sherpa-onnx/csrc/offline-tts-frontend.h" | ||
| 15 | +#include "sherpa-onnx/csrc/offline-tts-zipvoice-model-meta-data.h" | ||
| 16 | + | ||
| 17 | +namespace sherpa_onnx { | ||
| 18 | + | ||
| 19 | +class OfflineTtsZipvoiceFrontend : public OfflineTtsFrontend { | ||
| 20 | + public: | ||
| 21 | + OfflineTtsZipvoiceFrontend(const std::string &tokens, | ||
| 22 | + const std::string &data_dir, | ||
| 23 | + const std::string &pinyin_dict, | ||
| 24 | + const OfflineTtsZipvoiceModelMetaData &meta_data, | ||
| 25 | + bool debug = false); | ||
| 26 | + | ||
| 27 | + template <typename Manager> | ||
| 28 | + OfflineTtsZipvoiceFrontend(Manager *mgr, const std::string &tokens, | ||
| 29 | + const std::string &data_dir, | ||
| 30 | + const std::string &pinyin_dict, | ||
| 31 | + const OfflineTtsZipvoiceModelMetaData &meta_data, | ||
| 32 | + bool debug = false); | ||
| 33 | + | ||
| 34 | + /** Convert a string to token IDs. | ||
| 35 | + * | ||
| 36 | + * @param text The input text. | ||
| 37 | + * Example 1: "This is the first sample sentence; this is the | ||
| 38 | + * second one." Example 2: "这是第一句。这是第二句。" | ||
| 39 | + * @param voice Optional. It is for espeak-ng. | ||
| 40 | + * | ||
| 41 | + * @return Return a vector-of-vector of token IDs. Each subvector contains | ||
| 42 | + * a sentence that can be processed independently. | ||
| 43 | + * If a frontend does not support splitting the text into | ||
| 44 | + * sentences, the resulting vector contains only one subvector. | ||
| 45 | + */ | ||
| 46 | + std::vector<TokenIDs> ConvertTextToTokenIds( | ||
| 47 | + const std::string &text, const std::string &voice = "") const override; | ||
| 48 | + | ||
| 49 | + private: | ||
| 50 | + bool debug_ = false; | ||
| 51 | + std::unordered_map<std::string, int32_t> token2id_; | ||
| 52 | + const std::unordered_map<std::string, std::string> punct_map_ = { | ||
| 53 | + {",", ","}, {"。", "."}, {"!", "!"}, {"?", "?"}, {";", ";"}, | ||
| 54 | + {":", ":"}, {"、", ","}, {"‘", "'"}, {"“", "\""}, {"”", "\""}, | ||
| 55 | + {"’", "'"}, {"⋯", "…"}, {"···", "…"}, {"・・・", "…"}, {"...", "…"}}; | ||
| 56 | + OfflineTtsZipvoiceModelMetaData meta_data_; | ||
| 57 | + std::unique_ptr<cppinyin::PinyinEncoder> pinyin_encoder_; | ||
| 58 | +}; | ||
| 59 | + | ||
| 60 | +} // namespace sherpa_onnx | ||
| 61 | + | ||
| 62 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_FRONTEND_H_ |
sherpa-onnx/csrc/offline-tts-zipvoice-impl.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-tts-zipvoice-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_IMPL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_IMPL_H_ | ||
| 6 | + | ||
| 7 | +#include <cmath> | ||
| 8 | +#include <memory> | ||
| 9 | +#include <string> | ||
| 10 | +#include <strstream> | ||
| 11 | +#include <utility> | ||
| 12 | +#include <vector> | ||
| 13 | + | ||
| 14 | +#include "kaldi-native-fbank/csrc/mel-computations.h" | ||
| 15 | +#include "kaldi-native-fbank/csrc/stft.h" | ||
| 16 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 17 | +#include "sherpa-onnx/csrc/offline-tts-frontend.h" | ||
| 18 | +#include "sherpa-onnx/csrc/offline-tts-impl.h" | ||
| 19 | +#include "sherpa-onnx/csrc/offline-tts-zipvoice-frontend.h" | ||
| 20 | +#include "sherpa-onnx/csrc/offline-tts-zipvoice-model-config.h" | ||
| 21 | +#include "sherpa-onnx/csrc/offline-tts-zipvoice-model.h" | ||
| 22 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 23 | +#include "sherpa-onnx/csrc/resample.h" | ||
| 24 | +#include "sherpa-onnx/csrc/vocoder.h" | ||
| 25 | + | ||
| 26 | +namespace sherpa_onnx { | ||
| 27 | + | ||
| 28 | +class OfflineTtsZipvoiceImpl : public OfflineTtsImpl { | ||
| 29 | + public: | ||
| 30 | + explicit OfflineTtsZipvoiceImpl(const OfflineTtsConfig &config) | ||
| 31 | + : config_(config), | ||
| 32 | + model_(std::make_unique<OfflineTtsZipvoiceModel>(config.model)), | ||
| 33 | + vocoder_(Vocoder::Create(config.model)) { | ||
| 34 | + InitFrontend(); | ||
| 35 | + } | ||
| 36 | + | ||
| 37 | + template <typename Manager> | ||
| 38 | + OfflineTtsZipvoiceImpl(Manager *mgr, const OfflineTtsConfig &config) | ||
| 39 | + : config_(config), | ||
| 40 | + model_(std::make_unique<OfflineTtsZipvoiceModel>(mgr, config.model)), | ||
| 41 | + vocoder_(Vocoder::Create(mgr, config.model)) { | ||
| 42 | + InitFrontend(mgr); | ||
| 43 | + } | ||
| 44 | + | ||
| 45 | + int32_t SampleRate() const override { | ||
| 46 | + return model_->GetMetaData().sample_rate; | ||
| 47 | + } | ||
| 48 | + | ||
| 49 | + GeneratedAudio Generate( | ||
| 50 | + const std::string &text, const std::string &prompt_text, | ||
| 51 | + const std::vector<float> &prompt_samples, int32_t sample_rate, | ||
| 52 | + float speed, int32_t num_steps, | ||
| 53 | + GeneratedAudioCallback callback = nullptr) const override { | ||
| 54 | + std::vector<TokenIDs> text_token_ids = | ||
| 55 | + frontend_->ConvertTextToTokenIds(text); | ||
| 56 | + | ||
| 57 | + std::vector<TokenIDs> prompt_token_ids = | ||
| 58 | + frontend_->ConvertTextToTokenIds(prompt_text); | ||
| 59 | + | ||
| 60 | + if (text_token_ids.empty() || | ||
| 61 | + (text_token_ids.size() == 1 && text_token_ids[0].tokens.empty())) { | ||
| 62 | +#if __OHOS__ | ||
| 63 | + SHERPA_ONNX_LOGE("Failed to convert '%{public}s' to token IDs", | ||
| 64 | + text.c_str()); | ||
| 65 | +#else | ||
| 66 | + SHERPA_ONNX_LOGE("Failed to convert '%s' to token IDs", text.c_str()); | ||
| 67 | +#endif | ||
| 68 | + return {}; | ||
| 69 | + } | ||
| 70 | + | ||
| 71 | + if (prompt_token_ids.empty() || | ||
| 72 | + (prompt_token_ids.size() == 1 && prompt_token_ids[0].tokens.empty())) { | ||
| 73 | +#if __OHOS__ | ||
| 74 | + SHERPA_ONNX_LOGE( | ||
| 75 | + "Failed to convert prompt text '%{public}s' to token IDs", | ||
| 76 | + prompt_text.c_str()); | ||
| 77 | +#else | ||
| 78 | + SHERPA_ONNX_LOGE("Failed to convert prompt text '%s' to token IDs", | ||
| 79 | + prompt_text.c_str()); | ||
| 80 | +#endif | ||
| 81 | + return {}; | ||
| 82 | + } | ||
| 83 | + | ||
| 84 | + // we assume batch size is 1 | ||
| 85 | + std::vector<int64_t> tokens = text_token_ids[0].tokens; | ||
| 86 | + std::vector<int64_t> prompt_tokens = prompt_token_ids[0].tokens; | ||
| 87 | + | ||
| 88 | + return Process(tokens, prompt_tokens, prompt_samples, sample_rate, speed, | ||
| 89 | + num_steps); | ||
| 90 | + } | ||
| 91 | + | ||
| 92 | + private: | ||
| 93 | + template <typename Manager> | ||
| 94 | + void InitFrontend(Manager *mgr) { | ||
| 95 | + const auto &meta_data = model_->GetMetaData(); | ||
| 96 | + frontend_ = std::make_unique<OfflineTtsZipvoiceFrontend>( | ||
| 97 | + mgr, config_.model.zipvoice.tokens, config_.model.zipvoice.data_dir, | ||
| 98 | + config_.model.zipvoice.pinyin_dict, meta_data, config_.model.debug); | ||
| 99 | + } | ||
| 100 | + | ||
| 101 | + void InitFrontend() { | ||
| 102 | + const auto &meta_data = model_->GetMetaData(); | ||
| 103 | + | ||
| 104 | + if (meta_data.use_pinyin && config_.model.zipvoice.pinyin_dict.empty()) { | ||
| 105 | + SHERPA_ONNX_LOGE( | ||
| 106 | + "Please provide --zipvoice-pinyin-dict for converting Chinese into " | ||
| 107 | + "pinyin."); | ||
| 108 | + exit(-1); | ||
| 109 | + } | ||
| 110 | + if (meta_data.use_espeak && config_.model.zipvoice.data_dir.empty()) { | ||
| 111 | + SHERPA_ONNX_LOGE("Please provide --zipvoice-data-dir for espeak-ng."); | ||
| 112 | + exit(-1); | ||
| 113 | + } | ||
| 114 | + frontend_ = std::make_unique<OfflineTtsZipvoiceFrontend>( | ||
| 115 | + config_.model.zipvoice.tokens, config_.model.zipvoice.data_dir, | ||
| 116 | + config_.model.zipvoice.pinyin_dict, meta_data, config_.model.debug); | ||
| 117 | + } | ||
| 118 | + | ||
| 119 | + std::vector<int32_t> ComputeMelSpectrogram( | ||
| 120 | + const std::vector<float> &_samples, int32_t sample_rate, | ||
| 121 | + std::vector<float> *prompt_features) const { | ||
| 122 | + const auto &meta = model_->GetMetaData(); | ||
| 123 | + if (sample_rate != meta.sample_rate) { | ||
| 124 | + SHERPA_ONNX_LOGE( | ||
| 125 | + "Creating a resampler:\n" | ||
| 126 | + " in_sample_rate: %d\n" | ||
| 127 | + " output_sample_rate: %d\n", | ||
| 128 | + sample_rate, static_cast<int32_t>(meta.sample_rate)); | ||
| 129 | + | ||
| 130 | + float min_freq = std::min<int32_t>(sample_rate, meta.sample_rate); | ||
| 131 | + float lowpass_cutoff = 0.99 * 0.5 * min_freq; | ||
| 132 | + | ||
| 133 | + int32_t lowpass_filter_width = 6; | ||
| 134 | + auto resampler = std::make_unique<LinearResample>( | ||
| 135 | + sample_rate, meta.sample_rate, lowpass_cutoff, lowpass_filter_width); | ||
| 136 | + std::vector<float> samples; | ||
| 137 | + resampler->Resample(_samples.data(), _samples.size(), true, &samples); | ||
| 138 | + return ComputeMelSpectrogram(samples, prompt_features); | ||
| 139 | + } else { | ||
| 140 | + // Use the original samples if the sample rate matches | ||
| 141 | + return ComputeMelSpectrogram(_samples, prompt_features); | ||
| 142 | + } | ||
| 143 | + } | ||
| 144 | + | ||
| 145 | + std::vector<int32_t> ComputeMelSpectrogram( | ||
| 146 | + const std::vector<float> &samples, | ||
| 147 | + std::vector<float> *prompt_features) const { | ||
| 148 | + const auto &meta = model_->GetMetaData(); | ||
| 149 | + | ||
| 150 | + int32_t sample_rate = meta.sample_rate; | ||
| 151 | + int32_t n_fft = meta.n_fft; | ||
| 152 | + int32_t hop_length = meta.hop_length; | ||
| 153 | + int32_t win_length = meta.window_length; | ||
| 154 | + int32_t num_mels = meta.num_mels; | ||
| 155 | + | ||
| 156 | + knf::StftConfig stft_config; | ||
| 157 | + stft_config.n_fft = n_fft; | ||
| 158 | + stft_config.hop_length = hop_length; | ||
| 159 | + stft_config.win_length = win_length; | ||
| 160 | + stft_config.window_type = "hann"; | ||
| 161 | + stft_config.center = true; | ||
| 162 | + | ||
| 163 | + knf::Stft stft(stft_config); | ||
| 164 | + auto stft_result = stft.Compute(samples.data(), samples.size()); | ||
| 165 | + int32_t num_frames = stft_result.num_frames; | ||
| 166 | + int32_t fft_bins = n_fft / 2 + 1; | ||
| 167 | + | ||
| 168 | + knf::FrameExtractionOptions frame_opts; | ||
| 169 | + frame_opts.samp_freq = sample_rate; | ||
| 170 | + frame_opts.frame_length_ms = win_length * 1000 / sample_rate; | ||
| 171 | + frame_opts.frame_shift_ms = hop_length * 1000 / sample_rate; | ||
| 172 | + frame_opts.window_type = "hanning"; | ||
| 173 | + | ||
| 174 | + knf::MelBanksOptions mel_opts; | ||
| 175 | + mel_opts.num_bins = num_mels; | ||
| 176 | + mel_opts.low_freq = 0; | ||
| 177 | + mel_opts.high_freq = sample_rate / 2; | ||
| 178 | + mel_opts.is_librosa = true; | ||
| 179 | + mel_opts.use_slaney_mel_scale = false; | ||
| 180 | + mel_opts.norm = ""; | ||
| 181 | + | ||
| 182 | + knf::MelBanks mel_banks(mel_opts, frame_opts, 1.0f); | ||
| 183 | + | ||
| 184 | + prompt_features->clear(); | ||
| 185 | + prompt_features->reserve(num_frames * num_mels); | ||
| 186 | + | ||
| 187 | + for (int32_t i = 0; i < num_frames; ++i) { | ||
| 188 | + std::vector<float> magnitude_spectrum(fft_bins); | ||
| 189 | + for (int32_t k = 0; k < fft_bins; ++k) { | ||
| 190 | + float real = stft_result.real[i * fft_bins + k]; | ||
| 191 | + float imag = stft_result.imag[i * fft_bins + k]; | ||
| 192 | + magnitude_spectrum[k] = std::sqrt(real * real + imag * imag); | ||
| 193 | + } | ||
| 194 | + std::vector<float> mel_features(num_mels, 0.0f); | ||
| 195 | + mel_banks.Compute(magnitude_spectrum.data(), mel_features.data()); | ||
| 196 | + for (auto &v : mel_features) { | ||
| 197 | + v = std::log(v + 1e-10f); | ||
| 198 | + } | ||
| 199 | + // Instead of push_back a vector, push elements individually | ||
| 200 | + prompt_features->insert(prompt_features->end(), mel_features.begin(), | ||
| 201 | + mel_features.end()); | ||
| 202 | + } | ||
| 203 | + if (num_frames == 0) { | ||
| 204 | + SHERPA_ONNX_LOGE("No frames extracted from the prompt audio"); | ||
| 205 | + return {0, 0}; | ||
| 206 | + } else { | ||
| 207 | + return {num_frames, num_mels}; | ||
| 208 | + } | ||
| 209 | + } | ||
| 210 | + | ||
| 211 | + GeneratedAudio Process(const std::vector<int64_t> &tokens, | ||
| 212 | + const std::vector<int64_t> &prompt_tokens, | ||
| 213 | + const std::vector<float> &prompt_samples, | ||
| 214 | + int32_t sample_rate, float speed, | ||
| 215 | + int num_steps) const { | ||
| 216 | + auto memory_info = | ||
| 217 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 218 | + | ||
| 219 | + std::array<int64_t, 2> tokens_shape = {1, | ||
| 220 | + static_cast<int64_t>(tokens.size())}; | ||
| 221 | + Ort::Value tokens_tensor = Ort::Value::CreateTensor( | ||
| 222 | + memory_info, const_cast<int64_t *>(tokens.data()), tokens.size(), | ||
| 223 | + tokens_shape.data(), tokens_shape.size()); | ||
| 224 | + | ||
| 225 | + std::array<int64_t, 2> prompt_tokens_shape = { | ||
| 226 | + 1, static_cast<int64_t>(prompt_tokens.size())}; | ||
| 227 | + Ort::Value prompt_tokens_tensor = Ort::Value::CreateTensor( | ||
| 228 | + memory_info, const_cast<int64_t *>(prompt_tokens.data()), | ||
| 229 | + prompt_tokens.size(), prompt_tokens_shape.data(), | ||
| 230 | + prompt_tokens_shape.size()); | ||
| 231 | + | ||
| 232 | + float target_rms = config_.model.zipvoice.target_rms; | ||
| 233 | + float feat_scale = config_.model.zipvoice.feat_scale; | ||
| 234 | + | ||
| 235 | + // Scale prompt_samples | ||
| 236 | + std::vector<float> prompt_samples_scaled = prompt_samples; | ||
| 237 | + float prompt_rms = 0.0f; | ||
| 238 | + double sum_sq = 0.0; | ||
| 239 | + // Compute RMS of prompt_samples | ||
| 240 | + for (float s : prompt_samples_scaled) { | ||
| 241 | + sum_sq += s * s; | ||
| 242 | + } | ||
| 243 | + prompt_rms = std::sqrt(sum_sq / prompt_samples_scaled.size()); | ||
| 244 | + if (prompt_rms < target_rms && prompt_rms > 0.0f) { | ||
| 245 | + float scale = target_rms / static_cast<float>(prompt_rms); | ||
| 246 | + for (auto &s : prompt_samples_scaled) { | ||
| 247 | + s *= scale; | ||
| 248 | + } | ||
| 249 | + } | ||
| 250 | + | ||
| 251 | + std::vector<float> prompt_features; | ||
| 252 | + auto res_shape = ComputeMelSpectrogram(prompt_samples_scaled, sample_rate, | ||
| 253 | + &prompt_features); | ||
| 254 | + | ||
| 255 | + int32_t num_frames = res_shape[0]; | ||
| 256 | + int32_t mel_dim = res_shape[1]; | ||
| 257 | + | ||
| 258 | + if (feat_scale != 1.0f) { | ||
| 259 | + for (auto &item : prompt_features) { | ||
| 260 | + item *= feat_scale; | ||
| 261 | + } | ||
| 262 | + } | ||
| 263 | + | ||
| 264 | + std::array<int64_t, 3> shape = {1, num_frames, mel_dim}; | ||
| 265 | + auto prompt_features_tensor = Ort::Value::CreateTensor( | ||
| 266 | + memory_info, prompt_features.data(), prompt_features.size(), | ||
| 267 | + shape.data(), shape.size()); | ||
| 268 | + | ||
| 269 | + Ort::Value mel = | ||
| 270 | + model_->Run(std::move(tokens_tensor), std::move(prompt_tokens_tensor), | ||
| 271 | + std::move(prompt_features_tensor), speed, num_steps); | ||
| 272 | + | ||
| 273 | + // Assume mel_shape = {1, T, C} | ||
| 274 | + std::vector<int64_t> mel_shape = mel.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 275 | + int64_t T = mel_shape[1], C = mel_shape[2]; | ||
| 276 | + | ||
| 277 | + float *mel_data = mel.GetTensorMutableData<float>(); | ||
| 278 | + std::vector<float> mel_permuted(C * T); | ||
| 279 | + | ||
| 280 | + for (int64_t c = 0; c < C; ++c) { | ||
| 281 | + for (int64_t t = 0; t < T; ++t) { | ||
| 282 | + int64_t src_idx = t * C + c; // src: [T, C] (row major) | ||
| 283 | + int64_t dst_idx = c * T + t; // dst: [C, T] (row major) | ||
| 284 | + mel_permuted[dst_idx] = mel_data[src_idx] / feat_scale; | ||
| 285 | + } | ||
| 286 | + } | ||
| 287 | + | ||
| 288 | + std::array<int64_t, 3> new_shape = {1, C, T}; | ||
| 289 | + Ort::Value mel_new = Ort::Value::CreateTensor<float>( | ||
| 290 | + memory_info, mel_permuted.data(), mel_permuted.size(), new_shape.data(), | ||
| 291 | + new_shape.size()); | ||
| 292 | + | ||
| 293 | + GeneratedAudio ans; | ||
| 294 | + ans.samples = vocoder_->Run(std::move(mel_new)); | ||
| 295 | + ans.sample_rate = model_->GetMetaData().sample_rate; | ||
| 296 | + | ||
| 297 | + if (prompt_rms < target_rms && target_rms > 0.0f) { | ||
| 298 | + float scale = prompt_rms / target_rms; | ||
| 299 | + for (auto &s : ans.samples) { | ||
| 300 | + s *= scale; | ||
| 301 | + } | ||
| 302 | + } | ||
| 303 | + return ans; | ||
| 304 | + } | ||
| 305 | + | ||
| 306 | + private: | ||
| 307 | + OfflineTtsConfig config_; | ||
| 308 | + std::unique_ptr<OfflineTtsZipvoiceModel> model_; | ||
| 309 | + std::unique_ptr<Vocoder> vocoder_; | ||
| 310 | + std::unique_ptr<OfflineTtsFrontend> frontend_; | ||
| 311 | +}; | ||
| 312 | + | ||
| 313 | +} // namespace sherpa_onnx | ||
| 314 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_IMPL_H_ |
| 1 | +// sherpa-onnx/csrc/offline-tts-zipvoice-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-tts-zipvoice-model-config.h" | ||
| 6 | + | ||
| 7 | +#include <vector> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 10 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +void OfflineTtsZipvoiceModelConfig::Register(ParseOptions *po) { | ||
| 15 | + po->Register("zipvoice-tokens", &tokens, | ||
| 16 | + "Path to tokens.txt for ZipVoice models"); | ||
| 17 | + po->Register("zipvoice-data-dir", &data_dir, | ||
| 18 | + "Path to the directory containing dict for espeak-ng."); | ||
| 19 | + po->Register("zipvoice-pinyin-dict", &pinyin_dict, | ||
| 20 | + "Path to the pinyin dictionary for cppinyin (i.e converting " | ||
| 21 | + "Chinese into phones)."); | ||
| 22 | + po->Register("zipvoice-text-model", &text_model, | ||
| 23 | + "Path to zipvoice text model"); | ||
| 24 | + po->Register("zipvoice-flow-matching-model", &flow_matching_model, | ||
| 25 | + "Path to zipvoice flow-matching model"); | ||
| 26 | + po->Register("zipvoice-vocoder", &vocoder, "Path to zipvoice vocoder"); | ||
| 27 | + po->Register("zipvoice-feat-scale", &feat_scale, | ||
| 28 | + "Feature scale for ZipVoice (default: 0.1)"); | ||
| 29 | + po->Register("zipvoice-t-shift", &t_shift, | ||
| 30 | + "Shift t to smaller ones if t_shift < 1.0 (default: 0.5)"); | ||
| 31 | + po->Register( | ||
| 32 | + "zipvoice-target-rms", &target_rms, | ||
| 33 | + "Target speech normalization rms value for ZipVoice (default: 0.1)"); | ||
| 34 | + po->Register( | ||
| 35 | + "zipvoice-guidance-scale", &guidance_scale, | ||
| 36 | + "The scale of classifier-free guidance during inference for ZipVoice " | ||
| 37 | + "(default: 1.0)"); | ||
| 38 | +} | ||
| 39 | + | ||
| 40 | +bool OfflineTtsZipvoiceModelConfig::Validate() const { | ||
| 41 | + if (tokens.empty()) { | ||
| 42 | + SHERPA_ONNX_LOGE("Please provide --zipvoice-tokens"); | ||
| 43 | + return false; | ||
| 44 | + } | ||
| 45 | + if (!FileExists(tokens)) { | ||
| 46 | + SHERPA_ONNX_LOGE("--zipvoice-tokens: '%s' does not exist", tokens.c_str()); | ||
| 47 | + return false; | ||
| 48 | + } | ||
| 49 | + | ||
| 50 | + if (text_model.empty()) { | ||
| 51 | + SHERPA_ONNX_LOGE("Please provide --zipvoice-text-model"); | ||
| 52 | + return false; | ||
| 53 | + } | ||
| 54 | + if (!FileExists(text_model)) { | ||
| 55 | + SHERPA_ONNX_LOGE("--zipvoice-text-model: '%s' does not exist", | ||
| 56 | + text_model.c_str()); | ||
| 57 | + return false; | ||
| 58 | + } | ||
| 59 | + | ||
| 60 | + if (flow_matching_model.empty()) { | ||
| 61 | + SHERPA_ONNX_LOGE("Please provide --zipvoice-flow-matching-model"); | ||
| 62 | + return false; | ||
| 63 | + } | ||
| 64 | + if (!FileExists(flow_matching_model)) { | ||
| 65 | + SHERPA_ONNX_LOGE("--zipvoice-flow-matching-model: '%s' does not exist", | ||
| 66 | + flow_matching_model.c_str()); | ||
| 67 | + return false; | ||
| 68 | + } | ||
| 69 | + | ||
| 70 | + if (vocoder.empty()) { | ||
| 71 | + SHERPA_ONNX_LOGE("Please provide --zipvoice-vocoder"); | ||
| 72 | + return false; | ||
| 73 | + } | ||
| 74 | + | ||
| 75 | + if (!FileExists(vocoder)) { | ||
| 76 | + SHERPA_ONNX_LOGE("--zipvoice-vocoder: '%s' does not exist", | ||
| 77 | + vocoder.c_str()); | ||
| 78 | + return false; | ||
| 79 | + } | ||
| 80 | + | ||
| 81 | + if (!data_dir.empty()) { | ||
| 82 | + std::vector<std::string> required_files = { | ||
| 83 | + "phontab", | ||
| 84 | + "phonindex", | ||
| 85 | + "phondata", | ||
| 86 | + "intonations", | ||
| 87 | + }; | ||
| 88 | + for (const auto &f : required_files) { | ||
| 89 | + if (!FileExists(data_dir + "/" + f)) { | ||
| 90 | + SHERPA_ONNX_LOGE( | ||
| 91 | + "'%s/%s' does not exist. Please check zipvoice-data-dir", | ||
| 92 | + data_dir.c_str(), f.c_str()); | ||
| 93 | + return false; | ||
| 94 | + } | ||
| 95 | + } | ||
| 96 | + } | ||
| 97 | + | ||
| 98 | + if (!pinyin_dict.empty() && !FileExists(pinyin_dict)) { | ||
| 99 | + SHERPA_ONNX_LOGE("--zipvoice-pinyin-dict: '%s' does not exist", | ||
| 100 | + pinyin_dict.c_str()); | ||
| 101 | + return false; | ||
| 102 | + } | ||
| 103 | + | ||
| 104 | + if (feat_scale <= 0) { | ||
| 105 | + SHERPA_ONNX_LOGE("--zipvoice-feat-scale must be positive. Given: %f", | ||
| 106 | + feat_scale); | ||
| 107 | + return false; | ||
| 108 | + } | ||
| 109 | + | ||
| 110 | + if (t_shift < 0) { | ||
| 111 | + SHERPA_ONNX_LOGE("--zipvoice-t-shift must be non-negative. Given: %f", | ||
| 112 | + t_shift); | ||
| 113 | + return false; | ||
| 114 | + } | ||
| 115 | + | ||
| 116 | + if (target_rms <= 0) { | ||
| 117 | + SHERPA_ONNX_LOGE("--zipvoice-target-rms must be positive. Given: %f", | ||
| 118 | + target_rms); | ||
| 119 | + return false; | ||
| 120 | + } | ||
| 121 | + | ||
| 122 | + if (guidance_scale <= 0) { | ||
| 123 | + SHERPA_ONNX_LOGE("--zipvoice-guidance-scale must be positive. Given: %f", | ||
| 124 | + guidance_scale); | ||
| 125 | + return false; | ||
| 126 | + } | ||
| 127 | + | ||
| 128 | + return true; | ||
| 129 | +} | ||
| 130 | + | ||
| 131 | +std::string OfflineTtsZipvoiceModelConfig::ToString() const { | ||
| 132 | + std::ostringstream os; | ||
| 133 | + | ||
| 134 | + os << "OfflineTtsZipvoiceModelConfig("; | ||
| 135 | + os << "tokens=\"" << tokens << "\", "; | ||
| 136 | + os << "text_model=\"" << text_model << "\", "; | ||
| 137 | + os << "flow_matching_model=\"" << flow_matching_model << "\", "; | ||
| 138 | + os << "vocoder=\"" << vocoder << "\", "; | ||
| 139 | + os << "data_dir=\"" << data_dir << "\", "; | ||
| 140 | + os << "pinyin_dict=\"" << pinyin_dict << "\", "; | ||
| 141 | + os << "feat_scale=" << feat_scale << ", "; | ||
| 142 | + os << "t_shift=" << t_shift << ", "; | ||
| 143 | + os << "target_rms=" << target_rms << ", "; | ||
| 144 | + os << "guidance_scale=" << guidance_scale << ")"; | ||
| 145 | + | ||
| 146 | + return os.str(); | ||
| 147 | +} | ||
| 148 | + | ||
| 149 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-tts-zipvoice-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include <cstdint> | ||
| 9 | +#include <string> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +struct OfflineTtsZipvoiceModelConfig { | ||
| 16 | + std::string tokens; | ||
| 17 | + std::string text_model; | ||
| 18 | + std::string flow_matching_model; | ||
| 19 | + std::string vocoder; | ||
| 20 | + | ||
| 21 | + // If data_dir is given, lexicon is ignored | ||
| 22 | + // data_dir is for piper-phonemize, which uses espeak-ng | ||
| 23 | + std::string data_dir; | ||
| 24 | + | ||
| 25 | + // Used for converting Chinese characters to pinyin | ||
| 26 | + std::string pinyin_dict; | ||
| 27 | + | ||
| 28 | + float feat_scale = 0.1; | ||
| 29 | + float t_shift = 0.5; | ||
| 30 | + float target_rms = 0.1; | ||
| 31 | + float guidance_scale = 1.0; | ||
| 32 | + | ||
| 33 | + OfflineTtsZipvoiceModelConfig() = default; | ||
| 34 | + | ||
| 35 | + OfflineTtsZipvoiceModelConfig( | ||
| 36 | + const std::string &tokens, const std::string &text_model, | ||
| 37 | + const std::string &flow_matching_model, const std::string &vocoder, | ||
| 38 | + const std::string &data_dir, const std::string &pinyin_dict, | ||
| 39 | + float feat_scale = 0.1, float t_shift = 0.5, float target_rms = 0.1, | ||
| 40 | + float guidance_scale = 1.0) | ||
| 41 | + : tokens(tokens), | ||
| 42 | + text_model(text_model), | ||
| 43 | + flow_matching_model(flow_matching_model), | ||
| 44 | + vocoder(vocoder), | ||
| 45 | + data_dir(data_dir), | ||
| 46 | + pinyin_dict(pinyin_dict), | ||
| 47 | + feat_scale(feat_scale), | ||
| 48 | + t_shift(t_shift), | ||
| 49 | + target_rms(target_rms), | ||
| 50 | + guidance_scale(guidance_scale) {} | ||
| 51 | + | ||
| 52 | + void Register(ParseOptions *po); | ||
| 53 | + bool Validate() const; | ||
| 54 | + | ||
| 55 | + std::string ToString() const; | ||
| 56 | +}; | ||
| 57 | + | ||
| 58 | +} // namespace sherpa_onnx | ||
| 59 | + | ||
| 60 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_CONFIG_H_ |
| 1 | +// sherpa-onnx/csrc/offline-tts-zipvoice-model-meta-data.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_META_DATA_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_META_DATA_H_ | ||
| 7 | + | ||
| 8 | +#include <cstdint> | ||
| 9 | +#include <string> | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +// If you are not sure what each field means, please | ||
| 14 | +// have a look of the Python file in the model directory that | ||
| 15 | +// you have downloaded. | ||
| 16 | +struct OfflineTtsZipvoiceModelMetaData { | ||
| 17 | + int32_t version = 1; | ||
| 18 | + int32_t feat_dim = 100; | ||
| 19 | + int32_t sample_rate = 24000; | ||
| 20 | + int32_t n_fft = 1024; | ||
| 21 | + int32_t hop_length = 256; | ||
| 22 | + int32_t window_length = 1024; | ||
| 23 | + int32_t num_mels = 100; | ||
| 24 | + int32_t use_espeak = 1; | ||
| 25 | + int32_t use_pinyin = 1; | ||
| 26 | +}; | ||
| 27 | + | ||
| 28 | +} // namespace sherpa_onnx | ||
| 29 | + | ||
| 30 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_META_DATA_H_ |
| 1 | +// sherpa-onnx/csrc/offline-tts-zipvoice-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-tts-zipvoice-model.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <iostream> | ||
| 9 | +#include <random> | ||
| 10 | +#include <string> | ||
| 11 | +#include <utility> | ||
| 12 | +#include <vector> | ||
| 13 | + | ||
| 14 | +#if __ANDROID_API__ >= 9 | ||
| 15 | +#include "android/asset_manager.h" | ||
| 16 | +#include "android/asset_manager_jni.h" | ||
| 17 | +#endif | ||
| 18 | + | ||
| 19 | +#if __OHOS__ | ||
| 20 | +#include "rawfile/raw_file_manager.h" | ||
| 21 | +#endif | ||
| 22 | + | ||
| 23 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 24 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 25 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 26 | +#include "sherpa-onnx/csrc/session.h" | ||
| 27 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 28 | + | ||
| 29 | +namespace sherpa_onnx { | ||
| 30 | + | ||
| 31 | +class OfflineTtsZipvoiceModel::Impl { | ||
| 32 | + public: | ||
| 33 | + explicit Impl(const OfflineTtsModelConfig &config) | ||
| 34 | + : config_(config), | ||
| 35 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 36 | + sess_opts_(GetSessionOptions(config)), | ||
| 37 | + allocator_{} { | ||
| 38 | + auto text_buf = ReadFile(config.zipvoice.text_model); | ||
| 39 | + auto fm_buf = ReadFile(config.zipvoice.flow_matching_model); | ||
| 40 | + Init(text_buf.data(), text_buf.size(), fm_buf.data(), fm_buf.size()); | ||
| 41 | + } | ||
| 42 | + | ||
| 43 | + template <typename Manager> | ||
| 44 | + Impl(Manager *mgr, const OfflineTtsModelConfig &config) | ||
| 45 | + : config_(config), | ||
| 46 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 47 | + sess_opts_(GetSessionOptions(config)), | ||
| 48 | + allocator_{} { | ||
| 49 | + auto text_buf = ReadFile(mgr, config.zipvoice.text_model); | ||
| 50 | + auto fm_buf = ReadFile(mgr, config.zipvoice.flow_matching_model); | ||
| 51 | + Init(text_buf.data(), text_buf.size(), fm_buf.data(), fm_buf.size()); | ||
| 52 | + } | ||
| 53 | + | ||
| 54 | + const OfflineTtsZipvoiceModelMetaData &GetMetaData() const { | ||
| 55 | + return meta_data_; | ||
| 56 | + } | ||
| 57 | + | ||
| 58 | + Ort::Value Run(Ort::Value tokens, Ort::Value prompt_tokens, | ||
| 59 | + Ort::Value prompt_features, float speed, int32_t num_steps) { | ||
| 60 | + auto memory_info = | ||
| 61 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 62 | + | ||
| 63 | + std::vector<int64_t> tokens_shape = | ||
| 64 | + tokens.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 65 | + int64_t batch_size = tokens_shape[0]; | ||
| 66 | + if (batch_size != 1) { | ||
| 67 | + SHERPA_ONNX_LOGE("Support only batch_size == 1. Given: %d", | ||
| 68 | + static_cast<int32_t>(batch_size)); | ||
| 69 | + exit(-1); | ||
| 70 | + } | ||
| 71 | + | ||
| 72 | + std::vector<int64_t> prompt_feat_shape = | ||
| 73 | + prompt_features.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 74 | + | ||
| 75 | + int64_t prompt_feat_len = prompt_feat_shape[1]; | ||
| 76 | + int64_t prompt_feat_len_shape = 1; | ||
| 77 | + Ort::Value prompt_feat_len_tensor = Ort::Value::CreateTensor<int64_t>( | ||
| 78 | + memory_info, &prompt_feat_len, 1, &prompt_feat_len_shape, 1); | ||
| 79 | + | ||
| 80 | + int64_t speed_shape = 1; | ||
| 81 | + Ort::Value speed_tensor = Ort::Value::CreateTensor<float>( | ||
| 82 | + memory_info, &speed, 1, &speed_shape, 1); | ||
| 83 | + | ||
| 84 | + std::vector<Ort::Value> text_inputs; | ||
| 85 | + text_inputs.reserve(4); | ||
| 86 | + text_inputs.push_back(std::move(tokens)); | ||
| 87 | + text_inputs.push_back(std::move(prompt_tokens)); | ||
| 88 | + text_inputs.push_back(std::move(prompt_feat_len_tensor)); | ||
| 89 | + text_inputs.push_back(std::move(speed_tensor)); | ||
| 90 | + | ||
| 91 | + // forward text-encoder | ||
| 92 | + auto text_out = | ||
| 93 | + text_sess_->Run({}, text_input_names_ptr_.data(), text_inputs.data(), | ||
| 94 | + text_inputs.size(), text_output_names_ptr_.data(), | ||
| 95 | + text_output_names_ptr_.size()); | ||
| 96 | + | ||
| 97 | + Ort::Value &text_condition = text_out[0]; | ||
| 98 | + | ||
| 99 | + std::vector<int64_t> text_cond_shape = | ||
| 100 | + text_condition.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 101 | + int64_t num_frames = text_cond_shape[1]; | ||
| 102 | + | ||
| 103 | + int64_t feat_dim = meta_data_.feat_dim; | ||
| 104 | + | ||
| 105 | + std::vector<float> x_data(batch_size * num_frames * feat_dim); | ||
| 106 | + std::default_random_engine rng(std::random_device{}()); | ||
| 107 | + std::normal_distribution<float> norm(0, 1); | ||
| 108 | + for (auto &v : x_data) v = norm(rng); | ||
| 109 | + std::vector<int64_t> x_shape = {batch_size, num_frames, feat_dim}; | ||
| 110 | + Ort::Value x = Ort::Value::CreateTensor<float>( | ||
| 111 | + memory_info, x_data.data(), x_data.size(), x_shape.data(), | ||
| 112 | + x_shape.size()); | ||
| 113 | + | ||
| 114 | + std::vector<float> speech_cond_data(batch_size * num_frames * feat_dim, | ||
| 115 | + 0.0f); | ||
| 116 | + const float *src = prompt_features.GetTensorData<float>(); | ||
| 117 | + float *dst = speech_cond_data.data(); | ||
| 118 | + std::memcpy(dst, src, | ||
| 119 | + batch_size * prompt_feat_len * feat_dim * sizeof(float)); | ||
| 120 | + std::vector<int64_t> speech_cond_shape = {batch_size, num_frames, feat_dim}; | ||
| 121 | + Ort::Value speech_condition = Ort::Value::CreateTensor<float>( | ||
| 122 | + memory_info, speech_cond_data.data(), speech_cond_data.size(), | ||
| 123 | + speech_cond_shape.data(), speech_cond_shape.size()); | ||
| 124 | + | ||
| 125 | + float t_shift = config_.zipvoice.t_shift; | ||
| 126 | + float guidance_scale = config_.zipvoice.guidance_scale; | ||
| 127 | + | ||
| 128 | + std::vector<float> timesteps(num_steps + 1); | ||
| 129 | + for (int32_t i = 0; i <= num_steps; ++i) { | ||
| 130 | + float t = static_cast<float>(i) / num_steps; | ||
| 131 | + timesteps[i] = t_shift * t / (1.0f + (t_shift - 1.0f) * t); | ||
| 132 | + } | ||
| 133 | + | ||
| 134 | + int64_t guidance_scale_shape = 1; | ||
| 135 | + Ort::Value guidance_scale_tensor = Ort::Value::CreateTensor<float>( | ||
| 136 | + memory_info, &guidance_scale, 1, &guidance_scale_shape, 1); | ||
| 137 | + | ||
| 138 | + std::vector<Ort::Value> fm_inputs; | ||
| 139 | + fm_inputs.reserve(5); | ||
| 140 | + // fm_inputs[0] is t tensor, will set in for loop | ||
| 141 | + fm_inputs.emplace_back(nullptr); | ||
| 142 | + fm_inputs.push_back(std::move(x)); | ||
| 143 | + fm_inputs.push_back(std::move(text_condition)); | ||
| 144 | + fm_inputs.push_back(std::move(speech_condition)); | ||
| 145 | + fm_inputs.push_back(std::move(guidance_scale_tensor)); | ||
| 146 | + | ||
| 147 | + for (int32_t step = 0; step < num_steps; ++step) { | ||
| 148 | + float t_val = timesteps[step]; | ||
| 149 | + int64_t t_shape = 1; | ||
| 150 | + Ort::Value t_tensor = | ||
| 151 | + Ort::Value::CreateTensor<float>(memory_info, &t_val, 1, &t_shape, 1); | ||
| 152 | + fm_inputs[0] = std::move(t_tensor); | ||
| 153 | + auto fm_out = fm_sess_->Run( | ||
| 154 | + {}, fm_input_names_ptr_.data(), fm_inputs.data(), fm_inputs.size(), | ||
| 155 | + fm_output_names_ptr_.data(), fm_output_names_ptr_.size()); | ||
| 156 | + Ort::Value &v = fm_out[0]; | ||
| 157 | + | ||
| 158 | + float delta_t = timesteps[step + 1] - timesteps[step]; | ||
| 159 | + float *x_ptr = fm_inputs[1].GetTensorMutableData<float>(); | ||
| 160 | + const float *v_ptr = v.GetTensorData<float>(); | ||
| 161 | + int64_t N = batch_size * num_frames * feat_dim; | ||
| 162 | + for (int64_t i = 0; i < N; ++i) { | ||
| 163 | + x_ptr[i] += v_ptr[i] * delta_t; | ||
| 164 | + } | ||
| 165 | + } | ||
| 166 | + | ||
| 167 | + int64_t keep_frames = num_frames - prompt_feat_len; | ||
| 168 | + std::vector<float> out_data(batch_size * keep_frames * feat_dim); | ||
| 169 | + x = std::move(fm_inputs[1]); | ||
| 170 | + const float *x_ptr = x.GetTensorData<float>(); | ||
| 171 | + for (int64_t b = 0; b < batch_size; ++b) { | ||
| 172 | + std::memcpy(out_data.data() + b * keep_frames * feat_dim, | ||
| 173 | + x_ptr + (b * num_frames + prompt_feat_len) * feat_dim, | ||
| 174 | + keep_frames * feat_dim * sizeof(float)); | ||
| 175 | + } | ||
| 176 | + std::vector<int64_t> out_shape = {batch_size, keep_frames, feat_dim}; | ||
| 177 | + return Ort::Value::CreateTensor<float>(memory_info, out_data.data(), | ||
| 178 | + out_data.size(), out_shape.data(), | ||
| 179 | + out_shape.size()); | ||
| 180 | + } | ||
| 181 | + | ||
| 182 | + private: | ||
| 183 | + void Init(void *text_model_data, size_t text_model_data_length, | ||
| 184 | + void *fm_model_data, size_t fm_model_data_length) { | ||
| 185 | + // Init text-encoder model | ||
| 186 | + text_sess_ = std::make_unique<Ort::Session>( | ||
| 187 | + env_, text_model_data, text_model_data_length, sess_opts_); | ||
| 188 | + GetInputNames(text_sess_.get(), &text_input_names_, &text_input_names_ptr_); | ||
| 189 | + GetOutputNames(text_sess_.get(), &text_output_names_, | ||
| 190 | + &text_output_names_ptr_); | ||
| 191 | + | ||
| 192 | + // Init flow-matching model | ||
| 193 | + fm_sess_ = std::make_unique<Ort::Session>(env_, fm_model_data, | ||
| 194 | + fm_model_data_length, sess_opts_); | ||
| 195 | + GetInputNames(fm_sess_.get(), &fm_input_names_, &fm_input_names_ptr_); | ||
| 196 | + GetOutputNames(fm_sess_.get(), &fm_output_names_, &fm_output_names_ptr_); | ||
| 197 | + | ||
| 198 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 199 | + | ||
| 200 | + Ort::ModelMetadata meta_data = text_sess_->GetModelMetadata(); | ||
| 201 | + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.use_espeak, "use_espeak", | ||
| 202 | + 1); | ||
| 203 | + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.use_pinyin, "use_pinyin", | ||
| 204 | + 1); | ||
| 205 | + | ||
| 206 | + meta_data = fm_sess_->GetModelMetadata(); | ||
| 207 | + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.version, "version", 1); | ||
| 208 | + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.feat_dim, "feat_dim", | ||
| 209 | + 100); | ||
| 210 | + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.sample_rate, | ||
| 211 | + "sample_rate", 24000); | ||
| 212 | + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.n_fft, "n_fft", 1024); | ||
| 213 | + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.hop_length, "hop_length", | ||
| 214 | + 256); | ||
| 215 | + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.window_length, | ||
| 216 | + "window_length", 1024); | ||
| 217 | + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.num_mels, "num_mels", | ||
| 218 | + 100); | ||
| 219 | + | ||
| 220 | + if (config_.debug) { | ||
| 221 | + std::ostringstream os; | ||
| 222 | + | ||
| 223 | + os << "---zipvoice text-encoder model---\n"; | ||
| 224 | + Ort::ModelMetadata text_meta_data = text_sess_->GetModelMetadata(); | ||
| 225 | + PrintModelMetadata(os, text_meta_data); | ||
| 226 | + | ||
| 227 | + os << "----------input names----------\n"; | ||
| 228 | + int32_t i = 0; | ||
| 229 | + for (const auto &s : text_input_names_) { | ||
| 230 | + os << i << " " << s << "\n"; | ||
| 231 | + ++i; | ||
| 232 | + } | ||
| 233 | + os << "----------output names----------\n"; | ||
| 234 | + i = 0; | ||
| 235 | + for (const auto &s : text_output_names_) { | ||
| 236 | + os << i << " " << s << "\n"; | ||
| 237 | + ++i; | ||
| 238 | + } | ||
| 239 | + | ||
| 240 | + os << "---zipvoice flow-matching model---\n"; | ||
| 241 | + PrintModelMetadata(os, meta_data); | ||
| 242 | + | ||
| 243 | + os << "----------input names----------\n"; | ||
| 244 | + i = 0; | ||
| 245 | + for (const auto &s : fm_input_names_) { | ||
| 246 | + os << i << " " << s << "\n"; | ||
| 247 | + ++i; | ||
| 248 | + } | ||
| 249 | + os << "----------output names----------\n"; | ||
| 250 | + i = 0; | ||
| 251 | + for (const auto &s : fm_output_names_) { | ||
| 252 | + os << i << " " << s << "\n"; | ||
| 253 | + ++i; | ||
| 254 | + } | ||
| 255 | + | ||
| 256 | +#if __OHOS__ | ||
| 257 | + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); | ||
| 258 | +#else | ||
| 259 | + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
| 260 | +#endif | ||
| 261 | + } | ||
| 262 | + } | ||
| 263 | + | ||
| 264 | + private: | ||
| 265 | + OfflineTtsModelConfig config_; | ||
| 266 | + Ort::Env env_; | ||
| 267 | + Ort::SessionOptions sess_opts_; | ||
| 268 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 269 | + | ||
| 270 | + std::unique_ptr<Ort::Session> text_sess_; | ||
| 271 | + std::unique_ptr<Ort::Session> fm_sess_; | ||
| 272 | + | ||
| 273 | + std::vector<std::string> text_input_names_; | ||
| 274 | + std::vector<const char *> text_input_names_ptr_; | ||
| 275 | + | ||
| 276 | + std::vector<std::string> text_output_names_; | ||
| 277 | + std::vector<const char *> text_output_names_ptr_; | ||
| 278 | + | ||
| 279 | + std::vector<std::string> fm_input_names_; | ||
| 280 | + std::vector<const char *> fm_input_names_ptr_; | ||
| 281 | + | ||
| 282 | + std::vector<std::string> fm_output_names_; | ||
| 283 | + std::vector<const char *> fm_output_names_ptr_; | ||
| 284 | + | ||
| 285 | + OfflineTtsZipvoiceModelMetaData meta_data_; | ||
| 286 | +}; | ||
| 287 | + | ||
| 288 | +OfflineTtsZipvoiceModel::OfflineTtsZipvoiceModel( | ||
| 289 | + const OfflineTtsModelConfig &config) | ||
| 290 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 291 | + | ||
| 292 | +template <typename Manager> | ||
| 293 | +OfflineTtsZipvoiceModel::OfflineTtsZipvoiceModel( | ||
| 294 | + Manager *mgr, const OfflineTtsModelConfig &config) | ||
| 295 | + : impl_(std::make_unique<Impl>(mgr, config)) {} | ||
| 296 | + | ||
| 297 | +OfflineTtsZipvoiceModel::~OfflineTtsZipvoiceModel() = default; | ||
| 298 | + | ||
| 299 | +const OfflineTtsZipvoiceModelMetaData &OfflineTtsZipvoiceModel::GetMetaData() | ||
| 300 | + const { | ||
| 301 | + return impl_->GetMetaData(); | ||
| 302 | +} | ||
| 303 | + | ||
| 304 | +Ort::Value OfflineTtsZipvoiceModel::Run(Ort::Value tokens, | ||
| 305 | + Ort::Value prompt_tokens, | ||
| 306 | + Ort::Value prompt_features, | ||
| 307 | + float speed /*= 1.0*/, | ||
| 308 | + int32_t num_steps /*= 16*/) const { | ||
| 309 | + return impl_->Run(std::move(tokens), std::move(prompt_tokens), | ||
| 310 | + std::move(prompt_features), speed, num_steps); | ||
| 311 | +} | ||
| 312 | + | ||
| 313 | +#if __ANDROID_API__ >= 9 | ||
| 314 | +template OfflineTtsZipvoiceModel::OfflineTtsZipvoiceModel( | ||
| 315 | + AAssetManager *mgr, const OfflineTtsModelConfig &config); | ||
| 316 | +#endif | ||
| 317 | + | ||
| 318 | +#if __OHOS__ | ||
| 319 | +template OfflineTtsZipvoiceModel::OfflineTtsZipvoiceModel( | ||
| 320 | + NativeResourceManager *mgr, const OfflineTtsModelConfig &config); | ||
| 321 | +#endif | ||
| 322 | + | ||
| 323 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-tts-zipvoice-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | +#include <string> | ||
| 10 | + | ||
| 11 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 12 | +#include "sherpa-onnx/csrc/offline-tts-model-config.h" | ||
| 13 | +#include "sherpa-onnx/csrc/offline-tts-zipvoice-model-meta-data.h" | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | + | ||
| 17 | +class OfflineTtsZipvoiceModel { | ||
| 18 | + public: | ||
| 19 | + ~OfflineTtsZipvoiceModel(); | ||
| 20 | + | ||
| 21 | + explicit OfflineTtsZipvoiceModel(const OfflineTtsModelConfig &config); | ||
| 22 | + | ||
| 23 | + template <typename Manager> | ||
| 24 | + OfflineTtsZipvoiceModel(Manager *mgr, const OfflineTtsModelConfig &config); | ||
| 25 | + | ||
| 26 | + // Return a float32 tensor containing the mel | ||
| 27 | + // of shape (batch_size, mel_dim, num_frames) | ||
| 28 | + Ort::Value Run(Ort::Value tokens, Ort::Value prompt_tokens, | ||
| 29 | + Ort::Value prompt_features, float speed, | ||
| 30 | + int32_t num_steps) const; | ||
| 31 | + | ||
| 32 | + const OfflineTtsZipvoiceModelMetaData &GetMetaData() const; | ||
| 33 | + | ||
| 34 | + private: | ||
| 35 | + class Impl; | ||
| 36 | + std::unique_ptr<Impl> impl_; | ||
| 37 | +}; | ||
| 38 | + | ||
| 39 | +} // namespace sherpa_onnx | ||
| 40 | + | ||
| 41 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_H_ |
| @@ -196,6 +196,46 @@ GeneratedAudio OfflineTts::Generate( | @@ -196,6 +196,46 @@ GeneratedAudio OfflineTts::Generate( | ||
| 196 | #endif | 196 | #endif |
| 197 | } | 197 | } |
| 198 | 198 | ||
| 199 | +GeneratedAudio OfflineTts::Generate( | ||
| 200 | + const std::string &text, const std::string &prompt_text, | ||
| 201 | + const std::vector<float> &prompt_samples, int32_t sample_rate, | ||
| 202 | + float speed /*=1.0*/, int32_t num_steps /*=4*/, | ||
| 203 | + GeneratedAudioCallback callback /*=nullptr*/) const { | ||
| 204 | +#if !defined(_WIN32) | ||
| 205 | + return impl_->Generate(text, prompt_text, prompt_samples, sample_rate, speed, | ||
| 206 | + num_steps, std::move(callback)); | ||
| 207 | +#else | ||
| 208 | + static bool printed = false; | ||
| 209 | + auto utf8_text = text; | ||
| 210 | + if (IsGB2312(text)) { | ||
| 211 | + utf8_text = Gb2312ToUtf8(text); | ||
| 212 | + if (!printed) { | ||
| 213 | + SHERPA_ONNX_LOGE("Detected GB2312 encoded text! Converting it to UTF8."); | ||
| 214 | + printed = true; | ||
| 215 | + } | ||
| 216 | + } | ||
| 217 | + auto utf8_prompt_text = prompt_text; | ||
| 218 | + if (IsGB2312(prompt_text)) { | ||
| 219 | + utf8_prompt_text = Gb2312ToUtf8(prompt_text); | ||
| 220 | + if (!printed) { | ||
| 221 | + SHERPA_ONNX_LOGE( | ||
| 222 | + "Detected GB2312 encoded prompt text! Converting it to UTF8."); | ||
| 223 | + printed = true; | ||
| 224 | + } | ||
| 225 | + } | ||
| 226 | + if (IsUtf8(utf8_text) && IsUtf8(utf8_prompt_text)) { | ||
| 227 | + return impl_->Generate(utf8_text, utf8_prompt_text, prompt_samples, | ||
| 228 | + sample_rate, speed, num_steps, std::move(callback)); | ||
| 229 | + } else { | ||
| 230 | + SHERPA_ONNX_LOGE( | ||
| 231 | + "Non UTF8 encoded string is received. You would not get expected " | ||
| 232 | + "results!"); | ||
| 233 | + return impl_->Generate(utf8_text, utf8_prompt_text, prompt_samples, | ||
| 234 | + sample_rate, speed, num_steps, std::move(callback)); | ||
| 235 | + } | ||
| 236 | +#endif | ||
| 237 | +} | ||
| 238 | + | ||
| 199 | int32_t OfflineTts::SampleRate() const { return impl_->SampleRate(); } | 239 | int32_t OfflineTts::SampleRate() const { return impl_->SampleRate(); } |
| 200 | 240 | ||
| 201 | int32_t OfflineTts::NumSpeakers() const { return impl_->NumSpeakers(); } | 241 | int32_t OfflineTts::NumSpeakers() const { return impl_->NumSpeakers(); } |
| @@ -95,6 +95,26 @@ class OfflineTts { | @@ -95,6 +95,26 @@ class OfflineTts { | ||
| 95 | float speed = 1.0, | 95 | float speed = 1.0, |
| 96 | GeneratedAudioCallback callback = nullptr) const; | 96 | GeneratedAudioCallback callback = nullptr) const; |
| 97 | 97 | ||
| 98 | + // @param text The string to be synthesized. | ||
| 99 | + // @param prompt_text The transcribe of `prompt_sampes`. | ||
| 100 | + // @param prompt_samples The prompt audio samples (mono PCM floats in [-1,1]). | ||
| 101 | + // @param sample_rate The sample rate of `prompt_audio` in Hz. | ||
| 102 | + // @param speed The speed for the generated speech. E.g., 2 means 2x faster. | ||
| 103 | + // @param num_steps The number of flow steps to generate the audio. | ||
| 104 | + // @param callback If not NULL, it is called whenever config.max_num_sentences | ||
| 105 | + // sentences have been processed. Note that the passed | ||
| 106 | + // pointer `samples` for the callback might be invalidated | ||
| 107 | + // after the callback is returned, so the caller should not | ||
| 108 | + // keep a reference to it. The caller can copy the data if | ||
| 109 | + // he/she wants to access the samples after the callback | ||
| 110 | + // returns. The callback is called in the current thread. | ||
| 111 | + GeneratedAudio Generate(const std::string &text, | ||
| 112 | + const std::string &prompt_text, | ||
| 113 | + const std::vector<float> &prompt_samples, | ||
| 114 | + int32_t sample_rate, float speed = 1.0, | ||
| 115 | + int32_t num_steps = 4, | ||
| 116 | + GeneratedAudioCallback callback = nullptr) const; | ||
| 117 | + | ||
| 98 | // Return the sample rate of the generated audio | 118 | // Return the sample rate of the generated audio |
| 99 | int32_t SampleRate() const; | 119 | int32_t SampleRate() const; |
| 100 | 120 |
| 1 | +// sherpa-onnx/csrc/sherpa-onnx-offline-zeroshot-tts.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include <chrono> // NOLINT | ||
| 6 | +#include <fstream> | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/csrc/offline-tts.h" | ||
| 9 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 10 | +#include "sherpa-onnx/csrc/wave-reader.h" | ||
| 11 | +#include "sherpa-onnx/csrc/wave-writer.h" | ||
| 12 | + | ||
| 13 | +static int32_t AudioCallback(const float * /*samples*/, int32_t n, | ||
| 14 | + float progress) { | ||
| 15 | + printf("sample=%d, progress=%f\n", n, progress); | ||
| 16 | + return 1; | ||
| 17 | +} | ||
| 18 | + | ||
| 19 | +int main(int32_t argc, char *argv[]) { | ||
| 20 | + const char *kUsageMessage = R"usage( | ||
| 21 | +Offline/Non-streaming zero-shot text-to-speech with sherpa-onnx | ||
| 22 | + | ||
| 23 | +Usage example: | ||
| 24 | + | ||
| 25 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/sherpa-onnx-zipvoice-distill-zh-en-emilia.tar.bz2 | ||
| 26 | +tar xf sherpa-onnx-zipvoice-distill-zh-en-emilia.tar.bz2 | ||
| 27 | + | ||
| 28 | +./bin/sherpa-onnx-offline-zeroshot-tts \ | ||
| 29 | + --zipvoice-flow-matching-model=sherpa-onnx-zipvoice-distill-zh-en-emilia/fm_decoder.onnx \ | ||
| 30 | + --zipvoice-text-model=sherpa-onnx-zipvoice-distill-zh-en-emilia/text_encoder.onnx \ | ||
| 31 | + --zipvoice-data-dir=sherpa-onnx-zipvoice-distill-zh-en-emilia/espeak-ng-data \ | ||
| 32 | + --zipvoice-pinyin-dict=sherpa-onnx-zipvoice-distill-zh-en-emilia/pinyin.raw \ | ||
| 33 | + --zipvoice-tokens=sherpa-onnx-zipvoice-distill-zh-en-emilia/tokens.txt \ | ||
| 34 | + --zipvoice-vocoder=sherpa-onnx-zipvoice-distill-zh-en-emilia/vocos_24khz.onnx \ | ||
| 35 | + --prompt-audio=sherpa-onnx-zipvoice-distill-zh-en-emilia/prompt.wav \ | ||
| 36 | + --num-steps=4 \ | ||
| 37 | + --num-threads=4 \ | ||
| 38 | + --prompt-text="周日被我射熄火了,所以今天是周一。" \ | ||
| 39 | + "我是中国人民的儿子,我爱我的祖国。我得祖国是一个伟大的国家,拥有五千年的文明史。" | ||
| 40 | + | ||
| 41 | +It will generate a file ./generated.wav as specified by --output-filename. | ||
| 42 | +)usage"; | ||
| 43 | + | ||
| 44 | + sherpa_onnx::ParseOptions po(kUsageMessage); | ||
| 45 | + std::string output_filename = "./generated.wav"; | ||
| 46 | + | ||
| 47 | + int32_t num_steps = 4; | ||
| 48 | + float speed = 1.0; | ||
| 49 | + std::string prompt_text; | ||
| 50 | + std::string prompt_audio; | ||
| 51 | + | ||
| 52 | + po.Register("output-filename", &output_filename, | ||
| 53 | + "Path to save the generated audio"); | ||
| 54 | + | ||
| 55 | + po.Register("num-steps", &num_steps, | ||
| 56 | + "Number of inference steps for ZipVoice (default: 4)"); | ||
| 57 | + | ||
| 58 | + po.Register("speed", &speed, | ||
| 59 | + "Speech speed for ZipVoice (default: 1.0, larger=faster, " | ||
| 60 | + "smaller=slower)"); | ||
| 61 | + | ||
| 62 | + po.Register("prompt-text", &prompt_text, "The transcribe of prompt_samples."); | ||
| 63 | + | ||
| 64 | + po.Register("prompt-audio", &prompt_audio, | ||
| 65 | + "The prompt audio file, single channel pcm. "); | ||
| 66 | + | ||
| 67 | + sherpa_onnx::OfflineTtsConfig config; | ||
| 68 | + | ||
| 69 | + config.Register(&po); | ||
| 70 | + po.Read(argc, argv); | ||
| 71 | + | ||
| 72 | + if (po.NumArgs() == 0) { | ||
| 73 | + fprintf(stderr, "Error: Please provide the text to generate audio.\n\n"); | ||
| 74 | + po.PrintUsage(); | ||
| 75 | + exit(EXIT_FAILURE); | ||
| 76 | + } | ||
| 77 | + | ||
| 78 | + if (po.NumArgs() > 1) { | ||
| 79 | + fprintf(stderr, | ||
| 80 | + "Error: Accept only one positional argument. Please use single " | ||
| 81 | + "quotes to wrap your text\n"); | ||
| 82 | + po.PrintUsage(); | ||
| 83 | + exit(EXIT_FAILURE); | ||
| 84 | + } | ||
| 85 | + | ||
| 86 | + if (config.model.debug) { | ||
| 87 | + fprintf(stderr, "%s\n", config.model.ToString().c_str()); | ||
| 88 | + } | ||
| 89 | + | ||
| 90 | + if (!config.Validate()) { | ||
| 91 | + fprintf(stderr, "Errors in config!\n"); | ||
| 92 | + exit(EXIT_FAILURE); | ||
| 93 | + } | ||
| 94 | + | ||
| 95 | + if (prompt_text.empty() || prompt_audio.empty()) { | ||
| 96 | + fprintf(stderr, "Please provide both --prompt-text and --prompt-audio\n"); | ||
| 97 | + exit(EXIT_FAILURE); | ||
| 98 | + } | ||
| 99 | + | ||
| 100 | + sherpa_onnx::OfflineTts tts(config); | ||
| 101 | + | ||
| 102 | + int32_t sample_rate = -1; | ||
| 103 | + bool is_ok = false; | ||
| 104 | + const std::vector<float> prompt_samples = | ||
| 105 | + sherpa_onnx::ReadWave(prompt_audio, &sample_rate, &is_ok); | ||
| 106 | + | ||
| 107 | + if (!is_ok) { | ||
| 108 | + fprintf(stderr, "Failed to read '%s'\n", prompt_audio.c_str()); | ||
| 109 | + return -1; | ||
| 110 | + } | ||
| 111 | + | ||
| 112 | + const auto begin = std::chrono::steady_clock::now(); | ||
| 113 | + auto audio = tts.Generate(po.GetArg(1), prompt_text, prompt_samples, | ||
| 114 | + sample_rate, speed, num_steps, AudioCallback); | ||
| 115 | + const auto end = std::chrono::steady_clock::now(); | ||
| 116 | + | ||
| 117 | + if (audio.samples.empty()) { | ||
| 118 | + fprintf( | ||
| 119 | + stderr, | ||
| 120 | + "Error in generating audio. Please read previous error messages.\n"); | ||
| 121 | + exit(EXIT_FAILURE); | ||
| 122 | + } | ||
| 123 | + | ||
| 124 | + float elapsed_seconds = | ||
| 125 | + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin) | ||
| 126 | + .count() / | ||
| 127 | + 1000.; | ||
| 128 | + float duration = audio.samples.size() / static_cast<float>(audio.sample_rate); | ||
| 129 | + | ||
| 130 | + float rtf = elapsed_seconds / duration; | ||
| 131 | + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); | ||
| 132 | + fprintf(stderr, "Audio duration: %.3f s\n", duration); | ||
| 133 | + fprintf(stderr, "Real-time factor (RTF): %.3f/%.3f = %.3f\n", elapsed_seconds, | ||
| 134 | + duration, rtf); | ||
| 135 | + | ||
| 136 | + bool ok = sherpa_onnx::WriteWave(output_filename, audio.sample_rate, | ||
| 137 | + audio.samples.data(), audio.samples.size()); | ||
| 138 | + if (!ok) { | ||
| 139 | + fprintf(stderr, "Failed to write wave to %s\n", output_filename.c_str()); | ||
| 140 | + exit(EXIT_FAILURE); | ||
| 141 | + } | ||
| 142 | + | ||
| 143 | + fprintf(stderr, "The text is: %s.\n", po.GetArg(1).c_str()); | ||
| 144 | + fprintf(stderr, "Saved to %s successfully!\n", output_filename.c_str()); | ||
| 145 | + | ||
| 146 | + return 0; | ||
| 147 | +} |
| @@ -4,6 +4,9 @@ | @@ -4,6 +4,9 @@ | ||
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/csrc/text-utils.h" | 5 | #include "sherpa-onnx/csrc/text-utils.h" |
| 6 | 6 | ||
| 7 | +#include <regex> | ||
| 8 | +#include <sstream> | ||
| 9 | + | ||
| 7 | #include "gtest/gtest.h" | 10 | #include "gtest/gtest.h" |
| 8 | 11 | ||
| 9 | namespace sherpa_onnx { | 12 | namespace sherpa_onnx { |
| @@ -55,7 +58,6 @@ TEST(RemoveInvalidUtf8Sequences, Case1) { | @@ -55,7 +58,6 @@ TEST(RemoveInvalidUtf8Sequences, Case1) { | ||
| 55 | EXPECT_EQ(s.size() + 4, v.size()); | 58 | EXPECT_EQ(s.size() + 4, v.size()); |
| 56 | } | 59 | } |
| 57 | 60 | ||
| 58 | - | ||
| 59 | // Tests for sanitizeUtf8 | 61 | // Tests for sanitizeUtf8 |
| 60 | TEST(RemoveInvalidUtf8Sequences, ValidUtf8StringPassesUnchanged) { | 62 | TEST(RemoveInvalidUtf8Sequences, ValidUtf8StringPassesUnchanged) { |
| 61 | std::string input = "Valid UTF-8 🌍"; | 63 | std::string input = "Valid UTF-8 🌍"; |
| @@ -82,7 +84,7 @@ TEST(RemoveInvalidUtf8Sequences, MultipleInvalidBytes) { | @@ -82,7 +84,7 @@ TEST(RemoveInvalidUtf8Sequences, MultipleInvalidBytes) { | ||
| 82 | 84 | ||
| 83 | TEST(RemoveInvalidUtf8Sequences, BreakingCase_SpaceFollowedByInvalidByte) { | 85 | TEST(RemoveInvalidUtf8Sequences, BreakingCase_SpaceFollowedByInvalidByte) { |
| 84 | std::string input = "\x20\xC4"; // Space followed by an invalid byte | 86 | std::string input = "\x20\xC4"; // Space followed by an invalid byte |
| 85 | - std::string expected = " "; // 0xC4 removed | 87 | + std::string expected = " "; // 0xC4 removed |
| 86 | EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected); | 88 | EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected); |
| 87 | } | 89 | } |
| 88 | 90 | ||
| @@ -99,19 +101,19 @@ TEST(RemoveInvalidUtf8Sequences, MixedValidAndInvalidBytes) { | @@ -99,19 +101,19 @@ TEST(RemoveInvalidUtf8Sequences, MixedValidAndInvalidBytes) { | ||
| 99 | 101 | ||
| 100 | TEST(RemoveInvalidUtf8Sequences, SpaceFollowedByInvalidByte) { | 102 | TEST(RemoveInvalidUtf8Sequences, SpaceFollowedByInvalidByte) { |
| 101 | std::string input = "\x20\xC4"; // Space (0x20) followed by invalid (0xC4) | 103 | std::string input = "\x20\xC4"; // Space (0x20) followed by invalid (0xC4) |
| 102 | - std::string expected = " "; // Space remains, 0xC4 is removed | 104 | + std::string expected = " "; // Space remains, 0xC4 is removed |
| 103 | EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected); | 105 | EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected); |
| 104 | } | 106 | } |
| 105 | 107 | ||
| 106 | TEST(RemoveInvalidUtf8Sequences, RemoveTruncatedC4) { | 108 | TEST(RemoveInvalidUtf8Sequences, RemoveTruncatedC4) { |
| 107 | std::string input = "Hello \xc4 world"; // Invalid `0xC4` | 109 | std::string input = "Hello \xc4 world"; // Invalid `0xC4` |
| 108 | - std::string expected = "Hello world"; // `0xC4` should be removed | 110 | + std::string expected = "Hello world"; // `0xC4` should be removed |
| 109 | EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected); | 111 | EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected); |
| 110 | } | 112 | } |
| 111 | 113 | ||
| 112 | TEST(RemoveInvalidUtf8Sequences, SpaceFollowedByInvalidByte_Breaking) { | 114 | TEST(RemoveInvalidUtf8Sequences, SpaceFollowedByInvalidByte_Breaking) { |
| 113 | std::string input = "\x20\xc4"; // Space followed by invalid `0xc4` | 115 | std::string input = "\x20\xc4"; // Space followed by invalid `0xc4` |
| 114 | - std::string expected = " "; // `0xc4` should be removed, space remains | 116 | + std::string expected = " "; // `0xc4` should be removed, space remains |
| 115 | EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected); | 117 | EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected); |
| 116 | } | 118 | } |
| 117 | 119 |
| @@ -724,4 +724,62 @@ std::vector<std::string> SplitString(const std::string &s, int32_t chunk_size) { | @@ -724,4 +724,62 @@ std::vector<std::string> SplitString(const std::string &s, int32_t chunk_size) { | ||
| 724 | return ans; | 724 | return ans; |
| 725 | } | 725 | } |
| 726 | 726 | ||
| 727 | +std::u32string Utf8ToUtf32(const std::string &str) { | ||
| 728 | + std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> conv; | ||
| 729 | + return conv.from_bytes(str); | ||
| 730 | +} | ||
| 731 | + | ||
| 732 | +std::string Utf32ToUtf8(const std::u32string &str) { | ||
| 733 | + std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> conv; | ||
| 734 | + return conv.to_bytes(str); | ||
| 735 | +} | ||
| 736 | + | ||
| 737 | +// Helper: Convert ASCII chars in a std::string to uppercase (leaves non-ASCII | ||
| 738 | +// unchanged) | ||
| 739 | +std::string ToUpperAscii(const std::string &str) { | ||
| 740 | + std::string out = str; | ||
| 741 | + for (char &c : out) { | ||
| 742 | + unsigned char uc = static_cast<unsigned char>(c); | ||
| 743 | + if (uc >= 'a' && uc <= 'z') { | ||
| 744 | + c = static_cast<char>(uc - 'a' + 'A'); | ||
| 745 | + } | ||
| 746 | + } | ||
| 747 | + return out; | ||
| 748 | +} | ||
| 749 | + | ||
| 750 | +// Helper: Convert ASCII chars in a std::string to lowercase (leaves non-ASCII | ||
| 751 | +// unchanged) | ||
| 752 | +std::string ToLowerAscii(const std::string &str) { | ||
| 753 | + std::string out = str; | ||
| 754 | + for (char &c : out) { | ||
| 755 | + unsigned char uc = static_cast<unsigned char>(c); | ||
| 756 | + if (uc >= 'A' && uc <= 'Z') { | ||
| 757 | + c = static_cast<char>(uc - 'A' + 'a'); | ||
| 758 | + } | ||
| 759 | + } | ||
| 760 | + return out; | ||
| 761 | +} | ||
| 762 | + | ||
| 763 | +// Detect if a codepoint is a CJK character | ||
| 764 | +bool IsCJK(char32_t cp) { | ||
| 765 | + return (cp >= 0x1100 && cp <= 0x11FF) || (cp >= 0x2E80 && cp <= 0xA4CF) || | ||
| 766 | + (cp >= 0xA840 && cp <= 0xD7AF) || (cp >= 0xF900 && cp <= 0xFAFF) || | ||
| 767 | + (cp >= 0xFE30 && cp <= 0xFE4F) || (cp >= 0xFF65 && cp <= 0xFFDC) || | ||
| 768 | + (cp >= 0x20000 && cp <= 0x2FFFF); | ||
| 769 | +} | ||
| 770 | + | ||
| 771 | +bool ContainsCJK(const std::string &text) { | ||
| 772 | + std::u32string utf32_text = Utf8ToUtf32(text); | ||
| 773 | + return ContainsCJK(utf32_text); | ||
| 774 | +} | ||
| 775 | + | ||
| 776 | +bool ContainsCJK(const std::u32string &text) { | ||
| 777 | + for (char32_t cp : text) { | ||
| 778 | + if (IsCJK(cp)) { | ||
| 779 | + return true; | ||
| 780 | + } | ||
| 781 | + } | ||
| 782 | + return false; | ||
| 783 | +} | ||
| 784 | + | ||
| 727 | } // namespace sherpa_onnx | 785 | } // namespace sherpa_onnx |
| @@ -149,6 +149,29 @@ bool EndsWith(const std::string &haystack, const std::string &needle); | @@ -149,6 +149,29 @@ bool EndsWith(const std::string &haystack, const std::string &needle); | ||
| 149 | 149 | ||
| 150 | std::vector<std::string> SplitString(const std::string &s, int32_t chunk_size); | 150 | std::vector<std::string> SplitString(const std::string &s, int32_t chunk_size); |
| 151 | 151 | ||
| 152 | +// Converts a UTF-8 std::string to a UTF-32 std::u32string | ||
| 153 | +std::u32string Utf8ToUtf32(const std::string &str); | ||
| 154 | + | ||
| 155 | +// Converts a UTF-32 std::u32string to a UTF-8 std::string | ||
| 156 | +std::string Utf32ToUtf8(const std::u32string &str); | ||
| 157 | + | ||
| 158 | +// Helper: Convert ASCII chars in a std::string to uppercase (leaves non-ASCII | ||
| 159 | +// unchanged) | ||
| 160 | +std::string ToUpperAscii(const std::string &str); | ||
| 161 | + | ||
| 162 | +// Helper: Convert ASCII chars in a std::string to lowercase (leaves non-ASCII | ||
| 163 | +// unchanged) | ||
| 164 | +std::string ToLowerAscii(const std::string &str); | ||
| 165 | + | ||
| 166 | +// Detect if a codepoint is a CJK character | ||
| 167 | +bool IsCJK(char32_t cp); | ||
| 168 | + | ||
| 169 | +bool ContainsCJK(const std::string &text); | ||
| 170 | + | ||
| 171 | +bool ContainsCJK(const std::u32string &text); | ||
| 172 | + | ||
| 173 | +bool StringToBool(const std::string &s); | ||
| 174 | + | ||
| 152 | } // namespace sherpa_onnx | 175 | } // namespace sherpa_onnx |
| 153 | 176 | ||
| 154 | #endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_ | 177 | #endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_ |
| @@ -74,7 +74,18 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -74,7 +74,18 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 74 | } | 74 | } |
| 75 | 75 | ||
| 76 | std::unique_ptr<Vocoder> Vocoder::Create(const OfflineTtsModelConfig &config) { | 76 | std::unique_ptr<Vocoder> Vocoder::Create(const OfflineTtsModelConfig &config) { |
| 77 | - auto buffer = ReadFile(config.matcha.vocoder); | 77 | + std::vector<char> buffer; |
| 78 | + if (!config.matcha.vocoder.empty()) { | ||
| 79 | + SHERPA_ONNX_LOGE("Using matcha vocoder: %s", config.matcha.vocoder.c_str()); | ||
| 80 | + buffer = ReadFile(config.matcha.vocoder); | ||
| 81 | + } else if (!config.zipvoice.vocoder.empty()) { | ||
| 82 | + SHERPA_ONNX_LOGE("Using zipvoice vocoder: %s", | ||
| 83 | + config.zipvoice.vocoder.c_str()); | ||
| 84 | + buffer = ReadFile(config.zipvoice.vocoder); | ||
| 85 | + } else { | ||
| 86 | + SHERPA_ONNX_LOGE("No vocoder model provided in the config!"); | ||
| 87 | + exit(-1); | ||
| 88 | + } | ||
| 78 | auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug); | 89 | auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug); |
| 79 | 90 | ||
| 80 | switch (model_type) { | 91 | switch (model_type) { |
| @@ -94,7 +105,19 @@ std::unique_ptr<Vocoder> Vocoder::Create(const OfflineTtsModelConfig &config) { | @@ -94,7 +105,19 @@ std::unique_ptr<Vocoder> Vocoder::Create(const OfflineTtsModelConfig &config) { | ||
| 94 | template <typename Manager> | 105 | template <typename Manager> |
| 95 | std::unique_ptr<Vocoder> Vocoder::Create(Manager *mgr, | 106 | std::unique_ptr<Vocoder> Vocoder::Create(Manager *mgr, |
| 96 | const OfflineTtsModelConfig &config) { | 107 | const OfflineTtsModelConfig &config) { |
| 97 | - auto buffer = ReadFile(mgr, config.matcha.vocoder); | 108 | + std::vector<char> buffer; |
| 109 | + if (!config.matcha.vocoder.empty()) { | ||
| 110 | + SHERPA_ONNX_LOGE("Using matcha vocoder: %s", config.matcha.vocoder.c_str()); | ||
| 111 | + buffer = ReadFile(mgr, config.matcha.vocoder); | ||
| 112 | + } else if (!config.zipvoice.vocoder.empty()) { | ||
| 113 | + SHERPA_ONNX_LOGE("Using zipvoice vocoder: %s", | ||
| 114 | + config.zipvoice.vocoder.c_str()); | ||
| 115 | + buffer = ReadFile(mgr, config.zipvoice.vocoder); | ||
| 116 | + } else { | ||
| 117 | + SHERPA_ONNX_LOGE("No vocoder model provided in the config!"); | ||
| 118 | + return nullptr; | ||
| 119 | + } | ||
| 120 | + | ||
| 98 | auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug); | 121 | auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug); |
| 99 | 122 | ||
| 100 | switch (model_type) { | 123 | switch (model_type) { |
| @@ -42,8 +42,16 @@ class VocosVocoder::Impl { | @@ -42,8 +42,16 @@ class VocosVocoder::Impl { | ||
| 42 | env_(ORT_LOGGING_LEVEL_ERROR), | 42 | env_(ORT_LOGGING_LEVEL_ERROR), |
| 43 | sess_opts_(GetSessionOptions(config.num_threads, config.provider)), | 43 | sess_opts_(GetSessionOptions(config.num_threads, config.provider)), |
| 44 | allocator_{} { | 44 | allocator_{} { |
| 45 | - auto buf = ReadFile(config.matcha.vocoder); | ||
| 46 | - Init(buf.data(), buf.size()); | 45 | + std::vector<char> buffer; |
| 46 | + if (!config.matcha.vocoder.empty()) { | ||
| 47 | + buffer = ReadFile(config.matcha.vocoder); | ||
| 48 | + } else if (!config.zipvoice.vocoder.empty()) { | ||
| 49 | + buffer = ReadFile(config.zipvoice.vocoder); | ||
| 50 | + } else { | ||
| 51 | + SHERPA_ONNX_LOGE("No vocoder model provided in the config!"); | ||
| 52 | + exit(-1); | ||
| 53 | + } | ||
| 54 | + Init(buffer.data(), buffer.size()); | ||
| 47 | } | 55 | } |
| 48 | 56 | ||
| 49 | template <typename Manager> | 57 | template <typename Manager> |
| @@ -52,8 +60,16 @@ class VocosVocoder::Impl { | @@ -52,8 +60,16 @@ class VocosVocoder::Impl { | ||
| 52 | env_(ORT_LOGGING_LEVEL_ERROR), | 60 | env_(ORT_LOGGING_LEVEL_ERROR), |
| 53 | sess_opts_(GetSessionOptions(config.num_threads, config.provider)), | 61 | sess_opts_(GetSessionOptions(config.num_threads, config.provider)), |
| 54 | allocator_{} { | 62 | allocator_{} { |
| 55 | - auto buf = ReadFile(mgr, config.matcha.vocoder); | ||
| 56 | - Init(buf.data(), buf.size()); | 63 | + std::vector<char> buffer; |
| 64 | + if (!config.matcha.vocoder.empty()) { | ||
| 65 | + buffer = ReadFile(mgr, config.matcha.vocoder); | ||
| 66 | + } else if (!config.zipvoice.vocoder.empty()) { | ||
| 67 | + buffer = ReadFile(mgr, config.zipvoice.vocoder); | ||
| 68 | + } else { | ||
| 69 | + SHERPA_ONNX_LOGE("No vocoder model provided in the config!"); | ||
| 70 | + exit(-1); | ||
| 71 | + } | ||
| 72 | + Init(buffer.data(), buffer.size()); | ||
| 57 | } | 73 | } |
| 58 | 74 | ||
| 59 | std::vector<float> Run(Ort::Value mel) const { | 75 | std::vector<float> Run(Ort::Value mel) const { |
| @@ -72,6 +72,7 @@ if(SHERPA_ONNX_ENABLE_TTS) | @@ -72,6 +72,7 @@ if(SHERPA_ONNX_ENABLE_TTS) | ||
| 72 | offline-tts-matcha-model-config.cc | 72 | offline-tts-matcha-model-config.cc |
| 73 | offline-tts-model-config.cc | 73 | offline-tts-model-config.cc |
| 74 | offline-tts-vits-model-config.cc | 74 | offline-tts-vits-model-config.cc |
| 75 | + offline-tts-zipvoice-model-config.cc | ||
| 75 | offline-tts.cc | 76 | offline-tts.cc |
| 76 | ) | 77 | ) |
| 77 | endif() | 78 | endif() |
| @@ -11,6 +11,7 @@ | @@ -11,6 +11,7 @@ | ||
| 11 | #include "sherpa-onnx/python/csrc/offline-tts-kokoro-model-config.h" | 11 | #include "sherpa-onnx/python/csrc/offline-tts-kokoro-model-config.h" |
| 12 | #include "sherpa-onnx/python/csrc/offline-tts-matcha-model-config.h" | 12 | #include "sherpa-onnx/python/csrc/offline-tts-matcha-model-config.h" |
| 13 | #include "sherpa-onnx/python/csrc/offline-tts-vits-model-config.h" | 13 | #include "sherpa-onnx/python/csrc/offline-tts-vits-model-config.h" |
| 14 | +#include "sherpa-onnx/python/csrc/offline-tts-zipvoice-model-config.h" | ||
| 14 | 15 | ||
| 15 | namespace sherpa_onnx { | 16 | namespace sherpa_onnx { |
| 16 | 17 | ||
| @@ -18,6 +19,7 @@ void PybindOfflineTtsModelConfig(py::module *m) { | @@ -18,6 +19,7 @@ void PybindOfflineTtsModelConfig(py::module *m) { | ||
| 18 | PybindOfflineTtsVitsModelConfig(m); | 19 | PybindOfflineTtsVitsModelConfig(m); |
| 19 | PybindOfflineTtsMatchaModelConfig(m); | 20 | PybindOfflineTtsMatchaModelConfig(m); |
| 20 | PybindOfflineTtsKokoroModelConfig(m); | 21 | PybindOfflineTtsKokoroModelConfig(m); |
| 22 | + PybindOfflineTtsZipvoiceModelConfig(m); | ||
| 21 | PybindOfflineTtsKittenModelConfig(m); | 23 | PybindOfflineTtsKittenModelConfig(m); |
| 22 | 24 | ||
| 23 | using PyClass = OfflineTtsModelConfig; | 25 | using PyClass = OfflineTtsModelConfig; |
| @@ -27,17 +29,20 @@ void PybindOfflineTtsModelConfig(py::module *m) { | @@ -27,17 +29,20 @@ void PybindOfflineTtsModelConfig(py::module *m) { | ||
| 27 | .def(py::init<const OfflineTtsVitsModelConfig &, | 29 | .def(py::init<const OfflineTtsVitsModelConfig &, |
| 28 | const OfflineTtsMatchaModelConfig &, | 30 | const OfflineTtsMatchaModelConfig &, |
| 29 | const OfflineTtsKokoroModelConfig &, | 31 | const OfflineTtsKokoroModelConfig &, |
| 32 | + const OfflineTtsZipvoiceModelConfig &, | ||
| 30 | const OfflineTtsKittenModelConfig &, int32_t, bool, | 33 | const OfflineTtsKittenModelConfig &, int32_t, bool, |
| 31 | const std::string &>(), | 34 | const std::string &>(), |
| 32 | py::arg("vits") = OfflineTtsVitsModelConfig{}, | 35 | py::arg("vits") = OfflineTtsVitsModelConfig{}, |
| 33 | py::arg("matcha") = OfflineTtsMatchaModelConfig{}, | 36 | py::arg("matcha") = OfflineTtsMatchaModelConfig{}, |
| 34 | py::arg("kokoro") = OfflineTtsKokoroModelConfig{}, | 37 | py::arg("kokoro") = OfflineTtsKokoroModelConfig{}, |
| 38 | + py::arg("zipvoice") = OfflineTtsZipvoiceModelConfig{}, | ||
| 35 | py::arg("kitten") = OfflineTtsKittenModelConfig{}, | 39 | py::arg("kitten") = OfflineTtsKittenModelConfig{}, |
| 36 | py::arg("num_threads") = 1, py::arg("debug") = false, | 40 | py::arg("num_threads") = 1, py::arg("debug") = false, |
| 37 | py::arg("provider") = "cpu") | 41 | py::arg("provider") = "cpu") |
| 38 | .def_readwrite("vits", &PyClass::vits) | 42 | .def_readwrite("vits", &PyClass::vits) |
| 39 | .def_readwrite("matcha", &PyClass::matcha) | 43 | .def_readwrite("matcha", &PyClass::matcha) |
| 40 | .def_readwrite("kokoro", &PyClass::kokoro) | 44 | .def_readwrite("kokoro", &PyClass::kokoro) |
| 45 | + .def_readwrite("zipvoice", &PyClass::zipvoice) | ||
| 41 | .def_readwrite("kitten", &PyClass::kitten) | 46 | .def_readwrite("kitten", &PyClass::kitten) |
| 42 | .def_readwrite("num_threads", &PyClass::num_threads) | 47 | .def_readwrite("num_threads", &PyClass::num_threads) |
| 43 | .def_readwrite("debug", &PyClass::debug) | 48 | .def_readwrite("debug", &PyClass::debug) |
| 1 | +// sherpa-onnx/python/csrc/offline-tts-zipvoice-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/python/csrc/offline-tts-zipvoice-model-config.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/offline-tts-zipvoice-model-config.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +void PybindOfflineTtsZipvoiceModelConfig(py::module *m) { | ||
| 14 | + using PyClass = OfflineTtsZipvoiceModelConfig; | ||
| 15 | + | ||
| 16 | + py::class_<PyClass>(*m, "OfflineTtsZipvoiceModelConfig") | ||
| 17 | + .def(py::init<>()) | ||
| 18 | + .def(py::init<const std::string &, const std::string &, | ||
| 19 | + const std::string &, const std::string &, | ||
| 20 | + const std::string &, const std::string &, float, float, | ||
| 21 | + float, float>(), | ||
| 22 | + py::arg("tokens"), py::arg("text_model"), | ||
| 23 | + py::arg("flow_matching_model"), py::arg("vocoder"), | ||
| 24 | + py::arg("data_dir") = "", py::arg("pinyin_dict") = "", | ||
| 25 | + py::arg("feat_scale") = 0.1, py::arg("t_shift") = 0.5, | ||
| 26 | + py::arg("target_rms") = 0.1, py::arg("guidance_scale") = 1.0) | ||
| 27 | + .def_readwrite("tokens", &PyClass::tokens) | ||
| 28 | + .def_readwrite("text_model", &PyClass::text_model) | ||
| 29 | + .def_readwrite("flow_matching_model", &PyClass::flow_matching_model) | ||
| 30 | + .def_readwrite("vocoder", &PyClass::vocoder) | ||
| 31 | + .def_readwrite("data_dir", &PyClass::data_dir) | ||
| 32 | + .def_readwrite("pinyin_dict", &PyClass::pinyin_dict) | ||
| 33 | + .def_readwrite("feat_scale", &PyClass::feat_scale) | ||
| 34 | + .def_readwrite("t_shift", &PyClass::t_shift) | ||
| 35 | + .def_readwrite("target_rms", &PyClass::target_rms) | ||
| 36 | + .def_readwrite("guidance_scale", &PyClass::guidance_scale) | ||
| 37 | + .def("__str__", &PyClass::ToString) | ||
| 38 | + .def("validate", &PyClass::Validate); | ||
| 39 | +} | ||
| 40 | + | ||
| 41 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/python/csrc/offline-tts-zipvoice-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindOfflineTtsZipvoiceModelConfig(py::module *m); | ||
| 13 | + | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_ZIPVOICE_MODEL_CONFIG_H_ |
| @@ -84,6 +84,41 @@ void PybindOfflineTts(py::module *m) { | @@ -84,6 +84,41 @@ void PybindOfflineTts(py::module *m) { | ||
| 84 | }, | 84 | }, |
| 85 | py::arg("text"), py::arg("sid") = 0, py::arg("speed") = 1.0, | 85 | py::arg("text"), py::arg("sid") = 0, py::arg("speed") = 1.0, |
| 86 | py::arg("callback") = py::none(), | 86 | py::arg("callback") = py::none(), |
| 87 | + py::call_guard<py::gil_scoped_release>()) | ||
| 88 | + .def( | ||
| 89 | + "generate", | ||
| 90 | + [](const PyClass &self, const std::string &text, | ||
| 91 | + const std::string &prompt_text, | ||
| 92 | + const std::vector<float> &prompt_samples, int32_t sample_rate, | ||
| 93 | + float speed, int32_t num_steps, | ||
| 94 | + std::function<int32_t(py::array_t<float>, float)> callback) | ||
| 95 | + -> GeneratedAudio { | ||
| 96 | + if (!callback) { | ||
| 97 | + return self.Generate(text, prompt_text, prompt_samples, | ||
| 98 | + sample_rate, speed, num_steps); | ||
| 99 | + } | ||
| 100 | + | ||
| 101 | + std::function<int32_t(const float *, int32_t, float)> | ||
| 102 | + callback_wrapper = [callback](const float *samples, int32_t n, | ||
| 103 | + float progress) { | ||
| 104 | + // CAUTION(fangjun): we have to copy samples since it is | ||
| 105 | + // freed once the call back returns. | ||
| 106 | + | ||
| 107 | + pybind11::gil_scoped_acquire acquire; | ||
| 108 | + | ||
| 109 | + pybind11::array_t<float> array(n); | ||
| 110 | + py::buffer_info buf = array.request(); | ||
| 111 | + auto p = static_cast<float *>(buf.ptr); | ||
| 112 | + std::copy(samples, samples + n, p); | ||
| 113 | + return callback(array, progress); | ||
| 114 | + }; | ||
| 115 | + | ||
| 116 | + return self.Generate(text, prompt_text, prompt_samples, sample_rate, | ||
| 117 | + speed, num_steps, callback_wrapper); | ||
| 118 | + }, | ||
| 119 | + py::arg("text"), py::arg("prompt_text"), py::arg("prompt_samples"), | ||
| 120 | + py::arg("sample_rate"), py::arg("speed") = 1.0, | ||
| 121 | + py::arg("num_steps") = 4, py::arg("callback") = py::none(), | ||
| 87 | py::call_guard<py::gil_scoped_release>()); | 122 | py::call_guard<py::gil_scoped_release>()); |
| 88 | } | 123 | } |
| 89 | 124 |
| @@ -49,6 +49,7 @@ from sherpa_onnx.lib._sherpa_onnx import ( | @@ -49,6 +49,7 @@ from sherpa_onnx.lib._sherpa_onnx import ( | ||
| 49 | OfflineTtsMatchaModelConfig, | 49 | OfflineTtsMatchaModelConfig, |
| 50 | OfflineTtsModelConfig, | 50 | OfflineTtsModelConfig, |
| 51 | OfflineTtsVitsModelConfig, | 51 | OfflineTtsVitsModelConfig, |
| 52 | + OfflineTtsZipvoiceModelConfig, | ||
| 52 | OfflineWenetCtcModelConfig, | 53 | OfflineWenetCtcModelConfig, |
| 53 | OfflineWhisperModelConfig, | 54 | OfflineWhisperModelConfig, |
| 54 | OfflineZipformerAudioTaggingModelConfig, | 55 | OfflineZipformerAudioTaggingModelConfig, |
-
请 注册 或 登录 后发表评论