Committed by
GitHub
Added progress for callback of tts generator (#712)
Co-authored-by: leohwang <leohwang@360converter.com>
正在显示
8 个修改的文件
包含
51 行增加
和
40 行删除
| @@ -807,16 +807,10 @@ int32_t SherpaOnnxOfflineTtsNumSpeakers(const SherpaOnnxOfflineTts *tts) { | @@ -807,16 +807,10 @@ int32_t SherpaOnnxOfflineTtsNumSpeakers(const SherpaOnnxOfflineTts *tts) { | ||
| 807 | return tts->impl->NumSpeakers(); | 807 | return tts->impl->NumSpeakers(); |
| 808 | } | 808 | } |
| 809 | 809 | ||
| 810 | -const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerate( | 810 | +static const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateInternal( |
| 811 | const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, | 811 | const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, |
| 812 | - float speed) { | ||
| 813 | - return SherpaOnnxOfflineTtsGenerateWithCallback(tts, text, sid, speed, | ||
| 814 | - nullptr); | ||
| 815 | -} | ||
| 816 | - | ||
| 817 | -const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallback( | ||
| 818 | - const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed, | ||
| 819 | - SherpaOnnxGeneratedAudioCallback callback) { | 812 | + float speed, std::function<void(const float *, int32_t, float)> callback) |
| 813 | +{ | ||
| 820 | sherpa_onnx::GeneratedAudio audio = | 814 | sherpa_onnx::GeneratedAudio audio = |
| 821 | tts->impl->Generate(text, sid, speed, callback); | 815 | tts->impl->Generate(text, sid, speed, callback); |
| 822 | 816 | ||
| @@ -836,30 +830,39 @@ const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallback( | @@ -836,30 +830,39 @@ const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallback( | ||
| 836 | return ans; | 830 | return ans; |
| 837 | } | 831 | } |
| 838 | 832 | ||
| 839 | -const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallbackWithArg( | 833 | +const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerate( |
| 834 | + const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, | ||
| 835 | + float speed) { | ||
| 836 | + return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, nullptr ); | ||
| 837 | +} | ||
| 838 | + | ||
| 839 | +const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallback( | ||
| 840 | const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed, | 840 | const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed, |
| 841 | - SherpaOnnxGeneratedAudioCallbackWithArg callback, void *arg) { | ||
| 842 | - auto wrapper = [callback, arg](const float *samples, int32_t n) { | ||
| 843 | - callback(samples, n, arg); | 841 | + SherpaOnnxGeneratedAudioCallback callback) { |
| 842 | + auto wrapper = [callback](const float *samples, int32_t n, float /*progress*/) { | ||
| 843 | + callback(samples, n ); | ||
| 844 | }; | 844 | }; |
| 845 | 845 | ||
| 846 | - sherpa_onnx::GeneratedAudio audio = | ||
| 847 | - tts->impl->Generate(text, sid, speed, wrapper); | ||
| 848 | - | ||
| 849 | - if (audio.samples.empty()) { | ||
| 850 | - return nullptr; | ||
| 851 | - } | ||
| 852 | - | ||
| 853 | - SherpaOnnxGeneratedAudio *ans = new SherpaOnnxGeneratedAudio; | 846 | + return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, wrapper ); |
| 847 | +} | ||
| 854 | 848 | ||
| 855 | - float *samples = new float[audio.samples.size()]; | ||
| 856 | - std::copy(audio.samples.begin(), audio.samples.end(), samples); | 849 | +const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithProgressCallback( |
| 850 | + const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed, | ||
| 851 | + SherpaOnnxGeneratedAudioProgressCallback callback) { | ||
| 852 | + auto wrapper = [callback](const float *samples, int32_t n, float progress) { | ||
| 853 | + callback(samples, n, progress ); | ||
| 854 | + }; | ||
| 855 | + return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, wrapper ); | ||
| 856 | +} | ||
| 857 | 857 | ||
| 858 | - ans->samples = samples; | ||
| 859 | - ans->n = audio.samples.size(); | ||
| 860 | - ans->sample_rate = audio.sample_rate; | 858 | +const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallbackWithArg( |
| 859 | + const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed, | ||
| 860 | + SherpaOnnxGeneratedAudioCallbackWithArg callback, void *arg) { | ||
| 861 | + auto wrapper = [callback, arg](const float *samples, int32_t n, float /*progress*/) { | ||
| 862 | + callback(samples, n, arg); | ||
| 863 | + }; | ||
| 861 | 864 | ||
| 862 | - return ans; | 865 | + return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, wrapper ); |
| 863 | } | 866 | } |
| 864 | 867 | ||
| 865 | void SherpaOnnxDestroyOfflineTtsGeneratedAudio( | 868 | void SherpaOnnxDestroyOfflineTtsGeneratedAudio( |
| @@ -768,6 +768,9 @@ typedef void (*SherpaOnnxGeneratedAudioCallback)(const float *samples, | @@ -768,6 +768,9 @@ typedef void (*SherpaOnnxGeneratedAudioCallback)(const float *samples, | ||
| 768 | typedef void (*SherpaOnnxGeneratedAudioCallbackWithArg)(const float *samples, | 768 | typedef void (*SherpaOnnxGeneratedAudioCallbackWithArg)(const float *samples, |
| 769 | int32_t n, void *arg); | 769 | int32_t n, void *arg); |
| 770 | 770 | ||
| 771 | +typedef void (*SherpaOnnxGeneratedAudioProgressCallback)(const float *samples, | ||
| 772 | + int32_t n, float p); | ||
| 773 | + | ||
| 771 | SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTts SherpaOnnxOfflineTts; | 774 | SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTts SherpaOnnxOfflineTts; |
| 772 | 775 | ||
| 773 | // Create an instance of offline TTS. The user has to use DestroyOfflineTts() | 776 | // Create an instance of offline TTS. The user has to use DestroyOfflineTts() |
| @@ -134,7 +134,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | @@ -134,7 +134,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | ||
| 134 | if (config_.max_num_sentences <= 0 || x_size <= config_.max_num_sentences) { | 134 | if (config_.max_num_sentences <= 0 || x_size <= config_.max_num_sentences) { |
| 135 | auto ans = Process(x, sid, speed); | 135 | auto ans = Process(x, sid, speed); |
| 136 | if (callback) { | 136 | if (callback) { |
| 137 | - callback(ans.samples.data(), ans.samples.size()); | 137 | + callback(ans.samples.data(), ans.samples.size(), 1.0); |
| 138 | } | 138 | } |
| 139 | return ans; | 139 | return ans; |
| 140 | } | 140 | } |
| @@ -168,7 +168,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | @@ -168,7 +168,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | ||
| 168 | ans.samples.insert(ans.samples.end(), audio.samples.begin(), | 168 | ans.samples.insert(ans.samples.end(), audio.samples.begin(), |
| 169 | audio.samples.end()); | 169 | audio.samples.end()); |
| 170 | if (callback) { | 170 | if (callback) { |
| 171 | - callback(audio.samples.data(), audio.samples.size()); | 171 | + callback(audio.samples.data(), audio.samples.size(), b * 1.0 / num_batches); |
| 172 | // Caution(fangjun): audio is freed when the callback returns, so users | 172 | // Caution(fangjun): audio is freed when the callback returns, so users |
| 173 | // should copy the data if they want to access the data after | 173 | // should copy the data if they want to access the data after |
| 174 | // the callback returns to avoid segmentation fault. | 174 | // the callback returns to avoid segmentation fault. |
| @@ -187,7 +187,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | @@ -187,7 +187,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { | ||
| 187 | ans.samples.insert(ans.samples.end(), audio.samples.begin(), | 187 | ans.samples.insert(ans.samples.end(), audio.samples.begin(), |
| 188 | audio.samples.end()); | 188 | audio.samples.end()); |
| 189 | if (callback) { | 189 | if (callback) { |
| 190 | - callback(audio.samples.data(), audio.samples.size()); | 190 | + callback(audio.samples.data(), audio.samples.size(), 1.0); |
| 191 | // Caution(fangjun): audio is freed when the callback returns, so users | 191 | // Caution(fangjun): audio is freed when the callback returns, so users |
| 192 | // should copy the data if they want to access the data after | 192 | // should copy the data if they want to access the data after |
| 193 | // the callback returns to avoid segmentation fault. | 193 | // the callback returns to avoid segmentation fault. |
| @@ -55,7 +55,7 @@ struct GeneratedAudio { | @@ -55,7 +55,7 @@ struct GeneratedAudio { | ||
| 55 | class OfflineTtsImpl; | 55 | class OfflineTtsImpl; |
| 56 | 56 | ||
| 57 | using GeneratedAudioCallback = | 57 | using GeneratedAudioCallback = |
| 58 | - std::function<void(const float * /*samples*/, int32_t /*n*/)>; | 58 | + std::function<void(const float * /*samples*/, int32_t /*n*/, float /*progress*/)>; |
| 59 | 59 | ||
| 60 | class OfflineTts { | 60 | class OfflineTts { |
| 61 | public: | 61 | public: |
| @@ -47,7 +47,7 @@ static void Handler(int32_t /*sig*/) { | @@ -47,7 +47,7 @@ static void Handler(int32_t /*sig*/) { | ||
| 47 | fprintf(stderr, "\nCaught Ctrl + C. Exiting\n"); | 47 | fprintf(stderr, "\nCaught Ctrl + C. Exiting\n"); |
| 48 | } | 48 | } |
| 49 | 49 | ||
| 50 | -static void AudioGeneratedCallback(const float *s, int32_t n) { | 50 | +static void AudioGeneratedCallback(const float *s, int32_t n, float /*progress*/) { |
| 51 | if (n > 0) { | 51 | if (n > 0) { |
| 52 | Samples samples; | 52 | Samples samples; |
| 53 | samples.data = std::vector<float>{s, s + n}; | 53 | samples.data = std::vector<float>{s, s + n}; |
| @@ -9,6 +9,11 @@ | @@ -9,6 +9,11 @@ | ||
| 9 | #include "sherpa-onnx/csrc/parse-options.h" | 9 | #include "sherpa-onnx/csrc/parse-options.h" |
| 10 | #include "sherpa-onnx/csrc/wave-writer.h" | 10 | #include "sherpa-onnx/csrc/wave-writer.h" |
| 11 | 11 | ||
| 12 | +void audioCallback(const float *samples, int32_t n, float progress) | ||
| 13 | +{ | ||
| 14 | + printf( "sample=%d, progress=%f\n", n, progress ); | ||
| 15 | +} | ||
| 16 | + | ||
| 12 | int main(int32_t argc, char *argv[]) { | 17 | int main(int32_t argc, char *argv[]) { |
| 13 | const char *kUsageMessage = R"usage( | 18 | const char *kUsageMessage = R"usage( |
| 14 | Offline text-to-speech with sherpa-onnx | 19 | Offline text-to-speech with sherpa-onnx |
| @@ -74,7 +79,7 @@ or details. | @@ -74,7 +79,7 @@ or details. | ||
| 74 | sherpa_onnx::OfflineTts tts(config); | 79 | sherpa_onnx::OfflineTts tts(config); |
| 75 | 80 | ||
| 76 | const auto begin = std::chrono::steady_clock::now(); | 81 | const auto begin = std::chrono::steady_clock::now(); |
| 77 | - auto audio = tts.Generate(po.GetArg(1), sid); | 82 | + auto audio = tts.Generate(po.GetArg(1), sid, 1.0, audioCallback); |
| 78 | const auto end = std::chrono::steady_clock::now(); | 83 | const auto end = std::chrono::steady_clock::now(); |
| 79 | 84 | ||
| 80 | if (audio.samples.empty()) { | 85 | if (audio.samples.empty()) { |
| @@ -797,7 +797,7 @@ class SherpaOnnxOfflineTts { | @@ -797,7 +797,7 @@ class SherpaOnnxOfflineTts { | ||
| 797 | 797 | ||
| 798 | GeneratedAudio Generate( | 798 | GeneratedAudio Generate( |
| 799 | const std::string &text, int64_t sid = 0, float speed = 1.0, | 799 | const std::string &text, int64_t sid = 0, float speed = 1.0, |
| 800 | - std::function<void(const float *, int32_t)> callback = nullptr) const { | 800 | + std::function<void(const float *, int32_t, float)> callback = nullptr) const { |
| 801 | return tts_.Generate(text, sid, speed, callback); | 801 | return tts_.Generate(text, sid, speed, callback); |
| 802 | } | 802 | } |
| 803 | 803 | ||
| @@ -1314,8 +1314,8 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl( | @@ -1314,8 +1314,8 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl( | ||
| 1314 | const char *p_text = env->GetStringUTFChars(text, nullptr); | 1314 | const char *p_text = env->GetStringUTFChars(text, nullptr); |
| 1315 | SHERPA_ONNX_LOGE("string is: %s", p_text); | 1315 | SHERPA_ONNX_LOGE("string is: %s", p_text); |
| 1316 | 1316 | ||
| 1317 | - std::function<void(const float *, int32_t)> callback_wrapper = | ||
| 1318 | - [env, callback](const float *samples, int32_t n) { | 1317 | + std::function<void(const float *, int32_t, float)> callback_wrapper = |
| 1318 | + [env, callback](const float *samples, int32_t n, float /*p*/) { | ||
| 1319 | jclass cls = env->GetObjectClass(callback); | 1319 | jclass cls = env->GetObjectClass(callback); |
| 1320 | jmethodID mid = env->GetMethodID(cls, "invoke", "([F)V"); | 1320 | jmethodID mid = env->GetMethodID(cls, "invoke", "([F)V"); |
| 1321 | 1321 |
| @@ -55,14 +55,14 @@ void PybindOfflineTts(py::module *m) { | @@ -55,14 +55,14 @@ void PybindOfflineTts(py::module *m) { | ||
| 55 | .def( | 55 | .def( |
| 56 | "generate", | 56 | "generate", |
| 57 | [](const PyClass &self, const std::string &text, int64_t sid, | 57 | [](const PyClass &self, const std::string &text, int64_t sid, |
| 58 | - float speed, std::function<void(py::array_t<float>)> callback) | 58 | + float speed, std::function<void(py::array_t<float>, float)> callback) |
| 59 | -> GeneratedAudio { | 59 | -> GeneratedAudio { |
| 60 | if (!callback) { | 60 | if (!callback) { |
| 61 | return self.Generate(text, sid, speed); | 61 | return self.Generate(text, sid, speed); |
| 62 | } | 62 | } |
| 63 | 63 | ||
| 64 | - std::function<void(const float *, int32_t)> callback_wrapper = | ||
| 65 | - [callback](const float *samples, int32_t n) { | 64 | + std::function<void(const float *, int32_t, float)> callback_wrapper = |
| 65 | + [callback](const float *samples, int32_t n, float progress) { | ||
| 66 | // CAUTION(fangjun): we have to copy samples since it is | 66 | // CAUTION(fangjun): we have to copy samples since it is |
| 67 | // freed once the call back returns. | 67 | // freed once the call back returns. |
| 68 | 68 | ||
| @@ -72,7 +72,7 @@ void PybindOfflineTts(py::module *m) { | @@ -72,7 +72,7 @@ void PybindOfflineTts(py::module *m) { | ||
| 72 | py::buffer_info buf = array.request(); | 72 | py::buffer_info buf = array.request(); |
| 73 | auto p = static_cast<float *>(buf.ptr); | 73 | auto p = static_cast<float *>(buf.ptr); |
| 74 | std::copy(samples, samples + n, p); | 74 | std::copy(samples, samples + n, p); |
| 75 | - callback(array); | 75 | + callback(array, progress); |
| 76 | }; | 76 | }; |
| 77 | 77 | ||
| 78 | return self.Generate(text, sid, speed, callback_wrapper); | 78 | return self.Generate(text, sid, speed, callback_wrapper); |
-
请 注册 或 登录 后发表评论