Committed by
GitHub
Add Kotlin API for Matcha-TTS models. (#1668)
正在显示
9 个修改的文件
包含
117 行增加
和
9 行删除
| @@ -125,3 +125,4 @@ sherpa-onnx-moonshine-tiny-en-int8 | @@ -125,3 +125,4 @@ sherpa-onnx-moonshine-tiny-en-int8 | ||
| 125 | sherpa-onnx-moonshine-base-en-int8 | 125 | sherpa-onnx-moonshine-base-en-int8 |
| 126 | harmony-os/SherpaOnnxHar/sherpa_onnx/LICENSE | 126 | harmony-os/SherpaOnnxHar/sherpa_onnx/LICENSE |
| 127 | harmony-os/SherpaOnnxHar/sherpa_onnx/CHANGELOG.md | 127 | harmony-os/SherpaOnnxHar/sherpa_onnx/CHANGELOG.md |
| 128 | +matcha-icefall-zh-baker |
| @@ -105,6 +105,16 @@ function testTts() { | @@ -105,6 +105,16 @@ function testTts() { | ||
| 105 | rm vits-piper-en_US-amy-low.tar.bz2 | 105 | rm vits-piper-en_US-amy-low.tar.bz2 |
| 106 | fi | 106 | fi |
| 107 | 107 | ||
| 108 | + if [ ! -f ./matcha-icefall-zh-baker/model-steps-3.onnx ]; then | ||
| 109 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-zh-baker.tar.bz2 | ||
| 110 | + tar xvf matcha-icefall-zh-baker.tar.bz2 | ||
| 111 | + rm matcha-icefall-zh-baker.tar.bz2 | ||
| 112 | + fi | ||
| 113 | + | ||
| 114 | + if [ ! -f ./hifigan_v2.onnx ]; then | ||
| 115 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/vocoder-models/hifigan_v2.onnx | ||
| 116 | + fi | ||
| 117 | + | ||
| 108 | out_filename=test_tts.jar | 118 | out_filename=test_tts.jar |
| 109 | kotlinc-jvm -include-runtime -d $out_filename \ | 119 | kotlinc-jvm -include-runtime -d $out_filename \ |
| 110 | test_tts.kt \ | 120 | test_tts.kt \ |
| 1 | package com.k2fsa.sherpa.onnx | 1 | package com.k2fsa.sherpa.onnx |
| 2 | 2 | ||
| 3 | fun main() { | 3 | fun main() { |
| 4 | - testTts() | 4 | + testVits() |
| 5 | + testMatcha() | ||
| 5 | } | 6 | } |
| 6 | 7 | ||
| 7 | -fun testTts() { | 8 | +fun testMatcha() { |
| 9 | + // see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models | ||
| 10 | + // https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-zh-baker.tar.bz2 | ||
| 11 | + var config = OfflineTtsConfig( | ||
| 12 | + model=OfflineTtsModelConfig( | ||
| 13 | + matcha=OfflineTtsMatchaModelConfig( | ||
| 14 | + acousticModel="./matcha-icefall-zh-baker/model-steps-3.onnx", | ||
| 15 | + vocoder="./hifigan_v2.onnx", | ||
| 16 | + tokens="./matcha-icefall-zh-baker/tokens.txt", | ||
| 17 | + lexicon="./matcha-icefall-zh-baker/lexicon.txt", | ||
| 18 | + dictDir="./matcha-icefall-zh-baker/dict", | ||
| 19 | + ), | ||
| 20 | + numThreads=1, | ||
| 21 | + debug=true, | ||
| 22 | + ), | ||
| 23 | + ruleFsts="./matcha-icefall-zh-baker/phone.fst,./matcha-icefall-zh-baker/date.fst,./matcha-icefall-zh-baker/number.fst", | ||
| 24 | + ) | ||
| 25 | + val tts = OfflineTts(config=config) | ||
| 26 | + val audio = tts.generateWithCallback(text="某某银行的副行长和一些行政领导表示,他们去过长江和长白山; 经济不断增长。2024年12月31号,拨打110或者18920240511。123456块钱。", callback=::callback) | ||
| 27 | + audio.save(filename="test-zh.wav") | ||
| 28 | + tts.release() | ||
| 29 | + println("Saved to test-zh.wav") | ||
| 30 | +} | ||
| 31 | + | ||
| 32 | +fun testVits() { | ||
| 8 | // see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models | 33 | // see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models |
| 9 | // https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 | 34 | // https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 |
| 10 | var config = OfflineTtsConfig( | 35 | var config = OfflineTtsConfig( |
| @@ -1727,11 +1727,15 @@ const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation( | @@ -1727,11 +1727,15 @@ const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation( | ||
| 1727 | auto p = new SherpaOnnxOnlinePunctuation; | 1727 | auto p = new SherpaOnnxOnlinePunctuation; |
| 1728 | try { | 1728 | try { |
| 1729 | sherpa_onnx::OnlinePunctuationConfig punctuation_config; | 1729 | sherpa_onnx::OnlinePunctuationConfig punctuation_config; |
| 1730 | - punctuation_config.model.cnn_bilstm = SHERPA_ONNX_OR(config->model.cnn_bilstm, ""); | ||
| 1731 | - punctuation_config.model.bpe_vocab = SHERPA_ONNX_OR(config->model.bpe_vocab, ""); | ||
| 1732 | - punctuation_config.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1); | 1730 | + punctuation_config.model.cnn_bilstm = |
| 1731 | + SHERPA_ONNX_OR(config->model.cnn_bilstm, ""); | ||
| 1732 | + punctuation_config.model.bpe_vocab = | ||
| 1733 | + SHERPA_ONNX_OR(config->model.bpe_vocab, ""); | ||
| 1734 | + punctuation_config.model.num_threads = | ||
| 1735 | + SHERPA_ONNX_OR(config->model.num_threads, 1); | ||
| 1733 | punctuation_config.model.debug = config->model.debug; | 1736 | punctuation_config.model.debug = config->model.debug; |
| 1734 | - punctuation_config.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu"); | 1737 | + punctuation_config.model.provider = |
| 1738 | + SHERPA_ONNX_OR(config->model.provider, "cpu"); | ||
| 1735 | 1739 | ||
| 1736 | p->impl = | 1740 | p->impl = |
| 1737 | std::make_unique<sherpa_onnx::OnlinePunctuation>(punctuation_config); | 1741 | std::make_unique<sherpa_onnx::OnlinePunctuation>(punctuation_config); |
| @@ -1381,12 +1381,14 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuationConfig { | @@ -1381,12 +1381,14 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuationConfig { | ||
| 1381 | SherpaOnnxOnlinePunctuationModelConfig model; | 1381 | SherpaOnnxOnlinePunctuationModelConfig model; |
| 1382 | } SherpaOnnxOnlinePunctuationConfig; | 1382 | } SherpaOnnxOnlinePunctuationConfig; |
| 1383 | 1383 | ||
| 1384 | -SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuation SherpaOnnxOnlinePunctuation; | 1384 | +SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuation |
| 1385 | + SherpaOnnxOnlinePunctuation; | ||
| 1385 | 1386 | ||
| 1386 | // Create an online punctuation processor. The user has to invoke | 1387 | // Create an online punctuation processor. The user has to invoke |
| 1387 | // SherpaOnnxDestroyOnlinePunctuation() to free the returned pointer | 1388 | // SherpaOnnxDestroyOnlinePunctuation() to free the returned pointer |
| 1388 | // to avoid memory leak | 1389 | // to avoid memory leak |
| 1389 | -SHERPA_ONNX_API const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation( | 1390 | +SHERPA_ONNX_API const SherpaOnnxOnlinePunctuation * |
| 1391 | +SherpaOnnxCreateOnlinePunctuation( | ||
| 1390 | const SherpaOnnxOnlinePunctuationConfig *config); | 1392 | const SherpaOnnxOnlinePunctuationConfig *config); |
| 1391 | 1393 | ||
| 1392 | // Free a pointer returned by SherpaOnnxCreateOnlinePunctuation() | 1394 | // Free a pointer returned by SherpaOnnxCreateOnlinePunctuation() |
| @@ -155,7 +155,7 @@ class JiebaLexicon::Impl { | @@ -155,7 +155,7 @@ class JiebaLexicon::Impl { | ||
| 155 | 155 | ||
| 156 | this_sentence.insert(this_sentence.end(), ids.begin(), ids.end()); | 156 | this_sentence.insert(this_sentence.end(), ids.begin(), ids.end()); |
| 157 | 157 | ||
| 158 | - if (w == "。" || w == "!" || w == "?" || w == ",") { | 158 | + if (IsPunct(w)) { |
| 159 | ans.emplace_back(std::move(this_sentence)); | 159 | ans.emplace_back(std::move(this_sentence)); |
| 160 | this_sentence = {}; | 160 | this_sentence = {}; |
| 161 | } | 161 | } |
| @@ -20,6 +20,7 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) { | @@ -20,6 +20,7 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) { | ||
| 20 | jobject model = env->GetObjectField(config, fid); | 20 | jobject model = env->GetObjectField(config, fid); |
| 21 | jclass model_config_cls = env->GetObjectClass(model); | 21 | jclass model_config_cls = env->GetObjectClass(model); |
| 22 | 22 | ||
| 23 | + // vits | ||
| 23 | fid = env->GetFieldID(model_config_cls, "vits", | 24 | fid = env->GetFieldID(model_config_cls, "vits", |
| 24 | "Lcom/k2fsa/sherpa/onnx/OfflineTtsVitsModelConfig;"); | 25 | "Lcom/k2fsa/sherpa/onnx/OfflineTtsVitsModelConfig;"); |
| 25 | jobject vits = env->GetObjectField(model, fid); | 26 | jobject vits = env->GetObjectField(model, fid); |
| @@ -64,6 +65,54 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) { | @@ -64,6 +65,54 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) { | ||
| 64 | fid = env->GetFieldID(vits_cls, "lengthScale", "F"); | 65 | fid = env->GetFieldID(vits_cls, "lengthScale", "F"); |
| 65 | ans.model.vits.length_scale = env->GetFloatField(vits, fid); | 66 | ans.model.vits.length_scale = env->GetFloatField(vits, fid); |
| 66 | 67 | ||
| 68 | + // matcha | ||
| 69 | + fid = env->GetFieldID(model_config_cls, "matcha", | ||
| 70 | + "Lcom/k2fsa/sherpa/onnx/OfflineTtsMatchaModelConfig;"); | ||
| 71 | + jobject matcha = env->GetObjectField(model, fid); | ||
| 72 | + jclass matcha_cls = env->GetObjectClass(matcha); | ||
| 73 | + | ||
| 74 | + fid = env->GetFieldID(matcha_cls, "acousticModel", "Ljava/lang/String;"); | ||
| 75 | + s = (jstring)env->GetObjectField(matcha, fid); | ||
| 76 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 77 | + ans.model.matcha.acoustic_model = p; | ||
| 78 | + env->ReleaseStringUTFChars(s, p); | ||
| 79 | + | ||
| 80 | + fid = env->GetFieldID(matcha_cls, "vocoder", "Ljava/lang/String;"); | ||
| 81 | + s = (jstring)env->GetObjectField(matcha, fid); | ||
| 82 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 83 | + ans.model.matcha.vocoder = p; | ||
| 84 | + env->ReleaseStringUTFChars(s, p); | ||
| 85 | + | ||
| 86 | + fid = env->GetFieldID(matcha_cls, "lexicon", "Ljava/lang/String;"); | ||
| 87 | + s = (jstring)env->GetObjectField(matcha, fid); | ||
| 88 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 89 | + ans.model.matcha.lexicon = p; | ||
| 90 | + env->ReleaseStringUTFChars(s, p); | ||
| 91 | + | ||
| 92 | + fid = env->GetFieldID(matcha_cls, "tokens", "Ljava/lang/String;"); | ||
| 93 | + s = (jstring)env->GetObjectField(matcha, fid); | ||
| 94 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 95 | + ans.model.matcha.tokens = p; | ||
| 96 | + env->ReleaseStringUTFChars(s, p); | ||
| 97 | + | ||
| 98 | + fid = env->GetFieldID(matcha_cls, "dataDir", "Ljava/lang/String;"); | ||
| 99 | + s = (jstring)env->GetObjectField(matcha, fid); | ||
| 100 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 101 | + ans.model.matcha.data_dir = p; | ||
| 102 | + env->ReleaseStringUTFChars(s, p); | ||
| 103 | + | ||
| 104 | + fid = env->GetFieldID(matcha_cls, "dictDir", "Ljava/lang/String;"); | ||
| 105 | + s = (jstring)env->GetObjectField(matcha, fid); | ||
| 106 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 107 | + ans.model.matcha.dict_dir = p; | ||
| 108 | + env->ReleaseStringUTFChars(s, p); | ||
| 109 | + | ||
| 110 | + fid = env->GetFieldID(matcha_cls, "noiseScale", "F"); | ||
| 111 | + ans.model.matcha.noise_scale = env->GetFloatField(matcha, fid); | ||
| 112 | + | ||
| 113 | + fid = env->GetFieldID(matcha_cls, "lengthScale", "F"); | ||
| 114 | + ans.model.matcha.length_scale = env->GetFloatField(matcha, fid); | ||
| 115 | + | ||
| 67 | fid = env->GetFieldID(model_config_cls, "numThreads", "I"); | 116 | fid = env->GetFieldID(model_config_cls, "numThreads", "I"); |
| 68 | ans.model.num_threads = env->GetIntField(model, fid); | 117 | ans.model.num_threads = env->GetIntField(model, fid); |
| 69 | 118 |
| @@ -14,8 +14,20 @@ data class OfflineTtsVitsModelConfig( | @@ -14,8 +14,20 @@ data class OfflineTtsVitsModelConfig( | ||
| 14 | var lengthScale: Float = 1.0f, | 14 | var lengthScale: Float = 1.0f, |
| 15 | ) | 15 | ) |
| 16 | 16 | ||
| 17 | +data class OfflineTtsMatchaModelConfig( | ||
| 18 | + var acousticModel: String = "", | ||
| 19 | + var vocoder: String = "", | ||
| 20 | + var lexicon: String = "", | ||
| 21 | + var tokens: String = "", | ||
| 22 | + var dataDir: String = "", | ||
| 23 | + var dictDir: String = "", | ||
| 24 | + var noiseScale: Float = 1.0f, | ||
| 25 | + var lengthScale: Float = 1.0f, | ||
| 26 | +) | ||
| 27 | + | ||
| 17 | data class OfflineTtsModelConfig( | 28 | data class OfflineTtsModelConfig( |
| 18 | var vits: OfflineTtsVitsModelConfig = OfflineTtsVitsModelConfig(), | 29 | var vits: OfflineTtsVitsModelConfig = OfflineTtsVitsModelConfig(), |
| 30 | + var matcha: OfflineTtsMatchaModelConfig = OfflineTtsMatchaModelConfig(), | ||
| 19 | var numThreads: Int = 1, | 31 | var numThreads: Int = 1, |
| 20 | var debug: Boolean = false, | 32 | var debug: Boolean = false, |
| 21 | var provider: String = "cpu", | 33 | var provider: String = "cpu", |
-
请 注册 或 登录 后发表评论