Fangjun Kuang
Committed by GitHub

Support playing as it is generating for Android (#477)

1 package com.k2fsa.sherpa.onnx 1 package com.k2fsa.sherpa.onnx
2 2
3 import android.content.res.AssetManager 3 import android.content.res.AssetManager
4 -import android.media.MediaPlayer 4 +import android.media.*
5 import android.net.Uri 5 import android.net.Uri
6 import android.os.Bundle 6 import android.os.Bundle
7 import android.util.Log 7 import android.util.Log
@@ -23,6 +23,10 @@ class MainActivity : AppCompatActivity() { @@ -23,6 +23,10 @@ class MainActivity : AppCompatActivity() {
23 private lateinit var generate: Button 23 private lateinit var generate: Button
24 private lateinit var play: Button 24 private lateinit var play: Button
25 25
  26 + // see
  27 + // https://developer.android.com/reference/kotlin/android/media/AudioTrack
  28 + private lateinit var track: AudioTrack
  29 +
26 override fun onCreate(savedInstanceState: Bundle?) { 30 override fun onCreate(savedInstanceState: Bundle?) {
27 super.onCreate(savedInstanceState) 31 super.onCreate(savedInstanceState)
28 setContentView(R.layout.activity_main) 32 setContentView(R.layout.activity_main)
@@ -31,6 +35,10 @@ class MainActivity : AppCompatActivity() { @@ -31,6 +35,10 @@ class MainActivity : AppCompatActivity() {
31 initTts() 35 initTts()
32 Log.i(TAG, "Finish initializing TTS") 36 Log.i(TAG, "Finish initializing TTS")
33 37
  38 + Log.i(TAG, "Start to initialize AudioTrack")
  39 + initAudioTrack()
  40 + Log.i(TAG, "Finish initializing AudioTrack")
  41 +
34 text = findViewById(R.id.text) 42 text = findViewById(R.id.text)
35 sid = findViewById(R.id.sid) 43 sid = findViewById(R.id.sid)
36 speed = findViewById(R.id.speed) 44 speed = findViewById(R.id.speed)
@@ -51,6 +59,33 @@ class MainActivity : AppCompatActivity() { @@ -51,6 +59,33 @@ class MainActivity : AppCompatActivity() {
51 play.isEnabled = false 59 play.isEnabled = false
52 } 60 }
53 61
  62 + private fun initAudioTrack() {
  63 + val sampleRate = tts.sampleRate()
  64 + val bufLength = (sampleRate * 0.1).toInt()
  65 + Log.i(TAG, "sampleRate: ${sampleRate}, buffLength: ${bufLength}")
  66 +
  67 + val attr = AudioAttributes.Builder().setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
  68 + .setUsage(AudioAttributes.USAGE_MEDIA)
  69 + .build()
  70 +
  71 + val format = AudioFormat.Builder()
  72 + .setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
  73 + .setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
  74 + .setSampleRate(sampleRate)
  75 + .build()
  76 +
  77 + track = AudioTrack(
  78 + attr, format, bufLength, AudioTrack.MODE_STREAM,
  79 + AudioManager.AUDIO_SESSION_ID_GENERATE
  80 + )
  81 + track.play()
  82 + }
  83 +
  84 + // this function is called from C++
  85 + private fun callback(samples: FloatArray) {
  86 + track.write(samples, 0, samples.size, AudioTrack.WRITE_BLOCKING)
  87 + }
  88 +
54 private fun onClickGenerate() { 89 private fun onClickGenerate() {
55 val sidInt = sid.text.toString().toIntOrNull() 90 val sidInt = sid.text.toString().toIntOrNull()
56 if (sidInt == null || sidInt < 0) { 91 if (sidInt == null || sidInt < 0) {
@@ -79,16 +114,28 @@ class MainActivity : AppCompatActivity() { @@ -79,16 +114,28 @@ class MainActivity : AppCompatActivity() {
79 return 114 return
80 } 115 }
81 116
82 - play.isEnabled = false  
83 - val audio = tts.generate(text = textStr, sid = sidInt, speed = speedFloat) 117 + track.pause()
  118 + track.flush()
  119 + track.play()
84 120
85 - val filename = application.filesDir.absolutePath + "/generated.wav"  
86 - val ok = audio.samples.size > 0 && audio.save(filename)  
87 - if (ok) {  
88 - play.isEnabled = true  
89 - // Play automatically after generation  
90 - onClickPlay()  
91 - } 121 + play.isEnabled = false
  122 + Thread {
  123 + val audio = tts.generateWithCallback(
  124 + text = textStr,
  125 + sid = sidInt,
  126 + speed = speedFloat,
  127 + callback = this::callback
  128 + )
  129 +
  130 + val filename = application.filesDir.absolutePath + "/generated.wav"
  131 + val ok = audio.samples.size > 0 && audio.save(filename)
  132 + if (ok) {
  133 + runOnUiThread {
  134 + play.isEnabled = true
  135 + track.stop()
  136 + }
  137 + }
  138 + }.start()
92 } 139 }
93 140
94 private fun onClickPlay() { 141 private fun onClickPlay() {
@@ -54,6 +54,8 @@ class OfflineTts( @@ -54,6 +54,8 @@ class OfflineTts(
54 } 54 }
55 } 55 }
56 56
  57 + fun sampleRate() = getSampleRate(ptr)
  58 +
57 fun generate( 59 fun generate(
58 text: String, 60 text: String,
59 sid: Int = 0, 61 sid: Int = 0,
@@ -66,6 +68,19 @@ class OfflineTts( @@ -66,6 +68,19 @@ class OfflineTts(
66 ) 68 )
67 } 69 }
68 70
  71 + fun generateWithCallback(
  72 + text: String,
  73 + sid: Int = 0,
  74 + speed: Float = 1.0f,
  75 + callback: (samples: FloatArray) -> Unit
  76 + ): GeneratedAudio {
  77 + var objArray = generateWithCallbackImpl(ptr, text = text, sid = sid, speed = speed, callback=callback)
  78 + return GeneratedAudio(
  79 + samples = objArray[0] as FloatArray,
  80 + sampleRate = objArray[1] as Int
  81 + )
  82 + }
  83 +
69 fun allocate(assetManager: AssetManager? = null) { 84 fun allocate(assetManager: AssetManager? = null) {
70 if (ptr == 0L) { 85 if (ptr == 0L) {
71 if (assetManager != null) { 86 if (assetManager != null) {
@@ -97,6 +112,7 @@ class OfflineTts( @@ -97,6 +112,7 @@ class OfflineTts(
97 ): Long 112 ): Long
98 113
99 private external fun delete(ptr: Long) 114 private external fun delete(ptr: Long)
  115 + private external fun getSampleRate(ptr: Long): Int
100 116
101 // The returned array has two entries: 117 // The returned array has two entries:
102 // - the first entry is an 1-D float array containing audio samples. 118 // - the first entry is an 1-D float array containing audio samples.
@@ -109,6 +125,14 @@ class OfflineTts( @@ -109,6 +125,14 @@ class OfflineTts(
109 speed: Float = 1.0f 125 speed: Float = 1.0f
110 ): Array<Any> 126 ): Array<Any>
111 127
  128 + external fun generateWithCallbackImpl(
  129 + ptr: Long,
  130 + text: String,
  131 + sid: Int = 0,
  132 + speed: Float = 1.0f,
  133 + callback: (samples: FloatArray) -> Unit
  134 + ): Array<Any>
  135 +
112 companion object { 136 companion object {
113 init { 137 init {
114 System.loadLibrary("sherpa-onnx-jni") 138 System.loadLibrary("sherpa-onnx-jni")
@@ -2,6 +2,10 @@ package com.k2fsa.sherpa.onnx @@ -2,6 +2,10 @@ package com.k2fsa.sherpa.onnx
2 2
3 import android.content.res.AssetManager 3 import android.content.res.AssetManager
4 4
  5 +fun callback(samples: FloatArray): Unit {
  6 + println("callback got called with ${samples.size} samples");
  7 +}
  8 +
5 fun main() { 9 fun main() {
6 testTts() 10 testTts()
7 testAsr() 11 testAsr()
@@ -22,7 +26,7 @@ fun testTts() { @@ -22,7 +26,7 @@ fun testTts() {
22 ) 26 )
23 ) 27 )
24 val tts = OfflineTts(config=config) 28 val tts = OfflineTts(config=config)
25 - val audio = tts.generate(text="“Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.”") 29 + val audio = tts.generateWithCallback(text="“Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.”", callback=::callback)
26 audio.save(filename="test-en.wav") 30 audio.save(filename="test-en.wav")
27 } 31 }
28 32
@@ -172,57 +172,57 @@ def get_vits_models() -> List[TtsModel]: @@ -172,57 +172,57 @@ def get_vits_models() -> List[TtsModel]:
172 lang="zh", 172 lang="zh",
173 rule_fsts="vits-zh-aishell3/rule.fst", 173 rule_fsts="vits-zh-aishell3/rule.fst",
174 ), 174 ),
175 - TtsModel(  
176 - model_dir="vits-zh-hf-doom",  
177 - model_name="doom.onnx",  
178 - lang="zh",  
179 - rule_fsts="vits-zh-hf-doom/rule.fst",  
180 - ),  
181 - TtsModel(  
182 - model_dir="vits-zh-hf-echo",  
183 - model_name="echo.onnx",  
184 - lang="zh",  
185 - rule_fsts="vits-zh-hf-echo/rule.fst",  
186 - ),  
187 - TtsModel(  
188 - model_dir="vits-zh-hf-zenyatta",  
189 - model_name="zenyatta.onnx",  
190 - lang="zh",  
191 - rule_fsts="vits-zh-hf-zenyatta/rule.fst",  
192 - ),  
193 - TtsModel(  
194 - model_dir="vits-zh-hf-abyssinvoker",  
195 - model_name="abyssinvoker.onnx",  
196 - lang="zh",  
197 - rule_fsts="vits-zh-hf-abyssinvoker/rule.fst",  
198 - ),  
199 - TtsModel(  
200 - model_dir="vits-zh-hf-keqing",  
201 - model_name="keqing.onnx",  
202 - lang="zh",  
203 - rule_fsts="vits-zh-hf-keqing/rule.fst",  
204 - ),  
205 - TtsModel(  
206 - model_dir="vits-zh-hf-eula",  
207 - model_name="eula.onnx",  
208 - lang="zh",  
209 - rule_fsts="vits-zh-hf-eula/rule.fst",  
210 - ),  
211 - TtsModel(  
212 - model_dir="vits-zh-hf-bronya",  
213 - model_name="bronya.onnx",  
214 - lang="zh",  
215 - rule_fsts="vits-zh-hf-bronya/rule.fst",  
216 - ),  
217 - TtsModel(  
218 - model_dir="vits-zh-hf-theresa",  
219 - model_name="theresa.onnx",  
220 - lang="zh",  
221 - rule_fsts="vits-zh-hf-theresa/rule.fst",  
222 - ), 175 + # TtsModel(
  176 + # model_dir="vits-zh-hf-doom",
  177 + # model_name="doom.onnx",
  178 + # lang="zh",
  179 + # rule_fsts="vits-zh-hf-doom/rule.fst",
  180 + # ),
  181 + # TtsModel(
  182 + # model_dir="vits-zh-hf-echo",
  183 + # model_name="echo.onnx",
  184 + # lang="zh",
  185 + # rule_fsts="vits-zh-hf-echo/rule.fst",
  186 + # ),
  187 + # TtsModel(
  188 + # model_dir="vits-zh-hf-zenyatta",
  189 + # model_name="zenyatta.onnx",
  190 + # lang="zh",
  191 + # rule_fsts="vits-zh-hf-zenyatta/rule.fst",
  192 + # ),
  193 + # TtsModel(
  194 + # model_dir="vits-zh-hf-abyssinvoker",
  195 + # model_name="abyssinvoker.onnx",
  196 + # lang="zh",
  197 + # rule_fsts="vits-zh-hf-abyssinvoker/rule.fst",
  198 + # ),
  199 + # TtsModel(
  200 + # model_dir="vits-zh-hf-keqing",
  201 + # model_name="keqing.onnx",
  202 + # lang="zh",
  203 + # rule_fsts="vits-zh-hf-keqing/rule.fst",
  204 + # ),
  205 + # TtsModel(
  206 + # model_dir="vits-zh-hf-eula",
  207 + # model_name="eula.onnx",
  208 + # lang="zh",
  209 + # rule_fsts="vits-zh-hf-eula/rule.fst",
  210 + # ),
  211 + # TtsModel(
  212 + # model_dir="vits-zh-hf-bronya",
  213 + # model_name="bronya.onnx",
  214 + # lang="zh",
  215 + # rule_fsts="vits-zh-hf-bronya/rule.fst",
  216 + # ),
  217 + # TtsModel(
  218 + # model_dir="vits-zh-hf-theresa",
  219 + # model_name="theresa.onnx",
  220 + # lang="zh",
  221 + # rule_fsts="vits-zh-hf-theresa/rule.fst",
  222 + # ),
223 # English (US) 223 # English (US)
224 TtsModel(model_dir="vits-vctk", model_name="vits-vctk.onnx", lang="en"), 224 TtsModel(model_dir="vits-vctk", model_name="vits-vctk.onnx", lang="en"),
225 - TtsModel(model_dir="vits-ljs", model_name="vits-ljs.onnx", lang="en"), 225 + # TtsModel(model_dir="vits-ljs", model_name="vits-ljs.onnx", lang="en"),
226 # fmt: on 226 # fmt: on
227 ] 227 ]
228 228
@@ -238,8 +238,8 @@ def main(): @@ -238,8 +238,8 @@ def main():
238 template = environment.from_string(s) 238 template = environment.from_string(s)
239 d = dict() 239 d = dict()
240 240
241 - # all_model_list = get_vits_models()  
242 - all_model_list = get_piper_models() 241 + all_model_list = get_vits_models()
  242 + all_model_list += get_piper_models()
243 all_model_list += get_coqui_models() 243 all_model_list += get_coqui_models()
244 244
245 num_models = len(all_model_list) 245 num_models = len(all_model_list)
@@ -11,13 +11,15 @@ @@ -11,13 +11,15 @@
11 // android-ndk/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include 11 // android-ndk/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include
12 #include "jni.h" // NOLINT 12 #include "jni.h" // NOLINT
13 13
  14 +#include <fstream>
  15 +#include <functional>
14 #include <strstream> 16 #include <strstream>
15 #include <utility> 17 #include <utility>
  18 +
16 #if __ANDROID_API__ >= 9 19 #if __ANDROID_API__ >= 9
17 #include "android/asset_manager.h" 20 #include "android/asset_manager.h"
18 #include "android/asset_manager_jni.h" 21 #include "android/asset_manager_jni.h"
19 #endif 22 #endif
20 -#include <fstream>  
21 23
22 #include "sherpa-onnx/csrc/macros.h" 24 #include "sherpa-onnx/csrc/macros.h"
23 #include "sherpa-onnx/csrc/offline-recognizer.h" 25 #include "sherpa-onnx/csrc/offline-recognizer.h"
@@ -502,11 +504,14 @@ class SherpaOnnxOfflineTts { @@ -502,11 +504,14 @@ class SherpaOnnxOfflineTts {
502 explicit SherpaOnnxOfflineTts(const OfflineTtsConfig &config) 504 explicit SherpaOnnxOfflineTts(const OfflineTtsConfig &config)
503 : tts_(config) {} 505 : tts_(config) {}
504 506
505 - GeneratedAudio Generate(const std::string &text, int64_t sid = 0,  
506 - float speed = 1.0) const {  
507 - return tts_.Generate(text, sid, speed); 507 + GeneratedAudio Generate(
  508 + const std::string &text, int64_t sid = 0, float speed = 1.0,
  509 + std::function<void(const float *, int32_t)> callback = nullptr) const {
  510 + return tts_.Generate(text, sid, speed, callback);
508 } 511 }
509 512
  513 + int32_t SampleRate() const { return tts_.SampleRate(); }
  514 +
510 private: 515 private:
511 OfflineTts tts_; 516 OfflineTts tts_;
512 }; 517 };
@@ -628,6 +633,13 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_delete( @@ -628,6 +633,13 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_delete(
628 delete reinterpret_cast<sherpa_onnx::SherpaOnnxOfflineTts *>(ptr); 633 delete reinterpret_cast<sherpa_onnx::SherpaOnnxOfflineTts *>(ptr);
629 } 634 }
630 635
  636 +SHERPA_ONNX_EXTERN_C
  637 +JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_getSampleRate(
  638 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  639 + return reinterpret_cast<sherpa_onnx::SherpaOnnxOfflineTts *>(ptr)
  640 + ->SampleRate();
  641 +}
  642 +
631 // see 643 // see
632 // https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables 644 // https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables
633 static jobject NewInteger(JNIEnv *env, int32_t value) { 645 static jobject NewInteger(JNIEnv *env, int32_t value) {
@@ -664,6 +676,43 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateImpl(JNIEnv *env, jobject /*obj*/, @@ -664,6 +676,43 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateImpl(JNIEnv *env, jobject /*obj*/,
664 } 676 }
665 677
666 SHERPA_ONNX_EXTERN_C 678 SHERPA_ONNX_EXTERN_C
  679 +JNIEXPORT jobjectArray JNICALL
  680 +Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl(
  681 + JNIEnv *env, jobject /*obj*/, jlong ptr, jstring text, jint sid,
  682 + jfloat speed, jobject callback) {
  683 + const char *p_text = env->GetStringUTFChars(text, nullptr);
  684 + SHERPA_ONNX_LOGE("string is: %s", p_text);
  685 +
  686 + std::function<void(const float *, int32_t)> callback_wrapper =
  687 + [env, callback](const float *samples, int32_t n) {
  688 + jclass cls = env->GetObjectClass(callback);
  689 + jmethodID mid = env->GetMethodID(cls, "invoke", "([F)V");
  690 +
  691 + jfloatArray samples_arr = env->NewFloatArray(n);
  692 + env->SetFloatArrayRegion(samples_arr, 0, n, samples);
  693 + env->CallVoidMethod(callback, mid, samples_arr);
  694 + };
  695 +
  696 + auto audio =
  697 + reinterpret_cast<sherpa_onnx::SherpaOnnxOfflineTts *>(ptr)->Generate(
  698 + p_text, sid, speed, callback_wrapper);
  699 +
  700 + jfloatArray samples_arr = env->NewFloatArray(audio.samples.size());
  701 + env->SetFloatArrayRegion(samples_arr, 0, audio.samples.size(),
  702 + audio.samples.data());
  703 +
  704 + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
  705 + 2, env->FindClass("java/lang/Object"), nullptr);
  706 +
  707 + env->SetObjectArrayElement(obj_arr, 0, samples_arr);
  708 + env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, audio.sample_rate));
  709 +
  710 + env->ReleaseStringUTFChars(text, p_text);
  711 +
  712 + return obj_arr;
  713 +}
  714 +
  715 +SHERPA_ONNX_EXTERN_C
667 JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_GeneratedAudio_saveImpl( 716 JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_GeneratedAudio_saveImpl(
668 JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples, 717 JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples,
669 jint sample_rate) { 718 jint sample_rate) {