export-onnx.py
3.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
import onnx
import torch
from onnxsim import simplify
import torch
from torch import Tensor
def simple_pad(x: Tensor, pad: int) -> Tensor:
# _0 = torch.slice(torch.slice(torch.slice(x), 1), 2, 1, torch.add(1, pad))
_0 = x[:, :, 1 : 1 + pad]
left_pad = torch.flip(_0, [-1])
# _1 = torch.slice(torch.slice(torch.slice(x), 1), 2, torch.sub(-1, pad), -1)
_1 = x[:, :, (-1 - pad) : -1]
right_pad = torch.flip(_1, [-1])
_2 = torch.cat([left_pad, x, right_pad], 2)
return _2
class MyModule(torch.nn.Module):
def __init__(self, m):
super().__init__()
self.m = m
def adaptive_normalization_forward(self, spect):
m = self.m._model.adaptive_normalization
_0 = simple_pad
# Note(fangjun): rknn uses fp16 by default, whose max value is 65504
# so we need to re-write the computation for spect0
# spect0 = torch.log1p(torch.mul(spect, 1048576))
spect0 = torch.log1p(spect) + 13.86294
_1 = torch.eq(len(spect0.shape), 2)
if _1:
_2 = torch.unsqueeze(spect0, 0)
spect1 = _2
else:
spect1 = spect0
mean = torch.mean(spect1, [1], True)
to_pad = m.to_pad
mean0 = _0(
mean,
to_pad,
)
filter_ = m.filter_
mean1 = torch.conv1d(mean0, filter_)
mean_mean = torch.mean(mean1, [-1], True)
spect2 = torch.add(spect1, torch.neg(mean_mean))
return spect2
def forward(self, x: torch.Tensor, h: torch.Tensor, c: torch.Tensor):
m = self.m._model
feature_extractor = m.feature_extractor
x0 = (feature_extractor).forward(
x,
)
norm = self.adaptive_normalization_forward(x0)
x1 = torch.cat([x0, norm], 1)
first_layer = m.first_layer
x2 = (first_layer).forward(
x1,
)
encoder = m.encoder
x3 = (encoder).forward(
x2,
)
decoder = m.decoder
x4, h0, c0, = (decoder).forward(
x3,
h,
c,
)
_0 = torch.mean(torch.squeeze(x4, 1), [1])
out = torch.unsqueeze(_0, 1)
return (out, h0, c0)
@torch.no_grad()
def main():
m = torch.jit.load("./silero_vad.jit")
m = MyModule(m)
x = torch.rand((1, 512), dtype=torch.float32)
h = torch.rand((2, 1, 64), dtype=torch.float32)
c = torch.rand((2, 1, 64), dtype=torch.float32)
m = torch.jit.script(m)
torch.onnx.export(
m,
(x, h, c),
"m.onnx",
input_names=["x", "h", "c"],
output_names=["prob", "next_h", "next_c"],
)
print("simplifying ...")
model = onnx.load("m.onnx")
meta_data = {
"model_type": "silero-vad-v4",
"sample_rate": 16000,
"version": 4,
"h_shape": "2,1,64",
"c_shape": "2,1,64",
}
while len(model.metadata_props):
model.metadata_props.pop()
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
print("--------------------")
print(model.metadata_props)
model_simp, check = simplify(model)
onnx.save(model_simp, "m.onnx")
if __name__ == "__main__":
main()