Fangjun Kuang
Committed by GitHub

Add Go API for ten-vad (#2384)

@@ -2,9 +2,10 @@ package main @@ -2,9 +2,10 @@ package main
2 2
3 import ( 3 import (
4 "fmt" 4 "fmt"
5 - portaudio "github.com/csukuangfj/portaudio-go" 5 + "github.com/gen2brain/malgo"
6 sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx" 6 sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx"
7 "log" 7 "log"
  8 + "os"
8 ) 9 )
9 10
10 func main() { 11 func main() {
@@ -13,96 +14,127 @@ func main() { @@ -13,96 +14,127 @@ func main() {
13 config := sherpa.VadModelConfig{} 14 config := sherpa.VadModelConfig{}
14 15
15 // Please download silero_vad.onnx from 16 // Please download silero_vad.onnx from
16 - // https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx 17 + // https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx
  18 + // or ten-vad.onnx from
  19 + // https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/ten-vad.onnx
  20 +
  21 + if FileExists("./silero_vad.onnx") {
  22 + fmt.Println("Use silero-vad")
  23 + config.SileroVad.Model = "./silero_vad.onnx"
  24 + config.SileroVad.Threshold = 0.5
  25 + config.SileroVad.MinSilenceDuration = 0.5
  26 + config.SileroVad.MinSpeechDuration = 0.25
  27 + config.SileroVad.MaxSpeechDuration = 10
  28 + config.SileroVad.WindowSize = 512
  29 + } else if FileExists("./ten-vad.onnx") {
  30 + fmt.Println("Use ten-vad")
  31 + config.TenVad.Model = "./ten-vad.onnx"
  32 + config.TenVad.Threshold = 0.5
  33 + config.TenVad.MinSilenceDuration = 0.5
  34 + config.TenVad.MinSpeechDuration = 0.25
  35 + config.TenVad.MaxSpeechDuration = 10
  36 + config.TenVad.WindowSize = 256
  37 + } else {
  38 + fmt.Println("Please download either ./silero_vad.onnx or ./ten-vad.onnx")
  39 + return
  40 + }
17 41
18 - config.SileroVad.Model = "./silero_vad.onnx"  
19 - config.SileroVad.Threshold = 0.5  
20 - config.SileroVad.MinSilenceDuration = 0.5  
21 - config.SileroVad.MinSpeechDuration = 0.25  
22 - config.SileroVad.WindowSize = 512  
23 config.SampleRate = 16000 42 config.SampleRate = 16000
24 config.NumThreads = 1 43 config.NumThreads = 1
25 config.Provider = "cpu" 44 config.Provider = "cpu"
26 config.Debug = 1 45 config.Debug = 1
27 46
  47 + windowSize := config.SileroVad.WindowSize
  48 + if config.TenVad.Model != "" {
  49 + windowSize = config.TenVad.WindowSize
  50 + }
  51 +
28 var bufferSizeInSeconds float32 = 5 52 var bufferSizeInSeconds float32 = 5
29 53
30 vad := sherpa.NewVoiceActivityDetector(&config, bufferSizeInSeconds) 54 vad := sherpa.NewVoiceActivityDetector(&config, bufferSizeInSeconds)
31 defer sherpa.DeleteVoiceActivityDetector(vad) 55 defer sherpa.DeleteVoiceActivityDetector(vad)
32 56
33 - err := portaudio.Initialize()  
34 - if err != nil {  
35 - log.Fatalf("Unable to initialize portaudio: %v\n", err)  
36 - }  
37 - defer portaudio.Terminate() 57 + buffer := sherpa.NewCircularBuffer(10 * config.SampleRate)
  58 + defer sherpa.DeleteCircularBuffer(buffer)
38 59
39 - default_device, err := portaudio.DefaultInputDevice()  
40 - if err != nil {  
41 - log.Fatal("Failed to get default input device: %v\n", err)  
42 - }  
43 - log.Printf("Selected default input device: %s\n", default_device.Name)  
44 - param := portaudio.StreamParameters{}  
45 - param.Input.Device = default_device  
46 - param.Input.Channels = 1  
47 - param.Input.Latency = default_device.DefaultLowInputLatency 60 + ctx, err := malgo.InitContext(nil, malgo.ContextConfig{}, func(message string) {
  61 + fmt.Printf("LOG <%v>", message)
  62 + })
  63 + chk(err)
48 64
49 - param.SampleRate = float64(config.SampleRate)  
50 - param.FramesPerBuffer = 0  
51 - param.Flags = portaudio.ClipOff 65 + defer func() {
  66 + _ = ctx.Uninit()
  67 + ctx.Free()
  68 + }()
52 69
53 - // you can choose another value for 0.1 if you want  
54 - samplesPerCall := int32(param.SampleRate * 0.1) // 0.1 second  
55 - samples := make([]float32, samplesPerCall) 70 + deviceConfig := malgo.DefaultDeviceConfig(malgo.Duplex)
  71 + deviceConfig.Capture.Format = malgo.FormatS16
  72 + deviceConfig.Capture.Channels = 1
  73 + deviceConfig.Playback.Format = malgo.FormatS16
  74 + deviceConfig.Playback.Channels = 1
  75 + deviceConfig.SampleRate = 16000
  76 + deviceConfig.Alsa.NoMMap = 1
56 77
57 - s, err := portaudio.OpenStream(param, samples)  
58 - if err != nil {  
59 - log.Fatalf("Failed to open the stream")  
60 - }  
61 -  
62 - defer s.Close()  
63 - chk(s.Start())  
64 -  
65 - log.Print("Started! Please speak")  
66 printed := false 78 printed := false
67 -  
68 k := 0 79 k := 0
69 - for {  
70 - chk(s.Read())  
71 - vad.AcceptWaveform(samples)  
72 80
73 - if vad.IsSpeech() && !printed {  
74 - printed = true  
75 - log.Print("Detected speech\n")  
76 - } 81 + onRecvFrames := func(_, pSample []byte, framecount uint32) {
  82 + samples := samplesInt16ToFloat(pSample)
  83 + buffer.Push(samples)
  84 + for buffer.Size() >= windowSize {
  85 + head := buffer.Head()
  86 + s := buffer.Get(head, windowSize)
  87 + buffer.Pop(windowSize)
77 88
78 - if !vad.IsSpeech() {  
79 - printed = false  
80 - } 89 + vad.AcceptWaveform(s)
81 90
82 - for !vad.IsEmpty() {  
83 - speechSegment := vad.Front()  
84 - vad.Pop() 91 + if vad.IsSpeech() && !printed {
  92 + printed = true
  93 + log.Print("Detected speech\n")
  94 + }
85 95
86 - duration := float32(len(speechSegment.Samples)) / float32(config.SampleRate) 96 + if !vad.IsSpeech() {
  97 + printed = false
  98 + }
87 99
88 - audio := sherpa.GeneratedAudio{}  
89 - audio.Samples = speechSegment.Samples  
90 - audio.SampleRate = config.SampleRate 100 + for !vad.IsEmpty() {
  101 + speechSegment := vad.Front()
  102 + vad.Pop()
91 103
92 - filename := fmt.Sprintf("seg-%d-%.2f-seconds.wav", k, duration)  
93 - ok := audio.Save(filename)  
94 - if ok {  
95 - log.Printf("Saved to %s", filename)  
96 - } 104 + duration := float32(len(speechSegment.Samples)) / float32(config.SampleRate)
97 105
98 - k += 1 106 + audio := sherpa.GeneratedAudio{}
  107 + audio.Samples = speechSegment.Samples
  108 + audio.SampleRate = config.SampleRate
99 109
100 - log.Printf("Duration: %.2f seconds\n", duration)  
101 - log.Print("----------\n") 110 + filename := fmt.Sprintf("seg-%d-%.2f-seconds.wav", k, duration)
  111 + ok := audio.Save(filename)
  112 + if ok {
  113 + log.Printf("Saved to %s", filename)
  114 + }
  115 +
  116 + k += 1
  117 +
  118 + log.Printf("Duration: %.2f seconds\n", duration)
  119 + log.Print("----------\n")
  120 + }
102 } 121 }
103 } 122 }
104 123
105 - chk(s.Stop()) 124 + captureCallbacks := malgo.DeviceCallbacks{
  125 + Data: onRecvFrames,
  126 + }
  127 +
  128 + device, err := malgo.InitDevice(ctx.Context, deviceConfig, captureCallbacks)
  129 + chk(err)
  130 +
  131 + err = device.Start()
  132 + chk(err)
  133 +
  134 + fmt.Println("Started. Please speak. Press ctrl + C to exit")
  135 + fmt.Scanln()
  136 + device.Uninit()
  137 +
106 } 138 }
107 139
108 func chk(err error) { 140 func chk(err error) {
@@ -110,3 +142,25 @@ func chk(err error) { @@ -110,3 +142,25 @@ func chk(err error) {
110 panic(err) 142 panic(err)
111 } 143 }
112 } 144 }
  145 +
  146 +func samplesInt16ToFloat(inSamples []byte) []float32 {
  147 + numSamples := len(inSamples) / 2
  148 + outSamples := make([]float32, numSamples)
  149 +
  150 + for i := 0; i != numSamples; i++ {
  151 + // Decode two bytes into an int16 using bit manipulation
  152 + s16 := int16(inSamples[2*i]) | int16(inSamples[2*i+1])<<8
  153 + outSamples[i] = float32(s16) / 32768
  154 + }
  155 +
  156 + return outSamples
  157 +}
  158 +
  159 +func FileExists(path string) bool {
  160 + _, err := os.Stat(path)
  161 + if err == nil {
  162 + return true
  163 + }
  164 +
  165 + return false
  166 +}
@@ -3,7 +3,11 @@ @@ -3,7 +3,11 @@
3 set -ex 3 set -ex
4 4
5 if [ ! -f ./silero_vad.onnx ]; then 5 if [ ! -f ./silero_vad.onnx ]; then
6 - curl -SL -O https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx 6 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx
  7 +fi
  8 +
  9 +if [ ! -f ./ten-vad.onnx ]; then
  10 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/ten-vad.onnx
7 fi 11 fi
8 12
9 go mod tidy 13 go mod tidy
@@ -1142,8 +1142,18 @@ type SileroVadModelConfig struct { @@ -1142,8 +1142,18 @@ type SileroVadModelConfig struct {
1142 MaxSpeechDuration float32 1142 MaxSpeechDuration float32
1143 } 1143 }
1144 1144
  1145 +type TenVadModelConfig struct {
  1146 + Model string
  1147 + Threshold float32
  1148 + MinSilenceDuration float32
  1149 + MinSpeechDuration float32
  1150 + WindowSize int
  1151 + MaxSpeechDuration float32
  1152 +}
  1153 +
1145 type VadModelConfig struct { 1154 type VadModelConfig struct {
1146 SileroVad SileroVadModelConfig 1155 SileroVad SileroVadModelConfig
  1156 + TenVad TenVadModelConfig
1147 SampleRate int 1157 SampleRate int
1148 NumThreads int 1158 NumThreads int
1149 Provider string 1159 Provider string
@@ -1220,6 +1230,15 @@ func NewVoiceActivityDetector(config *VadModelConfig, bufferSizeInSeconds float3 @@ -1220,6 +1230,15 @@ func NewVoiceActivityDetector(config *VadModelConfig, bufferSizeInSeconds float3
1220 c.silero_vad.window_size = C.int(config.SileroVad.WindowSize) 1230 c.silero_vad.window_size = C.int(config.SileroVad.WindowSize)
1221 c.silero_vad.max_speech_duration = C.float(config.SileroVad.MaxSpeechDuration) 1231 c.silero_vad.max_speech_duration = C.float(config.SileroVad.MaxSpeechDuration)
1222 1232
  1233 + c.ten_vad.model = C.CString(config.TenVad.Model)
  1234 + defer C.free(unsafe.Pointer(c.ten_vad.model))
  1235 +
  1236 + c.ten_vad.threshold = C.float(config.TenVad.Threshold)
  1237 + c.ten_vad.min_silence_duration = C.float(config.TenVad.MinSilenceDuration)
  1238 + c.ten_vad.min_speech_duration = C.float(config.TenVad.MinSpeechDuration)
  1239 + c.ten_vad.window_size = C.int(config.TenVad.WindowSize)
  1240 + c.ten_vad.max_speech_duration = C.float(config.TenVad.MaxSpeechDuration)
  1241 +
1223 c.sample_rate = C.int(config.SampleRate) 1242 c.sample_rate = C.int(config.SampleRate)
1224 c.num_threads = C.int(config.NumThreads) 1243 c.num_threads = C.int(config.NumThreads)
1225 c.provider = C.CString(config.Provider) 1244 c.provider = C.CString(config.Provider)