offline-speaker-diarization.py
4.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python3
# Copyright (c) 2024 Xiaomi Corporation
"""
This file shows how to use sherpa-onnx Python API for
offline/non-streaming speaker diarization.
Usage:
Step 1: Download a speaker segmentation model
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available models. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
Step 2: Download a speaker embedding extractor model
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
for a list of available models. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
Step 3. Download test wave files
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available test wave files. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
Step 4. Run it
python3 ./python-api-examples/offline-speaker-diarization.py
"""
from pathlib import Path
import sherpa_onnx
import soundfile as sf
import librosa
def resample_audio(audio, sample_rate, target_sample_rate):
"""
Resample audio to target sample rate using librosa
"""
if sample_rate != target_sample_rate:
print(f"Resampling audio from {sample_rate}Hz to {target_sample_rate}Hz...")
audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=target_sample_rate)
print(f"Resampling completed. New audio shape: {audio.shape}")
return audio, target_sample_rate
return audio, sample_rate
def init_speaker_diarization(num_speakers: int = -1, cluster_threshold: float = 0.5):
"""
Args:
num_speakers:
If you know the actual number of speakers in the wave file, then please
specify it. Otherwise, leave it to -1
cluster_threshold:
If num_speakers is -1, then this threshold is used for clustering.
A smaller cluster_threshold leads to more clusters, i.e., more speakers.
A larger cluster_threshold leads to fewer clusters, i.e., fewer speakers.
"""
segmentation_model = "./sherpa-onnx-pyannote-segmentation-3-0/model.onnx"
embedding_extractor_model = (
"./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx"
)
config = sherpa_onnx.OfflineSpeakerDiarizationConfig(
segmentation=sherpa_onnx.OfflineSpeakerSegmentationModelConfig(
pyannote=sherpa_onnx.OfflineSpeakerSegmentationPyannoteModelConfig(
model=segmentation_model
),
),
embedding=sherpa_onnx.SpeakerEmbeddingExtractorConfig(
model=embedding_extractor_model
),
clustering=sherpa_onnx.FastClusteringConfig(
num_clusters=num_speakers, threshold=cluster_threshold
),
min_duration_on=0.3,
min_duration_off=0.5,
)
if not config.validate():
raise RuntimeError(
"Please check your config and make sure all required files exist"
)
return sherpa_onnx.OfflineSpeakerDiarization(config)
def progress_callback(num_processed_chunk: int, num_total_chunks: int) -> int:
progress = num_processed_chunk / num_total_chunks * 100
print(f"Progress: {progress:.3f}%")
return 0
def main():
wave_filename = "./0-four-speakers-zh.wav"
if not Path(wave_filename).is_file():
raise RuntimeError(f"{wave_filename} does not exist")
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel
# Since we know there are 4 speakers in the above test wave file, we use
# num_speakers 4 here
sd = init_speaker_diarization(num_speakers=4)
# Resample audio to match the expected sample rate
target_sample_rate = sd.sample_rate
audio, sample_rate = resample_audio(audio, sample_rate, target_sample_rate)
if sample_rate != sd.sample_rate:
raise RuntimeError(
f"Expected samples rate: {sd.sample_rate}, given: {sample_rate}"
)
show_progress = True
if show_progress:
result = sd.process(audio, callback=progress_callback).sort_by_start_time()
else:
result = sd.process(audio).sort_by_start_time()
for r in result:
print(f"{r.start:.3f} -- {r.end:.3f} speaker_{r.speaker:02}")
# print(r) # this one is simpler
if __name__ == "__main__":
main()