Fangjun Kuang
Committed by GitHub

Add Go API for ten-vad (#2384)

@@ -2,9 +2,10 @@ package main @@ -2,9 +2,10 @@ package main
2 2
3 import ( 3 import (
4 "fmt" 4 "fmt"
5 - portaudio "github.com/csukuangfj/portaudio-go" 5 + "github.com/gen2brain/malgo"
6 sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx" 6 sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx"
7 "log" 7 "log"
  8 + "os"
8 ) 9 )
9 10
10 func main() { 11 func main() {
@@ -13,62 +14,79 @@ func main() { @@ -13,62 +14,79 @@ func main() {
13 config := sherpa.VadModelConfig{} 14 config := sherpa.VadModelConfig{}
14 15
15 // Please download silero_vad.onnx from 16 // Please download silero_vad.onnx from
16 - // https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx 17 + // https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx
  18 + // or ten-vad.onnx from
  19 + // https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/ten-vad.onnx
17 20
  21 + if FileExists("./silero_vad.onnx") {
  22 + fmt.Println("Use silero-vad")
18 config.SileroVad.Model = "./silero_vad.onnx" 23 config.SileroVad.Model = "./silero_vad.onnx"
19 config.SileroVad.Threshold = 0.5 24 config.SileroVad.Threshold = 0.5
20 config.SileroVad.MinSilenceDuration = 0.5 25 config.SileroVad.MinSilenceDuration = 0.5
21 config.SileroVad.MinSpeechDuration = 0.25 26 config.SileroVad.MinSpeechDuration = 0.25
  27 + config.SileroVad.MaxSpeechDuration = 10
22 config.SileroVad.WindowSize = 512 28 config.SileroVad.WindowSize = 512
  29 + } else if FileExists("./ten-vad.onnx") {
  30 + fmt.Println("Use ten-vad")
  31 + config.TenVad.Model = "./ten-vad.onnx"
  32 + config.TenVad.Threshold = 0.5
  33 + config.TenVad.MinSilenceDuration = 0.5
  34 + config.TenVad.MinSpeechDuration = 0.25
  35 + config.TenVad.MaxSpeechDuration = 10
  36 + config.TenVad.WindowSize = 256
  37 + } else {
  38 + fmt.Println("Please download either ./silero_vad.onnx or ./ten-vad.onnx")
  39 + return
  40 + }
  41 +
23 config.SampleRate = 16000 42 config.SampleRate = 16000
24 config.NumThreads = 1 43 config.NumThreads = 1
25 config.Provider = "cpu" 44 config.Provider = "cpu"
26 config.Debug = 1 45 config.Debug = 1
27 46
  47 + windowSize := config.SileroVad.WindowSize
  48 + if config.TenVad.Model != "" {
  49 + windowSize = config.TenVad.WindowSize
  50 + }
  51 +
28 var bufferSizeInSeconds float32 = 5 52 var bufferSizeInSeconds float32 = 5
29 53
30 vad := sherpa.NewVoiceActivityDetector(&config, bufferSizeInSeconds) 54 vad := sherpa.NewVoiceActivityDetector(&config, bufferSizeInSeconds)
31 defer sherpa.DeleteVoiceActivityDetector(vad) 55 defer sherpa.DeleteVoiceActivityDetector(vad)
32 56
33 - err := portaudio.Initialize()  
34 - if err != nil {  
35 - log.Fatalf("Unable to initialize portaudio: %v\n", err)  
36 - }  
37 - defer portaudio.Terminate() 57 + buffer := sherpa.NewCircularBuffer(10 * config.SampleRate)
  58 + defer sherpa.DeleteCircularBuffer(buffer)
38 59
39 - default_device, err := portaudio.DefaultInputDevice()  
40 - if err != nil {  
41 - log.Fatal("Failed to get default input device: %v\n", err)  
42 - }  
43 - log.Printf("Selected default input device: %s\n", default_device.Name)  
44 - param := portaudio.StreamParameters{}  
45 - param.Input.Device = default_device  
46 - param.Input.Channels = 1  
47 - param.Input.Latency = default_device.DefaultLowInputLatency 60 + ctx, err := malgo.InitContext(nil, malgo.ContextConfig{}, func(message string) {
  61 + fmt.Printf("LOG <%v>", message)
  62 + })
  63 + chk(err)
48 64
49 - param.SampleRate = float64(config.SampleRate)  
50 - param.FramesPerBuffer = 0  
51 - param.Flags = portaudio.ClipOff 65 + defer func() {
  66 + _ = ctx.Uninit()
  67 + ctx.Free()
  68 + }()
52 69
53 - // you can choose another value for 0.1 if you want  
54 - samplesPerCall := int32(param.SampleRate * 0.1) // 0.1 second  
55 - samples := make([]float32, samplesPerCall) 70 + deviceConfig := malgo.DefaultDeviceConfig(malgo.Duplex)
  71 + deviceConfig.Capture.Format = malgo.FormatS16
  72 + deviceConfig.Capture.Channels = 1
  73 + deviceConfig.Playback.Format = malgo.FormatS16
  74 + deviceConfig.Playback.Channels = 1
  75 + deviceConfig.SampleRate = 16000
  76 + deviceConfig.Alsa.NoMMap = 1
56 77
57 - s, err := portaudio.OpenStream(param, samples)  
58 - if err != nil {  
59 - log.Fatalf("Failed to open the stream")  
60 - }  
61 -  
62 - defer s.Close()  
63 - chk(s.Start())  
64 -  
65 - log.Print("Started! Please speak")  
66 printed := false 78 printed := false
67 -  
68 k := 0 79 k := 0
69 - for {  
70 - chk(s.Read())  
71 - vad.AcceptWaveform(samples) 80 +
  81 + onRecvFrames := func(_, pSample []byte, framecount uint32) {
  82 + samples := samplesInt16ToFloat(pSample)
  83 + buffer.Push(samples)
  84 + for buffer.Size() >= windowSize {
  85 + head := buffer.Head()
  86 + s := buffer.Get(head, windowSize)
  87 + buffer.Pop(windowSize)
  88 +
  89 + vad.AcceptWaveform(s)
72 90
73 if vad.IsSpeech() && !printed { 91 if vad.IsSpeech() && !printed {
74 printed = true 92 printed = true
@@ -101,8 +119,22 @@ func main() { @@ -101,8 +119,22 @@ func main() {
101 log.Print("----------\n") 119 log.Print("----------\n")
102 } 120 }
103 } 121 }
  122 + }
  123 +
  124 + captureCallbacks := malgo.DeviceCallbacks{
  125 + Data: onRecvFrames,
  126 + }
  127 +
  128 + device, err := malgo.InitDevice(ctx.Context, deviceConfig, captureCallbacks)
  129 + chk(err)
  130 +
  131 + err = device.Start()
  132 + chk(err)
  133 +
  134 + fmt.Println("Started. Please speak. Press ctrl + C to exit")
  135 + fmt.Scanln()
  136 + device.Uninit()
104 137
105 - chk(s.Stop())  
106 } 138 }
107 139
108 func chk(err error) { 140 func chk(err error) {
@@ -110,3 +142,25 @@ func chk(err error) { @@ -110,3 +142,25 @@ func chk(err error) {
110 panic(err) 142 panic(err)
111 } 143 }
112 } 144 }
  145 +
  146 +func samplesInt16ToFloat(inSamples []byte) []float32 {
  147 + numSamples := len(inSamples) / 2
  148 + outSamples := make([]float32, numSamples)
  149 +
  150 + for i := 0; i != numSamples; i++ {
  151 + // Decode two bytes into an int16 using bit manipulation
  152 + s16 := int16(inSamples[2*i]) | int16(inSamples[2*i+1])<<8
  153 + outSamples[i] = float32(s16) / 32768
  154 + }
  155 +
  156 + return outSamples
  157 +}
  158 +
  159 +func FileExists(path string) bool {
  160 + _, err := os.Stat(path)
  161 + if err == nil {
  162 + return true
  163 + }
  164 +
  165 + return false
  166 +}
@@ -3,7 +3,11 @@ @@ -3,7 +3,11 @@
3 set -ex 3 set -ex
4 4
5 if [ ! -f ./silero_vad.onnx ]; then 5 if [ ! -f ./silero_vad.onnx ]; then
6 - curl -SL -O https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx 6 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx
  7 +fi
  8 +
  9 +if [ ! -f ./ten-vad.onnx ]; then
  10 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/ten-vad.onnx
7 fi 11 fi
8 12
9 go mod tidy 13 go mod tidy
@@ -1142,8 +1142,18 @@ type SileroVadModelConfig struct { @@ -1142,8 +1142,18 @@ type SileroVadModelConfig struct {
1142 MaxSpeechDuration float32 1142 MaxSpeechDuration float32
1143 } 1143 }
1144 1144
  1145 +type TenVadModelConfig struct {
  1146 + Model string
  1147 + Threshold float32
  1148 + MinSilenceDuration float32
  1149 + MinSpeechDuration float32
  1150 + WindowSize int
  1151 + MaxSpeechDuration float32
  1152 +}
  1153 +
1145 type VadModelConfig struct { 1154 type VadModelConfig struct {
1146 SileroVad SileroVadModelConfig 1155 SileroVad SileroVadModelConfig
  1156 + TenVad TenVadModelConfig
1147 SampleRate int 1157 SampleRate int
1148 NumThreads int 1158 NumThreads int
1149 Provider string 1159 Provider string
@@ -1220,6 +1230,15 @@ func NewVoiceActivityDetector(config *VadModelConfig, bufferSizeInSeconds float3 @@ -1220,6 +1230,15 @@ func NewVoiceActivityDetector(config *VadModelConfig, bufferSizeInSeconds float3
1220 c.silero_vad.window_size = C.int(config.SileroVad.WindowSize) 1230 c.silero_vad.window_size = C.int(config.SileroVad.WindowSize)
1221 c.silero_vad.max_speech_duration = C.float(config.SileroVad.MaxSpeechDuration) 1231 c.silero_vad.max_speech_duration = C.float(config.SileroVad.MaxSpeechDuration)
1222 1232
  1233 + c.ten_vad.model = C.CString(config.TenVad.Model)
  1234 + defer C.free(unsafe.Pointer(c.ten_vad.model))
  1235 +
  1236 + c.ten_vad.threshold = C.float(config.TenVad.Threshold)
  1237 + c.ten_vad.min_silence_duration = C.float(config.TenVad.MinSilenceDuration)
  1238 + c.ten_vad.min_speech_duration = C.float(config.TenVad.MinSpeechDuration)
  1239 + c.ten_vad.window_size = C.int(config.TenVad.WindowSize)
  1240 + c.ten_vad.max_speech_duration = C.float(config.TenVad.MaxSpeechDuration)
  1241 +
1223 c.sample_rate = C.int(config.SampleRate) 1242 c.sample_rate = C.int(config.SampleRate)
1224 c.num_threads = C.int(config.NumThreads) 1243 c.num_threads = C.int(config.NumThreads)
1225 c.provider = C.CString(config.Provider) 1244 c.provider = C.CString(config.Provider)