sherpa-onnx-offline-speaker-diarization.cc
4.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
// sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/wave-reader.h"
static int32_t ProgressCallback(int32_t processed_chunks, int32_t num_chunks,
void *) {
float progress = 100.0 * processed_chunks / num_chunks;
fprintf(stderr, "progress %.2f%%\n", progress);
// the return value is currently ignored
return 0;
}
int main(int32_t argc, char *argv[]) {
const char *kUsageMessage = R"usage(
Offline/Non-streaming speaker diarization with sherpa-onnx
Usage example:
Step 1: Download a speaker segmentation model
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available models. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
Step 2: Download a speaker embedding extractor model
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
for a list of available models. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
Step 3. Download test wave files
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available test wave files. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
Step 4. Build sherpa-onnx
Step 5. Run it
./bin/sherpa-onnx-offline-speaker-diarization \
--clustering.num-clusters=4 \
--segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \
--embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \
./0-four-speakers-zh.wav
Since we know that there are four speakers in the test wave file, we use
--clustering.num-clusters=4 in the above example.
If we don't know number of speakers in the given wave file, we can use
the argument --clustering.cluster-threshold. The following is an example:
./bin/sherpa-onnx-offline-speaker-diarization \
--clustering.cluster-threshold=0.90 \
--segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \
--embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \
./0-four-speakers-zh.wav
A larger threshold leads to few clusters, i.e., few speakers;
a smaller threshold leads to more clusters, i.e., more speakers
)usage";
sherpa_onnx::OfflineSpeakerDiarizationConfig config;
sherpa_onnx::ParseOptions po(kUsageMessage);
config.Register(&po);
po.Read(argc, argv);
std::cout << config.ToString() << "\n";
if (!config.Validate()) {
po.PrintUsage();
std::cerr << "Errors in config!\n";
return -1;
}
if (po.NumArgs() != 1) {
std::cerr << "Error: Please provide exactly 1 wave file.\n\n";
po.PrintUsage();
return -1;
}
sherpa_onnx::OfflineSpeakerDiarization sd(config);
std::cout << "Started\n";
const auto begin = std::chrono::steady_clock::now();
const std::string wav_filename = po.GetArg(1);
int32_t sample_rate = -1;
bool is_ok = false;
const std::vector<float> samples =
sherpa_onnx::ReadWave(wav_filename, &sample_rate, &is_ok);
if (!is_ok) {
std::cerr << "Failed to read " << wav_filename.c_str() << "\n";
return -1;
}
if (sample_rate != sd.SampleRate()) {
std::cerr << "Expect sample rate " << sd.SampleRate()
<< ". Given: " << sample_rate << "\n";
return -1;
}
float duration = samples.size() / static_cast<float>(sample_rate);
auto result =
sd.Process(samples.data(), samples.size(), ProgressCallback, nullptr)
.SortByStartTime();
for (const auto &r : result) {
std::cout << r.ToString() << "\n";
}
const auto end = std::chrono::steady_clock::now();
float elapsed_seconds =
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
.count() /
1000.;
fprintf(stderr, "Duration : %.3f s\n", duration);
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
float rtf = elapsed_seconds / duration;
fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
elapsed_seconds, duration, rtf);
return 0;
}