Committed by
GitHub
Add Koltin and Java API for Kokoro TTS models (#1728)
正在显示
18 个修改的文件
包含
549 行增加
和
40 行删除
| @@ -234,8 +234,12 @@ jobs: | @@ -234,8 +234,12 @@ jobs: | ||
| 234 | run: | | 234 | run: | |
| 235 | cd ./java-api-examples | 235 | cd ./java-api-examples |
| 236 | 236 | ||
| 237 | + ./run-non-streaming-tts-kokoro-en.sh | ||
| 237 | ./run-non-streaming-tts-matcha-zh.sh | 238 | ./run-non-streaming-tts-matcha-zh.sh |
| 238 | ./run-non-streaming-tts-matcha-en.sh | 239 | ./run-non-streaming-tts-matcha-en.sh |
| 240 | + ls -lh | ||
| 241 | + | ||
| 242 | + rm -rf kokoro-en-* | ||
| 239 | 243 | ||
| 240 | rm -rf matcha-icefall-* | 244 | rm -rf matcha-icefall-* |
| 241 | rm hifigan_v2.onnx | 245 | rm hifigan_v2.onnx |
| @@ -185,6 +185,7 @@ class MainActivity : AppCompatActivity() { | @@ -185,6 +185,7 @@ class MainActivity : AppCompatActivity() { | ||
| 185 | var modelName: String? | 185 | var modelName: String? |
| 186 | var acousticModelName: String? | 186 | var acousticModelName: String? |
| 187 | var vocoder: String? | 187 | var vocoder: String? |
| 188 | + var voices: String? | ||
| 188 | var ruleFsts: String? | 189 | var ruleFsts: String? |
| 189 | var ruleFars: String? | 190 | var ruleFars: String? |
| 190 | var lexicon: String? | 191 | var lexicon: String? |
| @@ -205,6 +206,10 @@ class MainActivity : AppCompatActivity() { | @@ -205,6 +206,10 @@ class MainActivity : AppCompatActivity() { | ||
| 205 | vocoder = null | 206 | vocoder = null |
| 206 | // Matcha -- end | 207 | // Matcha -- end |
| 207 | 208 | ||
| 209 | + // For Kokoro -- begin | ||
| 210 | + voices = null | ||
| 211 | + // For Kokoro -- end | ||
| 212 | + | ||
| 208 | 213 | ||
| 209 | modelDir = null | 214 | modelDir = null |
| 210 | ruleFsts = null | 215 | ruleFsts = null |
| @@ -269,6 +274,13 @@ class MainActivity : AppCompatActivity() { | @@ -269,6 +274,13 @@ class MainActivity : AppCompatActivity() { | ||
| 269 | // vocoder = "hifigan_v2.onnx" | 274 | // vocoder = "hifigan_v2.onnx" |
| 270 | // dataDir = "matcha-icefall-en_US-ljspeech/espeak-ng-data" | 275 | // dataDir = "matcha-icefall-en_US-ljspeech/espeak-ng-data" |
| 271 | 276 | ||
| 277 | + // Example 9 | ||
| 278 | + // kokoro-en-v0_19 | ||
| 279 | + // modelDir = "kokoro-en-v0_19" | ||
| 280 | + // modelName = "model.onnx" | ||
| 281 | + // voices = "voices.bin" | ||
| 282 | + // dataDir = "kokoro-en-v0_19/espeak-ng-data" | ||
| 283 | + | ||
| 272 | if (dataDir != null) { | 284 | if (dataDir != null) { |
| 273 | val newDir = copyDataDir(dataDir!!) | 285 | val newDir = copyDataDir(dataDir!!) |
| 274 | dataDir = "$newDir/$dataDir" | 286 | dataDir = "$newDir/$dataDir" |
| @@ -285,6 +297,7 @@ class MainActivity : AppCompatActivity() { | @@ -285,6 +297,7 @@ class MainActivity : AppCompatActivity() { | ||
| 285 | modelName = modelName ?: "", | 297 | modelName = modelName ?: "", |
| 286 | acousticModelName = acousticModelName ?: "", | 298 | acousticModelName = acousticModelName ?: "", |
| 287 | vocoder = vocoder ?: "", | 299 | vocoder = vocoder ?: "", |
| 300 | + voices = voices ?: "", | ||
| 288 | lexicon = lexicon ?: "", | 301 | lexicon = lexicon ?: "", |
| 289 | dataDir = dataDir ?: "", | 302 | dataDir = dataDir ?: "", |
| 290 | dictDir = dictDir ?: "", | 303 | dictDir = dictDir ?: "", |
| @@ -47,7 +47,7 @@ fun getSampleText(lang: String): String { | @@ -47,7 +47,7 @@ fun getSampleText(lang: String): String { | ||
| 47 | } | 47 | } |
| 48 | 48 | ||
| 49 | "eng" -> { | 49 | "eng" -> { |
| 50 | - text = "This is a text-to-speech engine using next generation Kaldi" | 50 | + text = "How are you doing today? This is a text-to-speech engine using next generation Kaldi" |
| 51 | } | 51 | } |
| 52 | 52 | ||
| 53 | "est" -> { | 53 | "est" -> { |
| @@ -3,6 +3,10 @@ | @@ -3,6 +3,10 @@ | ||
| 3 | package com.k2fsa.sherpa.onnx.tts.engine | 3 | package com.k2fsa.sherpa.onnx.tts.engine |
| 4 | 4 | ||
| 5 | import PreferenceHelper | 5 | import PreferenceHelper |
| 6 | +import android.media.AudioAttributes | ||
| 7 | +import android.media.AudioFormat | ||
| 8 | +import android.media.AudioManager | ||
| 9 | +import android.media.AudioTrack | ||
| 6 | import android.media.MediaPlayer | 10 | import android.media.MediaPlayer |
| 7 | import android.net.Uri | 11 | import android.net.Uri |
| 8 | import android.os.Bundle | 12 | import android.os.Bundle |
| @@ -36,7 +40,13 @@ import androidx.compose.ui.Modifier | @@ -36,7 +40,13 @@ import androidx.compose.ui.Modifier | ||
| 36 | import androidx.compose.ui.text.input.KeyboardType | 40 | import androidx.compose.ui.text.input.KeyboardType |
| 37 | import androidx.compose.ui.unit.dp | 41 | import androidx.compose.ui.unit.dp |
| 38 | import com.k2fsa.sherpa.onnx.tts.engine.ui.theme.SherpaOnnxTtsEngineTheme | 42 | import com.k2fsa.sherpa.onnx.tts.engine.ui.theme.SherpaOnnxTtsEngineTheme |
| 43 | +import kotlinx.coroutines.CoroutineScope | ||
| 44 | +import kotlinx.coroutines.Dispatchers | ||
| 45 | +import kotlinx.coroutines.channels.Channel | ||
| 46 | +import kotlinx.coroutines.launch | ||
| 47 | +import kotlinx.coroutines.withContext | ||
| 39 | import java.io.File | 48 | import java.io.File |
| 49 | +import kotlin.time.TimeSource | ||
| 40 | 50 | ||
| 41 | const val TAG = "sherpa-onnx-tts-engine" | 51 | const val TAG = "sherpa-onnx-tts-engine" |
| 42 | 52 | ||
| @@ -45,9 +55,26 @@ class MainActivity : ComponentActivity() { | @@ -45,9 +55,26 @@ class MainActivity : ComponentActivity() { | ||
| 45 | private val ttsViewModel: TtsViewModel by viewModels() | 55 | private val ttsViewModel: TtsViewModel by viewModels() |
| 46 | 56 | ||
| 47 | private var mediaPlayer: MediaPlayer? = null | 57 | private var mediaPlayer: MediaPlayer? = null |
| 58 | + | ||
| 59 | + // see | ||
| 60 | + // https://developer.android.com/reference/kotlin/android/media/AudioTrack | ||
| 61 | + private lateinit var track: AudioTrack | ||
| 62 | + | ||
| 63 | + private var stopped: Boolean = false | ||
| 64 | + | ||
| 65 | + private var samplesChannel = Channel<FloatArray>() | ||
| 66 | + | ||
| 48 | override fun onCreate(savedInstanceState: Bundle?) { | 67 | override fun onCreate(savedInstanceState: Bundle?) { |
| 49 | super.onCreate(savedInstanceState) | 68 | super.onCreate(savedInstanceState) |
| 69 | + | ||
| 70 | + Log.i(TAG, "Start to initialize TTS") | ||
| 50 | TtsEngine.createTts(this) | 71 | TtsEngine.createTts(this) |
| 72 | + Log.i(TAG, "Finish initializing TTS") | ||
| 73 | + | ||
| 74 | + Log.i(TAG, "Start to initialize AudioTrack") | ||
| 75 | + initAudioTrack() | ||
| 76 | + Log.i(TAG, "Finish initializing AudioTrack") | ||
| 77 | + | ||
| 51 | val preferenceHelper = PreferenceHelper(this) | 78 | val preferenceHelper = PreferenceHelper(this) |
| 52 | setContent { | 79 | setContent { |
| 53 | SherpaOnnxTtsEngineTheme { | 80 | SherpaOnnxTtsEngineTheme { |
| @@ -77,6 +104,11 @@ class MainActivity : ComponentActivity() { | @@ -77,6 +104,11 @@ class MainActivity : ComponentActivity() { | ||
| 77 | val testTextContent = getSampleText(TtsEngine.lang ?: "") | 104 | val testTextContent = getSampleText(TtsEngine.lang ?: "") |
| 78 | 105 | ||
| 79 | var testText by remember { mutableStateOf(testTextContent) } | 106 | var testText by remember { mutableStateOf(testTextContent) } |
| 107 | + var startEnabled by remember { mutableStateOf(true) } | ||
| 108 | + var playEnabled by remember { mutableStateOf(false) } | ||
| 109 | + var rtfText by remember { | ||
| 110 | + mutableStateOf("") | ||
| 111 | + } | ||
| 80 | 112 | ||
| 81 | val numSpeakers = TtsEngine.tts!!.numSpeakers() | 113 | val numSpeakers = TtsEngine.tts!!.numSpeakers() |
| 82 | if (numSpeakers > 1) { | 114 | if (numSpeakers > 1) { |
| @@ -119,52 +151,117 @@ class MainActivity : ComponentActivity() { | @@ -119,52 +151,117 @@ class MainActivity : ComponentActivity() { | ||
| 119 | 151 | ||
| 120 | Row { | 152 | Row { |
| 121 | Button( | 153 | Button( |
| 122 | - modifier = Modifier.padding(20.dp), | 154 | + enabled = startEnabled, |
| 155 | + modifier = Modifier.padding(5.dp), | ||
| 123 | onClick = { | 156 | onClick = { |
| 124 | Log.i(TAG, "Clicked, text: $testText") | 157 | Log.i(TAG, "Clicked, text: $testText") |
| 125 | if (testText.isBlank() || testText.isEmpty()) { | 158 | if (testText.isBlank() || testText.isEmpty()) { |
| 126 | Toast.makeText( | 159 | Toast.makeText( |
| 127 | applicationContext, | 160 | applicationContext, |
| 128 | - "Please input a test sentence", | 161 | + "Please input some text to generate", |
| 129 | Toast.LENGTH_SHORT | 162 | Toast.LENGTH_SHORT |
| 130 | ).show() | 163 | ).show() |
| 131 | } else { | 164 | } else { |
| 132 | - val audio = TtsEngine.tts!!.generate( | ||
| 133 | - text = testText, | ||
| 134 | - sid = TtsEngine.speakerId, | ||
| 135 | - speed = TtsEngine.speed, | ||
| 136 | - ) | ||
| 137 | - | ||
| 138 | - val filename = | ||
| 139 | - application.filesDir.absolutePath + "/generated.wav" | ||
| 140 | - val ok = | ||
| 141 | - audio.samples.isNotEmpty() && audio.save( | ||
| 142 | - filename | ||
| 143 | - ) | 165 | + startEnabled = false |
| 166 | + playEnabled = false | ||
| 167 | + stopped = false | ||
| 144 | 168 | ||
| 145 | - if (ok) { | ||
| 146 | - stopMediaPlayer() | ||
| 147 | - mediaPlayer = MediaPlayer.create( | ||
| 148 | - applicationContext, | ||
| 149 | - Uri.fromFile(File(filename)) | ||
| 150 | - ) | ||
| 151 | - mediaPlayer?.start() | ||
| 152 | - } else { | ||
| 153 | - Log.i(TAG, "Failed to generate or save audio") | 169 | + track.pause() |
| 170 | + track.flush() | ||
| 171 | + track.play() | ||
| 172 | + rtfText = "" | ||
| 173 | + Log.i(TAG, "Started with text $testText") | ||
| 174 | + | ||
| 175 | + samplesChannel = Channel<FloatArray>() | ||
| 176 | + | ||
| 177 | + CoroutineScope(Dispatchers.IO).launch { | ||
| 178 | + for (samples in samplesChannel) { | ||
| 179 | + track.write( | ||
| 180 | + samples, | ||
| 181 | + 0, | ||
| 182 | + samples.size, | ||
| 183 | + AudioTrack.WRITE_BLOCKING | ||
| 184 | + ) | ||
| 185 | + if (stopped) { | ||
| 186 | + break | ||
| 187 | + } | ||
| 188 | + } | ||
| 154 | } | 189 | } |
| 190 | + | ||
| 191 | + CoroutineScope(Dispatchers.Default).launch { | ||
| 192 | + val timeSource = TimeSource.Monotonic | ||
| 193 | + val startTime = timeSource.markNow() | ||
| 194 | + | ||
| 195 | + val audio = | ||
| 196 | + TtsEngine.tts!!.generateWithCallback( | ||
| 197 | + text = testText, | ||
| 198 | + sid = TtsEngine.speakerId, | ||
| 199 | + speed = TtsEngine.speed, | ||
| 200 | + callback = ::callback, | ||
| 201 | + ) | ||
| 202 | + | ||
| 203 | + val elapsed = | ||
| 204 | + startTime.elapsedNow().inWholeMilliseconds.toFloat() / 1000; | ||
| 205 | + val audioDuration = | ||
| 206 | + audio.samples.size / TtsEngine.tts!!.sampleRate() | ||
| 207 | + .toFloat() | ||
| 208 | + val RTF = String.format( | ||
| 209 | + "Number of threads: %d\nElapsed: %.3f s\nAudio duration: %.3f s\nRTF: %.3f/%.3f = %.3f", | ||
| 210 | + TtsEngine.tts!!.config.model.numThreads, | ||
| 211 | + audioDuration, | ||
| 212 | + elapsed, | ||
| 213 | + elapsed, | ||
| 214 | + audioDuration, | ||
| 215 | + elapsed / audioDuration | ||
| 216 | + ) | ||
| 217 | + samplesChannel.close() | ||
| 218 | + | ||
| 219 | + val filename = | ||
| 220 | + application.filesDir.absolutePath + "/generated.wav" | ||
| 221 | + | ||
| 222 | + | ||
| 223 | + val ok = | ||
| 224 | + audio.samples.isNotEmpty() && audio.save( | ||
| 225 | + filename | ||
| 226 | + ) | ||
| 227 | + | ||
| 228 | + if (ok) { | ||
| 229 | + withContext(Dispatchers.Main) { | ||
| 230 | + startEnabled = true | ||
| 231 | + playEnabled = true | ||
| 232 | + rtfText = RTF | ||
| 233 | + } | ||
| 234 | + } | ||
| 235 | + }.start() | ||
| 155 | } | 236 | } |
| 156 | }) { | 237 | }) { |
| 157 | - Text("Test") | 238 | + Text("Start") |
| 158 | } | 239 | } |
| 159 | 240 | ||
| 160 | Button( | 241 | Button( |
| 161 | - modifier = Modifier.padding(20.dp), | 242 | + modifier = Modifier.padding(5.dp), |
| 243 | + enabled = playEnabled, | ||
| 162 | onClick = { | 244 | onClick = { |
| 163 | - TtsEngine.speakerId = 0 | ||
| 164 | - TtsEngine.speed = 1.0f | ||
| 165 | - testText = "" | 245 | + stopped = true |
| 246 | + track.pause() | ||
| 247 | + track.flush() | ||
| 248 | + onClickPlay() | ||
| 166 | }) { | 249 | }) { |
| 167 | - Text("Reset") | 250 | + Text("Play") |
| 251 | + } | ||
| 252 | + | ||
| 253 | + Button( | ||
| 254 | + modifier = Modifier.padding(5.dp), | ||
| 255 | + onClick = { | ||
| 256 | + onClickStop() | ||
| 257 | + startEnabled = true | ||
| 258 | + }) { | ||
| 259 | + Text("Stop") | ||
| 260 | + } | ||
| 261 | + } | ||
| 262 | + if (rtfText.isNotEmpty()) { | ||
| 263 | + Row { | ||
| 264 | + Text(rtfText) | ||
| 168 | } | 265 | } |
| 169 | } | 266 | } |
| 170 | } | 267 | } |
| @@ -185,4 +282,63 @@ class MainActivity : ComponentActivity() { | @@ -185,4 +282,63 @@ class MainActivity : ComponentActivity() { | ||
| 185 | mediaPlayer?.release() | 282 | mediaPlayer?.release() |
| 186 | mediaPlayer = null | 283 | mediaPlayer = null |
| 187 | } | 284 | } |
| 285 | + | ||
| 286 | + private fun onClickPlay() { | ||
| 287 | + val filename = application.filesDir.absolutePath + "/generated.wav" | ||
| 288 | + stopMediaPlayer() | ||
| 289 | + mediaPlayer = MediaPlayer.create( | ||
| 290 | + applicationContext, | ||
| 291 | + Uri.fromFile(File(filename)) | ||
| 292 | + ) | ||
| 293 | + mediaPlayer?.start() | ||
| 294 | + } | ||
| 295 | + | ||
| 296 | + private fun onClickStop() { | ||
| 297 | + stopped = true | ||
| 298 | + track.pause() | ||
| 299 | + track.flush() | ||
| 300 | + | ||
| 301 | + stopMediaPlayer() | ||
| 302 | + } | ||
| 303 | + | ||
| 304 | + // this function is called from C++ | ||
| 305 | + private fun callback(samples: FloatArray): Int { | ||
| 306 | + if (!stopped) { | ||
| 307 | + val samplesCopy = samples.copyOf() | ||
| 308 | + CoroutineScope(Dispatchers.IO).launch { | ||
| 309 | + samplesChannel.send(samplesCopy) | ||
| 310 | + } | ||
| 311 | + return 1 | ||
| 312 | + } else { | ||
| 313 | + track.stop() | ||
| 314 | + Log.i(TAG, " return 0") | ||
| 315 | + return 0 | ||
| 316 | + } | ||
| 317 | + } | ||
| 318 | + | ||
| 319 | + private fun initAudioTrack() { | ||
| 320 | + val sampleRate = TtsEngine.tts!!.sampleRate() | ||
| 321 | + val bufLength = AudioTrack.getMinBufferSize( | ||
| 322 | + sampleRate, | ||
| 323 | + AudioFormat.CHANNEL_OUT_MONO, | ||
| 324 | + AudioFormat.ENCODING_PCM_FLOAT | ||
| 325 | + ) | ||
| 326 | + Log.i(TAG, "sampleRate: $sampleRate, buffLength: $bufLength") | ||
| 327 | + | ||
| 328 | + val attr = AudioAttributes.Builder().setContentType(AudioAttributes.CONTENT_TYPE_SPEECH) | ||
| 329 | + .setUsage(AudioAttributes.USAGE_MEDIA) | ||
| 330 | + .build() | ||
| 331 | + | ||
| 332 | + val format = AudioFormat.Builder() | ||
| 333 | + .setEncoding(AudioFormat.ENCODING_PCM_FLOAT) | ||
| 334 | + .setChannelMask(AudioFormat.CHANNEL_OUT_MONO) | ||
| 335 | + .setSampleRate(sampleRate) | ||
| 336 | + .build() | ||
| 337 | + | ||
| 338 | + track = AudioTrack( | ||
| 339 | + attr, format, bufLength, AudioTrack.MODE_STREAM, | ||
| 340 | + AudioManager.AUDIO_SESSION_ID_GENERATE | ||
| 341 | + ) | ||
| 342 | + track.play() | ||
| 343 | + } | ||
| 188 | } | 344 | } |
| @@ -41,8 +41,9 @@ object TtsEngine { | @@ -41,8 +41,9 @@ object TtsEngine { | ||
| 41 | 41 | ||
| 42 | private var modelDir: String? = null | 42 | private var modelDir: String? = null |
| 43 | private var modelName: String? = null | 43 | private var modelName: String? = null |
| 44 | - private var acousticModelName: String? = null | ||
| 45 | - private var vocoder: String? = null | 44 | + private var acousticModelName: String? = null // for matcha tts |
| 45 | + private var vocoder: String? = null // for matcha tts | ||
| 46 | + private var voices: String? = null // for kokoro | ||
| 46 | private var ruleFsts: String? = null | 47 | private var ruleFsts: String? = null |
| 47 | private var ruleFars: String? = null | 48 | private var ruleFars: String? = null |
| 48 | private var lexicon: String? = null | 49 | private var lexicon: String? = null |
| @@ -64,6 +65,10 @@ object TtsEngine { | @@ -64,6 +65,10 @@ object TtsEngine { | ||
| 64 | vocoder = null | 65 | vocoder = null |
| 65 | // For Matcha -- end | 66 | // For Matcha -- end |
| 66 | 67 | ||
| 68 | + // For Kokoro -- begin | ||
| 69 | + voices = null | ||
| 70 | + // For Kokoro -- end | ||
| 71 | + | ||
| 67 | modelDir = null | 72 | modelDir = null |
| 68 | ruleFsts = null | 73 | ruleFsts = null |
| 69 | ruleFars = null | 74 | ruleFars = null |
| @@ -139,6 +144,14 @@ object TtsEngine { | @@ -139,6 +144,14 @@ object TtsEngine { | ||
| 139 | // vocoder = "hifigan_v2.onnx" | 144 | // vocoder = "hifigan_v2.onnx" |
| 140 | // dataDir = "matcha-icefall-en_US-ljspeech/espeak-ng-data" | 145 | // dataDir = "matcha-icefall-en_US-ljspeech/espeak-ng-data" |
| 141 | // lang = "eng" | 146 | // lang = "eng" |
| 147 | + | ||
| 148 | + // Example 9 | ||
| 149 | + // kokoro-en-v0_19 | ||
| 150 | + // modelDir = "kokoro-en-v0_19" | ||
| 151 | + // modelName = "model.onnx" | ||
| 152 | + // voices = "voices.bin" | ||
| 153 | + // dataDir = "kokoro-en-v0_19/espeak-ng-data" | ||
| 154 | + // lang = "eng" | ||
| 142 | } | 155 | } |
| 143 | 156 | ||
| 144 | fun createTts(context: Context) { | 157 | fun createTts(context: Context) { |
| @@ -167,6 +180,7 @@ object TtsEngine { | @@ -167,6 +180,7 @@ object TtsEngine { | ||
| 167 | modelName = modelName ?: "", | 180 | modelName = modelName ?: "", |
| 168 | acousticModelName = acousticModelName ?: "", | 181 | acousticModelName = acousticModelName ?: "", |
| 169 | vocoder = vocoder ?: "", | 182 | vocoder = vocoder ?: "", |
| 183 | + voices = voices ?: "", | ||
| 170 | lexicon = lexicon ?: "", | 184 | lexicon = lexicon ?: "", |
| 171 | dataDir = dataDir ?: "", | 185 | dataDir = dataDir ?: "", |
| 172 | dictDir = dictDir ?: "", | 186 | dictDir = dictDir ?: "", |
| 1 | +// Copyright 2025 Xiaomi Corporation | ||
| 2 | + | ||
| 3 | +// This file shows how to use a Kokoro English model | ||
| 4 | +// to convert text to speech | ||
| 5 | +import com.k2fsa.sherpa.onnx.*; | ||
| 6 | + | ||
| 7 | +public class NonStreamingTtsKokoroEn { | ||
| 8 | + public static void main(String[] args) { | ||
| 9 | + // please visit | ||
| 10 | + // https://k2-fsa.github.io/sherpa/onnx/tts/pretrained_models/kokoro.html | ||
| 11 | + // to download model files | ||
| 12 | + String model = "./kokoro-en-v0_19/model.onnx"; | ||
| 13 | + String voices = "./kokoro-en-v0_19/voices.bin"; | ||
| 14 | + String tokens = "./kokoro-en-v0_19/tokens.txt"; | ||
| 15 | + String dataDir = "./kokoro-en-v0_19/espeak-ng-data"; | ||
| 16 | + String text = | ||
| 17 | + "Today as always, men fall into two groups: slaves and free men. Whoever does not have" | ||
| 18 | + + " two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a" | ||
| 19 | + + " businessman, an official, or a scholar."; | ||
| 20 | + | ||
| 21 | + OfflineTtsKokoroModelConfig kokoroModelConfig = | ||
| 22 | + OfflineTtsKokoroModelConfig.builder() | ||
| 23 | + .setModel(model) | ||
| 24 | + .setVoices(voices) | ||
| 25 | + .setTokens(tokens) | ||
| 26 | + .setDataDir(dataDir) | ||
| 27 | + .build(); | ||
| 28 | + | ||
| 29 | + OfflineTtsModelConfig modelConfig = | ||
| 30 | + OfflineTtsModelConfig.builder() | ||
| 31 | + .setKokoro(kokoroModelConfig) | ||
| 32 | + .setNumThreads(2) | ||
| 33 | + .setDebug(true) | ||
| 34 | + .build(); | ||
| 35 | + | ||
| 36 | + OfflineTtsConfig config = OfflineTtsConfig.builder().setModel(modelConfig).build(); | ||
| 37 | + OfflineTts tts = new OfflineTts(config); | ||
| 38 | + | ||
| 39 | + int sid = 0; | ||
| 40 | + float speed = 1.0f; | ||
| 41 | + long start = System.currentTimeMillis(); | ||
| 42 | + GeneratedAudio audio = tts.generate(text, sid, speed); | ||
| 43 | + long stop = System.currentTimeMillis(); | ||
| 44 | + | ||
| 45 | + float timeElapsedSeconds = (stop - start) / 1000.0f; | ||
| 46 | + | ||
| 47 | + float audioDuration = audio.getSamples().length / (float) audio.getSampleRate(); | ||
| 48 | + float real_time_factor = timeElapsedSeconds / audioDuration; | ||
| 49 | + | ||
| 50 | + String waveFilename = "tts-kokoro-en.wav"; | ||
| 51 | + audio.save(waveFilename); | ||
| 52 | + System.out.printf("-- elapsed : %.3f seconds\n", timeElapsedSeconds); | ||
| 53 | + System.out.printf("-- audio duration: %.3f seconds\n", timeElapsedSeconds); | ||
| 54 | + System.out.printf("-- real-time factor (RTF): %.3f\n", real_time_factor); | ||
| 55 | + System.out.printf("-- text: %s\n", text); | ||
| 56 | + System.out.printf("-- Saved to %s\n", waveFilename); | ||
| 57 | + | ||
| 58 | + tts.release(); | ||
| 59 | + } | ||
| 60 | +} |
| 1 | +#!/usr/bin/env bash | ||
| 2 | + | ||
| 3 | +set -ex | ||
| 4 | + | ||
| 5 | +if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then | ||
| 6 | + mkdir -p ../build | ||
| 7 | + pushd ../build | ||
| 8 | + cmake \ | ||
| 9 | + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ | ||
| 10 | + -DSHERPA_ONNX_ENABLE_TESTS=OFF \ | ||
| 11 | + -DSHERPA_ONNX_ENABLE_CHECK=OFF \ | ||
| 12 | + -DBUILD_SHARED_LIBS=ON \ | ||
| 13 | + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \ | ||
| 14 | + -DSHERPA_ONNX_ENABLE_JNI=ON \ | ||
| 15 | + .. | ||
| 16 | + | ||
| 17 | + make -j4 | ||
| 18 | + ls -lh lib | ||
| 19 | + popd | ||
| 20 | +fi | ||
| 21 | + | ||
| 22 | +if [ ! -f ../sherpa-onnx/java-api/build/sherpa-onnx.jar ]; then | ||
| 23 | + pushd ../sherpa-onnx/java-api | ||
| 24 | + make | ||
| 25 | + popd | ||
| 26 | +fi | ||
| 27 | + | ||
| 28 | +# please visit | ||
| 29 | +# https://k2-fsa.github.io/sherpa/onnx/tts/pretrained_models/kokoro.html | ||
| 30 | +# to download more models | ||
| 31 | +if [ ! -f ./kokoro-en-v0_19/model.onnx ]; then | ||
| 32 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/kokoro-en-v0_19.tar.bz2 | ||
| 33 | + tar xf kokoro-en-v0_19.tar.bz2 | ||
| 34 | + rm kokoro-en-v0_19.tar.bz2 | ||
| 35 | +fi | ||
| 36 | + | ||
| 37 | +java \ | ||
| 38 | + -Djava.library.path=$PWD/../build/lib \ | ||
| 39 | + -cp ../sherpa-onnx/java-api/build/sherpa-onnx.jar \ | ||
| 40 | + NonStreamingTtsKokoroEn.java |
| @@ -115,6 +115,12 @@ function testTts() { | @@ -115,6 +115,12 @@ function testTts() { | ||
| 115 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/vocoder-models/hifigan_v2.onnx | 115 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/vocoder-models/hifigan_v2.onnx |
| 116 | fi | 116 | fi |
| 117 | 117 | ||
| 118 | + if [ ! -f ./kokoro-en-v0_19/model.onnx ]; then | ||
| 119 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/kokoro-en-v0_19.tar.bz2 | ||
| 120 | + tar xf kokoro-en-v0_19.tar.bz2 | ||
| 121 | + rm kokoro-en-v0_19.tar.bz2 | ||
| 122 | + fi | ||
| 123 | + | ||
| 118 | out_filename=test_tts.jar | 124 | out_filename=test_tts.jar |
| 119 | kotlinc-jvm -include-runtime -d $out_filename \ | 125 | kotlinc-jvm -include-runtime -d $out_filename \ |
| 120 | test_tts.kt \ | 126 | test_tts.kt \ |
| @@ -3,6 +3,28 @@ package com.k2fsa.sherpa.onnx | @@ -3,6 +3,28 @@ package com.k2fsa.sherpa.onnx | ||
| 3 | fun main() { | 3 | fun main() { |
| 4 | testVits() | 4 | testVits() |
| 5 | testMatcha() | 5 | testMatcha() |
| 6 | + testKokoro() | ||
| 7 | +} | ||
| 8 | + | ||
| 9 | +fun testKokoro() { | ||
| 10 | + // see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models | ||
| 11 | + var config = OfflineTtsConfig( | ||
| 12 | + model=OfflineTtsModelConfig( | ||
| 13 | + kokoro=OfflineTtsKokoroModelConfig( | ||
| 14 | + model="./kokoro-en-v0_19/model.onnx", | ||
| 15 | + voices="./kokoro-en-v0_19/voices.bin", | ||
| 16 | + tokens="./kokoro-en-v0_19/tokens.txt", | ||
| 17 | + dataDir="./kokoro-en-v0_19/espeak-ng-data", | ||
| 18 | + ), | ||
| 19 | + numThreads=2, | ||
| 20 | + debug=true, | ||
| 21 | + ), | ||
| 22 | + ) | ||
| 23 | + val tts = OfflineTts(config=config) | ||
| 24 | + val audio = tts.generateWithCallback(text="How are you doing today?", callback=::callback) | ||
| 25 | + audio.save(filename="test-kokoro-en.wav") | ||
| 26 | + tts.release() | ||
| 27 | + println("Saved to test-kokoro-en.wav") | ||
| 6 | } | 28 | } |
| 7 | 29 | ||
| 8 | fun testMatcha() { | 30 | fun testMatcha() { |
| @@ -24,9 +46,9 @@ fun testMatcha() { | @@ -24,9 +46,9 @@ fun testMatcha() { | ||
| 24 | ) | 46 | ) |
| 25 | val tts = OfflineTts(config=config) | 47 | val tts = OfflineTts(config=config) |
| 26 | val audio = tts.generateWithCallback(text="某某银行的副行长和一些行政领导表示,他们去过长江和长白山; 经济不断增长。2024年12月31号,拨打110或者18920240511。123456块钱。", callback=::callback) | 48 | val audio = tts.generateWithCallback(text="某某银行的副行长和一些行政领导表示,他们去过长江和长白山; 经济不断增长。2024年12月31号,拨打110或者18920240511。123456块钱。", callback=::callback) |
| 27 | - audio.save(filename="test-zh.wav") | 49 | + audio.save(filename="test-matcha-zh.wav") |
| 28 | tts.release() | 50 | tts.release() |
| 29 | - println("Saved to test-zh.wav") | 51 | + println("Saved to test-matcha-zh.wav") |
| 30 | } | 52 | } |
| 31 | 53 | ||
| 32 | fun testVits() { | 54 | fun testVits() { |
| @@ -39,6 +39,7 @@ model_dir={{ tts_model.model_dir }} | @@ -39,6 +39,7 @@ model_dir={{ tts_model.model_dir }} | ||
| 39 | model_name={{ tts_model.model_name }} | 39 | model_name={{ tts_model.model_name }} |
| 40 | acoustic_model_name={{ tts_model.acoustic_model_name }} | 40 | acoustic_model_name={{ tts_model.acoustic_model_name }} |
| 41 | vocoder={{ tts_model.vocoder }} | 41 | vocoder={{ tts_model.vocoder }} |
| 42 | +voices={{ tts_model.voices }} | ||
| 42 | lang={{ tts_model.lang }} | 43 | lang={{ tts_model.lang }} |
| 43 | lang_iso_639_3={{ tts_model.lang_iso_639_3 }} | 44 | lang_iso_639_3={{ tts_model.lang_iso_639_3 }} |
| 44 | 45 | ||
| @@ -70,6 +71,10 @@ sed -i.bak s/"lang = null"/"lang = \"$lang_iso_639_3\""/ ./TtsEngine.kt | @@ -70,6 +71,10 @@ sed -i.bak s/"lang = null"/"lang = \"$lang_iso_639_3\""/ ./TtsEngine.kt | ||
| 70 | sed -i.bak s/"vocoder = null"/"vocoder = \"$vocoder\""/ ./TtsEngine.kt | 71 | sed -i.bak s/"vocoder = null"/"vocoder = \"$vocoder\""/ ./TtsEngine.kt |
| 71 | {% endif %} | 72 | {% endif %} |
| 72 | 73 | ||
| 74 | +{% if tts_model.voices %} | ||
| 75 | + sed -i.bak s/"voices = null"/"voices = \"$voices\""/ ./TtsEngine.kt | ||
| 76 | +{% endif %} | ||
| 77 | + | ||
| 73 | {% if tts_model.rule_fsts %} | 78 | {% if tts_model.rule_fsts %} |
| 74 | rule_fsts={{ tts_model.rule_fsts }} | 79 | rule_fsts={{ tts_model.rule_fsts }} |
| 75 | sed -i.bak s%"ruleFsts = null"%"ruleFsts = \"$rule_fsts\""% ./TtsEngine.kt | 80 | sed -i.bak s%"ruleFsts = null"%"ruleFsts = \"$rule_fsts\""% ./TtsEngine.kt |
| @@ -39,6 +39,7 @@ model_dir={{ tts_model.model_dir }} | @@ -39,6 +39,7 @@ model_dir={{ tts_model.model_dir }} | ||
| 39 | model_name={{ tts_model.model_name }} | 39 | model_name={{ tts_model.model_name }} |
| 40 | acoustic_model_name={{ tts_model.acoustic_model_name }} | 40 | acoustic_model_name={{ tts_model.acoustic_model_name }} |
| 41 | vocoder={{ tts_model.vocoder }} | 41 | vocoder={{ tts_model.vocoder }} |
| 42 | +voices={{ tts_model.voices }} | ||
| 42 | lang={{ tts_model.lang }} | 43 | lang={{ tts_model.lang }} |
| 43 | 44 | ||
| 44 | wget -qq https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/$model_dir.tar.bz2 | 45 | wget -qq https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/$model_dir.tar.bz2 |
| @@ -69,6 +70,9 @@ sed -i.bak s/"modelDir = null"/"modelDir = \"$model_dir\""/ ./MainActivity.kt | @@ -69,6 +70,9 @@ sed -i.bak s/"modelDir = null"/"modelDir = \"$model_dir\""/ ./MainActivity.kt | ||
| 69 | sed -i.bak s/"vocoder = null"/"vocoder = \"$vocoder\""/ ./MainActivity.kt | 70 | sed -i.bak s/"vocoder = null"/"vocoder = \"$vocoder\""/ ./MainActivity.kt |
| 70 | {% endif %} | 71 | {% endif %} |
| 71 | 72 | ||
| 73 | +{% if tts_model.voices %} | ||
| 74 | + sed -i.bak s/"voices = null"/"voices = \"$voices\""/ ./MainActivity.kt | ||
| 75 | +{% endif %} | ||
| 72 | 76 | ||
| 73 | {% if tts_model.rule_fsts %} | 77 | {% if tts_model.rule_fsts %} |
| 74 | rule_fsts={{ tts_model.rule_fsts }} | 78 | rule_fsts={{ tts_model.rule_fsts }} |
| @@ -33,6 +33,7 @@ class TtsModel: | @@ -33,6 +33,7 @@ class TtsModel: | ||
| 33 | model_name: str = "" # for vits | 33 | model_name: str = "" # for vits |
| 34 | acoustic_model_name: str = "" # for matcha | 34 | acoustic_model_name: str = "" # for matcha |
| 35 | vocoder: str = "" # for matcha | 35 | vocoder: str = "" # for matcha |
| 36 | + voices: str = "" # for kokoro | ||
| 36 | lang: str = "" # en, zh, fr, de, etc. | 37 | lang: str = "" # en, zh, fr, de, etc. |
| 37 | rule_fsts: Optional[List[str]] = None | 38 | rule_fsts: Optional[List[str]] = None |
| 38 | rule_fars: Optional[List[str]] = None | 39 | rule_fars: Optional[List[str]] = None |
| @@ -409,6 +410,21 @@ def get_matcha_models() -> List[TtsModel]: | @@ -409,6 +410,21 @@ def get_matcha_models() -> List[TtsModel]: | ||
| 409 | return chinese_models + english_models | 410 | return chinese_models + english_models |
| 410 | 411 | ||
| 411 | 412 | ||
| 413 | +def get_kokoro_models() -> List[TtsModel]: | ||
| 414 | + english_models = [ | ||
| 415 | + TtsModel( | ||
| 416 | + model_dir="kokoro-en-v0_19", | ||
| 417 | + model_name="model.onnx", | ||
| 418 | + lang="en", | ||
| 419 | + ) | ||
| 420 | + ] | ||
| 421 | + for m in english_models: | ||
| 422 | + m.data_dir = f"{m.model_dir}/espeak-ng-data" | ||
| 423 | + m.voices = "voices.bin" | ||
| 424 | + | ||
| 425 | + return english_models | ||
| 426 | + | ||
| 427 | + | ||
| 412 | def main(): | 428 | def main(): |
| 413 | args = get_args() | 429 | args = get_args() |
| 414 | index = args.index | 430 | index = args.index |
| @@ -421,6 +437,7 @@ def main(): | @@ -421,6 +437,7 @@ def main(): | ||
| 421 | all_model_list += get_mimic3_models() | 437 | all_model_list += get_mimic3_models() |
| 422 | all_model_list += get_coqui_models() | 438 | all_model_list += get_coqui_models() |
| 423 | all_model_list += get_matcha_models() | 439 | all_model_list += get_matcha_models() |
| 440 | + all_model_list += get_kokoro_models() | ||
| 424 | 441 | ||
| 425 | convert_lang_to_iso_639_3(all_model_list) | 442 | convert_lang_to_iso_639_3(all_model_list) |
| 426 | print(all_model_list) | 443 | print(all_model_list) |
| @@ -35,6 +35,7 @@ java_files += OfflineRecognizerResult.java | @@ -35,6 +35,7 @@ java_files += OfflineRecognizerResult.java | ||
| 35 | java_files += OfflineStream.java | 35 | java_files += OfflineStream.java |
| 36 | java_files += OfflineRecognizer.java | 36 | java_files += OfflineRecognizer.java |
| 37 | 37 | ||
| 38 | +java_files += OfflineTtsKokoroModelConfig.java | ||
| 38 | java_files += OfflineTtsMatchaModelConfig.java | 39 | java_files += OfflineTtsMatchaModelConfig.java |
| 39 | java_files += OfflineTtsVitsModelConfig.java | 40 | java_files += OfflineTtsVitsModelConfig.java |
| 40 | java_files += OfflineTtsModelConfig.java | 41 | java_files += OfflineTtsModelConfig.java |
| 1 | +// Copyright 2025 Xiaomi Corporation | ||
| 2 | +package com.k2fsa.sherpa.onnx; | ||
| 3 | + | ||
| 4 | +public class OfflineTtsKokoroModelConfig { | ||
| 5 | + private final String model; | ||
| 6 | + private final String voices; | ||
| 7 | + private final String tokens; | ||
| 8 | + private final String dataDir; | ||
| 9 | + private final float lengthScale; | ||
| 10 | + | ||
| 11 | + private OfflineTtsKokoroModelConfig(Builder builder) { | ||
| 12 | + this.model = builder.model; | ||
| 13 | + this.voices = builder.voices; | ||
| 14 | + this.tokens = builder.tokens; | ||
| 15 | + this.dataDir = builder.dataDir; | ||
| 16 | + this.lengthScale = builder.lengthScale; | ||
| 17 | + } | ||
| 18 | + | ||
| 19 | + public static Builder builder() { | ||
| 20 | + return new Builder(); | ||
| 21 | + } | ||
| 22 | + | ||
| 23 | + public String getModel() { | ||
| 24 | + return model; | ||
| 25 | + } | ||
| 26 | + | ||
| 27 | + public String getVoices() { | ||
| 28 | + return voices; | ||
| 29 | + } | ||
| 30 | + | ||
| 31 | + public String getTokens() { | ||
| 32 | + return tokens; | ||
| 33 | + } | ||
| 34 | + | ||
| 35 | + public String getDataDir() { | ||
| 36 | + return dataDir; | ||
| 37 | + } | ||
| 38 | + | ||
| 39 | + public float getLengthScale() { | ||
| 40 | + return lengthScale; | ||
| 41 | + } | ||
| 42 | + | ||
| 43 | + | ||
| 44 | + public static class Builder { | ||
| 45 | + private String model = ""; | ||
| 46 | + private String voices = ""; | ||
| 47 | + private String tokens = ""; | ||
| 48 | + private String dataDir = ""; | ||
| 49 | + private float lengthScale = 1.0f; | ||
| 50 | + | ||
| 51 | + public OfflineTtsKokoroModelConfig build() { | ||
| 52 | + return new OfflineTtsKokoroModelConfig(this); | ||
| 53 | + } | ||
| 54 | + | ||
| 55 | + public Builder setModel(String model) { | ||
| 56 | + this.model = model; | ||
| 57 | + return this; | ||
| 58 | + } | ||
| 59 | + | ||
| 60 | + public Builder setVoices(String voices) { | ||
| 61 | + this.voices = voices; | ||
| 62 | + return this; | ||
| 63 | + } | ||
| 64 | + | ||
| 65 | + public Builder setTokens(String tokens) { | ||
| 66 | + this.tokens = tokens; | ||
| 67 | + return this; | ||
| 68 | + } | ||
| 69 | + | ||
| 70 | + public Builder setDataDir(String dataDir) { | ||
| 71 | + this.dataDir = dataDir; | ||
| 72 | + return this; | ||
| 73 | + } | ||
| 74 | + | ||
| 75 | + public Builder setLengthScale(float lengthScale) { | ||
| 76 | + this.lengthScale = lengthScale; | ||
| 77 | + return this; | ||
| 78 | + } | ||
| 79 | + } | ||
| 80 | +} |
| @@ -5,6 +5,7 @@ package com.k2fsa.sherpa.onnx; | @@ -5,6 +5,7 @@ package com.k2fsa.sherpa.onnx; | ||
| 5 | public class OfflineTtsModelConfig { | 5 | public class OfflineTtsModelConfig { |
| 6 | private final OfflineTtsVitsModelConfig vits; | 6 | private final OfflineTtsVitsModelConfig vits; |
| 7 | private final OfflineTtsMatchaModelConfig matcha; | 7 | private final OfflineTtsMatchaModelConfig matcha; |
| 8 | + private final OfflineTtsKokoroModelConfig kokoro; | ||
| 8 | private final int numThreads; | 9 | private final int numThreads; |
| 9 | private final boolean debug; | 10 | private final boolean debug; |
| 10 | private final String provider; | 11 | private final String provider; |
| @@ -12,6 +13,7 @@ public class OfflineTtsModelConfig { | @@ -12,6 +13,7 @@ public class OfflineTtsModelConfig { | ||
| 12 | private OfflineTtsModelConfig(Builder builder) { | 13 | private OfflineTtsModelConfig(Builder builder) { |
| 13 | this.vits = builder.vits; | 14 | this.vits = builder.vits; |
| 14 | this.matcha = builder.matcha; | 15 | this.matcha = builder.matcha; |
| 16 | + this.kokoro = builder.kokoro; | ||
| 15 | this.numThreads = builder.numThreads; | 17 | this.numThreads = builder.numThreads; |
| 16 | this.debug = builder.debug; | 18 | this.debug = builder.debug; |
| 17 | this.provider = builder.provider; | 19 | this.provider = builder.provider; |
| @@ -29,9 +31,14 @@ public class OfflineTtsModelConfig { | @@ -29,9 +31,14 @@ public class OfflineTtsModelConfig { | ||
| 29 | return matcha; | 31 | return matcha; |
| 30 | } | 32 | } |
| 31 | 33 | ||
| 34 | + public OfflineTtsKokoroModelConfig getKokoro() { | ||
| 35 | + return kokoro; | ||
| 36 | + } | ||
| 37 | + | ||
| 32 | public static class Builder { | 38 | public static class Builder { |
| 33 | private OfflineTtsVitsModelConfig vits = OfflineTtsVitsModelConfig.builder().build(); | 39 | private OfflineTtsVitsModelConfig vits = OfflineTtsVitsModelConfig.builder().build(); |
| 34 | private OfflineTtsMatchaModelConfig matcha = OfflineTtsMatchaModelConfig.builder().build(); | 40 | private OfflineTtsMatchaModelConfig matcha = OfflineTtsMatchaModelConfig.builder().build(); |
| 41 | + private OfflineTtsKokoroModelConfig kokoro = OfflineTtsKokoroModelConfig.builder().build(); | ||
| 35 | private int numThreads = 1; | 42 | private int numThreads = 1; |
| 36 | private boolean debug = true; | 43 | private boolean debug = true; |
| 37 | private String provider = "cpu"; | 44 | private String provider = "cpu"; |
| @@ -50,6 +57,11 @@ public class OfflineTtsModelConfig { | @@ -50,6 +57,11 @@ public class OfflineTtsModelConfig { | ||
| 50 | return this; | 57 | return this; |
| 51 | } | 58 | } |
| 52 | 59 | ||
| 60 | + public Builder setKokoro(OfflineTtsKokoroModelConfig kokoro) { | ||
| 61 | + this.kokoro = kokoro; | ||
| 62 | + return this; | ||
| 63 | + } | ||
| 64 | + | ||
| 53 | public Builder setNumThreads(int numThreads) { | 65 | public Builder setNumThreads(int numThreads) { |
| 54 | this.numThreads = numThreads; | 66 | this.numThreads = numThreads; |
| 55 | return this; | 67 | return this; |
| @@ -113,6 +113,39 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) { | @@ -113,6 +113,39 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) { | ||
| 113 | fid = env->GetFieldID(matcha_cls, "lengthScale", "F"); | 113 | fid = env->GetFieldID(matcha_cls, "lengthScale", "F"); |
| 114 | ans.model.matcha.length_scale = env->GetFloatField(matcha, fid); | 114 | ans.model.matcha.length_scale = env->GetFloatField(matcha, fid); |
| 115 | 115 | ||
| 116 | + // kokoro | ||
| 117 | + fid = env->GetFieldID(model_config_cls, "kokoro", | ||
| 118 | + "Lcom/k2fsa/sherpa/onnx/OfflineTtsKokoroModelConfig;"); | ||
| 119 | + jobject kokoro = env->GetObjectField(model, fid); | ||
| 120 | + jclass kokoro_cls = env->GetObjectClass(kokoro); | ||
| 121 | + | ||
| 122 | + fid = env->GetFieldID(kokoro_cls, "model", "Ljava/lang/String;"); | ||
| 123 | + s = (jstring)env->GetObjectField(kokoro, fid); | ||
| 124 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 125 | + ans.model.kokoro.model = p; | ||
| 126 | + env->ReleaseStringUTFChars(s, p); | ||
| 127 | + | ||
| 128 | + fid = env->GetFieldID(kokoro_cls, "voices", "Ljava/lang/String;"); | ||
| 129 | + s = (jstring)env->GetObjectField(kokoro, fid); | ||
| 130 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 131 | + ans.model.kokoro.voices = p; | ||
| 132 | + env->ReleaseStringUTFChars(s, p); | ||
| 133 | + | ||
| 134 | + fid = env->GetFieldID(kokoro_cls, "tokens", "Ljava/lang/String;"); | ||
| 135 | + s = (jstring)env->GetObjectField(kokoro, fid); | ||
| 136 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 137 | + ans.model.kokoro.tokens = p; | ||
| 138 | + env->ReleaseStringUTFChars(s, p); | ||
| 139 | + | ||
| 140 | + fid = env->GetFieldID(kokoro_cls, "dataDir", "Ljava/lang/String;"); | ||
| 141 | + s = (jstring)env->GetObjectField(kokoro, fid); | ||
| 142 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 143 | + ans.model.kokoro.data_dir = p; | ||
| 144 | + env->ReleaseStringUTFChars(s, p); | ||
| 145 | + | ||
| 146 | + fid = env->GetFieldID(kokoro_cls, "lengthScale", "F"); | ||
| 147 | + ans.model.kokoro.length_scale = env->GetFloatField(kokoro, fid); | ||
| 148 | + | ||
| 116 | fid = env->GetFieldID(model_config_cls, "numThreads", "I"); | 149 | fid = env->GetFieldID(model_config_cls, "numThreads", "I"); |
| 117 | ans.model.num_threads = env->GetIntField(model, fid); | 150 | ans.model.num_threads = env->GetIntField(model, fid); |
| 118 | 151 | ||
| @@ -273,8 +306,8 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl( | @@ -273,8 +306,8 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl( | ||
| 273 | return env->CallIntMethod(should_continue, int_value_mid); | 306 | return env->CallIntMethod(should_continue, int_value_mid); |
| 274 | }; | 307 | }; |
| 275 | 308 | ||
| 276 | - auto audio = reinterpret_cast<sherpa_onnx::OfflineTts *>(ptr)->Generate( | ||
| 277 | - p_text, sid, speed, callback_wrapper); | 309 | + auto tts = reinterpret_cast<sherpa_onnx::OfflineTts *>(ptr); |
| 310 | + auto audio = tts->Generate(p_text, sid, speed, callback_wrapper); | ||
| 278 | 311 | ||
| 279 | jfloatArray samples_arr = env->NewFloatArray(audio.samples.size()); | 312 | jfloatArray samples_arr = env->NewFloatArray(audio.samples.size()); |
| 280 | env->SetFloatArrayRegion(samples_arr, 0, audio.samples.size(), | 313 | env->SetFloatArrayRegion(samples_arr, 0, audio.samples.size(), |
| @@ -25,9 +25,18 @@ data class OfflineTtsMatchaModelConfig( | @@ -25,9 +25,18 @@ data class OfflineTtsMatchaModelConfig( | ||
| 25 | var lengthScale: Float = 1.0f, | 25 | var lengthScale: Float = 1.0f, |
| 26 | ) | 26 | ) |
| 27 | 27 | ||
| 28 | +data class OfflineTtsKokoroModelConfig( | ||
| 29 | + var model: String = "", | ||
| 30 | + var voices: String = "", | ||
| 31 | + var tokens: String = "", | ||
| 32 | + var dataDir: String = "", | ||
| 33 | + var lengthScale: Float = 1.0f, | ||
| 34 | +) | ||
| 35 | + | ||
| 28 | data class OfflineTtsModelConfig( | 36 | data class OfflineTtsModelConfig( |
| 29 | var vits: OfflineTtsVitsModelConfig = OfflineTtsVitsModelConfig(), | 37 | var vits: OfflineTtsVitsModelConfig = OfflineTtsVitsModelConfig(), |
| 30 | var matcha: OfflineTtsMatchaModelConfig = OfflineTtsMatchaModelConfig(), | 38 | var matcha: OfflineTtsMatchaModelConfig = OfflineTtsMatchaModelConfig(), |
| 39 | + var kokoro: OfflineTtsKokoroModelConfig = OfflineTtsKokoroModelConfig(), | ||
| 31 | var numThreads: Int = 1, | 40 | var numThreads: Int = 1, |
| 32 | var debug: Boolean = false, | 41 | var debug: Boolean = false, |
| 33 | var provider: String = "cpu", | 42 | var provider: String = "cpu", |
| @@ -176,12 +185,32 @@ fun getOfflineTtsConfig( | @@ -176,12 +185,32 @@ fun getOfflineTtsConfig( | ||
| 176 | modelName: String, // for VITS | 185 | modelName: String, // for VITS |
| 177 | acousticModelName: String, // for Matcha | 186 | acousticModelName: String, // for Matcha |
| 178 | vocoder: String, // for Matcha | 187 | vocoder: String, // for Matcha |
| 188 | + voices: String, // for Kokoro | ||
| 179 | lexicon: String, | 189 | lexicon: String, |
| 180 | dataDir: String, | 190 | dataDir: String, |
| 181 | dictDir: String, | 191 | dictDir: String, |
| 182 | ruleFsts: String, | 192 | ruleFsts: String, |
| 183 | - ruleFars: String | 193 | + ruleFars: String, |
| 194 | + numThreads: Int? = null | ||
| 184 | ): OfflineTtsConfig { | 195 | ): OfflineTtsConfig { |
| 196 | + // For Matcha TTS, please set | ||
| 197 | + // acousticModelName, vocoder | ||
| 198 | + | ||
| 199 | + // For Kokoro TTS, please set | ||
| 200 | + // modelName, voices | ||
| 201 | + | ||
| 202 | + // For VITS, please set | ||
| 203 | + // modelName | ||
| 204 | + | ||
| 205 | + val numberOfThreads = if (numThreads != null) { | ||
| 206 | + numThreads | ||
| 207 | + } else if (voices.isNotEmpty()) { | ||
| 208 | + // for Kokoro TTS models, we use more threads | ||
| 209 | + 4 | ||
| 210 | + } else { | ||
| 211 | + 2 | ||
| 212 | + } | ||
| 213 | + | ||
| 185 | if (modelName.isEmpty() && acousticModelName.isEmpty()) { | 214 | if (modelName.isEmpty() && acousticModelName.isEmpty()) { |
| 186 | throw IllegalArgumentException("Please specify a TTS model") | 215 | throw IllegalArgumentException("Please specify a TTS model") |
| 187 | } | 216 | } |
| @@ -193,7 +222,8 @@ fun getOfflineTtsConfig( | @@ -193,7 +222,8 @@ fun getOfflineTtsConfig( | ||
| 193 | if (acousticModelName.isNotEmpty() && vocoder.isEmpty()) { | 222 | if (acousticModelName.isNotEmpty() && vocoder.isEmpty()) { |
| 194 | throw IllegalArgumentException("Please provide vocoder for Matcha TTS") | 223 | throw IllegalArgumentException("Please provide vocoder for Matcha TTS") |
| 195 | } | 224 | } |
| 196 | - val vits = if (modelName.isNotEmpty()) { | 225 | + |
| 226 | + val vits = if (modelName.isNotEmpty() && voices.isEmpty()) { | ||
| 197 | OfflineTtsVitsModelConfig( | 227 | OfflineTtsVitsModelConfig( |
| 198 | model = "$modelDir/$modelName", | 228 | model = "$modelDir/$modelName", |
| 199 | lexicon = "$modelDir/$lexicon", | 229 | lexicon = "$modelDir/$lexicon", |
| @@ -218,11 +248,23 @@ fun getOfflineTtsConfig( | @@ -218,11 +248,23 @@ fun getOfflineTtsConfig( | ||
| 218 | OfflineTtsMatchaModelConfig() | 248 | OfflineTtsMatchaModelConfig() |
| 219 | } | 249 | } |
| 220 | 250 | ||
| 251 | + val kokoro = if (voices.isNotEmpty()) { | ||
| 252 | + OfflineTtsKokoroModelConfig( | ||
| 253 | + model = "$modelDir/$modelName", | ||
| 254 | + voices = "$modelDir/$voices", | ||
| 255 | + tokens = "$modelDir/tokens.txt", | ||
| 256 | + dataDir = dataDir, | ||
| 257 | + ) | ||
| 258 | + } else { | ||
| 259 | + OfflineTtsKokoroModelConfig() | ||
| 260 | + } | ||
| 261 | + | ||
| 221 | return OfflineTtsConfig( | 262 | return OfflineTtsConfig( |
| 222 | model = OfflineTtsModelConfig( | 263 | model = OfflineTtsModelConfig( |
| 223 | vits = vits, | 264 | vits = vits, |
| 224 | matcha = matcha, | 265 | matcha = matcha, |
| 225 | - numThreads = 2, | 266 | + kokoro = kokoro, |
| 267 | + numThreads = numberOfThreads, | ||
| 226 | debug = true, | 268 | debug = true, |
| 227 | provider = "cpu", | 269 | provider = "cpu", |
| 228 | ), | 270 | ), |
-
请 注册 或 登录 后发表评论