Fangjun Kuang
Committed by GitHub

Add Koltin and Java API for Kokoro TTS models (#1728)

... ... @@ -234,8 +234,12 @@ jobs:
run: |
cd ./java-api-examples
./run-non-streaming-tts-kokoro-en.sh
./run-non-streaming-tts-matcha-zh.sh
./run-non-streaming-tts-matcha-en.sh
ls -lh
rm -rf kokoro-en-*
rm -rf matcha-icefall-*
rm hifigan_v2.onnx
... ...
... ... @@ -185,6 +185,7 @@ class MainActivity : AppCompatActivity() {
var modelName: String?
var acousticModelName: String?
var vocoder: String?
var voices: String?
var ruleFsts: String?
var ruleFars: String?
var lexicon: String?
... ... @@ -205,6 +206,10 @@ class MainActivity : AppCompatActivity() {
vocoder = null
// Matcha -- end
// For Kokoro -- begin
voices = null
// For Kokoro -- end
modelDir = null
ruleFsts = null
... ... @@ -269,6 +274,13 @@ class MainActivity : AppCompatActivity() {
// vocoder = "hifigan_v2.onnx"
// dataDir = "matcha-icefall-en_US-ljspeech/espeak-ng-data"
// Example 9
// kokoro-en-v0_19
// modelDir = "kokoro-en-v0_19"
// modelName = "model.onnx"
// voices = "voices.bin"
// dataDir = "kokoro-en-v0_19/espeak-ng-data"
if (dataDir != null) {
val newDir = copyDataDir(dataDir!!)
dataDir = "$newDir/$dataDir"
... ... @@ -285,6 +297,7 @@ class MainActivity : AppCompatActivity() {
modelName = modelName ?: "",
acousticModelName = acousticModelName ?: "",
vocoder = vocoder ?: "",
voices = voices ?: "",
lexicon = lexicon ?: "",
dataDir = dataDir ?: "",
dictDir = dictDir ?: "",
... ...
... ... @@ -47,7 +47,7 @@ fun getSampleText(lang: String): String {
}
"eng" -> {
text = "This is a text-to-speech engine using next generation Kaldi"
text = "How are you doing today? This is a text-to-speech engine using next generation Kaldi"
}
"est" -> {
... ...
... ... @@ -3,6 +3,10 @@
package com.k2fsa.sherpa.onnx.tts.engine
import PreferenceHelper
import android.media.AudioAttributes
import android.media.AudioFormat
import android.media.AudioManager
import android.media.AudioTrack
import android.media.MediaPlayer
import android.net.Uri
import android.os.Bundle
... ... @@ -36,7 +40,13 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.text.input.KeyboardType
import androidx.compose.ui.unit.dp
import com.k2fsa.sherpa.onnx.tts.engine.ui.theme.SherpaOnnxTtsEngineTheme
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import java.io.File
import kotlin.time.TimeSource
const val TAG = "sherpa-onnx-tts-engine"
... ... @@ -45,9 +55,26 @@ class MainActivity : ComponentActivity() {
private val ttsViewModel: TtsViewModel by viewModels()
private var mediaPlayer: MediaPlayer? = null
// see
// https://developer.android.com/reference/kotlin/android/media/AudioTrack
private lateinit var track: AudioTrack
private var stopped: Boolean = false
private var samplesChannel = Channel<FloatArray>()
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
Log.i(TAG, "Start to initialize TTS")
TtsEngine.createTts(this)
Log.i(TAG, "Finish initializing TTS")
Log.i(TAG, "Start to initialize AudioTrack")
initAudioTrack()
Log.i(TAG, "Finish initializing AudioTrack")
val preferenceHelper = PreferenceHelper(this)
setContent {
SherpaOnnxTtsEngineTheme {
... ... @@ -77,6 +104,11 @@ class MainActivity : ComponentActivity() {
val testTextContent = getSampleText(TtsEngine.lang ?: "")
var testText by remember { mutableStateOf(testTextContent) }
var startEnabled by remember { mutableStateOf(true) }
var playEnabled by remember { mutableStateOf(false) }
var rtfText by remember {
mutableStateOf("")
}
val numSpeakers = TtsEngine.tts!!.numSpeakers()
if (numSpeakers > 1) {
... ... @@ -119,52 +151,117 @@ class MainActivity : ComponentActivity() {
Row {
Button(
modifier = Modifier.padding(20.dp),
enabled = startEnabled,
modifier = Modifier.padding(5.dp),
onClick = {
Log.i(TAG, "Clicked, text: $testText")
if (testText.isBlank() || testText.isEmpty()) {
Toast.makeText(
applicationContext,
"Please input a test sentence",
"Please input some text to generate",
Toast.LENGTH_SHORT
).show()
} else {
val audio = TtsEngine.tts!!.generate(
text = testText,
sid = TtsEngine.speakerId,
speed = TtsEngine.speed,
)
val filename =
application.filesDir.absolutePath + "/generated.wav"
val ok =
audio.samples.isNotEmpty() && audio.save(
filename
)
startEnabled = false
playEnabled = false
stopped = false
if (ok) {
stopMediaPlayer()
mediaPlayer = MediaPlayer.create(
applicationContext,
Uri.fromFile(File(filename))
)
mediaPlayer?.start()
} else {
Log.i(TAG, "Failed to generate or save audio")
track.pause()
track.flush()
track.play()
rtfText = ""
Log.i(TAG, "Started with text $testText")
samplesChannel = Channel<FloatArray>()
CoroutineScope(Dispatchers.IO).launch {
for (samples in samplesChannel) {
track.write(
samples,
0,
samples.size,
AudioTrack.WRITE_BLOCKING
)
if (stopped) {
break
}
}
}
CoroutineScope(Dispatchers.Default).launch {
val timeSource = TimeSource.Monotonic
val startTime = timeSource.markNow()
val audio =
TtsEngine.tts!!.generateWithCallback(
text = testText,
sid = TtsEngine.speakerId,
speed = TtsEngine.speed,
callback = ::callback,
)
val elapsed =
startTime.elapsedNow().inWholeMilliseconds.toFloat() / 1000;
val audioDuration =
audio.samples.size / TtsEngine.tts!!.sampleRate()
.toFloat()
val RTF = String.format(
"Number of threads: %d\nElapsed: %.3f s\nAudio duration: %.3f s\nRTF: %.3f/%.3f = %.3f",
TtsEngine.tts!!.config.model.numThreads,
audioDuration,
elapsed,
elapsed,
audioDuration,
elapsed / audioDuration
)
samplesChannel.close()
val filename =
application.filesDir.absolutePath + "/generated.wav"
val ok =
audio.samples.isNotEmpty() && audio.save(
filename
)
if (ok) {
withContext(Dispatchers.Main) {
startEnabled = true
playEnabled = true
rtfText = RTF
}
}
}.start()
}
}) {
Text("Test")
Text("Start")
}
Button(
modifier = Modifier.padding(20.dp),
modifier = Modifier.padding(5.dp),
enabled = playEnabled,
onClick = {
TtsEngine.speakerId = 0
TtsEngine.speed = 1.0f
testText = ""
stopped = true
track.pause()
track.flush()
onClickPlay()
}) {
Text("Reset")
Text("Play")
}
Button(
modifier = Modifier.padding(5.dp),
onClick = {
onClickStop()
startEnabled = true
}) {
Text("Stop")
}
}
if (rtfText.isNotEmpty()) {
Row {
Text(rtfText)
}
}
}
... ... @@ -185,4 +282,63 @@ class MainActivity : ComponentActivity() {
mediaPlayer?.release()
mediaPlayer = null
}
private fun onClickPlay() {
val filename = application.filesDir.absolutePath + "/generated.wav"
stopMediaPlayer()
mediaPlayer = MediaPlayer.create(
applicationContext,
Uri.fromFile(File(filename))
)
mediaPlayer?.start()
}
private fun onClickStop() {
stopped = true
track.pause()
track.flush()
stopMediaPlayer()
}
// this function is called from C++
private fun callback(samples: FloatArray): Int {
if (!stopped) {
val samplesCopy = samples.copyOf()
CoroutineScope(Dispatchers.IO).launch {
samplesChannel.send(samplesCopy)
}
return 1
} else {
track.stop()
Log.i(TAG, " return 0")
return 0
}
}
private fun initAudioTrack() {
val sampleRate = TtsEngine.tts!!.sampleRate()
val bufLength = AudioTrack.getMinBufferSize(
sampleRate,
AudioFormat.CHANNEL_OUT_MONO,
AudioFormat.ENCODING_PCM_FLOAT
)
Log.i(TAG, "sampleRate: $sampleRate, buffLength: $bufLength")
val attr = AudioAttributes.Builder().setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
.setUsage(AudioAttributes.USAGE_MEDIA)
.build()
val format = AudioFormat.Builder()
.setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
.setSampleRate(sampleRate)
.build()
track = AudioTrack(
attr, format, bufLength, AudioTrack.MODE_STREAM,
AudioManager.AUDIO_SESSION_ID_GENERATE
)
track.play()
}
}
... ...
... ... @@ -41,8 +41,9 @@ object TtsEngine {
private var modelDir: String? = null
private var modelName: String? = null
private var acousticModelName: String? = null
private var vocoder: String? = null
private var acousticModelName: String? = null // for matcha tts
private var vocoder: String? = null // for matcha tts
private var voices: String? = null // for kokoro
private var ruleFsts: String? = null
private var ruleFars: String? = null
private var lexicon: String? = null
... ... @@ -64,6 +65,10 @@ object TtsEngine {
vocoder = null
// For Matcha -- end
// For Kokoro -- begin
voices = null
// For Kokoro -- end
modelDir = null
ruleFsts = null
ruleFars = null
... ... @@ -139,6 +144,14 @@ object TtsEngine {
// vocoder = "hifigan_v2.onnx"
// dataDir = "matcha-icefall-en_US-ljspeech/espeak-ng-data"
// lang = "eng"
// Example 9
// kokoro-en-v0_19
// modelDir = "kokoro-en-v0_19"
// modelName = "model.onnx"
// voices = "voices.bin"
// dataDir = "kokoro-en-v0_19/espeak-ng-data"
// lang = "eng"
}
fun createTts(context: Context) {
... ... @@ -167,6 +180,7 @@ object TtsEngine {
modelName = modelName ?: "",
acousticModelName = acousticModelName ?: "",
vocoder = vocoder ?: "",
voices = voices ?: "",
lexicon = lexicon ?: "",
dataDir = dataDir ?: "",
dictDir = dictDir ?: "",
... ...
<resources>
<string name="app_name">TTS Engine</string>
<string name="app_name">TTS Engine: Next-gen Kaldi</string>
</resources>
\ No newline at end of file
... ...
// Copyright 2025 Xiaomi Corporation
// This file shows how to use a Kokoro English model
// to convert text to speech
import com.k2fsa.sherpa.onnx.*;
public class NonStreamingTtsKokoroEn {
public static void main(String[] args) {
// please visit
// https://k2-fsa.github.io/sherpa/onnx/tts/pretrained_models/kokoro.html
// to download model files
String model = "./kokoro-en-v0_19/model.onnx";
String voices = "./kokoro-en-v0_19/voices.bin";
String tokens = "./kokoro-en-v0_19/tokens.txt";
String dataDir = "./kokoro-en-v0_19/espeak-ng-data";
String text =
"Today as always, men fall into two groups: slaves and free men. Whoever does not have"
+ " two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a"
+ " businessman, an official, or a scholar.";
OfflineTtsKokoroModelConfig kokoroModelConfig =
OfflineTtsKokoroModelConfig.builder()
.setModel(model)
.setVoices(voices)
.setTokens(tokens)
.setDataDir(dataDir)
.build();
OfflineTtsModelConfig modelConfig =
OfflineTtsModelConfig.builder()
.setKokoro(kokoroModelConfig)
.setNumThreads(2)
.setDebug(true)
.build();
OfflineTtsConfig config = OfflineTtsConfig.builder().setModel(modelConfig).build();
OfflineTts tts = new OfflineTts(config);
int sid = 0;
float speed = 1.0f;
long start = System.currentTimeMillis();
GeneratedAudio audio = tts.generate(text, sid, speed);
long stop = System.currentTimeMillis();
float timeElapsedSeconds = (stop - start) / 1000.0f;
float audioDuration = audio.getSamples().length / (float) audio.getSampleRate();
float real_time_factor = timeElapsedSeconds / audioDuration;
String waveFilename = "tts-kokoro-en.wav";
audio.save(waveFilename);
System.out.printf("-- elapsed : %.3f seconds\n", timeElapsedSeconds);
System.out.printf("-- audio duration: %.3f seconds\n", timeElapsedSeconds);
System.out.printf("-- real-time factor (RTF): %.3f\n", real_time_factor);
System.out.printf("-- text: %s\n", text);
System.out.printf("-- Saved to %s\n", waveFilename);
tts.release();
}
}
... ...
#!/usr/bin/env bash
set -ex
if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then
mkdir -p ../build
pushd ../build
cmake \
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
-DSHERPA_ONNX_ENABLE_CHECK=OFF \
-DBUILD_SHARED_LIBS=ON \
-DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
-DSHERPA_ONNX_ENABLE_JNI=ON \
..
make -j4
ls -lh lib
popd
fi
if [ ! -f ../sherpa-onnx/java-api/build/sherpa-onnx.jar ]; then
pushd ../sherpa-onnx/java-api
make
popd
fi
# please visit
# https://k2-fsa.github.io/sherpa/onnx/tts/pretrained_models/kokoro.html
# to download more models
if [ ! -f ./kokoro-en-v0_19/model.onnx ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/kokoro-en-v0_19.tar.bz2
tar xf kokoro-en-v0_19.tar.bz2
rm kokoro-en-v0_19.tar.bz2
fi
java \
-Djava.library.path=$PWD/../build/lib \
-cp ../sherpa-onnx/java-api/build/sherpa-onnx.jar \
NonStreamingTtsKokoroEn.java
... ...
... ... @@ -115,6 +115,12 @@ function testTts() {
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/vocoder-models/hifigan_v2.onnx
fi
if [ ! -f ./kokoro-en-v0_19/model.onnx ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/kokoro-en-v0_19.tar.bz2
tar xf kokoro-en-v0_19.tar.bz2
rm kokoro-en-v0_19.tar.bz2
fi
out_filename=test_tts.jar
kotlinc-jvm -include-runtime -d $out_filename \
test_tts.kt \
... ...
... ... @@ -3,6 +3,28 @@ package com.k2fsa.sherpa.onnx
fun main() {
testVits()
testMatcha()
testKokoro()
}
fun testKokoro() {
// see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models
var config = OfflineTtsConfig(
model=OfflineTtsModelConfig(
kokoro=OfflineTtsKokoroModelConfig(
model="./kokoro-en-v0_19/model.onnx",
voices="./kokoro-en-v0_19/voices.bin",
tokens="./kokoro-en-v0_19/tokens.txt",
dataDir="./kokoro-en-v0_19/espeak-ng-data",
),
numThreads=2,
debug=true,
),
)
val tts = OfflineTts(config=config)
val audio = tts.generateWithCallback(text="How are you doing today?", callback=::callback)
audio.save(filename="test-kokoro-en.wav")
tts.release()
println("Saved to test-kokoro-en.wav")
}
fun testMatcha() {
... ... @@ -24,9 +46,9 @@ fun testMatcha() {
)
val tts = OfflineTts(config=config)
val audio = tts.generateWithCallback(text="某某银行的副行长和一些行政领导表示,他们去过长江和长白山; 经济不断增长。2024年12月31号,拨打110或者18920240511。123456块钱。", callback=::callback)
audio.save(filename="test-zh.wav")
audio.save(filename="test-matcha-zh.wav")
tts.release()
println("Saved to test-zh.wav")
println("Saved to test-matcha-zh.wav")
}
fun testVits() {
... ...
... ... @@ -39,6 +39,7 @@ model_dir={{ tts_model.model_dir }}
model_name={{ tts_model.model_name }}
acoustic_model_name={{ tts_model.acoustic_model_name }}
vocoder={{ tts_model.vocoder }}
voices={{ tts_model.voices }}
lang={{ tts_model.lang }}
lang_iso_639_3={{ tts_model.lang_iso_639_3 }}
... ... @@ -70,6 +71,10 @@ sed -i.bak s/"lang = null"/"lang = \"$lang_iso_639_3\""/ ./TtsEngine.kt
sed -i.bak s/"vocoder = null"/"vocoder = \"$vocoder\""/ ./TtsEngine.kt
{% endif %}
{% if tts_model.voices %}
sed -i.bak s/"voices = null"/"voices = \"$voices\""/ ./TtsEngine.kt
{% endif %}
{% if tts_model.rule_fsts %}
rule_fsts={{ tts_model.rule_fsts }}
sed -i.bak s%"ruleFsts = null"%"ruleFsts = \"$rule_fsts\""% ./TtsEngine.kt
... ...
... ... @@ -39,6 +39,7 @@ model_dir={{ tts_model.model_dir }}
model_name={{ tts_model.model_name }}
acoustic_model_name={{ tts_model.acoustic_model_name }}
vocoder={{ tts_model.vocoder }}
voices={{ tts_model.voices }}
lang={{ tts_model.lang }}
wget -qq https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/$model_dir.tar.bz2
... ... @@ -69,6 +70,9 @@ sed -i.bak s/"modelDir = null"/"modelDir = \"$model_dir\""/ ./MainActivity.kt
sed -i.bak s/"vocoder = null"/"vocoder = \"$vocoder\""/ ./MainActivity.kt
{% endif %}
{% if tts_model.voices %}
sed -i.bak s/"voices = null"/"voices = \"$voices\""/ ./MainActivity.kt
{% endif %}
{% if tts_model.rule_fsts %}
rule_fsts={{ tts_model.rule_fsts }}
... ...
... ... @@ -33,6 +33,7 @@ class TtsModel:
model_name: str = "" # for vits
acoustic_model_name: str = "" # for matcha
vocoder: str = "" # for matcha
voices: str = "" # for kokoro
lang: str = "" # en, zh, fr, de, etc.
rule_fsts: Optional[List[str]] = None
rule_fars: Optional[List[str]] = None
... ... @@ -409,6 +410,21 @@ def get_matcha_models() -> List[TtsModel]:
return chinese_models + english_models
def get_kokoro_models() -> List[TtsModel]:
english_models = [
TtsModel(
model_dir="kokoro-en-v0_19",
model_name="model.onnx",
lang="en",
)
]
for m in english_models:
m.data_dir = f"{m.model_dir}/espeak-ng-data"
m.voices = "voices.bin"
return english_models
def main():
args = get_args()
index = args.index
... ... @@ -421,6 +437,7 @@ def main():
all_model_list += get_mimic3_models()
all_model_list += get_coqui_models()
all_model_list += get_matcha_models()
all_model_list += get_kokoro_models()
convert_lang_to_iso_639_3(all_model_list)
print(all_model_list)
... ...
... ... @@ -35,6 +35,7 @@ java_files += OfflineRecognizerResult.java
java_files += OfflineStream.java
java_files += OfflineRecognizer.java
java_files += OfflineTtsKokoroModelConfig.java
java_files += OfflineTtsMatchaModelConfig.java
java_files += OfflineTtsVitsModelConfig.java
java_files += OfflineTtsModelConfig.java
... ...
// Copyright 2025 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OfflineTtsKokoroModelConfig {
private final String model;
private final String voices;
private final String tokens;
private final String dataDir;
private final float lengthScale;
private OfflineTtsKokoroModelConfig(Builder builder) {
this.model = builder.model;
this.voices = builder.voices;
this.tokens = builder.tokens;
this.dataDir = builder.dataDir;
this.lengthScale = builder.lengthScale;
}
public static Builder builder() {
return new Builder();
}
public String getModel() {
return model;
}
public String getVoices() {
return voices;
}
public String getTokens() {
return tokens;
}
public String getDataDir() {
return dataDir;
}
public float getLengthScale() {
return lengthScale;
}
public static class Builder {
private String model = "";
private String voices = "";
private String tokens = "";
private String dataDir = "";
private float lengthScale = 1.0f;
public OfflineTtsKokoroModelConfig build() {
return new OfflineTtsKokoroModelConfig(this);
}
public Builder setModel(String model) {
this.model = model;
return this;
}
public Builder setVoices(String voices) {
this.voices = voices;
return this;
}
public Builder setTokens(String tokens) {
this.tokens = tokens;
return this;
}
public Builder setDataDir(String dataDir) {
this.dataDir = dataDir;
return this;
}
public Builder setLengthScale(float lengthScale) {
this.lengthScale = lengthScale;
return this;
}
}
}
... ...
... ... @@ -5,6 +5,7 @@ package com.k2fsa.sherpa.onnx;
public class OfflineTtsModelConfig {
private final OfflineTtsVitsModelConfig vits;
private final OfflineTtsMatchaModelConfig matcha;
private final OfflineTtsKokoroModelConfig kokoro;
private final int numThreads;
private final boolean debug;
private final String provider;
... ... @@ -12,6 +13,7 @@ public class OfflineTtsModelConfig {
private OfflineTtsModelConfig(Builder builder) {
this.vits = builder.vits;
this.matcha = builder.matcha;
this.kokoro = builder.kokoro;
this.numThreads = builder.numThreads;
this.debug = builder.debug;
this.provider = builder.provider;
... ... @@ -29,9 +31,14 @@ public class OfflineTtsModelConfig {
return matcha;
}
public OfflineTtsKokoroModelConfig getKokoro() {
return kokoro;
}
public static class Builder {
private OfflineTtsVitsModelConfig vits = OfflineTtsVitsModelConfig.builder().build();
private OfflineTtsMatchaModelConfig matcha = OfflineTtsMatchaModelConfig.builder().build();
private OfflineTtsKokoroModelConfig kokoro = OfflineTtsKokoroModelConfig.builder().build();
private int numThreads = 1;
private boolean debug = true;
private String provider = "cpu";
... ... @@ -50,6 +57,11 @@ public class OfflineTtsModelConfig {
return this;
}
public Builder setKokoro(OfflineTtsKokoroModelConfig kokoro) {
this.kokoro = kokoro;
return this;
}
public Builder setNumThreads(int numThreads) {
this.numThreads = numThreads;
return this;
... ...
... ... @@ -113,6 +113,39 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
fid = env->GetFieldID(matcha_cls, "lengthScale", "F");
ans.model.matcha.length_scale = env->GetFloatField(matcha, fid);
// kokoro
fid = env->GetFieldID(model_config_cls, "kokoro",
"Lcom/k2fsa/sherpa/onnx/OfflineTtsKokoroModelConfig;");
jobject kokoro = env->GetObjectField(model, fid);
jclass kokoro_cls = env->GetObjectClass(kokoro);
fid = env->GetFieldID(kokoro_cls, "model", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(kokoro, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.kokoro.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(kokoro_cls, "voices", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(kokoro, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.kokoro.voices = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(kokoro_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(kokoro, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.kokoro.tokens = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(kokoro_cls, "dataDir", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(kokoro, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.kokoro.data_dir = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(kokoro_cls, "lengthScale", "F");
ans.model.kokoro.length_scale = env->GetFloatField(kokoro, fid);
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model.num_threads = env->GetIntField(model, fid);
... ... @@ -273,8 +306,8 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl(
return env->CallIntMethod(should_continue, int_value_mid);
};
auto audio = reinterpret_cast<sherpa_onnx::OfflineTts *>(ptr)->Generate(
p_text, sid, speed, callback_wrapper);
auto tts = reinterpret_cast<sherpa_onnx::OfflineTts *>(ptr);
auto audio = tts->Generate(p_text, sid, speed, callback_wrapper);
jfloatArray samples_arr = env->NewFloatArray(audio.samples.size());
env->SetFloatArrayRegion(samples_arr, 0, audio.samples.size(),
... ...
... ... @@ -25,9 +25,18 @@ data class OfflineTtsMatchaModelConfig(
var lengthScale: Float = 1.0f,
)
data class OfflineTtsKokoroModelConfig(
var model: String = "",
var voices: String = "",
var tokens: String = "",
var dataDir: String = "",
var lengthScale: Float = 1.0f,
)
data class OfflineTtsModelConfig(
var vits: OfflineTtsVitsModelConfig = OfflineTtsVitsModelConfig(),
var matcha: OfflineTtsMatchaModelConfig = OfflineTtsMatchaModelConfig(),
var kokoro: OfflineTtsKokoroModelConfig = OfflineTtsKokoroModelConfig(),
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
... ... @@ -176,12 +185,32 @@ fun getOfflineTtsConfig(
modelName: String, // for VITS
acousticModelName: String, // for Matcha
vocoder: String, // for Matcha
voices: String, // for Kokoro
lexicon: String,
dataDir: String,
dictDir: String,
ruleFsts: String,
ruleFars: String
ruleFars: String,
numThreads: Int? = null
): OfflineTtsConfig {
// For Matcha TTS, please set
// acousticModelName, vocoder
// For Kokoro TTS, please set
// modelName, voices
// For VITS, please set
// modelName
val numberOfThreads = if (numThreads != null) {
numThreads
} else if (voices.isNotEmpty()) {
// for Kokoro TTS models, we use more threads
4
} else {
2
}
if (modelName.isEmpty() && acousticModelName.isEmpty()) {
throw IllegalArgumentException("Please specify a TTS model")
}
... ... @@ -193,7 +222,8 @@ fun getOfflineTtsConfig(
if (acousticModelName.isNotEmpty() && vocoder.isEmpty()) {
throw IllegalArgumentException("Please provide vocoder for Matcha TTS")
}
val vits = if (modelName.isNotEmpty()) {
val vits = if (modelName.isNotEmpty() && voices.isEmpty()) {
OfflineTtsVitsModelConfig(
model = "$modelDir/$modelName",
lexicon = "$modelDir/$lexicon",
... ... @@ -218,11 +248,23 @@ fun getOfflineTtsConfig(
OfflineTtsMatchaModelConfig()
}
val kokoro = if (voices.isNotEmpty()) {
OfflineTtsKokoroModelConfig(
model = "$modelDir/$modelName",
voices = "$modelDir/$voices",
tokens = "$modelDir/tokens.txt",
dataDir = dataDir,
)
} else {
OfflineTtsKokoroModelConfig()
}
return OfflineTtsConfig(
model = OfflineTtsModelConfig(
vits = vits,
matcha = matcha,
numThreads = 2,
kokoro = kokoro,
numThreads = numberOfThreads,
debug = true,
provider = "cpu",
),
... ...