Fangjun Kuang
Committed by GitHub

Enable to stop TTS generation (#1041)

正在显示 32 个修改的文件 包含 248 行增加69 行删除
@@ -8,7 +8,7 @@ project(sherpa-onnx) @@ -8,7 +8,7 @@ project(sherpa-onnx)
8 # ./nodejs-addon-examples 8 # ./nodejs-addon-examples
9 # ./dart-api-examples/ 9 # ./dart-api-examples/
10 # ./sherpa-onnx/flutter/CHANGELOG.md 10 # ./sherpa-onnx/flutter/CHANGELOG.md
11 -set(SHERPA_ONNX_VERSION "1.10.0") 11 +set(SHERPA_ONNX_VERSION "1.10.1")
12 12
13 # Disable warning about 13 # Disable warning about
14 # 14 #
@@ -26,6 +26,9 @@ class MainActivity : AppCompatActivity() { @@ -26,6 +26,9 @@ class MainActivity : AppCompatActivity() {
26 private lateinit var speed: EditText 26 private lateinit var speed: EditText
27 private lateinit var generate: Button 27 private lateinit var generate: Button
28 private lateinit var play: Button 28 private lateinit var play: Button
  29 + private lateinit var stop: Button
  30 + private var stopped: Boolean = false
  31 + private var mediaPlayer: MediaPlayer? = null
29 32
30 // see 33 // see
31 // https://developer.android.com/reference/kotlin/android/media/AudioTrack 34 // https://developer.android.com/reference/kotlin/android/media/AudioTrack
@@ -49,9 +52,11 @@ class MainActivity : AppCompatActivity() { @@ -49,9 +52,11 @@ class MainActivity : AppCompatActivity() {
49 52
50 generate = findViewById(R.id.generate) 53 generate = findViewById(R.id.generate)
51 play = findViewById(R.id.play) 54 play = findViewById(R.id.play)
  55 + stop = findViewById(R.id.stop)
52 56
53 generate.setOnClickListener { onClickGenerate() } 57 generate.setOnClickListener { onClickGenerate() }
54 play.setOnClickListener { onClickPlay() } 58 play.setOnClickListener { onClickPlay() }
  59 + stop.setOnClickListener { onClickStop() }
55 60
56 sid.setText("0") 61 sid.setText("0")
57 speed.setText("1.0") 62 speed.setText("1.0")
@@ -70,7 +75,7 @@ class MainActivity : AppCompatActivity() { @@ -70,7 +75,7 @@ class MainActivity : AppCompatActivity() {
70 AudioFormat.CHANNEL_OUT_MONO, 75 AudioFormat.CHANNEL_OUT_MONO,
71 AudioFormat.ENCODING_PCM_FLOAT 76 AudioFormat.ENCODING_PCM_FLOAT
72 ) 77 )
73 - Log.i(TAG, "sampleRate: ${sampleRate}, buffLength: ${bufLength}") 78 + Log.i(TAG, "sampleRate: $sampleRate, buffLength: $bufLength")
74 79
75 val attr = AudioAttributes.Builder().setContentType(AudioAttributes.CONTENT_TYPE_SPEECH) 80 val attr = AudioAttributes.Builder().setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
76 .setUsage(AudioAttributes.USAGE_MEDIA) 81 .setUsage(AudioAttributes.USAGE_MEDIA)
@@ -90,8 +95,14 @@ class MainActivity : AppCompatActivity() { @@ -90,8 +95,14 @@ class MainActivity : AppCompatActivity() {
90 } 95 }
91 96
92 // this function is called from C++ 97 // this function is called from C++
93 - private fun callback(samples: FloatArray) {  
94 - track.write(samples, 0, samples.size, AudioTrack.WRITE_BLOCKING) 98 + private fun callback(samples: FloatArray): Int {
  99 + if (!stopped) {
  100 + track.write(samples, 0, samples.size, AudioTrack.WRITE_BLOCKING)
  101 + return 1
  102 + } else {
  103 + track.stop()
  104 + return 0
  105 + }
95 } 106 }
96 107
97 private fun onClickGenerate() { 108 private fun onClickGenerate() {
@@ -127,6 +138,8 @@ class MainActivity : AppCompatActivity() { @@ -127,6 +138,8 @@ class MainActivity : AppCompatActivity() {
127 track.play() 138 track.play()
128 139
129 play.isEnabled = false 140 play.isEnabled = false
  141 + generate.isEnabled = false
  142 + stopped = false
130 Thread { 143 Thread {
131 val audio = tts.generateWithCallback( 144 val audio = tts.generateWithCallback(
132 text = textStr, 145 text = textStr,
@@ -140,6 +153,7 @@ class MainActivity : AppCompatActivity() { @@ -140,6 +153,7 @@ class MainActivity : AppCompatActivity() {
140 if (ok) { 153 if (ok) {
141 runOnUiThread { 154 runOnUiThread {
142 play.isEnabled = true 155 play.isEnabled = true
  156 + generate.isEnabled = true
143 track.stop() 157 track.stop()
144 } 158 }
145 } 159 }
@@ -148,11 +162,22 @@ class MainActivity : AppCompatActivity() { @@ -148,11 +162,22 @@ class MainActivity : AppCompatActivity() {
148 162
149 private fun onClickPlay() { 163 private fun onClickPlay() {
150 val filename = application.filesDir.absolutePath + "/generated.wav" 164 val filename = application.filesDir.absolutePath + "/generated.wav"
151 - val mediaPlayer = MediaPlayer.create( 165 + mediaPlayer?.stop()
  166 + mediaPlayer = MediaPlayer.create(
152 applicationContext, 167 applicationContext,
153 Uri.fromFile(File(filename)) 168 Uri.fromFile(File(filename))
154 ) 169 )
155 - mediaPlayer.start() 170 + mediaPlayer?.start()
  171 + }
  172 +
  173 + private fun onClickStop() {
  174 + stopped = true
  175 + play.isEnabled = true
  176 + generate.isEnabled = true
  177 + track.pause()
  178 + track.flush()
  179 + mediaPlayer?.stop()
  180 + mediaPlayer = null
156 } 181 }
157 182
158 private fun initTts() { 183 private fun initTts() {
@@ -76,7 +76,7 @@ class OfflineTts( @@ -76,7 +76,7 @@ class OfflineTts(
76 text: String, 76 text: String,
77 sid: Int = 0, 77 sid: Int = 0,
78 speed: Float = 1.0f, 78 speed: Float = 1.0f,
79 - callback: (samples: FloatArray) -> Unit 79 + callback: (samples: FloatArray) -> Int
80 ): GeneratedAudio { 80 ): GeneratedAudio {
81 val objArray = generateWithCallbackImpl( 81 val objArray = generateWithCallbackImpl(
82 ptr, 82 ptr,
@@ -146,7 +146,7 @@ class OfflineTts( @@ -146,7 +146,7 @@ class OfflineTts(
146 text: String, 146 text: String,
147 sid: Int = 0, 147 sid: Int = 0,
148 speed: Float = 1.0f, 148 speed: Float = 1.0f,
149 - callback: (samples: FloatArray) -> Unit 149 + callback: (samples: FloatArray) -> Int
150 ): Array<Any> 150 ): Array<Any>
151 151
152 companion object { 152 companion object {
@@ -84,4 +84,16 @@ @@ -84,4 +84,16 @@
84 app:layout_constraintLeft_toLeftOf="parent" 84 app:layout_constraintLeft_toLeftOf="parent"
85 app:layout_constraintRight_toRightOf="parent" 85 app:layout_constraintRight_toRightOf="parent"
86 app:layout_constraintTop_toBottomOf="@id/generate" /> 86 app:layout_constraintTop_toBottomOf="@id/generate" />
  87 +
  88 + <Button
  89 + android:id="@+id/stop"
  90 + android:textAllCaps="false"
  91 + android:layout_width="match_parent"
  92 + android:layout_height="50dp"
  93 + android:layout_marginTop="4dp"
  94 + android:text="@string/stop"
  95 + app:layout_constraintLeft_toLeftOf="parent"
  96 + app:layout_constraintRight_toRightOf="parent"
  97 + app:layout_constraintTop_toBottomOf="@id/play" />
  98 +
87 </androidx.constraintlayout.widget.ConstraintLayout> 99 </androidx.constraintlayout.widget.ConstraintLayout>
@@ -7,4 +7,5 @@ @@ -7,4 +7,5 @@
7 <string name="text_hint">Please input your text here</string> 7 <string name="text_hint">Please input your text here</string>
8 <string name="generate">Generate</string> 8 <string name="generate">Generate</string>
9 <string name="play">Play</string> 9 <string name="play">Play</string>
  10 + <string name="stop">Stop</string>
10 </resources> 11 </resources>
@@ -126,7 +126,7 @@ class TtsService : TextToSpeechService() { @@ -126,7 +126,7 @@ class TtsService : TextToSpeechService() {
126 return 126 return
127 } 127 }
128 128
129 - val ttsCallback = { floatSamples: FloatArray -> 129 + val ttsCallback: (FloatArray) -> Int = fun(floatSamples): Int {
130 // convert FloatArray to ByteArray 130 // convert FloatArray to ByteArray
131 val samples = floatArrayToByteArray(floatSamples) 131 val samples = floatArrayToByteArray(floatSamples)
132 val maxBufferSize: Int = callback.maxBufferSize 132 val maxBufferSize: Int = callback.maxBufferSize
@@ -137,6 +137,9 @@ class TtsService : TextToSpeechService() { @@ -137,6 +137,9 @@ class TtsService : TextToSpeechService() {
137 offset += bytesToWrite 137 offset += bytesToWrite
138 } 138 }
139 139
  140 + // 1 means to continue
  141 + // 0 means to stop
  142 + return 1
140 } 143 }
141 144
142 Log.i(TAG, "text: $text") 145 Log.i(TAG, "text: $text")
@@ -160,4 +163,4 @@ class TtsService : TextToSpeechService() { @@ -160,4 +163,4 @@ class TtsService : TextToSpeechService() {
160 } 163 }
161 return byteArray 164 return byteArray
162 } 165 }
163 -}  
  166 +}
@@ -10,7 +10,7 @@ environment: @@ -10,7 +10,7 @@ environment:
10 10
11 # Add regular dependencies here. 11 # Add regular dependencies here.
12 dependencies: 12 dependencies:
13 - sherpa_onnx: ^1.10.0 13 + sherpa_onnx: ^1.10.1
14 path: ^1.9.0 14 path: ^1.9.0
15 args: ^2.5.0 15 args: ^2.5.0
16 16
@@ -11,7 +11,7 @@ environment: @@ -11,7 +11,7 @@ environment:
11 11
12 # Add regular dependencies here. 12 # Add regular dependencies here.
13 dependencies: 13 dependencies:
14 - sherpa_onnx: ^1.10.0 14 + sherpa_onnx: ^1.10.1
15 path: ^1.9.0 15 path: ^1.9.0
16 args: ^2.5.0 16 args: ^2.5.0
17 17
@@ -68,6 +68,10 @@ void main(List<String> arguments) async { @@ -68,6 +68,10 @@ void main(List<String> arguments) async {
68 callback: (Float32List samples) { 68 callback: (Float32List samples) {
69 print('${samples.length} samples received'); 69 print('${samples.length} samples received');
70 // You can play samples in a separate thread/isolate 70 // You can play samples in a separate thread/isolate
  71 +
  72 + // 1 means to continue
  73 + // 0 means to stop
  74 + return 1;
71 }); 75 });
72 tts.free(); 76 tts.free();
73 77
@@ -8,7 +8,7 @@ environment: @@ -8,7 +8,7 @@ environment:
8 8
9 # Add regular dependencies here. 9 # Add regular dependencies here.
10 dependencies: 10 dependencies:
11 - sherpa_onnx: ^1.10.0 11 + sherpa_onnx: ^1.10.1
12 path: ^1.9.0 12 path: ^1.9.0
13 args: ^2.5.0 13 args: ^2.5.0
14 14
@@ -9,7 +9,7 @@ environment: @@ -9,7 +9,7 @@ environment:
9 sdk: ^3.4.0 9 sdk: ^3.4.0
10 10
11 dependencies: 11 dependencies:
12 - sherpa_onnx: ^1.10.0 12 + sherpa_onnx: ^1.10.1
13 path: ^1.9.0 13 path: ^1.9.0
14 args: ^2.5.0 14 args: ^2.5.0
15 15
@@ -187,6 +187,10 @@ to download more models. @@ -187,6 +187,10 @@ to download more models.
187 Marshal.Copy(samples, data, 0, n); 187 Marshal.Copy(samples, data, 0, n);
188 188
189 dataItems.Add(data); 189 dataItems.Add(data);
  190 +
  191 + // 1 means to keep generating
  192 + // 0 means to stop generating
  193 + return 1;
190 }; 194 };
191 195
192 bool playFinished = false; 196 bool playFinished = false;
@@ -25,6 +25,46 @@ fun testTts() { @@ -25,6 +25,46 @@ fun testTts() {
25 println("Saved to test-en.wav") 25 println("Saved to test-en.wav")
26 } 26 }
27 27
28 -fun callback(samples: FloatArray): Unit { 28 +/*
  29 +1. Unzip test_tts.jar
  30 +2.
  31 +javap ./com/k2fsa/sherpa/onnx/Test_ttsKt\$testTts\$audio\$1.class
  32 +
  33 +3. It prints:
  34 +Compiled from "test_tts.kt"
  35 +final class com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1 extends kotlin.jvm.internal.FunctionReferenceImpl implements kotlin.jvm.functions.Function1<float[], java.lang.Integer> {
  36 + public static final com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1 INSTANCE;
  37 + com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1();
  38 + public final java.lang.Integer invoke(float[]);
  39 + public java.lang.Object invoke(java.lang.Object);
  40 + static {};
  41 +}
  42 +
  43 +4.
  44 +javap -s ./com/k2fsa/sherpa/onnx/Test_ttsKt\$testTts\$audio\$1.class
  45 +
  46 +5. It prints
  47 +Compiled from "test_tts.kt"
  48 +final class com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1 extends kotlin.jvm.internal.FunctionReferenceImpl implements kotlin.jvm.functions.Function1<float[], java.lang.Integer> {
  49 + public static final com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1 INSTANCE;
  50 + descriptor: Lcom/k2fsa/sherpa/onnx/Test_ttsKt$testTts$audio$1;
  51 + com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1();
  52 + descriptor: ()V
  53 +
  54 + public final java.lang.Integer invoke(float[]);
  55 + descriptor: ([F)Ljava/lang/Integer;
  56 +
  57 + public java.lang.Object invoke(java.lang.Object);
  58 + descriptor: (Ljava/lang/Object;)Ljava/lang/Object;
  59 +
  60 + static {};
  61 + descriptor: ()V
  62 +}
  63 +*/
  64 +fun callback(samples: FloatArray): Int {
29 println("callback got called with ${samples.size} samples"); 65 println("callback got called with ${samples.size} samples");
  66 +
  67 + // 1 means to continue
  68 + // 0 means to stop
  69 + return 1
30 } 70 }
1 B// Microsoft Visual C++ generated resource script. 1 B// Microsoft Visual C++ generated resource script.
@@ -57,7 +57,7 @@ static bool g_started = false; @@ -57,7 +57,7 @@ static bool g_started = false;
57 static bool g_stopped = false; 57 static bool g_stopped = false;
58 static bool g_killed = false; 58 static bool g_killed = false;
59 59
60 -static void AudioGeneratedCallback(const float *s, int32_t n) { 60 +static int32_t AudioGeneratedCallback(const float *s, int32_t n) {
61 if (n > 0) { 61 if (n > 0) {
62 Samples samples; 62 Samples samples;
63 samples.data = std::vector<float>{s, s + n}; 63 samples.data = std::vector<float>{s, s + n};
@@ -66,6 +66,10 @@ static void AudioGeneratedCallback(const float *s, int32_t n) { @@ -66,6 +66,10 @@ static void AudioGeneratedCallback(const float *s, int32_t n) {
66 g_buffer.samples.push(std::move(samples)); 66 g_buffer.samples.push(std::move(samples));
67 g_started = true; 67 g_started = true;
68 } 68 }
  69 + if (g_killed) {
  70 + return 0;
  71 + }
  72 + return 1;
69 } 73 }
70 74
71 static int PlayCallback(const void * /*in*/, void *out, 75 static int PlayCallback(const void * /*in*/, void *out,
@@ -324,6 +328,7 @@ BEGIN_MESSAGE_MAP(CNonStreamingTextToSpeechDlg, CDialogEx) @@ -324,6 +328,7 @@ BEGIN_MESSAGE_MAP(CNonStreamingTextToSpeechDlg, CDialogEx)
324 ON_WM_PAINT() 328 ON_WM_PAINT()
325 ON_WM_QUERYDRAGICON() 329 ON_WM_QUERYDRAGICON()
326 ON_BN_CLICKED(IDOK, &CNonStreamingTextToSpeechDlg::OnBnClickedOk) 330 ON_BN_CLICKED(IDOK, &CNonStreamingTextToSpeechDlg::OnBnClickedOk)
  331 + ON_BN_CLICKED(IDC_STOP, &CNonStreamingTextToSpeechDlg::OnBnClickedStop)
327 END_MESSAGE_MAP() 332 END_MESSAGE_MAP()
328 333
329 334
@@ -492,11 +497,18 @@ void CNonStreamingTextToSpeechDlg::Init() { @@ -492,11 +497,18 @@ void CNonStreamingTextToSpeechDlg::Init() {
492 if (tts_) { 497 if (tts_) {
493 SherpaOnnxDestroyOfflineTts(tts_); 498 SherpaOnnxDestroyOfflineTts(tts_);
494 } 499 }
  500 + if (generate_thread_ && generate_thread_->joinable()) {
  501 + generate_thread_->join();
  502 + }
  503 +
  504 + if (play_thread_ && play_thread_->joinable()) {
  505 + play_thread_->join();
  506 + }
495 } 507 }
496 508
497 509
498 static std::string ToString(const CString &s) { 510 static std::string ToString(const CString &s) {
499 - CT2CA pszConvertedAnsiString( s); 511 + CT2CA pszConvertedAnsiString(s);
500 return std::string(pszConvertedAnsiString); 512 return std::string(pszConvertedAnsiString);
501 } 513 }
502 514
@@ -510,7 +522,7 @@ void CNonStreamingTextToSpeechDlg::OnBnClickedOk() { @@ -510,7 +522,7 @@ void CNonStreamingTextToSpeechDlg::OnBnClickedOk() {
510 } 522 }
511 523
512 speed_.GetWindowText(s); 524 speed_.GetWindowText(s);
513 - float speed = static_cast<float>(_ttof(s)); 525 + float speed = static_cast<float>(_ttof(s));
514 if (speed < 0) { 526 if (speed < 0) {
515 AfxMessageBox(Utf8ToUtf16("Please input a valid speed").c_str(), MB_OK); 527 AfxMessageBox(Utf8ToUtf16("Please input a valid speed").c_str(), MB_OK);
516 return; 528 return;
@@ -541,28 +553,40 @@ void CNonStreamingTextToSpeechDlg::OnBnClickedOk() { @@ -541,28 +553,40 @@ void CNonStreamingTextToSpeechDlg::OnBnClickedOk() {
541 // for simplicity 553 // for simplicity
542 play_thread_ = std::make_unique<std::thread>(StartPlayback, SherpaOnnxOfflineTtsSampleRate(tts_)); 554 play_thread_ = std::make_unique<std::thread>(StartPlayback, SherpaOnnxOfflineTtsSampleRate(tts_));
543 555
544 - generate_btn_.EnableWindow(FALSE);  
545 -  
546 - const SherpaOnnxGeneratedAudio *audio =  
547 - SherpaOnnxOfflineTtsGenerateWithCallback(tts_, ss.c_str(), speaker_id, speed, &AudioGeneratedCallback);  
548 -  
549 - generate_btn_.EnableWindow(TRUE); 556 + if (generate_thread_ && generate_thread_->joinable()) {
  557 + generate_thread_->join();
  558 + }
550 559
551 output_filename_.GetWindowText(s); 560 output_filename_.GetWindowText(s);
552 std::string filename = ToString(s); 561 std::string filename = ToString(s);
553 562
554 - int ok = SherpaOnnxWriteWave(audio->samples, audio->n, audio->sample_rate,  
555 - filename.c_str()); 563 + generate_thread_ = std::make_unique<std::thread>([ss, this,filename, speaker_id, speed]() {
  564 + std::string text = ss;
556 565
557 - SherpaOnnxDestroyOfflineTtsGeneratedAudio(audio); 566 + // generate_btn_.EnableWindow(FALSE);
558 567
559 - if (ok) {  
560 - // AfxMessageBox(Utf8ToUtf16(std::string("Saved to ") + filename + " successfully").c_str(), MB_OK);  
561 - AppendLineToMultilineEditCtrl(my_hint_, std::string("Saved to ") + filename + " successfully");  
562 - } else {  
563 - // AfxMessageBox(Utf8ToUtf16(std::string("Failed to save to ") + filename).c_str(), MB_OK);  
564 - AppendLineToMultilineEditCtrl(my_hint_, std::string("Failed to saved to ") + filename);  
565 - } 568 + const SherpaOnnxGeneratedAudio *audio =
  569 + SherpaOnnxOfflineTtsGenerateWithCallback(tts_, text.c_str(), speaker_id, speed, &AudioGeneratedCallback);
  570 + // generate_btn_.EnableWindow(TRUE);
  571 + g_stopped = true;
  572 +
  573 + int ok = SherpaOnnxWriteWave(audio->samples, audio->n, audio->sample_rate,
  574 + filename.c_str());
  575 +
  576 + SherpaOnnxDestroyOfflineTtsGeneratedAudio(audio);
  577 +
  578 + if (ok) {
  579 + // AfxMessageBox(Utf8ToUtf16(std::string("Saved to ") + filename + " successfully").c_str(), MB_OK);
  580 +
  581 + // AppendLineToMultilineEditCtrl(my_hint_, std::string("Saved to ") + filename + " successfully");
  582 + } else {
  583 + // AfxMessageBox(Utf8ToUtf16(std::string("Failed to save to ") + filename).c_str(), MB_OK);
  584 +
  585 + // AppendLineToMultilineEditCtrl(my_hint_, std::string("Failed to saved to ") + filename);
  586 + }
  587 + });
566 588
567 //CDialogEx::OnOK(); 589 //CDialogEx::OnOK();
568 } 590 }
  591 +
  592 +void CNonStreamingTextToSpeechDlg::OnBnClickedStop() { g_killed = true; }
@@ -60,5 +60,8 @@ public: @@ -60,5 +60,8 @@ public:
60 private: 60 private:
61 Microphone mic_; 61 Microphone mic_;
62 std::unique_ptr<std::thread> play_thread_; 62 std::unique_ptr<std::thread> play_thread_;
  63 + std::unique_ptr<std::thread> generate_thread_;
63 64
  65 + public:
  66 + afx_msg void OnBnClickedStop();
64 }; 67 };
@@ -13,6 +13,7 @@ @@ -13,6 +13,7 @@
13 #define IDC_HINT 1005 13 #define IDC_HINT 1005
14 #define IDC_EDIT1 1006 14 #define IDC_EDIT1 1006
15 #define IDC_OUTPUT_FILENAME 1006 15 #define IDC_OUTPUT_FILENAME 1006
  16 +#define IDC_STOP 1009
16 17
17 // Next default values for new objects 18 // Next default values for new objects
18 // 19 //
@@ -20,7 +21,7 @@ @@ -20,7 +21,7 @@
20 #ifndef APSTUDIO_READONLY_SYMBOLS 21 #ifndef APSTUDIO_READONLY_SYMBOLS
21 #define _APS_NEXT_RESOURCE_VALUE 130 22 #define _APS_NEXT_RESOURCE_VALUE 130
22 #define _APS_NEXT_COMMAND_VALUE 32771 23 #define _APS_NEXT_COMMAND_VALUE 32771
23 -#define _APS_NEXT_CONTROL_VALUE 1007 24 +#define _APS_NEXT_CONTROL_VALUE 1010
24 #define _APS_NEXT_SYMED_VALUE 101 25 #define _APS_NEXT_SYMED_VALUE 101
25 #endif 26 #endif
26 #endif 27 #endif
1 { 1 {
2 "dependencies": { 2 "dependencies": {
3 - "sherpa-onnx-node": "^1.10.0" 3 + "sherpa-onnx-node": "^1.10.1"
4 } 4 }
5 } 5 }
@@ -228,6 +228,13 @@ def generated_audio_callback(samples: np.ndarray, progress: float): @@ -228,6 +228,13 @@ def generated_audio_callback(samples: np.ndarray, progress: float):
228 logging.info("Start playing ...") 228 logging.info("Start playing ...")
229 started = True 229 started = True
230 230
  231 + # 1 means to keep generating
  232 + # 0 means to stop generating
  233 + if killed:
  234 + return 0
  235 +
  236 + return 1
  237 +
231 238
232 # see https://python-sounddevice.readthedocs.io/en/0.4.6/api/streams.html#sounddevice.OutputStream 239 # see https://python-sounddevice.readthedocs.io/en/0.4.6/api/streams.html#sounddevice.OutputStream
233 def play_audio_callback( 240 def play_audio_callback(
@@ -8,8 +8,8 @@ using System; @@ -8,8 +8,8 @@ using System;
8 8
9 namespace SherpaOnnx 9 namespace SherpaOnnx
10 { 10 {
11 - // IntPtr is actuallly a `const float*` from C++  
12 - public delegate void OfflineTtsCallback(IntPtr samples, int n); 11 + // IntPtr is actually a `const float*` from C++
  12 + public delegate int OfflineTtsCallback(IntPtr samples, int n);
13 13
14 public class OfflineTts : IDisposable 14 public class OfflineTts : IDisposable
15 { 15 {
@@ -88,4 +88,4 @@ namespace SherpaOnnx @@ -88,4 +88,4 @@ namespace SherpaOnnx
88 [DllImport(Dll.Filename, CallingConvention = CallingConvention.Cdecl)] 88 [DllImport(Dll.Filename, CallingConvention = CallingConvention.Cdecl)]
89 private static extern IntPtr SherpaOnnxOfflineTtsGenerateWithCallback(IntPtr handle, [MarshalAs(UnmanagedType.LPStr)] string text, int sid, float speed, OfflineTtsCallback callback); 89 private static extern IntPtr SherpaOnnxOfflineTtsGenerateWithCallback(IntPtr handle, [MarshalAs(UnmanagedType.LPStr)] string text, int sid, float speed, OfflineTtsCallback callback);
90 } 90 }
91 -}  
  91 +}
@@ -935,7 +935,7 @@ int32_t SherpaOnnxOfflineTtsNumSpeakers(const SherpaOnnxOfflineTts *tts) { @@ -935,7 +935,7 @@ int32_t SherpaOnnxOfflineTtsNumSpeakers(const SherpaOnnxOfflineTts *tts) {
935 935
936 static const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateInternal( 936 static const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateInternal(
937 const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed, 937 const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
938 - std::function<void(const float *, int32_t, float)> callback) { 938 + std::function<int32_t(const float *, int32_t, float)> callback) {
939 sherpa_onnx::GeneratedAudio audio = 939 sherpa_onnx::GeneratedAudio audio =
940 tts->impl->Generate(text, sid, speed, callback); 940 tts->impl->Generate(text, sid, speed, callback);
941 941
@@ -965,7 +965,9 @@ const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallback( @@ -965,7 +965,9 @@ const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallback(
965 const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed, 965 const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
966 SherpaOnnxGeneratedAudioCallback callback) { 966 SherpaOnnxGeneratedAudioCallback callback) {
967 auto wrapper = [callback](const float *samples, int32_t n, 967 auto wrapper = [callback](const float *samples, int32_t n,
968 - float /*progress*/) { callback(samples, n); }; 968 + float /*progress*/) {
  969 + return callback(samples, n);
  970 + };
969 971
970 return SherpaOnnxOfflineTtsGenerateInternal(tts, text, sid, speed, wrapper); 972 return SherpaOnnxOfflineTtsGenerateInternal(tts, text, sid, speed, wrapper);
971 } 973 }
@@ -975,7 +977,7 @@ SherpaOnnxOfflineTtsGenerateWithProgressCallback( @@ -975,7 +977,7 @@ SherpaOnnxOfflineTtsGenerateWithProgressCallback(
975 const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed, 977 const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
976 SherpaOnnxGeneratedAudioProgressCallback callback) { 978 SherpaOnnxGeneratedAudioProgressCallback callback) {
977 auto wrapper = [callback](const float *samples, int32_t n, float progress) { 979 auto wrapper = [callback](const float *samples, int32_t n, float progress) {
978 - callback(samples, n, progress); 980 + return callback(samples, n, progress);
979 }; 981 };
980 return SherpaOnnxOfflineTtsGenerateInternal(tts, text, sid, speed, wrapper); 982 return SherpaOnnxOfflineTtsGenerateInternal(tts, text, sid, speed, wrapper);
981 } 983 }
@@ -985,7 +987,7 @@ const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallbackWithArg( @@ -985,7 +987,7 @@ const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallbackWithArg(
985 SherpaOnnxGeneratedAudioCallbackWithArg callback, void *arg) { 987 SherpaOnnxGeneratedAudioCallbackWithArg callback, void *arg) {
986 auto wrapper = [callback, arg](const float *samples, int32_t n, 988 auto wrapper = [callback, arg](const float *samples, int32_t n,
987 float /*progress*/) { 989 float /*progress*/) {
988 - callback(samples, n, arg); 990 + return callback(samples, n, arg);
989 }; 991 };
990 992
991 return SherpaOnnxOfflineTtsGenerateInternal(tts, text, sid, speed, wrapper); 993 return SherpaOnnxOfflineTtsGenerateInternal(tts, text, sid, speed, wrapper);
@@ -850,14 +850,17 @@ SHERPA_ONNX_API typedef struct SherpaOnnxGeneratedAudio { @@ -850,14 +850,17 @@ SHERPA_ONNX_API typedef struct SherpaOnnxGeneratedAudio {
850 int32_t sample_rate; 850 int32_t sample_rate;
851 } SherpaOnnxGeneratedAudio; 851 } SherpaOnnxGeneratedAudio;
852 852
853 -typedef void (*SherpaOnnxGeneratedAudioCallback)(const float *samples,  
854 - int32_t n); 853 +// If the callback returns 0, then it stops generating
  854 +// If the callback returns 1, then it keeps generating
  855 +typedef int32_t (*SherpaOnnxGeneratedAudioCallback)(const float *samples,
  856 + int32_t n);
855 857
856 -typedef void (*SherpaOnnxGeneratedAudioCallbackWithArg)(const float *samples,  
857 - int32_t n, void *arg); 858 +typedef int32_t (*SherpaOnnxGeneratedAudioCallbackWithArg)(const float *samples,
  859 + int32_t n,
  860 + void *arg);
858 861
859 -typedef void (*SherpaOnnxGeneratedAudioProgressCallback)(const float *samples,  
860 - int32_t n, float p); 862 +typedef int32_t (*SherpaOnnxGeneratedAudioProgressCallback)(
  863 + const float *samples, int32_t n, float p);
861 864
862 SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTts SherpaOnnxOfflineTts; 865 SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTts SherpaOnnxOfflineTts;
863 866
@@ -216,9 +216,11 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -216,9 +216,11 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
216 216
217 GeneratedAudio ans; 217 GeneratedAudio ans;
218 218
  219 + int32_t should_continue = 1;
  220 +
219 int32_t k = 0; 221 int32_t k = 0;
220 222
221 - for (int32_t b = 0; b != num_batches; ++b) { 223 + for (int32_t b = 0; b != num_batches && should_continue; ++b) {
222 batch.clear(); 224 batch.clear();
223 for (int32_t i = 0; i != batch_size; ++i, ++k) { 225 for (int32_t i = 0; i != batch_size; ++i, ++k) {
224 batch.push_back(std::move(x[k])); 226 batch.push_back(std::move(x[k]));
@@ -229,8 +231,8 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -229,8 +231,8 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
229 ans.samples.insert(ans.samples.end(), audio.samples.begin(), 231 ans.samples.insert(ans.samples.end(), audio.samples.begin(),
230 audio.samples.end()); 232 audio.samples.end());
231 if (callback) { 233 if (callback) {
232 - callback(audio.samples.data(), audio.samples.size(),  
233 - b * 1.0 / num_batches); 234 + should_continue = callback(audio.samples.data(), audio.samples.size(),
  235 + b * 1.0 / num_batches);
234 // Caution(fangjun): audio is freed when the callback returns, so users 236 // Caution(fangjun): audio is freed when the callback returns, so users
235 // should copy the data if they want to access the data after 237 // should copy the data if they want to access the data after
236 // the callback returns to avoid segmentation fault. 238 // the callback returns to avoid segmentation fault.
@@ -238,7 +240,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -238,7 +240,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
238 } 240 }
239 241
240 batch.clear(); 242 batch.clear();
241 - while (k < static_cast<int32_t>(x.size())) { 243 + while (k < static_cast<int32_t>(x.size()) && should_continue) {
242 batch.push_back(std::move(x[k])); 244 batch.push_back(std::move(x[k]));
243 ++k; 245 ++k;
244 } 246 }
@@ -59,7 +59,9 @@ struct GeneratedAudio { @@ -59,7 +59,9 @@ struct GeneratedAudio {
59 59
60 class OfflineTtsImpl; 60 class OfflineTtsImpl;
61 61
62 -using GeneratedAudioCallback = std::function<void( 62 +// If the callback returns 0, then it stop generating
  63 +// if the callback returns 1, then it keeps generating
  64 +using GeneratedAudioCallback = std::function<int32_t(
63 const float * /*samples*/, int32_t /*n*/, float /*progress*/)>; 65 const float * /*samples*/, int32_t /*n*/, float /*progress*/)>;
64 66
65 class OfflineTts { 67 class OfflineTts {
@@ -44,13 +44,20 @@ static void Handler(int32_t /*sig*/) { @@ -44,13 +44,20 @@ static void Handler(int32_t /*sig*/) {
44 fprintf(stderr, "\nCaught Ctrl + C. Exiting\n"); 44 fprintf(stderr, "\nCaught Ctrl + C. Exiting\n");
45 } 45 }
46 46
47 -static void AudioGeneratedCallback(const float *s, int32_t n,  
48 - float /*progress*/) { 47 +static int32_t AudioGeneratedCallback(const float *s, int32_t n,
  48 + float /*progress*/) {
49 if (n > 0) { 49 if (n > 0) {
50 std::lock_guard<std::mutex> lock(g_buffer.mutex); 50 std::lock_guard<std::mutex> lock(g_buffer.mutex);
51 g_buffer.samples.push({s, s + n}); 51 g_buffer.samples.push({s, s + n});
52 g_cv.notify_all(); 52 g_cv.notify_all();
53 } 53 }
  54 +
  55 + if (g_killed) {
  56 + return 0; // stop generating
  57 + }
  58 +
  59 + // continue generating
  60 + return 1;
54 } 61 }
55 62
56 static void StartPlayback(const std::string &device_name, int32_t sample_rate) { 63 static void StartPlayback(const std::string &device_name, int32_t sample_rate) {
@@ -47,8 +47,8 @@ static void Handler(int32_t /*sig*/) { @@ -47,8 +47,8 @@ static void Handler(int32_t /*sig*/) {
47 fprintf(stderr, "\nCaught Ctrl + C. Exiting\n"); 47 fprintf(stderr, "\nCaught Ctrl + C. Exiting\n");
48 } 48 }
49 49
50 -static void AudioGeneratedCallback(const float *s, int32_t n,  
51 - float /*progress*/) { 50 +static int32_t AudioGeneratedCallback(const float *s, int32_t n,
  51 + float /*progress*/) {
52 if (n > 0) { 52 if (n > 0) {
53 Samples samples; 53 Samples samples;
54 samples.data = std::vector<float>{s, s + n}; 54 samples.data = std::vector<float>{s, s + n};
@@ -57,6 +57,12 @@ static void AudioGeneratedCallback(const float *s, int32_t n, @@ -57,6 +57,12 @@ static void AudioGeneratedCallback(const float *s, int32_t n,
57 g_buffer.samples.push(std::move(samples)); 57 g_buffer.samples.push(std::move(samples));
58 g_started = true; 58 g_started = true;
59 } 59 }
  60 + if (g_killed) {
  61 + return 0; // stop generating
  62 + }
  63 +
  64 + // continue generating
  65 + return 1;
60 } 66 }
61 67
62 static int PlayCallback(const void * /*in*/, void *out, 68 static int PlayCallback(const void * /*in*/, void *out,
@@ -9,8 +9,9 @@ @@ -9,8 +9,9 @@
9 #include "sherpa-onnx/csrc/parse-options.h" 9 #include "sherpa-onnx/csrc/parse-options.h"
10 #include "sherpa-onnx/csrc/wave-writer.h" 10 #include "sherpa-onnx/csrc/wave-writer.h"
11 11
12 -void audioCallback(const float * /*samples*/, int32_t n, float progress) { 12 +int32_t audioCallback(const float * /*samples*/, int32_t n, float progress) {
13 printf("sample=%d, progress=%f\n", n, progress); 13 printf("sample=%d, progress=%f\n", n, progress);
  14 + return 1;
14 } 15 }
15 16
16 int main(int32_t argc, char *argv[]) { 17 int main(int32_t argc, char *argv[]) {
  1 +## 1.10.1
  2 +
  3 +* Enable to stop TTS generation
  4 +
1 ## 1.10.0 5 ## 1.10.0
2 6
3 * Add inverse text normalization 7 * Add inverse text normalization
@@ -326,7 +326,7 @@ typedef SherpaOnnxDestroyOfflineTtsGeneratedAudioNative = Void Function( @@ -326,7 +326,7 @@ typedef SherpaOnnxDestroyOfflineTtsGeneratedAudioNative = Void Function(
326 typedef SherpaOnnxDestroyOfflineTtsGeneratedAudio = void Function( 326 typedef SherpaOnnxDestroyOfflineTtsGeneratedAudio = void Function(
327 Pointer<SherpaOnnxGeneratedAudio>); 327 Pointer<SherpaOnnxGeneratedAudio>);
328 328
329 -typedef SherpaOnnxGeneratedAudioCallbackNative = Void Function( 329 +typedef SherpaOnnxGeneratedAudioCallbackNative = Int Function(
330 Pointer<Float>, Int32); 330 Pointer<Float>, Int32);
331 331
332 typedef SherpaOnnxOfflineTtsGenerateWithCallbackNative 332 typedef SherpaOnnxOfflineTtsGenerateWithCallbackNative
@@ -149,7 +149,7 @@ class OfflineTts { @@ -149,7 +149,7 @@ class OfflineTts {
149 {required String text, 149 {required String text,
150 int sid = 0, 150 int sid = 0,
151 double speed = 1.0, 151 double speed = 1.0,
152 - required void Function(Float32List samples) callback}) { 152 + required int Function(Float32List samples) callback}) {
153 // see 153 // see
154 // https://github.com/dart-lang/sdk/issues/54276#issuecomment-1846109285 154 // https://github.com/dart-lang/sdk/issues/54276#issuecomment-1846109285
155 // https://stackoverflow.com/questions/69537440/callbacks-in-dart-dartffi-only-supports-calling-static-dart-functions-from-nat 155 // https://stackoverflow.com/questions/69537440/callbacks-in-dart-dartffi-only-supports-calling-static-dart-functions-from-nat
@@ -159,8 +159,8 @@ class OfflineTts { @@ -159,8 +159,8 @@ class OfflineTts {
159 (Pointer<Float> samples, int n) { 159 (Pointer<Float> samples, int n) {
160 final s = samples.asTypedList(n); 160 final s = samples.asTypedList(n);
161 final newSamples = Float32List.fromList(s); 161 final newSamples = Float32List.fromList(s);
162 - callback(newSamples);  
163 - }); 162 + return callback(newSamples);
  163 + }, exceptionalReturn: 0);
164 164
165 final Pointer<Utf8> textPtr = text.toNativeUtf8(); 165 final Pointer<Utf8> textPtr = text.toNativeUtf8();
166 final p = SherpaOnnxBindings.offlineTtsGenerateWithCallback 166 final p = SherpaOnnxBindings.offlineTtsGenerateWithCallback
@@ -186,14 +186,42 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl( @@ -186,14 +186,42 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl(
186 const char *p_text = env->GetStringUTFChars(text, nullptr); 186 const char *p_text = env->GetStringUTFChars(text, nullptr);
187 SHERPA_ONNX_LOGE("string is: %s", p_text); 187 SHERPA_ONNX_LOGE("string is: %s", p_text);
188 188
189 - std::function<void(const float *, int32_t, float)> callback_wrapper = 189 + std::function<int32_t(const float *, int32_t, float)> callback_wrapper =
190 [env, callback](const float *samples, int32_t n, float /*progress*/) { 190 [env, callback](const float *samples, int32_t n, float /*progress*/) {
191 jclass cls = env->GetObjectClass(callback); 191 jclass cls = env->GetObjectClass(callback);
192 - jmethodID mid = env->GetMethodID(cls, "invoke", "([F)V"); 192 +
  193 +#if 0
  194 + // this block is for debugging only
  195 + // see also
  196 + // https://jnjosh.com/posts/kotlinfromcpp/
  197 + jmethodID classMethodId =
  198 + env->GetMethodID(cls, "getClass", "()Ljava/lang/Class;");
  199 + jobject klassObj = env->CallObjectMethod(callback, classMethodId);
  200 + auto klassObject = env->GetObjectClass(klassObj);
  201 + auto nameMethodId =
  202 + env->GetMethodID(klassObject, "getName", "()Ljava/lang/String;");
  203 + jstring classString =
  204 + (jstring)env->CallObjectMethod(klassObj, nameMethodId);
  205 + auto className = env->GetStringUTFChars(classString, NULL);
  206 + SHERPA_ONNX_LOGE("name is: %s", className);
  207 + env->ReleaseStringUTFChars(classString, className);
  208 +#endif
  209 +
  210 + jmethodID mid =
  211 + env->GetMethodID(cls, "invoke", "([F)Ljava/lang/Integer;");
  212 + if (mid == nullptr) {
  213 + SHERPA_ONNX_LOGE("Failed to get the callback. Ignore it.");
  214 + return 1;
  215 + }
193 216
194 jfloatArray samples_arr = env->NewFloatArray(n); 217 jfloatArray samples_arr = env->NewFloatArray(n);
195 env->SetFloatArrayRegion(samples_arr, 0, n, samples); 218 env->SetFloatArrayRegion(samples_arr, 0, n, samples);
196 - env->CallVoidMethod(callback, mid, samples_arr); 219 +
  220 + jobject should_continue =
  221 + env->CallObjectMethod(callback, mid, samples_arr);
  222 + jclass jklass = env->GetObjectClass(should_continue);
  223 + jmethodID int_value_mid = env->GetMethodID(jklass, "intValue", "()I");
  224 + return env->CallIntMethod(should_continue, int_value_mid);
197 }; 225 };
198 226
199 auto audio = reinterpret_cast<sherpa_onnx::OfflineTts *>(ptr)->Generate( 227 auto audio = reinterpret_cast<sherpa_onnx::OfflineTts *>(ptr)->Generate(
@@ -57,13 +57,13 @@ void PybindOfflineTts(py::module *m) { @@ -57,13 +57,13 @@ void PybindOfflineTts(py::module *m) {
57 "generate", 57 "generate",
58 [](const PyClass &self, const std::string &text, int64_t sid, 58 [](const PyClass &self, const std::string &text, int64_t sid,
59 float speed, 59 float speed,
60 - std::function<void(py::array_t<float>, float)> callback) 60 + std::function<int32_t(py::array_t<float>, float)> callback)
61 -> GeneratedAudio { 61 -> GeneratedAudio {
62 if (!callback) { 62 if (!callback) {
63 return self.Generate(text, sid, speed); 63 return self.Generate(text, sid, speed);
64 } 64 }
65 65
66 - std::function<void(const float *, int32_t, float)> 66 + std::function<int32_t(const float *, int32_t, float)>
67 callback_wrapper = [callback](const float *samples, int32_t n, 67 callback_wrapper = [callback](const float *samples, int32_t n,
68 float progress) { 68 float progress) {
69 // CAUTION(fangjun): we have to copy samples since it is 69 // CAUTION(fangjun): we have to copy samples since it is
@@ -75,7 +75,7 @@ void PybindOfflineTts(py::module *m) { @@ -75,7 +75,7 @@ void PybindOfflineTts(py::module *m) {
75 py::buffer_info buf = array.request(); 75 py::buffer_info buf = array.request();
76 auto p = static_cast<float *>(buf.ptr); 76 auto p = static_cast<float *>(buf.ptr);
77 std::copy(samples, samples + n, p); 77 std::copy(samples, samples + n, p);
78 - callback(array, progress); 78 + return callback(array, progress);
79 }; 79 };
80 80
81 return self.Generate(text, sid, speed, callback_wrapper); 81 return self.Generate(text, sid, speed, callback_wrapper);