Fangjun Kuang
Committed by GitHub

Export Pyannote speaker segmentation models to onnx (#1382)

  1 +name: export-pyannote-segmentation-to-onnx
  2 +
  3 +on:
  4 + workflow_dispatch:
  5 +
  6 +concurrency:
  7 + group: export-pyannote-segmentation-to-onnx-${{ github.ref }}
  8 + cancel-in-progress: true
  9 +
  10 +jobs:
  11 + export-pyannote-segmentation-to-onnx:
  12 + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
  13 + name: export Pyannote segmentation models to ONNX
  14 + runs-on: ${{ matrix.os }}
  15 + strategy:
  16 + fail-fast: false
  17 + matrix:
  18 + os: [macos-latest]
  19 + python-version: ["3.10"]
  20 +
  21 + steps:
  22 + - uses: actions/checkout@v4
  23 +
  24 + - name: Setup Python ${{ matrix.python-version }}
  25 + uses: actions/setup-python@v5
  26 + with:
  27 + python-version: ${{ matrix.python-version }}
  28 +
  29 + - name: Install pyannote
  30 + shell: bash
  31 + run: |
  32 + pip install pyannote.audio onnx onnxruntime
  33 +
  34 + - name: Run
  35 + shell: bash
  36 + run: |
  37 + d=sherpa-onnx-pyannote-segmentation-3-0
  38 + src=$PWD/$d
  39 + mkdir -p $src
  40 +
  41 + pushd scripts/pyannote/segmentation
  42 + ./run.sh
  43 + cp ./*.onnx $src/
  44 + cp ./README.md $src/
  45 + cp ./LICENSE $src/
  46 + cp ./run.sh $src/
  47 + cp ./*.py $src/
  48 +
  49 + popd
  50 + ls -lh $d
  51 + tar cjfv $d.tar.bz2 $d
  52 +
  53 + - name: Release
  54 + uses: svenstaro/upload-release-action@v2
  55 + with:
  56 + file_glob: true
  57 + file: ./*.tar.bz2
  58 + overwrite: true
  59 + repo_name: k2-fsa/sherpa-onnx
  60 + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
  61 + tag: speaker-segmentation-models
  62 +
  63 + - name: Publish to huggingface
  64 + env:
  65 + HF_TOKEN: ${{ secrets.HF_TOKEN }}
  66 + uses: nick-fields/retry@v3
  67 + with:
  68 + max_attempts: 20
  69 + timeout_seconds: 200
  70 + shell: bash
  71 + command: |
  72 + git config --global user.email "csukuangfj@gmail.com"
  73 + git config --global user.name "Fangjun Kuang"
  74 +
  75 + d=sherpa-onnx-pyannote-segmentation-3-0
  76 + export GIT_LFS_SKIP_SMUDGE=1
  77 + export GIT_CLONE_PROTECTION_ACTIVE=false
  78 + git clone https://huggingface.co/csukuangfj/$d huggingface
  79 + cp -v $d/* ./huggingface
  80 + cd huggingface
  81 + git lfs track "*.onnx"
  82 + git status
  83 + git add .
  84 + git status
  85 + git commit -m "add models"
  86 + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d main
  1 +#!/usr/bin/env python3
  2 +
  3 +from typing import Any, Dict
  4 +
  5 +import onnx
  6 +import torch
  7 +from onnxruntime.quantization import QuantType, quantize_dynamic
  8 +from pyannote.audio import Model
  9 +from pyannote.audio.core.task import Problem, Resolution
  10 +
  11 +
  12 +def add_meta_data(filename: str, meta_data: Dict[str, Any]):
  13 + """Add meta data to an ONNX model. It is changed in-place.
  14 +
  15 + Args:
  16 + filename:
  17 + Filename of the ONNX model to be changed.
  18 + meta_data:
  19 + Key-value pairs.
  20 + """
  21 + model = onnx.load(filename)
  22 +
  23 + while len(model.metadata_props):
  24 + model.metadata_props.pop()
  25 +
  26 + for key, value in meta_data.items():
  27 + meta = model.metadata_props.add()
  28 + meta.key = key
  29 + meta.value = str(value)
  30 +
  31 + onnx.save(model, filename)
  32 +
  33 +
  34 +@torch.no_grad()
  35 +def main():
  36 + # You can download ./pytorch_model.bin from
  37 + # https://hf-mirror.com/csukuangfj/pyannote-models/tree/main/segmentation-3.0
  38 + pt_filename = "./pytorch_model.bin"
  39 + model = Model.from_pretrained(pt_filename)
  40 + model.eval()
  41 + assert model.dimension == 7, model.dimension
  42 + print(model.specifications)
  43 +
  44 + assert (
  45 + model.specifications.problem == Problem.MONO_LABEL_CLASSIFICATION
  46 + ), model.specifications.problem
  47 +
  48 + assert (
  49 + model.specifications.resolution == Resolution.FRAME
  50 + ), model.specifications.resolution
  51 +
  52 + assert model.specifications.duration == 10.0, model.specifications.duration
  53 +
  54 + assert model.audio.sample_rate == 16000, model.audio.sample_rate
  55 +
  56 + # (batch, num_channels, num_samples)
  57 + assert list(model.example_input_array.shape) == [
  58 + 1,
  59 + 1,
  60 + 16000 * 10,
  61 + ], model.example_input_array.shape
  62 +
  63 + example_output = model(model.example_input_array)
  64 +
  65 + # (batch, num_frames, num_classes)
  66 + assert list(example_output.shape) == [1, 589, 7], example_output.shape
  67 +
  68 + assert model.receptive_field.step == 0.016875, model.receptive_field.step
  69 + assert model.receptive_field.duration == 0.0619375, model.receptive_field.duration
  70 + assert model.receptive_field.step * 16000 == 270, model.receptive_field.step * 16000
  71 + assert model.receptive_field.duration * 16000 == 991, (
  72 + model.receptive_field.duration * 16000
  73 + )
  74 +
  75 + opset_version = 18
  76 +
  77 + filename = "model.onnx"
  78 + torch.onnx.export(
  79 + model,
  80 + model.example_input_array,
  81 + filename,
  82 + opset_version=opset_version,
  83 + input_names=["x"],
  84 + output_names=["y"],
  85 + dynamic_axes={
  86 + "x": {0: "N", 2: "T"},
  87 + "y": {0: "N", 1: "T"},
  88 + },
  89 + )
  90 +
  91 + sample_rate = model.audio.sample_rate
  92 +
  93 + window_size = int(model.specifications.duration) * 16000
  94 + receptive_field_size = int(model.receptive_field.duration * 16000)
  95 + receptive_field_shift = int(model.receptive_field.step * 16000)
  96 +
  97 + meta_data = {
  98 + "num_speakers": len(model.specifications.classes),
  99 + "powerset_max_classes": model.specifications.powerset_max_classes,
  100 + "num_classes": model.dimension,
  101 + "sample_rate": sample_rate,
  102 + "window_size": window_size,
  103 + "receptive_field_size": receptive_field_size,
  104 + "receptive_field_shift": receptive_field_shift,
  105 + "model_type": "pyannote-segmentation-3.0",
  106 + "version": "1",
  107 + "model_author": "pyannote",
  108 + "maintainer": "k2-fsa",
  109 + "url_1": "https://huggingface.co/pyannote/segmentation-3.0",
  110 + "url_2": "https://huggingface.co/csukuangfj/pyannote-models/tree/main/segmentation-3.0",
  111 + "license": "https://huggingface.co/pyannote/segmentation-3.0/blob/main/LICENSE",
  112 + }
  113 + add_meta_data(filename=filename, meta_data=meta_data)
  114 +
  115 + print("Generate int8 quantization models")
  116 +
  117 + filename_int8 = "model.int8.onnx"
  118 + quantize_dynamic(
  119 + model_input=filename,
  120 + model_output=filename_int8,
  121 + weight_type=QuantType.QUInt8,
  122 + )
  123 +
  124 + print(f"Saved to {filename} and {filename_int8}")
  125 +
  126 +
  127 +if __name__ == "__main__":
  128 + main()
  1 +
  2 +# config.yaml
  3 +
  4 +
  5 +```yaml
  6 +task:
  7 + _target_: pyannote.audio.tasks.SpeakerDiarization
  8 + duration: 10.0
  9 + max_speakers_per_chunk: 3
  10 + max_speakers_per_frame: 2
  11 +model:
  12 + _target_: pyannote.audio.models.segmentation.PyanNet
  13 + sample_rate: 16000
  14 + num_channels: 1
  15 + sincnet:
  16 + stride: 10
  17 + lstm:
  18 + hidden_size: 128
  19 + num_layers: 4
  20 + bidirectional: true
  21 + monolithic: true
  22 + linear:
  23 + hidden_size: 128
  24 + num_layers: 2
  25 +```
  26 +
  27 +# Model architecture of ./pytorch_model.bin
  28 +
  29 +`print(model)`:
  30 +
  31 +```python3
  32 +PyanNet(
  33 + (sincnet): SincNet(
  34 + (wav_norm1d): InstanceNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  35 + (conv1d): ModuleList(
  36 + (0): Encoder(
  37 + (filterbank): ParamSincFB()
  38 + )
  39 + (1): Conv1d(80, 60, kernel_size=(5,), stride=(1,))
  40 + (2): Conv1d(60, 60, kernel_size=(5,), stride=(1,))
  41 + )
  42 + (pool1d): ModuleList(
  43 + (0-2): 3 x MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  44 + )
  45 + (norm1d): ModuleList(
  46 + (0): InstanceNorm1d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  47 + (1-2): 2 x InstanceNorm1d(60, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  48 + )
  49 + )
  50 + (lstm): LSTM(60, 128, num_layers=4, batch_first=True, dropout=0.5, bidirectional=True)
  51 + (linear): ModuleList(
  52 + (0): Linear(in_features=256, out_features=128, bias=True)
  53 + (1): Linear(in_features=128, out_features=128, bias=True)
  54 + )
  55 + (classifier): Linear(in_features=128, out_features=7, bias=True)
  56 + (activation): LogSoftmax(dim=-1)
  57 +)
  58 +```
  59 +
  60 +```python3
  61 +>>> list(model.specifications)
  62 +[Specifications(problem=<Problem.MONO_LABEL_CLASSIFICATION: 1>, resolution=<Resolution.FRAME: 1>, duration=10.0, min_duration=None, warm_up=(0.0, 0.0), classes=['speaker#1', 'speaker#2', 'speaker#3'], powerset_max_classes=2, permutation_invariant=True)]
  63 +```
  64 +
  65 +```python3
  66 +>>> model.hparams
  67 +"linear": {'hidden_size': 128, 'num_layers': 2}
  68 +"lstm": {'hidden_size': 128, 'num_layers': 4, 'bidirectional': True, 'monolithic': True, 'dropout': 0.5, 'batch_first': True}
  69 +"num_channels": 1
  70 +"sample_rate": 16000
  71 +"sincnet": {'stride': 10, 'sample_rate': 16000}
  72 +```
  73 +
  74 +## Papers
  75 +
  76 +- [pyannote.audio 2.1 speaker diarization pipeline: principle, benchmark, and recipe](https://hal.science/hal-04247212/document)
  77 +- [pyannote.audio speaker diarization pipeline at VoxSRC 2023](https://mmai.io/datasets/voxceleb/voxsrc/data_workshop_2023/reports/pyannote_report.pdf)
  78 +
  1 +#!/usr/bin/env bash
  2 +
  3 +
  4 +python3 -m onnxruntime.quantization.preprocess --input model.onnx --output tmp.preprocessed.onnx
  5 +mv ./tmp.preprocessed.onnx ./model.onnx
  6 +./show-onnx.py --filename ./model.onnx
  7 +
  8 +<<EOF
  9 +=========./model.onnx==========
  10 +NodeArg(name='x', type='tensor(float)', shape=[1, 1, 'T'])
  11 +-----
  12 +NodeArg(name='y', type='tensor(float)', shape=[1, 'floor(floor(floor(floor(T/10 - 251/10)/3 - 2/3)/3)/3 - 8/3) + 1', 7])
  13 +
  14 + floor(floor(floor(floor(T/10 - 251/10)/3 - 2/3)/3)/3 - 8/3) + 1
  15 += floor(floor(floor(floor(T - 251)/30 - 2/3)/3)/3 - 8/3) + 1
  16 += floor(floor(floor(floor(T - 271)/30)/3)/3 - 8/3) + 1
  17 += floor(floor(floor(floor(T - 271)/90))/3 - 8/3) + 1
  18 += floor(floor(floor(T - 271)/90)/3 - 8/3) + 1
  19 += floor(floor((T - 271)/90)/3 - 8/3) + 1
  20 += floor(floor((T - 271)/90 - 8)/3) + 1
  21 += floor(floor((T - 271 - 720)/90)/3) + 1
  22 += floor(floor((T - 991)/90)/3) + 1
  23 += floor(floor((T - 991)/270)) + 1
  24 += (T - 991)/270 + 1
  25 += (T - 991 + 270)/270
  26 += (T - 721)/270
  27 +
  28 +It means:
  29 + - Number of input samples should be at least 721
  30 + - One frame corresponds to 270 samples. (If we use T + 270, it outputs one more frame)
  31 +EOF
  1 +#!/usr/bin/env bash
  2 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +set -ex
  5 +function install_pyannote() {
  6 + pip install pyannote.audio onnx onnxruntime
  7 +}
  8 +
  9 +function download_test_files() {
  10 + curl -SL -O https://huggingface.co/csukuangfj/pyannote-models/resolve/main/segmentation-3.0/pytorch_model.bin
  11 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/lei-jun-test.wav
  12 +}
  13 +
  14 +install_pyannote
  15 +download_test_files
  16 +
  17 +./export-onnx.py
  18 +./preprocess.sh
  19 +
  20 +echo "----------torch----------"
  21 +./vad-torch.py
  22 +
  23 +echo "----------onnx model.onnx----------"
  24 +./vad-onnx.py --model ./model.onnx --wav ./lei-jun-test.wav
  25 +
  26 +echo "----------onnx model.int8.onnx----------"
  27 +./vad-onnx.py --model ./model.int8.onnx --wav ./lei-jun-test.wav
  28 +
  29 +cat >README.md << EOF
  30 +# Introduction
  31 +
  32 +Models in this file are converted from
  33 +https://huggingface.co/pyannote/segmentation-3.0/tree/main
  34 +
  35 +EOF
  36 +
  37 +cat >LICENSE <<EOF
  38 +MIT License
  39 +
  40 +Copyright (c) 2022 CNRS
  41 +
  42 +Permission is hereby granted, free of charge, to any person obtaining a copy
  43 +of this software and associated documentation files (the "Software"), to deal
  44 +in the Software without restriction, including without limitation the rights
  45 +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  46 +copies of the Software, and to permit persons to whom the Software is
  47 +furnished to do so, subject to the following conditions:
  48 +
  49 +The above copyright notice and this permission notice shall be included in all
  50 +copies or substantial portions of the Software.
  51 +
  52 +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  53 +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  54 +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  55 +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  56 +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  57 +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  58 +SOFTWARE.
  59 +EOF
  1 +#!/usr/bin/env python3
  2 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +import onnxruntime
  5 +import argparse
  6 +
  7 +
  8 +def get_args():
  9 + parser = argparse.ArgumentParser(
  10 + formatter_class=argparse.ArgumentDefaultsHelpFormatter
  11 + )
  12 +
  13 + parser.add_argument(
  14 + "--filename",
  15 + type=str,
  16 + required=True,
  17 + help="Path to model.onnx",
  18 + )
  19 +
  20 + return parser.parse_args()
  21 +
  22 +
  23 +def show(filename):
  24 + session_opts = onnxruntime.SessionOptions()
  25 + session_opts.log_severity_level = 3
  26 + sess = onnxruntime.InferenceSession(filename, session_opts)
  27 + for i in sess.get_inputs():
  28 + print(i)
  29 +
  30 + print("-----")
  31 +
  32 + for i in sess.get_outputs():
  33 + print(i)
  34 +
  35 +
  36 +def main():
  37 + args = get_args()
  38 + print(f"========={args.filename}==========")
  39 + show(args.filename)
  40 +
  41 +
  42 +if __name__ == "__main__":
  43 + main()
  1 +#!/usr/bin/env python3
  2 +
  3 +"""
  4 +./export-onnx.py
  5 +./preprocess.sh
  6 +
  7 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/lei-jun-test.wav
  8 +./vad-onnx.py --model ./model.onnx --wav ./lei-jun-test.wav
  9 +"""
  10 +
  11 +import argparse
  12 +from pathlib import Path
  13 +
  14 +import librosa
  15 +import numpy as np
  16 +import onnxruntime as ort
  17 +import soundfile as sf
  18 +from numpy.lib.stride_tricks import as_strided
  19 +
  20 +
  21 +def get_args():
  22 + parser = argparse.ArgumentParser()
  23 + parser.add_argument("--model", type=str, required=True, help="Path to model.onnx")
  24 + parser.add_argument("--wav", type=str, required=True, help="Path to test.wav")
  25 +
  26 + return parser.parse_args()
  27 +
  28 +
  29 +class OnnxModel:
  30 + def __init__(self, filename):
  31 + session_opts = ort.SessionOptions()
  32 + session_opts.inter_op_num_threads = 1
  33 + session_opts.intra_op_num_threads = 1
  34 +
  35 + self.session_opts = session_opts
  36 +
  37 + self.model = ort.InferenceSession(
  38 + filename,
  39 + sess_options=self.session_opts,
  40 + providers=["CPUExecutionProvider"],
  41 + )
  42 +
  43 + meta = self.model.get_modelmeta().custom_metadata_map
  44 + print(meta)
  45 +
  46 + self.window_size = int(meta["window_size"])
  47 + self.sample_rate = int(meta["sample_rate"])
  48 + self.window_shift = int(0.1 * self.window_size)
  49 + self.receptive_field_size = int(meta["receptive_field_size"])
  50 + self.receptive_field_shift = int(meta["receptive_field_shift"])
  51 + self.num_speakers = int(meta["num_speakers"])
  52 + self.powerset_max_classes = int(meta["powerset_max_classes"])
  53 + self.num_classes = int(meta["num_classes"])
  54 +
  55 + def __call__(self, x):
  56 + """
  57 + Args:
  58 + x: (N, num_samples)
  59 + Returns:
  60 + A tensor of shape (N, num_frames, num_classes)
  61 + """
  62 + x = np.expand_dims(x, axis=1)
  63 +
  64 + (y,) = self.model.run(
  65 + [self.model.get_outputs()[0].name], {self.model.get_inputs()[0].name: x}
  66 + )
  67 +
  68 + return y
  69 +
  70 +
  71 +def load_wav(filename, expected_sample_rate) -> np.ndarray:
  72 + audio, sample_rate = sf.read(filename, dtype="float32", always_2d=True)
  73 + audio = audio[:, 0] # only use the first channel
  74 + if sample_rate != expected_sample_rate:
  75 + audio = librosa.resample(
  76 + audio,
  77 + orig_sr=sample_rate,
  78 + target_sr=expected_sample_rate,
  79 + )
  80 + return audio
  81 +
  82 +
  83 +def get_powerset_mapping(num_classes, num_speakers, powerset_max_classes):
  84 + mapping = np.zeros((num_classes, num_speakers))
  85 +
  86 + k = 1
  87 + for i in range(1, powerset_max_classes + 1):
  88 + if i == 1:
  89 + for j in range(0, num_speakers):
  90 + mapping[k, j] = 1
  91 + k += 1
  92 + elif i == 2:
  93 + for j in range(0, num_speakers):
  94 + for m in range(j + 1, num_speakers):
  95 + mapping[k, j] = 1
  96 + mapping[k, m] = 1
  97 + k += 1
  98 + elif i == 3:
  99 + raise RuntimeError("Unsupported")
  100 +
  101 + return mapping
  102 +
  103 +
  104 +def to_multi_label(y, mapping):
  105 + """
  106 + Args:
  107 + y: (num_chunks, num_frames, num_classes)
  108 + Returns:
  109 + A tensor of shape (num_chunks, num_frames, num_speakers)
  110 + """
  111 + y = np.argmax(y, axis=-1)
  112 + labels = mapping[y.reshape(-1)].reshape(y.shape[0], y.shape[1], -1)
  113 + return labels
  114 +
  115 +
  116 +def main():
  117 + args = get_args()
  118 + assert Path(args.model).is_file(), args.model
  119 + assert Path(args.wav).is_file(), args.wav
  120 +
  121 + m = OnnxModel(args.model)
  122 + audio = load_wav(args.wav, m.sample_rate)
  123 + # audio: (num_samples,)
  124 + print("audio", audio.shape, audio.min(), audio.max(), audio.sum())
  125 +
  126 + num = (audio.shape[0] - m.window_size) // m.window_shift + 1
  127 +
  128 + samples = as_strided(
  129 + audio,
  130 + shape=(num, m.window_size),
  131 + strides=(m.window_shift * audio.strides[0], audio.strides[0]),
  132 + )
  133 +
  134 + # or use torch.Tensor.unfold
  135 + # samples = torch.from_numpy(audio).unfold(0, m.window_size, m.window_shift).numpy()
  136 +
  137 + print(
  138 + "samples",
  139 + samples.shape,
  140 + samples.mean(),
  141 + samples.sum(),
  142 + samples[:3, :3].sum(axis=-1),
  143 + )
  144 +
  145 + if (
  146 + audio.shape[0] < m.window_size
  147 + or (audio.shape[0] - m.window_size) % m.window_shift > 0
  148 + ):
  149 + has_last_chunk = True
  150 + else:
  151 + has_last_chunk = False
  152 +
  153 + num_chunks = samples.shape[0]
  154 + batch_size = 32
  155 + output = []
  156 + for i in range(0, num_chunks, batch_size):
  157 + start = i
  158 + end = i + batch_size
  159 + # it's perfectly ok to use end > num_chunks
  160 + y = m(samples[start:end])
  161 + output.append(y)
  162 +
  163 + if has_last_chunk:
  164 + last_chunk = audio[num_chunks * m.window_shift :] # noqa
  165 + pad_size = m.window_size - last_chunk.shape[0]
  166 + last_chunk = np.pad(last_chunk, (0, pad_size))
  167 + last_chunk = np.expand_dims(last_chunk, axis=0)
  168 + y = m(last_chunk)
  169 + output.append(y)
  170 +
  171 + y = np.vstack(output)
  172 + # y: (num_chunks, num_frames, num_classes)
  173 +
  174 + mapping = get_powerset_mapping(
  175 + num_classes=m.num_classes,
  176 + num_speakers=m.num_speakers,
  177 + powerset_max_classes=m.powerset_max_classes,
  178 + )
  179 + labels = to_multi_label(y, mapping=mapping)
  180 + # labels: (num_chunks, num_frames, num_speakers)
  181 +
  182 + # binary classification
  183 + labels = np.max(labels, axis=-1)
  184 + # labels: (num_chunk, num_frames)
  185 +
  186 + num_frames = (
  187 + int(
  188 + (m.window_size + (labels.shape[0] - 1) * m.window_shift)
  189 + / m.receptive_field_shift
  190 + )
  191 + + 1
  192 + )
  193 +
  194 + count = np.zeros((num_frames,))
  195 + classification = np.zeros((num_frames,))
  196 + weight = np.hamming(labels.shape[1])
  197 +
  198 + for i in range(labels.shape[0]):
  199 + this_chunk = labels[i]
  200 + start = int(i * m.window_shift / m.receptive_field_shift + 0.5)
  201 + end = start + this_chunk.shape[0]
  202 +
  203 + classification[start:end] += this_chunk * weight
  204 + count[start:end] += weight
  205 +
  206 + classification /= np.maximum(count, 1e-12)
  207 +
  208 + if has_last_chunk:
  209 + stop_frame = int(audio.shape[0] / m.receptive_field_shift)
  210 + classification = classification[:stop_frame]
  211 +
  212 + classification = classification.tolist()
  213 +
  214 + onset = 0.5
  215 + offset = 0.5
  216 +
  217 + is_active = classification[0] > onset
  218 + start = None
  219 +
  220 + scale = m.receptive_field_shift / m.sample_rate
  221 + scale_offset = m.receptive_field_size / m.sample_rate * 0.5
  222 +
  223 + for i in range(len(classification)):
  224 + if is_active:
  225 + if classification[i] < offset:
  226 + print(
  227 + f"{start*scale + scale_offset:.3f} -- {i*scale + scale_offset:.3f}"
  228 + )
  229 + is_active = False
  230 + else:
  231 + if classification[i] > onset:
  232 + start = i
  233 + is_active = True
  234 +
  235 + if is_active:
  236 + print(
  237 + f"{start*scale + scale_offset:.3f} -- {(len(classification)-1)*scale + scale_offset:.3f}"
  238 + )
  239 +
  240 +
  241 +if __name__ == "__main__":
  242 + main()
  1 +#!/usr/bin/env python3
  2 +
  3 +import torch
  4 +from pyannote.audio import Model
  5 +from pyannote.audio.pipelines import (
  6 + VoiceActivityDetection as VoiceActivityDetectionPipeline,
  7 +)
  8 +
  9 +
  10 +@torch.no_grad()
  11 +def main():
  12 + # Please download it from
  13 + # https://huggingface.co/csukuangfj/pyannote-models/tree/main/segmentation-3.0
  14 + pt_filename = "./pytorch_model.bin"
  15 + model = Model.from_pretrained(pt_filename)
  16 + model.eval()
  17 +
  18 + pipeline = VoiceActivityDetectionPipeline(segmentation=model)
  19 +
  20 + # https://huggingface.co/pyannote/voice-activity-detection/blob/main/config.yaml
  21 + # https://github.com/pyannote/pyannote-audio/issues/1215
  22 + initial_params = {
  23 + "min_duration_on": 0.0,
  24 + "min_duration_off": 0.0,
  25 + }
  26 + pipeline.onset = 0.5
  27 + pipeline.offset = 0.5
  28 +
  29 + pipeline.instantiate(initial_params)
  30 +
  31 + # wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/lei-jun-test.wav
  32 + t = pipeline("./lei-jun-test.wav")
  33 + print(type(t))
  34 + print(t)
  35 +
  36 +
  37 +if __name__ == "__main__":
  38 + main()