Fangjun Kuang
Committed by GitHub

Go API for speaker diarization (#1403)

@@ -68,6 +68,50 @@ jobs: @@ -68,6 +68,50 @@ jobs:
68 run: | 68 run: |
69 gcc --version 69 gcc --version
70 70
  71 + - name: Test non-streaming speaker diarization
  72 + if: matrix.os != 'windows-latest'
  73 + shell: bash
  74 + run: |
  75 + cd go-api-examples/non-streaming-speaker-diarization/
  76 + ./run.sh
  77 +
  78 + - name: Test non-streaming speaker diarization
  79 + if: matrix.os == 'windows-latest' && matrix.arch == 'x64'
  80 + shell: bash
  81 + run: |
  82 + cd go-api-examples/non-streaming-speaker-diarization/
  83 + go mod tidy
  84 + cat go.mod
  85 + go build
  86 +
  87 + echo $PWD
  88 + ls -lh /C/Users/runneradmin/go/pkg/mod/github.com/k2-fsa/
  89 + ls -lh /C/Users/runneradmin/go/pkg/mod/github.com/k2-fsa/*
  90 + cp -v /C/Users/runneradmin/go/pkg/mod/github.com/k2-fsa/sherpa-onnx-go-windows*/lib/x86_64-pc-windows-gnu/*.dll .
  91 +
  92 + ./run.sh
  93 +
  94 + - name: Test non-streaming speaker diarization
  95 + if: matrix.os == 'windows-latest' && matrix.arch == 'x86'
  96 + shell: bash
  97 + run: |
  98 + cd go-api-examples/non-streaming-speaker-diarization/
  99 +
  100 + go env GOARCH
  101 + go env -w GOARCH=386
  102 + go env -w CGO_ENABLED=1
  103 +
  104 + go mod tidy
  105 + cat go.mod
  106 + go build
  107 +
  108 + echo $PWD
  109 + ls -lh /C/Users/runneradmin/go/pkg/mod/github.com/k2-fsa/
  110 + ls -lh /C/Users/runneradmin/go/pkg/mod/github.com/k2-fsa/*
  111 + cp -v /C/Users/runneradmin/go/pkg/mod/github.com/k2-fsa/sherpa-onnx-go-windows*/lib/i686-pc-windows-gnu/*.dll .
  112 +
  113 + ./run.sh
  114 +
71 - name: Test streaming HLG decoding (Linux/macOS) 115 - name: Test streaming HLG decoding (Linux/macOS)
72 if: matrix.os != 'windows-latest' 116 if: matrix.os != 'windows-latest'
73 shell: bash 117 shell: bash
@@ -134,6 +134,12 @@ jobs: @@ -134,6 +134,12 @@ jobs:
134 name: ${{ matrix.os }}-libs 134 name: ${{ matrix.os }}-libs
135 path: to-upload/ 135 path: to-upload/
136 136
  137 + - name: Test non-streaming speaker diarization
  138 + shell: bash
  139 + run: |
  140 + cd scripts/go/_internal/non-streaming-speaker-diarization/
  141 + ./run.sh
  142 +
137 - name: Test speaker identification 143 - name: Test speaker identification
138 shell: bash 144 shell: bash
139 run: | 145 run: |
  1 +module non-streaming-speaker-diarization
  2 +
  3 +go 1.12
  1 +package main
  2 +
  3 +import (
  4 + sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx"
  5 + "log"
  6 +)
  7 +
  8 +/*
  9 +Usage:
  10 +
  11 +Step 1: Download a speaker segmentation model
  12 +
  13 +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
  14 +for a list of available models. The following is an example
  15 +
  16 + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
  17 + tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
  18 + rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
  19 +
  20 +Step 2: Download a speaker embedding extractor model
  21 +
  22 +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
  23 +for a list of available models. The following is an example
  24 +
  25 + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
  26 +
  27 +Step 3. Download test wave files
  28 +
  29 +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
  30 +for a list of available test wave files. The following is an example
  31 +
  32 + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
  33 +
  34 +Step 4. Run it
  35 +*/
  36 +
  37 +func initSpeakerDiarization() *sherpa.OfflineSpeakerDiarization {
  38 + config := sherpa.OfflineSpeakerDiarizationConfig{}
  39 +
  40 + config.Segmentation.Pyannote.Model = "./sherpa-onnx-pyannote-segmentation-3-0/model.onnx"
  41 + config.Embedding.Model = "./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx"
  42 +
  43 + // The test wave file contains 4 speakers, so we use 4 here
  44 + config.Clustering.NumClusters = 4
  45 +
  46 + // if you don't know the actual numbers in the wave file,
  47 + // then please don't set NumClusters; you need to use
  48 + //
  49 + // config.Clustering.Threshold = 0.5
  50 + //
  51 +
  52 + // A larger Threshold leads to fewer clusters
  53 + // A smaller Threshold leads to more clusters
  54 +
  55 + sd := sherpa.NewOfflineSpeakerDiarization(&config)
  56 + return sd
  57 +}
  58 +
  59 +func main() {
  60 + wave_filename := "./0-four-speakers-zh.wav"
  61 + wave := sherpa.ReadWave(wave_filename)
  62 + if wave == nil {
  63 + log.Printf("Failed to read %v", wave_filename)
  64 + return
  65 + }
  66 +
  67 + sd := initSpeakerDiarization()
  68 + if sd == nil {
  69 + log.Printf("Please check your config")
  70 + return
  71 + }
  72 +
  73 + defer sherpa.DeleteOfflineSpeakerDiarization(sd)
  74 +
  75 + if wave.SampleRate != sd.SampleRate() {
  76 + log.Printf("Expected sample rate: %v, given: %d\n", sd.SampleRate(), wave.SampleRate)
  77 + return
  78 + }
  79 +
  80 + log.Println("Started")
  81 + segments := sd.Process(wave.Samples)
  82 + n := len(segments)
  83 +
  84 + for i := 0; i < n; i++ {
  85 + log.Printf("%.3f -- %.3f speaker_%02d\n", segments[i].Start, segments[i].End, segments[i].Speaker)
  86 + }
  87 +}
  1 +#!/usr/bin/env bash
  2 +
  3 +
  4 +if [ ! -f ./sherpa-onnx-pyannote-segmentation-3-0/model.onnx ]; then
  5 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
  6 + tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
  7 + rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
  8 +fi
  9 +
  10 +if [ ! -f ./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx ]; then
  11 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
  12 +fi
  13 +
  14 +if [ ! -f ./0-four-speakers-zh.wav ]; then
  15 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
  16 +fi
  17 +
  18 +go mod tidy
  19 +go build
  20 +./non-streaming-speaker-diarization
  1 +module non-streaming-speaker-diarization
  2 +
  3 +go 1.12
  4 +
  5 +replace github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx => ../
  1 +../../../../go-api-examples/non-streaming-speaker-diarization/main.go
  1 +../../../../go-api-examples/non-streaming-speaker-diarization/run.sh
@@ -1175,7 +1175,14 @@ func ReadWave(filename string) *Wave { @@ -1175,7 +1175,14 @@ func ReadWave(filename string) *Wave {
1175 w := C.SherpaOnnxReadWave(s) 1175 w := C.SherpaOnnxReadWave(s)
1176 defer C.SherpaOnnxFreeWave(w) 1176 defer C.SherpaOnnxFreeWave(w)
1177 1177
  1178 + if w == nil {
  1179 + return nil
  1180 + }
  1181 +
1178 n := int(w.num_samples) 1182 n := int(w.num_samples)
  1183 + if n == 0 {
  1184 + return nil
  1185 + }
1179 1186
1180 ans := &Wave{} 1187 ans := &Wave{}
1181 ans.SampleRate = int(w.sample_rate) 1188 ans.SampleRate = int(w.sample_rate)
@@ -1189,3 +1196,114 @@ func ReadWave(filename string) *Wave { @@ -1189,3 +1196,114 @@ func ReadWave(filename string) *Wave {
1189 1196
1190 return ans 1197 return ans
1191 } 1198 }
  1199 +
  1200 +// ============================================================
  1201 +// For offline speaker diarization
  1202 +// ============================================================
  1203 +type OfflineSpeakerSegmentationPyannoteModelConfig struct {
  1204 + Model string
  1205 +}
  1206 +
  1207 +type OfflineSpeakerSegmentationModelConfig struct {
  1208 + Pyannote OfflineSpeakerSegmentationPyannoteModelConfig
  1209 + NumThreads int
  1210 + Debug int
  1211 + Provider string
  1212 +}
  1213 +
  1214 +type FastClusteringConfig struct {
  1215 + NumClusters int
  1216 + Threshold float32
  1217 +}
  1218 +
  1219 +type OfflineSpeakerDiarizationConfig struct {
  1220 + Segmentation OfflineSpeakerSegmentationModelConfig
  1221 + Embedding SpeakerEmbeddingExtractorConfig
  1222 + Clustering FastClusteringConfig
  1223 + MinDurationOn float32
  1224 + MinDurationOff float32
  1225 +}
  1226 +
  1227 +type OfflineSpeakerDiarization struct {
  1228 + impl *C.struct_SherpaOnnxOfflineSpeakerDiarization
  1229 +}
  1230 +
  1231 +func DeleteOfflineSpeakerDiarization(sd *OfflineSpeakerDiarization) {
  1232 + C.SherpaOnnxDestroyOfflineSpeakerDiarization(sd.impl)
  1233 + sd.impl = nil
  1234 +}
  1235 +
  1236 +func NewOfflineSpeakerDiarization(config *OfflineSpeakerDiarizationConfig) *OfflineSpeakerDiarization {
  1237 + c := C.struct_SherpaOnnxOfflineSpeakerDiarizationConfig{}
  1238 + c.segmentation.pyannote.model = C.CString(config.Segmentation.Pyannote.Model)
  1239 + defer C.free(unsafe.Pointer(c.segmentation.pyannote.model))
  1240 +
  1241 + c.segmentation.num_threads = C.int(config.Segmentation.NumThreads)
  1242 +
  1243 + c.segmentation.debug = C.int(config.Segmentation.Debug)
  1244 +
  1245 + c.segmentation.provider = C.CString(config.Segmentation.Provider)
  1246 + defer C.free(unsafe.Pointer(c.segmentation.provider))
  1247 +
  1248 + c.embedding.model = C.CString(config.Embedding.Model)
  1249 + defer C.free(unsafe.Pointer(c.embedding.model))
  1250 +
  1251 + c.embedding.num_threads = C.int(config.Embedding.NumThreads)
  1252 +
  1253 + c.embedding.debug = C.int(config.Embedding.Debug)
  1254 +
  1255 + c.embedding.provider = C.CString(config.Embedding.Provider)
  1256 + defer C.free(unsafe.Pointer(c.embedding.provider))
  1257 +
  1258 + c.clustering.num_clusters = C.int(config.Clustering.NumClusters)
  1259 + c.clustering.threshold = C.float(config.Clustering.Threshold)
  1260 + c.min_duration_on = C.float(config.MinDurationOn)
  1261 + c.min_duration_off = C.float(config.MinDurationOff)
  1262 +
  1263 + p := C.SherpaOnnxCreateOfflineSpeakerDiarization(&c)
  1264 +
  1265 + if p == nil {
  1266 + return nil
  1267 + }
  1268 +
  1269 + sd := &OfflineSpeakerDiarization{}
  1270 + sd.impl = p
  1271 +
  1272 + return sd
  1273 +}
  1274 +
  1275 +func (sd *OfflineSpeakerDiarization) SampleRate() int {
  1276 + return int(C.SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(sd.impl))
  1277 +}
  1278 +
  1279 +type OfflineSpeakerDiarizationSegment struct {
  1280 + Start float32
  1281 + End float32
  1282 + Speaker int
  1283 +}
  1284 +
  1285 +func (sd *OfflineSpeakerDiarization) Process(samples []float32) []OfflineSpeakerDiarizationSegment {
  1286 + r := C.SherpaOnnxOfflineSpeakerDiarizationProcess(sd.impl, (*C.float)(&samples[0]), C.int(len(samples)))
  1287 + defer C.SherpaOnnxOfflineSpeakerDiarizationDestroyResult(r)
  1288 +
  1289 + n := int(C.SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments(r))
  1290 +
  1291 + if n == 0 {
  1292 + return nil
  1293 + }
  1294 +
  1295 + s := C.SherpaOnnxOfflineSpeakerDiarizationResultSortByStartTime(r)
  1296 + defer C.SherpaOnnxOfflineSpeakerDiarizationDestroySegment(s)
  1297 +
  1298 + ans := make([]OfflineSpeakerDiarizationSegment, n)
  1299 +
  1300 + p := (*[1 << 28]C.struct_SherpaOnnxOfflineSpeakerDiarizationSegment)(unsafe.Pointer(s))[:n:n]
  1301 +
  1302 + for i := 0; i < n; i++ {
  1303 + ans[i].Start = float32(p[i].start)
  1304 + ans[i].End = float32(p[i].end)
  1305 + ans[i].Speaker = int(p[i].speaker)
  1306 + }
  1307 +
  1308 + return ans
  1309 +}