export-onnx-transducer.py
4.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import argparse
from typing import Dict
import nemo.collections.asr as nemo_asr
import onnx
import torch
from onnxruntime.quantization import QuantType, quantize_dynamic
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
required=True,
choices=["80", "480", "1040"],
)
return parser.parse_args()
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
while len(model.metadata_props):
model.metadata_props.pop()
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
@torch.no_grad()
def main():
args = get_args()
model_name = f"stt_en_fastconformer_hybrid_large_streaming_{args.model}ms"
asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=model_name)
with open("./tokens.txt", "w", encoding="utf-8") as f:
for i, s in enumerate(asr_model.joint.vocabulary):
f.write(f"{s} {i}\n")
f.write(f"<blk> {i+1}\n")
print("Saved to tokens.txt")
decoder_type = "rnnt"
asr_model.change_decoding_strategy(decoder_type=decoder_type)
asr_model.eval()
assert asr_model.encoder.streaming_cfg is not None
if isinstance(asr_model.encoder.streaming_cfg.chunk_size, list):
chunk_size = asr_model.encoder.streaming_cfg.chunk_size[1]
else:
chunk_size = asr_model.encoder.streaming_cfg.chunk_size
if isinstance(asr_model.encoder.streaming_cfg.pre_encode_cache_size, list):
pre_encode_cache_size = asr_model.encoder.streaming_cfg.pre_encode_cache_size[1]
else:
pre_encode_cache_size = asr_model.encoder.streaming_cfg.pre_encode_cache_size
window_size = chunk_size + pre_encode_cache_size
print("chunk_size", chunk_size)
print("pre_encode_cache_size", pre_encode_cache_size)
print("window_size", window_size)
chunk_shift = chunk_size
# cache_last_channel: (batch_size, dim1, dim2, dim3)
cache_last_channel_dim1 = len(asr_model.encoder.layers)
cache_last_channel_dim2 = asr_model.encoder.streaming_cfg.last_channel_cache_size
cache_last_channel_dim3 = asr_model.encoder.d_model
# cache_last_time: (batch_size, dim1, dim2, dim3)
cache_last_time_dim1 = len(asr_model.encoder.layers)
cache_last_time_dim2 = asr_model.encoder.d_model
cache_last_time_dim3 = asr_model.encoder.conv_context_size[0]
asr_model.set_export_config({"decoder_type": "rnnt", "cache_support": True})
# asr_model.export("model.onnx")
asr_model.encoder.export("encoder.onnx")
asr_model.decoder.export("decoder.onnx")
asr_model.joint.export("joiner.onnx")
# model.onnx is a suffix.
# It will generate two files:
# encoder-model.onnx
# decoder_joint-model.onnx
normalize_type = asr_model.cfg.preprocessor.normalize
if normalize_type == "NA":
normalize_type = ""
meta_data = {
"vocab_size": asr_model.decoder.vocab_size,
"window_size": window_size,
"chunk_shift": chunk_shift,
"normalize_type": normalize_type,
"cache_last_channel_dim1": cache_last_channel_dim1,
"cache_last_channel_dim2": cache_last_channel_dim2,
"cache_last_channel_dim3": cache_last_channel_dim3,
"cache_last_time_dim1": cache_last_time_dim1,
"cache_last_time_dim2": cache_last_time_dim2,
"cache_last_time_dim3": cache_last_time_dim3,
"pred_rnn_layers": asr_model.decoder.pred_rnn_layers,
"pred_hidden": asr_model.decoder.pred_hidden,
"subsampling_factor": 8,
"model_type": "EncDecHybridRNNTCTCBPEModel",
"version": "1",
"model_author": "NeMo",
"url": f"https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/{model_name}",
"comment": "Only the transducer branch is exported",
}
add_meta_data("encoder.onnx", meta_data)
for m in ["encoder", "decoder", "joiner"]:
quantize_dynamic(
model_input=f"{m}.onnx",
model_output=f"{m}.int8.onnx",
weight_type=QuantType.QUInt8,
)
print(meta_data)
if __name__ == "__main__":
main()