Leo Huang
Committed by GitHub

Added progress for callback of tts generator (#712)

Co-authored-by: leohwang <leohwang@360converter.com>
@@ -807,16 +807,10 @@ int32_t SherpaOnnxOfflineTtsNumSpeakers(const SherpaOnnxOfflineTts *tts) { @@ -807,16 +807,10 @@ int32_t SherpaOnnxOfflineTtsNumSpeakers(const SherpaOnnxOfflineTts *tts) {
807 return tts->impl->NumSpeakers(); 807 return tts->impl->NumSpeakers();
808 } 808 }
809 809
810 -const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerate( 810 +static const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateInternal(
811 const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, 811 const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid,
812 - float speed) {  
813 - return SherpaOnnxOfflineTtsGenerateWithCallback(tts, text, sid, speed,  
814 - nullptr);  
815 -}  
816 -  
817 -const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallback(  
818 - const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,  
819 - SherpaOnnxGeneratedAudioCallback callback) { 812 + float speed, std::function<void(const float *, int32_t, float)> callback)
  813 +{
820 sherpa_onnx::GeneratedAudio audio = 814 sherpa_onnx::GeneratedAudio audio =
821 tts->impl->Generate(text, sid, speed, callback); 815 tts->impl->Generate(text, sid, speed, callback);
822 816
@@ -836,30 +830,39 @@ const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallback( @@ -836,30 +830,39 @@ const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallback(
836 return ans; 830 return ans;
837 } 831 }
838 832
839 -const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallbackWithArg( 833 +const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerate(
  834 + const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid,
  835 + float speed) {
  836 + return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, nullptr );
  837 +}
  838 +
  839 +const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallback(
840 const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed, 840 const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
841 - SherpaOnnxGeneratedAudioCallbackWithArg callback, void *arg) {  
842 - auto wrapper = [callback, arg](const float *samples, int32_t n) {  
843 - callback(samples, n, arg); 841 + SherpaOnnxGeneratedAudioCallback callback) {
  842 + auto wrapper = [callback](const float *samples, int32_t n, float /*progress*/) {
  843 + callback(samples, n );
844 }; 844 };
845 845
846 - sherpa_onnx::GeneratedAudio audio =  
847 - tts->impl->Generate(text, sid, speed, wrapper);  
848 -  
849 - if (audio.samples.empty()) {  
850 - return nullptr;  
851 - }  
852 -  
853 - SherpaOnnxGeneratedAudio *ans = new SherpaOnnxGeneratedAudio; 846 + return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, wrapper );
  847 +}
854 848
855 - float *samples = new float[audio.samples.size()];  
856 - std::copy(audio.samples.begin(), audio.samples.end(), samples); 849 +const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithProgressCallback(
  850 + const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
  851 + SherpaOnnxGeneratedAudioProgressCallback callback) {
  852 + auto wrapper = [callback](const float *samples, int32_t n, float progress) {
  853 + callback(samples, n, progress );
  854 + };
  855 + return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, wrapper );
  856 +}
857 857
858 - ans->samples = samples;  
859 - ans->n = audio.samples.size();  
860 - ans->sample_rate = audio.sample_rate; 858 +const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallbackWithArg(
  859 + const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
  860 + SherpaOnnxGeneratedAudioCallbackWithArg callback, void *arg) {
  861 + auto wrapper = [callback, arg](const float *samples, int32_t n, float /*progress*/) {
  862 + callback(samples, n, arg);
  863 + };
861 864
862 - return ans; 865 + return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, wrapper );
863 } 866 }
864 867
865 void SherpaOnnxDestroyOfflineTtsGeneratedAudio( 868 void SherpaOnnxDestroyOfflineTtsGeneratedAudio(
@@ -768,6 +768,9 @@ typedef void (*SherpaOnnxGeneratedAudioCallback)(const float *samples, @@ -768,6 +768,9 @@ typedef void (*SherpaOnnxGeneratedAudioCallback)(const float *samples,
768 typedef void (*SherpaOnnxGeneratedAudioCallbackWithArg)(const float *samples, 768 typedef void (*SherpaOnnxGeneratedAudioCallbackWithArg)(const float *samples,
769 int32_t n, void *arg); 769 int32_t n, void *arg);
770 770
  771 +typedef void (*SherpaOnnxGeneratedAudioProgressCallback)(const float *samples,
  772 + int32_t n, float p);
  773 +
771 SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTts SherpaOnnxOfflineTts; 774 SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTts SherpaOnnxOfflineTts;
772 775
773 // Create an instance of offline TTS. The user has to use DestroyOfflineTts() 776 // Create an instance of offline TTS. The user has to use DestroyOfflineTts()
@@ -134,7 +134,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -134,7 +134,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
134 if (config_.max_num_sentences <= 0 || x_size <= config_.max_num_sentences) { 134 if (config_.max_num_sentences <= 0 || x_size <= config_.max_num_sentences) {
135 auto ans = Process(x, sid, speed); 135 auto ans = Process(x, sid, speed);
136 if (callback) { 136 if (callback) {
137 - callback(ans.samples.data(), ans.samples.size()); 137 + callback(ans.samples.data(), ans.samples.size(), 1.0);
138 } 138 }
139 return ans; 139 return ans;
140 } 140 }
@@ -168,7 +168,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -168,7 +168,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
168 ans.samples.insert(ans.samples.end(), audio.samples.begin(), 168 ans.samples.insert(ans.samples.end(), audio.samples.begin(),
169 audio.samples.end()); 169 audio.samples.end());
170 if (callback) { 170 if (callback) {
171 - callback(audio.samples.data(), audio.samples.size()); 171 + callback(audio.samples.data(), audio.samples.size(), b * 1.0 / num_batches);
172 // Caution(fangjun): audio is freed when the callback returns, so users 172 // Caution(fangjun): audio is freed when the callback returns, so users
173 // should copy the data if they want to access the data after 173 // should copy the data if they want to access the data after
174 // the callback returns to avoid segmentation fault. 174 // the callback returns to avoid segmentation fault.
@@ -187,7 +187,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { @@ -187,7 +187,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
187 ans.samples.insert(ans.samples.end(), audio.samples.begin(), 187 ans.samples.insert(ans.samples.end(), audio.samples.begin(),
188 audio.samples.end()); 188 audio.samples.end());
189 if (callback) { 189 if (callback) {
190 - callback(audio.samples.data(), audio.samples.size()); 190 + callback(audio.samples.data(), audio.samples.size(), 1.0);
191 // Caution(fangjun): audio is freed when the callback returns, so users 191 // Caution(fangjun): audio is freed when the callback returns, so users
192 // should copy the data if they want to access the data after 192 // should copy the data if they want to access the data after
193 // the callback returns to avoid segmentation fault. 193 // the callback returns to avoid segmentation fault.
@@ -55,7 +55,7 @@ struct GeneratedAudio { @@ -55,7 +55,7 @@ struct GeneratedAudio {
55 class OfflineTtsImpl; 55 class OfflineTtsImpl;
56 56
57 using GeneratedAudioCallback = 57 using GeneratedAudioCallback =
58 - std::function<void(const float * /*samples*/, int32_t /*n*/)>; 58 + std::function<void(const float * /*samples*/, int32_t /*n*/, float /*progress*/)>;
59 59
60 class OfflineTts { 60 class OfflineTts {
61 public: 61 public:
@@ -47,7 +47,7 @@ static void Handler(int32_t /*sig*/) { @@ -47,7 +47,7 @@ static void Handler(int32_t /*sig*/) {
47 fprintf(stderr, "\nCaught Ctrl + C. Exiting\n"); 47 fprintf(stderr, "\nCaught Ctrl + C. Exiting\n");
48 } 48 }
49 49
50 -static void AudioGeneratedCallback(const float *s, int32_t n) { 50 +static void AudioGeneratedCallback(const float *s, int32_t n, float /*progress*/) {
51 if (n > 0) { 51 if (n > 0) {
52 Samples samples; 52 Samples samples;
53 samples.data = std::vector<float>{s, s + n}; 53 samples.data = std::vector<float>{s, s + n};
@@ -9,6 +9,11 @@ @@ -9,6 +9,11 @@
9 #include "sherpa-onnx/csrc/parse-options.h" 9 #include "sherpa-onnx/csrc/parse-options.h"
10 #include "sherpa-onnx/csrc/wave-writer.h" 10 #include "sherpa-onnx/csrc/wave-writer.h"
11 11
  12 +void audioCallback(const float *samples, int32_t n, float progress)
  13 +{
  14 + printf( "sample=%d, progress=%f\n", n, progress );
  15 +}
  16 +
12 int main(int32_t argc, char *argv[]) { 17 int main(int32_t argc, char *argv[]) {
13 const char *kUsageMessage = R"usage( 18 const char *kUsageMessage = R"usage(
14 Offline text-to-speech with sherpa-onnx 19 Offline text-to-speech with sherpa-onnx
@@ -74,7 +79,7 @@ or details. @@ -74,7 +79,7 @@ or details.
74 sherpa_onnx::OfflineTts tts(config); 79 sherpa_onnx::OfflineTts tts(config);
75 80
76 const auto begin = std::chrono::steady_clock::now(); 81 const auto begin = std::chrono::steady_clock::now();
77 - auto audio = tts.Generate(po.GetArg(1), sid); 82 + auto audio = tts.Generate(po.GetArg(1), sid, 1.0, audioCallback);
78 const auto end = std::chrono::steady_clock::now(); 83 const auto end = std::chrono::steady_clock::now();
79 84
80 if (audio.samples.empty()) { 85 if (audio.samples.empty()) {
@@ -797,7 +797,7 @@ class SherpaOnnxOfflineTts { @@ -797,7 +797,7 @@ class SherpaOnnxOfflineTts {
797 797
798 GeneratedAudio Generate( 798 GeneratedAudio Generate(
799 const std::string &text, int64_t sid = 0, float speed = 1.0, 799 const std::string &text, int64_t sid = 0, float speed = 1.0,
800 - std::function<void(const float *, int32_t)> callback = nullptr) const { 800 + std::function<void(const float *, int32_t, float)> callback = nullptr) const {
801 return tts_.Generate(text, sid, speed, callback); 801 return tts_.Generate(text, sid, speed, callback);
802 } 802 }
803 803
@@ -1314,8 +1314,8 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl( @@ -1314,8 +1314,8 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl(
1314 const char *p_text = env->GetStringUTFChars(text, nullptr); 1314 const char *p_text = env->GetStringUTFChars(text, nullptr);
1315 SHERPA_ONNX_LOGE("string is: %s", p_text); 1315 SHERPA_ONNX_LOGE("string is: %s", p_text);
1316 1316
1317 - std::function<void(const float *, int32_t)> callback_wrapper =  
1318 - [env, callback](const float *samples, int32_t n) { 1317 + std::function<void(const float *, int32_t, float)> callback_wrapper =
  1318 + [env, callback](const float *samples, int32_t n, float /*p*/) {
1319 jclass cls = env->GetObjectClass(callback); 1319 jclass cls = env->GetObjectClass(callback);
1320 jmethodID mid = env->GetMethodID(cls, "invoke", "([F)V"); 1320 jmethodID mid = env->GetMethodID(cls, "invoke", "([F)V");
1321 1321
@@ -55,14 +55,14 @@ void PybindOfflineTts(py::module *m) { @@ -55,14 +55,14 @@ void PybindOfflineTts(py::module *m) {
55 .def( 55 .def(
56 "generate", 56 "generate",
57 [](const PyClass &self, const std::string &text, int64_t sid, 57 [](const PyClass &self, const std::string &text, int64_t sid,
58 - float speed, std::function<void(py::array_t<float>)> callback) 58 + float speed, std::function<void(py::array_t<float>, float)> callback)
59 -> GeneratedAudio { 59 -> GeneratedAudio {
60 if (!callback) { 60 if (!callback) {
61 return self.Generate(text, sid, speed); 61 return self.Generate(text, sid, speed);
62 } 62 }
63 63
64 - std::function<void(const float *, int32_t)> callback_wrapper =  
65 - [callback](const float *samples, int32_t n) { 64 + std::function<void(const float *, int32_t, float)> callback_wrapper =
  65 + [callback](const float *samples, int32_t n, float progress) {
66 // CAUTION(fangjun): we have to copy samples since it is 66 // CAUTION(fangjun): we have to copy samples since it is
67 // freed once the call back returns. 67 // freed once the call back returns.
68 68
@@ -72,7 +72,7 @@ void PybindOfflineTts(py::module *m) { @@ -72,7 +72,7 @@ void PybindOfflineTts(py::module *m) {
72 py::buffer_info buf = array.request(); 72 py::buffer_info buf = array.request();
73 auto p = static_cast<float *>(buf.ptr); 73 auto p = static_cast<float *>(buf.ptr);
74 std::copy(samples, samples + n, p); 74 std::copy(samples, samples + n, p);
75 - callback(array); 75 + callback(array, progress);
76 }; 76 };
77 77
78 return self.Generate(text, sid, speed, callback_wrapper); 78 return self.Generate(text, sid, speed, callback_wrapper);