正在显示
51 个修改的文件
包含
693 行增加
和
156 行删除
| @@ -63,10 +63,16 @@ jobs: | @@ -63,10 +63,16 @@ jobs: | ||
| 63 | echo "pwd: $PWD" | 63 | echo "pwd: $PWD" |
| 64 | ls -lh ../scripts/melo-tts | 64 | ls -lh ../scripts/melo-tts |
| 65 | 65 | ||
| 66 | + rm -rf ./ | ||
| 67 | + | ||
| 66 | cp -v ../scripts/melo-tts/*.onnx . | 68 | cp -v ../scripts/melo-tts/*.onnx . |
| 67 | cp -v ../scripts/melo-tts/lexicon.txt . | 69 | cp -v ../scripts/melo-tts/lexicon.txt . |
| 68 | cp -v ../scripts/melo-tts/tokens.txt . | 70 | cp -v ../scripts/melo-tts/tokens.txt . |
| 71 | + cp -v ../scripts/melo-tts/README.md . | ||
| 72 | + | ||
| 73 | + curl -SL -O https://raw.githubusercontent.com/myshell-ai/MeloTTS/main/LICENSE | ||
| 69 | 74 | ||
| 75 | + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/new_heteronym.fst | ||
| 70 | curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/date.fst | 76 | curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/date.fst |
| 71 | curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/number.fst | 77 | curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/number.fst |
| 72 | curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/phone.fst | 78 | curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/phone.fst |
| @@ -77,6 +83,10 @@ jobs: | @@ -77,6 +83,10 @@ jobs: | ||
| 77 | git lfs track "*.onnx" | 83 | git lfs track "*.onnx" |
| 78 | git add . | 84 | git add . |
| 79 | 85 | ||
| 86 | + ls -lh | ||
| 87 | + | ||
| 88 | + git status | ||
| 89 | + | ||
| 80 | git commit -m "add models" | 90 | git commit -m "add models" |
| 81 | git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/vits-melo-tts-zh_en main || true | 91 | git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/vits-melo-tts-zh_en main || true |
| 82 | 92 |
| @@ -39,10 +39,14 @@ jobs: | @@ -39,10 +39,14 @@ jobs: | ||
| 39 | cd build | 39 | cd build |
| 40 | cmake \ | 40 | cmake \ |
| 41 | -A x64 \ | 41 | -A x64 \ |
| 42 | - -D CMAKE_BUILD_TYPE=Release \ | ||
| 43 | - -D BUILD_SHARED_LIBS=ON \ | 42 | + -DBUILD_SHARED_LIBS=ON \ |
| 44 | -D SHERPA_ONNX_ENABLE_JNI=ON \ | 43 | -D SHERPA_ONNX_ENABLE_JNI=ON \ |
| 45 | - -D CMAKE_INSTALL_PREFIX=./install \ | 44 | + -DCMAKE_INSTALL_PREFIX=./install \ |
| 45 | + -DCMAKE_BUILD_TYPE=Release \ | ||
| 46 | + -DSHERPA_ONNX_ENABLE_WEBSOCKET=OFF \ | ||
| 47 | + -DBUILD_ESPEAK_NG_EXE=OFF \ | ||
| 48 | + -DSHERPA_ONNX_BUILD_C_API_EXAMPLES=OFF \ | ||
| 49 | + -DSHERPA_ONNX_ENABLE_BINARY=ON \ | ||
| 46 | .. | 50 | .. |
| 47 | 51 | ||
| 48 | - name: Build sherpa-onnx for windows | 52 | - name: Build sherpa-onnx for windows |
| @@ -11,7 +11,7 @@ project(sherpa-onnx) | @@ -11,7 +11,7 @@ project(sherpa-onnx) | ||
| 11 | # ./nodejs-addon-examples | 11 | # ./nodejs-addon-examples |
| 12 | # ./dart-api-examples/ | 12 | # ./dart-api-examples/ |
| 13 | # ./CHANGELOG.md | 13 | # ./CHANGELOG.md |
| 14 | -set(SHERPA_ONNX_VERSION "1.10.15") | 14 | +set(SHERPA_ONNX_VERSION "1.10.16") |
| 15 | 15 | ||
| 16 | # Disable warning about | 16 | # Disable warning about |
| 17 | # | 17 | # |
| @@ -5,7 +5,7 @@ description: > | @@ -5,7 +5,7 @@ description: > | ||
| 5 | 5 | ||
| 6 | publish_to: 'none' | 6 | publish_to: 'none' |
| 7 | 7 | ||
| 8 | -version: 1.10.14 | 8 | +version: 1.10.16 |
| 9 | 9 | ||
| 10 | topics: | 10 | topics: |
| 11 | - speech-recognition | 11 | - speech-recognition |
| @@ -30,7 +30,7 @@ dependencies: | @@ -30,7 +30,7 @@ dependencies: | ||
| 30 | record: ^5.1.0 | 30 | record: ^5.1.0 |
| 31 | url_launcher: ^6.2.6 | 31 | url_launcher: ^6.2.6 |
| 32 | 32 | ||
| 33 | - sherpa_onnx: ^1.10.15 | 33 | + sherpa_onnx: ^1.10.16 |
| 34 | # sherpa_onnx: | 34 | # sherpa_onnx: |
| 35 | # path: ../../flutter/sherpa_onnx | 35 | # path: ../../flutter/sherpa_onnx |
| 36 | 36 |
| @@ -5,7 +5,7 @@ description: > | @@ -5,7 +5,7 @@ description: > | ||
| 5 | 5 | ||
| 6 | publish_to: 'none' # Remove this line if you wish to publish to pub.dev | 6 | publish_to: 'none' # Remove this line if you wish to publish to pub.dev |
| 7 | 7 | ||
| 8 | -version: 1.0.0 | 8 | +version: 1.10.16 |
| 9 | 9 | ||
| 10 | environment: | 10 | environment: |
| 11 | sdk: '>=3.4.0 <4.0.0' | 11 | sdk: '>=3.4.0 <4.0.0' |
| @@ -17,7 +17,7 @@ dependencies: | @@ -17,7 +17,7 @@ dependencies: | ||
| 17 | cupertino_icons: ^1.0.6 | 17 | cupertino_icons: ^1.0.6 |
| 18 | path_provider: ^2.1.3 | 18 | path_provider: ^2.1.3 |
| 19 | path: ^1.9.0 | 19 | path: ^1.9.0 |
| 20 | - sherpa_onnx: ^1.10.15 | 20 | + sherpa_onnx: ^1.10.16 |
| 21 | url_launcher: ^6.2.6 | 21 | url_launcher: ^6.2.6 |
| 22 | audioplayers: ^5.0.0 | 22 | audioplayers: ^5.0.0 |
| 23 | 23 |
| @@ -17,7 +17,7 @@ topics: | @@ -17,7 +17,7 @@ topics: | ||
| 17 | - voice-activity-detection | 17 | - voice-activity-detection |
| 18 | 18 | ||
| 19 | # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx_macos.podspec | 19 | # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx_macos.podspec |
| 20 | -version: 1.10.15 | 20 | +version: 1.10.16 |
| 21 | 21 | ||
| 22 | homepage: https://github.com/k2-fsa/sherpa-onnx | 22 | homepage: https://github.com/k2-fsa/sherpa-onnx |
| 23 | 23 | ||
| @@ -30,19 +30,19 @@ dependencies: | @@ -30,19 +30,19 @@ dependencies: | ||
| 30 | flutter: | 30 | flutter: |
| 31 | sdk: flutter | 31 | sdk: flutter |
| 32 | 32 | ||
| 33 | - sherpa_onnx_android: ^1.10.15 | 33 | + sherpa_onnx_android: ^1.10.16 |
| 34 | # path: ../sherpa_onnx_android | 34 | # path: ../sherpa_onnx_android |
| 35 | 35 | ||
| 36 | - sherpa_onnx_macos: ^1.10.15 | 36 | + sherpa_onnx_macos: ^1.10.16 |
| 37 | # path: ../sherpa_onnx_macos | 37 | # path: ../sherpa_onnx_macos |
| 38 | 38 | ||
| 39 | - sherpa_onnx_linux: ^1.10.15 | 39 | + sherpa_onnx_linux: ^1.10.16 |
| 40 | # path: ../sherpa_onnx_linux | 40 | # path: ../sherpa_onnx_linux |
| 41 | # | 41 | # |
| 42 | - sherpa_onnx_windows: ^1.10.15 | 42 | + sherpa_onnx_windows: ^1.10.16 |
| 43 | # path: ../sherpa_onnx_windows | 43 | # path: ../sherpa_onnx_windows |
| 44 | 44 | ||
| 45 | - sherpa_onnx_ios: ^1.10.15 | 45 | + sherpa_onnx_ios: ^1.10.16 |
| 46 | # sherpa_onnx_ios: | 46 | # sherpa_onnx_ios: |
| 47 | # path: ../sherpa_onnx_ios | 47 | # path: ../sherpa_onnx_ios |
| 48 | 48 |
| @@ -7,7 +7,7 @@ | @@ -7,7 +7,7 @@ | ||
| 7 | # https://groups.google.com/g/dart-ffi/c/nUATMBy7r0c | 7 | # https://groups.google.com/g/dart-ffi/c/nUATMBy7r0c |
| 8 | Pod::Spec.new do |s| | 8 | Pod::Spec.new do |s| |
| 9 | s.name = 'sherpa_onnx_ios' | 9 | s.name = 'sherpa_onnx_ios' |
| 10 | - s.version = '1.10.15' | 10 | + s.version = '1.10.16' |
| 11 | s.summary = 'A new Flutter FFI plugin project.' | 11 | s.summary = 'A new Flutter FFI plugin project.' |
| 12 | s.description = <<-DESC | 12 | s.description = <<-DESC |
| 13 | A new Flutter FFI plugin project. | 13 | A new Flutter FFI plugin project. |
| @@ -4,7 +4,7 @@ | @@ -4,7 +4,7 @@ | ||
| 4 | # | 4 | # |
| 5 | Pod::Spec.new do |s| | 5 | Pod::Spec.new do |s| |
| 6 | s.name = 'sherpa_onnx_macos' | 6 | s.name = 'sherpa_onnx_macos' |
| 7 | - s.version = '1.10.15' | 7 | + s.version = '1.10.16' |
| 8 | s.summary = 'sherpa-onnx Flutter FFI plugin project.' | 8 | s.summary = 'sherpa-onnx Flutter FFI plugin project.' |
| 9 | s.description = <<-DESC | 9 | s.description = <<-DESC |
| 10 | sherpa-onnx Flutter FFI plugin project. | 10 | sherpa-onnx Flutter FFI plugin project. |
| @@ -78,6 +78,10 @@ sed -i.bak s/"lang = null"/"lang = \"$lang_iso_639_3\""/ ./TtsEngine.kt | @@ -78,6 +78,10 @@ sed -i.bak s/"lang = null"/"lang = \"$lang_iso_639_3\""/ ./TtsEngine.kt | ||
| 78 | git diff | 78 | git diff |
| 79 | popd | 79 | popd |
| 80 | 80 | ||
| 81 | +if [[ $model_dir == vits-melo-tts-zh_en ]]; then | ||
| 82 | + lang=zh_en | ||
| 83 | +fi | ||
| 84 | + | ||
| 81 | for arch in arm64-v8a armeabi-v7a x86_64 x86; do | 85 | for arch in arm64-v8a armeabi-v7a x86_64 x86; do |
| 82 | log "------------------------------------------------------------" | 86 | log "------------------------------------------------------------" |
| 83 | log "build tts apk for $arch" | 87 | log "build tts apk for $arch" |
| @@ -76,6 +76,10 @@ sed -i.bak s/"modelName = null"/"modelName = \"$model_name\""/ ./MainActivity.kt | @@ -76,6 +76,10 @@ sed -i.bak s/"modelName = null"/"modelName = \"$model_name\""/ ./MainActivity.kt | ||
| 76 | git diff | 76 | git diff |
| 77 | popd | 77 | popd |
| 78 | 78 | ||
| 79 | +if [[ $model_dir == vits-melo-tts-zh_en ]]; then | ||
| 80 | + lang=zh_en | ||
| 81 | +fi | ||
| 82 | + | ||
| 79 | for arch in arm64-v8a armeabi-v7a x86_64 x86; do | 83 | for arch in arm64-v8a armeabi-v7a x86_64 x86; do |
| 80 | log "------------------------------------------------------------" | 84 | log "------------------------------------------------------------" |
| 81 | log "build tts apk for $arch" | 85 | log "build tts apk for $arch" |
| @@ -313,6 +313,11 @@ def get_vits_models() -> List[TtsModel]: | @@ -313,6 +313,11 @@ def get_vits_models() -> List[TtsModel]: | ||
| 313 | lang="zh", | 313 | lang="zh", |
| 314 | ), | 314 | ), |
| 315 | TtsModel( | 315 | TtsModel( |
| 316 | + model_dir="vits-melo-tts-zh_en", | ||
| 317 | + model_name="model.onnx", | ||
| 318 | + lang="zh", | ||
| 319 | + ), | ||
| 320 | + TtsModel( | ||
| 316 | model_dir="vits-zh-hf-fanchen-C", | 321 | model_dir="vits-zh-hf-fanchen-C", |
| 317 | model_name="vits-zh-hf-fanchen-C.onnx", | 322 | model_name="vits-zh-hf-fanchen-C.onnx", |
| 318 | lang="zh", | 323 | lang="zh", |
| @@ -339,18 +344,21 @@ def get_vits_models() -> List[TtsModel]: | @@ -339,18 +344,21 @@ def get_vits_models() -> List[TtsModel]: | ||
| 339 | ), | 344 | ), |
| 340 | ] | 345 | ] |
| 341 | 346 | ||
| 342 | - rule_fsts = ["phone.fst", "date.fst", "number.fst", "new_heteronym.fst"] | 347 | + rule_fsts = ["phone.fst", "date.fst", "number.fst"] |
| 343 | for m in chinese_models: | 348 | for m in chinese_models: |
| 344 | s = [f"{m.model_dir}/{r}" for r in rule_fsts] | 349 | s = [f"{m.model_dir}/{r}" for r in rule_fsts] |
| 345 | - if "vits-zh-hf" in m.model_dir or "sherpa-onnx-vits-zh-ll" == m.model_dir: | 350 | + if ( |
| 351 | + "vits-zh-hf" in m.model_dir | ||
| 352 | + or "sherpa-onnx-vits-zh-ll" == m.model_dir | ||
| 353 | + or "melo-tts" in m.model_dir | ||
| 354 | + ): | ||
| 346 | s = s[:-1] | 355 | s = s[:-1] |
| 347 | m.dict_dir = m.model_dir + "/dict" | 356 | m.dict_dir = m.model_dir + "/dict" |
| 357 | + else: | ||
| 358 | + m.rule_fars = f"{m.model_dir}/rule.far" | ||
| 348 | 359 | ||
| 349 | m.rule_fsts = ",".join(s) | 360 | m.rule_fsts = ",".join(s) |
| 350 | 361 | ||
| 351 | - if "vits-zh-hf" not in m.model_dir and "zh-ll" not in m.model_dir: | ||
| 352 | - m.rule_fars = f"{m.model_dir}/rule.far" | ||
| 353 | - | ||
| 354 | all_models = chinese_models + [ | 362 | all_models = chinese_models + [ |
| 355 | TtsModel( | 363 | TtsModel( |
| 356 | model_dir="vits-cantonese-hf-xiaomaiiwn", | 364 | model_dir="vits-cantonese-hf-xiaomaiiwn", |
| @@ -17,7 +17,7 @@ topics: | @@ -17,7 +17,7 @@ topics: | ||
| 17 | - voice-activity-detection | 17 | - voice-activity-detection |
| 18 | 18 | ||
| 19 | # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx.podspec | 19 | # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx.podspec |
| 20 | -version: 1.10.15 | 20 | +version: 1.10.16 |
| 21 | 21 | ||
| 22 | homepage: https://github.com/k2-fsa/sherpa-onnx | 22 | homepage: https://github.com/k2-fsa/sherpa-onnx |
| 23 | 23 |
| @@ -6,9 +6,6 @@ from typing import List, Optional | @@ -6,9 +6,6 @@ from typing import List, Optional | ||
| 6 | 6 | ||
| 7 | import jinja2 | 7 | import jinja2 |
| 8 | 8 | ||
| 9 | -# pip install iso639-lang | ||
| 10 | -from iso639 import Lang | ||
| 11 | - | ||
| 12 | 9 | ||
| 13 | def get_args(): | 10 | def get_args(): |
| 14 | parser = argparse.ArgumentParser() | 11 | parser = argparse.ArgumentParser() |
| @@ -37,13 +34,6 @@ class TtsModel: | @@ -37,13 +34,6 @@ class TtsModel: | ||
| 37 | data_dir: Optional[str] = None | 34 | data_dir: Optional[str] = None |
| 38 | dict_dir: Optional[str] = None | 35 | dict_dir: Optional[str] = None |
| 39 | is_char: bool = False | 36 | is_char: bool = False |
| 40 | - lang_iso_639_3: str = "" | ||
| 41 | - | ||
| 42 | - | ||
| 43 | -def convert_lang_to_iso_639_3(models: List[TtsModel]): | ||
| 44 | - for m in models: | ||
| 45 | - if m.lang_iso_639_3 == "": | ||
| 46 | - m.lang_iso_639_3 = Lang(m.lang).pt3 | ||
| 47 | 37 | ||
| 48 | 38 | ||
| 49 | def get_coqui_models() -> List[TtsModel]: | 39 | def get_coqui_models() -> List[TtsModel]: |
| @@ -313,6 +303,11 @@ def get_vits_models() -> List[TtsModel]: | @@ -313,6 +303,11 @@ def get_vits_models() -> List[TtsModel]: | ||
| 313 | lang="zh", | 303 | lang="zh", |
| 314 | ), | 304 | ), |
| 315 | TtsModel( | 305 | TtsModel( |
| 306 | + model_dir="vits-melo-tts-zh_en", | ||
| 307 | + model_name="model.onnx", | ||
| 308 | + lang="zh_en", | ||
| 309 | + ), | ||
| 310 | + TtsModel( | ||
| 316 | model_dir="vits-zh-hf-fanchen-C", | 311 | model_dir="vits-zh-hf-fanchen-C", |
| 317 | model_name="vits-zh-hf-fanchen-C.onnx", | 312 | model_name="vits-zh-hf-fanchen-C.onnx", |
| 318 | lang="zh", | 313 | lang="zh", |
| @@ -332,26 +327,33 @@ def get_vits_models() -> List[TtsModel]: | @@ -332,26 +327,33 @@ def get_vits_models() -> List[TtsModel]: | ||
| 332 | model_name="vits-zh-hf-fanchen-unity.onnx", | 327 | model_name="vits-zh-hf-fanchen-unity.onnx", |
| 333 | lang="zh", | 328 | lang="zh", |
| 334 | ), | 329 | ), |
| 330 | + TtsModel( | ||
| 331 | + model_dir="sherpa-onnx-vits-zh-ll", | ||
| 332 | + model_name="model.onnx", | ||
| 333 | + lang="zh", | ||
| 334 | + ), | ||
| 335 | ] | 335 | ] |
| 336 | 336 | ||
| 337 | - rule_fsts = ["phone.fst", "date.fst", "number.fst", "new_heteronym.fst"] | 337 | + rule_fsts = ["phone.fst", "date.fst", "number.fst"] |
| 338 | for m in chinese_models: | 338 | for m in chinese_models: |
| 339 | s = [f"{m.model_dir}/{r}" for r in rule_fsts] | 339 | s = [f"{m.model_dir}/{r}" for r in rule_fsts] |
| 340 | - if "vits-zh-hf" in m.model_dir: | 340 | + if ( |
| 341 | + "vits-zh-hf" in m.model_dir | ||
| 342 | + or "sherpa-onnx-vits-zh-ll" == m.model_dir | ||
| 343 | + or "melo-tts" in m.model_dir | ||
| 344 | + ): | ||
| 341 | s = s[:-1] | 345 | s = s[:-1] |
| 342 | m.dict_dir = m.model_dir + "/dict" | 346 | m.dict_dir = m.model_dir + "/dict" |
| 347 | + else: | ||
| 348 | + m.rule_fars = f"{m.model_dir}/rule.far" | ||
| 343 | 349 | ||
| 344 | m.rule_fsts = ",".join(s) | 350 | m.rule_fsts = ",".join(s) |
| 345 | 351 | ||
| 346 | - if "vits-zh-hf" not in m.model_dir: | ||
| 347 | - m.rule_fars = f"{m.model_dir}/rule.far" | ||
| 348 | - | ||
| 349 | all_models = chinese_models + [ | 352 | all_models = chinese_models + [ |
| 350 | TtsModel( | 353 | TtsModel( |
| 351 | model_dir="vits-cantonese-hf-xiaomaiiwn", | 354 | model_dir="vits-cantonese-hf-xiaomaiiwn", |
| 352 | model_name="vits-cantonese-hf-xiaomaiiwn.onnx", | 355 | model_name="vits-cantonese-hf-xiaomaiiwn.onnx", |
| 353 | lang="cantonese", | 356 | lang="cantonese", |
| 354 | - lang_iso_639_3="yue", | ||
| 355 | rule_fsts="vits-cantonese-hf-xiaomaiiwn/rule.fst", | 357 | rule_fsts="vits-cantonese-hf-xiaomaiiwn/rule.fst", |
| 356 | ), | 358 | ), |
| 357 | # English (US) | 359 | # English (US) |
| @@ -374,7 +376,6 @@ def main(): | @@ -374,7 +376,6 @@ def main(): | ||
| 374 | all_model_list += get_piper_models() | 376 | all_model_list += get_piper_models() |
| 375 | all_model_list += get_mimic3_models() | 377 | all_model_list += get_mimic3_models() |
| 376 | all_model_list += get_coqui_models() | 378 | all_model_list += get_coqui_models() |
| 377 | - convert_lang_to_iso_639_3(all_model_list) | ||
| 378 | 379 | ||
| 379 | num_models = len(all_model_list) | 380 | num_models = len(all_model_list) |
| 380 | 381 |
scripts/melo-tts/README.md
0 → 100644
| @@ -8,7 +8,6 @@ from melo.text import language_id_map, language_tone_start_map | @@ -8,7 +8,6 @@ from melo.text import language_id_map, language_tone_start_map | ||
| 8 | from melo.text.chinese import pinyin_to_symbol_map | 8 | from melo.text.chinese import pinyin_to_symbol_map |
| 9 | from melo.text.english import eng_dict, refine_syllables | 9 | from melo.text.english import eng_dict, refine_syllables |
| 10 | from pypinyin import Style, lazy_pinyin, phrases_dict, pinyin_dict | 10 | from pypinyin import Style, lazy_pinyin, phrases_dict, pinyin_dict |
| 11 | -from melo.text.symbols import language_tone_start_map | ||
| 12 | 11 | ||
| 13 | for k, v in pinyin_to_symbol_map.items(): | 12 | for k, v in pinyin_to_symbol_map.items(): |
| 14 | if isinstance(v, list): | 13 | if isinstance(v, list): |
| @@ -82,6 +81,7 @@ def generate_tokens(symbol_list): | @@ -82,6 +81,7 @@ def generate_tokens(symbol_list): | ||
| 82 | def generate_lexicon(): | 81 | def generate_lexicon(): |
| 83 | word_dict = pinyin_dict.pinyin_dict | 82 | word_dict = pinyin_dict.pinyin_dict |
| 84 | phrases = phrases_dict.phrases_dict | 83 | phrases = phrases_dict.phrases_dict |
| 84 | + eng_dict["kaldi"] = [["K", "AH0"], ["L", "D", "IH0"]] | ||
| 85 | with open("lexicon.txt", "w", encoding="utf-8") as f: | 85 | with open("lexicon.txt", "w", encoding="utf-8") as f: |
| 86 | for word in eng_dict: | 86 | for word in eng_dict: |
| 87 | phones, tones = refine_syllables(eng_dict[word]) | 87 | phones, tones = refine_syllables(eng_dict[word]) |
| @@ -237,9 +237,11 @@ def main(): | @@ -237,9 +237,11 @@ def main(): | ||
| 237 | meta_data = { | 237 | meta_data = { |
| 238 | "model_type": "melo-vits", | 238 | "model_type": "melo-vits", |
| 239 | "comment": "melo", | 239 | "comment": "melo", |
| 240 | + "version": 2, | ||
| 240 | "language": "Chinese + English", | 241 | "language": "Chinese + English", |
| 241 | "add_blank": int(model.hps.data.add_blank), | 242 | "add_blank": int(model.hps.data.add_blank), |
| 242 | "n_speakers": 1, | 243 | "n_speakers": 1, |
| 244 | + "jieba": 1, | ||
| 243 | "sample_rate": model.hps.data.sampling_rate, | 245 | "sample_rate": model.hps.data.sampling_rate, |
| 244 | "bert_dim": 1024, | 246 | "bert_dim": 1024, |
| 245 | "ja_bert_dim": 768, | 247 | "ja_bert_dim": 768, |
| @@ -12,7 +12,7 @@ function install() { | @@ -12,7 +12,7 @@ function install() { | ||
| 12 | cd MeloTTS | 12 | cd MeloTTS |
| 13 | pip install -r ./requirements.txt | 13 | pip install -r ./requirements.txt |
| 14 | 14 | ||
| 15 | - pip install soundfile onnx onnxruntime | 15 | + pip install soundfile onnx==1.15.0 onnxruntime==1.16.3 |
| 16 | 16 | ||
| 17 | python3 -m unidic download | 17 | python3 -m unidic download |
| 18 | popd | 18 | popd |
| @@ -135,28 +135,11 @@ class OnnxModel: | @@ -135,28 +135,11 @@ class OnnxModel: | ||
| 135 | def main(): | 135 | def main(): |
| 136 | lexicon = Lexicon(lexion_filename="./lexicon.txt", tokens_filename="./tokens.txt") | 136 | lexicon = Lexicon(lexion_filename="./lexicon.txt", tokens_filename="./tokens.txt") |
| 137 | 137 | ||
| 138 | - text = "永远相信,美好的事情即将发生。" | 138 | + text = "这是一个使用 next generation kaldi 的 text to speech 中英文例子. Thank you! 你觉得如何呢? are you ok? Fantastic! How about you?" |
| 139 | s = jieba.cut(text, HMM=True) | 139 | s = jieba.cut(text, HMM=True) |
| 140 | 140 | ||
| 141 | phones, tones = lexicon.convert(s) | 141 | phones, tones = lexicon.convert(s) |
| 142 | 142 | ||
| 143 | - en_text = "how are you ?".split() | ||
| 144 | - | ||
| 145 | - phones_en, tones_en = lexicon.convert(en_text) | ||
| 146 | - phones += [0] | ||
| 147 | - tones += [0] | ||
| 148 | - | ||
| 149 | - phones += phones_en | ||
| 150 | - tones += tones_en | ||
| 151 | - | ||
| 152 | - text = "多音字测试, 银行,行不行?长沙长大" | ||
| 153 | - s = jieba.cut(text, HMM=True) | ||
| 154 | - | ||
| 155 | - phones2, tones2 = lexicon.convert(s) | ||
| 156 | - | ||
| 157 | - phones += phones2 | ||
| 158 | - tones += tones2 | ||
| 159 | - | ||
| 160 | model = OnnxModel("./model.onnx") | 143 | model = OnnxModel("./model.onnx") |
| 161 | 144 | ||
| 162 | if model.add_blank: | 145 | if model.add_blank: |
| @@ -422,10 +422,10 @@ sherpa_onnx::OfflineRecognizerConfig convertConfig( | @@ -422,10 +422,10 @@ sherpa_onnx::OfflineRecognizerConfig convertConfig( | ||
| 422 | 422 | ||
| 423 | void SherpaOnnxOfflineRecognizerSetConfig( | 423 | void SherpaOnnxOfflineRecognizerSetConfig( |
| 424 | const SherpaOnnxOfflineRecognizer *recognizer, | 424 | const SherpaOnnxOfflineRecognizer *recognizer, |
| 425 | - const SherpaOnnxOfflineRecognizerConfig *config){ | 425 | + const SherpaOnnxOfflineRecognizerConfig *config) { |
| 426 | sherpa_onnx::OfflineRecognizerConfig recognizer_config = | 426 | sherpa_onnx::OfflineRecognizerConfig recognizer_config = |
| 427 | convertConfig(config); | 427 | convertConfig(config); |
| 428 | - recognizer->impl->SetConfig(recognizer_config); | 428 | + recognizer->impl->SetConfig(recognizer_config); |
| 429 | } | 429 | } |
| 430 | 430 | ||
| 431 | void DestroyOfflineRecognizer(SherpaOnnxOfflineRecognizer *recognizer) { | 431 | void DestroyOfflineRecognizer(SherpaOnnxOfflineRecognizer *recognizer) { |
| @@ -478,7 +478,7 @@ const SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult( | @@ -478,7 +478,7 @@ const SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult( | ||
| 478 | pText[text.size()] = 0; | 478 | pText[text.size()] = 0; |
| 479 | r->text = pText; | 479 | r->text = pText; |
| 480 | 480 | ||
| 481 | - //lang | 481 | + // lang |
| 482 | const auto &lang = result.lang; | 482 | const auto &lang = result.lang; |
| 483 | char *c_lang = new char[lang.size() + 1]; | 483 | char *c_lang = new char[lang.size() + 1]; |
| 484 | std::copy(lang.begin(), lang.end(), c_lang); | 484 | std::copy(lang.begin(), lang.end(), c_lang); |
| @@ -1317,7 +1317,7 @@ void SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches( | @@ -1317,7 +1317,7 @@ void SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches( | ||
| 1317 | } | 1317 | } |
| 1318 | delete[] r->matches; | 1318 | delete[] r->matches; |
| 1319 | delete r; | 1319 | delete r; |
| 1320 | -}; | 1320 | +} |
| 1321 | 1321 | ||
| 1322 | int32_t SherpaOnnxSpeakerEmbeddingManagerVerify( | 1322 | int32_t SherpaOnnxSpeakerEmbeddingManagerVerify( |
| 1323 | const SherpaOnnxSpeakerEmbeddingManager *p, const char *name, | 1323 | const SherpaOnnxSpeakerEmbeddingManager *p, const char *name, |
| @@ -496,7 +496,7 @@ SHERPA_ONNX_API void DecodeMultipleOfflineStreams( | @@ -496,7 +496,7 @@ SHERPA_ONNX_API void DecodeMultipleOfflineStreams( | ||
| 496 | SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult { | 496 | SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult { |
| 497 | const char *text; | 497 | const char *text; |
| 498 | 498 | ||
| 499 | - // Pointer to continuous memory which holds timestamps | 499 | + // Pointer to continuous memory which holds timestamps |
| 500 | // | 500 | // |
| 501 | // It is NULL if the model does not support timestamps | 501 | // It is NULL if the model does not support timestamps |
| 502 | float *timestamps; | 502 | float *timestamps; |
| @@ -525,9 +525,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult { | @@ -525,9 +525,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult { | ||
| 525 | */ | 525 | */ |
| 526 | const char *json; | 526 | const char *json; |
| 527 | 527 | ||
| 528 | - //return recognized language | 528 | + // return recognized language |
| 529 | const char *lang; | 529 | const char *lang; |
| 530 | - | ||
| 531 | } SherpaOnnxOfflineRecognizerResult; | 530 | } SherpaOnnxOfflineRecognizerResult; |
| 532 | 531 | ||
| 533 | /// Get the result of the offline stream. | 532 | /// Get the result of the offline stream. |
| @@ -142,7 +142,9 @@ if(SHERPA_ONNX_ENABLE_TTS) | @@ -142,7 +142,9 @@ if(SHERPA_ONNX_ENABLE_TTS) | ||
| 142 | list(APPEND sources | 142 | list(APPEND sources |
| 143 | jieba-lexicon.cc | 143 | jieba-lexicon.cc |
| 144 | lexicon.cc | 144 | lexicon.cc |
| 145 | + melo-tts-lexicon.cc | ||
| 145 | offline-tts-character-frontend.cc | 146 | offline-tts-character-frontend.cc |
| 147 | + offline-tts-frontend.cc | ||
| 146 | offline-tts-impl.cc | 148 | offline-tts-impl.cc |
| 147 | offline-tts-model-config.cc | 149 | offline-tts-model-config.cc |
| 148 | offline-tts-vits-model-config.cc | 150 | offline-tts-vits-model-config.cc |
| @@ -33,7 +33,7 @@ TEST(CppJieBa, Case1) { | @@ -33,7 +33,7 @@ TEST(CppJieBa, Case1) { | ||
| 33 | std::vector<std::string> words; | 33 | std::vector<std::string> words; |
| 34 | std::vector<cppjieba::Word> jiebawords; | 34 | std::vector<cppjieba::Word> jiebawords; |
| 35 | 35 | ||
| 36 | - std::string s = "他来到了网易杭研大厦"; | 36 | + std::string s = "他来到了网易杭研大厦。How are you?"; |
| 37 | std::cout << s << std::endl; | 37 | std::cout << s << std::endl; |
| 38 | std::cout << "[demo] Cut With HMM" << std::endl; | 38 | std::cout << "[demo] Cut With HMM" << std::endl; |
| 39 | jieba.Cut(s, words, true); | 39 | jieba.Cut(s, words, true); |
| @@ -17,6 +17,7 @@ namespace sherpa_onnx { | @@ -17,6 +17,7 @@ namespace sherpa_onnx { | ||
| 17 | 17 | ||
| 18 | // implemented in ./lexicon.cc | 18 | // implemented in ./lexicon.cc |
| 19 | std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is); | 19 | std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is); |
| 20 | + | ||
| 20 | std::vector<int32_t> ConvertTokensToIds( | 21 | std::vector<int32_t> ConvertTokensToIds( |
| 21 | const std::unordered_map<std::string, int32_t> &token2id, | 22 | const std::unordered_map<std::string, int32_t> &token2id, |
| 22 | const std::vector<std::string> &tokens); | 23 | const std::vector<std::string> &tokens); |
| @@ -53,8 +54,7 @@ class JiebaLexicon::Impl { | @@ -53,8 +54,7 @@ class JiebaLexicon::Impl { | ||
| 53 | } | 54 | } |
| 54 | } | 55 | } |
| 55 | 56 | ||
| 56 | - std::vector<std::vector<int64_t>> ConvertTextToTokenIds( | ||
| 57 | - const std::string &text) const { | 57 | + std::vector<TokenIDs> ConvertTextToTokenIds(const std::string &text) const { |
| 58 | // see | 58 | // see |
| 59 | // https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/text/mandarin.py#L244 | 59 | // https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/text/mandarin.py#L244 |
| 60 | std::regex punct_re{":|、|;"}; | 60 | std::regex punct_re{":|、|;"}; |
| @@ -87,7 +87,7 @@ class JiebaLexicon::Impl { | @@ -87,7 +87,7 @@ class JiebaLexicon::Impl { | ||
| 87 | SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str()); | 87 | SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str()); |
| 88 | } | 88 | } |
| 89 | 89 | ||
| 90 | - std::vector<std::vector<int64_t>> ans; | 90 | + std::vector<TokenIDs> ans; |
| 91 | std::vector<int64_t> this_sentence; | 91 | std::vector<int64_t> this_sentence; |
| 92 | 92 | ||
| 93 | int32_t blank = token2id_.at(" "); | 93 | int32_t blank = token2id_.at(" "); |
| @@ -217,7 +217,7 @@ JiebaLexicon::JiebaLexicon(const std::string &lexicon, | @@ -217,7 +217,7 @@ JiebaLexicon::JiebaLexicon(const std::string &lexicon, | ||
| 217 | : impl_(std::make_unique<Impl>(lexicon, tokens, dict_dir, meta_data, | 217 | : impl_(std::make_unique<Impl>(lexicon, tokens, dict_dir, meta_data, |
| 218 | debug)) {} | 218 | debug)) {} |
| 219 | 219 | ||
| 220 | -std::vector<std::vector<int64_t>> JiebaLexicon::ConvertTextToTokenIds( | 220 | +std::vector<TokenIDs> JiebaLexicon::ConvertTextToTokenIds( |
| 221 | const std::string &text, const std::string & /*unused_voice = ""*/) const { | 221 | const std::string &text, const std::string & /*unused_voice = ""*/) const { |
| 222 | return impl_->ConvertTextToTokenIds(text); | 222 | return impl_->ConvertTextToTokenIds(text); |
| 223 | } | 223 | } |
| @@ -10,11 +10,6 @@ | @@ -10,11 +10,6 @@ | ||
| 10 | #include <unordered_map> | 10 | #include <unordered_map> |
| 11 | #include <vector> | 11 | #include <vector> |
| 12 | 12 | ||
| 13 | -#if __ANDROID_API__ >= 9 | ||
| 14 | -#include "android/asset_manager.h" | ||
| 15 | -#include "android/asset_manager_jni.h" | ||
| 16 | -#endif | ||
| 17 | - | ||
| 18 | #include "sherpa-onnx/csrc/offline-tts-frontend.h" | 13 | #include "sherpa-onnx/csrc/offline-tts-frontend.h" |
| 19 | #include "sherpa-onnx/csrc/offline-tts-vits-model-metadata.h" | 14 | #include "sherpa-onnx/csrc/offline-tts-vits-model-metadata.h" |
| 20 | 15 | ||
| @@ -27,13 +22,7 @@ class JiebaLexicon : public OfflineTtsFrontend { | @@ -27,13 +22,7 @@ class JiebaLexicon : public OfflineTtsFrontend { | ||
| 27 | const std::string &dict_dir, | 22 | const std::string &dict_dir, |
| 28 | const OfflineTtsVitsModelMetaData &meta_data, bool debug); | 23 | const OfflineTtsVitsModelMetaData &meta_data, bool debug); |
| 29 | 24 | ||
| 30 | -#if __ANDROID_API__ >= 9 | ||
| 31 | - JiebaLexicon(AAssetManager *mgr, const std::string &lexicon, | ||
| 32 | - const std::string &tokens, const std::string &dict_dir, | ||
| 33 | - const OfflineTtsVitsModelMetaData &meta_data); | ||
| 34 | -#endif | ||
| 35 | - | ||
| 36 | - std::vector<std::vector<int64_t>> ConvertTextToTokenIds( | 25 | + std::vector<TokenIDs> ConvertTextToTokenIds( |
| 37 | const std::string &text, | 26 | const std::string &text, |
| 38 | const std::string &unused_voice = "") const override; | 27 | const std::string &unused_voice = "") const override; |
| 39 | 28 |
| @@ -172,7 +172,7 @@ Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon, | @@ -172,7 +172,7 @@ Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon, | ||
| 172 | } | 172 | } |
| 173 | #endif | 173 | #endif |
| 174 | 174 | ||
| 175 | -std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIds( | 175 | +std::vector<TokenIDs> Lexicon::ConvertTextToTokenIds( |
| 176 | const std::string &text, const std::string & /*voice*/ /*= ""*/) const { | 176 | const std::string &text, const std::string & /*voice*/ /*= ""*/) const { |
| 177 | switch (language_) { | 177 | switch (language_) { |
| 178 | case Language::kChinese: | 178 | case Language::kChinese: |
| @@ -187,7 +187,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIds( | @@ -187,7 +187,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIds( | ||
| 187 | return {}; | 187 | return {}; |
| 188 | } | 188 | } |
| 189 | 189 | ||
| 190 | -std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese( | 190 | +std::vector<TokenIDs> Lexicon::ConvertTextToTokenIdsChinese( |
| 191 | const std::string &_text) const { | 191 | const std::string &_text) const { |
| 192 | std::string text(_text); | 192 | std::string text(_text); |
| 193 | ToLowerCase(&text); | 193 | ToLowerCase(&text); |
| @@ -209,7 +209,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese( | @@ -209,7 +209,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese( | ||
| 209 | fprintf(stderr, "\n"); | 209 | fprintf(stderr, "\n"); |
| 210 | } | 210 | } |
| 211 | 211 | ||
| 212 | - std::vector<std::vector<int64_t>> ans; | 212 | + std::vector<TokenIDs> ans; |
| 213 | std::vector<int64_t> this_sentence; | 213 | std::vector<int64_t> this_sentence; |
| 214 | 214 | ||
| 215 | int32_t blank = -1; | 215 | int32_t blank = -1; |
| @@ -288,7 +288,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese( | @@ -288,7 +288,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese( | ||
| 288 | return ans; | 288 | return ans; |
| 289 | } | 289 | } |
| 290 | 290 | ||
| 291 | -std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsNotChinese( | 291 | +std::vector<TokenIDs> Lexicon::ConvertTextToTokenIdsNotChinese( |
| 292 | const std::string &_text) const { | 292 | const std::string &_text) const { |
| 293 | std::string text(_text); | 293 | std::string text(_text); |
| 294 | ToLowerCase(&text); | 294 | ToLowerCase(&text); |
| @@ -311,7 +311,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsNotChinese( | @@ -311,7 +311,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsNotChinese( | ||
| 311 | 311 | ||
| 312 | int32_t blank = token2id_.at(" "); | 312 | int32_t blank = token2id_.at(" "); |
| 313 | 313 | ||
| 314 | - std::vector<std::vector<int64_t>> ans; | 314 | + std::vector<TokenIDs> ans; |
| 315 | std::vector<int64_t> this_sentence; | 315 | std::vector<int64_t> this_sentence; |
| 316 | 316 | ||
| 317 | for (const auto &w : words) { | 317 | for (const auto &w : words) { |
| @@ -36,14 +36,14 @@ class Lexicon : public OfflineTtsFrontend { | @@ -36,14 +36,14 @@ class Lexicon : public OfflineTtsFrontend { | ||
| 36 | const std::string &language, bool debug = false); | 36 | const std::string &language, bool debug = false); |
| 37 | #endif | 37 | #endif |
| 38 | 38 | ||
| 39 | - std::vector<std::vector<int64_t>> ConvertTextToTokenIds( | 39 | + std::vector<TokenIDs> ConvertTextToTokenIds( |
| 40 | const std::string &text, const std::string &voice = "") const override; | 40 | const std::string &text, const std::string &voice = "") const override; |
| 41 | 41 | ||
| 42 | private: | 42 | private: |
| 43 | - std::vector<std::vector<int64_t>> ConvertTextToTokenIdsNotChinese( | 43 | + std::vector<TokenIDs> ConvertTextToTokenIdsNotChinese( |
| 44 | const std::string &text) const; | 44 | const std::string &text) const; |
| 45 | 45 | ||
| 46 | - std::vector<std::vector<int64_t>> ConvertTextToTokenIdsChinese( | 46 | + std::vector<TokenIDs> ConvertTextToTokenIdsChinese( |
| 47 | const std::string &text) const; | 47 | const std::string &text) const; |
| 48 | 48 | ||
| 49 | void InitLanguage(const std::string &lang); | 49 | void InitLanguage(const std::string &lang); |
sherpa-onnx/csrc/melo-tts-lexicon.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/melo-tts-lexicon.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/melo-tts-lexicon.h" | ||
| 6 | + | ||
| 7 | +#include <fstream> | ||
| 8 | +#include <regex> // NOLINT | ||
| 9 | +#include <utility> | ||
| 10 | + | ||
| 11 | +#include "cppjieba/Jieba.hpp" | ||
| 12 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 13 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 14 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 15 | + | ||
| 16 | +namespace sherpa_onnx { | ||
| 17 | + | ||
| 18 | +// implemented in ./lexicon.cc | ||
| 19 | +std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is); | ||
| 20 | + | ||
| 21 | +std::vector<int32_t> ConvertTokensToIds( | ||
| 22 | + const std::unordered_map<std::string, int32_t> &token2id, | ||
| 23 | + const std::vector<std::string> &tokens); | ||
| 24 | + | ||
| 25 | +class MeloTtsLexicon::Impl { | ||
| 26 | + public: | ||
| 27 | + Impl(const std::string &lexicon, const std::string &tokens, | ||
| 28 | + const std::string &dict_dir, | ||
| 29 | + const OfflineTtsVitsModelMetaData &meta_data, bool debug) | ||
| 30 | + : meta_data_(meta_data), debug_(debug) { | ||
| 31 | + std::string dict = dict_dir + "/jieba.dict.utf8"; | ||
| 32 | + std::string hmm = dict_dir + "/hmm_model.utf8"; | ||
| 33 | + std::string user_dict = dict_dir + "/user.dict.utf8"; | ||
| 34 | + std::string idf = dict_dir + "/idf.utf8"; | ||
| 35 | + std::string stop_word = dict_dir + "/stop_words.utf8"; | ||
| 36 | + | ||
| 37 | + AssertFileExists(dict); | ||
| 38 | + AssertFileExists(hmm); | ||
| 39 | + AssertFileExists(user_dict); | ||
| 40 | + AssertFileExists(idf); | ||
| 41 | + AssertFileExists(stop_word); | ||
| 42 | + | ||
| 43 | + jieba_ = | ||
| 44 | + std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word); | ||
| 45 | + | ||
| 46 | + { | ||
| 47 | + std::ifstream is(tokens); | ||
| 48 | + InitTokens(is); | ||
| 49 | + } | ||
| 50 | + | ||
| 51 | + { | ||
| 52 | + std::ifstream is(lexicon); | ||
| 53 | + InitLexicon(is); | ||
| 54 | + } | ||
| 55 | + } | ||
| 56 | + | ||
| 57 | + std::vector<TokenIDs> ConvertTextToTokenIds(const std::string &_text) const { | ||
| 58 | + std::string text = ToLowerCase(_text); | ||
| 59 | + // see | ||
| 60 | + // https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/text/mandarin.py#L244 | ||
| 61 | + std::regex punct_re{":|、|;"}; | ||
| 62 | + std::string s = std::regex_replace(text, punct_re, ","); | ||
| 63 | + | ||
| 64 | + std::regex punct_re2("。"); | ||
| 65 | + s = std::regex_replace(s, punct_re2, "."); | ||
| 66 | + | ||
| 67 | + std::regex punct_re3("?"); | ||
| 68 | + s = std::regex_replace(s, punct_re3, "?"); | ||
| 69 | + | ||
| 70 | + std::regex punct_re4("!"); | ||
| 71 | + s = std::regex_replace(s, punct_re4, "!"); | ||
| 72 | + | ||
| 73 | + std::vector<std::string> words; | ||
| 74 | + bool is_hmm = true; | ||
| 75 | + jieba_->Cut(text, words, is_hmm); | ||
| 76 | + | ||
| 77 | + if (debug_) { | ||
| 78 | + SHERPA_ONNX_LOGE("input text: %s", text.c_str()); | ||
| 79 | + SHERPA_ONNX_LOGE("after replacing punctuations: %s", s.c_str()); | ||
| 80 | + | ||
| 81 | + std::ostringstream os; | ||
| 82 | + std::string sep = ""; | ||
| 83 | + for (const auto &w : words) { | ||
| 84 | + os << sep << w; | ||
| 85 | + sep = "_"; | ||
| 86 | + } | ||
| 87 | + | ||
| 88 | + SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str()); | ||
| 89 | + } | ||
| 90 | + | ||
| 91 | + std::vector<TokenIDs> ans; | ||
| 92 | + TokenIDs this_sentence; | ||
| 93 | + | ||
| 94 | + int32_t blank = token2id_.at("_"); | ||
| 95 | + for (const auto &w : words) { | ||
| 96 | + auto ids = ConvertWordToIds(w); | ||
| 97 | + if (ids.tokens.empty()) { | ||
| 98 | + SHERPA_ONNX_LOGE("Ignore OOV '%s'", w.c_str()); | ||
| 99 | + continue; | ||
| 100 | + } | ||
| 101 | + | ||
| 102 | + this_sentence.tokens.insert(this_sentence.tokens.end(), | ||
| 103 | + ids.tokens.begin(), ids.tokens.end()); | ||
| 104 | + this_sentence.tones.insert(this_sentence.tones.end(), ids.tones.begin(), | ||
| 105 | + ids.tones.end()); | ||
| 106 | + | ||
| 107 | + if (w == "." || w == "!" || w == "?" || w == ",") { | ||
| 108 | + ans.push_back(std::move(this_sentence)); | ||
| 109 | + this_sentence = {}; | ||
| 110 | + } | ||
| 111 | + } // for (const auto &w : words) | ||
| 112 | + | ||
| 113 | + if (!this_sentence.tokens.empty()) { | ||
| 114 | + ans.push_back(std::move(this_sentence)); | ||
| 115 | + } | ||
| 116 | + | ||
| 117 | + return ans; | ||
| 118 | + } | ||
| 119 | + | ||
| 120 | + private: | ||
| 121 | + TokenIDs ConvertWordToIds(const std::string &w) const { | ||
| 122 | + if (word2ids_.count(w)) { | ||
| 123 | + return word2ids_.at(w); | ||
| 124 | + } | ||
| 125 | + | ||
| 126 | + if (token2id_.count(w)) { | ||
| 127 | + return {{token2id_.at(w)}, {0}}; | ||
| 128 | + } | ||
| 129 | + | ||
| 130 | + TokenIDs ans; | ||
| 131 | + | ||
| 132 | + std::vector<std::string> words = SplitUtf8(w); | ||
| 133 | + for (const auto &word : words) { | ||
| 134 | + if (word2ids_.count(word)) { | ||
| 135 | + auto ids = ConvertWordToIds(word); | ||
| 136 | + ans.tokens.insert(ans.tokens.end(), ids.tokens.begin(), | ||
| 137 | + ids.tokens.end()); | ||
| 138 | + ans.tones.insert(ans.tones.end(), ids.tones.begin(), ids.tones.end()); | ||
| 139 | + } | ||
| 140 | + } | ||
| 141 | + | ||
| 142 | + return ans; | ||
| 143 | + } | ||
| 144 | + | ||
| 145 | + void InitTokens(std::istream &is) { | ||
| 146 | + token2id_ = ReadTokens(is); | ||
| 147 | + token2id_[" "] = token2id_["_"]; | ||
| 148 | + | ||
| 149 | + std::vector<std::pair<std::string, std::string>> puncts = { | ||
| 150 | + {",", ","}, {".", "。"}, {"!", "!"}, {"?", "?"}}; | ||
| 151 | + | ||
| 152 | + for (const auto &p : puncts) { | ||
| 153 | + if (token2id_.count(p.first) && !token2id_.count(p.second)) { | ||
| 154 | + token2id_[p.second] = token2id_[p.first]; | ||
| 155 | + } | ||
| 156 | + | ||
| 157 | + if (!token2id_.count(p.first) && token2id_.count(p.second)) { | ||
| 158 | + token2id_[p.first] = token2id_[p.second]; | ||
| 159 | + } | ||
| 160 | + } | ||
| 161 | + | ||
| 162 | + if (!token2id_.count("、") && token2id_.count(",")) { | ||
| 163 | + token2id_["、"] = token2id_[","]; | ||
| 164 | + } | ||
| 165 | + } | ||
| 166 | + | ||
| 167 | + void InitLexicon(std::istream &is) { | ||
| 168 | + std::string word; | ||
| 169 | + std::vector<std::string> token_list; | ||
| 170 | + | ||
| 171 | + std::vector<std::string> phone_list; | ||
| 172 | + std::vector<int64_t> tone_list; | ||
| 173 | + | ||
| 174 | + std::string line; | ||
| 175 | + std::string phone; | ||
| 176 | + int32_t line_num = 0; | ||
| 177 | + | ||
| 178 | + while (std::getline(is, line)) { | ||
| 179 | + ++line_num; | ||
| 180 | + | ||
| 181 | + std::istringstream iss(line); | ||
| 182 | + | ||
| 183 | + token_list.clear(); | ||
| 184 | + phone_list.clear(); | ||
| 185 | + tone_list.clear(); | ||
| 186 | + | ||
| 187 | + iss >> word; | ||
| 188 | + ToLowerCase(&word); | ||
| 189 | + | ||
| 190 | + if (word2ids_.count(word)) { | ||
| 191 | + SHERPA_ONNX_LOGE("Duplicated word: %s at line %d:%s. Ignore it.", | ||
| 192 | + word.c_str(), line_num, line.c_str()); | ||
| 193 | + continue; | ||
| 194 | + } | ||
| 195 | + | ||
| 196 | + while (iss >> phone) { | ||
| 197 | + token_list.push_back(std::move(phone)); | ||
| 198 | + } | ||
| 199 | + | ||
| 200 | + if ((token_list.size() & 1) != 0) { | ||
| 201 | + SHERPA_ONNX_LOGE("Invalid line %d: '%s'", line_num, line.c_str()); | ||
| 202 | + exit(-1); | ||
| 203 | + } | ||
| 204 | + | ||
| 205 | + int32_t num_phones = token_list.size() / 2; | ||
| 206 | + phone_list.reserve(num_phones); | ||
| 207 | + tone_list.reserve(num_phones); | ||
| 208 | + | ||
| 209 | + for (int32_t i = 0; i != num_phones; ++i) { | ||
| 210 | + phone_list.push_back(std::move(token_list[i])); | ||
| 211 | + tone_list.push_back(std::stoi(token_list[i + num_phones], nullptr)); | ||
| 212 | + if (tone_list.back() < 0 || tone_list.back() > 50) { | ||
| 213 | + SHERPA_ONNX_LOGE("Invalid line %d: '%s'", line_num, line.c_str()); | ||
| 214 | + exit(-1); | ||
| 215 | + } | ||
| 216 | + } | ||
| 217 | + | ||
| 218 | + std::vector<int32_t> ids = ConvertTokensToIds(token2id_, phone_list); | ||
| 219 | + if (ids.empty()) { | ||
| 220 | + continue; | ||
| 221 | + } | ||
| 222 | + | ||
| 223 | + if (ids.size() != num_phones) { | ||
| 224 | + SHERPA_ONNX_LOGE("Invalid line %d: '%s'", line_num, line.c_str()); | ||
| 225 | + exit(-1); | ||
| 226 | + } | ||
| 227 | + | ||
| 228 | + std::vector<int64_t> ids64{ids.begin(), ids.end()}; | ||
| 229 | + | ||
| 230 | + word2ids_.insert( | ||
| 231 | + {std::move(word), TokenIDs{std::move(ids64), std::move(tone_list)}}); | ||
| 232 | + } | ||
| 233 | + | ||
| 234 | + word2ids_["呣"] = word2ids_["母"]; | ||
| 235 | + word2ids_["嗯"] = word2ids_["恩"]; | ||
| 236 | + } | ||
| 237 | + | ||
| 238 | + private: | ||
| 239 | + // lexicon.txt is saved in word2ids_ | ||
| 240 | + std::unordered_map<std::string, TokenIDs> word2ids_; | ||
| 241 | + | ||
| 242 | + // tokens.txt is saved in token2id_ | ||
| 243 | + std::unordered_map<std::string, int32_t> token2id_; | ||
| 244 | + | ||
| 245 | + OfflineTtsVitsModelMetaData meta_data_; | ||
| 246 | + | ||
| 247 | + std::unique_ptr<cppjieba::Jieba> jieba_; | ||
| 248 | + bool debug_ = false; | ||
| 249 | +}; | ||
| 250 | + | ||
| 251 | +MeloTtsLexicon::~MeloTtsLexicon() = default; | ||
| 252 | + | ||
| 253 | +MeloTtsLexicon::MeloTtsLexicon(const std::string &lexicon, | ||
| 254 | + const std::string &tokens, | ||
| 255 | + const std::string &dict_dir, | ||
| 256 | + const OfflineTtsVitsModelMetaData &meta_data, | ||
| 257 | + bool debug) | ||
| 258 | + : impl_(std::make_unique<Impl>(lexicon, tokens, dict_dir, meta_data, | ||
| 259 | + debug)) {} | ||
| 260 | + | ||
| 261 | +std::vector<TokenIDs> MeloTtsLexicon::ConvertTextToTokenIds( | ||
| 262 | + const std::string &text, const std::string & /*unused_voice = ""*/) const { | ||
| 263 | + return impl_->ConvertTextToTokenIds(text); | ||
| 264 | +} | ||
| 265 | + | ||
| 266 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/melo-tts-lexicon.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/melo-tts-lexicon.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_MELO_TTS_LEXICON_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_MELO_TTS_LEXICON_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | +#include <string> | ||
| 10 | +#include <unordered_map> | ||
| 11 | +#include <vector> | ||
| 12 | + | ||
| 13 | +#include "sherpa-onnx/csrc/offline-tts-frontend.h" | ||
| 14 | +#include "sherpa-onnx/csrc/offline-tts-vits-model-metadata.h" | ||
| 15 | + | ||
| 16 | +namespace sherpa_onnx { | ||
| 17 | + | ||
| 18 | +class MeloTtsLexicon : public OfflineTtsFrontend { | ||
| 19 | + public: | ||
| 20 | + ~MeloTtsLexicon() override; | ||
| 21 | + MeloTtsLexicon(const std::string &lexicon, const std::string &tokens, | ||
| 22 | + const std::string &dict_dir, | ||
| 23 | + const OfflineTtsVitsModelMetaData &meta_data, bool debug); | ||
| 24 | + | ||
| 25 | + std::vector<TokenIDs> ConvertTextToTokenIds( | ||
| 26 | + const std::string &text, | ||
| 27 | + const std::string &unused_voice = "") const override; | ||
| 28 | + | ||
| 29 | + private: | ||
| 30 | + class Impl; | ||
| 31 | + std::unique_ptr<Impl> impl_; | ||
| 32 | +}; | ||
| 33 | + | ||
| 34 | +} // namespace sherpa_onnx | ||
| 35 | + | ||
| 36 | +#endif // SHERPA_ONNX_CSRC_MELO_TTS_LEXICON_H_ |
| @@ -94,8 +94,7 @@ OfflineTtsCharacterFrontend::OfflineTtsCharacterFrontend( | @@ -94,8 +94,7 @@ OfflineTtsCharacterFrontend::OfflineTtsCharacterFrontend( | ||
| 94 | 94 | ||
| 95 | #endif | 95 | #endif |
| 96 | 96 | ||
| 97 | -std::vector<std::vector<int64_t>> | ||
| 98 | -OfflineTtsCharacterFrontend::ConvertTextToTokenIds( | 97 | +std::vector<TokenIDs> OfflineTtsCharacterFrontend::ConvertTextToTokenIds( |
| 99 | const std::string &_text, const std::string & /*voice = ""*/) const { | 98 | const std::string &_text, const std::string & /*voice = ""*/) const { |
| 100 | // see | 99 | // see |
| 101 | // https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/utils/text/tokenizer.py#L87 | 100 | // https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/utils/text/tokenizer.py#L87 |
| @@ -112,7 +111,7 @@ OfflineTtsCharacterFrontend::ConvertTextToTokenIds( | @@ -112,7 +111,7 @@ OfflineTtsCharacterFrontend::ConvertTextToTokenIds( | ||
| 112 | std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> conv; | 111 | std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> conv; |
| 113 | std::u32string s = conv.from_bytes(text); | 112 | std::u32string s = conv.from_bytes(text); |
| 114 | 113 | ||
| 115 | - std::vector<std::vector<int64_t>> ans; | 114 | + std::vector<TokenIDs> ans; |
| 116 | 115 | ||
| 117 | std::vector<int64_t> this_sentence; | 116 | std::vector<int64_t> this_sentence; |
| 118 | if (add_blank) { | 117 | if (add_blank) { |
| @@ -41,7 +41,7 @@ class OfflineTtsCharacterFrontend : public OfflineTtsFrontend { | @@ -41,7 +41,7 @@ class OfflineTtsCharacterFrontend : public OfflineTtsFrontend { | ||
| 41 | * If a frontend does not support splitting the text into | 41 | * If a frontend does not support splitting the text into |
| 42 | * sentences, the resulting vector contains only one subvector. | 42 | * sentences, the resulting vector contains only one subvector. |
| 43 | */ | 43 | */ |
| 44 | - std::vector<std::vector<int64_t>> ConvertTextToTokenIds( | 44 | + std::vector<TokenIDs> ConvertTextToTokenIds( |
| 45 | const std::string &text, const std::string &voice = "") const override; | 45 | const std::string &text, const std::string &voice = "") const override; |
| 46 | 46 | ||
| 47 | private: | 47 | private: |
sherpa-onnx/csrc/offline-tts-frontend.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-tts-frontend.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-tts-frontend.h" | ||
| 6 | + | ||
| 7 | +#include <sstream> | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +std::string TokenIDs::ToString() const { | ||
| 13 | + std::ostringstream os; | ||
| 14 | + os << "TokenIDs("; | ||
| 15 | + os << "tokens=["; | ||
| 16 | + std::string sep; | ||
| 17 | + for (auto i : tokens) { | ||
| 18 | + os << sep << i; | ||
| 19 | + sep = ", "; | ||
| 20 | + } | ||
| 21 | + os << "], "; | ||
| 22 | + | ||
| 23 | + os << "tones=["; | ||
| 24 | + sep = {}; | ||
| 25 | + for (auto i : tones) { | ||
| 26 | + os << sep << i; | ||
| 27 | + sep = ", "; | ||
| 28 | + } | ||
| 29 | + os << "]"; | ||
| 30 | + os << ")"; | ||
| 31 | + return os.str(); | ||
| 32 | +} | ||
| 33 | + | ||
| 34 | +} // namespace sherpa_onnx |
| @@ -8,8 +8,28 @@ | @@ -8,8 +8,28 @@ | ||
| 8 | #include <string> | 8 | #include <string> |
| 9 | #include <vector> | 9 | #include <vector> |
| 10 | 10 | ||
| 11 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 12 | + | ||
| 11 | namespace sherpa_onnx { | 13 | namespace sherpa_onnx { |
| 12 | 14 | ||
| 15 | +struct TokenIDs { | ||
| 16 | + TokenIDs() = default; | ||
| 17 | + | ||
| 18 | + /*implicit*/ TokenIDs(const std::vector<int64_t> &tokens) // NOLINT | ||
| 19 | + : tokens{tokens} {} | ||
| 20 | + | ||
| 21 | + TokenIDs(const std::vector<int64_t> &tokens, | ||
| 22 | + const std::vector<int64_t> &tones) | ||
| 23 | + : tokens{tokens}, tones{tones} {} | ||
| 24 | + | ||
| 25 | + std::string ToString() const; | ||
| 26 | + | ||
| 27 | + std::vector<int64_t> tokens; | ||
| 28 | + | ||
| 29 | + // Used only in MeloTTS | ||
| 30 | + std::vector<int64_t> tones; | ||
| 31 | +}; | ||
| 32 | + | ||
| 13 | class OfflineTtsFrontend { | 33 | class OfflineTtsFrontend { |
| 14 | public: | 34 | public: |
| 15 | virtual ~OfflineTtsFrontend() = default; | 35 | virtual ~OfflineTtsFrontend() = default; |
| @@ -26,7 +46,7 @@ class OfflineTtsFrontend { | @@ -26,7 +46,7 @@ class OfflineTtsFrontend { | ||
| 26 | * If a frontend does not support splitting the text into sentences, | 46 | * If a frontend does not support splitting the text into sentences, |
| 27 | * the resulting vector contains only one subvector. | 47 | * the resulting vector contains only one subvector. |
| 28 | */ | 48 | */ |
| 29 | - virtual std::vector<std::vector<int64_t>> ConvertTextToTokenIds( | 49 | + virtual std::vector<TokenIDs> ConvertTextToTokenIds( |
| 30 | const std::string &text, const std::string &voice = "") const = 0; | 50 | const std::string &text, const std::string &voice = "") const = 0; |
| 31 | }; | 51 | }; |
| 32 | 52 |
| @@ -22,6 +22,7 @@ | @@ -22,6 +22,7 @@ | ||
| 22 | #include "sherpa-onnx/csrc/jieba-lexicon.h" | 22 | #include "sherpa-onnx/csrc/jieba-lexicon.h" |
| 23 | #include "sherpa-onnx/csrc/lexicon.h" | 23 | #include "sherpa-onnx/csrc/lexicon.h" |
| 24 | #include "sherpa-onnx/csrc/macros.h" | 24 | #include "sherpa-onnx/csrc/macros.h" |
| 25 | +#include "sherpa-onnx/csrc/melo-tts-lexicon.h" | ||
| 25 | #include "sherpa-onnx/csrc/offline-tts-character-frontend.h" | 26 | #include "sherpa-onnx/csrc/offline-tts-character-frontend.h" |
| 26 | #include "sherpa-onnx/csrc/offline-tts-frontend.h" | 27 | #include "sherpa-onnx/csrc/offline-tts-frontend.h" |
| 27 | #include "sherpa-onnx/csrc/offline-tts-impl.h" | 28 | #include "sherpa-onnx/csrc/offline-tts-impl.h" |
| @@ -174,26 +175,47 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | @@ -174,26 +175,47 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | ||
| 174 | } | 175 | } |
| 175 | } | 176 | } |
| 176 | 177 | ||
| 177 | - std::vector<std::vector<int64_t>> x = | 178 | + std::vector<TokenIDs> token_ids = |
| 178 | frontend_->ConvertTextToTokenIds(text, meta_data.voice); | 179 | frontend_->ConvertTextToTokenIds(text, meta_data.voice); |
| 179 | 180 | ||
| 180 | - if (x.empty() || (x.size() == 1 && x[0].empty())) { | 181 | + if (token_ids.empty() || |
| 182 | + (token_ids.size() == 1 && token_ids[0].tokens.empty())) { | ||
| 181 | SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str()); | 183 | SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str()); |
| 182 | return {}; | 184 | return {}; |
| 183 | } | 185 | } |
| 184 | 186 | ||
| 187 | + std::vector<std::vector<int64_t>> x; | ||
| 188 | + std::vector<std::vector<int64_t>> tones; | ||
| 189 | + | ||
| 190 | + x.reserve(token_ids.size()); | ||
| 191 | + | ||
| 192 | + for (auto &i : token_ids) { | ||
| 193 | + x.push_back(std::move(i.tokens)); | ||
| 194 | + } | ||
| 195 | + | ||
| 196 | + if (!token_ids[0].tones.empty()) { | ||
| 197 | + tones.reserve(token_ids.size()); | ||
| 198 | + for (auto &i : token_ids) { | ||
| 199 | + tones.push_back(std::move(i.tones)); | ||
| 200 | + } | ||
| 201 | + } | ||
| 202 | + | ||
| 185 | // TODO(fangjun): add blank inside the frontend, not here | 203 | // TODO(fangjun): add blank inside the frontend, not here |
| 186 | if (meta_data.add_blank && config_.model.vits.data_dir.empty() && | 204 | if (meta_data.add_blank && config_.model.vits.data_dir.empty() && |
| 187 | meta_data.frontend != "characters") { | 205 | meta_data.frontend != "characters") { |
| 188 | for (auto &k : x) { | 206 | for (auto &k : x) { |
| 189 | k = AddBlank(k); | 207 | k = AddBlank(k); |
| 190 | } | 208 | } |
| 209 | + | ||
| 210 | + for (auto &k : tones) { | ||
| 211 | + k = AddBlank(k); | ||
| 212 | + } | ||
| 191 | } | 213 | } |
| 192 | 214 | ||
| 193 | int32_t x_size = static_cast<int32_t>(x.size()); | 215 | int32_t x_size = static_cast<int32_t>(x.size()); |
| 194 | 216 | ||
| 195 | if (config_.max_num_sentences <= 0 || x_size <= config_.max_num_sentences) { | 217 | if (config_.max_num_sentences <= 0 || x_size <= config_.max_num_sentences) { |
| 196 | - auto ans = Process(x, sid, speed); | 218 | + auto ans = Process(x, tones, sid, speed); |
| 197 | if (callback) { | 219 | if (callback) { |
| 198 | callback(ans.samples.data(), ans.samples.size(), 1.0); | 220 | callback(ans.samples.data(), ans.samples.size(), 1.0); |
| 199 | } | 221 | } |
| @@ -202,9 +224,12 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | @@ -202,9 +224,12 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | ||
| 202 | 224 | ||
| 203 | // the input text is too long, we process sentences within it in batches | 225 | // the input text is too long, we process sentences within it in batches |
| 204 | // to avoid OOM. Batch size is config_.max_num_sentences | 226 | // to avoid OOM. Batch size is config_.max_num_sentences |
| 205 | - std::vector<std::vector<int64_t>> batch; | 227 | + std::vector<std::vector<int64_t>> batch_x; |
| 228 | + std::vector<std::vector<int64_t>> batch_tones; | ||
| 229 | + | ||
| 206 | int32_t batch_size = config_.max_num_sentences; | 230 | int32_t batch_size = config_.max_num_sentences; |
| 207 | - batch.reserve(config_.max_num_sentences); | 231 | + batch_x.reserve(config_.max_num_sentences); |
| 232 | + batch_tones.reserve(config_.max_num_sentences); | ||
| 208 | int32_t num_batches = x_size / batch_size; | 233 | int32_t num_batches = x_size / batch_size; |
| 209 | 234 | ||
| 210 | if (config_.model.debug) { | 235 | if (config_.model.debug) { |
| @@ -221,12 +246,17 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | @@ -221,12 +246,17 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | ||
| 221 | int32_t k = 0; | 246 | int32_t k = 0; |
| 222 | 247 | ||
| 223 | for (int32_t b = 0; b != num_batches && should_continue; ++b) { | 248 | for (int32_t b = 0; b != num_batches && should_continue; ++b) { |
| 224 | - batch.clear(); | 249 | + batch_x.clear(); |
| 250 | + batch_tones.clear(); | ||
| 225 | for (int32_t i = 0; i != batch_size; ++i, ++k) { | 251 | for (int32_t i = 0; i != batch_size; ++i, ++k) { |
| 226 | - batch.push_back(std::move(x[k])); | 252 | + batch_x.push_back(std::move(x[k])); |
| 253 | + | ||
| 254 | + if (!tones.empty()) { | ||
| 255 | + batch_tones.push_back(std::move(tones[k])); | ||
| 256 | + } | ||
| 227 | } | 257 | } |
| 228 | 258 | ||
| 229 | - auto audio = Process(batch, sid, speed); | 259 | + auto audio = Process(batch_x, batch_tones, sid, speed); |
| 230 | ans.sample_rate = audio.sample_rate; | 260 | ans.sample_rate = audio.sample_rate; |
| 231 | ans.samples.insert(ans.samples.end(), audio.samples.begin(), | 261 | ans.samples.insert(ans.samples.end(), audio.samples.begin(), |
| 232 | audio.samples.end()); | 262 | audio.samples.end()); |
| @@ -239,14 +269,19 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | @@ -239,14 +269,19 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | ||
| 239 | } | 269 | } |
| 240 | } | 270 | } |
| 241 | 271 | ||
| 242 | - batch.clear(); | 272 | + batch_x.clear(); |
| 273 | + batch_tones.clear(); | ||
| 243 | while (k < static_cast<int32_t>(x.size()) && should_continue) { | 274 | while (k < static_cast<int32_t>(x.size()) && should_continue) { |
| 244 | - batch.push_back(std::move(x[k])); | 275 | + batch_x.push_back(std::move(x[k])); |
| 276 | + if (!tones.empty()) { | ||
| 277 | + batch_tones.push_back(std::move(tones[k])); | ||
| 278 | + } | ||
| 279 | + | ||
| 245 | ++k; | 280 | ++k; |
| 246 | } | 281 | } |
| 247 | 282 | ||
| 248 | - if (!batch.empty()) { | ||
| 249 | - auto audio = Process(batch, sid, speed); | 283 | + if (!batch_x.empty()) { |
| 284 | + auto audio = Process(batch_x, batch_tones, sid, speed); | ||
| 250 | ans.sample_rate = audio.sample_rate; | 285 | ans.sample_rate = audio.sample_rate; |
| 251 | ans.samples.insert(ans.samples.end(), audio.samples.begin(), | 286 | ans.samples.insert(ans.samples.end(), audio.samples.begin(), |
| 252 | audio.samples.end()); | 287 | audio.samples.end()); |
| @@ -308,6 +343,12 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | @@ -308,6 +343,12 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | ||
| 308 | if (meta_data.frontend == "characters") { | 343 | if (meta_data.frontend == "characters") { |
| 309 | frontend_ = std::make_unique<OfflineTtsCharacterFrontend>( | 344 | frontend_ = std::make_unique<OfflineTtsCharacterFrontend>( |
| 310 | config_.model.vits.tokens, meta_data); | 345 | config_.model.vits.tokens, meta_data); |
| 346 | + } else if (meta_data.jieba && !config_.model.vits.dict_dir.empty() && | ||
| 347 | + meta_data.is_melo_tts) { | ||
| 348 | + frontend_ = std::make_unique<MeloTtsLexicon>( | ||
| 349 | + config_.model.vits.lexicon, config_.model.vits.tokens, | ||
| 350 | + config_.model.vits.dict_dir, model_->GetMetaData(), | ||
| 351 | + config_.model.debug); | ||
| 311 | } else if (meta_data.jieba && !config_.model.vits.dict_dir.empty()) { | 352 | } else if (meta_data.jieba && !config_.model.vits.dict_dir.empty()) { |
| 312 | frontend_ = std::make_unique<JiebaLexicon>( | 353 | frontend_ = std::make_unique<JiebaLexicon>( |
| 313 | config_.model.vits.lexicon, config_.model.vits.tokens, | 354 | config_.model.vits.lexicon, config_.model.vits.tokens, |
| @@ -344,6 +385,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | @@ -344,6 +385,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | ||
| 344 | } | 385 | } |
| 345 | 386 | ||
| 346 | GeneratedAudio Process(const std::vector<std::vector<int64_t>> &tokens, | 387 | GeneratedAudio Process(const std::vector<std::vector<int64_t>> &tokens, |
| 388 | + const std::vector<std::vector<int64_t>> &tones, | ||
| 347 | int32_t sid, float speed) const { | 389 | int32_t sid, float speed) const { |
| 348 | int32_t num_tokens = 0; | 390 | int32_t num_tokens = 0; |
| 349 | for (const auto &k : tokens) { | 391 | for (const auto &k : tokens) { |
| @@ -356,6 +398,14 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | @@ -356,6 +398,14 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | ||
| 356 | x.insert(x.end(), k.begin(), k.end()); | 398 | x.insert(x.end(), k.begin(), k.end()); |
| 357 | } | 399 | } |
| 358 | 400 | ||
| 401 | + std::vector<int64_t> tone_list; | ||
| 402 | + if (!tones.empty()) { | ||
| 403 | + tone_list.reserve(num_tokens); | ||
| 404 | + for (const auto &k : tones) { | ||
| 405 | + tone_list.insert(tone_list.end(), k.begin(), k.end()); | ||
| 406 | + } | ||
| 407 | + } | ||
| 408 | + | ||
| 359 | auto memory_info = | 409 | auto memory_info = |
| 360 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | 410 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); |
| 361 | 411 | ||
| @@ -363,7 +413,20 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | @@ -363,7 +413,20 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | ||
| 363 | Ort::Value x_tensor = Ort::Value::CreateTensor( | 413 | Ort::Value x_tensor = Ort::Value::CreateTensor( |
| 364 | memory_info, x.data(), x.size(), x_shape.data(), x_shape.size()); | 414 | memory_info, x.data(), x.size(), x_shape.data(), x_shape.size()); |
| 365 | 415 | ||
| 366 | - Ort::Value audio = model_->Run(std::move(x_tensor), sid, speed); | 416 | + Ort::Value tones_tensor{nullptr}; |
| 417 | + if (!tones.empty()) { | ||
| 418 | + tones_tensor = Ort::Value::CreateTensor(memory_info, tone_list.data(), | ||
| 419 | + tone_list.size(), x_shape.data(), | ||
| 420 | + x_shape.size()); | ||
| 421 | + } | ||
| 422 | + | ||
| 423 | + Ort::Value audio{nullptr}; | ||
| 424 | + if (tones.empty()) { | ||
| 425 | + audio = model_->Run(std::move(x_tensor), sid, speed); | ||
| 426 | + } else { | ||
| 427 | + audio = | ||
| 428 | + model_->Run(std::move(x_tensor), std::move(tones_tensor), sid, speed); | ||
| 429 | + } | ||
| 367 | 430 | ||
| 368 | std::vector<int64_t> audio_shape = | 431 | std::vector<int64_t> audio_shape = |
| 369 | audio.GetTensorTypeAndShapeInfo().GetShape(); | 432 | audio.GetTensorTypeAndShapeInfo().GetShape(); |
| @@ -21,6 +21,7 @@ struct OfflineTtsVitsModelMetaData { | @@ -21,6 +21,7 @@ struct OfflineTtsVitsModelMetaData { | ||
| 21 | bool is_piper = false; | 21 | bool is_piper = false; |
| 22 | bool is_coqui = false; | 22 | bool is_coqui = false; |
| 23 | bool is_icefall = false; | 23 | bool is_icefall = false; |
| 24 | + bool is_melo_tts = false; | ||
| 24 | 25 | ||
| 25 | // for Chinese TTS models from | 26 | // for Chinese TTS models from |
| 26 | // https://github.com/Plachtaa/VITS-fast-fine-tuning | 27 | // https://github.com/Plachtaa/VITS-fast-fine-tuning |
| @@ -33,6 +34,10 @@ struct OfflineTtsVitsModelMetaData { | @@ -33,6 +34,10 @@ struct OfflineTtsVitsModelMetaData { | ||
| 33 | int32_t use_eos_bos = 0; | 34 | int32_t use_eos_bos = 0; |
| 34 | int32_t pad_id = 0; | 35 | int32_t pad_id = 0; |
| 35 | 36 | ||
| 37 | + // for melo tts | ||
| 38 | + int32_t speaker_id = 0; | ||
| 39 | + int32_t version = 0; | ||
| 40 | + | ||
| 36 | std::string punctuations; | 41 | std::string punctuations; |
| 37 | std::string language; | 42 | std::string language; |
| 38 | std::string voice; | 43 | std::string voice; |
| @@ -45,6 +45,64 @@ class OfflineTtsVitsModel::Impl { | @@ -45,6 +45,64 @@ class OfflineTtsVitsModel::Impl { | ||
| 45 | return RunVits(std::move(x), sid, speed); | 45 | return RunVits(std::move(x), sid, speed); |
| 46 | } | 46 | } |
| 47 | 47 | ||
| 48 | + Ort::Value Run(Ort::Value x, Ort::Value tones, int64_t sid, float speed) { | ||
| 49 | + // For MeloTTS, we hardcode sid to the one contained in the meta data | ||
| 50 | + sid = meta_data_.speaker_id; | ||
| 51 | + | ||
| 52 | + auto memory_info = | ||
| 53 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 54 | + | ||
| 55 | + std::vector<int64_t> x_shape = x.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 56 | + if (x_shape[0] != 1) { | ||
| 57 | + SHERPA_ONNX_LOGE("Support only batch_size == 1. Given: %d", | ||
| 58 | + static_cast<int32_t>(x_shape[0])); | ||
| 59 | + exit(-1); | ||
| 60 | + } | ||
| 61 | + | ||
| 62 | + int64_t len = x_shape[1]; | ||
| 63 | + int64_t len_shape = 1; | ||
| 64 | + | ||
| 65 | + Ort::Value x_length = | ||
| 66 | + Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1); | ||
| 67 | + | ||
| 68 | + int64_t scale_shape = 1; | ||
| 69 | + float noise_scale = config_.vits.noise_scale; | ||
| 70 | + float length_scale = config_.vits.length_scale; | ||
| 71 | + float noise_scale_w = config_.vits.noise_scale_w; | ||
| 72 | + | ||
| 73 | + if (speed != 1 && speed > 0) { | ||
| 74 | + length_scale = 1. / speed; | ||
| 75 | + } | ||
| 76 | + | ||
| 77 | + Ort::Value noise_scale_tensor = | ||
| 78 | + Ort::Value::CreateTensor(memory_info, &noise_scale, 1, &scale_shape, 1); | ||
| 79 | + | ||
| 80 | + Ort::Value length_scale_tensor = Ort::Value::CreateTensor( | ||
| 81 | + memory_info, &length_scale, 1, &scale_shape, 1); | ||
| 82 | + | ||
| 83 | + Ort::Value noise_scale_w_tensor = Ort::Value::CreateTensor( | ||
| 84 | + memory_info, &noise_scale_w, 1, &scale_shape, 1); | ||
| 85 | + | ||
| 86 | + Ort::Value sid_tensor = | ||
| 87 | + Ort::Value::CreateTensor(memory_info, &sid, 1, &scale_shape, 1); | ||
| 88 | + | ||
| 89 | + std::vector<Ort::Value> inputs; | ||
| 90 | + inputs.reserve(7); | ||
| 91 | + inputs.push_back(std::move(x)); | ||
| 92 | + inputs.push_back(std::move(x_length)); | ||
| 93 | + inputs.push_back(std::move(tones)); | ||
| 94 | + inputs.push_back(std::move(sid_tensor)); | ||
| 95 | + inputs.push_back(std::move(noise_scale_tensor)); | ||
| 96 | + inputs.push_back(std::move(length_scale_tensor)); | ||
| 97 | + inputs.push_back(std::move(noise_scale_w_tensor)); | ||
| 98 | + | ||
| 99 | + auto out = | ||
| 100 | + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 101 | + output_names_ptr_.data(), output_names_ptr_.size()); | ||
| 102 | + | ||
| 103 | + return std::move(out[0]); | ||
| 104 | + } | ||
| 105 | + | ||
| 48 | const OfflineTtsVitsModelMetaData &GetMetaData() const { return meta_data_; } | 106 | const OfflineTtsVitsModelMetaData &GetMetaData() const { return meta_data_; } |
| 49 | 107 | ||
| 50 | private: | 108 | private: |
| @@ -83,6 +141,10 @@ class OfflineTtsVitsModel::Impl { | @@ -83,6 +141,10 @@ class OfflineTtsVitsModel::Impl { | ||
| 83 | SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate"); | 141 | SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate"); |
| 84 | SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.add_blank, "add_blank", | 142 | SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.add_blank, "add_blank", |
| 85 | 0); | 143 | 0); |
| 144 | + | ||
| 145 | + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.speaker_id, "speaker_id", | ||
| 146 | + 0); | ||
| 147 | + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.version, "version", 0); | ||
| 86 | SHERPA_ONNX_READ_META_DATA(meta_data_.num_speakers, "n_speakers"); | 148 | SHERPA_ONNX_READ_META_DATA(meta_data_.num_speakers, "n_speakers"); |
| 87 | SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.punctuations, | 149 | SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.punctuations, |
| 88 | "punctuation", ""); | 150 | "punctuation", ""); |
| @@ -115,6 +177,22 @@ class OfflineTtsVitsModel::Impl { | @@ -115,6 +177,22 @@ class OfflineTtsVitsModel::Impl { | ||
| 115 | if (comment.find("icefall") != std::string::npos) { | 177 | if (comment.find("icefall") != std::string::npos) { |
| 116 | meta_data_.is_icefall = true; | 178 | meta_data_.is_icefall = true; |
| 117 | } | 179 | } |
| 180 | + | ||
| 181 | + if (comment.find("melo") != std::string::npos) { | ||
| 182 | + meta_data_.is_melo_tts = true; | ||
| 183 | + int32_t expected_version = 2; | ||
| 184 | + if (meta_data_.version < expected_version) { | ||
| 185 | + SHERPA_ONNX_LOGE( | ||
| 186 | + "Please download the latest MeloTTS model and retry. Current " | ||
| 187 | + "version: %d. Expected version: %d", | ||
| 188 | + meta_data_.version, expected_version); | ||
| 189 | + exit(-1); | ||
| 190 | + } | ||
| 191 | + | ||
| 192 | + // NOTE(fangjun): | ||
| 193 | + // version 0 is the first version | ||
| 194 | + // version 2: add jieba=1 to the metadata | ||
| 195 | + } | ||
| 118 | } | 196 | } |
| 119 | 197 | ||
| 120 | Ort::Value RunVitsPiperOrCoqui(Ort::Value x, int64_t sid, float speed) { | 198 | Ort::Value RunVitsPiperOrCoqui(Ort::Value x, int64_t sid, float speed) { |
| @@ -269,6 +347,12 @@ Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, int64_t sid /*=0*/, | @@ -269,6 +347,12 @@ Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, int64_t sid /*=0*/, | ||
| 269 | return impl_->Run(std::move(x), sid, speed); | 347 | return impl_->Run(std::move(x), sid, speed); |
| 270 | } | 348 | } |
| 271 | 349 | ||
| 350 | +Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, Ort::Value tones, | ||
| 351 | + int64_t sid /*= 0*/, | ||
| 352 | + float speed /*= 1.0*/) { | ||
| 353 | + return impl_->Run(std::move(x), std::move(tones), sid, speed); | ||
| 354 | +} | ||
| 355 | + | ||
| 272 | const OfflineTtsVitsModelMetaData &OfflineTtsVitsModel::GetMetaData() const { | 356 | const OfflineTtsVitsModelMetaData &OfflineTtsVitsModel::GetMetaData() const { |
| 273 | return impl_->GetMetaData(); | 357 | return impl_->GetMetaData(); |
| 274 | } | 358 | } |
| @@ -40,6 +40,10 @@ class OfflineTtsVitsModel { | @@ -40,6 +40,10 @@ class OfflineTtsVitsModel { | ||
| 40 | */ | 40 | */ |
| 41 | Ort::Value Run(Ort::Value x, int64_t sid = 0, float speed = 1.0); | 41 | Ort::Value Run(Ort::Value x, int64_t sid = 0, float speed = 1.0); |
| 42 | 42 | ||
| 43 | + // This is for MeloTTS | ||
| 44 | + Ort::Value Run(Ort::Value x, Ort::Value tones, int64_t sid = 0, | ||
| 45 | + float speed = 1.0); | ||
| 46 | + | ||
| 43 | const OfflineTtsVitsModelMetaData &GetMetaData() const; | 47 | const OfflineTtsVitsModelMetaData &GetMetaData() const; |
| 44 | 48 | ||
| 45 | private: | 49 | private: |
| @@ -5,8 +5,8 @@ | @@ -5,8 +5,8 @@ | ||
| 5 | #ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_ | 5 | #ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_ |
| 6 | #define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_ | 6 | #define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_ |
| 7 | 7 | ||
| 8 | -#include <vector> | ||
| 9 | #include <string> | 8 | #include <string> |
| 9 | +#include <vector> | ||
| 10 | 10 | ||
| 11 | #include "onnxruntime_cxx_api.h" // NOLINT | 11 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 12 | #include "sherpa-onnx/csrc/offline-whisper-model-config.h" | 12 | #include "sherpa-onnx/csrc/offline-whisper-model-config.h" |
| @@ -36,7 +36,6 @@ class OfflineWhisperDecoder { | @@ -36,7 +36,6 @@ class OfflineWhisperDecoder { | ||
| 36 | Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0; | 36 | Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0; |
| 37 | 37 | ||
| 38 | virtual void SetConfig(const OfflineWhisperModelConfig &config) = 0; | 38 | virtual void SetConfig(const OfflineWhisperModelConfig &config) = 0; |
| 39 | - | ||
| 40 | }; | 39 | }; |
| 41 | 40 | ||
| 42 | } // namespace sherpa_onnx | 41 | } // namespace sherpa_onnx |
| @@ -12,7 +12,8 @@ | @@ -12,7 +12,8 @@ | ||
| 12 | 12 | ||
| 13 | namespace sherpa_onnx { | 13 | namespace sherpa_onnx { |
| 14 | 14 | ||
| 15 | -void OfflineWhisperGreedySearchDecoder::SetConfig(const OfflineWhisperModelConfig &config) { | 15 | +void OfflineWhisperGreedySearchDecoder::SetConfig( |
| 16 | + const OfflineWhisperModelConfig &config) { | ||
| 16 | config_ = config; | 17 | config_ = config; |
| 17 | } | 18 | } |
| 18 | 19 | ||
| @@ -135,9 +136,9 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, | @@ -135,9 +136,9 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, | ||
| 135 | 136 | ||
| 136 | const auto &id2lang = model_->GetID2Lang(); | 137 | const auto &id2lang = model_->GetID2Lang(); |
| 137 | if (id2lang.count(initial_tokens[1])) { | 138 | if (id2lang.count(initial_tokens[1])) { |
| 138 | - ans[0].lang = id2lang.at(initial_tokens[1]); | 139 | + ans[0].lang = id2lang.at(initial_tokens[1]); |
| 139 | } else { | 140 | } else { |
| 140 | - ans[0].lang = ""; | 141 | + ans[0].lang = ""; |
| 141 | } | 142 | } |
| 142 | 143 | ||
| 143 | ans[0].tokens = std::move(predicted_tokens); | 144 | ans[0].tokens = std::move(predicted_tokens); |
| @@ -153,15 +153,21 @@ Ort::Value View(Ort::Value *v) { | @@ -153,15 +153,21 @@ Ort::Value View(Ort::Value *v) { | ||
| 153 | } | 153 | } |
| 154 | } | 154 | } |
| 155 | 155 | ||
| 156 | +template <typename T /*= float*/> | ||
| 156 | void Print1D(Ort::Value *v) { | 157 | void Print1D(Ort::Value *v) { |
| 157 | std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); | 158 | std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); |
| 158 | - const float *d = v->GetTensorData<float>(); | 159 | + const T *d = v->GetTensorData<T>(); |
| 160 | + std::ostringstream os; | ||
| 159 | for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) { | 161 | for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) { |
| 160 | - fprintf(stderr, "%.3f ", d[i]); | 162 | + os << *d << " "; |
| 161 | } | 163 | } |
| 162 | - fprintf(stderr, "\n"); | 164 | + os << "\n"; |
| 165 | + fprintf(stderr, "%s\n", os.str().c_str()); | ||
| 163 | } | 166 | } |
| 164 | 167 | ||
| 168 | +template void Print1D<int64_t>(Ort::Value *v); | ||
| 169 | +template void Print1D<float>(Ort::Value *v); | ||
| 170 | + | ||
| 165 | template <typename T /*= float*/> | 171 | template <typename T /*= float*/> |
| 166 | void Print2D(Ort::Value *v) { | 172 | void Print2D(Ort::Value *v) { |
| 167 | std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); | 173 | std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); |
| @@ -69,6 +69,7 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v); | @@ -69,6 +69,7 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v); | ||
| 69 | Ort::Value View(Ort::Value *v); | 69 | Ort::Value View(Ort::Value *v); |
| 70 | 70 | ||
| 71 | // Print a 1-D tensor to stderr | 71 | // Print a 1-D tensor to stderr |
| 72 | +template <typename T = float> | ||
| 72 | void Print1D(Ort::Value *v); | 73 | void Print1D(Ort::Value *v); |
| 73 | 74 | ||
| 74 | // Print a 2-D tensor to stderr | 75 | // Print a 2-D tensor to stderr |
| @@ -214,7 +214,7 @@ PiperPhonemizeLexicon::PiperPhonemizeLexicon( | @@ -214,7 +214,7 @@ PiperPhonemizeLexicon::PiperPhonemizeLexicon( | ||
| 214 | } | 214 | } |
| 215 | #endif | 215 | #endif |
| 216 | 216 | ||
| 217 | -std::vector<std::vector<int64_t>> PiperPhonemizeLexicon::ConvertTextToTokenIds( | 217 | +std::vector<TokenIDs> PiperPhonemizeLexicon::ConvertTextToTokenIds( |
| 218 | const std::string &text, const std::string &voice /*= ""*/) const { | 218 | const std::string &text, const std::string &voice /*= ""*/) const { |
| 219 | piper::eSpeakPhonemeConfig config; | 219 | piper::eSpeakPhonemeConfig config; |
| 220 | 220 | ||
| @@ -232,7 +232,7 @@ std::vector<std::vector<int64_t>> PiperPhonemizeLexicon::ConvertTextToTokenIds( | @@ -232,7 +232,7 @@ std::vector<std::vector<int64_t>> PiperPhonemizeLexicon::ConvertTextToTokenIds( | ||
| 232 | piper::phonemize_eSpeak(text, config, phonemes); | 232 | piper::phonemize_eSpeak(text, config, phonemes); |
| 233 | } | 233 | } |
| 234 | 234 | ||
| 235 | - std::vector<std::vector<int64_t>> ans; | 235 | + std::vector<TokenIDs> ans; |
| 236 | 236 | ||
| 237 | std::vector<int64_t> phoneme_ids; | 237 | std::vector<int64_t> phoneme_ids; |
| 238 | 238 |
| @@ -30,7 +30,7 @@ class PiperPhonemizeLexicon : public OfflineTtsFrontend { | @@ -30,7 +30,7 @@ class PiperPhonemizeLexicon : public OfflineTtsFrontend { | ||
| 30 | const OfflineTtsVitsModelMetaData &meta_data); | 30 | const OfflineTtsVitsModelMetaData &meta_data); |
| 31 | #endif | 31 | #endif |
| 32 | 32 | ||
| 33 | - std::vector<std::vector<int64_t>> ConvertTextToTokenIds( | 33 | + std::vector<TokenIDs> ConvertTextToTokenIds( |
| 34 | const std::string &text, const std::string &voice = "") const override; | 34 | const std::string &text, const std::string &voice = "") const override; |
| 35 | 35 | ||
| 36 | private: | 36 | private: |
| @@ -31,8 +31,8 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) { | @@ -31,8 +31,8 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) { | ||
| 31 | api.ReleaseStatus(status); | 31 | api.ReleaseStatus(status); |
| 32 | } | 32 | } |
| 33 | 33 | ||
| 34 | -static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | ||
| 35 | - const std::string &provider_str, | 34 | +static Ort::SessionOptions GetSessionOptionsImpl( |
| 35 | + int32_t num_threads, const std::string &provider_str, | ||
| 36 | const ProviderConfig *provider_config = nullptr) { | 36 | const ProviderConfig *provider_config = nullptr) { |
| 37 | Provider p = StringToProvider(provider_str); | 37 | Provider p = StringToProvider(provider_str); |
| 38 | 38 | ||
| @@ -67,8 +67,9 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | @@ -67,8 +67,9 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | ||
| 67 | } | 67 | } |
| 68 | case Provider::kTRT: { | 68 | case Provider::kTRT: { |
| 69 | if (provider_config == nullptr) { | 69 | if (provider_config == nullptr) { |
| 70 | - SHERPA_ONNX_LOGE("Tensorrt support for Online models ony," | ||
| 71 | - "Must be extended for offline and others"); | 70 | + SHERPA_ONNX_LOGE( |
| 71 | + "Tensorrt support for Online models ony," | ||
| 72 | + "Must be extended for offline and others"); | ||
| 72 | exit(1); | 73 | exit(1); |
| 73 | } | 74 | } |
| 74 | auto trt_config = provider_config->trt_config; | 75 | auto trt_config = provider_config->trt_config; |
| @@ -84,29 +85,27 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | @@ -84,29 +85,27 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | ||
| 84 | std::to_string(trt_config.trt_max_partition_iterations); | 85 | std::to_string(trt_config.trt_max_partition_iterations); |
| 85 | auto trt_min_subgraph_size = | 86 | auto trt_min_subgraph_size = |
| 86 | std::to_string(trt_config.trt_min_subgraph_size); | 87 | std::to_string(trt_config.trt_min_subgraph_size); |
| 87 | - auto trt_fp16_enable = | ||
| 88 | - std::to_string(trt_config.trt_fp16_enable); | 88 | + auto trt_fp16_enable = std::to_string(trt_config.trt_fp16_enable); |
| 89 | auto trt_detailed_build_log = | 89 | auto trt_detailed_build_log = |
| 90 | std::to_string(trt_config.trt_detailed_build_log); | 90 | std::to_string(trt_config.trt_detailed_build_log); |
| 91 | auto trt_engine_cache_enable = | 91 | auto trt_engine_cache_enable = |
| 92 | std::to_string(trt_config.trt_engine_cache_enable); | 92 | std::to_string(trt_config.trt_engine_cache_enable); |
| 93 | auto trt_timing_cache_enable = | 93 | auto trt_timing_cache_enable = |
| 94 | std::to_string(trt_config.trt_timing_cache_enable); | 94 | std::to_string(trt_config.trt_timing_cache_enable); |
| 95 | - auto trt_dump_subgraphs = | ||
| 96 | - std::to_string(trt_config.trt_dump_subgraphs); | 95 | + auto trt_dump_subgraphs = std::to_string(trt_config.trt_dump_subgraphs); |
| 97 | std::vector<TrtPairs> trt_options = { | 96 | std::vector<TrtPairs> trt_options = { |
| 98 | - {"device_id", device_id.c_str()}, | ||
| 99 | - {"trt_max_workspace_size", trt_max_workspace_size.c_str()}, | ||
| 100 | - {"trt_max_partition_iterations", trt_max_partition_iterations.c_str()}, | ||
| 101 | - {"trt_min_subgraph_size", trt_min_subgraph_size.c_str()}, | ||
| 102 | - {"trt_fp16_enable", trt_fp16_enable.c_str()}, | ||
| 103 | - {"trt_detailed_build_log", trt_detailed_build_log.c_str()}, | ||
| 104 | - {"trt_engine_cache_enable", trt_engine_cache_enable.c_str()}, | ||
| 105 | - {"trt_engine_cache_path", trt_config.trt_engine_cache_path.c_str()}, | ||
| 106 | - {"trt_timing_cache_enable", trt_timing_cache_enable.c_str()}, | ||
| 107 | - {"trt_timing_cache_path", trt_config.trt_timing_cache_path.c_str()}, | ||
| 108 | - {"trt_dump_subgraphs", trt_dump_subgraphs.c_str()} | ||
| 109 | - }; | 97 | + {"device_id", device_id.c_str()}, |
| 98 | + {"trt_max_workspace_size", trt_max_workspace_size.c_str()}, | ||
| 99 | + {"trt_max_partition_iterations", | ||
| 100 | + trt_max_partition_iterations.c_str()}, | ||
| 101 | + {"trt_min_subgraph_size", trt_min_subgraph_size.c_str()}, | ||
| 102 | + {"trt_fp16_enable", trt_fp16_enable.c_str()}, | ||
| 103 | + {"trt_detailed_build_log", trt_detailed_build_log.c_str()}, | ||
| 104 | + {"trt_engine_cache_enable", trt_engine_cache_enable.c_str()}, | ||
| 105 | + {"trt_engine_cache_path", trt_config.trt_engine_cache_path.c_str()}, | ||
| 106 | + {"trt_timing_cache_enable", trt_timing_cache_enable.c_str()}, | ||
| 107 | + {"trt_timing_cache_path", trt_config.trt_timing_cache_path.c_str()}, | ||
| 108 | + {"trt_dump_subgraphs", trt_dump_subgraphs.c_str()}}; | ||
| 110 | // ToDo : Trt configs | 109 | // ToDo : Trt configs |
| 111 | // "trt_int8_enable" | 110 | // "trt_int8_enable" |
| 112 | // "trt_int8_use_native_calibration_table" | 111 | // "trt_int8_use_native_calibration_table" |
| @@ -151,9 +150,8 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | @@ -151,9 +150,8 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | ||
| 151 | 150 | ||
| 152 | if (provider_config != nullptr) { | 151 | if (provider_config != nullptr) { |
| 153 | options.device_id = provider_config->device; | 152 | options.device_id = provider_config->device; |
| 154 | - options.cudnn_conv_algo_search = | ||
| 155 | - OrtCudnnConvAlgoSearch(provider_config->cuda_config | ||
| 156 | - .cudnn_conv_algo_search); | 153 | + options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch( |
| 154 | + provider_config->cuda_config.cudnn_conv_algo_search); | ||
| 157 | } else { | 155 | } else { |
| 158 | options.device_id = 0; | 156 | options.device_id = 0; |
| 159 | // Default OrtCudnnConvAlgoSearchExhaustive is extremely slow | 157 | // Default OrtCudnnConvAlgoSearchExhaustive is extremely slow |
| @@ -219,22 +217,24 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | @@ -219,22 +217,24 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | ||
| 219 | 217 | ||
| 220 | Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config) { | 218 | Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config) { |
| 221 | return GetSessionOptionsImpl(config.num_threads, | 219 | return GetSessionOptionsImpl(config.num_threads, |
| 222 | - config.provider_config.provider, &config.provider_config); | 220 | + config.provider_config.provider, |
| 221 | + &config.provider_config); | ||
| 223 | } | 222 | } |
| 224 | 223 | ||
| 225 | Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, | 224 | Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, |
| 226 | - const std::string &model_type) { | 225 | + const std::string &model_type) { |
| 227 | /* | 226 | /* |
| 228 | Transducer models : Only encoder will run with tensorrt, | 227 | Transducer models : Only encoder will run with tensorrt, |
| 229 | decoder and joiner will run with cuda | 228 | decoder and joiner will run with cuda |
| 230 | */ | 229 | */ |
| 231 | - if(config.provider_config.provider == "trt" && | 230 | + if (config.provider_config.provider == "trt" && |
| 232 | (model_type == "decoder" || model_type == "joiner")) { | 231 | (model_type == "decoder" || model_type == "joiner")) { |
| 233 | - return GetSessionOptionsImpl(config.num_threads, | ||
| 234 | - "cuda", &config.provider_config); | 232 | + return GetSessionOptionsImpl(config.num_threads, "cuda", |
| 233 | + &config.provider_config); | ||
| 235 | } | 234 | } |
| 236 | return GetSessionOptionsImpl(config.num_threads, | 235 | return GetSessionOptionsImpl(config.num_threads, |
| 237 | - config.provider_config.provider, &config.provider_config); | 236 | + config.provider_config.provider, |
| 237 | + &config.provider_config); | ||
| 238 | } | 238 | } |
| 239 | 239 | ||
| 240 | Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) { | 240 | Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) { |
| @@ -5,6 +5,8 @@ | @@ -5,6 +5,8 @@ | ||
| 5 | #ifndef SHERPA_ONNX_CSRC_SESSION_H_ | 5 | #ifndef SHERPA_ONNX_CSRC_SESSION_H_ |
| 6 | #define SHERPA_ONNX_CSRC_SESSION_H_ | 6 | #define SHERPA_ONNX_CSRC_SESSION_H_ |
| 7 | 7 | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 8 | #include "onnxruntime_cxx_api.h" // NOLINT | 10 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 9 | #include "sherpa-onnx/csrc/audio-tagging-model-config.h" | 11 | #include "sherpa-onnx/csrc/audio-tagging-model-config.h" |
| 10 | #include "sherpa-onnx/csrc/offline-lm-config.h" | 12 | #include "sherpa-onnx/csrc/offline-lm-config.h" |
| @@ -25,7 +27,7 @@ namespace sherpa_onnx { | @@ -25,7 +27,7 @@ namespace sherpa_onnx { | ||
| 25 | Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config); | 27 | Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config); |
| 26 | 28 | ||
| 27 | Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, | 29 | Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, |
| 28 | - const std::string &model_type); | 30 | + const std::string &model_type); |
| 29 | 31 | ||
| 30 | Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config); | 32 | Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config); |
| 31 | 33 |
| @@ -11,7 +11,7 @@ | @@ -11,7 +11,7 @@ | ||
| 11 | namespace sherpa_onnx { | 11 | namespace sherpa_onnx { |
| 12 | 12 | ||
| 13 | TEST(UTF8, Case1) { | 13 | TEST(UTF8, Case1) { |
| 14 | - std::string hello = "你好, 早上好!世界. hello!。Hallo"; | 14 | + std::string hello = "你好, 早上好!世界. hello!。Hallo! how are you?"; |
| 15 | std::vector<std::string> ss = SplitUtf8(hello); | 15 | std::vector<std::string> ss = SplitUtf8(hello); |
| 16 | for (const auto &s : ss) { | 16 | for (const auto &s : ss) { |
| 17 | std::cout << s << "\n"; | 17 | std::cout << s << "\n"; |
-
请 注册 或 登录 后发表评论