export-onnx.py
4.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import os
from typing import Any, Dict
import onnx
import torch
from onnxruntime.quantization import QuantType, quantize_dynamic
from pyannote.audio import Model
from pyannote.audio.core.task import Problem, Resolution
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
while len(model.metadata_props):
model.metadata_props.pop()
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
@torch.no_grad()
def main():
# You can download ./pytorch_model.bin from
# https://hf-mirror.com/csukuangfj/pyannote-models/tree/main/segmentation-3.0
# or from
# https://huggingface.co/Revai/reverb-diarization-v1/tree/main
pt_filename = "./pytorch_model.bin"
model = Model.from_pretrained(pt_filename)
model.eval()
assert model.dimension == 7, model.dimension
print(model.specifications)
assert (
model.specifications.problem == Problem.MONO_LABEL_CLASSIFICATION
), model.specifications.problem
assert (
model.specifications.resolution == Resolution.FRAME
), model.specifications.resolution
assert model.specifications.duration == 10.0, model.specifications.duration
assert model.audio.sample_rate == 16000, model.audio.sample_rate
# (batch, num_channels, num_samples)
assert list(model.example_input_array.shape) == [
1,
1,
16000 * 10,
], model.example_input_array.shape
example_output = model(model.example_input_array)
# (batch, num_frames, num_classes)
assert list(example_output.shape) == [1, 589, 7], example_output.shape
assert model.receptive_field.step == 0.016875, model.receptive_field.step
assert model.receptive_field.duration == 0.0619375, model.receptive_field.duration
assert model.receptive_field.step * 16000 == 270, model.receptive_field.step * 16000
assert model.receptive_field.duration * 16000 == 991, (
model.receptive_field.duration * 16000
)
opset_version = 13
filename = "model.onnx"
torch.onnx.export(
model,
model.example_input_array,
filename,
opset_version=opset_version,
input_names=["x"],
output_names=["y"],
dynamic_axes={
"x": {0: "N", 2: "T"},
"y": {0: "N", 1: "T"},
},
)
sample_rate = model.audio.sample_rate
window_size = int(model.specifications.duration) * 16000
receptive_field_size = int(model.receptive_field.duration * 16000)
receptive_field_shift = int(model.receptive_field.step * 16000)
is_revai = os.getenv("SHERPA_ONNX_IS_REVAI", "")
if is_revai == "":
url_1 = "https://huggingface.co/pyannote/segmentation-3.0"
url_2 = "https://huggingface.co/csukuangfj/pyannote-models/tree/main/segmentation-3.0"
license_url = (
"https://huggingface.co/pyannote/segmentation-3.0/blob/main/LICENSE"
)
model_author = "pyannote-audio"
else:
url_1 = "https://huggingface.co/Revai/reverb-diarization-v1"
url_2 = "https://huggingface.co/csukuangfj/sherpa-onnx-reverb-diarization-v1"
license_url = (
"https://huggingface.co/Revai/reverb-diarization-v1/blob/main/LICENSE"
)
model_author = "Revai"
meta_data = {
"num_speakers": len(model.specifications.classes),
"powerset_max_classes": model.specifications.powerset_max_classes,
"num_classes": model.dimension,
"sample_rate": sample_rate,
"window_size": window_size,
"receptive_field_size": receptive_field_size,
"receptive_field_shift": receptive_field_shift,
"model_type": "pyannote-segmentation-3.0",
"version": "1",
"model_author": model_author,
"maintainer": "k2-fsa",
"url_1": url_1,
"url_2": url_2,
"license": license_url,
}
add_meta_data(filename=filename, meta_data=meta_data)
print("Generate int8 quantization models")
filename_int8 = "model.int8.onnx"
quantize_dynamic(
model_input=filename,
model_output=filename_int8,
weight_type=QuantType.QUInt8,
)
print(f"Saved to {filename} and {filename_int8}")
if __name__ == "__main__":
main()