Fangjun Kuang
Committed by GitHub

Add C++ runtime for MeloTTS (#1138)

正在显示 51 个修改的文件 包含 693 行增加156 行删除
@@ -63,10 +63,16 @@ jobs: @@ -63,10 +63,16 @@ jobs:
63 echo "pwd: $PWD" 63 echo "pwd: $PWD"
64 ls -lh ../scripts/melo-tts 64 ls -lh ../scripts/melo-tts
65 65
  66 + rm -rf ./
  67 +
66 cp -v ../scripts/melo-tts/*.onnx . 68 cp -v ../scripts/melo-tts/*.onnx .
67 cp -v ../scripts/melo-tts/lexicon.txt . 69 cp -v ../scripts/melo-tts/lexicon.txt .
68 cp -v ../scripts/melo-tts/tokens.txt . 70 cp -v ../scripts/melo-tts/tokens.txt .
  71 + cp -v ../scripts/melo-tts/README.md .
  72 +
  73 + curl -SL -O https://raw.githubusercontent.com/myshell-ai/MeloTTS/main/LICENSE
69 74
  75 + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/new_heteronym.fst
70 curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/date.fst 76 curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/date.fst
71 curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/number.fst 77 curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/number.fst
72 curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/phone.fst 78 curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/phone.fst
@@ -77,6 +83,10 @@ jobs: @@ -77,6 +83,10 @@ jobs:
77 git lfs track "*.onnx" 83 git lfs track "*.onnx"
78 git add . 84 git add .
79 85
  86 + ls -lh
  87 +
  88 + git status
  89 +
80 git commit -m "add models" 90 git commit -m "add models"
81 git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/vits-melo-tts-zh_en main || true 91 git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/vits-melo-tts-zh_en main || true
82 92
@@ -39,10 +39,14 @@ jobs: @@ -39,10 +39,14 @@ jobs:
39 cd build 39 cd build
40 cmake \ 40 cmake \
41 -A x64 \ 41 -A x64 \
42 - -D CMAKE_BUILD_TYPE=Release \  
43 - -D BUILD_SHARED_LIBS=ON \ 42 + -DBUILD_SHARED_LIBS=ON \
44 -D SHERPA_ONNX_ENABLE_JNI=ON \ 43 -D SHERPA_ONNX_ENABLE_JNI=ON \
45 - -D CMAKE_INSTALL_PREFIX=./install \ 44 + -DCMAKE_INSTALL_PREFIX=./install \
  45 + -DCMAKE_BUILD_TYPE=Release \
  46 + -DSHERPA_ONNX_ENABLE_WEBSOCKET=OFF \
  47 + -DBUILD_ESPEAK_NG_EXE=OFF \
  48 + -DSHERPA_ONNX_BUILD_C_API_EXAMPLES=OFF \
  49 + -DSHERPA_ONNX_ENABLE_BINARY=ON \
46 .. 50 ..
47 51
48 - name: Build sherpa-onnx for windows 52 - name: Build sherpa-onnx for windows
  1 +## 1.10.16
  2 +
  3 +* Support zh-en TTS model from MeloTTS.
  4 +
1 ## 1.10.15 5 ## 1.10.15
2 6
3 * Downgrade onnxruntime from v1.18.1 to v1.17.1 7 * Downgrade onnxruntime from v1.18.1 to v1.17.1
@@ -11,7 +11,7 @@ project(sherpa-onnx) @@ -11,7 +11,7 @@ project(sherpa-onnx)
11 # ./nodejs-addon-examples 11 # ./nodejs-addon-examples
12 # ./dart-api-examples/ 12 # ./dart-api-examples/
13 # ./CHANGELOG.md 13 # ./CHANGELOG.md
14 -set(SHERPA_ONNX_VERSION "1.10.15") 14 +set(SHERPA_ONNX_VERSION "1.10.16")
15 15
16 # Disable warning about 16 # Disable warning about
17 # 17 #
@@ -10,7 +10,7 @@ environment: @@ -10,7 +10,7 @@ environment:
10 10
11 # Add regular dependencies here. 11 # Add regular dependencies here.
12 dependencies: 12 dependencies:
13 - sherpa_onnx: ^1.10.15 13 + sherpa_onnx: ^1.10.16
14 path: ^1.9.0 14 path: ^1.9.0
15 args: ^2.5.0 15 args: ^2.5.0
16 16
@@ -11,7 +11,7 @@ environment: @@ -11,7 +11,7 @@ environment:
11 11
12 # Add regular dependencies here. 12 # Add regular dependencies here.
13 dependencies: 13 dependencies:
14 - sherpa_onnx: ^1.10.15 14 + sherpa_onnx: ^1.10.16
15 path: ^1.9.0 15 path: ^1.9.0
16 args: ^2.5.0 16 args: ^2.5.0
17 17
@@ -8,7 +8,7 @@ environment: @@ -8,7 +8,7 @@ environment:
8 8
9 # Add regular dependencies here. 9 # Add regular dependencies here.
10 dependencies: 10 dependencies:
11 - sherpa_onnx: ^1.10.15 11 + sherpa_onnx: ^1.10.16
12 path: ^1.9.0 12 path: ^1.9.0
13 args: ^2.5.0 13 args: ^2.5.0
14 14
@@ -9,7 +9,7 @@ environment: @@ -9,7 +9,7 @@ environment:
9 sdk: ^3.4.0 9 sdk: ^3.4.0
10 10
11 dependencies: 11 dependencies:
12 - sherpa_onnx: ^1.10.15 12 + sherpa_onnx: ^1.10.16
13 path: ^1.9.0 13 path: ^1.9.0
14 args: ^2.5.0 14 args: ^2.5.0
15 15
@@ -5,7 +5,7 @@ description: > @@ -5,7 +5,7 @@ description: >
5 5
6 publish_to: 'none' 6 publish_to: 'none'
7 7
8 -version: 1.10.14 8 +version: 1.10.16
9 9
10 topics: 10 topics:
11 - speech-recognition 11 - speech-recognition
@@ -30,7 +30,7 @@ dependencies: @@ -30,7 +30,7 @@ dependencies:
30 record: ^5.1.0 30 record: ^5.1.0
31 url_launcher: ^6.2.6 31 url_launcher: ^6.2.6
32 32
33 - sherpa_onnx: ^1.10.15 33 + sherpa_onnx: ^1.10.16
34 # sherpa_onnx: 34 # sherpa_onnx:
35 # path: ../../flutter/sherpa_onnx 35 # path: ../../flutter/sherpa_onnx
36 36
@@ -5,7 +5,7 @@ description: > @@ -5,7 +5,7 @@ description: >
5 5
6 publish_to: 'none' # Remove this line if you wish to publish to pub.dev 6 publish_to: 'none' # Remove this line if you wish to publish to pub.dev
7 7
8 -version: 1.0.0 8 +version: 1.10.16
9 9
10 environment: 10 environment:
11 sdk: '>=3.4.0 <4.0.0' 11 sdk: '>=3.4.0 <4.0.0'
@@ -17,7 +17,7 @@ dependencies: @@ -17,7 +17,7 @@ dependencies:
17 cupertino_icons: ^1.0.6 17 cupertino_icons: ^1.0.6
18 path_provider: ^2.1.3 18 path_provider: ^2.1.3
19 path: ^1.9.0 19 path: ^1.9.0
20 - sherpa_onnx: ^1.10.15 20 + sherpa_onnx: ^1.10.16
21 url_launcher: ^6.2.6 21 url_launcher: ^6.2.6
22 audioplayers: ^5.0.0 22 audioplayers: ^5.0.0
23 23
@@ -17,7 +17,7 @@ topics: @@ -17,7 +17,7 @@ topics:
17 - voice-activity-detection 17 - voice-activity-detection
18 18
19 # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx_macos.podspec 19 # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx_macos.podspec
20 -version: 1.10.15 20 +version: 1.10.16
21 21
22 homepage: https://github.com/k2-fsa/sherpa-onnx 22 homepage: https://github.com/k2-fsa/sherpa-onnx
23 23
@@ -30,19 +30,19 @@ dependencies: @@ -30,19 +30,19 @@ dependencies:
30 flutter: 30 flutter:
31 sdk: flutter 31 sdk: flutter
32 32
33 - sherpa_onnx_android: ^1.10.15 33 + sherpa_onnx_android: ^1.10.16
34 # path: ../sherpa_onnx_android 34 # path: ../sherpa_onnx_android
35 35
36 - sherpa_onnx_macos: ^1.10.15 36 + sherpa_onnx_macos: ^1.10.16
37 # path: ../sherpa_onnx_macos 37 # path: ../sherpa_onnx_macos
38 38
39 - sherpa_onnx_linux: ^1.10.15 39 + sherpa_onnx_linux: ^1.10.16
40 # path: ../sherpa_onnx_linux 40 # path: ../sherpa_onnx_linux
41 # 41 #
42 - sherpa_onnx_windows: ^1.10.15 42 + sherpa_onnx_windows: ^1.10.16
43 # path: ../sherpa_onnx_windows 43 # path: ../sherpa_onnx_windows
44 44
45 - sherpa_onnx_ios: ^1.10.15 45 + sherpa_onnx_ios: ^1.10.16
46 # sherpa_onnx_ios: 46 # sherpa_onnx_ios:
47 # path: ../sherpa_onnx_ios 47 # path: ../sherpa_onnx_ios
48 48
@@ -7,7 +7,7 @@ @@ -7,7 +7,7 @@
7 # https://groups.google.com/g/dart-ffi/c/nUATMBy7r0c 7 # https://groups.google.com/g/dart-ffi/c/nUATMBy7r0c
8 Pod::Spec.new do |s| 8 Pod::Spec.new do |s|
9 s.name = 'sherpa_onnx_ios' 9 s.name = 'sherpa_onnx_ios'
10 - s.version = '1.10.15' 10 + s.version = '1.10.16'
11 s.summary = 'A new Flutter FFI plugin project.' 11 s.summary = 'A new Flutter FFI plugin project.'
12 s.description = <<-DESC 12 s.description = <<-DESC
13 A new Flutter FFI plugin project. 13 A new Flutter FFI plugin project.
@@ -4,7 +4,7 @@ @@ -4,7 +4,7 @@
4 # 4 #
5 Pod::Spec.new do |s| 5 Pod::Spec.new do |s|
6 s.name = 'sherpa_onnx_macos' 6 s.name = 'sherpa_onnx_macos'
7 - s.version = '1.10.15' 7 + s.version = '1.10.16'
8 s.summary = 'sherpa-onnx Flutter FFI plugin project.' 8 s.summary = 'sherpa-onnx Flutter FFI plugin project.'
9 s.description = <<-DESC 9 s.description = <<-DESC
10 sherpa-onnx Flutter FFI plugin project. 10 sherpa-onnx Flutter FFI plugin project.
1 { 1 {
2 "dependencies": { 2 "dependencies": {
3 - "sherpa-onnx-node": "^1.10.15" 3 + "sherpa-onnx-node": "^1.10.16"
4 } 4 }
5 } 5 }
@@ -78,6 +78,10 @@ sed -i.bak s/"lang = null"/"lang = \"$lang_iso_639_3\""/ ./TtsEngine.kt @@ -78,6 +78,10 @@ sed -i.bak s/"lang = null"/"lang = \"$lang_iso_639_3\""/ ./TtsEngine.kt
78 git diff 78 git diff
79 popd 79 popd
80 80
  81 +if [[ $model_dir == vits-melo-tts-zh_en ]]; then
  82 + lang=zh_en
  83 +fi
  84 +
81 for arch in arm64-v8a armeabi-v7a x86_64 x86; do 85 for arch in arm64-v8a armeabi-v7a x86_64 x86; do
82 log "------------------------------------------------------------" 86 log "------------------------------------------------------------"
83 log "build tts apk for $arch" 87 log "build tts apk for $arch"
@@ -76,6 +76,10 @@ sed -i.bak s/"modelName = null"/"modelName = \"$model_name\""/ ./MainActivity.kt @@ -76,6 +76,10 @@ sed -i.bak s/"modelName = null"/"modelName = \"$model_name\""/ ./MainActivity.kt
76 git diff 76 git diff
77 popd 77 popd
78 78
  79 +if [[ $model_dir == vits-melo-tts-zh_en ]]; then
  80 + lang=zh_en
  81 +fi
  82 +
79 for arch in arm64-v8a armeabi-v7a x86_64 x86; do 83 for arch in arm64-v8a armeabi-v7a x86_64 x86; do
80 log "------------------------------------------------------------" 84 log "------------------------------------------------------------"
81 log "build tts apk for $arch" 85 log "build tts apk for $arch"
@@ -313,6 +313,11 @@ def get_vits_models() -> List[TtsModel]: @@ -313,6 +313,11 @@ def get_vits_models() -> List[TtsModel]:
313 lang="zh", 313 lang="zh",
314 ), 314 ),
315 TtsModel( 315 TtsModel(
  316 + model_dir="vits-melo-tts-zh_en",
  317 + model_name="model.onnx",
  318 + lang="zh",
  319 + ),
  320 + TtsModel(
316 model_dir="vits-zh-hf-fanchen-C", 321 model_dir="vits-zh-hf-fanchen-C",
317 model_name="vits-zh-hf-fanchen-C.onnx", 322 model_name="vits-zh-hf-fanchen-C.onnx",
318 lang="zh", 323 lang="zh",
@@ -339,18 +344,21 @@ def get_vits_models() -> List[TtsModel]: @@ -339,18 +344,21 @@ def get_vits_models() -> List[TtsModel]:
339 ), 344 ),
340 ] 345 ]
341 346
342 - rule_fsts = ["phone.fst", "date.fst", "number.fst", "new_heteronym.fst"] 347 + rule_fsts = ["phone.fst", "date.fst", "number.fst"]
343 for m in chinese_models: 348 for m in chinese_models:
344 s = [f"{m.model_dir}/{r}" for r in rule_fsts] 349 s = [f"{m.model_dir}/{r}" for r in rule_fsts]
345 - if "vits-zh-hf" in m.model_dir or "sherpa-onnx-vits-zh-ll" == m.model_dir: 350 + if (
  351 + "vits-zh-hf" in m.model_dir
  352 + or "sherpa-onnx-vits-zh-ll" == m.model_dir
  353 + or "melo-tts" in m.model_dir
  354 + ):
346 s = s[:-1] 355 s = s[:-1]
347 m.dict_dir = m.model_dir + "/dict" 356 m.dict_dir = m.model_dir + "/dict"
  357 + else:
  358 + m.rule_fars = f"{m.model_dir}/rule.far"
348 359
349 m.rule_fsts = ",".join(s) 360 m.rule_fsts = ",".join(s)
350 361
351 - if "vits-zh-hf" not in m.model_dir and "zh-ll" not in m.model_dir:  
352 - m.rule_fars = f"{m.model_dir}/rule.far"  
353 -  
354 all_models = chinese_models + [ 362 all_models = chinese_models + [
355 TtsModel( 363 TtsModel(
356 model_dir="vits-cantonese-hf-xiaomaiiwn", 364 model_dir="vits-cantonese-hf-xiaomaiiwn",
@@ -17,7 +17,7 @@ topics: @@ -17,7 +17,7 @@ topics:
17 - voice-activity-detection 17 - voice-activity-detection
18 18
19 # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx.podspec 19 # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx.podspec
20 -version: 1.10.15 20 +version: 1.10.16
21 21
22 homepage: https://github.com/k2-fsa/sherpa-onnx 22 homepage: https://github.com/k2-fsa/sherpa-onnx
23 23
@@ -6,9 +6,6 @@ from typing import List, Optional @@ -6,9 +6,6 @@ from typing import List, Optional
6 6
7 import jinja2 7 import jinja2
8 8
9 -# pip install iso639-lang  
10 -from iso639 import Lang  
11 -  
12 9
13 def get_args(): 10 def get_args():
14 parser = argparse.ArgumentParser() 11 parser = argparse.ArgumentParser()
@@ -37,13 +34,6 @@ class TtsModel: @@ -37,13 +34,6 @@ class TtsModel:
37 data_dir: Optional[str] = None 34 data_dir: Optional[str] = None
38 dict_dir: Optional[str] = None 35 dict_dir: Optional[str] = None
39 is_char: bool = False 36 is_char: bool = False
40 - lang_iso_639_3: str = ""  
41 -  
42 -  
43 -def convert_lang_to_iso_639_3(models: List[TtsModel]):  
44 - for m in models:  
45 - if m.lang_iso_639_3 == "":  
46 - m.lang_iso_639_3 = Lang(m.lang).pt3  
47 37
48 38
49 def get_coqui_models() -> List[TtsModel]: 39 def get_coqui_models() -> List[TtsModel]:
@@ -313,6 +303,11 @@ def get_vits_models() -> List[TtsModel]: @@ -313,6 +303,11 @@ def get_vits_models() -> List[TtsModel]:
313 lang="zh", 303 lang="zh",
314 ), 304 ),
315 TtsModel( 305 TtsModel(
  306 + model_dir="vits-melo-tts-zh_en",
  307 + model_name="model.onnx",
  308 + lang="zh_en",
  309 + ),
  310 + TtsModel(
316 model_dir="vits-zh-hf-fanchen-C", 311 model_dir="vits-zh-hf-fanchen-C",
317 model_name="vits-zh-hf-fanchen-C.onnx", 312 model_name="vits-zh-hf-fanchen-C.onnx",
318 lang="zh", 313 lang="zh",
@@ -332,26 +327,33 @@ def get_vits_models() -> List[TtsModel]: @@ -332,26 +327,33 @@ def get_vits_models() -> List[TtsModel]:
332 model_name="vits-zh-hf-fanchen-unity.onnx", 327 model_name="vits-zh-hf-fanchen-unity.onnx",
333 lang="zh", 328 lang="zh",
334 ), 329 ),
  330 + TtsModel(
  331 + model_dir="sherpa-onnx-vits-zh-ll",
  332 + model_name="model.onnx",
  333 + lang="zh",
  334 + ),
335 ] 335 ]
336 336
337 - rule_fsts = ["phone.fst", "date.fst", "number.fst", "new_heteronym.fst"] 337 + rule_fsts = ["phone.fst", "date.fst", "number.fst"]
338 for m in chinese_models: 338 for m in chinese_models:
339 s = [f"{m.model_dir}/{r}" for r in rule_fsts] 339 s = [f"{m.model_dir}/{r}" for r in rule_fsts]
340 - if "vits-zh-hf" in m.model_dir: 340 + if (
  341 + "vits-zh-hf" in m.model_dir
  342 + or "sherpa-onnx-vits-zh-ll" == m.model_dir
  343 + or "melo-tts" in m.model_dir
  344 + ):
341 s = s[:-1] 345 s = s[:-1]
342 m.dict_dir = m.model_dir + "/dict" 346 m.dict_dir = m.model_dir + "/dict"
  347 + else:
  348 + m.rule_fars = f"{m.model_dir}/rule.far"
343 349
344 m.rule_fsts = ",".join(s) 350 m.rule_fsts = ",".join(s)
345 351
346 - if "vits-zh-hf" not in m.model_dir:  
347 - m.rule_fars = f"{m.model_dir}/rule.far"  
348 -  
349 all_models = chinese_models + [ 352 all_models = chinese_models + [
350 TtsModel( 353 TtsModel(
351 model_dir="vits-cantonese-hf-xiaomaiiwn", 354 model_dir="vits-cantonese-hf-xiaomaiiwn",
352 model_name="vits-cantonese-hf-xiaomaiiwn.onnx", 355 model_name="vits-cantonese-hf-xiaomaiiwn.onnx",
353 lang="cantonese", 356 lang="cantonese",
354 - lang_iso_639_3="yue",  
355 rule_fsts="vits-cantonese-hf-xiaomaiiwn/rule.fst", 357 rule_fsts="vits-cantonese-hf-xiaomaiiwn/rule.fst",
356 ), 358 ),
357 # English (US) 359 # English (US)
@@ -374,7 +376,6 @@ def main(): @@ -374,7 +376,6 @@ def main():
374 all_model_list += get_piper_models() 376 all_model_list += get_piper_models()
375 all_model_list += get_mimic3_models() 377 all_model_list += get_mimic3_models()
376 all_model_list += get_coqui_models() 378 all_model_list += get_coqui_models()
377 - convert_lang_to_iso_639_3(all_model_list)  
378 379
379 num_models = len(all_model_list) 380 num_models = len(all_model_list)
380 381
  1 +# Introduction
  2 +
  3 +Models in this directory are converted from
  4 +https://github.com/myshell-ai/MeloTTS
  5 +
  6 +Note there is only a single female speaker in the model.
@@ -8,7 +8,6 @@ from melo.text import language_id_map, language_tone_start_map @@ -8,7 +8,6 @@ from melo.text import language_id_map, language_tone_start_map
8 from melo.text.chinese import pinyin_to_symbol_map 8 from melo.text.chinese import pinyin_to_symbol_map
9 from melo.text.english import eng_dict, refine_syllables 9 from melo.text.english import eng_dict, refine_syllables
10 from pypinyin import Style, lazy_pinyin, phrases_dict, pinyin_dict 10 from pypinyin import Style, lazy_pinyin, phrases_dict, pinyin_dict
11 -from melo.text.symbols import language_tone_start_map  
12 11
13 for k, v in pinyin_to_symbol_map.items(): 12 for k, v in pinyin_to_symbol_map.items():
14 if isinstance(v, list): 13 if isinstance(v, list):
@@ -82,6 +81,7 @@ def generate_tokens(symbol_list): @@ -82,6 +81,7 @@ def generate_tokens(symbol_list):
82 def generate_lexicon(): 81 def generate_lexicon():
83 word_dict = pinyin_dict.pinyin_dict 82 word_dict = pinyin_dict.pinyin_dict
84 phrases = phrases_dict.phrases_dict 83 phrases = phrases_dict.phrases_dict
  84 + eng_dict["kaldi"] = [["K", "AH0"], ["L", "D", "IH0"]]
85 with open("lexicon.txt", "w", encoding="utf-8") as f: 85 with open("lexicon.txt", "w", encoding="utf-8") as f:
86 for word in eng_dict: 86 for word in eng_dict:
87 phones, tones = refine_syllables(eng_dict[word]) 87 phones, tones = refine_syllables(eng_dict[word])
@@ -237,9 +237,11 @@ def main(): @@ -237,9 +237,11 @@ def main():
237 meta_data = { 237 meta_data = {
238 "model_type": "melo-vits", 238 "model_type": "melo-vits",
239 "comment": "melo", 239 "comment": "melo",
  240 + "version": 2,
240 "language": "Chinese + English", 241 "language": "Chinese + English",
241 "add_blank": int(model.hps.data.add_blank), 242 "add_blank": int(model.hps.data.add_blank),
242 "n_speakers": 1, 243 "n_speakers": 1,
  244 + "jieba": 1,
243 "sample_rate": model.hps.data.sampling_rate, 245 "sample_rate": model.hps.data.sampling_rate,
244 "bert_dim": 1024, 246 "bert_dim": 1024,
245 "ja_bert_dim": 768, 247 "ja_bert_dim": 768,
@@ -12,7 +12,7 @@ function install() { @@ -12,7 +12,7 @@ function install() {
12 cd MeloTTS 12 cd MeloTTS
13 pip install -r ./requirements.txt 13 pip install -r ./requirements.txt
14 14
15 - pip install soundfile onnx onnxruntime 15 + pip install soundfile onnx==1.15.0 onnxruntime==1.16.3
16 16
17 python3 -m unidic download 17 python3 -m unidic download
18 popd 18 popd
@@ -135,28 +135,11 @@ class OnnxModel: @@ -135,28 +135,11 @@ class OnnxModel:
135 def main(): 135 def main():
136 lexicon = Lexicon(lexion_filename="./lexicon.txt", tokens_filename="./tokens.txt") 136 lexicon = Lexicon(lexion_filename="./lexicon.txt", tokens_filename="./tokens.txt")
137 137
138 - text = "永远相信,美好的事情即将发生。" 138 + text = "这是一个使用 next generation kaldi 的 text to speech 中英文例子. Thank you! 你觉得如何呢? are you ok? Fantastic! How about you?"
139 s = jieba.cut(text, HMM=True) 139 s = jieba.cut(text, HMM=True)
140 140
141 phones, tones = lexicon.convert(s) 141 phones, tones = lexicon.convert(s)
142 142
143 - en_text = "how are you ?".split()  
144 -  
145 - phones_en, tones_en = lexicon.convert(en_text)  
146 - phones += [0]  
147 - tones += [0]  
148 -  
149 - phones += phones_en  
150 - tones += tones_en  
151 -  
152 - text = "多音字测试, 银行,行不行?长沙长大"  
153 - s = jieba.cut(text, HMM=True)  
154 -  
155 - phones2, tones2 = lexicon.convert(s)  
156 -  
157 - phones += phones2  
158 - tones += tones2  
159 -  
160 model = OnnxModel("./model.onnx") 143 model = OnnxModel("./model.onnx")
161 144
162 if model.add_blank: 145 if model.add_blank:
@@ -422,10 +422,10 @@ sherpa_onnx::OfflineRecognizerConfig convertConfig( @@ -422,10 +422,10 @@ sherpa_onnx::OfflineRecognizerConfig convertConfig(
422 422
423 void SherpaOnnxOfflineRecognizerSetConfig( 423 void SherpaOnnxOfflineRecognizerSetConfig(
424 const SherpaOnnxOfflineRecognizer *recognizer, 424 const SherpaOnnxOfflineRecognizer *recognizer,
425 - const SherpaOnnxOfflineRecognizerConfig *config){ 425 + const SherpaOnnxOfflineRecognizerConfig *config) {
426 sherpa_onnx::OfflineRecognizerConfig recognizer_config = 426 sherpa_onnx::OfflineRecognizerConfig recognizer_config =
427 convertConfig(config); 427 convertConfig(config);
428 - recognizer->impl->SetConfig(recognizer_config); 428 + recognizer->impl->SetConfig(recognizer_config);
429 } 429 }
430 430
431 void DestroyOfflineRecognizer(SherpaOnnxOfflineRecognizer *recognizer) { 431 void DestroyOfflineRecognizer(SherpaOnnxOfflineRecognizer *recognizer) {
@@ -478,7 +478,7 @@ const SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult( @@ -478,7 +478,7 @@ const SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult(
478 pText[text.size()] = 0; 478 pText[text.size()] = 0;
479 r->text = pText; 479 r->text = pText;
480 480
481 - //lang 481 + // lang
482 const auto &lang = result.lang; 482 const auto &lang = result.lang;
483 char *c_lang = new char[lang.size() + 1]; 483 char *c_lang = new char[lang.size() + 1];
484 std::copy(lang.begin(), lang.end(), c_lang); 484 std::copy(lang.begin(), lang.end(), c_lang);
@@ -1317,7 +1317,7 @@ void SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches( @@ -1317,7 +1317,7 @@ void SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches(
1317 } 1317 }
1318 delete[] r->matches; 1318 delete[] r->matches;
1319 delete r; 1319 delete r;
1320 -}; 1320 +}
1321 1321
1322 int32_t SherpaOnnxSpeakerEmbeddingManagerVerify( 1322 int32_t SherpaOnnxSpeakerEmbeddingManagerVerify(
1323 const SherpaOnnxSpeakerEmbeddingManager *p, const char *name, 1323 const SherpaOnnxSpeakerEmbeddingManager *p, const char *name,
@@ -496,7 +496,7 @@ SHERPA_ONNX_API void DecodeMultipleOfflineStreams( @@ -496,7 +496,7 @@ SHERPA_ONNX_API void DecodeMultipleOfflineStreams(
496 SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult { 496 SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
497 const char *text; 497 const char *text;
498 498
499 - // Pointer to continuous memory which holds timestamps 499 + // Pointer to continuous memory which holds timestamps
500 // 500 //
501 // It is NULL if the model does not support timestamps 501 // It is NULL if the model does not support timestamps
502 float *timestamps; 502 float *timestamps;
@@ -525,9 +525,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult { @@ -525,9 +525,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
525 */ 525 */
526 const char *json; 526 const char *json;
527 527
528 - //return recognized language 528 + // return recognized language
529 const char *lang; 529 const char *lang;
530 -  
531 } SherpaOnnxOfflineRecognizerResult; 530 } SherpaOnnxOfflineRecognizerResult;
532 531
533 /// Get the result of the offline stream. 532 /// Get the result of the offline stream.
@@ -142,7 +142,9 @@ if(SHERPA_ONNX_ENABLE_TTS) @@ -142,7 +142,9 @@ if(SHERPA_ONNX_ENABLE_TTS)
142 list(APPEND sources 142 list(APPEND sources
143 jieba-lexicon.cc 143 jieba-lexicon.cc
144 lexicon.cc 144 lexicon.cc
  145 + melo-tts-lexicon.cc
145 offline-tts-character-frontend.cc 146 offline-tts-character-frontend.cc
  147 + offline-tts-frontend.cc
146 offline-tts-impl.cc 148 offline-tts-impl.cc
147 offline-tts-model-config.cc 149 offline-tts-model-config.cc
148 offline-tts-vits-model-config.cc 150 offline-tts-vits-model-config.cc
@@ -33,7 +33,7 @@ TEST(CppJieBa, Case1) { @@ -33,7 +33,7 @@ TEST(CppJieBa, Case1) {
33 std::vector<std::string> words; 33 std::vector<std::string> words;
34 std::vector<cppjieba::Word> jiebawords; 34 std::vector<cppjieba::Word> jiebawords;
35 35
36 - std::string s = "他来到了网易杭研大厦"; 36 + std::string s = "他来到了网易杭研大厦。How are you?";
37 std::cout << s << std::endl; 37 std::cout << s << std::endl;
38 std::cout << "[demo] Cut With HMM" << std::endl; 38 std::cout << "[demo] Cut With HMM" << std::endl;
39 jieba.Cut(s, words, true); 39 jieba.Cut(s, words, true);
@@ -17,6 +17,7 @@ namespace sherpa_onnx { @@ -17,6 +17,7 @@ namespace sherpa_onnx {
17 17
18 // implemented in ./lexicon.cc 18 // implemented in ./lexicon.cc
19 std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is); 19 std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is);
  20 +
20 std::vector<int32_t> ConvertTokensToIds( 21 std::vector<int32_t> ConvertTokensToIds(
21 const std::unordered_map<std::string, int32_t> &token2id, 22 const std::unordered_map<std::string, int32_t> &token2id,
22 const std::vector<std::string> &tokens); 23 const std::vector<std::string> &tokens);
@@ -53,8 +54,7 @@ class JiebaLexicon::Impl { @@ -53,8 +54,7 @@ class JiebaLexicon::Impl {
53 } 54 }
54 } 55 }
55 56
56 - std::vector<std::vector<int64_t>> ConvertTextToTokenIds(  
57 - const std::string &text) const { 57 + std::vector<TokenIDs> ConvertTextToTokenIds(const std::string &text) const {
58 // see 58 // see
59 // https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/text/mandarin.py#L244 59 // https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/text/mandarin.py#L244
60 std::regex punct_re{":|、|;"}; 60 std::regex punct_re{":|、|;"};
@@ -87,7 +87,7 @@ class JiebaLexicon::Impl { @@ -87,7 +87,7 @@ class JiebaLexicon::Impl {
87 SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str()); 87 SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str());
88 } 88 }
89 89
90 - std::vector<std::vector<int64_t>> ans; 90 + std::vector<TokenIDs> ans;
91 std::vector<int64_t> this_sentence; 91 std::vector<int64_t> this_sentence;
92 92
93 int32_t blank = token2id_.at(" "); 93 int32_t blank = token2id_.at(" ");
@@ -217,7 +217,7 @@ JiebaLexicon::JiebaLexicon(const std::string &lexicon, @@ -217,7 +217,7 @@ JiebaLexicon::JiebaLexicon(const std::string &lexicon,
217 : impl_(std::make_unique<Impl>(lexicon, tokens, dict_dir, meta_data, 217 : impl_(std::make_unique<Impl>(lexicon, tokens, dict_dir, meta_data,
218 debug)) {} 218 debug)) {}
219 219
220 -std::vector<std::vector<int64_t>> JiebaLexicon::ConvertTextToTokenIds( 220 +std::vector<TokenIDs> JiebaLexicon::ConvertTextToTokenIds(
221 const std::string &text, const std::string & /*unused_voice = ""*/) const { 221 const std::string &text, const std::string & /*unused_voice = ""*/) const {
222 return impl_->ConvertTextToTokenIds(text); 222 return impl_->ConvertTextToTokenIds(text);
223 } 223 }
@@ -10,11 +10,6 @@ @@ -10,11 +10,6 @@
10 #include <unordered_map> 10 #include <unordered_map>
11 #include <vector> 11 #include <vector>
12 12
13 -#if __ANDROID_API__ >= 9  
14 -#include "android/asset_manager.h"  
15 -#include "android/asset_manager_jni.h"  
16 -#endif  
17 -  
18 #include "sherpa-onnx/csrc/offline-tts-frontend.h" 13 #include "sherpa-onnx/csrc/offline-tts-frontend.h"
19 #include "sherpa-onnx/csrc/offline-tts-vits-model-metadata.h" 14 #include "sherpa-onnx/csrc/offline-tts-vits-model-metadata.h"
20 15
@@ -27,13 +22,7 @@ class JiebaLexicon : public OfflineTtsFrontend { @@ -27,13 +22,7 @@ class JiebaLexicon : public OfflineTtsFrontend {
27 const std::string &dict_dir, 22 const std::string &dict_dir,
28 const OfflineTtsVitsModelMetaData &meta_data, bool debug); 23 const OfflineTtsVitsModelMetaData &meta_data, bool debug);
29 24
30 -#if __ANDROID_API__ >= 9  
31 - JiebaLexicon(AAssetManager *mgr, const std::string &lexicon,  
32 - const std::string &tokens, const std::string &dict_dir,  
33 - const OfflineTtsVitsModelMetaData &meta_data);  
34 -#endif  
35 -  
36 - std::vector<std::vector<int64_t>> ConvertTextToTokenIds( 25 + std::vector<TokenIDs> ConvertTextToTokenIds(
37 const std::string &text, 26 const std::string &text,
38 const std::string &unused_voice = "") const override; 27 const std::string &unused_voice = "") const override;
39 28
@@ -172,7 +172,7 @@ Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon, @@ -172,7 +172,7 @@ Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon,
172 } 172 }
173 #endif 173 #endif
174 174
175 -std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIds( 175 +std::vector<TokenIDs> Lexicon::ConvertTextToTokenIds(
176 const std::string &text, const std::string & /*voice*/ /*= ""*/) const { 176 const std::string &text, const std::string & /*voice*/ /*= ""*/) const {
177 switch (language_) { 177 switch (language_) {
178 case Language::kChinese: 178 case Language::kChinese:
@@ -187,7 +187,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIds( @@ -187,7 +187,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIds(
187 return {}; 187 return {};
188 } 188 }
189 189
190 -std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese( 190 +std::vector<TokenIDs> Lexicon::ConvertTextToTokenIdsChinese(
191 const std::string &_text) const { 191 const std::string &_text) const {
192 std::string text(_text); 192 std::string text(_text);
193 ToLowerCase(&text); 193 ToLowerCase(&text);
@@ -209,7 +209,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese( @@ -209,7 +209,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
209 fprintf(stderr, "\n"); 209 fprintf(stderr, "\n");
210 } 210 }
211 211
212 - std::vector<std::vector<int64_t>> ans; 212 + std::vector<TokenIDs> ans;
213 std::vector<int64_t> this_sentence; 213 std::vector<int64_t> this_sentence;
214 214
215 int32_t blank = -1; 215 int32_t blank = -1;
@@ -288,7 +288,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese( @@ -288,7 +288,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
288 return ans; 288 return ans;
289 } 289 }
290 290
291 -std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsNotChinese( 291 +std::vector<TokenIDs> Lexicon::ConvertTextToTokenIdsNotChinese(
292 const std::string &_text) const { 292 const std::string &_text) const {
293 std::string text(_text); 293 std::string text(_text);
294 ToLowerCase(&text); 294 ToLowerCase(&text);
@@ -311,7 +311,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsNotChinese( @@ -311,7 +311,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsNotChinese(
311 311
312 int32_t blank = token2id_.at(" "); 312 int32_t blank = token2id_.at(" ");
313 313
314 - std::vector<std::vector<int64_t>> ans; 314 + std::vector<TokenIDs> ans;
315 std::vector<int64_t> this_sentence; 315 std::vector<int64_t> this_sentence;
316 316
317 for (const auto &w : words) { 317 for (const auto &w : words) {
@@ -36,14 +36,14 @@ class Lexicon : public OfflineTtsFrontend { @@ -36,14 +36,14 @@ class Lexicon : public OfflineTtsFrontend {
36 const std::string &language, bool debug = false); 36 const std::string &language, bool debug = false);
37 #endif 37 #endif
38 38
39 - std::vector<std::vector<int64_t>> ConvertTextToTokenIds( 39 + std::vector<TokenIDs> ConvertTextToTokenIds(
40 const std::string &text, const std::string &voice = "") const override; 40 const std::string &text, const std::string &voice = "") const override;
41 41
42 private: 42 private:
43 - std::vector<std::vector<int64_t>> ConvertTextToTokenIdsNotChinese( 43 + std::vector<TokenIDs> ConvertTextToTokenIdsNotChinese(
44 const std::string &text) const; 44 const std::string &text) const;
45 45
46 - std::vector<std::vector<int64_t>> ConvertTextToTokenIdsChinese( 46 + std::vector<TokenIDs> ConvertTextToTokenIdsChinese(
47 const std::string &text) const; 47 const std::string &text) const;
48 48
49 void InitLanguage(const std::string &lang); 49 void InitLanguage(const std::string &lang);
  1 +// sherpa-onnx/csrc/melo-tts-lexicon.cc
  2 +//
  3 +// Copyright (c) 2022-2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/melo-tts-lexicon.h"
  6 +
  7 +#include <fstream>
  8 +#include <regex> // NOLINT
  9 +#include <utility>
  10 +
  11 +#include "cppjieba/Jieba.hpp"
  12 +#include "sherpa-onnx/csrc/file-utils.h"
  13 +#include "sherpa-onnx/csrc/macros.h"
  14 +#include "sherpa-onnx/csrc/text-utils.h"
  15 +
  16 +namespace sherpa_onnx {
  17 +
  18 +// implemented in ./lexicon.cc
  19 +std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is);
  20 +
  21 +std::vector<int32_t> ConvertTokensToIds(
  22 + const std::unordered_map<std::string, int32_t> &token2id,
  23 + const std::vector<std::string> &tokens);
  24 +
  25 +class MeloTtsLexicon::Impl {
  26 + public:
  27 + Impl(const std::string &lexicon, const std::string &tokens,
  28 + const std::string &dict_dir,
  29 + const OfflineTtsVitsModelMetaData &meta_data, bool debug)
  30 + : meta_data_(meta_data), debug_(debug) {
  31 + std::string dict = dict_dir + "/jieba.dict.utf8";
  32 + std::string hmm = dict_dir + "/hmm_model.utf8";
  33 + std::string user_dict = dict_dir + "/user.dict.utf8";
  34 + std::string idf = dict_dir + "/idf.utf8";
  35 + std::string stop_word = dict_dir + "/stop_words.utf8";
  36 +
  37 + AssertFileExists(dict);
  38 + AssertFileExists(hmm);
  39 + AssertFileExists(user_dict);
  40 + AssertFileExists(idf);
  41 + AssertFileExists(stop_word);
  42 +
  43 + jieba_ =
  44 + std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word);
  45 +
  46 + {
  47 + std::ifstream is(tokens);
  48 + InitTokens(is);
  49 + }
  50 +
  51 + {
  52 + std::ifstream is(lexicon);
  53 + InitLexicon(is);
  54 + }
  55 + }
  56 +
  57 + std::vector<TokenIDs> ConvertTextToTokenIds(const std::string &_text) const {
  58 + std::string text = ToLowerCase(_text);
  59 + // see
  60 + // https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/text/mandarin.py#L244
  61 + std::regex punct_re{":|、|;"};
  62 + std::string s = std::regex_replace(text, punct_re, ",");
  63 +
  64 + std::regex punct_re2("。");
  65 + s = std::regex_replace(s, punct_re2, ".");
  66 +
  67 + std::regex punct_re3("?");
  68 + s = std::regex_replace(s, punct_re3, "?");
  69 +
  70 + std::regex punct_re4("!");
  71 + s = std::regex_replace(s, punct_re4, "!");
  72 +
  73 + std::vector<std::string> words;
  74 + bool is_hmm = true;
  75 + jieba_->Cut(text, words, is_hmm);
  76 +
  77 + if (debug_) {
  78 + SHERPA_ONNX_LOGE("input text: %s", text.c_str());
  79 + SHERPA_ONNX_LOGE("after replacing punctuations: %s", s.c_str());
  80 +
  81 + std::ostringstream os;
  82 + std::string sep = "";
  83 + for (const auto &w : words) {
  84 + os << sep << w;
  85 + sep = "_";
  86 + }
  87 +
  88 + SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str());
  89 + }
  90 +
  91 + std::vector<TokenIDs> ans;
  92 + TokenIDs this_sentence;
  93 +
  94 + int32_t blank = token2id_.at("_");
  95 + for (const auto &w : words) {
  96 + auto ids = ConvertWordToIds(w);
  97 + if (ids.tokens.empty()) {
  98 + SHERPA_ONNX_LOGE("Ignore OOV '%s'", w.c_str());
  99 + continue;
  100 + }
  101 +
  102 + this_sentence.tokens.insert(this_sentence.tokens.end(),
  103 + ids.tokens.begin(), ids.tokens.end());
  104 + this_sentence.tones.insert(this_sentence.tones.end(), ids.tones.begin(),
  105 + ids.tones.end());
  106 +
  107 + if (w == "." || w == "!" || w == "?" || w == ",") {
  108 + ans.push_back(std::move(this_sentence));
  109 + this_sentence = {};
  110 + }
  111 + } // for (const auto &w : words)
  112 +
  113 + if (!this_sentence.tokens.empty()) {
  114 + ans.push_back(std::move(this_sentence));
  115 + }
  116 +
  117 + return ans;
  118 + }
  119 +
  120 + private:
  121 + TokenIDs ConvertWordToIds(const std::string &w) const {
  122 + if (word2ids_.count(w)) {
  123 + return word2ids_.at(w);
  124 + }
  125 +
  126 + if (token2id_.count(w)) {
  127 + return {{token2id_.at(w)}, {0}};
  128 + }
  129 +
  130 + TokenIDs ans;
  131 +
  132 + std::vector<std::string> words = SplitUtf8(w);
  133 + for (const auto &word : words) {
  134 + if (word2ids_.count(word)) {
  135 + auto ids = ConvertWordToIds(word);
  136 + ans.tokens.insert(ans.tokens.end(), ids.tokens.begin(),
  137 + ids.tokens.end());
  138 + ans.tones.insert(ans.tones.end(), ids.tones.begin(), ids.tones.end());
  139 + }
  140 + }
  141 +
  142 + return ans;
  143 + }
  144 +
  145 + void InitTokens(std::istream &is) {
  146 + token2id_ = ReadTokens(is);
  147 + token2id_[" "] = token2id_["_"];
  148 +
  149 + std::vector<std::pair<std::string, std::string>> puncts = {
  150 + {",", ","}, {".", "。"}, {"!", "!"}, {"?", "?"}};
  151 +
  152 + for (const auto &p : puncts) {
  153 + if (token2id_.count(p.first) && !token2id_.count(p.second)) {
  154 + token2id_[p.second] = token2id_[p.first];
  155 + }
  156 +
  157 + if (!token2id_.count(p.first) && token2id_.count(p.second)) {
  158 + token2id_[p.first] = token2id_[p.second];
  159 + }
  160 + }
  161 +
  162 + if (!token2id_.count("、") && token2id_.count(",")) {
  163 + token2id_["、"] = token2id_[","];
  164 + }
  165 + }
  166 +
  167 + void InitLexicon(std::istream &is) {
  168 + std::string word;
  169 + std::vector<std::string> token_list;
  170 +
  171 + std::vector<std::string> phone_list;
  172 + std::vector<int64_t> tone_list;
  173 +
  174 + std::string line;
  175 + std::string phone;
  176 + int32_t line_num = 0;
  177 +
  178 + while (std::getline(is, line)) {
  179 + ++line_num;
  180 +
  181 + std::istringstream iss(line);
  182 +
  183 + token_list.clear();
  184 + phone_list.clear();
  185 + tone_list.clear();
  186 +
  187 + iss >> word;
  188 + ToLowerCase(&word);
  189 +
  190 + if (word2ids_.count(word)) {
  191 + SHERPA_ONNX_LOGE("Duplicated word: %s at line %d:%s. Ignore it.",
  192 + word.c_str(), line_num, line.c_str());
  193 + continue;
  194 + }
  195 +
  196 + while (iss >> phone) {
  197 + token_list.push_back(std::move(phone));
  198 + }
  199 +
  200 + if ((token_list.size() & 1) != 0) {
  201 + SHERPA_ONNX_LOGE("Invalid line %d: '%s'", line_num, line.c_str());
  202 + exit(-1);
  203 + }
  204 +
  205 + int32_t num_phones = token_list.size() / 2;
  206 + phone_list.reserve(num_phones);
  207 + tone_list.reserve(num_phones);
  208 +
  209 + for (int32_t i = 0; i != num_phones; ++i) {
  210 + phone_list.push_back(std::move(token_list[i]));
  211 + tone_list.push_back(std::stoi(token_list[i + num_phones], nullptr));
  212 + if (tone_list.back() < 0 || tone_list.back() > 50) {
  213 + SHERPA_ONNX_LOGE("Invalid line %d: '%s'", line_num, line.c_str());
  214 + exit(-1);
  215 + }
  216 + }
  217 +
  218 + std::vector<int32_t> ids = ConvertTokensToIds(token2id_, phone_list);
  219 + if (ids.empty()) {
  220 + continue;
  221 + }
  222 +
  223 + if (ids.size() != num_phones) {
  224 + SHERPA_ONNX_LOGE("Invalid line %d: '%s'", line_num, line.c_str());
  225 + exit(-1);
  226 + }
  227 +
  228 + std::vector<int64_t> ids64{ids.begin(), ids.end()};
  229 +
  230 + word2ids_.insert(
  231 + {std::move(word), TokenIDs{std::move(ids64), std::move(tone_list)}});
  232 + }
  233 +
  234 + word2ids_["呣"] = word2ids_["母"];
  235 + word2ids_["嗯"] = word2ids_["恩"];
  236 + }
  237 +
  238 + private:
  239 + // lexicon.txt is saved in word2ids_
  240 + std::unordered_map<std::string, TokenIDs> word2ids_;
  241 +
  242 + // tokens.txt is saved in token2id_
  243 + std::unordered_map<std::string, int32_t> token2id_;
  244 +
  245 + OfflineTtsVitsModelMetaData meta_data_;
  246 +
  247 + std::unique_ptr<cppjieba::Jieba> jieba_;
  248 + bool debug_ = false;
  249 +};
  250 +
  251 +MeloTtsLexicon::~MeloTtsLexicon() = default;
  252 +
  253 +MeloTtsLexicon::MeloTtsLexicon(const std::string &lexicon,
  254 + const std::string &tokens,
  255 + const std::string &dict_dir,
  256 + const OfflineTtsVitsModelMetaData &meta_data,
  257 + bool debug)
  258 + : impl_(std::make_unique<Impl>(lexicon, tokens, dict_dir, meta_data,
  259 + debug)) {}
  260 +
  261 +std::vector<TokenIDs> MeloTtsLexicon::ConvertTextToTokenIds(
  262 + const std::string &text, const std::string & /*unused_voice = ""*/) const {
  263 + return impl_->ConvertTextToTokenIds(text);
  264 +}
  265 +
  266 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/melo-tts-lexicon.h
  2 +//
  3 +// Copyright (c) 2022-2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_MELO_TTS_LEXICON_H_
  6 +#define SHERPA_ONNX_CSRC_MELO_TTS_LEXICON_H_
  7 +
  8 +#include <memory>
  9 +#include <string>
  10 +#include <unordered_map>
  11 +#include <vector>
  12 +
  13 +#include "sherpa-onnx/csrc/offline-tts-frontend.h"
  14 +#include "sherpa-onnx/csrc/offline-tts-vits-model-metadata.h"
  15 +
  16 +namespace sherpa_onnx {
  17 +
  18 +class MeloTtsLexicon : public OfflineTtsFrontend {
  19 + public:
  20 + ~MeloTtsLexicon() override;
  21 + MeloTtsLexicon(const std::string &lexicon, const std::string &tokens,
  22 + const std::string &dict_dir,
  23 + const OfflineTtsVitsModelMetaData &meta_data, bool debug);
  24 +
  25 + std::vector<TokenIDs> ConvertTextToTokenIds(
  26 + const std::string &text,
  27 + const std::string &unused_voice = "") const override;
  28 +
  29 + private:
  30 + class Impl;
  31 + std::unique_ptr<Impl> impl_;
  32 +};
  33 +
  34 +} // namespace sherpa_onnx
  35 +
  36 +#endif // SHERPA_ONNX_CSRC_MELO_TTS_LEXICON_H_
@@ -94,8 +94,7 @@ OfflineTtsCharacterFrontend::OfflineTtsCharacterFrontend( @@ -94,8 +94,7 @@ OfflineTtsCharacterFrontend::OfflineTtsCharacterFrontend(
94 94
95 #endif 95 #endif
96 96
97 -std::vector<std::vector<int64_t>>  
98 -OfflineTtsCharacterFrontend::ConvertTextToTokenIds( 97 +std::vector<TokenIDs> OfflineTtsCharacterFrontend::ConvertTextToTokenIds(
99 const std::string &_text, const std::string & /*voice = ""*/) const { 98 const std::string &_text, const std::string & /*voice = ""*/) const {
100 // see 99 // see
101 // https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/utils/text/tokenizer.py#L87 100 // https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/utils/text/tokenizer.py#L87
@@ -112,7 +111,7 @@ OfflineTtsCharacterFrontend::ConvertTextToTokenIds( @@ -112,7 +111,7 @@ OfflineTtsCharacterFrontend::ConvertTextToTokenIds(
112 std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> conv; 111 std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> conv;
113 std::u32string s = conv.from_bytes(text); 112 std::u32string s = conv.from_bytes(text);
114 113
115 - std::vector<std::vector<int64_t>> ans; 114 + std::vector<TokenIDs> ans;
116 115
117 std::vector<int64_t> this_sentence; 116 std::vector<int64_t> this_sentence;
118 if (add_blank) { 117 if (add_blank) {
@@ -41,7 +41,7 @@ class OfflineTtsCharacterFrontend : public OfflineTtsFrontend { @@ -41,7 +41,7 @@ class OfflineTtsCharacterFrontend : public OfflineTtsFrontend {
41 * If a frontend does not support splitting the text into 41 * If a frontend does not support splitting the text into
42 * sentences, the resulting vector contains only one subvector. 42 * sentences, the resulting vector contains only one subvector.
43 */ 43 */
44 - std::vector<std::vector<int64_t>> ConvertTextToTokenIds( 44 + std::vector<TokenIDs> ConvertTextToTokenIds(
45 const std::string &text, const std::string &voice = "") const override; 45 const std::string &text, const std::string &voice = "") const override;
46 46
47 private: 47 private:
  1 +// sherpa-onnx/csrc/offline-tts-frontend.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-tts-frontend.h"
  6 +
  7 +#include <sstream>
  8 +#include <string>
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +std::string TokenIDs::ToString() const {
  13 + std::ostringstream os;
  14 + os << "TokenIDs(";
  15 + os << "tokens=[";
  16 + std::string sep;
  17 + for (auto i : tokens) {
  18 + os << sep << i;
  19 + sep = ", ";
  20 + }
  21 + os << "], ";
  22 +
  23 + os << "tones=[";
  24 + sep = {};
  25 + for (auto i : tones) {
  26 + os << sep << i;
  27 + sep = ", ";
  28 + }
  29 + os << "]";
  30 + os << ")";
  31 + return os.str();
  32 +}
  33 +
  34 +} // namespace sherpa_onnx
@@ -8,8 +8,28 @@ @@ -8,8 +8,28 @@
8 #include <string> 8 #include <string>
9 #include <vector> 9 #include <vector>
10 10
  11 +#include "sherpa-onnx/csrc/macros.h"
  12 +
11 namespace sherpa_onnx { 13 namespace sherpa_onnx {
12 14
  15 +struct TokenIDs {
  16 + TokenIDs() = default;
  17 +
  18 + /*implicit*/ TokenIDs(const std::vector<int64_t> &tokens) // NOLINT
  19 + : tokens{tokens} {}
  20 +
  21 + TokenIDs(const std::vector<int64_t> &tokens,
  22 + const std::vector<int64_t> &tones)
  23 + : tokens{tokens}, tones{tones} {}
  24 +
  25 + std::string ToString() const;
  26 +
  27 + std::vector<int64_t> tokens;
  28 +
  29 + // Used only in MeloTTS
  30 + std::vector<int64_t> tones;
  31 +};
  32 +
13 class OfflineTtsFrontend { 33 class OfflineTtsFrontend {
14 public: 34 public:
15 virtual ~OfflineTtsFrontend() = default; 35 virtual ~OfflineTtsFrontend() = default;
@@ -26,7 +46,7 @@ class OfflineTtsFrontend { @@ -26,7 +46,7 @@ class OfflineTtsFrontend {
26 * If a frontend does not support splitting the text into sentences, 46 * If a frontend does not support splitting the text into sentences,
27 * the resulting vector contains only one subvector. 47 * the resulting vector contains only one subvector.
28 */ 48 */
29 - virtual std::vector<std::vector<int64_t>> ConvertTextToTokenIds( 49 + virtual std::vector<TokenIDs> ConvertTextToTokenIds(
30 const std::string &text, const std::string &voice = "") const = 0; 50 const std::string &text, const std::string &voice = "") const = 0;
31 }; 51 };
32 52
@@ -22,6 +22,7 @@ @@ -22,6 +22,7 @@
22 #include "sherpa-onnx/csrc/jieba-lexicon.h" 22 #include "sherpa-onnx/csrc/jieba-lexicon.h"
23 #include "sherpa-onnx/csrc/lexicon.h" 23 #include "sherpa-onnx/csrc/lexicon.h"
24 #include "sherpa-onnx/csrc/macros.h" 24 #include "sherpa-onnx/csrc/macros.h"
  25 +#include "sherpa-onnx/csrc/melo-tts-lexicon.h"
25 #include "sherpa-onnx/csrc/offline-tts-character-frontend.h" 26 #include "sherpa-onnx/csrc/offline-tts-character-frontend.h"
26 #include "sherpa-onnx/csrc/offline-tts-frontend.h" 27 #include "sherpa-onnx/csrc/offline-tts-frontend.h"
27 #include "sherpa-onnx/csrc/offline-tts-impl.h" 28 #include "sherpa-onnx/csrc/offline-tts-impl.h"
@@ -174,26 +175,47 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -174,26 +175,47 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
174 } 175 }
175 } 176 }
176 177
177 - std::vector<std::vector<int64_t>> x = 178 + std::vector<TokenIDs> token_ids =
178 frontend_->ConvertTextToTokenIds(text, meta_data.voice); 179 frontend_->ConvertTextToTokenIds(text, meta_data.voice);
179 180
180 - if (x.empty() || (x.size() == 1 && x[0].empty())) { 181 + if (token_ids.empty() ||
  182 + (token_ids.size() == 1 && token_ids[0].tokens.empty())) {
181 SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str()); 183 SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
182 return {}; 184 return {};
183 } 185 }
184 186
  187 + std::vector<std::vector<int64_t>> x;
  188 + std::vector<std::vector<int64_t>> tones;
  189 +
  190 + x.reserve(token_ids.size());
  191 +
  192 + for (auto &i : token_ids) {
  193 + x.push_back(std::move(i.tokens));
  194 + }
  195 +
  196 + if (!token_ids[0].tones.empty()) {
  197 + tones.reserve(token_ids.size());
  198 + for (auto &i : token_ids) {
  199 + tones.push_back(std::move(i.tones));
  200 + }
  201 + }
  202 +
185 // TODO(fangjun): add blank inside the frontend, not here 203 // TODO(fangjun): add blank inside the frontend, not here
186 if (meta_data.add_blank && config_.model.vits.data_dir.empty() && 204 if (meta_data.add_blank && config_.model.vits.data_dir.empty() &&
187 meta_data.frontend != "characters") { 205 meta_data.frontend != "characters") {
188 for (auto &k : x) { 206 for (auto &k : x) {
189 k = AddBlank(k); 207 k = AddBlank(k);
190 } 208 }
  209 +
  210 + for (auto &k : tones) {
  211 + k = AddBlank(k);
  212 + }
191 } 213 }
192 214
193 int32_t x_size = static_cast<int32_t>(x.size()); 215 int32_t x_size = static_cast<int32_t>(x.size());
194 216
195 if (config_.max_num_sentences <= 0 || x_size <= config_.max_num_sentences) { 217 if (config_.max_num_sentences <= 0 || x_size <= config_.max_num_sentences) {
196 - auto ans = Process(x, sid, speed); 218 + auto ans = Process(x, tones, sid, speed);
197 if (callback) { 219 if (callback) {
198 callback(ans.samples.data(), ans.samples.size(), 1.0); 220 callback(ans.samples.data(), ans.samples.size(), 1.0);
199 } 221 }
@@ -202,9 +224,12 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -202,9 +224,12 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
202 224
203 // the input text is too long, we process sentences within it in batches 225 // the input text is too long, we process sentences within it in batches
204 // to avoid OOM. Batch size is config_.max_num_sentences 226 // to avoid OOM. Batch size is config_.max_num_sentences
205 - std::vector<std::vector<int64_t>> batch; 227 + std::vector<std::vector<int64_t>> batch_x;
  228 + std::vector<std::vector<int64_t>> batch_tones;
  229 +
206 int32_t batch_size = config_.max_num_sentences; 230 int32_t batch_size = config_.max_num_sentences;
207 - batch.reserve(config_.max_num_sentences); 231 + batch_x.reserve(config_.max_num_sentences);
  232 + batch_tones.reserve(config_.max_num_sentences);
208 int32_t num_batches = x_size / batch_size; 233 int32_t num_batches = x_size / batch_size;
209 234
210 if (config_.model.debug) { 235 if (config_.model.debug) {
@@ -221,12 +246,17 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -221,12 +246,17 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
221 int32_t k = 0; 246 int32_t k = 0;
222 247
223 for (int32_t b = 0; b != num_batches && should_continue; ++b) { 248 for (int32_t b = 0; b != num_batches && should_continue; ++b) {
224 - batch.clear(); 249 + batch_x.clear();
  250 + batch_tones.clear();
225 for (int32_t i = 0; i != batch_size; ++i, ++k) { 251 for (int32_t i = 0; i != batch_size; ++i, ++k) {
226 - batch.push_back(std::move(x[k])); 252 + batch_x.push_back(std::move(x[k]));
  253 +
  254 + if (!tones.empty()) {
  255 + batch_tones.push_back(std::move(tones[k]));
  256 + }
227 } 257 }
228 258
229 - auto audio = Process(batch, sid, speed); 259 + auto audio = Process(batch_x, batch_tones, sid, speed);
230 ans.sample_rate = audio.sample_rate; 260 ans.sample_rate = audio.sample_rate;
231 ans.samples.insert(ans.samples.end(), audio.samples.begin(), 261 ans.samples.insert(ans.samples.end(), audio.samples.begin(),
232 audio.samples.end()); 262 audio.samples.end());
@@ -239,14 +269,19 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -239,14 +269,19 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
239 } 269 }
240 } 270 }
241 271
242 - batch.clear(); 272 + batch_x.clear();
  273 + batch_tones.clear();
243 while (k < static_cast<int32_t>(x.size()) && should_continue) { 274 while (k < static_cast<int32_t>(x.size()) && should_continue) {
244 - batch.push_back(std::move(x[k])); 275 + batch_x.push_back(std::move(x[k]));
  276 + if (!tones.empty()) {
  277 + batch_tones.push_back(std::move(tones[k]));
  278 + }
  279 +
245 ++k; 280 ++k;
246 } 281 }
247 282
248 - if (!batch.empty()) {  
249 - auto audio = Process(batch, sid, speed); 283 + if (!batch_x.empty()) {
  284 + auto audio = Process(batch_x, batch_tones, sid, speed);
250 ans.sample_rate = audio.sample_rate; 285 ans.sample_rate = audio.sample_rate;
251 ans.samples.insert(ans.samples.end(), audio.samples.begin(), 286 ans.samples.insert(ans.samples.end(), audio.samples.begin(),
252 audio.samples.end()); 287 audio.samples.end());
@@ -308,6 +343,12 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -308,6 +343,12 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
308 if (meta_data.frontend == "characters") { 343 if (meta_data.frontend == "characters") {
309 frontend_ = std::make_unique<OfflineTtsCharacterFrontend>( 344 frontend_ = std::make_unique<OfflineTtsCharacterFrontend>(
310 config_.model.vits.tokens, meta_data); 345 config_.model.vits.tokens, meta_data);
  346 + } else if (meta_data.jieba && !config_.model.vits.dict_dir.empty() &&
  347 + meta_data.is_melo_tts) {
  348 + frontend_ = std::make_unique<MeloTtsLexicon>(
  349 + config_.model.vits.lexicon, config_.model.vits.tokens,
  350 + config_.model.vits.dict_dir, model_->GetMetaData(),
  351 + config_.model.debug);
311 } else if (meta_data.jieba && !config_.model.vits.dict_dir.empty()) { 352 } else if (meta_data.jieba && !config_.model.vits.dict_dir.empty()) {
312 frontend_ = std::make_unique<JiebaLexicon>( 353 frontend_ = std::make_unique<JiebaLexicon>(
313 config_.model.vits.lexicon, config_.model.vits.tokens, 354 config_.model.vits.lexicon, config_.model.vits.tokens,
@@ -344,6 +385,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -344,6 +385,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
344 } 385 }
345 386
346 GeneratedAudio Process(const std::vector<std::vector<int64_t>> &tokens, 387 GeneratedAudio Process(const std::vector<std::vector<int64_t>> &tokens,
  388 + const std::vector<std::vector<int64_t>> &tones,
347 int32_t sid, float speed) const { 389 int32_t sid, float speed) const {
348 int32_t num_tokens = 0; 390 int32_t num_tokens = 0;
349 for (const auto &k : tokens) { 391 for (const auto &k : tokens) {
@@ -356,6 +398,14 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -356,6 +398,14 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
356 x.insert(x.end(), k.begin(), k.end()); 398 x.insert(x.end(), k.begin(), k.end());
357 } 399 }
358 400
  401 + std::vector<int64_t> tone_list;
  402 + if (!tones.empty()) {
  403 + tone_list.reserve(num_tokens);
  404 + for (const auto &k : tones) {
  405 + tone_list.insert(tone_list.end(), k.begin(), k.end());
  406 + }
  407 + }
  408 +
359 auto memory_info = 409 auto memory_info =
360 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); 410 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
361 411
@@ -363,7 +413,20 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -363,7 +413,20 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
363 Ort::Value x_tensor = Ort::Value::CreateTensor( 413 Ort::Value x_tensor = Ort::Value::CreateTensor(
364 memory_info, x.data(), x.size(), x_shape.data(), x_shape.size()); 414 memory_info, x.data(), x.size(), x_shape.data(), x_shape.size());
365 415
366 - Ort::Value audio = model_->Run(std::move(x_tensor), sid, speed); 416 + Ort::Value tones_tensor{nullptr};
  417 + if (!tones.empty()) {
  418 + tones_tensor = Ort::Value::CreateTensor(memory_info, tone_list.data(),
  419 + tone_list.size(), x_shape.data(),
  420 + x_shape.size());
  421 + }
  422 +
  423 + Ort::Value audio{nullptr};
  424 + if (tones.empty()) {
  425 + audio = model_->Run(std::move(x_tensor), sid, speed);
  426 + } else {
  427 + audio =
  428 + model_->Run(std::move(x_tensor), std::move(tones_tensor), sid, speed);
  429 + }
367 430
368 std::vector<int64_t> audio_shape = 431 std::vector<int64_t> audio_shape =
369 audio.GetTensorTypeAndShapeInfo().GetShape(); 432 audio.GetTensorTypeAndShapeInfo().GetShape();
@@ -21,6 +21,7 @@ struct OfflineTtsVitsModelMetaData { @@ -21,6 +21,7 @@ struct OfflineTtsVitsModelMetaData {
21 bool is_piper = false; 21 bool is_piper = false;
22 bool is_coqui = false; 22 bool is_coqui = false;
23 bool is_icefall = false; 23 bool is_icefall = false;
  24 + bool is_melo_tts = false;
24 25
25 // for Chinese TTS models from 26 // for Chinese TTS models from
26 // https://github.com/Plachtaa/VITS-fast-fine-tuning 27 // https://github.com/Plachtaa/VITS-fast-fine-tuning
@@ -33,6 +34,10 @@ struct OfflineTtsVitsModelMetaData { @@ -33,6 +34,10 @@ struct OfflineTtsVitsModelMetaData {
33 int32_t use_eos_bos = 0; 34 int32_t use_eos_bos = 0;
34 int32_t pad_id = 0; 35 int32_t pad_id = 0;
35 36
  37 + // for melo tts
  38 + int32_t speaker_id = 0;
  39 + int32_t version = 0;
  40 +
36 std::string punctuations; 41 std::string punctuations;
37 std::string language; 42 std::string language;
38 std::string voice; 43 std::string voice;
@@ -45,6 +45,64 @@ class OfflineTtsVitsModel::Impl { @@ -45,6 +45,64 @@ class OfflineTtsVitsModel::Impl {
45 return RunVits(std::move(x), sid, speed); 45 return RunVits(std::move(x), sid, speed);
46 } 46 }
47 47
  48 + Ort::Value Run(Ort::Value x, Ort::Value tones, int64_t sid, float speed) {
  49 + // For MeloTTS, we hardcode sid to the one contained in the meta data
  50 + sid = meta_data_.speaker_id;
  51 +
  52 + auto memory_info =
  53 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  54 +
  55 + std::vector<int64_t> x_shape = x.GetTensorTypeAndShapeInfo().GetShape();
  56 + if (x_shape[0] != 1) {
  57 + SHERPA_ONNX_LOGE("Support only batch_size == 1. Given: %d",
  58 + static_cast<int32_t>(x_shape[0]));
  59 + exit(-1);
  60 + }
  61 +
  62 + int64_t len = x_shape[1];
  63 + int64_t len_shape = 1;
  64 +
  65 + Ort::Value x_length =
  66 + Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1);
  67 +
  68 + int64_t scale_shape = 1;
  69 + float noise_scale = config_.vits.noise_scale;
  70 + float length_scale = config_.vits.length_scale;
  71 + float noise_scale_w = config_.vits.noise_scale_w;
  72 +
  73 + if (speed != 1 && speed > 0) {
  74 + length_scale = 1. / speed;
  75 + }
  76 +
  77 + Ort::Value noise_scale_tensor =
  78 + Ort::Value::CreateTensor(memory_info, &noise_scale, 1, &scale_shape, 1);
  79 +
  80 + Ort::Value length_scale_tensor = Ort::Value::CreateTensor(
  81 + memory_info, &length_scale, 1, &scale_shape, 1);
  82 +
  83 + Ort::Value noise_scale_w_tensor = Ort::Value::CreateTensor(
  84 + memory_info, &noise_scale_w, 1, &scale_shape, 1);
  85 +
  86 + Ort::Value sid_tensor =
  87 + Ort::Value::CreateTensor(memory_info, &sid, 1, &scale_shape, 1);
  88 +
  89 + std::vector<Ort::Value> inputs;
  90 + inputs.reserve(7);
  91 + inputs.push_back(std::move(x));
  92 + inputs.push_back(std::move(x_length));
  93 + inputs.push_back(std::move(tones));
  94 + inputs.push_back(std::move(sid_tensor));
  95 + inputs.push_back(std::move(noise_scale_tensor));
  96 + inputs.push_back(std::move(length_scale_tensor));
  97 + inputs.push_back(std::move(noise_scale_w_tensor));
  98 +
  99 + auto out =
  100 + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
  101 + output_names_ptr_.data(), output_names_ptr_.size());
  102 +
  103 + return std::move(out[0]);
  104 + }
  105 +
48 const OfflineTtsVitsModelMetaData &GetMetaData() const { return meta_data_; } 106 const OfflineTtsVitsModelMetaData &GetMetaData() const { return meta_data_; }
49 107
50 private: 108 private:
@@ -83,6 +141,10 @@ class OfflineTtsVitsModel::Impl { @@ -83,6 +141,10 @@ class OfflineTtsVitsModel::Impl {
83 SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate"); 141 SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate");
84 SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.add_blank, "add_blank", 142 SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.add_blank, "add_blank",
85 0); 143 0);
  144 +
  145 + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.speaker_id, "speaker_id",
  146 + 0);
  147 + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.version, "version", 0);
86 SHERPA_ONNX_READ_META_DATA(meta_data_.num_speakers, "n_speakers"); 148 SHERPA_ONNX_READ_META_DATA(meta_data_.num_speakers, "n_speakers");
87 SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.punctuations, 149 SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.punctuations,
88 "punctuation", ""); 150 "punctuation", "");
@@ -115,6 +177,22 @@ class OfflineTtsVitsModel::Impl { @@ -115,6 +177,22 @@ class OfflineTtsVitsModel::Impl {
115 if (comment.find("icefall") != std::string::npos) { 177 if (comment.find("icefall") != std::string::npos) {
116 meta_data_.is_icefall = true; 178 meta_data_.is_icefall = true;
117 } 179 }
  180 +
  181 + if (comment.find("melo") != std::string::npos) {
  182 + meta_data_.is_melo_tts = true;
  183 + int32_t expected_version = 2;
  184 + if (meta_data_.version < expected_version) {
  185 + SHERPA_ONNX_LOGE(
  186 + "Please download the latest MeloTTS model and retry. Current "
  187 + "version: %d. Expected version: %d",
  188 + meta_data_.version, expected_version);
  189 + exit(-1);
  190 + }
  191 +
  192 + // NOTE(fangjun):
  193 + // version 0 is the first version
  194 + // version 2: add jieba=1 to the metadata
  195 + }
118 } 196 }
119 197
120 Ort::Value RunVitsPiperOrCoqui(Ort::Value x, int64_t sid, float speed) { 198 Ort::Value RunVitsPiperOrCoqui(Ort::Value x, int64_t sid, float speed) {
@@ -269,6 +347,12 @@ Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, int64_t sid /*=0*/, @@ -269,6 +347,12 @@ Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, int64_t sid /*=0*/,
269 return impl_->Run(std::move(x), sid, speed); 347 return impl_->Run(std::move(x), sid, speed);
270 } 348 }
271 349
  350 +Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, Ort::Value tones,
  351 + int64_t sid /*= 0*/,
  352 + float speed /*= 1.0*/) {
  353 + return impl_->Run(std::move(x), std::move(tones), sid, speed);
  354 +}
  355 +
272 const OfflineTtsVitsModelMetaData &OfflineTtsVitsModel::GetMetaData() const { 356 const OfflineTtsVitsModelMetaData &OfflineTtsVitsModel::GetMetaData() const {
273 return impl_->GetMetaData(); 357 return impl_->GetMetaData();
274 } 358 }
@@ -40,6 +40,10 @@ class OfflineTtsVitsModel { @@ -40,6 +40,10 @@ class OfflineTtsVitsModel {
40 */ 40 */
41 Ort::Value Run(Ort::Value x, int64_t sid = 0, float speed = 1.0); 41 Ort::Value Run(Ort::Value x, int64_t sid = 0, float speed = 1.0);
42 42
  43 + // This is for MeloTTS
  44 + Ort::Value Run(Ort::Value x, Ort::Value tones, int64_t sid = 0,
  45 + float speed = 1.0);
  46 +
43 const OfflineTtsVitsModelMetaData &GetMetaData() const; 47 const OfflineTtsVitsModelMetaData &GetMetaData() const;
44 48
45 private: 49 private:
@@ -5,8 +5,8 @@ @@ -5,8 +5,8 @@
5 #ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_ 5 #ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
6 #define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_ 6 #define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
7 7
8 -#include <vector>  
9 #include <string> 8 #include <string>
  9 +#include <vector>
10 10
11 #include "onnxruntime_cxx_api.h" // NOLINT 11 #include "onnxruntime_cxx_api.h" // NOLINT
12 #include "sherpa-onnx/csrc/offline-whisper-model-config.h" 12 #include "sherpa-onnx/csrc/offline-whisper-model-config.h"
@@ -36,7 +36,6 @@ class OfflineWhisperDecoder { @@ -36,7 +36,6 @@ class OfflineWhisperDecoder {
36 Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0; 36 Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0;
37 37
38 virtual void SetConfig(const OfflineWhisperModelConfig &config) = 0; 38 virtual void SetConfig(const OfflineWhisperModelConfig &config) = 0;
39 -  
40 }; 39 };
41 40
42 } // namespace sherpa_onnx 41 } // namespace sherpa_onnx
@@ -12,7 +12,8 @@ @@ -12,7 +12,8 @@
12 12
13 namespace sherpa_onnx { 13 namespace sherpa_onnx {
14 14
15 -void OfflineWhisperGreedySearchDecoder::SetConfig(const OfflineWhisperModelConfig &config) { 15 +void OfflineWhisperGreedySearchDecoder::SetConfig(
  16 + const OfflineWhisperModelConfig &config) {
16 config_ = config; 17 config_ = config;
17 } 18 }
18 19
@@ -135,9 +136,9 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, @@ -135,9 +136,9 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
135 136
136 const auto &id2lang = model_->GetID2Lang(); 137 const auto &id2lang = model_->GetID2Lang();
137 if (id2lang.count(initial_tokens[1])) { 138 if (id2lang.count(initial_tokens[1])) {
138 - ans[0].lang = id2lang.at(initial_tokens[1]); 139 + ans[0].lang = id2lang.at(initial_tokens[1]);
139 } else { 140 } else {
140 - ans[0].lang = ""; 141 + ans[0].lang = "";
141 } 142 }
142 143
143 ans[0].tokens = std::move(predicted_tokens); 144 ans[0].tokens = std::move(predicted_tokens);
@@ -153,15 +153,21 @@ Ort::Value View(Ort::Value *v) { @@ -153,15 +153,21 @@ Ort::Value View(Ort::Value *v) {
153 } 153 }
154 } 154 }
155 155
  156 +template <typename T /*= float*/>
156 void Print1D(Ort::Value *v) { 157 void Print1D(Ort::Value *v) {
157 std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); 158 std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
158 - const float *d = v->GetTensorData<float>(); 159 + const T *d = v->GetTensorData<T>();
  160 + std::ostringstream os;
159 for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) { 161 for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
160 - fprintf(stderr, "%.3f ", d[i]); 162 + os << *d << " ";
161 } 163 }
162 - fprintf(stderr, "\n"); 164 + os << "\n";
  165 + fprintf(stderr, "%s\n", os.str().c_str());
163 } 166 }
164 167
  168 +template void Print1D<int64_t>(Ort::Value *v);
  169 +template void Print1D<float>(Ort::Value *v);
  170 +
165 template <typename T /*= float*/> 171 template <typename T /*= float*/>
166 void Print2D(Ort::Value *v) { 172 void Print2D(Ort::Value *v) {
167 std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); 173 std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
@@ -69,6 +69,7 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v); @@ -69,6 +69,7 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v);
69 Ort::Value View(Ort::Value *v); 69 Ort::Value View(Ort::Value *v);
70 70
71 // Print a 1-D tensor to stderr 71 // Print a 1-D tensor to stderr
  72 +template <typename T = float>
72 void Print1D(Ort::Value *v); 73 void Print1D(Ort::Value *v);
73 74
74 // Print a 2-D tensor to stderr 75 // Print a 2-D tensor to stderr
@@ -214,7 +214,7 @@ PiperPhonemizeLexicon::PiperPhonemizeLexicon( @@ -214,7 +214,7 @@ PiperPhonemizeLexicon::PiperPhonemizeLexicon(
214 } 214 }
215 #endif 215 #endif
216 216
217 -std::vector<std::vector<int64_t>> PiperPhonemizeLexicon::ConvertTextToTokenIds( 217 +std::vector<TokenIDs> PiperPhonemizeLexicon::ConvertTextToTokenIds(
218 const std::string &text, const std::string &voice /*= ""*/) const { 218 const std::string &text, const std::string &voice /*= ""*/) const {
219 piper::eSpeakPhonemeConfig config; 219 piper::eSpeakPhonemeConfig config;
220 220
@@ -232,7 +232,7 @@ std::vector<std::vector<int64_t>> PiperPhonemizeLexicon::ConvertTextToTokenIds( @@ -232,7 +232,7 @@ std::vector<std::vector<int64_t>> PiperPhonemizeLexicon::ConvertTextToTokenIds(
232 piper::phonemize_eSpeak(text, config, phonemes); 232 piper::phonemize_eSpeak(text, config, phonemes);
233 } 233 }
234 234
235 - std::vector<std::vector<int64_t>> ans; 235 + std::vector<TokenIDs> ans;
236 236
237 std::vector<int64_t> phoneme_ids; 237 std::vector<int64_t> phoneme_ids;
238 238
@@ -30,7 +30,7 @@ class PiperPhonemizeLexicon : public OfflineTtsFrontend { @@ -30,7 +30,7 @@ class PiperPhonemizeLexicon : public OfflineTtsFrontend {
30 const OfflineTtsVitsModelMetaData &meta_data); 30 const OfflineTtsVitsModelMetaData &meta_data);
31 #endif 31 #endif
32 32
33 - std::vector<std::vector<int64_t>> ConvertTextToTokenIds( 33 + std::vector<TokenIDs> ConvertTextToTokenIds(
34 const std::string &text, const std::string &voice = "") const override; 34 const std::string &text, const std::string &voice = "") const override;
35 35
36 private: 36 private:
@@ -31,8 +31,8 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) { @@ -31,8 +31,8 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) {
31 api.ReleaseStatus(status); 31 api.ReleaseStatus(status);
32 } 32 }
33 33
34 -static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,  
35 - const std::string &provider_str, 34 +static Ort::SessionOptions GetSessionOptionsImpl(
  35 + int32_t num_threads, const std::string &provider_str,
36 const ProviderConfig *provider_config = nullptr) { 36 const ProviderConfig *provider_config = nullptr) {
37 Provider p = StringToProvider(provider_str); 37 Provider p = StringToProvider(provider_str);
38 38
@@ -67,8 +67,9 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, @@ -67,8 +67,9 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
67 } 67 }
68 case Provider::kTRT: { 68 case Provider::kTRT: {
69 if (provider_config == nullptr) { 69 if (provider_config == nullptr) {
70 - SHERPA_ONNX_LOGE("Tensorrt support for Online models ony,"  
71 - "Must be extended for offline and others"); 70 + SHERPA_ONNX_LOGE(
  71 + "Tensorrt support for Online models ony,"
  72 + "Must be extended for offline and others");
72 exit(1); 73 exit(1);
73 } 74 }
74 auto trt_config = provider_config->trt_config; 75 auto trt_config = provider_config->trt_config;
@@ -84,29 +85,27 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, @@ -84,29 +85,27 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
84 std::to_string(trt_config.trt_max_partition_iterations); 85 std::to_string(trt_config.trt_max_partition_iterations);
85 auto trt_min_subgraph_size = 86 auto trt_min_subgraph_size =
86 std::to_string(trt_config.trt_min_subgraph_size); 87 std::to_string(trt_config.trt_min_subgraph_size);
87 - auto trt_fp16_enable =  
88 - std::to_string(trt_config.trt_fp16_enable); 88 + auto trt_fp16_enable = std::to_string(trt_config.trt_fp16_enable);
89 auto trt_detailed_build_log = 89 auto trt_detailed_build_log =
90 std::to_string(trt_config.trt_detailed_build_log); 90 std::to_string(trt_config.trt_detailed_build_log);
91 auto trt_engine_cache_enable = 91 auto trt_engine_cache_enable =
92 std::to_string(trt_config.trt_engine_cache_enable); 92 std::to_string(trt_config.trt_engine_cache_enable);
93 auto trt_timing_cache_enable = 93 auto trt_timing_cache_enable =
94 std::to_string(trt_config.trt_timing_cache_enable); 94 std::to_string(trt_config.trt_timing_cache_enable);
95 - auto trt_dump_subgraphs =  
96 - std::to_string(trt_config.trt_dump_subgraphs); 95 + auto trt_dump_subgraphs = std::to_string(trt_config.trt_dump_subgraphs);
97 std::vector<TrtPairs> trt_options = { 96 std::vector<TrtPairs> trt_options = {
98 - {"device_id", device_id.c_str()},  
99 - {"trt_max_workspace_size", trt_max_workspace_size.c_str()},  
100 - {"trt_max_partition_iterations", trt_max_partition_iterations.c_str()},  
101 - {"trt_min_subgraph_size", trt_min_subgraph_size.c_str()},  
102 - {"trt_fp16_enable", trt_fp16_enable.c_str()},  
103 - {"trt_detailed_build_log", trt_detailed_build_log.c_str()},  
104 - {"trt_engine_cache_enable", trt_engine_cache_enable.c_str()},  
105 - {"trt_engine_cache_path", trt_config.trt_engine_cache_path.c_str()},  
106 - {"trt_timing_cache_enable", trt_timing_cache_enable.c_str()},  
107 - {"trt_timing_cache_path", trt_config.trt_timing_cache_path.c_str()},  
108 - {"trt_dump_subgraphs", trt_dump_subgraphs.c_str()}  
109 - }; 97 + {"device_id", device_id.c_str()},
  98 + {"trt_max_workspace_size", trt_max_workspace_size.c_str()},
  99 + {"trt_max_partition_iterations",
  100 + trt_max_partition_iterations.c_str()},
  101 + {"trt_min_subgraph_size", trt_min_subgraph_size.c_str()},
  102 + {"trt_fp16_enable", trt_fp16_enable.c_str()},
  103 + {"trt_detailed_build_log", trt_detailed_build_log.c_str()},
  104 + {"trt_engine_cache_enable", trt_engine_cache_enable.c_str()},
  105 + {"trt_engine_cache_path", trt_config.trt_engine_cache_path.c_str()},
  106 + {"trt_timing_cache_enable", trt_timing_cache_enable.c_str()},
  107 + {"trt_timing_cache_path", trt_config.trt_timing_cache_path.c_str()},
  108 + {"trt_dump_subgraphs", trt_dump_subgraphs.c_str()}};
110 // ToDo : Trt configs 109 // ToDo : Trt configs
111 // "trt_int8_enable" 110 // "trt_int8_enable"
112 // "trt_int8_use_native_calibration_table" 111 // "trt_int8_use_native_calibration_table"
@@ -151,9 +150,8 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, @@ -151,9 +150,8 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
151 150
152 if (provider_config != nullptr) { 151 if (provider_config != nullptr) {
153 options.device_id = provider_config->device; 152 options.device_id = provider_config->device;
154 - options.cudnn_conv_algo_search =  
155 - OrtCudnnConvAlgoSearch(provider_config->cuda_config  
156 - .cudnn_conv_algo_search); 153 + options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch(
  154 + provider_config->cuda_config.cudnn_conv_algo_search);
157 } else { 155 } else {
158 options.device_id = 0; 156 options.device_id = 0;
159 // Default OrtCudnnConvAlgoSearchExhaustive is extremely slow 157 // Default OrtCudnnConvAlgoSearchExhaustive is extremely slow
@@ -219,22 +217,24 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, @@ -219,22 +217,24 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
219 217
220 Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config) { 218 Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config) {
221 return GetSessionOptionsImpl(config.num_threads, 219 return GetSessionOptionsImpl(config.num_threads,
222 - config.provider_config.provider, &config.provider_config); 220 + config.provider_config.provider,
  221 + &config.provider_config);
223 } 222 }
224 223
225 Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, 224 Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config,
226 - const std::string &model_type) { 225 + const std::string &model_type) {
227 /* 226 /*
228 Transducer models : Only encoder will run with tensorrt, 227 Transducer models : Only encoder will run with tensorrt,
229 decoder and joiner will run with cuda 228 decoder and joiner will run with cuda
230 */ 229 */
231 - if(config.provider_config.provider == "trt" && 230 + if (config.provider_config.provider == "trt" &&
232 (model_type == "decoder" || model_type == "joiner")) { 231 (model_type == "decoder" || model_type == "joiner")) {
233 - return GetSessionOptionsImpl(config.num_threads,  
234 - "cuda", &config.provider_config); 232 + return GetSessionOptionsImpl(config.num_threads, "cuda",
  233 + &config.provider_config);
235 } 234 }
236 return GetSessionOptionsImpl(config.num_threads, 235 return GetSessionOptionsImpl(config.num_threads,
237 - config.provider_config.provider, &config.provider_config); 236 + config.provider_config.provider,
  237 + &config.provider_config);
238 } 238 }
239 239
240 Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) { 240 Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) {
@@ -5,6 +5,8 @@ @@ -5,6 +5,8 @@
5 #ifndef SHERPA_ONNX_CSRC_SESSION_H_ 5 #ifndef SHERPA_ONNX_CSRC_SESSION_H_
6 #define SHERPA_ONNX_CSRC_SESSION_H_ 6 #define SHERPA_ONNX_CSRC_SESSION_H_
7 7
  8 +#include <string>
  9 +
8 #include "onnxruntime_cxx_api.h" // NOLINT 10 #include "onnxruntime_cxx_api.h" // NOLINT
9 #include "sherpa-onnx/csrc/audio-tagging-model-config.h" 11 #include "sherpa-onnx/csrc/audio-tagging-model-config.h"
10 #include "sherpa-onnx/csrc/offline-lm-config.h" 12 #include "sherpa-onnx/csrc/offline-lm-config.h"
@@ -25,7 +27,7 @@ namespace sherpa_onnx { @@ -25,7 +27,7 @@ namespace sherpa_onnx {
25 Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config); 27 Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config);
26 28
27 Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, 29 Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config,
28 - const std::string &model_type); 30 + const std::string &model_type);
29 31
30 Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config); 32 Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config);
31 33
@@ -6,6 +6,7 @@ @@ -6,6 +6,7 @@
6 6
7 #include <algorithm> 7 #include <algorithm>
8 #include <unordered_map> 8 #include <unordered_map>
  9 +#include <utility>
9 10
10 #include "Eigen/Dense" 11 #include "Eigen/Dense"
11 #include "sherpa-onnx/csrc/macros.h" 12 #include "sherpa-onnx/csrc/macros.h"
@@ -11,7 +11,7 @@ @@ -11,7 +11,7 @@
11 namespace sherpa_onnx { 11 namespace sherpa_onnx {
12 12
13 TEST(UTF8, Case1) { 13 TEST(UTF8, Case1) {
14 - std::string hello = "你好, 早上好!世界. hello!。Hallo"; 14 + std::string hello = "你好, 早上好!世界. hello!。Hallo! how are you?";
15 std::vector<std::string> ss = SplitUtf8(hello); 15 std::vector<std::string> ss = SplitUtf8(hello);
16 for (const auto &s : ss) { 16 for (const auto &s : ss) {
17 std::cout << s << "\n"; 17 std::cout << s << "\n";