Fangjun Kuang
Committed by GitHub

Support user provided data in tts callback. (#653)

@@ -636,6 +636,32 @@ const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallback( @@ -636,6 +636,32 @@ const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallback(
636 return ans; 636 return ans;
637 } 637 }
638 638
  639 +const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallbackWithArg(
  640 + const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
  641 + SherpaOnnxGeneratedAudioCallbackWithArg callback, void *arg) {
  642 + auto wrapper = [callback, arg](const float *samples, int32_t n) {
  643 + callback(samples, n, arg);
  644 + };
  645 +
  646 + sherpa_onnx::GeneratedAudio audio =
  647 + tts->impl->Generate(text, sid, speed, wrapper);
  648 +
  649 + if (audio.samples.empty()) {
  650 + return nullptr;
  651 + }
  652 +
  653 + SherpaOnnxGeneratedAudio *ans = new SherpaOnnxGeneratedAudio;
  654 +
  655 + float *samples = new float[audio.samples.size()];
  656 + std::copy(audio.samples.begin(), audio.samples.end(), samples);
  657 +
  658 + ans->samples = samples;
  659 + ans->n = audio.samples.size();
  660 + ans->sample_rate = audio.sample_rate;
  661 +
  662 + return ans;
  663 +}
  664 +
639 void SherpaOnnxDestroyOfflineTtsGeneratedAudio( 665 void SherpaOnnxDestroyOfflineTtsGeneratedAudio(
640 const SherpaOnnxGeneratedAudio *p) { 666 const SherpaOnnxGeneratedAudio *p) {
641 if (p) { 667 if (p) {
@@ -644,6 +644,9 @@ SHERPA_ONNX_API typedef struct SherpaOnnxGeneratedAudio { @@ -644,6 +644,9 @@ SHERPA_ONNX_API typedef struct SherpaOnnxGeneratedAudio {
644 typedef void (*SherpaOnnxGeneratedAudioCallback)(const float *samples, 644 typedef void (*SherpaOnnxGeneratedAudioCallback)(const float *samples,
645 int32_t n); 645 int32_t n);
646 646
  647 +typedef void (*SherpaOnnxGeneratedAudioCallbackWithArg)(const float *samples,
  648 + int32_t n, void *arg);
  649 +
647 SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTts SherpaOnnxOfflineTts; 650 SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTts SherpaOnnxOfflineTts;
648 651
649 // Create an instance of offline TTS. The user has to use DestroyOfflineTts() 652 // Create an instance of offline TTS. The user has to use DestroyOfflineTts()
@@ -678,6 +681,13 @@ SherpaOnnxOfflineTtsGenerateWithCallback( @@ -678,6 +681,13 @@ SherpaOnnxOfflineTtsGenerateWithCallback(
678 const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed, 681 const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
679 SherpaOnnxGeneratedAudioCallback callback); 682 SherpaOnnxGeneratedAudioCallback callback);
680 683
  684 +// Same as SherpaOnnxGeneratedAudioCallback but you can pass an additional
  685 +// `void* arg` to the callback.
  686 +SHERPA_ONNX_API const SherpaOnnxGeneratedAudio *
  687 +SherpaOnnxOfflineTtsGenerateWithCallbackWithArg(
  688 + const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
  689 + SherpaOnnxGeneratedAudioCallbackWithArg callback, void *arg);
  690 +
681 SHERPA_ONNX_API void SherpaOnnxDestroyOfflineTtsGeneratedAudio( 691 SHERPA_ONNX_API void SherpaOnnxDestroyOfflineTtsGeneratedAudio(
682 const SherpaOnnxGeneratedAudio *p); 692 const SherpaOnnxGeneratedAudio *p);
683 693
@@ -43,5 +43,3 @@ void PybindAlsa(py::module *m) { @@ -43,5 +43,3 @@ void PybindAlsa(py::module *m) {
43 } 43 }
44 44
45 } // namespace sherpa_onnx 45 } // namespace sherpa_onnx
46 -  
47 -#endif // SHERPA_ONNX_PYTHON_CSRC_FAKED_ALSA_H_