正在显示
32 个修改的文件
包含
248 行增加
和
69 行删除
| @@ -8,7 +8,7 @@ project(sherpa-onnx) | @@ -8,7 +8,7 @@ project(sherpa-onnx) | ||
| 8 | # ./nodejs-addon-examples | 8 | # ./nodejs-addon-examples |
| 9 | # ./dart-api-examples/ | 9 | # ./dart-api-examples/ |
| 10 | # ./sherpa-onnx/flutter/CHANGELOG.md | 10 | # ./sherpa-onnx/flutter/CHANGELOG.md |
| 11 | -set(SHERPA_ONNX_VERSION "1.10.0") | 11 | +set(SHERPA_ONNX_VERSION "1.10.1") |
| 12 | 12 | ||
| 13 | # Disable warning about | 13 | # Disable warning about |
| 14 | # | 14 | # |
| @@ -26,6 +26,9 @@ class MainActivity : AppCompatActivity() { | @@ -26,6 +26,9 @@ class MainActivity : AppCompatActivity() { | ||
| 26 | private lateinit var speed: EditText | 26 | private lateinit var speed: EditText |
| 27 | private lateinit var generate: Button | 27 | private lateinit var generate: Button |
| 28 | private lateinit var play: Button | 28 | private lateinit var play: Button |
| 29 | + private lateinit var stop: Button | ||
| 30 | + private var stopped: Boolean = false | ||
| 31 | + private var mediaPlayer: MediaPlayer? = null | ||
| 29 | 32 | ||
| 30 | // see | 33 | // see |
| 31 | // https://developer.android.com/reference/kotlin/android/media/AudioTrack | 34 | // https://developer.android.com/reference/kotlin/android/media/AudioTrack |
| @@ -49,9 +52,11 @@ class MainActivity : AppCompatActivity() { | @@ -49,9 +52,11 @@ class MainActivity : AppCompatActivity() { | ||
| 49 | 52 | ||
| 50 | generate = findViewById(R.id.generate) | 53 | generate = findViewById(R.id.generate) |
| 51 | play = findViewById(R.id.play) | 54 | play = findViewById(R.id.play) |
| 55 | + stop = findViewById(R.id.stop) | ||
| 52 | 56 | ||
| 53 | generate.setOnClickListener { onClickGenerate() } | 57 | generate.setOnClickListener { onClickGenerate() } |
| 54 | play.setOnClickListener { onClickPlay() } | 58 | play.setOnClickListener { onClickPlay() } |
| 59 | + stop.setOnClickListener { onClickStop() } | ||
| 55 | 60 | ||
| 56 | sid.setText("0") | 61 | sid.setText("0") |
| 57 | speed.setText("1.0") | 62 | speed.setText("1.0") |
| @@ -70,7 +75,7 @@ class MainActivity : AppCompatActivity() { | @@ -70,7 +75,7 @@ class MainActivity : AppCompatActivity() { | ||
| 70 | AudioFormat.CHANNEL_OUT_MONO, | 75 | AudioFormat.CHANNEL_OUT_MONO, |
| 71 | AudioFormat.ENCODING_PCM_FLOAT | 76 | AudioFormat.ENCODING_PCM_FLOAT |
| 72 | ) | 77 | ) |
| 73 | - Log.i(TAG, "sampleRate: ${sampleRate}, buffLength: ${bufLength}") | 78 | + Log.i(TAG, "sampleRate: $sampleRate, buffLength: $bufLength") |
| 74 | 79 | ||
| 75 | val attr = AudioAttributes.Builder().setContentType(AudioAttributes.CONTENT_TYPE_SPEECH) | 80 | val attr = AudioAttributes.Builder().setContentType(AudioAttributes.CONTENT_TYPE_SPEECH) |
| 76 | .setUsage(AudioAttributes.USAGE_MEDIA) | 81 | .setUsage(AudioAttributes.USAGE_MEDIA) |
| @@ -90,8 +95,14 @@ class MainActivity : AppCompatActivity() { | @@ -90,8 +95,14 @@ class MainActivity : AppCompatActivity() { | ||
| 90 | } | 95 | } |
| 91 | 96 | ||
| 92 | // this function is called from C++ | 97 | // this function is called from C++ |
| 93 | - private fun callback(samples: FloatArray) { | ||
| 94 | - track.write(samples, 0, samples.size, AudioTrack.WRITE_BLOCKING) | 98 | + private fun callback(samples: FloatArray): Int { |
| 99 | + if (!stopped) { | ||
| 100 | + track.write(samples, 0, samples.size, AudioTrack.WRITE_BLOCKING) | ||
| 101 | + return 1 | ||
| 102 | + } else { | ||
| 103 | + track.stop() | ||
| 104 | + return 0 | ||
| 105 | + } | ||
| 95 | } | 106 | } |
| 96 | 107 | ||
| 97 | private fun onClickGenerate() { | 108 | private fun onClickGenerate() { |
| @@ -127,6 +138,8 @@ class MainActivity : AppCompatActivity() { | @@ -127,6 +138,8 @@ class MainActivity : AppCompatActivity() { | ||
| 127 | track.play() | 138 | track.play() |
| 128 | 139 | ||
| 129 | play.isEnabled = false | 140 | play.isEnabled = false |
| 141 | + generate.isEnabled = false | ||
| 142 | + stopped = false | ||
| 130 | Thread { | 143 | Thread { |
| 131 | val audio = tts.generateWithCallback( | 144 | val audio = tts.generateWithCallback( |
| 132 | text = textStr, | 145 | text = textStr, |
| @@ -140,6 +153,7 @@ class MainActivity : AppCompatActivity() { | @@ -140,6 +153,7 @@ class MainActivity : AppCompatActivity() { | ||
| 140 | if (ok) { | 153 | if (ok) { |
| 141 | runOnUiThread { | 154 | runOnUiThread { |
| 142 | play.isEnabled = true | 155 | play.isEnabled = true |
| 156 | + generate.isEnabled = true | ||
| 143 | track.stop() | 157 | track.stop() |
| 144 | } | 158 | } |
| 145 | } | 159 | } |
| @@ -148,11 +162,22 @@ class MainActivity : AppCompatActivity() { | @@ -148,11 +162,22 @@ class MainActivity : AppCompatActivity() { | ||
| 148 | 162 | ||
| 149 | private fun onClickPlay() { | 163 | private fun onClickPlay() { |
| 150 | val filename = application.filesDir.absolutePath + "/generated.wav" | 164 | val filename = application.filesDir.absolutePath + "/generated.wav" |
| 151 | - val mediaPlayer = MediaPlayer.create( | 165 | + mediaPlayer?.stop() |
| 166 | + mediaPlayer = MediaPlayer.create( | ||
| 152 | applicationContext, | 167 | applicationContext, |
| 153 | Uri.fromFile(File(filename)) | 168 | Uri.fromFile(File(filename)) |
| 154 | ) | 169 | ) |
| 155 | - mediaPlayer.start() | 170 | + mediaPlayer?.start() |
| 171 | + } | ||
| 172 | + | ||
| 173 | + private fun onClickStop() { | ||
| 174 | + stopped = true | ||
| 175 | + play.isEnabled = true | ||
| 176 | + generate.isEnabled = true | ||
| 177 | + track.pause() | ||
| 178 | + track.flush() | ||
| 179 | + mediaPlayer?.stop() | ||
| 180 | + mediaPlayer = null | ||
| 156 | } | 181 | } |
| 157 | 182 | ||
| 158 | private fun initTts() { | 183 | private fun initTts() { |
| @@ -76,7 +76,7 @@ class OfflineTts( | @@ -76,7 +76,7 @@ class OfflineTts( | ||
| 76 | text: String, | 76 | text: String, |
| 77 | sid: Int = 0, | 77 | sid: Int = 0, |
| 78 | speed: Float = 1.0f, | 78 | speed: Float = 1.0f, |
| 79 | - callback: (samples: FloatArray) -> Unit | 79 | + callback: (samples: FloatArray) -> Int |
| 80 | ): GeneratedAudio { | 80 | ): GeneratedAudio { |
| 81 | val objArray = generateWithCallbackImpl( | 81 | val objArray = generateWithCallbackImpl( |
| 82 | ptr, | 82 | ptr, |
| @@ -146,7 +146,7 @@ class OfflineTts( | @@ -146,7 +146,7 @@ class OfflineTts( | ||
| 146 | text: String, | 146 | text: String, |
| 147 | sid: Int = 0, | 147 | sid: Int = 0, |
| 148 | speed: Float = 1.0f, | 148 | speed: Float = 1.0f, |
| 149 | - callback: (samples: FloatArray) -> Unit | 149 | + callback: (samples: FloatArray) -> Int |
| 150 | ): Array<Any> | 150 | ): Array<Any> |
| 151 | 151 | ||
| 152 | companion object { | 152 | companion object { |
| @@ -84,4 +84,16 @@ | @@ -84,4 +84,16 @@ | ||
| 84 | app:layout_constraintLeft_toLeftOf="parent" | 84 | app:layout_constraintLeft_toLeftOf="parent" |
| 85 | app:layout_constraintRight_toRightOf="parent" | 85 | app:layout_constraintRight_toRightOf="parent" |
| 86 | app:layout_constraintTop_toBottomOf="@id/generate" /> | 86 | app:layout_constraintTop_toBottomOf="@id/generate" /> |
| 87 | + | ||
| 88 | + <Button | ||
| 89 | + android:id="@+id/stop" | ||
| 90 | + android:textAllCaps="false" | ||
| 91 | + android:layout_width="match_parent" | ||
| 92 | + android:layout_height="50dp" | ||
| 93 | + android:layout_marginTop="4dp" | ||
| 94 | + android:text="@string/stop" | ||
| 95 | + app:layout_constraintLeft_toLeftOf="parent" | ||
| 96 | + app:layout_constraintRight_toRightOf="parent" | ||
| 97 | + app:layout_constraintTop_toBottomOf="@id/play" /> | ||
| 98 | + | ||
| 87 | </androidx.constraintlayout.widget.ConstraintLayout> | 99 | </androidx.constraintlayout.widget.ConstraintLayout> |
| @@ -7,4 +7,5 @@ | @@ -7,4 +7,5 @@ | ||
| 7 | <string name="text_hint">Please input your text here</string> | 7 | <string name="text_hint">Please input your text here</string> |
| 8 | <string name="generate">Generate</string> | 8 | <string name="generate">Generate</string> |
| 9 | <string name="play">Play</string> | 9 | <string name="play">Play</string> |
| 10 | + <string name="stop">Stop</string> | ||
| 10 | </resources> | 11 | </resources> |
| @@ -126,7 +126,7 @@ class TtsService : TextToSpeechService() { | @@ -126,7 +126,7 @@ class TtsService : TextToSpeechService() { | ||
| 126 | return | 126 | return |
| 127 | } | 127 | } |
| 128 | 128 | ||
| 129 | - val ttsCallback = { floatSamples: FloatArray -> | 129 | + val ttsCallback: (FloatArray) -> Int = fun(floatSamples): Int { |
| 130 | // convert FloatArray to ByteArray | 130 | // convert FloatArray to ByteArray |
| 131 | val samples = floatArrayToByteArray(floatSamples) | 131 | val samples = floatArrayToByteArray(floatSamples) |
| 132 | val maxBufferSize: Int = callback.maxBufferSize | 132 | val maxBufferSize: Int = callback.maxBufferSize |
| @@ -137,6 +137,9 @@ class TtsService : TextToSpeechService() { | @@ -137,6 +137,9 @@ class TtsService : TextToSpeechService() { | ||
| 137 | offset += bytesToWrite | 137 | offset += bytesToWrite |
| 138 | } | 138 | } |
| 139 | 139 | ||
| 140 | + // 1 means to continue | ||
| 141 | + // 0 means to stop | ||
| 142 | + return 1 | ||
| 140 | } | 143 | } |
| 141 | 144 | ||
| 142 | Log.i(TAG, "text: $text") | 145 | Log.i(TAG, "text: $text") |
| @@ -160,4 +163,4 @@ class TtsService : TextToSpeechService() { | @@ -160,4 +163,4 @@ class TtsService : TextToSpeechService() { | ||
| 160 | } | 163 | } |
| 161 | return byteArray | 164 | return byteArray |
| 162 | } | 165 | } |
| 163 | -} | ||
| 166 | +} |
| @@ -68,6 +68,10 @@ void main(List<String> arguments) async { | @@ -68,6 +68,10 @@ void main(List<String> arguments) async { | ||
| 68 | callback: (Float32List samples) { | 68 | callback: (Float32List samples) { |
| 69 | print('${samples.length} samples received'); | 69 | print('${samples.length} samples received'); |
| 70 | // You can play samples in a separate thread/isolate | 70 | // You can play samples in a separate thread/isolate |
| 71 | + | ||
| 72 | + // 1 means to continue | ||
| 73 | + // 0 means to stop | ||
| 74 | + return 1; | ||
| 71 | }); | 75 | }); |
| 72 | tts.free(); | 76 | tts.free(); |
| 73 | 77 |
| @@ -187,6 +187,10 @@ to download more models. | @@ -187,6 +187,10 @@ to download more models. | ||
| 187 | Marshal.Copy(samples, data, 0, n); | 187 | Marshal.Copy(samples, data, 0, n); |
| 188 | 188 | ||
| 189 | dataItems.Add(data); | 189 | dataItems.Add(data); |
| 190 | + | ||
| 191 | + // 1 means to keep generating | ||
| 192 | + // 0 means to stop generating | ||
| 193 | + return 1; | ||
| 190 | }; | 194 | }; |
| 191 | 195 | ||
| 192 | bool playFinished = false; | 196 | bool playFinished = false; |
| @@ -25,6 +25,46 @@ fun testTts() { | @@ -25,6 +25,46 @@ fun testTts() { | ||
| 25 | println("Saved to test-en.wav") | 25 | println("Saved to test-en.wav") |
| 26 | } | 26 | } |
| 27 | 27 | ||
| 28 | -fun callback(samples: FloatArray): Unit { | 28 | +/* |
| 29 | +1. Unzip test_tts.jar | ||
| 30 | +2. | ||
| 31 | +javap ./com/k2fsa/sherpa/onnx/Test_ttsKt\$testTts\$audio\$1.class | ||
| 32 | + | ||
| 33 | +3. It prints: | ||
| 34 | +Compiled from "test_tts.kt" | ||
| 35 | +final class com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1 extends kotlin.jvm.internal.FunctionReferenceImpl implements kotlin.jvm.functions.Function1<float[], java.lang.Integer> { | ||
| 36 | + public static final com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1 INSTANCE; | ||
| 37 | + com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1(); | ||
| 38 | + public final java.lang.Integer invoke(float[]); | ||
| 39 | + public java.lang.Object invoke(java.lang.Object); | ||
| 40 | + static {}; | ||
| 41 | +} | ||
| 42 | + | ||
| 43 | +4. | ||
| 44 | +javap -s ./com/k2fsa/sherpa/onnx/Test_ttsKt\$testTts\$audio\$1.class | ||
| 45 | + | ||
| 46 | +5. It prints | ||
| 47 | +Compiled from "test_tts.kt" | ||
| 48 | +final class com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1 extends kotlin.jvm.internal.FunctionReferenceImpl implements kotlin.jvm.functions.Function1<float[], java.lang.Integer> { | ||
| 49 | + public static final com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1 INSTANCE; | ||
| 50 | + descriptor: Lcom/k2fsa/sherpa/onnx/Test_ttsKt$testTts$audio$1; | ||
| 51 | + com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1(); | ||
| 52 | + descriptor: ()V | ||
| 53 | + | ||
| 54 | + public final java.lang.Integer invoke(float[]); | ||
| 55 | + descriptor: ([F)Ljava/lang/Integer; | ||
| 56 | + | ||
| 57 | + public java.lang.Object invoke(java.lang.Object); | ||
| 58 | + descriptor: (Ljava/lang/Object;)Ljava/lang/Object; | ||
| 59 | + | ||
| 60 | + static {}; | ||
| 61 | + descriptor: ()V | ||
| 62 | +} | ||
| 63 | +*/ | ||
| 64 | +fun callback(samples: FloatArray): Int { | ||
| 29 | println("callback got called with ${samples.size} samples"); | 65 | println("callback got called with ${samples.size} samples"); |
| 66 | + | ||
| 67 | + // 1 means to continue | ||
| 68 | + // 0 means to stop | ||
| 69 | + return 1 | ||
| 30 | } | 70 | } |
| @@ -57,7 +57,7 @@ static bool g_started = false; | @@ -57,7 +57,7 @@ static bool g_started = false; | ||
| 57 | static bool g_stopped = false; | 57 | static bool g_stopped = false; |
| 58 | static bool g_killed = false; | 58 | static bool g_killed = false; |
| 59 | 59 | ||
| 60 | -static void AudioGeneratedCallback(const float *s, int32_t n) { | 60 | +static int32_t AudioGeneratedCallback(const float *s, int32_t n) { |
| 61 | if (n > 0) { | 61 | if (n > 0) { |
| 62 | Samples samples; | 62 | Samples samples; |
| 63 | samples.data = std::vector<float>{s, s + n}; | 63 | samples.data = std::vector<float>{s, s + n}; |
| @@ -66,6 +66,10 @@ static void AudioGeneratedCallback(const float *s, int32_t n) { | @@ -66,6 +66,10 @@ static void AudioGeneratedCallback(const float *s, int32_t n) { | ||
| 66 | g_buffer.samples.push(std::move(samples)); | 66 | g_buffer.samples.push(std::move(samples)); |
| 67 | g_started = true; | 67 | g_started = true; |
| 68 | } | 68 | } |
| 69 | + if (g_killed) { | ||
| 70 | + return 0; | ||
| 71 | + } | ||
| 72 | + return 1; | ||
| 69 | } | 73 | } |
| 70 | 74 | ||
| 71 | static int PlayCallback(const void * /*in*/, void *out, | 75 | static int PlayCallback(const void * /*in*/, void *out, |
| @@ -324,6 +328,7 @@ BEGIN_MESSAGE_MAP(CNonStreamingTextToSpeechDlg, CDialogEx) | @@ -324,6 +328,7 @@ BEGIN_MESSAGE_MAP(CNonStreamingTextToSpeechDlg, CDialogEx) | ||
| 324 | ON_WM_PAINT() | 328 | ON_WM_PAINT() |
| 325 | ON_WM_QUERYDRAGICON() | 329 | ON_WM_QUERYDRAGICON() |
| 326 | ON_BN_CLICKED(IDOK, &CNonStreamingTextToSpeechDlg::OnBnClickedOk) | 330 | ON_BN_CLICKED(IDOK, &CNonStreamingTextToSpeechDlg::OnBnClickedOk) |
| 331 | + ON_BN_CLICKED(IDC_STOP, &CNonStreamingTextToSpeechDlg::OnBnClickedStop) | ||
| 327 | END_MESSAGE_MAP() | 332 | END_MESSAGE_MAP() |
| 328 | 333 | ||
| 329 | 334 | ||
| @@ -492,11 +497,18 @@ void CNonStreamingTextToSpeechDlg::Init() { | @@ -492,11 +497,18 @@ void CNonStreamingTextToSpeechDlg::Init() { | ||
| 492 | if (tts_) { | 497 | if (tts_) { |
| 493 | SherpaOnnxDestroyOfflineTts(tts_); | 498 | SherpaOnnxDestroyOfflineTts(tts_); |
| 494 | } | 499 | } |
| 500 | + if (generate_thread_ && generate_thread_->joinable()) { | ||
| 501 | + generate_thread_->join(); | ||
| 502 | + } | ||
| 503 | + | ||
| 504 | + if (play_thread_ && play_thread_->joinable()) { | ||
| 505 | + play_thread_->join(); | ||
| 506 | + } | ||
| 495 | } | 507 | } |
| 496 | 508 | ||
| 497 | 509 | ||
| 498 | static std::string ToString(const CString &s) { | 510 | static std::string ToString(const CString &s) { |
| 499 | - CT2CA pszConvertedAnsiString( s); | 511 | + CT2CA pszConvertedAnsiString(s); |
| 500 | return std::string(pszConvertedAnsiString); | 512 | return std::string(pszConvertedAnsiString); |
| 501 | } | 513 | } |
| 502 | 514 | ||
| @@ -510,7 +522,7 @@ void CNonStreamingTextToSpeechDlg::OnBnClickedOk() { | @@ -510,7 +522,7 @@ void CNonStreamingTextToSpeechDlg::OnBnClickedOk() { | ||
| 510 | } | 522 | } |
| 511 | 523 | ||
| 512 | speed_.GetWindowText(s); | 524 | speed_.GetWindowText(s); |
| 513 | - float speed = static_cast<float>(_ttof(s)); | 525 | + float speed = static_cast<float>(_ttof(s)); |
| 514 | if (speed < 0) { | 526 | if (speed < 0) { |
| 515 | AfxMessageBox(Utf8ToUtf16("Please input a valid speed").c_str(), MB_OK); | 527 | AfxMessageBox(Utf8ToUtf16("Please input a valid speed").c_str(), MB_OK); |
| 516 | return; | 528 | return; |
| @@ -541,28 +553,40 @@ void CNonStreamingTextToSpeechDlg::OnBnClickedOk() { | @@ -541,28 +553,40 @@ void CNonStreamingTextToSpeechDlg::OnBnClickedOk() { | ||
| 541 | // for simplicity | 553 | // for simplicity |
| 542 | play_thread_ = std::make_unique<std::thread>(StartPlayback, SherpaOnnxOfflineTtsSampleRate(tts_)); | 554 | play_thread_ = std::make_unique<std::thread>(StartPlayback, SherpaOnnxOfflineTtsSampleRate(tts_)); |
| 543 | 555 | ||
| 544 | - generate_btn_.EnableWindow(FALSE); | ||
| 545 | - | ||
| 546 | - const SherpaOnnxGeneratedAudio *audio = | ||
| 547 | - SherpaOnnxOfflineTtsGenerateWithCallback(tts_, ss.c_str(), speaker_id, speed, &AudioGeneratedCallback); | ||
| 548 | - | ||
| 549 | - generate_btn_.EnableWindow(TRUE); | 556 | + if (generate_thread_ && generate_thread_->joinable()) { |
| 557 | + generate_thread_->join(); | ||
| 558 | + } | ||
| 550 | 559 | ||
| 551 | output_filename_.GetWindowText(s); | 560 | output_filename_.GetWindowText(s); |
| 552 | std::string filename = ToString(s); | 561 | std::string filename = ToString(s); |
| 553 | 562 | ||
| 554 | - int ok = SherpaOnnxWriteWave(audio->samples, audio->n, audio->sample_rate, | ||
| 555 | - filename.c_str()); | 563 | + generate_thread_ = std::make_unique<std::thread>([ss, this,filename, speaker_id, speed]() { |
| 564 | + std::string text = ss; | ||
| 556 | 565 | ||
| 557 | - SherpaOnnxDestroyOfflineTtsGeneratedAudio(audio); | 566 | + // generate_btn_.EnableWindow(FALSE); |
| 558 | 567 | ||
| 559 | - if (ok) { | ||
| 560 | - // AfxMessageBox(Utf8ToUtf16(std::string("Saved to ") + filename + " successfully").c_str(), MB_OK); | ||
| 561 | - AppendLineToMultilineEditCtrl(my_hint_, std::string("Saved to ") + filename + " successfully"); | ||
| 562 | - } else { | ||
| 563 | - // AfxMessageBox(Utf8ToUtf16(std::string("Failed to save to ") + filename).c_str(), MB_OK); | ||
| 564 | - AppendLineToMultilineEditCtrl(my_hint_, std::string("Failed to saved to ") + filename); | ||
| 565 | - } | 568 | + const SherpaOnnxGeneratedAudio *audio = |
| 569 | + SherpaOnnxOfflineTtsGenerateWithCallback(tts_, text.c_str(), speaker_id, speed, &AudioGeneratedCallback); | ||
| 570 | + // generate_btn_.EnableWindow(TRUE); | ||
| 571 | + g_stopped = true; | ||
| 572 | + | ||
| 573 | + int ok = SherpaOnnxWriteWave(audio->samples, audio->n, audio->sample_rate, | ||
| 574 | + filename.c_str()); | ||
| 575 | + | ||
| 576 | + SherpaOnnxDestroyOfflineTtsGeneratedAudio(audio); | ||
| 577 | + | ||
| 578 | + if (ok) { | ||
| 579 | + // AfxMessageBox(Utf8ToUtf16(std::string("Saved to ") + filename + " successfully").c_str(), MB_OK); | ||
| 580 | + | ||
| 581 | + // AppendLineToMultilineEditCtrl(my_hint_, std::string("Saved to ") + filename + " successfully"); | ||
| 582 | + } else { | ||
| 583 | + // AfxMessageBox(Utf8ToUtf16(std::string("Failed to save to ") + filename).c_str(), MB_OK); | ||
| 584 | + | ||
| 585 | + // AppendLineToMultilineEditCtrl(my_hint_, std::string("Failed to saved to ") + filename); | ||
| 586 | + } | ||
| 587 | + }); | ||
| 566 | 588 | ||
| 567 | //CDialogEx::OnOK(); | 589 | //CDialogEx::OnOK(); |
| 568 | } | 590 | } |
| 591 | + | ||
| 592 | +void CNonStreamingTextToSpeechDlg::OnBnClickedStop() { g_killed = true; } |
| @@ -60,5 +60,8 @@ public: | @@ -60,5 +60,8 @@ public: | ||
| 60 | private: | 60 | private: |
| 61 | Microphone mic_; | 61 | Microphone mic_; |
| 62 | std::unique_ptr<std::thread> play_thread_; | 62 | std::unique_ptr<std::thread> play_thread_; |
| 63 | + std::unique_ptr<std::thread> generate_thread_; | ||
| 63 | 64 | ||
| 65 | + public: | ||
| 66 | + afx_msg void OnBnClickedStop(); | ||
| 64 | }; | 67 | }; |
| @@ -13,6 +13,7 @@ | @@ -13,6 +13,7 @@ | ||
| 13 | #define IDC_HINT 1005 | 13 | #define IDC_HINT 1005 |
| 14 | #define IDC_EDIT1 1006 | 14 | #define IDC_EDIT1 1006 |
| 15 | #define IDC_OUTPUT_FILENAME 1006 | 15 | #define IDC_OUTPUT_FILENAME 1006 |
| 16 | +#define IDC_STOP 1009 | ||
| 16 | 17 | ||
| 17 | // Next default values for new objects | 18 | // Next default values for new objects |
| 18 | // | 19 | // |
| @@ -20,7 +21,7 @@ | @@ -20,7 +21,7 @@ | ||
| 20 | #ifndef APSTUDIO_READONLY_SYMBOLS | 21 | #ifndef APSTUDIO_READONLY_SYMBOLS |
| 21 | #define _APS_NEXT_RESOURCE_VALUE 130 | 22 | #define _APS_NEXT_RESOURCE_VALUE 130 |
| 22 | #define _APS_NEXT_COMMAND_VALUE 32771 | 23 | #define _APS_NEXT_COMMAND_VALUE 32771 |
| 23 | -#define _APS_NEXT_CONTROL_VALUE 1007 | 24 | +#define _APS_NEXT_CONTROL_VALUE 1010 |
| 24 | #define _APS_NEXT_SYMED_VALUE 101 | 25 | #define _APS_NEXT_SYMED_VALUE 101 |
| 25 | #endif | 26 | #endif |
| 26 | #endif | 27 | #endif |
| @@ -228,6 +228,13 @@ def generated_audio_callback(samples: np.ndarray, progress: float): | @@ -228,6 +228,13 @@ def generated_audio_callback(samples: np.ndarray, progress: float): | ||
| 228 | logging.info("Start playing ...") | 228 | logging.info("Start playing ...") |
| 229 | started = True | 229 | started = True |
| 230 | 230 | ||
| 231 | + # 1 means to keep generating | ||
| 232 | + # 0 means to stop generating | ||
| 233 | + if killed: | ||
| 234 | + return 0 | ||
| 235 | + | ||
| 236 | + return 1 | ||
| 237 | + | ||
| 231 | 238 | ||
| 232 | # see https://python-sounddevice.readthedocs.io/en/0.4.6/api/streams.html#sounddevice.OutputStream | 239 | # see https://python-sounddevice.readthedocs.io/en/0.4.6/api/streams.html#sounddevice.OutputStream |
| 233 | def play_audio_callback( | 240 | def play_audio_callback( |
| @@ -8,8 +8,8 @@ using System; | @@ -8,8 +8,8 @@ using System; | ||
| 8 | 8 | ||
| 9 | namespace SherpaOnnx | 9 | namespace SherpaOnnx |
| 10 | { | 10 | { |
| 11 | - // IntPtr is actuallly a `const float*` from C++ | ||
| 12 | - public delegate void OfflineTtsCallback(IntPtr samples, int n); | 11 | + // IntPtr is actually a `const float*` from C++ |
| 12 | + public delegate int OfflineTtsCallback(IntPtr samples, int n); | ||
| 13 | 13 | ||
| 14 | public class OfflineTts : IDisposable | 14 | public class OfflineTts : IDisposable |
| 15 | { | 15 | { |
| @@ -88,4 +88,4 @@ namespace SherpaOnnx | @@ -88,4 +88,4 @@ namespace SherpaOnnx | ||
| 88 | [DllImport(Dll.Filename, CallingConvention = CallingConvention.Cdecl)] | 88 | [DllImport(Dll.Filename, CallingConvention = CallingConvention.Cdecl)] |
| 89 | private static extern IntPtr SherpaOnnxOfflineTtsGenerateWithCallback(IntPtr handle, [MarshalAs(UnmanagedType.LPStr)] string text, int sid, float speed, OfflineTtsCallback callback); | 89 | private static extern IntPtr SherpaOnnxOfflineTtsGenerateWithCallback(IntPtr handle, [MarshalAs(UnmanagedType.LPStr)] string text, int sid, float speed, OfflineTtsCallback callback); |
| 90 | } | 90 | } |
| 91 | -} | ||
| 91 | +} |
| @@ -935,7 +935,7 @@ int32_t SherpaOnnxOfflineTtsNumSpeakers(const SherpaOnnxOfflineTts *tts) { | @@ -935,7 +935,7 @@ int32_t SherpaOnnxOfflineTtsNumSpeakers(const SherpaOnnxOfflineTts *tts) { | ||
| 935 | 935 | ||
| 936 | static const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateInternal( | 936 | static const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateInternal( |
| 937 | const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed, | 937 | const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed, |
| 938 | - std::function<void(const float *, int32_t, float)> callback) { | 938 | + std::function<int32_t(const float *, int32_t, float)> callback) { |
| 939 | sherpa_onnx::GeneratedAudio audio = | 939 | sherpa_onnx::GeneratedAudio audio = |
| 940 | tts->impl->Generate(text, sid, speed, callback); | 940 | tts->impl->Generate(text, sid, speed, callback); |
| 941 | 941 | ||
| @@ -965,7 +965,9 @@ const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallback( | @@ -965,7 +965,9 @@ const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallback( | ||
| 965 | const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed, | 965 | const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed, |
| 966 | SherpaOnnxGeneratedAudioCallback callback) { | 966 | SherpaOnnxGeneratedAudioCallback callback) { |
| 967 | auto wrapper = [callback](const float *samples, int32_t n, | 967 | auto wrapper = [callback](const float *samples, int32_t n, |
| 968 | - float /*progress*/) { callback(samples, n); }; | 968 | + float /*progress*/) { |
| 969 | + return callback(samples, n); | ||
| 970 | + }; | ||
| 969 | 971 | ||
| 970 | return SherpaOnnxOfflineTtsGenerateInternal(tts, text, sid, speed, wrapper); | 972 | return SherpaOnnxOfflineTtsGenerateInternal(tts, text, sid, speed, wrapper); |
| 971 | } | 973 | } |
| @@ -975,7 +977,7 @@ SherpaOnnxOfflineTtsGenerateWithProgressCallback( | @@ -975,7 +977,7 @@ SherpaOnnxOfflineTtsGenerateWithProgressCallback( | ||
| 975 | const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed, | 977 | const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed, |
| 976 | SherpaOnnxGeneratedAudioProgressCallback callback) { | 978 | SherpaOnnxGeneratedAudioProgressCallback callback) { |
| 977 | auto wrapper = [callback](const float *samples, int32_t n, float progress) { | 979 | auto wrapper = [callback](const float *samples, int32_t n, float progress) { |
| 978 | - callback(samples, n, progress); | 980 | + return callback(samples, n, progress); |
| 979 | }; | 981 | }; |
| 980 | return SherpaOnnxOfflineTtsGenerateInternal(tts, text, sid, speed, wrapper); | 982 | return SherpaOnnxOfflineTtsGenerateInternal(tts, text, sid, speed, wrapper); |
| 981 | } | 983 | } |
| @@ -985,7 +987,7 @@ const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallbackWithArg( | @@ -985,7 +987,7 @@ const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallbackWithArg( | ||
| 985 | SherpaOnnxGeneratedAudioCallbackWithArg callback, void *arg) { | 987 | SherpaOnnxGeneratedAudioCallbackWithArg callback, void *arg) { |
| 986 | auto wrapper = [callback, arg](const float *samples, int32_t n, | 988 | auto wrapper = [callback, arg](const float *samples, int32_t n, |
| 987 | float /*progress*/) { | 989 | float /*progress*/) { |
| 988 | - callback(samples, n, arg); | 990 | + return callback(samples, n, arg); |
| 989 | }; | 991 | }; |
| 990 | 992 | ||
| 991 | return SherpaOnnxOfflineTtsGenerateInternal(tts, text, sid, speed, wrapper); | 993 | return SherpaOnnxOfflineTtsGenerateInternal(tts, text, sid, speed, wrapper); |
| @@ -850,14 +850,17 @@ SHERPA_ONNX_API typedef struct SherpaOnnxGeneratedAudio { | @@ -850,14 +850,17 @@ SHERPA_ONNX_API typedef struct SherpaOnnxGeneratedAudio { | ||
| 850 | int32_t sample_rate; | 850 | int32_t sample_rate; |
| 851 | } SherpaOnnxGeneratedAudio; | 851 | } SherpaOnnxGeneratedAudio; |
| 852 | 852 | ||
| 853 | -typedef void (*SherpaOnnxGeneratedAudioCallback)(const float *samples, | ||
| 854 | - int32_t n); | 853 | +// If the callback returns 0, then it stops generating |
| 854 | +// If the callback returns 1, then it keeps generating | ||
| 855 | +typedef int32_t (*SherpaOnnxGeneratedAudioCallback)(const float *samples, | ||
| 856 | + int32_t n); | ||
| 855 | 857 | ||
| 856 | -typedef void (*SherpaOnnxGeneratedAudioCallbackWithArg)(const float *samples, | ||
| 857 | - int32_t n, void *arg); | 858 | +typedef int32_t (*SherpaOnnxGeneratedAudioCallbackWithArg)(const float *samples, |
| 859 | + int32_t n, | ||
| 860 | + void *arg); | ||
| 858 | 861 | ||
| 859 | -typedef void (*SherpaOnnxGeneratedAudioProgressCallback)(const float *samples, | ||
| 860 | - int32_t n, float p); | 862 | +typedef int32_t (*SherpaOnnxGeneratedAudioProgressCallback)( |
| 863 | + const float *samples, int32_t n, float p); | ||
| 861 | 864 | ||
| 862 | SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTts SherpaOnnxOfflineTts; | 865 | SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTts SherpaOnnxOfflineTts; |
| 863 | 866 |
| @@ -216,9 +216,11 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | @@ -216,9 +216,11 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | ||
| 216 | 216 | ||
| 217 | GeneratedAudio ans; | 217 | GeneratedAudio ans; |
| 218 | 218 | ||
| 219 | + int32_t should_continue = 1; | ||
| 220 | + | ||
| 219 | int32_t k = 0; | 221 | int32_t k = 0; |
| 220 | 222 | ||
| 221 | - for (int32_t b = 0; b != num_batches; ++b) { | 223 | + for (int32_t b = 0; b != num_batches && should_continue; ++b) { |
| 222 | batch.clear(); | 224 | batch.clear(); |
| 223 | for (int32_t i = 0; i != batch_size; ++i, ++k) { | 225 | for (int32_t i = 0; i != batch_size; ++i, ++k) { |
| 224 | batch.push_back(std::move(x[k])); | 226 | batch.push_back(std::move(x[k])); |
| @@ -229,8 +231,8 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | @@ -229,8 +231,8 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | ||
| 229 | ans.samples.insert(ans.samples.end(), audio.samples.begin(), | 231 | ans.samples.insert(ans.samples.end(), audio.samples.begin(), |
| 230 | audio.samples.end()); | 232 | audio.samples.end()); |
| 231 | if (callback) { | 233 | if (callback) { |
| 232 | - callback(audio.samples.data(), audio.samples.size(), | ||
| 233 | - b * 1.0 / num_batches); | 234 | + should_continue = callback(audio.samples.data(), audio.samples.size(), |
| 235 | + b * 1.0 / num_batches); | ||
| 234 | // Caution(fangjun): audio is freed when the callback returns, so users | 236 | // Caution(fangjun): audio is freed when the callback returns, so users |
| 235 | // should copy the data if they want to access the data after | 237 | // should copy the data if they want to access the data after |
| 236 | // the callback returns to avoid segmentation fault. | 238 | // the callback returns to avoid segmentation fault. |
| @@ -238,7 +240,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | @@ -238,7 +240,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | ||
| 238 | } | 240 | } |
| 239 | 241 | ||
| 240 | batch.clear(); | 242 | batch.clear(); |
| 241 | - while (k < static_cast<int32_t>(x.size())) { | 243 | + while (k < static_cast<int32_t>(x.size()) && should_continue) { |
| 242 | batch.push_back(std::move(x[k])); | 244 | batch.push_back(std::move(x[k])); |
| 243 | ++k; | 245 | ++k; |
| 244 | } | 246 | } |
| @@ -59,7 +59,9 @@ struct GeneratedAudio { | @@ -59,7 +59,9 @@ struct GeneratedAudio { | ||
| 59 | 59 | ||
| 60 | class OfflineTtsImpl; | 60 | class OfflineTtsImpl; |
| 61 | 61 | ||
| 62 | -using GeneratedAudioCallback = std::function<void( | 62 | +// If the callback returns 0, then it stop generating |
| 63 | +// if the callback returns 1, then it keeps generating | ||
| 64 | +using GeneratedAudioCallback = std::function<int32_t( | ||
| 63 | const float * /*samples*/, int32_t /*n*/, float /*progress*/)>; | 65 | const float * /*samples*/, int32_t /*n*/, float /*progress*/)>; |
| 64 | 66 | ||
| 65 | class OfflineTts { | 67 | class OfflineTts { |
| @@ -44,13 +44,20 @@ static void Handler(int32_t /*sig*/) { | @@ -44,13 +44,20 @@ static void Handler(int32_t /*sig*/) { | ||
| 44 | fprintf(stderr, "\nCaught Ctrl + C. Exiting\n"); | 44 | fprintf(stderr, "\nCaught Ctrl + C. Exiting\n"); |
| 45 | } | 45 | } |
| 46 | 46 | ||
| 47 | -static void AudioGeneratedCallback(const float *s, int32_t n, | ||
| 48 | - float /*progress*/) { | 47 | +static int32_t AudioGeneratedCallback(const float *s, int32_t n, |
| 48 | + float /*progress*/) { | ||
| 49 | if (n > 0) { | 49 | if (n > 0) { |
| 50 | std::lock_guard<std::mutex> lock(g_buffer.mutex); | 50 | std::lock_guard<std::mutex> lock(g_buffer.mutex); |
| 51 | g_buffer.samples.push({s, s + n}); | 51 | g_buffer.samples.push({s, s + n}); |
| 52 | g_cv.notify_all(); | 52 | g_cv.notify_all(); |
| 53 | } | 53 | } |
| 54 | + | ||
| 55 | + if (g_killed) { | ||
| 56 | + return 0; // stop generating | ||
| 57 | + } | ||
| 58 | + | ||
| 59 | + // continue generating | ||
| 60 | + return 1; | ||
| 54 | } | 61 | } |
| 55 | 62 | ||
| 56 | static void StartPlayback(const std::string &device_name, int32_t sample_rate) { | 63 | static void StartPlayback(const std::string &device_name, int32_t sample_rate) { |
| @@ -47,8 +47,8 @@ static void Handler(int32_t /*sig*/) { | @@ -47,8 +47,8 @@ static void Handler(int32_t /*sig*/) { | ||
| 47 | fprintf(stderr, "\nCaught Ctrl + C. Exiting\n"); | 47 | fprintf(stderr, "\nCaught Ctrl + C. Exiting\n"); |
| 48 | } | 48 | } |
| 49 | 49 | ||
| 50 | -static void AudioGeneratedCallback(const float *s, int32_t n, | ||
| 51 | - float /*progress*/) { | 50 | +static int32_t AudioGeneratedCallback(const float *s, int32_t n, |
| 51 | + float /*progress*/) { | ||
| 52 | if (n > 0) { | 52 | if (n > 0) { |
| 53 | Samples samples; | 53 | Samples samples; |
| 54 | samples.data = std::vector<float>{s, s + n}; | 54 | samples.data = std::vector<float>{s, s + n}; |
| @@ -57,6 +57,12 @@ static void AudioGeneratedCallback(const float *s, int32_t n, | @@ -57,6 +57,12 @@ static void AudioGeneratedCallback(const float *s, int32_t n, | ||
| 57 | g_buffer.samples.push(std::move(samples)); | 57 | g_buffer.samples.push(std::move(samples)); |
| 58 | g_started = true; | 58 | g_started = true; |
| 59 | } | 59 | } |
| 60 | + if (g_killed) { | ||
| 61 | + return 0; // stop generating | ||
| 62 | + } | ||
| 63 | + | ||
| 64 | + // continue generating | ||
| 65 | + return 1; | ||
| 60 | } | 66 | } |
| 61 | 67 | ||
| 62 | static int PlayCallback(const void * /*in*/, void *out, | 68 | static int PlayCallback(const void * /*in*/, void *out, |
| @@ -9,8 +9,9 @@ | @@ -9,8 +9,9 @@ | ||
| 9 | #include "sherpa-onnx/csrc/parse-options.h" | 9 | #include "sherpa-onnx/csrc/parse-options.h" |
| 10 | #include "sherpa-onnx/csrc/wave-writer.h" | 10 | #include "sherpa-onnx/csrc/wave-writer.h" |
| 11 | 11 | ||
| 12 | -void audioCallback(const float * /*samples*/, int32_t n, float progress) { | 12 | +int32_t audioCallback(const float * /*samples*/, int32_t n, float progress) { |
| 13 | printf("sample=%d, progress=%f\n", n, progress); | 13 | printf("sample=%d, progress=%f\n", n, progress); |
| 14 | + return 1; | ||
| 14 | } | 15 | } |
| 15 | 16 | ||
| 16 | int main(int32_t argc, char *argv[]) { | 17 | int main(int32_t argc, char *argv[]) { |
| @@ -326,7 +326,7 @@ typedef SherpaOnnxDestroyOfflineTtsGeneratedAudioNative = Void Function( | @@ -326,7 +326,7 @@ typedef SherpaOnnxDestroyOfflineTtsGeneratedAudioNative = Void Function( | ||
| 326 | typedef SherpaOnnxDestroyOfflineTtsGeneratedAudio = void Function( | 326 | typedef SherpaOnnxDestroyOfflineTtsGeneratedAudio = void Function( |
| 327 | Pointer<SherpaOnnxGeneratedAudio>); | 327 | Pointer<SherpaOnnxGeneratedAudio>); |
| 328 | 328 | ||
| 329 | -typedef SherpaOnnxGeneratedAudioCallbackNative = Void Function( | 329 | +typedef SherpaOnnxGeneratedAudioCallbackNative = Int Function( |
| 330 | Pointer<Float>, Int32); | 330 | Pointer<Float>, Int32); |
| 331 | 331 | ||
| 332 | typedef SherpaOnnxOfflineTtsGenerateWithCallbackNative | 332 | typedef SherpaOnnxOfflineTtsGenerateWithCallbackNative |
| @@ -149,7 +149,7 @@ class OfflineTts { | @@ -149,7 +149,7 @@ class OfflineTts { | ||
| 149 | {required String text, | 149 | {required String text, |
| 150 | int sid = 0, | 150 | int sid = 0, |
| 151 | double speed = 1.0, | 151 | double speed = 1.0, |
| 152 | - required void Function(Float32List samples) callback}) { | 152 | + required int Function(Float32List samples) callback}) { |
| 153 | // see | 153 | // see |
| 154 | // https://github.com/dart-lang/sdk/issues/54276#issuecomment-1846109285 | 154 | // https://github.com/dart-lang/sdk/issues/54276#issuecomment-1846109285 |
| 155 | // https://stackoverflow.com/questions/69537440/callbacks-in-dart-dartffi-only-supports-calling-static-dart-functions-from-nat | 155 | // https://stackoverflow.com/questions/69537440/callbacks-in-dart-dartffi-only-supports-calling-static-dart-functions-from-nat |
| @@ -159,8 +159,8 @@ class OfflineTts { | @@ -159,8 +159,8 @@ class OfflineTts { | ||
| 159 | (Pointer<Float> samples, int n) { | 159 | (Pointer<Float> samples, int n) { |
| 160 | final s = samples.asTypedList(n); | 160 | final s = samples.asTypedList(n); |
| 161 | final newSamples = Float32List.fromList(s); | 161 | final newSamples = Float32List.fromList(s); |
| 162 | - callback(newSamples); | ||
| 163 | - }); | 162 | + return callback(newSamples); |
| 163 | + }, exceptionalReturn: 0); | ||
| 164 | 164 | ||
| 165 | final Pointer<Utf8> textPtr = text.toNativeUtf8(); | 165 | final Pointer<Utf8> textPtr = text.toNativeUtf8(); |
| 166 | final p = SherpaOnnxBindings.offlineTtsGenerateWithCallback | 166 | final p = SherpaOnnxBindings.offlineTtsGenerateWithCallback |
| @@ -186,14 +186,42 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl( | @@ -186,14 +186,42 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl( | ||
| 186 | const char *p_text = env->GetStringUTFChars(text, nullptr); | 186 | const char *p_text = env->GetStringUTFChars(text, nullptr); |
| 187 | SHERPA_ONNX_LOGE("string is: %s", p_text); | 187 | SHERPA_ONNX_LOGE("string is: %s", p_text); |
| 188 | 188 | ||
| 189 | - std::function<void(const float *, int32_t, float)> callback_wrapper = | 189 | + std::function<int32_t(const float *, int32_t, float)> callback_wrapper = |
| 190 | [env, callback](const float *samples, int32_t n, float /*progress*/) { | 190 | [env, callback](const float *samples, int32_t n, float /*progress*/) { |
| 191 | jclass cls = env->GetObjectClass(callback); | 191 | jclass cls = env->GetObjectClass(callback); |
| 192 | - jmethodID mid = env->GetMethodID(cls, "invoke", "([F)V"); | 192 | + |
| 193 | +#if 0 | ||
| 194 | + // this block is for debugging only | ||
| 195 | + // see also | ||
| 196 | + // https://jnjosh.com/posts/kotlinfromcpp/ | ||
| 197 | + jmethodID classMethodId = | ||
| 198 | + env->GetMethodID(cls, "getClass", "()Ljava/lang/Class;"); | ||
| 199 | + jobject klassObj = env->CallObjectMethod(callback, classMethodId); | ||
| 200 | + auto klassObject = env->GetObjectClass(klassObj); | ||
| 201 | + auto nameMethodId = | ||
| 202 | + env->GetMethodID(klassObject, "getName", "()Ljava/lang/String;"); | ||
| 203 | + jstring classString = | ||
| 204 | + (jstring)env->CallObjectMethod(klassObj, nameMethodId); | ||
| 205 | + auto className = env->GetStringUTFChars(classString, NULL); | ||
| 206 | + SHERPA_ONNX_LOGE("name is: %s", className); | ||
| 207 | + env->ReleaseStringUTFChars(classString, className); | ||
| 208 | +#endif | ||
| 209 | + | ||
| 210 | + jmethodID mid = | ||
| 211 | + env->GetMethodID(cls, "invoke", "([F)Ljava/lang/Integer;"); | ||
| 212 | + if (mid == nullptr) { | ||
| 213 | + SHERPA_ONNX_LOGE("Failed to get the callback. Ignore it."); | ||
| 214 | + return 1; | ||
| 215 | + } | ||
| 193 | 216 | ||
| 194 | jfloatArray samples_arr = env->NewFloatArray(n); | 217 | jfloatArray samples_arr = env->NewFloatArray(n); |
| 195 | env->SetFloatArrayRegion(samples_arr, 0, n, samples); | 218 | env->SetFloatArrayRegion(samples_arr, 0, n, samples); |
| 196 | - env->CallVoidMethod(callback, mid, samples_arr); | 219 | + |
| 220 | + jobject should_continue = | ||
| 221 | + env->CallObjectMethod(callback, mid, samples_arr); | ||
| 222 | + jclass jklass = env->GetObjectClass(should_continue); | ||
| 223 | + jmethodID int_value_mid = env->GetMethodID(jklass, "intValue", "()I"); | ||
| 224 | + return env->CallIntMethod(should_continue, int_value_mid); | ||
| 197 | }; | 225 | }; |
| 198 | 226 | ||
| 199 | auto audio = reinterpret_cast<sherpa_onnx::OfflineTts *>(ptr)->Generate( | 227 | auto audio = reinterpret_cast<sherpa_onnx::OfflineTts *>(ptr)->Generate( |
| @@ -57,13 +57,13 @@ void PybindOfflineTts(py::module *m) { | @@ -57,13 +57,13 @@ void PybindOfflineTts(py::module *m) { | ||
| 57 | "generate", | 57 | "generate", |
| 58 | [](const PyClass &self, const std::string &text, int64_t sid, | 58 | [](const PyClass &self, const std::string &text, int64_t sid, |
| 59 | float speed, | 59 | float speed, |
| 60 | - std::function<void(py::array_t<float>, float)> callback) | 60 | + std::function<int32_t(py::array_t<float>, float)> callback) |
| 61 | -> GeneratedAudio { | 61 | -> GeneratedAudio { |
| 62 | if (!callback) { | 62 | if (!callback) { |
| 63 | return self.Generate(text, sid, speed); | 63 | return self.Generate(text, sid, speed); |
| 64 | } | 64 | } |
| 65 | 65 | ||
| 66 | - std::function<void(const float *, int32_t, float)> | 66 | + std::function<int32_t(const float *, int32_t, float)> |
| 67 | callback_wrapper = [callback](const float *samples, int32_t n, | 67 | callback_wrapper = [callback](const float *samples, int32_t n, |
| 68 | float progress) { | 68 | float progress) { |
| 69 | // CAUTION(fangjun): we have to copy samples since it is | 69 | // CAUTION(fangjun): we have to copy samples since it is |
| @@ -75,7 +75,7 @@ void PybindOfflineTts(py::module *m) { | @@ -75,7 +75,7 @@ void PybindOfflineTts(py::module *m) { | ||
| 75 | py::buffer_info buf = array.request(); | 75 | py::buffer_info buf = array.request(); |
| 76 | auto p = static_cast<float *>(buf.ptr); | 76 | auto p = static_cast<float *>(buf.ptr); |
| 77 | std::copy(samples, samples + n, p); | 77 | std::copy(samples, samples + n, p); |
| 78 | - callback(array, progress); | 78 | + return callback(array, progress); |
| 79 | }; | 79 | }; |
| 80 | 80 | ||
| 81 | return self.Generate(text, sid, speed, callback_wrapper); | 81 | return self.Generate(text, sid, speed, callback_wrapper); |
-
请 注册 或 登录 后发表评论