export_onnx_180m_flash.py
11.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
"""
<|en|>
<|pnc|>
<|noitn|>
<|nodiarize|>
<|notimestamp|>
"""
import os
from typing import Dict, Tuple
import nemo
import onnx
import torch
from nemo.collections.common.parts import NEG_INF
from onnxruntime.quantization import QuantType, quantize_dynamic
"""
NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED :
Could not find an implementation for Trilu(14) node with name '/Trilu'
See also https://github.com/microsoft/onnxruntime/issues/16189#issuecomment-1722219631
So we use fixed_form_attention_mask() to replace
the original form_attention_mask()
"""
def fixed_form_attention_mask(input_mask, diagonal=None):
"""
Fixed: Build attention mask with optional masking of future tokens we forbid
to attend to (e.g. as it is in Transformer decoder).
Args:
input_mask: binary mask of size B x L with 1s corresponding to valid
tokens and 0s corresponding to padding tokens
diagonal: diagonal where triangular future mask starts
None -- do not mask anything
0 -- regular translation or language modeling future masking
1 -- query stream masking as in XLNet architecture
Returns:
attention_mask: mask of size B x 1 x L x L with 0s corresponding to
tokens we plan to attend to and -10000 otherwise
"""
if input_mask is None:
return None
attn_shape = (1, input_mask.shape[1], input_mask.shape[1])
attn_mask = input_mask.to(dtype=bool).unsqueeze(1)
if diagonal is not None:
future_mask = torch.tril(
torch.ones(
attn_shape,
dtype=torch.int64, # it was torch.bool
# but onnxruntime does not support torch.int32 or torch.bool
# in torch.tril
device=input_mask.device,
),
diagonal,
).bool()
attn_mask = attn_mask & future_mask
attention_mask = (1 - attn_mask.to(torch.float)) * NEG_INF
return attention_mask.unsqueeze(1)
nemo.collections.common.parts.form_attention_mask = fixed_form_attention_mask
from nemo.collections.asr.models import EncDecMultiTaskModel
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
while len(model.metadata_props):
model.metadata_props.pop()
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
def lens_to_mask(lens, max_length):
"""
Create a mask from a tensor of lengths.
"""
batch_size = lens.shape[0]
arange = torch.arange(max_length, device=lens.device)
mask = arange.expand(batch_size, max_length) < lens.unsqueeze(1)
return mask
class EncoderWrapper(torch.nn.Module):
def __init__(self, m):
super().__init__()
self.encoder = m.encoder
self.encoder_decoder_proj = m.encoder_decoder_proj
def forward(
self, x: torch.Tensor, x_len: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x: (N, T, C)
x_len: (N,)
Returns:
- enc_states: (N, T, C)
- encoded_len: (N,)
- enc_mask: (N, T)
"""
x = x.permute(0, 2, 1)
# x: (N, C, T)
encoded, encoded_len = self.encoder(audio_signal=x, length=x_len)
enc_states = encoded.permute(0, 2, 1)
enc_states = self.encoder_decoder_proj(enc_states)
enc_mask = lens_to_mask(encoded_len, enc_states.shape[1])
return enc_states, encoded_len, enc_mask
class DecoderWrapper(torch.nn.Module):
def __init__(self, m):
super().__init__()
self.decoder = m.transf_decoder
self.log_softmax = m.log_softmax
# We use only greedy search, so there is no need to compute log_softmax
self.log_softmax.mlp.log_softmax = False
def forward(
self,
decoder_input_ids: torch.Tensor,
decoder_mems_list_0: torch.Tensor,
decoder_mems_list_1: torch.Tensor,
decoder_mems_list_2: torch.Tensor,
decoder_mems_list_3: torch.Tensor,
decoder_mems_list_4: torch.Tensor,
decoder_mems_list_5: torch.Tensor,
enc_states: torch.Tensor,
enc_mask: torch.Tensor,
):
"""
Args:
decoder_input_ids: (N, num_tokens), torch.int32
decoder_mems_list_i: (N, num_tokens, 1024)
enc_states: (N, T, 1024)
enc_mask: (N, T)
Returns:
- logits: (N, 1, vocab_size)
- decoder_mems_list_i: (N, num_tokens_2, 1024)
"""
pos = decoder_input_ids[0][-1].item()
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_hidden_states = self.decoder.embedding.forward(
decoder_input_ids, start_pos=pos
)
decoder_input_mask = torch.ones_like(decoder_input_ids).float()
decoder_mems_list = self.decoder.decoder.forward(
decoder_hidden_states,
decoder_input_mask,
enc_states,
enc_mask,
[
decoder_mems_list_0,
decoder_mems_list_1,
decoder_mems_list_2,
decoder_mems_list_3,
decoder_mems_list_4,
decoder_mems_list_5,
],
return_mems=True,
)
logits = self.log_softmax(hidden_states=decoder_mems_list[-1][:, -1:])
return logits, decoder_mems_list
def export_encoder(canary_model):
encoder = EncoderWrapper(canary_model)
x = torch.rand(1, 4000, 128)
x_lens = torch.tensor([x.shape[1]], dtype=torch.int64)
encoder_filename = "encoder.onnx"
torch.onnx.export(
encoder,
(x, x_lens),
encoder_filename,
input_names=["x", "x_len"],
output_names=["enc_states", "enc_len", "enc_mask"],
opset_version=14,
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_len": {0: "N"},
"enc_states": {0: "N", 1: "T"},
"enc_len": {0: "N"},
"enc_mask": {0: "N", 1: "T"},
},
)
def export_decoder(canary_model):
decoder = DecoderWrapper(canary_model)
decoder_input_ids = torch.tensor([[1, 0]], dtype=torch.int32)
decoder_mems_list_0 = torch.zeros(1, 10, 1024)
decoder_mems_list_1 = torch.zeros(1, 10, 1024)
decoder_mems_list_2 = torch.zeros(1, 10, 1024)
decoder_mems_list_3 = torch.zeros(1, 10, 1024)
decoder_mems_list_4 = torch.zeros(1, 10, 1024)
decoder_mems_list_5 = torch.zeros(1, 10, 1024)
enc_states = torch.zeros(1, 1000, 1024)
enc_mask = torch.ones(1, 1000).bool()
torch.onnx.export(
decoder,
(
decoder_input_ids,
decoder_mems_list_0,
decoder_mems_list_1,
decoder_mems_list_2,
decoder_mems_list_3,
decoder_mems_list_4,
decoder_mems_list_5,
enc_states,
enc_mask,
),
"decoder.onnx",
dynamo=True,
opset_version=14,
external_data=False,
input_names=[
"decoder_input_ids",
"decoder_mems_list_0",
"decoder_mems_list_1",
"decoder_mems_list_2",
"decoder_mems_list_3",
"decoder_mems_list_4",
"decoder_mems_list_5",
"enc_states",
"enc_mask",
],
output_names=[
"logits",
"next_decoder_mem_list_0",
"next_decoder_mem_list_1",
"next_decoder_mem_list_2",
"next_decoder_mem_list_3",
"next_decoder_mem_list_4",
"next_decoder_mem_list_5",
],
dynamic_axes={
"decoder_input_ids": {1: "num_tokens"},
"decoder_mems_list_0": {1: "num_tokens"},
"decoder_mems_list_1": {1: "num_tokens"},
"decoder_mems_list_2": {1: "num_tokens"},
"decoder_mems_list_3": {1: "num_tokens"},
"decoder_mems_list_4": {1: "num_tokens"},
"decoder_mems_list_5": {1: "num_tokens"},
"enc_states": {1: "T"},
"enc_mask": {1: "T"},
},
)
def export_tokens(canary_model):
underline = "▁"
with open("./tokens.txt", "w", encoding="utf-8") as f:
for i in range(canary_model.tokenizer.vocab_size):
s = canary_model.tokenizer.ids_to_text([i])
if s[0] == " ":
s = underline + s[1:]
f.write(f"{s} {i}\n")
print("Saved to tokens.txt")
@torch.no_grad()
def main():
canary_model = EncDecMultiTaskModel.from_pretrained("nvidia/canary-180m-flash")
canary_model.eval()
preprocessor = canary_model.cfg["preprocessor"]
sample_rate = preprocessor["sample_rate"]
normalize_type = preprocessor["normalize"]
window_size = preprocessor["window_size"] # ms
window_stride = preprocessor["window_stride"] # ms
window = preprocessor["window"]
features = preprocessor["features"]
n_fft = preprocessor["n_fft"]
vocab_size = canary_model.tokenizer.vocab_size # 5248
subsampling_factor = canary_model.cfg["encoder"]["subsampling_factor"]
assert sample_rate == 16000, sample_rate
assert normalize_type == "per_feature", normalize_type
assert window_size == 0.025, window_size
assert window_stride == 0.01, window_stride
assert window == "hann", window
assert features == 128, features
assert n_fft == 512, n_fft
assert subsampling_factor == 8, subsampling_factor
export_tokens(canary_model)
export_encoder(canary_model)
export_decoder(canary_model)
for m in ["encoder", "decoder"]:
quantize_dynamic(
model_input=f"./{m}.onnx",
model_output=f"./{m}.int8.onnx",
weight_type=QuantType.QUInt8,
)
meta_data = {
"vocab_size": vocab_size,
"normalize_type": normalize_type,
"subsampling_factor": subsampling_factor,
"model_type": "EncDecMultiTaskModel",
"version": "1",
"model_author": "NeMo",
"url": "https://huggingface.co/nvidia/canary-180m-flash",
"feat_dim": features,
}
add_meta_data("encoder.onnx", meta_data)
add_meta_data("encoder.int8.onnx", meta_data)
"""
To fix the following error with onnxruntime 1.17.1 and 1.16.3:
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 :FAIL : Load model from ./decoder.int8.onnx failed:/Users/runner/work/1/s/onnxruntime/core/graph/model.cc:150 onnxruntime::Model::Model(onnx::ModelProto &&, const onnxruntime::PathString &, const onnxruntime::IOnnxRuntimeOpSchemaRegistryList *, const logging::Logger &, const onnxruntime::ModelOptions &)
Unsupported model IR version: 10, max supported IR version: 9
"""
for filename in ["./decoder.onnx", "./decoder.int8.onnx"]:
model = onnx.load(filename)
print("old", model.ir_version)
model.ir_version = 9
print("new", model.ir_version)
onnx.save(model, filename)
os.system("ls -lh *.onnx")
if __name__ == "__main__":
main()