Fangjun Kuang
Committed by GitHub

Speaker diarization example with onnxruntime Python API (#1395)

  1 +name: speaker-diarization
  2 +
  3 +on:
  4 + push:
  5 + branches:
  6 + - speaker-diarization
  7 + workflow_dispatch:
  8 +
  9 +concurrency:
  10 + group: speaker-diarization-${{ github.ref }}
  11 + cancel-in-progress: true
  12 +
  13 +jobs:
  14 + linux:
  15 + name: speaker diarization
  16 + runs-on: ${{ matrix.os }}
  17 + strategy:
  18 + fail-fast: false
  19 + matrix:
  20 + os: [macos-latest]
  21 + python-version: ["3.10"]
  22 +
  23 + steps:
  24 + - uses: actions/checkout@v4
  25 + with:
  26 + fetch-depth: 0
  27 +
  28 + - name: ccache
  29 + uses: hendrikmuhs/ccache-action@v1.2
  30 + with:
  31 + key: ${{ matrix.os }}-speaker-diarization
  32 +
  33 + - name: Setup Python ${{ matrix.python-version }}
  34 + uses: actions/setup-python@v5
  35 + with:
  36 + python-version: ${{ matrix.python-version }}
  37 +
  38 + - name: Install pyannote
  39 + shell: bash
  40 + run: |
  41 + pip install pyannote.audio onnx onnxruntime
  42 +
  43 + - name: Install sherpa-onnx from source
  44 + shell: bash
  45 + run: |
  46 + python3 -m pip install --upgrade pip
  47 + python3 -m pip install wheel twine setuptools
  48 +
  49 + export CMAKE_CXX_COMPILER_LAUNCHER=ccache
  50 + export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH"
  51 +
  52 + cat sherpa-onnx/python/sherpa_onnx/__init__.py
  53 +
  54 + python3 setup.py bdist_wheel
  55 + ls -lh dist
  56 + pip install ./dist/*.whl
  57 +
  58 + - name: Run tests
  59 + shell: bash
  60 + run: |
  61 + pushd scripts/pyannote/segmentation
  62 +
  63 + python3 -c "import sherpa_onnx; print(sherpa_onnx.__file__)"
  64 + python3 -c "import sherpa_onnx; print(sherpa_onnx.__version__)"
  65 + python3 -c "import sherpa_onnx; print(dir(sherpa_onnx))"
  66 +
  67 + curl -SL -O https://huggingface.co/csukuangfj/pyannote-models/resolve/main/segmentation-3.0/pytorch_model.bin
  68 +
  69 + test_wavs=(
  70 + 0-two-speakers-zh.wav
  71 + 1-two-speakers-en.wav
  72 + 2-two-speakers-en.wav
  73 + 3-two-speakers-en.wav
  74 + )
  75 +
  76 + for w in ${test_wavs[@]}; do
  77 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/$w
  78 + done
  79 +
  80 + soxi *.wav
  81 +
  82 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
  83 + tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
  84 + rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
  85 + ls -lh sherpa-onnx-pyannote-segmentation-3-0
  86 +
  87 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
  88 +
  89 + for w in ${test_wavs[@]}; do
  90 + echo "---------test $w (onnx)----------"
  91 + time ./speaker-diarization-onnx.py \
  92 + --seg-model ./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \
  93 + --speaker-embedding-model ./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \
  94 + --wav $w
  95 +
  96 + echo "---------test $w (torch)----------"
  97 + time ./speaker-diarization-torch.py --wav $w
  98 + done
@@ -118,3 +118,5 @@ vits-melo-tts-zh_en @@ -118,3 +118,5 @@ vits-melo-tts-zh_en
118 *.o 118 *.o
119 *.ppu 119 *.ppu
120 sherpa-onnx-online-punct-en-2024-08-06 120 sherpa-onnx-online-punct-en-2024-08-06
  121 +*.mp4
  122 +*.mp3
  1 +# File description
  2 +
  3 +Please download test wave files from
  4 +https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
  5 +
  6 +## 0-two-speakers-zh.wav
  7 +
  8 +This file is from
  9 +https://www.modelscope.cn/models/iic/speech_campplus_speaker-diarization_common/file/view/master?fileName=examples%252F2speakers_example.wav&status=0
  10 +
  11 +Note that we have renamed it from `2speakers_example.wav` to `0-two-speakers-zh.wav`.
  12 +
  13 +## 1-two-speakers-en.wav
  14 +
  15 +This file is from
  16 +https://github.com/pengzhendong/pyannote-onnx/blob/master/data/test_16k.wav
  17 +and it contains speeches from two speakers.
  18 +
  19 +Note that we have renamed it from `test_16k.wav` to `1-two-speakers-en.wav`
  20 +
  21 +
  22 +## 2-two-speakers-en.wav
  23 +This file is from
  24 +https://huggingface.co/spaces/Xenova/whisper-speaker-diarization
  25 +
  26 +Note that the original file is `./fcf059e3-689f-47ec-a000-bdace87f0113.mp4`.
  27 +We use the following commands to convert it to `2-two-speakers-en.wav`.
  28 +
  29 +```bash
  30 +ffmpeg -i ./fcf059e3-689f-47ec-a000-bdace87f0113.mp4 -ac 1 -ar 16000 ./2-two-speakers-en.wav
  31 +```
  32 +
  33 +## 3-two-speakers-en.wav
  34 +
  35 +This file is from
  36 +https://aws.amazon.com/blogs/machine-learning/deploy-a-hugging-face-pyannote-speaker-diarization-model-on-amazon-sagemaker-as-an-asynchronous-endpoint/
  37 +
  38 +Note that the original file is `ML16091-Audio.mp3`. We use the following
  39 +commands to convert it to `3-two-speakers-en.wav`
  40 +
  41 +
  42 +```bash
  43 +sox ML16091-Audio.mp3 3-two-speakers-en.wav
  44 +```
  1 +#!/usr/bin/env python3
  2 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +"""
  5 +Please refer to
  6 +https://github.com/k2-fsa/sherpa-onnx/blob/master/.github/workflows/speaker-diarization.yaml
  7 +for usages.
  8 +"""
  9 +
  10 +import argparse
  11 +from datetime import timedelta
  12 +from pathlib import Path
  13 +from typing import List
  14 +
  15 +import librosa
  16 +import numpy as np
  17 +import onnxruntime as ort
  18 +import sherpa_onnx
  19 +import soundfile as sf
  20 +from numpy.lib.stride_tricks import as_strided
  21 +
  22 +
  23 +class Segment:
  24 + def __init__(
  25 + self,
  26 + start,
  27 + end,
  28 + speaker,
  29 + ):
  30 + assert start < end
  31 + self.start = start
  32 + self.end = end
  33 + self.speaker = speaker
  34 +
  35 + def merge(self, other, gap=0.5):
  36 + assert self.speaker == other.speaker, (self.speaker, other.speaker)
  37 + if self.end < other.start and self.end + gap >= other.start:
  38 + return Segment(start=self.start, end=other.end, speaker=self.speaker)
  39 + elif other.end < self.start and other.end + gap >= self.start:
  40 + return Segment(start=other.start, end=self.end, speaker=self.speaker)
  41 + else:
  42 + return None
  43 +
  44 + @property
  45 + def duration(self):
  46 + return self.end - self.start
  47 +
  48 + def __str__(self):
  49 + s = f"{timedelta(seconds=self.start)}"[:-3]
  50 + s += " --> "
  51 + s += f"{timedelta(seconds=self.end)}"[:-3]
  52 + s += f" speaker_{self.speaker:02d}"
  53 + return s
  54 +
  55 +
  56 +def merge_segment_list(in_out: List[Segment], min_duration_off: float):
  57 + changed = True
  58 + while changed:
  59 + changed = False
  60 + for i in range(len(in_out)):
  61 + if i + 1 >= len(in_out):
  62 + continue
  63 +
  64 + new_segment = in_out[i].merge(in_out[i + 1], gap=min_duration_off)
  65 + if new_segment is None:
  66 + continue
  67 + del in_out[i + 1]
  68 + in_out[i] = new_segment
  69 + changed = True
  70 + break
  71 +
  72 +
  73 +def get_args():
  74 + parser = argparse.ArgumentParser()
  75 + parser.add_argument(
  76 + "--seg-model",
  77 + type=str,
  78 + required=True,
  79 + help="Path to model.onnx for segmentation",
  80 + )
  81 + parser.add_argument(
  82 + "--speaker-embedding-model",
  83 + type=str,
  84 + required=True,
  85 + help="Path to model.onnx for speaker embedding extractor",
  86 + )
  87 + parser.add_argument("--wav", type=str, required=True, help="Path to test.wav")
  88 +
  89 + return parser.parse_args()
  90 +
  91 +
  92 +class OnnxSegmentationModel:
  93 + def __init__(self, filename):
  94 + session_opts = ort.SessionOptions()
  95 + session_opts.inter_op_num_threads = 1
  96 + session_opts.intra_op_num_threads = 1
  97 +
  98 + self.session_opts = session_opts
  99 +
  100 + self.model = ort.InferenceSession(
  101 + filename,
  102 + sess_options=self.session_opts,
  103 + providers=["CPUExecutionProvider"],
  104 + )
  105 +
  106 + meta = self.model.get_modelmeta().custom_metadata_map
  107 + print(meta)
  108 +
  109 + self.window_size = int(meta["window_size"])
  110 + self.sample_rate = int(meta["sample_rate"])
  111 + self.window_shift = int(0.1 * self.window_size)
  112 + self.receptive_field_size = int(meta["receptive_field_size"])
  113 + self.receptive_field_shift = int(meta["receptive_field_shift"])
  114 + self.num_speakers = int(meta["num_speakers"])
  115 + self.powerset_max_classes = int(meta["powerset_max_classes"])
  116 + self.num_classes = int(meta["num_classes"])
  117 +
  118 + def __call__(self, x):
  119 + """
  120 + Args:
  121 + x: (N, num_samples)
  122 + Returns:
  123 + A tensor of shape (N, num_frames, num_classes)
  124 + """
  125 + x = np.expand_dims(x, axis=1)
  126 +
  127 + (y,) = self.model.run(
  128 + [self.model.get_outputs()[0].name], {self.model.get_inputs()[0].name: x}
  129 + )
  130 +
  131 + return y
  132 +
  133 +
  134 +def load_wav(filename, expected_sample_rate) -> np.ndarray:
  135 + audio, sample_rate = sf.read(filename, dtype="float32", always_2d=True)
  136 + audio = audio[:, 0] # only use the first channel
  137 + if sample_rate != expected_sample_rate:
  138 + audio = librosa.resample(
  139 + audio,
  140 + orig_sr=sample_rate,
  141 + target_sr=expected_sample_rate,
  142 + )
  143 + return audio
  144 +
  145 +
  146 +def get_powerset_mapping(num_classes, num_speakers, powerset_max_classes):
  147 + mapping = np.zeros((num_classes, num_speakers))
  148 +
  149 + k = 1
  150 + for i in range(1, powerset_max_classes + 1):
  151 + if i == 1:
  152 + for j in range(0, num_speakers):
  153 + mapping[k, j] = 1
  154 + k += 1
  155 + elif i == 2:
  156 + for j in range(0, num_speakers):
  157 + for m in range(j + 1, num_speakers):
  158 + mapping[k, j] = 1
  159 + mapping[k, m] = 1
  160 + k += 1
  161 + elif i == 3:
  162 + raise RuntimeError("Unsupported")
  163 +
  164 + return mapping
  165 +
  166 +
  167 +def to_multi_label(y, mapping):
  168 + """
  169 + Args:
  170 + y: (num_chunks, num_frames, num_classes)
  171 + Returns:
  172 + A tensor of shape (num_chunks, num_frames, num_speakers)
  173 + """
  174 + y = np.argmax(y, axis=-1)
  175 + labels = mapping[y.reshape(-1)].reshape(y.shape[0], y.shape[1], -1)
  176 + return labels
  177 +
  178 +
  179 +# speaker count per frame
  180 +def speaker_count(labels, seg_m):
  181 + """
  182 + Args:
  183 + labels: (num_chunks, num_frames, num_speakers)
  184 + seg_m: Segmentation model
  185 + Returns:
  186 + A integer array of shape (num_total_frames,)
  187 + """
  188 + labels = labels.sum(axis=-1)
  189 + # Now labels: (num_chunks, num_frames)
  190 +
  191 + num_frames = (
  192 + int(
  193 + (seg_m.window_size + (labels.shape[0] - 1) * seg_m.window_shift)
  194 + / seg_m.receptive_field_shift
  195 + )
  196 + + 1
  197 + )
  198 + ans = np.zeros((num_frames,))
  199 + count = np.zeros((num_frames,))
  200 +
  201 + for i in range(labels.shape[0]):
  202 + this_chunk = labels[i]
  203 + start = int(i * seg_m.window_shift / seg_m.receptive_field_shift + 0.5)
  204 + end = start + this_chunk.shape[0]
  205 + ans[start:end] += this_chunk
  206 + count[start:end] += 1
  207 +
  208 + ans /= np.maximum(count, 1e-12)
  209 +
  210 + return (ans + 0.5).astype(np.int8)
  211 +
  212 +
  213 +def load_speaker_embedding_model(filename):
  214 + config = sherpa_onnx.SpeakerEmbeddingExtractorConfig(
  215 + model=filename,
  216 + num_threads=1,
  217 + debug=0,
  218 + )
  219 + if not config.validate():
  220 + raise ValueError(f"Invalid config. {config}")
  221 + extractor = sherpa_onnx.SpeakerEmbeddingExtractor(config)
  222 + return extractor
  223 +
  224 +
  225 +def get_embeddings(embedding_filename, audio, labels, seg_m, exclude_overlap):
  226 + """
  227 + Args:
  228 + embedding_filename: Path to the speaker embedding extractor model
  229 + audio: (num_samples,)
  230 + labels: (num_chunks, num_frames, num_speakers)
  231 + seg_m: segmentation model
  232 + Returns:
  233 + Return (num_chunks, num_speakers, embedding_dim)
  234 + """
  235 + if exclude_overlap:
  236 + labels = labels * (labels.sum(axis=-1, keepdims=True) < 2)
  237 +
  238 + extractor = load_speaker_embedding_model(embedding_filename)
  239 + buffer = np.empty(seg_m.window_size)
  240 + num_chunks, num_frames, num_speakers = labels.shape
  241 +
  242 + ans_chunk_speaker_pair = []
  243 + ans_embeddings = []
  244 +
  245 + for i in range(num_chunks):
  246 + labels_T = labels[i].T
  247 + # t: (num_speakers, num_frames)
  248 +
  249 + sample_offset = i * seg_m.window_shift
  250 +
  251 + for j in range(num_speakers):
  252 + frames = labels_T[j]
  253 + if frames.sum() < 10:
  254 + # skip segment less than 20 frames, i.e., about 0.2 seconds
  255 + continue
  256 +
  257 + start = None
  258 + start_samples = 0
  259 + idx = 0
  260 + for k in range(num_frames):
  261 + if frames[k] != 0:
  262 + if start is None:
  263 + start = k
  264 + elif start is not None:
  265 + start_samples = (
  266 + int(start / num_frames * seg_m.window_size) + sample_offset
  267 + )
  268 + end_samples = (
  269 + int(k / num_frames * seg_m.window_size) + sample_offset
  270 + )
  271 + num_samples = end_samples - start_samples
  272 + buffer[idx : idx + num_samples] = audio[start_samples:end_samples]
  273 + idx += num_samples
  274 +
  275 + start = None
  276 + if start is not None:
  277 + start_samples = (
  278 + int(start / num_frames * seg_m.window_size) + sample_offset
  279 + )
  280 + end_samples = int(k / num_frames * seg_m.window_size) + sample_offset
  281 + num_samples = end_samples - start_samples
  282 + buffer[idx : idx + num_samples] = audio[start_samples:end_samples]
  283 + idx += num_samples
  284 +
  285 + stream = extractor.create_stream()
  286 + stream.accept_waveform(sample_rate=seg_m.sample_rate, waveform=buffer[:idx])
  287 + stream.input_finished()
  288 +
  289 + assert extractor.is_ready(stream)
  290 + embedding = extractor.compute(stream)
  291 + embedding = np.array(embedding)
  292 +
  293 + ans_chunk_speaker_pair.append([i, j])
  294 + ans_embeddings.append(embedding)
  295 +
  296 + assert len(ans_chunk_speaker_pair) == len(ans_embeddings), (
  297 + len(ans_chunk_speaker_pair),
  298 + len(ans_embeddings),
  299 + )
  300 + return ans_chunk_speaker_pair, np.array(ans_embeddings)
  301 +
  302 +
  303 +def main():
  304 + args = get_args()
  305 + assert Path(args.seg_model).is_file(), args.seg_model
  306 + assert Path(args.wav).is_file(), args.wav
  307 +
  308 + seg_m = OnnxSegmentationModel(args.seg_model)
  309 + audio = load_wav(args.wav, seg_m.sample_rate)
  310 + # audio: (num_samples,)
  311 +
  312 + num = (audio.shape[0] - seg_m.window_size) // seg_m.window_shift + 1
  313 +
  314 + samples = as_strided(
  315 + audio,
  316 + shape=(num, seg_m.window_size),
  317 + strides=(seg_m.window_shift * audio.strides[0], audio.strides[0]),
  318 + )
  319 +
  320 + # or use torch.Tensor.unfold
  321 + # samples = torch.from_numpy(audio).unfold(0, seg_m.window_size, seg_m.window_shift).numpy()
  322 +
  323 + if (
  324 + audio.shape[0] < seg_m.window_size
  325 + or (audio.shape[0] - seg_m.window_size) % seg_m.window_shift > 0
  326 + ):
  327 + has_last_chunk = True
  328 + else:
  329 + has_last_chunk = False
  330 +
  331 + num_chunks = samples.shape[0]
  332 + batch_size = 32
  333 + output = []
  334 + for i in range(0, num_chunks, batch_size):
  335 + start = i
  336 + end = i + batch_size
  337 + # it's perfectly ok to use end > num_chunks
  338 + y = seg_m(samples[start:end])
  339 + output.append(y)
  340 +
  341 + if has_last_chunk:
  342 + last_chunk = audio[num_chunks * seg_m.window_shift :] # noqa
  343 + pad_size = seg_m.window_size - last_chunk.shape[0]
  344 + last_chunk = np.pad(last_chunk, (0, pad_size))
  345 + last_chunk = np.expand_dims(last_chunk, axis=0)
  346 + y = seg_m(last_chunk)
  347 + output.append(y)
  348 +
  349 + y = np.vstack(output)
  350 + # y: (num_chunks, num_frames, num_classes)
  351 +
  352 + mapping = get_powerset_mapping(
  353 + num_classes=seg_m.num_classes,
  354 + num_speakers=seg_m.num_speakers,
  355 + powerset_max_classes=seg_m.powerset_max_classes,
  356 + )
  357 + labels = to_multi_label(y, mapping=mapping)
  358 + # labels: (num_chunks, num_frames, num_speakers)
  359 +
  360 + inactive = (labels.sum(axis=1) == 0).astype(np.int8)
  361 + # inactive: (num_chunks, num_speakers)
  362 +
  363 + speakers_per_frame = speaker_count(labels=labels, seg_m=seg_m)
  364 + # speakers_per_frame: (num_frames, speakers_per_frame)
  365 +
  366 + if speakers_per_frame.max() == 0:
  367 + print("No speakers found in the audio file!")
  368 + return
  369 +
  370 + # if users specify only 1 speaker for clustering, then return the
  371 + # result directly
  372 +
  373 + # Now, get embeddings
  374 + chunk_speaker_pair, embeddings = get_embeddings(
  375 + args.speaker_embedding_model,
  376 + audio=audio,
  377 + labels=labels,
  378 + seg_m=seg_m,
  379 + # exclude_overlap=True,
  380 + exclude_overlap=False,
  381 + )
  382 + # chunk_speaker_pair: a list of (chunk_idx, speaker_idx)
  383 + # embeddings: (batch_size, embedding_dim)
  384 +
  385 + # Please change num_clusters or threshold by yourself.
  386 + clustering_config = sherpa_onnx.FastClusteringConfig(num_clusters=2)
  387 + # clustering_config = sherpa_onnx.FastClusteringConfig(threshold=0.8)
  388 + clustering = sherpa_onnx.FastClustering(clustering_config)
  389 + cluster_labels = clustering(embeddings)
  390 +
  391 + chunk_speaker_to_cluster = dict()
  392 + for (chunk_idx, speaker_idx), cluster_idx in zip(
  393 + chunk_speaker_pair, cluster_labels
  394 + ):
  395 + if inactive[chunk_idx, speaker_idx] == 1:
  396 + print("skip ", chunk_idx, speaker_idx)
  397 + continue
  398 + chunk_speaker_to_cluster[(chunk_idx, speaker_idx)] = cluster_idx
  399 +
  400 + num_speakers = max(cluster_labels) + 1
  401 + relabels = np.zeros((labels.shape[0], labels.shape[1], num_speakers))
  402 + for i in range(labels.shape[0]):
  403 + for j in range(labels.shape[1]):
  404 + for k in range(labels.shape[2]):
  405 + if (i, k) not in chunk_speaker_to_cluster:
  406 + continue
  407 + t = chunk_speaker_to_cluster[(i, k)]
  408 +
  409 + if labels[i, j, k] == 1:
  410 + relabels[i, j, t] = 1
  411 +
  412 + num_frames = (
  413 + int(
  414 + (seg_m.window_size + (relabels.shape[0] - 1) * seg_m.window_shift)
  415 + / seg_m.receptive_field_shift
  416 + )
  417 + + 1
  418 + )
  419 +
  420 + count = np.zeros((num_frames, relabels.shape[-1]))
  421 + for i in range(relabels.shape[0]):
  422 + this_chunk = relabels[i]
  423 + start = int(i * seg_m.window_shift / seg_m.receptive_field_shift + 0.5)
  424 + end = start + this_chunk.shape[0]
  425 + count[start:end] += this_chunk
  426 +
  427 + if has_last_chunk:
  428 + stop_frame = int(audio.shape[0] / seg_m.receptive_field_shift)
  429 + count = count[:stop_frame]
  430 +
  431 + sorted_count = np.argsort(-count, axis=-1)
  432 + final = np.zeros((count.shape[0], count.shape[1]))
  433 +
  434 + for i, (c, sc) in enumerate(zip(speakers_per_frame, sorted_count)):
  435 + for k in range(c):
  436 + final[i, sc[k]] = 1
  437 +
  438 + min_duration_off = 0.5
  439 + min_duration_on = 0.3
  440 + onset = 0.5
  441 + offset = 0.5
  442 + # final: (num_frames, num_speakers)
  443 +
  444 + final = final.T
  445 + for kk in range(final.shape[0]):
  446 + segment_list = []
  447 + frames = final[kk]
  448 +
  449 + is_active = frames[0] > onset
  450 +
  451 + start = None
  452 + if is_active:
  453 + start = 0
  454 + scale = seg_m.receptive_field_shift / seg_m.sample_rate
  455 + scale_offset = seg_m.receptive_field_size / seg_m.sample_rate * 0.5
  456 + for i in range(1, len(frames)):
  457 + if is_active:
  458 + if frames[i] < offset:
  459 + segment = Segment(
  460 + start=start * scale + scale_offset,
  461 + end=i * scale + scale_offset,
  462 + speaker=kk,
  463 + )
  464 + segment_list.append(segment)
  465 + is_active = False
  466 + else:
  467 + if frames[i] > onset:
  468 + start = i
  469 + is_active = True
  470 +
  471 + if is_active:
  472 + segment = Segment(
  473 + start=start * scale + scale_offset,
  474 + end=(len(frames) - 1) * scale + scale_offset,
  475 + speaker=kk,
  476 + )
  477 + segment_list.append(segment)
  478 +
  479 + if len(segment_list) > 1:
  480 + merge_segment_list(segment_list, min_duration_off=min_duration_off)
  481 + for s in segment_list:
  482 + if s.duration < min_duration_on:
  483 + continue
  484 + print(s)
  485 +
  486 +
  487 +if __name__ == "__main__":
  488 + main()
  1 +#!/usr/bin/env python3
  2 +
  3 +"""
  4 +Please refer to
  5 +https://github.com/k2-fsa/sherpa-onnx/blob/master/.github/workflows/speaker-diarization.yaml
  6 +for usages.
  7 +"""
  8 +
  9 +"""
  10 +1. Go to https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM/tree/main
  11 +wget https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM/resolve/main/speaker-embedding.onnx
  12 +
  13 +2. Change line 166 of pyannote/audio/pipelines/speaker_diarization.py
  14 +
  15 +```
  16 + # self._embedding = PretrainedSpeakerEmbedding(
  17 + # self.embedding, use_auth_token=use_auth_token
  18 + # )
  19 + self._embedding = embedding
  20 +```
  21 +"""
  22 +
  23 +import argparse
  24 +from pathlib import Path
  25 +
  26 +import torch
  27 +from pyannote.audio import Model
  28 +from pyannote.audio.pipelines import SpeakerDiarization as SpeakerDiarizationPipeline
  29 +from pyannote.audio.pipelines.speaker_verification import (
  30 + ONNXWeSpeakerPretrainedSpeakerEmbedding,
  31 +)
  32 +
  33 +
  34 +def get_args():
  35 + parser = argparse.ArgumentParser()
  36 + parser.add_argument("--wav", type=str, required=True, help="Path to test.wav")
  37 +
  38 + return parser.parse_args()
  39 +
  40 +
  41 +def build_pipeline():
  42 + embedding_filename = "./speaker-embedding.onnx"
  43 + if Path(embedding_filename).is_file():
  44 + # You need to modify line 166
  45 + # of pyannote/audio/pipelines/speaker_diarization.py
  46 + # Please see the comments at the start of this script for details
  47 + embedding = ONNXWeSpeakerPretrainedSpeakerEmbedding(embedding_filename)
  48 + else:
  49 + embedding = "hbredin/wespeaker-voxceleb-resnet34-LM"
  50 +
  51 + pt_filename = "./pytorch_model.bin"
  52 + segmentation = Model.from_pretrained(pt_filename)
  53 + segmentation.eval()
  54 +
  55 + pipeline = SpeakerDiarizationPipeline(
  56 + segmentation=segmentation,
  57 + embedding=embedding,
  58 + embedding_exclude_overlap=True,
  59 + )
  60 +
  61 + params = {
  62 + "clustering": {
  63 + "method": "centroid",
  64 + "min_cluster_size": 12,
  65 + "threshold": 0.7045654963945799,
  66 + },
  67 + "segmentation": {"min_duration_off": 0.5},
  68 + }
  69 +
  70 + pipeline.instantiate(params)
  71 + return pipeline
  72 +
  73 +
  74 +@torch.no_grad()
  75 +def main():
  76 + args = get_args()
  77 + assert Path(args.wav).is_file(), args.wav
  78 + pipeline = build_pipeline()
  79 + print(pipeline)
  80 + t = pipeline(args.wav)
  81 + print(type(t))
  82 + print(t)
  83 +
  84 +
  85 +if __name__ == "__main__":
  86 + main()
@@ -52,7 +52,7 @@ class FastClustering::Impl { @@ -52,7 +52,7 @@ class FastClustering::Impl {
52 std::vector<double> height(num_rows - 1); 52 std::vector<double> height(num_rows - 1);
53 53
54 fastclustercpp::hclust_fast(num_rows, distance.data(), 54 fastclustercpp::hclust_fast(num_rows, distance.data(),
55 - fastclustercpp::HCLUST_METHOD_SINGLE, 55 + fastclustercpp::HCLUST_METHOD_COMPLETE,
56 merge.data(), height.data()); 56 merge.data(), height.data());
57 57
58 std::vector<int32_t> labels(num_rows); 58 std::vector<int32_t> labels(num_rows);