Committed by
GitHub
Support passing TTS callback in Swift API (#1218)
正在显示
2 个修改的文件
包含
54 行增加
和
4 行删除
| @@ -757,6 +757,14 @@ class SherpaOnnxGeneratedAudioWrapper { | @@ -757,6 +757,14 @@ class SherpaOnnxGeneratedAudioWrapper { | ||
| 757 | } | 757 | } |
| 758 | } | 758 | } |
| 759 | 759 | ||
| 760 | +typealias TtsCallbackWithArg = ( | ||
| 761 | + @convention(c) ( | ||
| 762 | + UnsafePointer<Float>?, // const float* samples | ||
| 763 | + Int32, // int32_t n | ||
| 764 | + UnsafeMutableRawPointer? // void *arg | ||
| 765 | + ) -> Int32 | ||
| 766 | +)? | ||
| 767 | + | ||
| 760 | class SherpaOnnxOfflineTtsWrapper { | 768 | class SherpaOnnxOfflineTtsWrapper { |
| 761 | /// A pointer to the underlying counterpart in C | 769 | /// A pointer to the underlying counterpart in C |
| 762 | let tts: OpaquePointer! | 770 | let tts: OpaquePointer! |
| @@ -780,6 +788,17 @@ class SherpaOnnxOfflineTtsWrapper { | @@ -780,6 +788,17 @@ class SherpaOnnxOfflineTtsWrapper { | ||
| 780 | 788 | ||
| 781 | return SherpaOnnxGeneratedAudioWrapper(audio: audio) | 789 | return SherpaOnnxGeneratedAudioWrapper(audio: audio) |
| 782 | } | 790 | } |
| 791 | + | ||
| 792 | + func generateWithCallbackWithArg( | ||
| 793 | + text: String, callback: TtsCallbackWithArg, arg: UnsafeMutableRawPointer, sid: Int = 0, | ||
| 794 | + speed: Float = 1.0 | ||
| 795 | + ) -> SherpaOnnxGeneratedAudioWrapper { | ||
| 796 | + let audio: UnsafePointer<SherpaOnnxGeneratedAudio>? = | ||
| 797 | + SherpaOnnxOfflineTtsGenerateWithCallbackWithArg( | ||
| 798 | + tts, toCPointer(text), Int32(sid), speed, callback, arg) | ||
| 799 | + | ||
| 800 | + return SherpaOnnxGeneratedAudioWrapper(audio: audio) | ||
| 801 | + } | ||
| 783 | } | 802 | } |
| 784 | 803 | ||
| 785 | // spoken language identification | 804 | // spoken language identification |
| 1 | +class MyClass { | ||
| 2 | + func playSamples(samples: [Float]) { | ||
| 3 | + print("Play \(samples.count) samples") | ||
| 4 | + } | ||
| 5 | +} | ||
| 6 | + | ||
| 1 | func run() { | 7 | func run() { |
| 2 | let model = "./vits-piper-en_US-amy-low/en_US-amy-low.onnx" | 8 | let model = "./vits-piper-en_US-amy-low/en_US-amy-low.onnx" |
| 3 | let tokens = "./vits-piper-en_US-amy-low/tokens.txt" | 9 | let tokens = "./vits-piper-en_US-amy-low/tokens.txt" |
| @@ -11,6 +17,27 @@ func run() { | @@ -11,6 +17,27 @@ func run() { | ||
| 11 | let modelConfig = sherpaOnnxOfflineTtsModelConfig(vits: vits) | 17 | let modelConfig = sherpaOnnxOfflineTtsModelConfig(vits: vits) |
| 12 | var ttsConfig = sherpaOnnxOfflineTtsConfig(model: modelConfig) | 18 | var ttsConfig = sherpaOnnxOfflineTtsConfig(model: modelConfig) |
| 13 | 19 | ||
| 20 | + let myClass = MyClass() | ||
| 21 | + | ||
| 22 | + // We use Unretained here so myClass must be kept alive as the callback is invoked | ||
| 23 | + // | ||
| 24 | + // See also | ||
| 25 | + // https://medium.com/codex/swift-c-callback-interoperability-6d57da6c8ee6 | ||
| 26 | + let arg = Unmanaged<MyClass>.passUnretained(myClass).toOpaque() | ||
| 27 | + | ||
| 28 | + let callback: TtsCallbackWithArg = { samples, n, arg in | ||
| 29 | + let o = Unmanaged<MyClass>.fromOpaque(arg!).takeUnretainedValue() | ||
| 30 | + var savedSamples: [Float] = [] | ||
| 31 | + for index in 0..<n { | ||
| 32 | + savedSamples.append(samples![Int(index)]) | ||
| 33 | + } | ||
| 34 | + | ||
| 35 | + o.playSamples(samples: savedSamples) | ||
| 36 | + | ||
| 37 | + // return 1 so that it continues generating | ||
| 38 | + return 1 | ||
| 39 | + } | ||
| 40 | + | ||
| 14 | let tts = SherpaOnnxOfflineTtsWrapper(config: &ttsConfig) | 41 | let tts = SherpaOnnxOfflineTtsWrapper(config: &ttsConfig) |
| 15 | 42 | ||
| 16 | let text = | 43 | let text = |
| @@ -18,11 +45,15 @@ func run() { | @@ -18,11 +45,15 @@ func run() { | ||
| 18 | let sid = 99 | 45 | let sid = 99 |
| 19 | let speed: Float = 1.0 | 46 | let speed: Float = 1.0 |
| 20 | 47 | ||
| 21 | - let audio = tts.generate(text: text, sid: sid, speed: speed) | 48 | + let audio = tts.generateWithCallbackWithArg( |
| 49 | + text: text, callback: callback, arg: arg, sid: sid, speed: speed) | ||
| 22 | let filename = "test.wav" | 50 | let filename = "test.wav" |
| 23 | - audio.save(filename: filename) | ||
| 24 | - | ||
| 25 | - print("\nSaved to:\n\(filename)") | 51 | + let ok = audio.save(filename: filename) |
| 52 | + if ok == 1 { | ||
| 53 | + print("\nSaved to:\(filename)") | ||
| 54 | + } else { | ||
| 55 | + print("Failed to save to \(filename)") | ||
| 56 | + } | ||
| 26 | } | 57 | } |
| 27 | 58 | ||
| 28 | @main | 59 | @main |
-
请 注册 或 登录 后发表评论