speaker-diarization-torch.py
2.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#!/usr/bin/env python3
"""
Please refer to
https://github.com/k2-fsa/sherpa-onnx/blob/master/.github/workflows/speaker-diarization.yaml
for usages.
"""
"""
1. Go to https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM/tree/main
wget https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM/resolve/main/speaker-embedding.onnx
2. Change line 166 of pyannote/audio/pipelines/speaker_diarization.py
```
# self._embedding = PretrainedSpeakerEmbedding(
# self.embedding, use_auth_token=use_auth_token
# )
self._embedding = embedding
```
"""
import argparse
from pathlib import Path
import torch
from pyannote.audio import Model
from pyannote.audio.pipelines import SpeakerDiarization as SpeakerDiarizationPipeline
from pyannote.audio.pipelines.speaker_verification import (
ONNXWeSpeakerPretrainedSpeakerEmbedding,
)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--wav", type=str, required=True, help="Path to test.wav")
return parser.parse_args()
def build_pipeline():
embedding_filename = "./speaker-embedding.onnx"
if Path(embedding_filename).is_file():
# You need to modify line 166
# of pyannote/audio/pipelines/speaker_diarization.py
# Please see the comments at the start of this script for details
embedding = ONNXWeSpeakerPretrainedSpeakerEmbedding(embedding_filename)
else:
embedding = "hbredin/wespeaker-voxceleb-resnet34-LM"
pt_filename = "./pytorch_model.bin"
segmentation = Model.from_pretrained(pt_filename)
segmentation.eval()
pipeline = SpeakerDiarizationPipeline(
segmentation=segmentation,
embedding=embedding,
embedding_exclude_overlap=True,
)
params = {
"clustering": {
"method": "centroid",
"min_cluster_size": 12,
"threshold": 0.7045654963945799,
},
"segmentation": {"min_duration_off": 0.5},
}
pipeline.instantiate(params)
return pipeline
@torch.no_grad()
def main():
args = get_args()
assert Path(args.wav).is_file(), args.wav
pipeline = build_pipeline()
print(pipeline)
t = pipeline(args.wav)
print(type(t))
print(t)
if __name__ == "__main__":
main()