Fangjun Kuang
Committed by GitHub

Python API for speaker diarization. (#1400)

... ... @@ -8,6 +8,21 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "test offline speaker diarization"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
python3 ./python-api-examples/offline-speaker-diarization.py
rm -rf *.wav *.onnx ./sherpa-onnx-pyannote-segmentation-3-0
log "test_clustering"
pushd /tmp/
mkdir test-cluster
... ...
... ... @@ -93,7 +93,7 @@ jobs:
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx-offline-speaker-diarization.exe
.github/scripts/test-speaker-diarization.sh
... ...
... ... @@ -93,7 +93,7 @@ jobs:
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx-offline-speaker-diarization.exe
.github/scripts/test-speaker-diarization.sh
... ...
#!/usr/bin/env python3
# Copyright (c) 2024 Xiaomi Corporation
"""
This file shows how to use sherpa-onnx Python API for
offline/non-streaming speaker diarization.
Usage:
Step 1: Download a speaker segmentation model
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available models. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
Step 2: Download a speaker embedding extractor model
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
for a list of available models. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
Step 3. Download test wave files
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available test wave files. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
Step 4. Run it
python3 ./python-api-examples/offline-speaker-diarization.py
"""
from pathlib import Path
import sherpa_onnx
import soundfile as sf
def init_speaker_diarization(num_speakers: int = -1, cluster_threshold: float = 0.5):
"""
Args:
num_speakers:
If you know the actual number of speakers in the wave file, then please
specify it. Otherwise, leave it to -1
cluster_threshold:
If num_speakers is -1, then this threshold is used for clustering.
A smaller cluster_threshold leads to more clusters, i.e., more speakers.
A larger cluster_threshold leads to fewer clusters, i.e., fewer speakers.
"""
segmentation_model = "./sherpa-onnx-pyannote-segmentation-3-0/model.onnx"
embedding_extractor_model = (
"./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx"
)
config = sherpa_onnx.OfflineSpeakerDiarizationConfig(
segmentation=sherpa_onnx.OfflineSpeakerSegmentationModelConfig(
pyannote=sherpa_onnx.OfflineSpeakerSegmentationPyannoteModelConfig(
model=segmentation_model
),
),
embedding=sherpa_onnx.SpeakerEmbeddingExtractorConfig(
model=embedding_extractor_model
),
clustering=sherpa_onnx.FastClusteringConfig(
num_clusters=num_speakers, threshold=cluster_threshold
),
min_duration_on=0.3,
min_duration_off=0.5,
)
if not config.validate():
raise RuntimeError(
"Please check your config and make sure all required files exist"
)
return sherpa_onnx.OfflineSpeakerDiarization(config)
def progress_callback(num_processed_chunk: int, num_total_chunks: int) -> int:
progress = num_processed_chunk / num_total_chunks * 100
print(f"Progress: {progress:.3f}%")
return 0
def main():
wave_filename = "./0-four-speakers-zh.wav"
if not Path(wave_filename).is_file():
raise RuntimeError(f"{wave_filename} does not exist")
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel
# Since we know there are 4 speakers in the above test wave file, we use
# num_speakers 4 here
sd = init_speaker_diarization(num_speakers=4)
if sample_rate != sd.sample_rate:
raise RuntimeError(
f"Expected samples rate: {sd.sample_rate}, given: {sample_rate}"
)
show_porgress = True
if show_porgress:
result = sd.process(audio, callback=progress_callback).sort_by_start_time()
else:
result = sd.process(audio).sort_by_start_time()
for r in result:
print(f"{r.start:.3f} -- {r.end:.3f} speaker_{r.speaker:02}")
# print(r) # this one is simpler
if __name__ == "__main__":
main()
... ...
... ... @@ -103,7 +103,7 @@ class OfflineSpeakerDiarizationPyannoteImpl
auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels);
Matrix2D embeddings =
ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second,
callback, callback_arg);
std::move(callback), callback_arg);
std::vector<int32_t> cluster_labels = clustering_.Cluster(
&embeddings(0, 0), embeddings.rows(), embeddings.cols());
... ...
... ... @@ -28,6 +28,8 @@ class OfflineSpeakerDiarizationSegment {
const std::string &Text() const { return text_; }
float Duration() const { return end_ - start_; }
void SetText(const std::string &text) { text_ = text; }
std::string ToString() const;
private:
... ...
... ... @@ -34,10 +34,13 @@ struct OfflineSpeakerDiarizationConfig {
OfflineSpeakerDiarizationConfig(
const OfflineSpeakerSegmentationModelConfig &segmentation,
const SpeakerEmbeddingExtractorConfig &embedding,
const FastClusteringConfig &clustering)
const FastClusteringConfig &clustering, float min_duration_on,
float min_duration_off)
: segmentation(segmentation),
embedding(embedding),
clustering(clustering) {}
clustering(clustering),
min_duration_on(min_duration_on),
min_duration_off(min_duration_off) {}
void Register(ParseOptions *po);
bool Validate() const;
... ...
... ... @@ -62,6 +62,8 @@ endif()
if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
list(APPEND srcs
fast-clustering.cc
offline-speaker-diarization-result.cc
offline-speaker-diarization.cc
)
endif()
... ...
// sherpa-onnx/python/csrc/offline-speaker-diarization-result.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/offline-speaker-diarization-result.h"
#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h"
namespace sherpa_onnx {
static void PybindOfflineSpeakerDiarizationSegment(py::module *m) {
using PyClass = OfflineSpeakerDiarizationSegment;
py::class_<PyClass>(*m, "OfflineSpeakerDiarizationSegment")
.def_property_readonly("start", &PyClass::Start)
.def_property_readonly("end", &PyClass::End)
.def_property_readonly("duration", &PyClass::Duration)
.def_property_readonly("speaker", &PyClass::Speaker)
.def_property("text", &PyClass::Text, &PyClass::SetText)
.def("__str__", &PyClass::ToString);
}
void PybindOfflineSpeakerDiarizationResult(py::module *m) {
PybindOfflineSpeakerDiarizationSegment(m);
using PyClass = OfflineSpeakerDiarizationResult;
py::class_<PyClass>(*m, "OfflineSpeakerDiarizationResult")
.def_property_readonly("num_speakers", &PyClass::NumSpeakers)
.def_property_readonly("num_segments", &PyClass::NumSegments)
.def("sort_by_start_time", &PyClass::SortByStartTime)
.def("sort_by_speaker", &PyClass::SortBySpeaker);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/offline-speaker-diarization-result.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOfflineSpeakerDiarizationResult(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
... ...
// sherpa-onnx/python/csrc/offline-speaker-diarization.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/offline-speaker-diarization.h"
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h"
namespace sherpa_onnx {
static void PybindOfflineSpeakerSegmentationPyannoteModelConfig(py::module *m) {
using PyClass = OfflineSpeakerSegmentationPyannoteModelConfig;
py::class_<PyClass>(*m, "OfflineSpeakerSegmentationPyannoteModelConfig")
.def(py::init<>())
.def(py::init<const std::string &>(), py::arg("model"))
.def_readwrite("model", &PyClass::model)
.def("__str__", &PyClass::ToString)
.def("validate", &PyClass::Validate);
}
static void PybindOfflineSpeakerSegmentationModelConfig(py::module *m) {
PybindOfflineSpeakerSegmentationPyannoteModelConfig(m);
using PyClass = OfflineSpeakerSegmentationModelConfig;
py::class_<PyClass>(*m, "OfflineSpeakerSegmentationModelConfig")
.def(py::init<>())
.def(py::init<const OfflineSpeakerSegmentationPyannoteModelConfig &,
int32_t, bool, const std::string &>(),
py::arg("pyannote"), py::arg("num_threads") = 1,
py::arg("debug") = false, py::arg("provider") = "cpu")
.def_readwrite("pyannote", &PyClass::pyannote)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)
.def_readwrite("provider", &PyClass::provider)
.def("__str__", &PyClass::ToString)
.def("validate", &PyClass::Validate);
}
static void PybindOfflineSpeakerDiarizationConfig(py::module *m) {
PybindOfflineSpeakerSegmentationModelConfig(m);
using PyClass = OfflineSpeakerDiarizationConfig;
py::class_<PyClass>(*m, "OfflineSpeakerDiarizationConfig")
.def(py::init<const OfflineSpeakerSegmentationModelConfig &,
const SpeakerEmbeddingExtractorConfig &,
const FastClusteringConfig &, float, float>(),
py::arg("segmentation"), py::arg("embedding"), py::arg("clustering"),
py::arg("min_duration_on") = 0.3, py::arg("min_duration_off") = 0.5)
.def_readwrite("segmentation", &PyClass::segmentation)
.def_readwrite("embedding", &PyClass::embedding)
.def_readwrite("clustering", &PyClass::clustering)
.def_readwrite("min_duration_on", &PyClass::min_duration_on)
.def_readwrite("min_duration_off", &PyClass::min_duration_off)
.def("__str__", &PyClass::ToString)
.def("validate", &PyClass::Validate);
}
void PybindOfflineSpeakerDiarization(py::module *m) {
PybindOfflineSpeakerDiarizationConfig(m);
using PyClass = OfflineSpeakerDiarization;
py::class_<PyClass>(*m, "OfflineSpeakerDiarization")
.def(py::init<const OfflineSpeakerDiarizationConfig &>(),
py::arg("config"))
.def_property_readonly("sample_rate", &PyClass::SampleRate)
.def(
"process",
[](const PyClass &self, const std::vector<float> samples,
std::function<int32_t(int32_t, int32_t)> callback) {
if (!callback) {
return self.Process(samples.data(), samples.size());
}
std::function<int32_t(int32_t, int32_t, void *)> callback_wrapper =
[callback](int32_t processed_chunks, int32_t num_chunks,
void *) -> int32_t {
callback(processed_chunks, num_chunks);
return 0;
};
return self.Process(samples.data(), samples.size(),
callback_wrapper);
},
py::arg("samples"), py::arg("callback") = py::none());
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/offline-speaker-diarization.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOfflineSpeakerDiarization(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
... ...
... ... @@ -37,6 +37,8 @@
#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1
#include "sherpa-onnx/python/csrc/fast-clustering.h"
#include "sherpa-onnx/python/csrc/offline-speaker-diarization-result.h"
#include "sherpa-onnx/python/csrc/offline-speaker-diarization.h"
#endif
namespace sherpa_onnx {
... ... @@ -74,14 +76,16 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
PybindOfflineTts(&m);
#endif
#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1
PybindFastClustering(&m);
#endif
PybindSpeakerEmbeddingExtractor(&m);
PybindSpeakerEmbeddingManager(&m);
PybindSpokenLanguageIdentification(&m);
#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1
PybindFastClustering(&m);
PybindOfflineSpeakerDiarizationResult(&m);
PybindOfflineSpeakerDiarization(&m);
#endif
PybindAlsa(&m);
}
... ...
... ... @@ -11,6 +11,12 @@ from _sherpa_onnx import (
OfflinePunctuation,
OfflinePunctuationConfig,
OfflinePunctuationModelConfig,
OfflineSpeakerDiarization,
OfflineSpeakerDiarizationConfig,
OfflineSpeakerDiarizationResult,
OfflineSpeakerDiarizationSegment,
OfflineSpeakerSegmentationModelConfig,
OfflineSpeakerSegmentationPyannoteModelConfig,
OfflineStream,
OfflineTts,
OfflineTtsConfig,
... ...