Committed by
GitHub
Speaker diarization example with onnxruntime Python API (#1395)
正在显示
6 个修改的文件
包含
719 行增加
和
1 行删除
.github/workflows/speaker-diarization.yaml
0 → 100644
| 1 | +name: speaker-diarization | ||
| 2 | + | ||
| 3 | +on: | ||
| 4 | + push: | ||
| 5 | + branches: | ||
| 6 | + - speaker-diarization | ||
| 7 | + workflow_dispatch: | ||
| 8 | + | ||
| 9 | +concurrency: | ||
| 10 | + group: speaker-diarization-${{ github.ref }} | ||
| 11 | + cancel-in-progress: true | ||
| 12 | + | ||
| 13 | +jobs: | ||
| 14 | + linux: | ||
| 15 | + name: speaker diarization | ||
| 16 | + runs-on: ${{ matrix.os }} | ||
| 17 | + strategy: | ||
| 18 | + fail-fast: false | ||
| 19 | + matrix: | ||
| 20 | + os: [macos-latest] | ||
| 21 | + python-version: ["3.10"] | ||
| 22 | + | ||
| 23 | + steps: | ||
| 24 | + - uses: actions/checkout@v4 | ||
| 25 | + with: | ||
| 26 | + fetch-depth: 0 | ||
| 27 | + | ||
| 28 | + - name: ccache | ||
| 29 | + uses: hendrikmuhs/ccache-action@v1.2 | ||
| 30 | + with: | ||
| 31 | + key: ${{ matrix.os }}-speaker-diarization | ||
| 32 | + | ||
| 33 | + - name: Setup Python ${{ matrix.python-version }} | ||
| 34 | + uses: actions/setup-python@v5 | ||
| 35 | + with: | ||
| 36 | + python-version: ${{ matrix.python-version }} | ||
| 37 | + | ||
| 38 | + - name: Install pyannote | ||
| 39 | + shell: bash | ||
| 40 | + run: | | ||
| 41 | + pip install pyannote.audio onnx onnxruntime | ||
| 42 | + | ||
| 43 | + - name: Install sherpa-onnx from source | ||
| 44 | + shell: bash | ||
| 45 | + run: | | ||
| 46 | + python3 -m pip install --upgrade pip | ||
| 47 | + python3 -m pip install wheel twine setuptools | ||
| 48 | + | ||
| 49 | + export CMAKE_CXX_COMPILER_LAUNCHER=ccache | ||
| 50 | + export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH" | ||
| 51 | + | ||
| 52 | + cat sherpa-onnx/python/sherpa_onnx/__init__.py | ||
| 53 | + | ||
| 54 | + python3 setup.py bdist_wheel | ||
| 55 | + ls -lh dist | ||
| 56 | + pip install ./dist/*.whl | ||
| 57 | + | ||
| 58 | + - name: Run tests | ||
| 59 | + shell: bash | ||
| 60 | + run: | | ||
| 61 | + pushd scripts/pyannote/segmentation | ||
| 62 | + | ||
| 63 | + python3 -c "import sherpa_onnx; print(sherpa_onnx.__file__)" | ||
| 64 | + python3 -c "import sherpa_onnx; print(sherpa_onnx.__version__)" | ||
| 65 | + python3 -c "import sherpa_onnx; print(dir(sherpa_onnx))" | ||
| 66 | + | ||
| 67 | + curl -SL -O https://huggingface.co/csukuangfj/pyannote-models/resolve/main/segmentation-3.0/pytorch_model.bin | ||
| 68 | + | ||
| 69 | + test_wavs=( | ||
| 70 | + 0-two-speakers-zh.wav | ||
| 71 | + 1-two-speakers-en.wav | ||
| 72 | + 2-two-speakers-en.wav | ||
| 73 | + 3-two-speakers-en.wav | ||
| 74 | + ) | ||
| 75 | + | ||
| 76 | + for w in ${test_wavs[@]}; do | ||
| 77 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/$w | ||
| 78 | + done | ||
| 79 | + | ||
| 80 | + soxi *.wav | ||
| 81 | + | ||
| 82 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 | ||
| 83 | + tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 | ||
| 84 | + rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 | ||
| 85 | + ls -lh sherpa-onnx-pyannote-segmentation-3-0 | ||
| 86 | + | ||
| 87 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx | ||
| 88 | + | ||
| 89 | + for w in ${test_wavs[@]}; do | ||
| 90 | + echo "---------test $w (onnx)----------" | ||
| 91 | + time ./speaker-diarization-onnx.py \ | ||
| 92 | + --seg-model ./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \ | ||
| 93 | + --speaker-embedding-model ./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \ | ||
| 94 | + --wav $w | ||
| 95 | + | ||
| 96 | + echo "---------test $w (torch)----------" | ||
| 97 | + time ./speaker-diarization-torch.py --wav $w | ||
| 98 | + done |
scripts/pyannote/segmentation/README.md
0 → 100644
| 1 | +# File description | ||
| 2 | + | ||
| 3 | +Please download test wave files from | ||
| 4 | +https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models | ||
| 5 | + | ||
| 6 | +## 0-two-speakers-zh.wav | ||
| 7 | + | ||
| 8 | +This file is from | ||
| 9 | +https://www.modelscope.cn/models/iic/speech_campplus_speaker-diarization_common/file/view/master?fileName=examples%252F2speakers_example.wav&status=0 | ||
| 10 | + | ||
| 11 | +Note that we have renamed it from `2speakers_example.wav` to `0-two-speakers-zh.wav`. | ||
| 12 | + | ||
| 13 | +## 1-two-speakers-en.wav | ||
| 14 | + | ||
| 15 | +This file is from | ||
| 16 | +https://github.com/pengzhendong/pyannote-onnx/blob/master/data/test_16k.wav | ||
| 17 | +and it contains speeches from two speakers. | ||
| 18 | + | ||
| 19 | +Note that we have renamed it from `test_16k.wav` to `1-two-speakers-en.wav` | ||
| 20 | + | ||
| 21 | + | ||
| 22 | +## 2-two-speakers-en.wav | ||
| 23 | +This file is from | ||
| 24 | +https://huggingface.co/spaces/Xenova/whisper-speaker-diarization | ||
| 25 | + | ||
| 26 | +Note that the original file is `./fcf059e3-689f-47ec-a000-bdace87f0113.mp4`. | ||
| 27 | +We use the following commands to convert it to `2-two-speakers-en.wav`. | ||
| 28 | + | ||
| 29 | +```bash | ||
| 30 | +ffmpeg -i ./fcf059e3-689f-47ec-a000-bdace87f0113.mp4 -ac 1 -ar 16000 ./2-two-speakers-en.wav | ||
| 31 | +``` | ||
| 32 | + | ||
| 33 | +## 3-two-speakers-en.wav | ||
| 34 | + | ||
| 35 | +This file is from | ||
| 36 | +https://aws.amazon.com/blogs/machine-learning/deploy-a-hugging-face-pyannote-speaker-diarization-model-on-amazon-sagemaker-as-an-asynchronous-endpoint/ | ||
| 37 | + | ||
| 38 | +Note that the original file is `ML16091-Audio.mp3`. We use the following | ||
| 39 | +commands to convert it to `3-two-speakers-en.wav` | ||
| 40 | + | ||
| 41 | + | ||
| 42 | +```bash | ||
| 43 | +sox ML16091-Audio.mp3 3-two-speakers-en.wav | ||
| 44 | +``` |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +""" | ||
| 5 | +Please refer to | ||
| 6 | +https://github.com/k2-fsa/sherpa-onnx/blob/master/.github/workflows/speaker-diarization.yaml | ||
| 7 | +for usages. | ||
| 8 | +""" | ||
| 9 | + | ||
| 10 | +import argparse | ||
| 11 | +from datetime import timedelta | ||
| 12 | +from pathlib import Path | ||
| 13 | +from typing import List | ||
| 14 | + | ||
| 15 | +import librosa | ||
| 16 | +import numpy as np | ||
| 17 | +import onnxruntime as ort | ||
| 18 | +import sherpa_onnx | ||
| 19 | +import soundfile as sf | ||
| 20 | +from numpy.lib.stride_tricks import as_strided | ||
| 21 | + | ||
| 22 | + | ||
| 23 | +class Segment: | ||
| 24 | + def __init__( | ||
| 25 | + self, | ||
| 26 | + start, | ||
| 27 | + end, | ||
| 28 | + speaker, | ||
| 29 | + ): | ||
| 30 | + assert start < end | ||
| 31 | + self.start = start | ||
| 32 | + self.end = end | ||
| 33 | + self.speaker = speaker | ||
| 34 | + | ||
| 35 | + def merge(self, other, gap=0.5): | ||
| 36 | + assert self.speaker == other.speaker, (self.speaker, other.speaker) | ||
| 37 | + if self.end < other.start and self.end + gap >= other.start: | ||
| 38 | + return Segment(start=self.start, end=other.end, speaker=self.speaker) | ||
| 39 | + elif other.end < self.start and other.end + gap >= self.start: | ||
| 40 | + return Segment(start=other.start, end=self.end, speaker=self.speaker) | ||
| 41 | + else: | ||
| 42 | + return None | ||
| 43 | + | ||
| 44 | + @property | ||
| 45 | + def duration(self): | ||
| 46 | + return self.end - self.start | ||
| 47 | + | ||
| 48 | + def __str__(self): | ||
| 49 | + s = f"{timedelta(seconds=self.start)}"[:-3] | ||
| 50 | + s += " --> " | ||
| 51 | + s += f"{timedelta(seconds=self.end)}"[:-3] | ||
| 52 | + s += f" speaker_{self.speaker:02d}" | ||
| 53 | + return s | ||
| 54 | + | ||
| 55 | + | ||
| 56 | +def merge_segment_list(in_out: List[Segment], min_duration_off: float): | ||
| 57 | + changed = True | ||
| 58 | + while changed: | ||
| 59 | + changed = False | ||
| 60 | + for i in range(len(in_out)): | ||
| 61 | + if i + 1 >= len(in_out): | ||
| 62 | + continue | ||
| 63 | + | ||
| 64 | + new_segment = in_out[i].merge(in_out[i + 1], gap=min_duration_off) | ||
| 65 | + if new_segment is None: | ||
| 66 | + continue | ||
| 67 | + del in_out[i + 1] | ||
| 68 | + in_out[i] = new_segment | ||
| 69 | + changed = True | ||
| 70 | + break | ||
| 71 | + | ||
| 72 | + | ||
| 73 | +def get_args(): | ||
| 74 | + parser = argparse.ArgumentParser() | ||
| 75 | + parser.add_argument( | ||
| 76 | + "--seg-model", | ||
| 77 | + type=str, | ||
| 78 | + required=True, | ||
| 79 | + help="Path to model.onnx for segmentation", | ||
| 80 | + ) | ||
| 81 | + parser.add_argument( | ||
| 82 | + "--speaker-embedding-model", | ||
| 83 | + type=str, | ||
| 84 | + required=True, | ||
| 85 | + help="Path to model.onnx for speaker embedding extractor", | ||
| 86 | + ) | ||
| 87 | + parser.add_argument("--wav", type=str, required=True, help="Path to test.wav") | ||
| 88 | + | ||
| 89 | + return parser.parse_args() | ||
| 90 | + | ||
| 91 | + | ||
| 92 | +class OnnxSegmentationModel: | ||
| 93 | + def __init__(self, filename): | ||
| 94 | + session_opts = ort.SessionOptions() | ||
| 95 | + session_opts.inter_op_num_threads = 1 | ||
| 96 | + session_opts.intra_op_num_threads = 1 | ||
| 97 | + | ||
| 98 | + self.session_opts = session_opts | ||
| 99 | + | ||
| 100 | + self.model = ort.InferenceSession( | ||
| 101 | + filename, | ||
| 102 | + sess_options=self.session_opts, | ||
| 103 | + providers=["CPUExecutionProvider"], | ||
| 104 | + ) | ||
| 105 | + | ||
| 106 | + meta = self.model.get_modelmeta().custom_metadata_map | ||
| 107 | + print(meta) | ||
| 108 | + | ||
| 109 | + self.window_size = int(meta["window_size"]) | ||
| 110 | + self.sample_rate = int(meta["sample_rate"]) | ||
| 111 | + self.window_shift = int(0.1 * self.window_size) | ||
| 112 | + self.receptive_field_size = int(meta["receptive_field_size"]) | ||
| 113 | + self.receptive_field_shift = int(meta["receptive_field_shift"]) | ||
| 114 | + self.num_speakers = int(meta["num_speakers"]) | ||
| 115 | + self.powerset_max_classes = int(meta["powerset_max_classes"]) | ||
| 116 | + self.num_classes = int(meta["num_classes"]) | ||
| 117 | + | ||
| 118 | + def __call__(self, x): | ||
| 119 | + """ | ||
| 120 | + Args: | ||
| 121 | + x: (N, num_samples) | ||
| 122 | + Returns: | ||
| 123 | + A tensor of shape (N, num_frames, num_classes) | ||
| 124 | + """ | ||
| 125 | + x = np.expand_dims(x, axis=1) | ||
| 126 | + | ||
| 127 | + (y,) = self.model.run( | ||
| 128 | + [self.model.get_outputs()[0].name], {self.model.get_inputs()[0].name: x} | ||
| 129 | + ) | ||
| 130 | + | ||
| 131 | + return y | ||
| 132 | + | ||
| 133 | + | ||
| 134 | +def load_wav(filename, expected_sample_rate) -> np.ndarray: | ||
| 135 | + audio, sample_rate = sf.read(filename, dtype="float32", always_2d=True) | ||
| 136 | + audio = audio[:, 0] # only use the first channel | ||
| 137 | + if sample_rate != expected_sample_rate: | ||
| 138 | + audio = librosa.resample( | ||
| 139 | + audio, | ||
| 140 | + orig_sr=sample_rate, | ||
| 141 | + target_sr=expected_sample_rate, | ||
| 142 | + ) | ||
| 143 | + return audio | ||
| 144 | + | ||
| 145 | + | ||
| 146 | +def get_powerset_mapping(num_classes, num_speakers, powerset_max_classes): | ||
| 147 | + mapping = np.zeros((num_classes, num_speakers)) | ||
| 148 | + | ||
| 149 | + k = 1 | ||
| 150 | + for i in range(1, powerset_max_classes + 1): | ||
| 151 | + if i == 1: | ||
| 152 | + for j in range(0, num_speakers): | ||
| 153 | + mapping[k, j] = 1 | ||
| 154 | + k += 1 | ||
| 155 | + elif i == 2: | ||
| 156 | + for j in range(0, num_speakers): | ||
| 157 | + for m in range(j + 1, num_speakers): | ||
| 158 | + mapping[k, j] = 1 | ||
| 159 | + mapping[k, m] = 1 | ||
| 160 | + k += 1 | ||
| 161 | + elif i == 3: | ||
| 162 | + raise RuntimeError("Unsupported") | ||
| 163 | + | ||
| 164 | + return mapping | ||
| 165 | + | ||
| 166 | + | ||
| 167 | +def to_multi_label(y, mapping): | ||
| 168 | + """ | ||
| 169 | + Args: | ||
| 170 | + y: (num_chunks, num_frames, num_classes) | ||
| 171 | + Returns: | ||
| 172 | + A tensor of shape (num_chunks, num_frames, num_speakers) | ||
| 173 | + """ | ||
| 174 | + y = np.argmax(y, axis=-1) | ||
| 175 | + labels = mapping[y.reshape(-1)].reshape(y.shape[0], y.shape[1], -1) | ||
| 176 | + return labels | ||
| 177 | + | ||
| 178 | + | ||
| 179 | +# speaker count per frame | ||
| 180 | +def speaker_count(labels, seg_m): | ||
| 181 | + """ | ||
| 182 | + Args: | ||
| 183 | + labels: (num_chunks, num_frames, num_speakers) | ||
| 184 | + seg_m: Segmentation model | ||
| 185 | + Returns: | ||
| 186 | + A integer array of shape (num_total_frames,) | ||
| 187 | + """ | ||
| 188 | + labels = labels.sum(axis=-1) | ||
| 189 | + # Now labels: (num_chunks, num_frames) | ||
| 190 | + | ||
| 191 | + num_frames = ( | ||
| 192 | + int( | ||
| 193 | + (seg_m.window_size + (labels.shape[0] - 1) * seg_m.window_shift) | ||
| 194 | + / seg_m.receptive_field_shift | ||
| 195 | + ) | ||
| 196 | + + 1 | ||
| 197 | + ) | ||
| 198 | + ans = np.zeros((num_frames,)) | ||
| 199 | + count = np.zeros((num_frames,)) | ||
| 200 | + | ||
| 201 | + for i in range(labels.shape[0]): | ||
| 202 | + this_chunk = labels[i] | ||
| 203 | + start = int(i * seg_m.window_shift / seg_m.receptive_field_shift + 0.5) | ||
| 204 | + end = start + this_chunk.shape[0] | ||
| 205 | + ans[start:end] += this_chunk | ||
| 206 | + count[start:end] += 1 | ||
| 207 | + | ||
| 208 | + ans /= np.maximum(count, 1e-12) | ||
| 209 | + | ||
| 210 | + return (ans + 0.5).astype(np.int8) | ||
| 211 | + | ||
| 212 | + | ||
| 213 | +def load_speaker_embedding_model(filename): | ||
| 214 | + config = sherpa_onnx.SpeakerEmbeddingExtractorConfig( | ||
| 215 | + model=filename, | ||
| 216 | + num_threads=1, | ||
| 217 | + debug=0, | ||
| 218 | + ) | ||
| 219 | + if not config.validate(): | ||
| 220 | + raise ValueError(f"Invalid config. {config}") | ||
| 221 | + extractor = sherpa_onnx.SpeakerEmbeddingExtractor(config) | ||
| 222 | + return extractor | ||
| 223 | + | ||
| 224 | + | ||
| 225 | +def get_embeddings(embedding_filename, audio, labels, seg_m, exclude_overlap): | ||
| 226 | + """ | ||
| 227 | + Args: | ||
| 228 | + embedding_filename: Path to the speaker embedding extractor model | ||
| 229 | + audio: (num_samples,) | ||
| 230 | + labels: (num_chunks, num_frames, num_speakers) | ||
| 231 | + seg_m: segmentation model | ||
| 232 | + Returns: | ||
| 233 | + Return (num_chunks, num_speakers, embedding_dim) | ||
| 234 | + """ | ||
| 235 | + if exclude_overlap: | ||
| 236 | + labels = labels * (labels.sum(axis=-1, keepdims=True) < 2) | ||
| 237 | + | ||
| 238 | + extractor = load_speaker_embedding_model(embedding_filename) | ||
| 239 | + buffer = np.empty(seg_m.window_size) | ||
| 240 | + num_chunks, num_frames, num_speakers = labels.shape | ||
| 241 | + | ||
| 242 | + ans_chunk_speaker_pair = [] | ||
| 243 | + ans_embeddings = [] | ||
| 244 | + | ||
| 245 | + for i in range(num_chunks): | ||
| 246 | + labels_T = labels[i].T | ||
| 247 | + # t: (num_speakers, num_frames) | ||
| 248 | + | ||
| 249 | + sample_offset = i * seg_m.window_shift | ||
| 250 | + | ||
| 251 | + for j in range(num_speakers): | ||
| 252 | + frames = labels_T[j] | ||
| 253 | + if frames.sum() < 10: | ||
| 254 | + # skip segment less than 20 frames, i.e., about 0.2 seconds | ||
| 255 | + continue | ||
| 256 | + | ||
| 257 | + start = None | ||
| 258 | + start_samples = 0 | ||
| 259 | + idx = 0 | ||
| 260 | + for k in range(num_frames): | ||
| 261 | + if frames[k] != 0: | ||
| 262 | + if start is None: | ||
| 263 | + start = k | ||
| 264 | + elif start is not None: | ||
| 265 | + start_samples = ( | ||
| 266 | + int(start / num_frames * seg_m.window_size) + sample_offset | ||
| 267 | + ) | ||
| 268 | + end_samples = ( | ||
| 269 | + int(k / num_frames * seg_m.window_size) + sample_offset | ||
| 270 | + ) | ||
| 271 | + num_samples = end_samples - start_samples | ||
| 272 | + buffer[idx : idx + num_samples] = audio[start_samples:end_samples] | ||
| 273 | + idx += num_samples | ||
| 274 | + | ||
| 275 | + start = None | ||
| 276 | + if start is not None: | ||
| 277 | + start_samples = ( | ||
| 278 | + int(start / num_frames * seg_m.window_size) + sample_offset | ||
| 279 | + ) | ||
| 280 | + end_samples = int(k / num_frames * seg_m.window_size) + sample_offset | ||
| 281 | + num_samples = end_samples - start_samples | ||
| 282 | + buffer[idx : idx + num_samples] = audio[start_samples:end_samples] | ||
| 283 | + idx += num_samples | ||
| 284 | + | ||
| 285 | + stream = extractor.create_stream() | ||
| 286 | + stream.accept_waveform(sample_rate=seg_m.sample_rate, waveform=buffer[:idx]) | ||
| 287 | + stream.input_finished() | ||
| 288 | + | ||
| 289 | + assert extractor.is_ready(stream) | ||
| 290 | + embedding = extractor.compute(stream) | ||
| 291 | + embedding = np.array(embedding) | ||
| 292 | + | ||
| 293 | + ans_chunk_speaker_pair.append([i, j]) | ||
| 294 | + ans_embeddings.append(embedding) | ||
| 295 | + | ||
| 296 | + assert len(ans_chunk_speaker_pair) == len(ans_embeddings), ( | ||
| 297 | + len(ans_chunk_speaker_pair), | ||
| 298 | + len(ans_embeddings), | ||
| 299 | + ) | ||
| 300 | + return ans_chunk_speaker_pair, np.array(ans_embeddings) | ||
| 301 | + | ||
| 302 | + | ||
| 303 | +def main(): | ||
| 304 | + args = get_args() | ||
| 305 | + assert Path(args.seg_model).is_file(), args.seg_model | ||
| 306 | + assert Path(args.wav).is_file(), args.wav | ||
| 307 | + | ||
| 308 | + seg_m = OnnxSegmentationModel(args.seg_model) | ||
| 309 | + audio = load_wav(args.wav, seg_m.sample_rate) | ||
| 310 | + # audio: (num_samples,) | ||
| 311 | + | ||
| 312 | + num = (audio.shape[0] - seg_m.window_size) // seg_m.window_shift + 1 | ||
| 313 | + | ||
| 314 | + samples = as_strided( | ||
| 315 | + audio, | ||
| 316 | + shape=(num, seg_m.window_size), | ||
| 317 | + strides=(seg_m.window_shift * audio.strides[0], audio.strides[0]), | ||
| 318 | + ) | ||
| 319 | + | ||
| 320 | + # or use torch.Tensor.unfold | ||
| 321 | + # samples = torch.from_numpy(audio).unfold(0, seg_m.window_size, seg_m.window_shift).numpy() | ||
| 322 | + | ||
| 323 | + if ( | ||
| 324 | + audio.shape[0] < seg_m.window_size | ||
| 325 | + or (audio.shape[0] - seg_m.window_size) % seg_m.window_shift > 0 | ||
| 326 | + ): | ||
| 327 | + has_last_chunk = True | ||
| 328 | + else: | ||
| 329 | + has_last_chunk = False | ||
| 330 | + | ||
| 331 | + num_chunks = samples.shape[0] | ||
| 332 | + batch_size = 32 | ||
| 333 | + output = [] | ||
| 334 | + for i in range(0, num_chunks, batch_size): | ||
| 335 | + start = i | ||
| 336 | + end = i + batch_size | ||
| 337 | + # it's perfectly ok to use end > num_chunks | ||
| 338 | + y = seg_m(samples[start:end]) | ||
| 339 | + output.append(y) | ||
| 340 | + | ||
| 341 | + if has_last_chunk: | ||
| 342 | + last_chunk = audio[num_chunks * seg_m.window_shift :] # noqa | ||
| 343 | + pad_size = seg_m.window_size - last_chunk.shape[0] | ||
| 344 | + last_chunk = np.pad(last_chunk, (0, pad_size)) | ||
| 345 | + last_chunk = np.expand_dims(last_chunk, axis=0) | ||
| 346 | + y = seg_m(last_chunk) | ||
| 347 | + output.append(y) | ||
| 348 | + | ||
| 349 | + y = np.vstack(output) | ||
| 350 | + # y: (num_chunks, num_frames, num_classes) | ||
| 351 | + | ||
| 352 | + mapping = get_powerset_mapping( | ||
| 353 | + num_classes=seg_m.num_classes, | ||
| 354 | + num_speakers=seg_m.num_speakers, | ||
| 355 | + powerset_max_classes=seg_m.powerset_max_classes, | ||
| 356 | + ) | ||
| 357 | + labels = to_multi_label(y, mapping=mapping) | ||
| 358 | + # labels: (num_chunks, num_frames, num_speakers) | ||
| 359 | + | ||
| 360 | + inactive = (labels.sum(axis=1) == 0).astype(np.int8) | ||
| 361 | + # inactive: (num_chunks, num_speakers) | ||
| 362 | + | ||
| 363 | + speakers_per_frame = speaker_count(labels=labels, seg_m=seg_m) | ||
| 364 | + # speakers_per_frame: (num_frames, speakers_per_frame) | ||
| 365 | + | ||
| 366 | + if speakers_per_frame.max() == 0: | ||
| 367 | + print("No speakers found in the audio file!") | ||
| 368 | + return | ||
| 369 | + | ||
| 370 | + # if users specify only 1 speaker for clustering, then return the | ||
| 371 | + # result directly | ||
| 372 | + | ||
| 373 | + # Now, get embeddings | ||
| 374 | + chunk_speaker_pair, embeddings = get_embeddings( | ||
| 375 | + args.speaker_embedding_model, | ||
| 376 | + audio=audio, | ||
| 377 | + labels=labels, | ||
| 378 | + seg_m=seg_m, | ||
| 379 | + # exclude_overlap=True, | ||
| 380 | + exclude_overlap=False, | ||
| 381 | + ) | ||
| 382 | + # chunk_speaker_pair: a list of (chunk_idx, speaker_idx) | ||
| 383 | + # embeddings: (batch_size, embedding_dim) | ||
| 384 | + | ||
| 385 | + # Please change num_clusters or threshold by yourself. | ||
| 386 | + clustering_config = sherpa_onnx.FastClusteringConfig(num_clusters=2) | ||
| 387 | + # clustering_config = sherpa_onnx.FastClusteringConfig(threshold=0.8) | ||
| 388 | + clustering = sherpa_onnx.FastClustering(clustering_config) | ||
| 389 | + cluster_labels = clustering(embeddings) | ||
| 390 | + | ||
| 391 | + chunk_speaker_to_cluster = dict() | ||
| 392 | + for (chunk_idx, speaker_idx), cluster_idx in zip( | ||
| 393 | + chunk_speaker_pair, cluster_labels | ||
| 394 | + ): | ||
| 395 | + if inactive[chunk_idx, speaker_idx] == 1: | ||
| 396 | + print("skip ", chunk_idx, speaker_idx) | ||
| 397 | + continue | ||
| 398 | + chunk_speaker_to_cluster[(chunk_idx, speaker_idx)] = cluster_idx | ||
| 399 | + | ||
| 400 | + num_speakers = max(cluster_labels) + 1 | ||
| 401 | + relabels = np.zeros((labels.shape[0], labels.shape[1], num_speakers)) | ||
| 402 | + for i in range(labels.shape[0]): | ||
| 403 | + for j in range(labels.shape[1]): | ||
| 404 | + for k in range(labels.shape[2]): | ||
| 405 | + if (i, k) not in chunk_speaker_to_cluster: | ||
| 406 | + continue | ||
| 407 | + t = chunk_speaker_to_cluster[(i, k)] | ||
| 408 | + | ||
| 409 | + if labels[i, j, k] == 1: | ||
| 410 | + relabels[i, j, t] = 1 | ||
| 411 | + | ||
| 412 | + num_frames = ( | ||
| 413 | + int( | ||
| 414 | + (seg_m.window_size + (relabels.shape[0] - 1) * seg_m.window_shift) | ||
| 415 | + / seg_m.receptive_field_shift | ||
| 416 | + ) | ||
| 417 | + + 1 | ||
| 418 | + ) | ||
| 419 | + | ||
| 420 | + count = np.zeros((num_frames, relabels.shape[-1])) | ||
| 421 | + for i in range(relabels.shape[0]): | ||
| 422 | + this_chunk = relabels[i] | ||
| 423 | + start = int(i * seg_m.window_shift / seg_m.receptive_field_shift + 0.5) | ||
| 424 | + end = start + this_chunk.shape[0] | ||
| 425 | + count[start:end] += this_chunk | ||
| 426 | + | ||
| 427 | + if has_last_chunk: | ||
| 428 | + stop_frame = int(audio.shape[0] / seg_m.receptive_field_shift) | ||
| 429 | + count = count[:stop_frame] | ||
| 430 | + | ||
| 431 | + sorted_count = np.argsort(-count, axis=-1) | ||
| 432 | + final = np.zeros((count.shape[0], count.shape[1])) | ||
| 433 | + | ||
| 434 | + for i, (c, sc) in enumerate(zip(speakers_per_frame, sorted_count)): | ||
| 435 | + for k in range(c): | ||
| 436 | + final[i, sc[k]] = 1 | ||
| 437 | + | ||
| 438 | + min_duration_off = 0.5 | ||
| 439 | + min_duration_on = 0.3 | ||
| 440 | + onset = 0.5 | ||
| 441 | + offset = 0.5 | ||
| 442 | + # final: (num_frames, num_speakers) | ||
| 443 | + | ||
| 444 | + final = final.T | ||
| 445 | + for kk in range(final.shape[0]): | ||
| 446 | + segment_list = [] | ||
| 447 | + frames = final[kk] | ||
| 448 | + | ||
| 449 | + is_active = frames[0] > onset | ||
| 450 | + | ||
| 451 | + start = None | ||
| 452 | + if is_active: | ||
| 453 | + start = 0 | ||
| 454 | + scale = seg_m.receptive_field_shift / seg_m.sample_rate | ||
| 455 | + scale_offset = seg_m.receptive_field_size / seg_m.sample_rate * 0.5 | ||
| 456 | + for i in range(1, len(frames)): | ||
| 457 | + if is_active: | ||
| 458 | + if frames[i] < offset: | ||
| 459 | + segment = Segment( | ||
| 460 | + start=start * scale + scale_offset, | ||
| 461 | + end=i * scale + scale_offset, | ||
| 462 | + speaker=kk, | ||
| 463 | + ) | ||
| 464 | + segment_list.append(segment) | ||
| 465 | + is_active = False | ||
| 466 | + else: | ||
| 467 | + if frames[i] > onset: | ||
| 468 | + start = i | ||
| 469 | + is_active = True | ||
| 470 | + | ||
| 471 | + if is_active: | ||
| 472 | + segment = Segment( | ||
| 473 | + start=start * scale + scale_offset, | ||
| 474 | + end=(len(frames) - 1) * scale + scale_offset, | ||
| 475 | + speaker=kk, | ||
| 476 | + ) | ||
| 477 | + segment_list.append(segment) | ||
| 478 | + | ||
| 479 | + if len(segment_list) > 1: | ||
| 480 | + merge_segment_list(segment_list, min_duration_off=min_duration_off) | ||
| 481 | + for s in segment_list: | ||
| 482 | + if s.duration < min_duration_on: | ||
| 483 | + continue | ||
| 484 | + print(s) | ||
| 485 | + | ||
| 486 | + | ||
| 487 | +if __name__ == "__main__": | ||
| 488 | + main() |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | + | ||
| 3 | +""" | ||
| 4 | +Please refer to | ||
| 5 | +https://github.com/k2-fsa/sherpa-onnx/blob/master/.github/workflows/speaker-diarization.yaml | ||
| 6 | +for usages. | ||
| 7 | +""" | ||
| 8 | + | ||
| 9 | +""" | ||
| 10 | +1. Go to https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM/tree/main | ||
| 11 | +wget https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM/resolve/main/speaker-embedding.onnx | ||
| 12 | + | ||
| 13 | +2. Change line 166 of pyannote/audio/pipelines/speaker_diarization.py | ||
| 14 | + | ||
| 15 | +``` | ||
| 16 | + # self._embedding = PretrainedSpeakerEmbedding( | ||
| 17 | + # self.embedding, use_auth_token=use_auth_token | ||
| 18 | + # ) | ||
| 19 | + self._embedding = embedding | ||
| 20 | +``` | ||
| 21 | +""" | ||
| 22 | + | ||
| 23 | +import argparse | ||
| 24 | +from pathlib import Path | ||
| 25 | + | ||
| 26 | +import torch | ||
| 27 | +from pyannote.audio import Model | ||
| 28 | +from pyannote.audio.pipelines import SpeakerDiarization as SpeakerDiarizationPipeline | ||
| 29 | +from pyannote.audio.pipelines.speaker_verification import ( | ||
| 30 | + ONNXWeSpeakerPretrainedSpeakerEmbedding, | ||
| 31 | +) | ||
| 32 | + | ||
| 33 | + | ||
| 34 | +def get_args(): | ||
| 35 | + parser = argparse.ArgumentParser() | ||
| 36 | + parser.add_argument("--wav", type=str, required=True, help="Path to test.wav") | ||
| 37 | + | ||
| 38 | + return parser.parse_args() | ||
| 39 | + | ||
| 40 | + | ||
| 41 | +def build_pipeline(): | ||
| 42 | + embedding_filename = "./speaker-embedding.onnx" | ||
| 43 | + if Path(embedding_filename).is_file(): | ||
| 44 | + # You need to modify line 166 | ||
| 45 | + # of pyannote/audio/pipelines/speaker_diarization.py | ||
| 46 | + # Please see the comments at the start of this script for details | ||
| 47 | + embedding = ONNXWeSpeakerPretrainedSpeakerEmbedding(embedding_filename) | ||
| 48 | + else: | ||
| 49 | + embedding = "hbredin/wespeaker-voxceleb-resnet34-LM" | ||
| 50 | + | ||
| 51 | + pt_filename = "./pytorch_model.bin" | ||
| 52 | + segmentation = Model.from_pretrained(pt_filename) | ||
| 53 | + segmentation.eval() | ||
| 54 | + | ||
| 55 | + pipeline = SpeakerDiarizationPipeline( | ||
| 56 | + segmentation=segmentation, | ||
| 57 | + embedding=embedding, | ||
| 58 | + embedding_exclude_overlap=True, | ||
| 59 | + ) | ||
| 60 | + | ||
| 61 | + params = { | ||
| 62 | + "clustering": { | ||
| 63 | + "method": "centroid", | ||
| 64 | + "min_cluster_size": 12, | ||
| 65 | + "threshold": 0.7045654963945799, | ||
| 66 | + }, | ||
| 67 | + "segmentation": {"min_duration_off": 0.5}, | ||
| 68 | + } | ||
| 69 | + | ||
| 70 | + pipeline.instantiate(params) | ||
| 71 | + return pipeline | ||
| 72 | + | ||
| 73 | + | ||
| 74 | +@torch.no_grad() | ||
| 75 | +def main(): | ||
| 76 | + args = get_args() | ||
| 77 | + assert Path(args.wav).is_file(), args.wav | ||
| 78 | + pipeline = build_pipeline() | ||
| 79 | + print(pipeline) | ||
| 80 | + t = pipeline(args.wav) | ||
| 81 | + print(type(t)) | ||
| 82 | + print(t) | ||
| 83 | + | ||
| 84 | + | ||
| 85 | +if __name__ == "__main__": | ||
| 86 | + main() |
| @@ -52,7 +52,7 @@ class FastClustering::Impl { | @@ -52,7 +52,7 @@ class FastClustering::Impl { | ||
| 52 | std::vector<double> height(num_rows - 1); | 52 | std::vector<double> height(num_rows - 1); |
| 53 | 53 | ||
| 54 | fastclustercpp::hclust_fast(num_rows, distance.data(), | 54 | fastclustercpp::hclust_fast(num_rows, distance.data(), |
| 55 | - fastclustercpp::HCLUST_METHOD_SINGLE, | 55 | + fastclustercpp::HCLUST_METHOD_COMPLETE, |
| 56 | merge.data(), height.data()); | 56 | merge.data(), height.data()); |
| 57 | 57 | ||
| 58 | std::vector<int32_t> labels(num_rows); | 58 | std::vector<int32_t> labels(num_rows); |
-
请 注册 或 登录 后发表评论