Fangjun Kuang
Committed by GitHub

Python API for speaker diarization. (#1400)

@@ -8,6 +8,21 @@ log() { @@ -8,6 +8,21 @@ log() {
8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" 8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
9 } 9 }
10 10
  11 +log "test offline speaker diarization"
  12 +
  13 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
  14 +tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
  15 +rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
  16 +
  17 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
  18 +
  19 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
  20 +
  21 +python3 ./python-api-examples/offline-speaker-diarization.py
  22 +
  23 +rm -rf *.wav *.onnx ./sherpa-onnx-pyannote-segmentation-3-0
  24 +
  25 +
11 log "test_clustering" 26 log "test_clustering"
12 pushd /tmp/ 27 pushd /tmp/
13 mkdir test-cluster 28 mkdir test-cluster
@@ -93,7 +93,7 @@ jobs: @@ -93,7 +93,7 @@ jobs:
93 shell: bash 93 shell: bash
94 run: | 94 run: |
95 du -h -d1 . 95 du -h -d1 .
96 - export PATH=$PWD/build/bin:$PATH 96 + export PATH=$PWD/build/bin/Release:$PATH
97 export EXE=sherpa-onnx-offline-speaker-diarization.exe 97 export EXE=sherpa-onnx-offline-speaker-diarization.exe
98 98
99 .github/scripts/test-speaker-diarization.sh 99 .github/scripts/test-speaker-diarization.sh
@@ -93,7 +93,7 @@ jobs: @@ -93,7 +93,7 @@ jobs:
93 shell: bash 93 shell: bash
94 run: | 94 run: |
95 du -h -d1 . 95 du -h -d1 .
96 - export PATH=$PWD/build/bin:$PATH 96 + export PATH=$PWD/build/bin/Release:$PATH
97 export EXE=sherpa-onnx-offline-speaker-diarization.exe 97 export EXE=sherpa-onnx-offline-speaker-diarization.exe
98 98
99 .github/scripts/test-speaker-diarization.sh 99 .github/scripts/test-speaker-diarization.sh
  1 +#!/usr/bin/env python3
  2 +# Copyright (c) 2024 Xiaomi Corporation
  3 +
  4 +"""
  5 +This file shows how to use sherpa-onnx Python API for
  6 +offline/non-streaming speaker diarization.
  7 +
  8 +Usage:
  9 +
  10 +Step 1: Download a speaker segmentation model
  11 +
  12 +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
  13 +for a list of available models. The following is an example
  14 +
  15 + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
  16 + tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
  17 + rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
  18 +
  19 +Step 2: Download a speaker embedding extractor model
  20 +
  21 +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
  22 +for a list of available models. The following is an example
  23 +
  24 + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
  25 +
  26 +Step 3. Download test wave files
  27 +
  28 +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
  29 +for a list of available test wave files. The following is an example
  30 +
  31 + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
  32 +
  33 +Step 4. Run it
  34 +
  35 + python3 ./python-api-examples/offline-speaker-diarization.py
  36 +
  37 +"""
  38 +from pathlib import Path
  39 +
  40 +import sherpa_onnx
  41 +import soundfile as sf
  42 +
  43 +
  44 +def init_speaker_diarization(num_speakers: int = -1, cluster_threshold: float = 0.5):
  45 + """
  46 + Args:
  47 + num_speakers:
  48 + If you know the actual number of speakers in the wave file, then please
  49 + specify it. Otherwise, leave it to -1
  50 + cluster_threshold:
  51 + If num_speakers is -1, then this threshold is used for clustering.
  52 + A smaller cluster_threshold leads to more clusters, i.e., more speakers.
  53 + A larger cluster_threshold leads to fewer clusters, i.e., fewer speakers.
  54 + """
  55 + segmentation_model = "./sherpa-onnx-pyannote-segmentation-3-0/model.onnx"
  56 + embedding_extractor_model = (
  57 + "./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx"
  58 + )
  59 +
  60 + config = sherpa_onnx.OfflineSpeakerDiarizationConfig(
  61 + segmentation=sherpa_onnx.OfflineSpeakerSegmentationModelConfig(
  62 + pyannote=sherpa_onnx.OfflineSpeakerSegmentationPyannoteModelConfig(
  63 + model=segmentation_model
  64 + ),
  65 + ),
  66 + embedding=sherpa_onnx.SpeakerEmbeddingExtractorConfig(
  67 + model=embedding_extractor_model
  68 + ),
  69 + clustering=sherpa_onnx.FastClusteringConfig(
  70 + num_clusters=num_speakers, threshold=cluster_threshold
  71 + ),
  72 + min_duration_on=0.3,
  73 + min_duration_off=0.5,
  74 + )
  75 + if not config.validate():
  76 + raise RuntimeError(
  77 + "Please check your config and make sure all required files exist"
  78 + )
  79 +
  80 + return sherpa_onnx.OfflineSpeakerDiarization(config)
  81 +
  82 +
  83 +def progress_callback(num_processed_chunk: int, num_total_chunks: int) -> int:
  84 + progress = num_processed_chunk / num_total_chunks * 100
  85 + print(f"Progress: {progress:.3f}%")
  86 + return 0
  87 +
  88 +
  89 +def main():
  90 + wave_filename = "./0-four-speakers-zh.wav"
  91 + if not Path(wave_filename).is_file():
  92 + raise RuntimeError(f"{wave_filename} does not exist")
  93 +
  94 + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
  95 + audio = audio[:, 0] # only use the first channel
  96 +
  97 + # Since we know there are 4 speakers in the above test wave file, we use
  98 + # num_speakers 4 here
  99 + sd = init_speaker_diarization(num_speakers=4)
  100 + if sample_rate != sd.sample_rate:
  101 + raise RuntimeError(
  102 + f"Expected samples rate: {sd.sample_rate}, given: {sample_rate}"
  103 + )
  104 +
  105 + show_porgress = True
  106 +
  107 + if show_porgress:
  108 + result = sd.process(audio, callback=progress_callback).sort_by_start_time()
  109 + else:
  110 + result = sd.process(audio).sort_by_start_time()
  111 +
  112 + for r in result:
  113 + print(f"{r.start:.3f} -- {r.end:.3f} speaker_{r.speaker:02}")
  114 + # print(r) # this one is simpler
  115 +
  116 +
  117 +if __name__ == "__main__":
  118 + main()
@@ -103,7 +103,7 @@ class OfflineSpeakerDiarizationPyannoteImpl @@ -103,7 +103,7 @@ class OfflineSpeakerDiarizationPyannoteImpl
103 auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels); 103 auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels);
104 Matrix2D embeddings = 104 Matrix2D embeddings =
105 ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second, 105 ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second,
106 - callback, callback_arg); 106 + std::move(callback), callback_arg);
107 107
108 std::vector<int32_t> cluster_labels = clustering_.Cluster( 108 std::vector<int32_t> cluster_labels = clustering_.Cluster(
109 &embeddings(0, 0), embeddings.rows(), embeddings.cols()); 109 &embeddings(0, 0), embeddings.rows(), embeddings.cols());
@@ -28,6 +28,8 @@ class OfflineSpeakerDiarizationSegment { @@ -28,6 +28,8 @@ class OfflineSpeakerDiarizationSegment {
28 const std::string &Text() const { return text_; } 28 const std::string &Text() const { return text_; }
29 float Duration() const { return end_ - start_; } 29 float Duration() const { return end_ - start_; }
30 30
  31 + void SetText(const std::string &text) { text_ = text; }
  32 +
31 std::string ToString() const; 33 std::string ToString() const;
32 34
33 private: 35 private:
@@ -34,10 +34,13 @@ struct OfflineSpeakerDiarizationConfig { @@ -34,10 +34,13 @@ struct OfflineSpeakerDiarizationConfig {
34 OfflineSpeakerDiarizationConfig( 34 OfflineSpeakerDiarizationConfig(
35 const OfflineSpeakerSegmentationModelConfig &segmentation, 35 const OfflineSpeakerSegmentationModelConfig &segmentation,
36 const SpeakerEmbeddingExtractorConfig &embedding, 36 const SpeakerEmbeddingExtractorConfig &embedding,
37 - const FastClusteringConfig &clustering) 37 + const FastClusteringConfig &clustering, float min_duration_on,
  38 + float min_duration_off)
38 : segmentation(segmentation), 39 : segmentation(segmentation),
39 embedding(embedding), 40 embedding(embedding),
40 - clustering(clustering) {} 41 + clustering(clustering),
  42 + min_duration_on(min_duration_on),
  43 + min_duration_off(min_duration_off) {}
41 44
42 void Register(ParseOptions *po); 45 void Register(ParseOptions *po);
43 bool Validate() const; 46 bool Validate() const;
@@ -62,6 +62,8 @@ endif() @@ -62,6 +62,8 @@ endif()
62 if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) 62 if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
63 list(APPEND srcs 63 list(APPEND srcs
64 fast-clustering.cc 64 fast-clustering.cc
  65 + offline-speaker-diarization-result.cc
  66 + offline-speaker-diarization.cc
65 ) 67 )
66 endif() 68 endif()
67 69
  1 +// sherpa-onnx/python/csrc/offline-speaker-diarization-result.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/offline-speaker-diarization-result.h"
  6 +
  7 +#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h"
  8 +
  9 +namespace sherpa_onnx {
  10 +
  11 +static void PybindOfflineSpeakerDiarizationSegment(py::module *m) {
  12 + using PyClass = OfflineSpeakerDiarizationSegment;
  13 + py::class_<PyClass>(*m, "OfflineSpeakerDiarizationSegment")
  14 + .def_property_readonly("start", &PyClass::Start)
  15 + .def_property_readonly("end", &PyClass::End)
  16 + .def_property_readonly("duration", &PyClass::Duration)
  17 + .def_property_readonly("speaker", &PyClass::Speaker)
  18 + .def_property("text", &PyClass::Text, &PyClass::SetText)
  19 + .def("__str__", &PyClass::ToString);
  20 +}
  21 +
  22 +void PybindOfflineSpeakerDiarizationResult(py::module *m) {
  23 + PybindOfflineSpeakerDiarizationSegment(m);
  24 + using PyClass = OfflineSpeakerDiarizationResult;
  25 + py::class_<PyClass>(*m, "OfflineSpeakerDiarizationResult")
  26 + .def_property_readonly("num_speakers", &PyClass::NumSpeakers)
  27 + .def_property_readonly("num_segments", &PyClass::NumSegments)
  28 + .def("sort_by_start_time", &PyClass::SortByStartTime)
  29 + .def("sort_by_speaker", &PyClass::SortBySpeaker);
  30 +}
  31 +
  32 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-speaker-diarization-result.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineSpeakerDiarizationResult(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
  1 +// sherpa-onnx/python/csrc/offline-speaker-diarization.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/offline-speaker-diarization.h"
  6 +
  7 +#include <string>
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
  11 +#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
  12 +#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +static void PybindOfflineSpeakerSegmentationPyannoteModelConfig(py::module *m) {
  17 + using PyClass = OfflineSpeakerSegmentationPyannoteModelConfig;
  18 + py::class_<PyClass>(*m, "OfflineSpeakerSegmentationPyannoteModelConfig")
  19 + .def(py::init<>())
  20 + .def(py::init<const std::string &>(), py::arg("model"))
  21 + .def_readwrite("model", &PyClass::model)
  22 + .def("__str__", &PyClass::ToString)
  23 + .def("validate", &PyClass::Validate);
  24 +}
  25 +
  26 +static void PybindOfflineSpeakerSegmentationModelConfig(py::module *m) {
  27 + PybindOfflineSpeakerSegmentationPyannoteModelConfig(m);
  28 +
  29 + using PyClass = OfflineSpeakerSegmentationModelConfig;
  30 + py::class_<PyClass>(*m, "OfflineSpeakerSegmentationModelConfig")
  31 + .def(py::init<>())
  32 + .def(py::init<const OfflineSpeakerSegmentationPyannoteModelConfig &,
  33 + int32_t, bool, const std::string &>(),
  34 + py::arg("pyannote"), py::arg("num_threads") = 1,
  35 + py::arg("debug") = false, py::arg("provider") = "cpu")
  36 + .def_readwrite("pyannote", &PyClass::pyannote)
  37 + .def_readwrite("num_threads", &PyClass::num_threads)
  38 + .def_readwrite("debug", &PyClass::debug)
  39 + .def_readwrite("provider", &PyClass::provider)
  40 + .def("__str__", &PyClass::ToString)
  41 + .def("validate", &PyClass::Validate);
  42 +}
  43 +
  44 +static void PybindOfflineSpeakerDiarizationConfig(py::module *m) {
  45 + PybindOfflineSpeakerSegmentationModelConfig(m);
  46 +
  47 + using PyClass = OfflineSpeakerDiarizationConfig;
  48 + py::class_<PyClass>(*m, "OfflineSpeakerDiarizationConfig")
  49 + .def(py::init<const OfflineSpeakerSegmentationModelConfig &,
  50 + const SpeakerEmbeddingExtractorConfig &,
  51 + const FastClusteringConfig &, float, float>(),
  52 + py::arg("segmentation"), py::arg("embedding"), py::arg("clustering"),
  53 + py::arg("min_duration_on") = 0.3, py::arg("min_duration_off") = 0.5)
  54 + .def_readwrite("segmentation", &PyClass::segmentation)
  55 + .def_readwrite("embedding", &PyClass::embedding)
  56 + .def_readwrite("clustering", &PyClass::clustering)
  57 + .def_readwrite("min_duration_on", &PyClass::min_duration_on)
  58 + .def_readwrite("min_duration_off", &PyClass::min_duration_off)
  59 + .def("__str__", &PyClass::ToString)
  60 + .def("validate", &PyClass::Validate);
  61 +}
  62 +
  63 +void PybindOfflineSpeakerDiarization(py::module *m) {
  64 + PybindOfflineSpeakerDiarizationConfig(m);
  65 +
  66 + using PyClass = OfflineSpeakerDiarization;
  67 + py::class_<PyClass>(*m, "OfflineSpeakerDiarization")
  68 + .def(py::init<const OfflineSpeakerDiarizationConfig &>(),
  69 + py::arg("config"))
  70 + .def_property_readonly("sample_rate", &PyClass::SampleRate)
  71 + .def(
  72 + "process",
  73 + [](const PyClass &self, const std::vector<float> samples,
  74 + std::function<int32_t(int32_t, int32_t)> callback) {
  75 + if (!callback) {
  76 + return self.Process(samples.data(), samples.size());
  77 + }
  78 +
  79 + std::function<int32_t(int32_t, int32_t, void *)> callback_wrapper =
  80 + [callback](int32_t processed_chunks, int32_t num_chunks,
  81 + void *) -> int32_t {
  82 + callback(processed_chunks, num_chunks);
  83 + return 0;
  84 + };
  85 +
  86 + return self.Process(samples.data(), samples.size(),
  87 + callback_wrapper);
  88 + },
  89 + py::arg("samples"), py::arg("callback") = py::none());
  90 +}
  91 +
  92 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-speaker-diarization.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineSpeakerDiarization(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
@@ -37,6 +37,8 @@ @@ -37,6 +37,8 @@
37 37
38 #if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1 38 #if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1
39 #include "sherpa-onnx/python/csrc/fast-clustering.h" 39 #include "sherpa-onnx/python/csrc/fast-clustering.h"
  40 +#include "sherpa-onnx/python/csrc/offline-speaker-diarization-result.h"
  41 +#include "sherpa-onnx/python/csrc/offline-speaker-diarization.h"
40 #endif 42 #endif
41 43
42 namespace sherpa_onnx { 44 namespace sherpa_onnx {
@@ -74,14 +76,16 @@ PYBIND11_MODULE(_sherpa_onnx, m) { @@ -74,14 +76,16 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
74 PybindOfflineTts(&m); 76 PybindOfflineTts(&m);
75 #endif 77 #endif
76 78
77 -#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1  
78 - PybindFastClustering(&m);  
79 -#endif  
80 -  
81 PybindSpeakerEmbeddingExtractor(&m); 79 PybindSpeakerEmbeddingExtractor(&m);
82 PybindSpeakerEmbeddingManager(&m); 80 PybindSpeakerEmbeddingManager(&m);
83 PybindSpokenLanguageIdentification(&m); 81 PybindSpokenLanguageIdentification(&m);
84 82
  83 +#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1
  84 + PybindFastClustering(&m);
  85 + PybindOfflineSpeakerDiarizationResult(&m);
  86 + PybindOfflineSpeakerDiarization(&m);
  87 +#endif
  88 +
85 PybindAlsa(&m); 89 PybindAlsa(&m);
86 } 90 }
87 91
@@ -11,6 +11,12 @@ from _sherpa_onnx import ( @@ -11,6 +11,12 @@ from _sherpa_onnx import (
11 OfflinePunctuation, 11 OfflinePunctuation,
12 OfflinePunctuationConfig, 12 OfflinePunctuationConfig,
13 OfflinePunctuationModelConfig, 13 OfflinePunctuationModelConfig,
  14 + OfflineSpeakerDiarization,
  15 + OfflineSpeakerDiarizationConfig,
  16 + OfflineSpeakerDiarizationResult,
  17 + OfflineSpeakerDiarizationSegment,
  18 + OfflineSpeakerSegmentationModelConfig,
  19 + OfflineSpeakerSegmentationPyannoteModelConfig,
14 OfflineStream, 20 OfflineStream,
15 OfflineTts, 21 OfflineTts,
16 OfflineTtsConfig, 22 OfflineTtsConfig,