Fangjun Kuang
Committed by GitHub

Add Kotlin API for Matcha-TTS models. (#1668)

@@ -75,3 +75,8 @@ jobs: @@ -75,3 +75,8 @@ jobs:
75 75
76 cd ./kotlin-api-examples 76 cd ./kotlin-api-examples
77 ./run.sh 77 ./run.sh
  78 +
  79 + - uses: actions/upload-artifact@v4
  80 + with:
  81 + name: tts-files-${{ matrix.os }}
  82 + path: kotlin-api-examples/test-*.wav
@@ -125,3 +125,4 @@ sherpa-onnx-moonshine-tiny-en-int8 @@ -125,3 +125,4 @@ sherpa-onnx-moonshine-tiny-en-int8
125 sherpa-onnx-moonshine-base-en-int8 125 sherpa-onnx-moonshine-base-en-int8
126 harmony-os/SherpaOnnxHar/sherpa_onnx/LICENSE 126 harmony-os/SherpaOnnxHar/sherpa_onnx/LICENSE
127 harmony-os/SherpaOnnxHar/sherpa_onnx/CHANGELOG.md 127 harmony-os/SherpaOnnxHar/sherpa_onnx/CHANGELOG.md
  128 +matcha-icefall-zh-baker
@@ -105,6 +105,16 @@ function testTts() { @@ -105,6 +105,16 @@ function testTts() {
105 rm vits-piper-en_US-amy-low.tar.bz2 105 rm vits-piper-en_US-amy-low.tar.bz2
106 fi 106 fi
107 107
  108 + if [ ! -f ./matcha-icefall-zh-baker/model-steps-3.onnx ]; then
  109 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-zh-baker.tar.bz2
  110 + tar xvf matcha-icefall-zh-baker.tar.bz2
  111 + rm matcha-icefall-zh-baker.tar.bz2
  112 + fi
  113 +
  114 + if [ ! -f ./hifigan_v2.onnx ]; then
  115 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/vocoder-models/hifigan_v2.onnx
  116 + fi
  117 +
108 out_filename=test_tts.jar 118 out_filename=test_tts.jar
109 kotlinc-jvm -include-runtime -d $out_filename \ 119 kotlinc-jvm -include-runtime -d $out_filename \
110 test_tts.kt \ 120 test_tts.kt \
1 package com.k2fsa.sherpa.onnx 1 package com.k2fsa.sherpa.onnx
2 2
3 fun main() { 3 fun main() {
4 - testTts() 4 + testVits()
  5 + testMatcha()
5 } 6 }
6 7
7 -fun testTts() { 8 +fun testMatcha() {
  9 + // see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models
  10 + // https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-zh-baker.tar.bz2
  11 + var config = OfflineTtsConfig(
  12 + model=OfflineTtsModelConfig(
  13 + matcha=OfflineTtsMatchaModelConfig(
  14 + acousticModel="./matcha-icefall-zh-baker/model-steps-3.onnx",
  15 + vocoder="./hifigan_v2.onnx",
  16 + tokens="./matcha-icefall-zh-baker/tokens.txt",
  17 + lexicon="./matcha-icefall-zh-baker/lexicon.txt",
  18 + dictDir="./matcha-icefall-zh-baker/dict",
  19 + ),
  20 + numThreads=1,
  21 + debug=true,
  22 + ),
  23 + ruleFsts="./matcha-icefall-zh-baker/phone.fst,./matcha-icefall-zh-baker/date.fst,./matcha-icefall-zh-baker/number.fst",
  24 + )
  25 + val tts = OfflineTts(config=config)
  26 + val audio = tts.generateWithCallback(text="某某银行的副行长和一些行政领导表示,他们去过长江和长白山; 经济不断增长。2024年12月31号,拨打110或者18920240511。123456块钱。", callback=::callback)
  27 + audio.save(filename="test-zh.wav")
  28 + tts.release()
  29 + println("Saved to test-zh.wav")
  30 +}
  31 +
  32 +fun testVits() {
8 // see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models 33 // see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models
9 // https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 34 // https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2
10 var config = OfflineTtsConfig( 35 var config = OfflineTtsConfig(
@@ -1727,11 +1727,15 @@ const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation( @@ -1727,11 +1727,15 @@ const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation(
1727 auto p = new SherpaOnnxOnlinePunctuation; 1727 auto p = new SherpaOnnxOnlinePunctuation;
1728 try { 1728 try {
1729 sherpa_onnx::OnlinePunctuationConfig punctuation_config; 1729 sherpa_onnx::OnlinePunctuationConfig punctuation_config;
1730 - punctuation_config.model.cnn_bilstm = SHERPA_ONNX_OR(config->model.cnn_bilstm, "");  
1731 - punctuation_config.model.bpe_vocab = SHERPA_ONNX_OR(config->model.bpe_vocab, "");  
1732 - punctuation_config.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1); 1730 + punctuation_config.model.cnn_bilstm =
  1731 + SHERPA_ONNX_OR(config->model.cnn_bilstm, "");
  1732 + punctuation_config.model.bpe_vocab =
  1733 + SHERPA_ONNX_OR(config->model.bpe_vocab, "");
  1734 + punctuation_config.model.num_threads =
  1735 + SHERPA_ONNX_OR(config->model.num_threads, 1);
1733 punctuation_config.model.debug = config->model.debug; 1736 punctuation_config.model.debug = config->model.debug;
1734 - punctuation_config.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu"); 1737 + punctuation_config.model.provider =
  1738 + SHERPA_ONNX_OR(config->model.provider, "cpu");
1735 1739
1736 p->impl = 1740 p->impl =
1737 std::make_unique<sherpa_onnx::OnlinePunctuation>(punctuation_config); 1741 std::make_unique<sherpa_onnx::OnlinePunctuation>(punctuation_config);
@@ -1381,12 +1381,14 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuationConfig { @@ -1381,12 +1381,14 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuationConfig {
1381 SherpaOnnxOnlinePunctuationModelConfig model; 1381 SherpaOnnxOnlinePunctuationModelConfig model;
1382 } SherpaOnnxOnlinePunctuationConfig; 1382 } SherpaOnnxOnlinePunctuationConfig;
1383 1383
1384 -SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuation SherpaOnnxOnlinePunctuation; 1384 +SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuation
  1385 + SherpaOnnxOnlinePunctuation;
1385 1386
1386 // Create an online punctuation processor. The user has to invoke 1387 // Create an online punctuation processor. The user has to invoke
1387 // SherpaOnnxDestroyOnlinePunctuation() to free the returned pointer 1388 // SherpaOnnxDestroyOnlinePunctuation() to free the returned pointer
1388 // to avoid memory leak 1389 // to avoid memory leak
1389 -SHERPA_ONNX_API const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation( 1390 +SHERPA_ONNX_API const SherpaOnnxOnlinePunctuation *
  1391 +SherpaOnnxCreateOnlinePunctuation(
1390 const SherpaOnnxOnlinePunctuationConfig *config); 1392 const SherpaOnnxOnlinePunctuationConfig *config);
1391 1393
1392 // Free a pointer returned by SherpaOnnxCreateOnlinePunctuation() 1394 // Free a pointer returned by SherpaOnnxCreateOnlinePunctuation()
@@ -155,7 +155,7 @@ class JiebaLexicon::Impl { @@ -155,7 +155,7 @@ class JiebaLexicon::Impl {
155 155
156 this_sentence.insert(this_sentence.end(), ids.begin(), ids.end()); 156 this_sentence.insert(this_sentence.end(), ids.begin(), ids.end());
157 157
158 - if (w == "。" || w == "!" || w == "?" || w == ",") { 158 + if (IsPunct(w)) {
159 ans.emplace_back(std::move(this_sentence)); 159 ans.emplace_back(std::move(this_sentence));
160 this_sentence = {}; 160 this_sentence = {};
161 } 161 }
@@ -20,6 +20,7 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) { @@ -20,6 +20,7 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
20 jobject model = env->GetObjectField(config, fid); 20 jobject model = env->GetObjectField(config, fid);
21 jclass model_config_cls = env->GetObjectClass(model); 21 jclass model_config_cls = env->GetObjectClass(model);
22 22
  23 + // vits
23 fid = env->GetFieldID(model_config_cls, "vits", 24 fid = env->GetFieldID(model_config_cls, "vits",
24 "Lcom/k2fsa/sherpa/onnx/OfflineTtsVitsModelConfig;"); 25 "Lcom/k2fsa/sherpa/onnx/OfflineTtsVitsModelConfig;");
25 jobject vits = env->GetObjectField(model, fid); 26 jobject vits = env->GetObjectField(model, fid);
@@ -64,6 +65,54 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) { @@ -64,6 +65,54 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
64 fid = env->GetFieldID(vits_cls, "lengthScale", "F"); 65 fid = env->GetFieldID(vits_cls, "lengthScale", "F");
65 ans.model.vits.length_scale = env->GetFloatField(vits, fid); 66 ans.model.vits.length_scale = env->GetFloatField(vits, fid);
66 67
  68 + // matcha
  69 + fid = env->GetFieldID(model_config_cls, "matcha",
  70 + "Lcom/k2fsa/sherpa/onnx/OfflineTtsMatchaModelConfig;");
  71 + jobject matcha = env->GetObjectField(model, fid);
  72 + jclass matcha_cls = env->GetObjectClass(matcha);
  73 +
  74 + fid = env->GetFieldID(matcha_cls, "acousticModel", "Ljava/lang/String;");
  75 + s = (jstring)env->GetObjectField(matcha, fid);
  76 + p = env->GetStringUTFChars(s, nullptr);
  77 + ans.model.matcha.acoustic_model = p;
  78 + env->ReleaseStringUTFChars(s, p);
  79 +
  80 + fid = env->GetFieldID(matcha_cls, "vocoder", "Ljava/lang/String;");
  81 + s = (jstring)env->GetObjectField(matcha, fid);
  82 + p = env->GetStringUTFChars(s, nullptr);
  83 + ans.model.matcha.vocoder = p;
  84 + env->ReleaseStringUTFChars(s, p);
  85 +
  86 + fid = env->GetFieldID(matcha_cls, "lexicon", "Ljava/lang/String;");
  87 + s = (jstring)env->GetObjectField(matcha, fid);
  88 + p = env->GetStringUTFChars(s, nullptr);
  89 + ans.model.matcha.lexicon = p;
  90 + env->ReleaseStringUTFChars(s, p);
  91 +
  92 + fid = env->GetFieldID(matcha_cls, "tokens", "Ljava/lang/String;");
  93 + s = (jstring)env->GetObjectField(matcha, fid);
  94 + p = env->GetStringUTFChars(s, nullptr);
  95 + ans.model.matcha.tokens = p;
  96 + env->ReleaseStringUTFChars(s, p);
  97 +
  98 + fid = env->GetFieldID(matcha_cls, "dataDir", "Ljava/lang/String;");
  99 + s = (jstring)env->GetObjectField(matcha, fid);
  100 + p = env->GetStringUTFChars(s, nullptr);
  101 + ans.model.matcha.data_dir = p;
  102 + env->ReleaseStringUTFChars(s, p);
  103 +
  104 + fid = env->GetFieldID(matcha_cls, "dictDir", "Ljava/lang/String;");
  105 + s = (jstring)env->GetObjectField(matcha, fid);
  106 + p = env->GetStringUTFChars(s, nullptr);
  107 + ans.model.matcha.dict_dir = p;
  108 + env->ReleaseStringUTFChars(s, p);
  109 +
  110 + fid = env->GetFieldID(matcha_cls, "noiseScale", "F");
  111 + ans.model.matcha.noise_scale = env->GetFloatField(matcha, fid);
  112 +
  113 + fid = env->GetFieldID(matcha_cls, "lengthScale", "F");
  114 + ans.model.matcha.length_scale = env->GetFloatField(matcha, fid);
  115 +
67 fid = env->GetFieldID(model_config_cls, "numThreads", "I"); 116 fid = env->GetFieldID(model_config_cls, "numThreads", "I");
68 ans.model.num_threads = env->GetIntField(model, fid); 117 ans.model.num_threads = env->GetIntField(model, fid);
69 118
@@ -14,8 +14,20 @@ data class OfflineTtsVitsModelConfig( @@ -14,8 +14,20 @@ data class OfflineTtsVitsModelConfig(
14 var lengthScale: Float = 1.0f, 14 var lengthScale: Float = 1.0f,
15 ) 15 )
16 16
  17 +data class OfflineTtsMatchaModelConfig(
  18 + var acousticModel: String = "",
  19 + var vocoder: String = "",
  20 + var lexicon: String = "",
  21 + var tokens: String = "",
  22 + var dataDir: String = "",
  23 + var dictDir: String = "",
  24 + var noiseScale: Float = 1.0f,
  25 + var lengthScale: Float = 1.0f,
  26 +)
  27 +
17 data class OfflineTtsModelConfig( 28 data class OfflineTtsModelConfig(
18 var vits: OfflineTtsVitsModelConfig = OfflineTtsVitsModelConfig(), 29 var vits: OfflineTtsVitsModelConfig = OfflineTtsVitsModelConfig(),
  30 + var matcha: OfflineTtsMatchaModelConfig = OfflineTtsMatchaModelConfig(),
19 var numThreads: Int = 1, 31 var numThreads: Int = 1,
20 var debug: Boolean = false, 32 var debug: Boolean = false,
21 var provider: String = "cpu", 33 var provider: String = "cpu",