Fangjun Kuang
Committed by GitHub

Support passing TTS callback in Swift API (#1218)

@@ -757,6 +757,14 @@ class SherpaOnnxGeneratedAudioWrapper { @@ -757,6 +757,14 @@ class SherpaOnnxGeneratedAudioWrapper {
757 } 757 }
758 } 758 }
759 759
  760 +typealias TtsCallbackWithArg = (
  761 + @convention(c) (
  762 + UnsafePointer<Float>?, // const float* samples
  763 + Int32, // int32_t n
  764 + UnsafeMutableRawPointer? // void *arg
  765 + ) -> Int32
  766 +)?
  767 +
760 class SherpaOnnxOfflineTtsWrapper { 768 class SherpaOnnxOfflineTtsWrapper {
761 /// A pointer to the underlying counterpart in C 769 /// A pointer to the underlying counterpart in C
762 let tts: OpaquePointer! 770 let tts: OpaquePointer!
@@ -780,6 +788,17 @@ class SherpaOnnxOfflineTtsWrapper { @@ -780,6 +788,17 @@ class SherpaOnnxOfflineTtsWrapper {
780 788
781 return SherpaOnnxGeneratedAudioWrapper(audio: audio) 789 return SherpaOnnxGeneratedAudioWrapper(audio: audio)
782 } 790 }
  791 +
  792 + func generateWithCallbackWithArg(
  793 + text: String, callback: TtsCallbackWithArg, arg: UnsafeMutableRawPointer, sid: Int = 0,
  794 + speed: Float = 1.0
  795 + ) -> SherpaOnnxGeneratedAudioWrapper {
  796 + let audio: UnsafePointer<SherpaOnnxGeneratedAudio>? =
  797 + SherpaOnnxOfflineTtsGenerateWithCallbackWithArg(
  798 + tts, toCPointer(text), Int32(sid), speed, callback, arg)
  799 +
  800 + return SherpaOnnxGeneratedAudioWrapper(audio: audio)
  801 + }
783 } 802 }
784 803
785 // spoken language identification 804 // spoken language identification
  1 +class MyClass {
  2 + func playSamples(samples: [Float]) {
  3 + print("Play \(samples.count) samples")
  4 + }
  5 +}
  6 +
1 func run() { 7 func run() {
2 let model = "./vits-piper-en_US-amy-low/en_US-amy-low.onnx" 8 let model = "./vits-piper-en_US-amy-low/en_US-amy-low.onnx"
3 let tokens = "./vits-piper-en_US-amy-low/tokens.txt" 9 let tokens = "./vits-piper-en_US-amy-low/tokens.txt"
@@ -11,6 +17,27 @@ func run() { @@ -11,6 +17,27 @@ func run() {
11 let modelConfig = sherpaOnnxOfflineTtsModelConfig(vits: vits) 17 let modelConfig = sherpaOnnxOfflineTtsModelConfig(vits: vits)
12 var ttsConfig = sherpaOnnxOfflineTtsConfig(model: modelConfig) 18 var ttsConfig = sherpaOnnxOfflineTtsConfig(model: modelConfig)
13 19
  20 + let myClass = MyClass()
  21 +
  22 + // We use Unretained here so myClass must be kept alive as the callback is invoked
  23 + //
  24 + // See also
  25 + // https://medium.com/codex/swift-c-callback-interoperability-6d57da6c8ee6
  26 + let arg = Unmanaged<MyClass>.passUnretained(myClass).toOpaque()
  27 +
  28 + let callback: TtsCallbackWithArg = { samples, n, arg in
  29 + let o = Unmanaged<MyClass>.fromOpaque(arg!).takeUnretainedValue()
  30 + var savedSamples: [Float] = []
  31 + for index in 0..<n {
  32 + savedSamples.append(samples![Int(index)])
  33 + }
  34 +
  35 + o.playSamples(samples: savedSamples)
  36 +
  37 + // return 1 so that it continues generating
  38 + return 1
  39 + }
  40 +
14 let tts = SherpaOnnxOfflineTtsWrapper(config: &ttsConfig) 41 let tts = SherpaOnnxOfflineTtsWrapper(config: &ttsConfig)
15 42
16 let text = 43 let text =
@@ -18,11 +45,15 @@ func run() { @@ -18,11 +45,15 @@ func run() {
18 let sid = 99 45 let sid = 99
19 let speed: Float = 1.0 46 let speed: Float = 1.0
20 47
21 - let audio = tts.generate(text: text, sid: sid, speed: speed) 48 + let audio = tts.generateWithCallbackWithArg(
  49 + text: text, callback: callback, arg: arg, sid: sid, speed: speed)
22 let filename = "test.wav" 50 let filename = "test.wav"
23 - audio.save(filename: filename)  
24 -  
25 - print("\nSaved to:\n\(filename)") 51 + let ok = audio.save(filename: filename)
  52 + if ok == 1 {
  53 + print("\nSaved to:\(filename)")
  54 + } else {
  55 + print("Failed to save to \(filename)")
  56 + }
26 } 57 }
27 58
28 @main 59 @main