dynamic_quantization.py
2.0 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
import argparse
import onnxmltools
from onnxmltools.utils.float16_converter import convert_float_to_float16
from onnxruntime.quantization import QuantType, quantize_dynamic
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input",
type=str,
required=True,
)
parser.add_argument(
"--output-fp16",
type=str,
required=True,
)
parser.add_argument(
"--output-int8",
type=str,
required=True,
)
return parser.parse_args()
# for op_block_list, see also
# https://github.com/microsoft/onnxruntime/blob/089c52e4522491312e6839af146a276f2351972e/onnxruntime/python/tools/transformers/float16.py#L115
#
# libc++abi: terminating with uncaught exception of type Ort::Exception:
# Type Error: Type (tensor(float16)) of output arg (/dp/RandomNormalLike_output_0)
# of node (/dp/RandomNormalLike) does not match expected type (tensor(float)).
#
# libc++abi: terminating with uncaught exception of type Ort::Exception:
# This is an invalid model. Type Error: Type 'tensor(float16)' of input
# parameter (/enc_p/encoder/attn_layers.0/Constant_84_output_0) of
# operator (Range) in node (/Range_1) is invalid.
def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path):
onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path)
onnx_fp16_model = convert_float_to_float16(
onnx_fp32_model,
keep_io_types=True,
op_block_list=[
"RandomNormalLike",
"Range",
],
)
onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)
def main():
args = get_args()
print(args)
in_filename = args.input
output_fp16 = args.output_fp16
output_int8 = args.output_int8
quantize_dynamic(
model_input=in_filename,
model_output=output_int8,
weight_type=QuantType.QUInt8,
)
export_onnx_fp16(in_filename, output_fp16)
if __name__ == "__main__":
main()