Fangjun Kuang
Committed by GitHub

Add Go API for speaker identification (#718)

@@ -66,6 +66,12 @@ jobs: @@ -66,6 +66,12 @@ jobs:
66 run: | 66 run: |
67 gcc --version 67 gcc --version
68 68
  69 + - name: Test speaker identification
  70 + shell: bash
  71 + run: |
  72 + cd go-api-examples/speaker-identification
  73 + ./run.sh
  74 +
69 - name: Test non-streaming TTS (Linux/macOS) 75 - name: Test non-streaming TTS (Linux/macOS)
70 if: matrix.os != 'windows-latest' 76 if: matrix.os != 'windows-latest'
71 shell: bash 77 shell: bash
@@ -74,6 +74,12 @@ jobs: @@ -74,6 +74,12 @@ jobs:
74 go mod tidy 74 go mod tidy
75 go build 75 go build
76 76
  77 + - name: Test speaker identification
  78 + shell: bash
  79 + run: |
  80 + cd scripts/go/_internal/speaker-identification/
  81 + ./run.sh
  82 +
77 - name: Test non-streaming TTS (macOS) 83 - name: Test non-streaming TTS (macOS)
78 shell: bash 84 shell: bash
79 run: | 85 run: |
@@ -88,3 +88,5 @@ vits-mms-* @@ -88,3 +88,5 @@ vits-mms-*
88 *.tar.bz2 88 *.tar.bz2
89 sherpa-onnx-paraformer-trilingual-zh-cantonese-en 89 sherpa-onnx-paraformer-trilingual-zh-cantonese-en
90 sr-data 90 sr-data
  91 +*xcworkspace/xcuserdata/*
  92 +
@@ -26,4 +26,8 @@ for details. @@ -26,4 +26,8 @@ for details.
26 - [./vad-spoken-language-identification](./vad-spoken-language-identification) It shows how to use silero VAD + Whisper 26 - [./vad-spoken-language-identification](./vad-spoken-language-identification) It shows how to use silero VAD + Whisper
27 for spoken language identification. 27 for spoken language identification.
28 28
  29 +- [./speaker-identification](./speaker-identification) It shows how to use Go API for speaker identification.
  30 +
  31 +- [./vad-speaker-identification](./vad-speaker-identification) It shows how to use Go API for VAD + speaker identification.
  32 +
29 [sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx 33 [sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx
  1 +module speaker-identification
  2 +
  3 +go 1.12
  1 +package main
  2 +
  3 +import (
  4 + sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx"
  5 + "log"
  6 +)
  7 +
  8 +func createSpeakerEmbeddingExtractor() *sherpa.SpeakerEmbeddingExtractor {
  9 + config := sherpa.SpeakerEmbeddingExtractorConfig{}
  10 +
  11 + // Please download the model from
  12 + // https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx
  13 + //
  14 + // You can find more models at
  15 + // https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
  16 +
  17 + config.Model = "./3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx"
  18 + config.NumThreads = 1
  19 + config.Debug = 1
  20 + config.Provider = "cpu"
  21 +
  22 + ex := sherpa.NewSpeakerEmbeddingExtractor(&config)
  23 + return ex
  24 +}
  25 +
  26 +func computeEmbeddings(ex *sherpa.SpeakerEmbeddingExtractor, files []string) [][]float32 {
  27 + embeddings := make([][]float32, len(files))
  28 +
  29 + for i, f := range files {
  30 + wave := sherpa.ReadWave(f)
  31 +
  32 + stream := ex.CreateStream()
  33 + defer sherpa.DeleteOnlineStream(stream)
  34 + stream.AcceptWaveform(wave.SampleRate, wave.Samples)
  35 + stream.InputFinished()
  36 + embeddings[i] = ex.Compute(stream)
  37 + }
  38 +
  39 + return embeddings
  40 +
  41 +}
  42 +
  43 +func registerSpeakers(ex *sherpa.SpeakerEmbeddingExtractor, manager *sherpa.SpeakerEmbeddingManager) {
  44 + // Please download the test data from
  45 + // https://github.com/csukuangfj/sr-data
  46 + spk1_files := []string{
  47 + "./sr-data/enroll/fangjun-sr-1.wav",
  48 + "./sr-data/enroll/fangjun-sr-2.wav",
  49 + "./sr-data/enroll/fangjun-sr-3.wav",
  50 + }
  51 +
  52 + spk2_files := []string{
  53 + "./sr-data/enroll/leijun-sr-1.wav",
  54 + "./sr-data/enroll/leijun-sr-2.wav",
  55 + }
  56 +
  57 + spk1_embeddings := computeEmbeddings(ex, spk1_files)
  58 + spk2_embeddings := computeEmbeddings(ex, spk2_files)
  59 +
  60 + ok := manager.RegisterV("fangjun", spk1_embeddings)
  61 + if !ok {
  62 + panic("Failed to register fangjun")
  63 + }
  64 +
  65 + ok = manager.RegisterV("leijun", spk2_embeddings)
  66 + if !ok {
  67 + panic("Failed to register leijun")
  68 + }
  69 +
  70 + if !manager.Contains("fangjun") {
  71 + panic("Failed to find fangjun")
  72 + }
  73 +
  74 + if !manager.Contains("leijun") {
  75 + panic("Failed to find leijun")
  76 + }
  77 +
  78 + if manager.NumSpeakers() != 2 {
  79 + panic("There should be only 2 speakers")
  80 + }
  81 +
  82 + all_speakers := manager.AllSpeakers()
  83 + log.Printf("All speakers: %v\n", all_speakers)
  84 +}
  85 +
  86 +func main() {
  87 + log.SetFlags(log.LstdFlags | log.Lmicroseconds)
  88 +
  89 + ex := createSpeakerEmbeddingExtractor()
  90 + defer sherpa.DeleteSpeakerEmbeddingExtractor(ex)
  91 +
  92 + manager := sherpa.NewSpeakerEmbeddingManager(ex.Dim())
  93 + defer sherpa.DeleteSpeakerEmbeddingManager(manager)
  94 + registerSpeakers(ex, manager)
  95 +
  96 + // Please download the test data from
  97 + // https://github.com/csukuangfj/sr-data
  98 + test1 := "./sr-data/test/fangjun-test-sr-1.wav"
  99 + embeddings := computeEmbeddings(ex, []string{test1})[0]
  100 + threshold := float32(0.6)
  101 + name := manager.Search(embeddings, threshold)
  102 + if len(name) > 0 {
  103 + log.Printf("%v matches %v", test1, name)
  104 + } else {
  105 + log.Printf("No matches found for %v", test1)
  106 + }
  107 +
  108 + test2 := "./sr-data/test/leijun-test-sr-1.wav"
  109 + embeddings = computeEmbeddings(ex, []string{test2})[0]
  110 + name = manager.Search(embeddings, threshold)
  111 + if len(name) > 0 {
  112 + log.Printf("%v matches %v", test2, name)
  113 + } else {
  114 + log.Printf("No matches found for %v", test2)
  115 + }
  116 +
  117 + test3 := "./sr-data/test/liudehua-test-sr-1.wav"
  118 + embeddings = computeEmbeddings(ex, []string{test3})[0]
  119 + name = manager.Search(embeddings, threshold)
  120 + if len(name) > 0 {
  121 + log.Printf("%v matches %v", test3, name)
  122 + } else {
  123 + log.Printf("No matches found for %v", test3)
  124 + }
  125 +
  126 + if !manager.Remove("fangjun") {
  127 + panic("Failed to deregister fangjun")
  128 + } else {
  129 + log.Print("fangjun deregistered\n")
  130 + }
  131 +
  132 + test1 = "./sr-data/test/fangjun-test-sr-1.wav"
  133 + embeddings = computeEmbeddings(ex, []string{test1})[0]
  134 + name = manager.Search(embeddings, threshold)
  135 + if len(name) > 0 {
  136 + log.Printf("%v matches %v", test1, name)
  137 + } else {
  138 + log.Printf("No matches found for %v", test1)
  139 + }
  140 +}
  141 +
  142 +func chk(err error) {
  143 + if err != nil {
  144 + panic(err)
  145 + }
  146 +}
  1 +#!/usr/bin/env bash
  2 +
  3 +if [ ! -f ./3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx ]; then
  4 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx
  5 +fi
  6 +
  7 +if [ ! -f ./sr-data/enroll/fangjun-sr-1.wav ]; then
  8 + git clone https://github.com/csukuangfj/sr-data
  9 +fi
  10 +
  11 +go mod tidy
  12 +go build
  13 +./speaker-identification
@@ -104,7 +104,7 @@ func main() { @@ -104,7 +104,7 @@ func main() {
104 104
105 duration := float32(len(speechSegment.Samples)) / float32(config.SampleRate) 105 duration := float32(len(speechSegment.Samples)) / float32(config.SampleRate)
106 106
107 - audio := &sherpa.GeneratedAudio{} 107 + audio := &sherpa.Wave{}
108 audio.Samples = speechSegment.Samples 108 audio.Samples = speechSegment.Samples
109 audio.SampleRate = config.SampleRate 109 audio.SampleRate = config.SampleRate
110 110
@@ -120,7 +120,7 @@ func main() { @@ -120,7 +120,7 @@ func main() {
120 chk(s.Stop()) 120 chk(s.Stop())
121 } 121 }
122 122
123 -func decode(recognizer *sherpa.OfflineRecognizer, audio *sherpa.GeneratedAudio, id int) { 123 +func decode(recognizer *sherpa.OfflineRecognizer, audio *sherpa.Wave, id int) {
124 stream := sherpa.NewOfflineStream(recognizer) 124 stream := sherpa.NewOfflineStream(recognizer)
125 defer sherpa.DeleteOfflineStream(stream) 125 defer sherpa.DeleteOfflineStream(stream)
126 stream.AcceptWaveform(audio.SampleRate, audio.Samples) 126 stream.AcceptWaveform(audio.SampleRate, audio.Samples)
@@ -102,7 +102,7 @@ func main() { @@ -102,7 +102,7 @@ func main() {
102 102
103 duration := float32(len(speechSegment.Samples)) / float32(config.SampleRate) 103 duration := float32(len(speechSegment.Samples)) / float32(config.SampleRate)
104 104
105 - audio := &sherpa.GeneratedAudio{} 105 + audio := &sherpa.Wave{}
106 audio.Samples = speechSegment.Samples 106 audio.Samples = speechSegment.Samples
107 audio.SampleRate = config.SampleRate 107 audio.SampleRate = config.SampleRate
108 108
@@ -118,7 +118,7 @@ func main() { @@ -118,7 +118,7 @@ func main() {
118 chk(s.Stop()) 118 chk(s.Stop())
119 } 119 }
120 120
121 -func decode(recognizer *sherpa.OfflineRecognizer, audio *sherpa.GeneratedAudio, id int) { 121 +func decode(recognizer *sherpa.OfflineRecognizer, audio *sherpa.Wave, id int) {
122 stream := sherpa.NewOfflineStream(recognizer) 122 stream := sherpa.NewOfflineStream(recognizer)
123 defer sherpa.DeleteOfflineStream(stream) 123 defer sherpa.DeleteOfflineStream(stream)
124 stream.AcceptWaveform(audio.SampleRate, audio.Samples) 124 stream.AcceptWaveform(audio.SampleRate, audio.Samples)
  1 +module vad-speaker-identification
  2 +
  3 +go 1.12
  1 +package main
  2 +
  3 +import (
  4 + "fmt"
  5 + "github.com/gordonklaus/portaudio"
  6 + sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx"
  7 + "log"
  8 +)
  9 +
  10 +func createSpeakerEmbeddingExtractor() *sherpa.SpeakerEmbeddingExtractor {
  11 + config := sherpa.SpeakerEmbeddingExtractorConfig{}
  12 +
  13 + // Please download the model from
  14 + // https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx
  15 + //
  16 + // You can find more models at
  17 + // https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
  18 +
  19 + config.Model = "./3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx"
  20 + config.NumThreads = 2
  21 + config.Debug = 1
  22 + config.Provider = "cpu"
  23 +
  24 + ex := sherpa.NewSpeakerEmbeddingExtractor(&config)
  25 + return ex
  26 +}
  27 +
  28 +func computeEmbeddings(ex *sherpa.SpeakerEmbeddingExtractor, files []string) [][]float32 {
  29 + embeddings := make([][]float32, len(files))
  30 +
  31 + for i, f := range files {
  32 + wave := sherpa.ReadWave(f)
  33 +
  34 + stream := ex.CreateStream()
  35 + defer sherpa.DeleteOnlineStream(stream)
  36 + stream.AcceptWaveform(wave.SampleRate, wave.Samples)
  37 + stream.InputFinished()
  38 + embeddings[i] = ex.Compute(stream)
  39 + }
  40 +
  41 + return embeddings
  42 +
  43 +}
  44 +
  45 +func registerSpeakers(ex *sherpa.SpeakerEmbeddingExtractor, manager *sherpa.SpeakerEmbeddingManager) {
  46 + // Please download the test data from
  47 + // https://github.com/csukuangfj/sr-data
  48 + spk1_files := []string{
  49 + "./sr-data/enroll/fangjun-sr-1.wav",
  50 + "./sr-data/enroll/fangjun-sr-2.wav",
  51 + "./sr-data/enroll/fangjun-sr-3.wav",
  52 + }
  53 +
  54 + spk2_files := []string{
  55 + "./sr-data/enroll/leijun-sr-1.wav",
  56 + "./sr-data/enroll/leijun-sr-2.wav",
  57 + }
  58 +
  59 + spk1_embeddings := computeEmbeddings(ex, spk1_files)
  60 + spk2_embeddings := computeEmbeddings(ex, spk2_files)
  61 +
  62 + ok := manager.RegisterV("fangjun", spk1_embeddings)
  63 + if !ok {
  64 + panic("Failed to register fangjun")
  65 + }
  66 +
  67 + ok = manager.RegisterV("leijun", spk2_embeddings)
  68 + if !ok {
  69 + panic("Failed to register leijun")
  70 + }
  71 +
  72 + if !manager.Contains("fangjun") {
  73 + panic("Failed to find fangjun")
  74 + }
  75 +
  76 + if !manager.Contains("leijun") {
  77 + panic("Failed to find leijun")
  78 + }
  79 +
  80 + if manager.NumSpeakers() != 2 {
  81 + panic("There should be only 2 speakers")
  82 + }
  83 +
  84 + all_speakers := manager.AllSpeakers()
  85 + log.Printf("All speakers: %v\n", all_speakers)
  86 +}
  87 +
  88 +func createVad() *sherpa.VoiceActivityDetector {
  89 + config := sherpa.VadModelConfig{}
  90 +
  91 + // Please download silero_vad.onnx from
  92 + // https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx
  93 +
  94 + config.SileroVad.Model = "./silero_vad.onnx"
  95 + config.SileroVad.Threshold = 0.5
  96 + config.SileroVad.MinSilenceDuration = 0.5
  97 + config.SileroVad.MinSpeechDuration = 0.5
  98 + config.SileroVad.WindowSize = 512
  99 + config.SampleRate = 16000
  100 + config.NumThreads = 1
  101 + config.Provider = "cpu"
  102 + config.Debug = 1
  103 +
  104 + var bufferSizeInSeconds float32 = 20
  105 +
  106 + vad := sherpa.NewVoiceActivityDetector(&config, bufferSizeInSeconds)
  107 + return vad
  108 +}
  109 +
  110 +func main() {
  111 + log.SetFlags(log.LstdFlags | log.Lmicroseconds)
  112 +
  113 + vad := createVad()
  114 + defer sherpa.DeleteVoiceActivityDetector(vad)
  115 +
  116 + ex := createSpeakerEmbeddingExtractor()
  117 + defer sherpa.DeleteSpeakerEmbeddingExtractor(ex)
  118 +
  119 + manager := sherpa.NewSpeakerEmbeddingManager(ex.Dim())
  120 + defer sherpa.DeleteSpeakerEmbeddingManager(manager)
  121 + registerSpeakers(ex, manager)
  122 +
  123 + err := portaudio.Initialize()
  124 + if err != nil {
  125 + log.Fatalf("Unable to initialize portaudio: %v\n", err)
  126 + }
  127 + defer portaudio.Terminate()
  128 +
  129 + default_device, err := portaudio.DefaultInputDevice()
  130 + if err != nil {
  131 + log.Fatal("Failed to get default input device: %v\n", err)
  132 + }
  133 + log.Printf("Selected default input device: %s\n", default_device.Name)
  134 + param := portaudio.StreamParameters{}
  135 + param.Input.Device = default_device
  136 + param.Input.Channels = 1
  137 + param.Input.Latency = default_device.DefaultHighInputLatency
  138 +
  139 + param.SampleRate = 16000
  140 + param.FramesPerBuffer = 0
  141 + param.Flags = portaudio.ClipOff
  142 +
  143 + // you can choose another value for 0.1 if you want
  144 + samplesPerCall := int32(param.SampleRate * 0.1) // 0.1 second
  145 + samples := make([]float32, samplesPerCall)
  146 +
  147 + s, err := portaudio.OpenStream(param, samples)
  148 + if err != nil {
  149 + log.Fatalf("Failed to open the stream")
  150 + }
  151 +
  152 + defer s.Close()
  153 + chk(s.Start())
  154 +
  155 + log.Print("Started! Please speak")
  156 + printed := false
  157 +
  158 + k := 0
  159 + for {
  160 + chk(s.Read())
  161 + vad.AcceptWaveform(samples)
  162 +
  163 + if vad.IsSpeech() && !printed {
  164 + printed = true
  165 + log.Print("Detected speech\n")
  166 + }
  167 +
  168 + if !vad.IsSpeech() {
  169 + printed = false
  170 + }
  171 +
  172 + for !vad.IsEmpty() {
  173 + speechSegment := vad.Front()
  174 + vad.Pop()
  175 +
  176 + audio := &sherpa.Wave{}
  177 + audio.Samples = speechSegment.Samples
  178 + audio.SampleRate = 16000
  179 +
  180 + // Now decode it
  181 + go decode(ex, manager, audio, k)
  182 +
  183 + k += 1
  184 + }
  185 + }
  186 +
  187 + chk(s.Stop())
  188 +
  189 +}
  190 +
  191 +func chk(err error) {
  192 + if err != nil {
  193 + panic(err)
  194 + }
  195 +}
  196 +
  197 +func decode(ex *sherpa.SpeakerEmbeddingExtractor, manager *sherpa.SpeakerEmbeddingManager, audio *sherpa.GeneratedAudio, id int) {
  198 + stream := ex.CreateStream()
  199 + defer sherpa.DeleteOnlineStream(stream)
  200 +
  201 + stream.AcceptWaveform(audio.SampleRate, audio.Samples)
  202 + stream.InputFinished()
  203 + embeddings := ex.Compute(stream)
  204 + threshold := float32(0.5)
  205 + name := manager.Search(embeddings, threshold)
  206 + if len(name) > 0 {
  207 + log.Printf("Found speaker: %v\n", name)
  208 + } else {
  209 + log.Print("Unknown speaker\n")
  210 + name = "Unknown"
  211 + }
  212 +
  213 + duration := float32(len(audio.Samples)) / float32(audio.SampleRate)
  214 +
  215 + filename := fmt.Sprintf("seg-%d-%.2f-seconds-%s.wav", id, duration, name)
  216 + ok := audio.Save(filename)
  217 + if ok {
  218 + log.Printf("Saved to %s", filename)
  219 + }
  220 + log.Print("----------\n")
  221 +}
  1 +#!/usr/bin/env bash
  2 +
  3 +if [ ! -f ./3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx ]; then
  4 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx
  5 +fi
  6 +
  7 +if [ ! -f ./sr-data/enroll/fangjun-sr-1.wav ]; then
  8 + git clone https://github.com/csukuangfj/sr-data
  9 +fi
  10 +
  11 +go mod tidy
  12 +go build
  13 +./speaker-identification
@@ -99,7 +99,7 @@ func main() { @@ -99,7 +99,7 @@ func main() {
99 99
100 duration := float32(len(speechSegment.Samples)) / float32(config.SampleRate) 100 duration := float32(len(speechSegment.Samples)) / float32(config.SampleRate)
101 101
102 - audio := &sherpa.GeneratedAudio{} 102 + audio := &sherpa.Wave{}
103 audio.Samples = speechSegment.Samples 103 audio.Samples = speechSegment.Samples
104 audio.SampleRate = config.SampleRate 104 audio.SampleRate = config.SampleRate
105 105
@@ -115,7 +115,7 @@ func main() { @@ -115,7 +115,7 @@ func main() {
115 chk(s.Stop()) 115 chk(s.Stop())
116 } 116 }
117 117
118 -func decode(slid *sherpa.SpokenLanguageIdentification, audio *sherpa.GeneratedAudio, id int) { 118 +func decode(slid *sherpa.SpokenLanguageIdentification, audio *sherpa.Wave, id int) {
119 stream := slid.CreateStream() 119 stream := slid.CreateStream()
120 defer sherpa.DeleteOfflineStream(stream) 120 defer sherpa.DeleteOfflineStream(stream)
121 121
1 -<?xml version="1.0" encoding="UTF-8"?>  
2 -<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">  
3 -<plist version="1.0">  
4 -<dict>  
5 - <key>SchemeUserState</key>  
6 - <dict>  
7 - <key>SherpaOnnx.xcscheme_^#shared#^_</key>  
8 - <dict>  
9 - <key>orderHint</key>  
10 - <integer>0</integer>  
11 - </dict>  
12 - </dict>  
13 -</dict>  
14 -</plist>  
  1 +module speaker-identification
  2 +
  3 +go 1.12
  4 +
  5 +replace github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx => ../
  1 +../../../../go-api-examples/speaker-identification/main.go
  1 +../../../../go-api-examples/speaker-identification/run.sh
@@ -3,8 +3,3 @@ module vad-asr-paraformer @@ -3,8 +3,3 @@ module vad-asr-paraformer
3 go 1.12 3 go 1.12
4 4
5 replace github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx => ../ 5 replace github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx => ../
6 -  
7 -require (  
8 - github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5  
9 - github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx v0.0.0-00010101000000-000000000000  
10 -)  
  1 +module vad-speaker-identification
  2 +
  3 +go 1.12
  4 +
  5 +replace github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx => ../
  1 +../../../../go-api-examples/vad-speaker-identification/main.go
  1 +../../../../go-api-examples/vad-speaker-identification/run.sh
@@ -746,11 +746,11 @@ func (vad *VoiceActivityDetector) AcceptWaveform(samples []float32) { @@ -746,11 +746,11 @@ func (vad *VoiceActivityDetector) AcceptWaveform(samples []float32) {
746 } 746 }
747 747
748 func (vad *VoiceActivityDetector) IsEmpty() bool { 748 func (vad *VoiceActivityDetector) IsEmpty() bool {
749 - return 1 == int(C.SherpaOnnxVoiceActivityDetectorEmpty(vad.impl)) 749 + return int(C.SherpaOnnxVoiceActivityDetectorEmpty(vad.impl)) == 1
750 } 750 }
751 751
752 func (vad *VoiceActivityDetector) IsSpeech() bool { 752 func (vad *VoiceActivityDetector) IsSpeech() bool {
753 - return 1 == int(C.SherpaOnnxVoiceActivityDetectorDetected(vad.impl)) 753 + return int(C.SherpaOnnxVoiceActivityDetectorDetected(vad.impl)) == 1
754 } 754 }
755 755
756 func (vad *VoiceActivityDetector) Pop() { 756 func (vad *VoiceActivityDetector) Pop() {
@@ -852,3 +852,204 @@ func (slid *SpokenLanguageIdentification) Compute(stream *OfflineStream) *Spoken @@ -852,3 +852,204 @@ func (slid *SpokenLanguageIdentification) Compute(stream *OfflineStream) *Spoken
852 852
853 return ans 853 return ans
854 } 854 }
  855 +
  856 +// ============================================================
  857 +// For speaker embedding extraction
  858 +// ============================================================
  859 +
  860 +type SpeakerEmbeddingExtractorConfig struct {
  861 + Model string
  862 + NumThreads int
  863 + Debug int
  864 + Provider string
  865 +}
  866 +
  867 +type SpeakerEmbeddingExtractor struct {
  868 + impl *C.struct_SherpaOnnxSpeakerEmbeddingExtractor
  869 +}
  870 +
  871 +// The user has to invoke [DeleteSpeakerEmbeddingExtractor]() to free the returned value
  872 +// to avoid memory leak
  873 +func NewSpeakerEmbeddingExtractor(config *SpeakerEmbeddingExtractorConfig) *SpeakerEmbeddingExtractor {
  874 + c := C.struct_SherpaOnnxSpeakerEmbeddingExtractorConfig{}
  875 +
  876 + c.model = C.CString(config.Model)
  877 + defer C.free(unsafe.Pointer(c.model))
  878 +
  879 + c.num_threads = C.int(config.NumThreads)
  880 + c.debug = C.int(config.Debug)
  881 +
  882 + c.provider = C.CString(config.Provider)
  883 + defer C.free(unsafe.Pointer(c.provider))
  884 +
  885 + ex := &SpeakerEmbeddingExtractor{}
  886 + ex.impl = C.SherpaOnnxCreateSpeakerEmbeddingExtractor(&c)
  887 +
  888 + return ex
  889 +}
  890 +
  891 +func DeleteSpeakerEmbeddingExtractor(ex *SpeakerEmbeddingExtractor) {
  892 + C.SherpaOnnxDestroySpeakerEmbeddingExtractor(ex.impl)
  893 + ex.impl = nil
  894 +}
  895 +
  896 +func (ex *SpeakerEmbeddingExtractor) Dim() int {
  897 + return int(C.SherpaOnnxSpeakerEmbeddingExtractorDim(ex.impl))
  898 +}
  899 +
  900 +// The user is responsible to invoke [DeleteOnlineStream]() to free
  901 +// the returned stream to avoid memory leak
  902 +func (ex *SpeakerEmbeddingExtractor) CreateStream() *OnlineStream {
  903 + stream := &OnlineStream{}
  904 + stream.impl = C.SherpaOnnxSpeakerEmbeddingExtractorCreateStream(ex.impl)
  905 + return stream
  906 +}
  907 +
  908 +func (ex *SpeakerEmbeddingExtractor) IsReady(stream *OnlineStream) bool {
  909 + return int(C.SherpaOnnxSpeakerEmbeddingExtractorIsReady(ex.impl, stream.impl)) == 1
  910 +}
  911 +
  912 +func (ex *SpeakerEmbeddingExtractor) Compute(stream *OnlineStream) []float32 {
  913 + embedding := C.SherpaOnnxSpeakerEmbeddingExtractorComputeEmbedding(ex.impl, stream.impl)
  914 + defer C.SherpaOnnxSpeakerEmbeddingExtractorDestroyEmbedding(embedding)
  915 +
  916 + n := ex.Dim()
  917 + ans := make([]float32, n)
  918 +
  919 + // see https://stackoverflow.com/questions/48756732/what-does-1-30c-yourtype-do-exactly-in-cgo
  920 + // :n:n means 0:n:n, means low:high:capacity
  921 + c := (*[1 << 28]C.float)(unsafe.Pointer(embedding))[:n:n]
  922 +
  923 + for i := 0; i < n; i++ {
  924 + ans[i] = float32(c[i])
  925 + }
  926 +
  927 + return ans
  928 +}
  929 +
  930 +type SpeakerEmbeddingManager struct {
  931 + impl *C.struct_SherpaOnnxSpeakerEmbeddingManager
  932 +}
  933 +
  934 +// The user has to invoke [DeleteSpeakerEmbeddingManager]() to free the returned
  935 +// value to avoid memory leak
  936 +func NewSpeakerEmbeddingManager(dim int) *SpeakerEmbeddingManager {
  937 + m := &SpeakerEmbeddingManager{}
  938 + m.impl = C.SherpaOnnxCreateSpeakerEmbeddingManager(C.int(dim))
  939 + return m
  940 +}
  941 +
  942 +func DeleteSpeakerEmbeddingManager(m *SpeakerEmbeddingManager) {
  943 + C.SherpaOnnxDestroySpeakerEmbeddingManager(m.impl)
  944 + m.impl = nil
  945 +}
  946 +
  947 +func (m *SpeakerEmbeddingManager) Register(name string, embedding []float32) bool {
  948 + s := C.CString(name)
  949 + defer C.free(unsafe.Pointer(s))
  950 +
  951 + return C.int(C.SherpaOnnxSpeakerEmbeddingManagerAdd(m.impl, s, (*C.float)(&embedding[0]))) == 1
  952 +}
  953 +
  954 +func (m *SpeakerEmbeddingManager) RegisterV(name string, embeddings [][]float32) bool {
  955 + s := C.CString(name)
  956 + defer C.free(unsafe.Pointer(s))
  957 +
  958 + if len(embeddings) == 0 {
  959 + return false
  960 + }
  961 +
  962 + dim := len(embeddings[0])
  963 + v := make([]float32, 0, dim*len(embeddings))
  964 + for _, embedding := range embeddings {
  965 + v = append(v, embedding...)
  966 + }
  967 +
  968 + return C.int(C.SherpaOnnxSpeakerEmbeddingManagerAddListFlattened(m.impl, s, (*C.float)(&v[0]), C.int(len(embeddings)))) == 1
  969 +}
  970 +
  971 +func (m *SpeakerEmbeddingManager) Remove(name string) bool {
  972 + s := C.CString(name)
  973 + defer C.free(unsafe.Pointer(s))
  974 +
  975 + return C.int(C.SherpaOnnxSpeakerEmbeddingManagerRemove(m.impl, s)) == 1
  976 +}
  977 +
  978 +func (m *SpeakerEmbeddingManager) Search(embedding []float32, threshold float32) string {
  979 + var s string
  980 +
  981 + name := C.SherpaOnnxSpeakerEmbeddingManagerSearch(m.impl, (*C.float)(&embedding[0]), C.float(threshold))
  982 + defer C.SherpaOnnxSpeakerEmbeddingManagerFreeSearch(name)
  983 +
  984 + if name != nil {
  985 + s = C.GoString(name)
  986 + }
  987 +
  988 + return s
  989 +}
  990 +
  991 +func (m *SpeakerEmbeddingManager) Verify(name string, embedding []float32, threshold float32) bool {
  992 + s := C.CString(name)
  993 + defer C.free(unsafe.Pointer(s))
  994 +
  995 + return C.int(C.SherpaOnnxSpeakerEmbeddingManagerVerify(m.impl, s, (*C.float)(&embedding[0]), C.float(threshold))) == 1
  996 +}
  997 +
  998 +func (m *SpeakerEmbeddingManager) Contains(name string) bool {
  999 + s := C.CString(name)
  1000 + defer C.free(unsafe.Pointer(s))
  1001 +
  1002 + return C.int(C.SherpaOnnxSpeakerEmbeddingManagerContains(m.impl, s)) == 1
  1003 +}
  1004 +
  1005 +func (m *SpeakerEmbeddingManager) NumSpeakers() int {
  1006 + return int(C.SherpaOnnxSpeakerEmbeddingManagerNumSpeakers(m.impl))
  1007 +}
  1008 +
  1009 +func (m *SpeakerEmbeddingManager) AllSpeakers() []string {
  1010 + all_speakers := C.SherpaOnnxSpeakerEmbeddingManagerGetAllSpeakers(m.impl)
  1011 + defer C.SherpaOnnxSpeakerEmbeddingManagerFreeAllSpeakers(all_speakers)
  1012 +
  1013 + n := m.NumSpeakers()
  1014 + if n == 0 {
  1015 + return nil
  1016 + }
  1017 +
  1018 + // https://stackoverflow.com/questions/62012070/convert-array-of-strings-from-cgo-in-go
  1019 + p := (*[1 << 28]*C.char)(unsafe.Pointer(all_speakers))[:n:n]
  1020 +
  1021 + ans := make([]string, n)
  1022 +
  1023 + for i := 0; i < n; i++ {
  1024 + ans[i] = C.GoString(p[i])
  1025 + }
  1026 +
  1027 + return ans
  1028 +}
  1029 +
  1030 +// Wave
  1031 +
  1032 +// single channel wave
  1033 +type Wave = GeneratedAudio
  1034 +
  1035 +func ReadWave(filename string) *Wave {
  1036 + s := C.CString(filename)
  1037 + defer C.free(unsafe.Pointer(s))
  1038 +
  1039 + w := C.SherpaOnnxReadWave(s)
  1040 + defer C.SherpaOnnxFreeWave(w)
  1041 +
  1042 + n := int(w.num_samples)
  1043 +
  1044 + ans := &Wave{}
  1045 + ans.SampleRate = int(w.sample_rate)
  1046 + samples := (*[1 << 28]C.float)(unsafe.Pointer(w.samples))[:n:n]
  1047 +
  1048 + ans.Samples = make([]float32, n)
  1049 +
  1050 + for i := 0; i < n; i++ {
  1051 + ans.Samples[i] = float32(samples[i])
  1052 + }
  1053 +
  1054 + return ans
  1055 +}