export-rknn.py 3.0 KB
#!/usr/bin/env python3
# Copyright (c)  2025  Xiaomi Corporation (authors: Fangjun Kuang)

import argparse
import logging
from pathlib import Path

from rknn.api import RKNN

logging.basicConfig(level=logging.WARNING)

g_platforms = [
    #  "rv1103",
    #  "rv1103b",
    #  "rv1106",
    #  "rk2118",
    "rk3562",
    "rk3566",
    "rk3568",
    "rk3576",
    "rk3588",
]


def get_parser():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument(
        "--target-platform",
        type=str,
        required=True,
        help=f"Supported values are: {','.join(g_platforms)}",
    )

    parser.add_argument(
        "--in-model",
        type=str,
        required=True,
        help="Path to the input onnx model",
    )

    parser.add_argument(
        "--out-model",
        type=str,
        required=True,
        help="Path to the output rknn model",
    )

    return parser


def get_meta_data(model: str):
    import onnxruntime

    session_opts = onnxruntime.SessionOptions()
    session_opts.inter_op_num_threads = 1
    session_opts.intra_op_num_threads = 1

    m = onnxruntime.InferenceSession(
        model,
        sess_options=session_opts,
        providers=["CPUExecutionProvider"],
    )

    for i in m.get_inputs():
        print(i)

    print("-----")

    for i in m.get_outputs():
        print(i)
    print()

    meta = m.get_modelmeta().custom_metadata_map
    s = ""
    sep = ""
    for key, value in meta.items():
        s = s + sep + f"{key}={value}"
        sep = ";"
    assert len(s) < 1024

    return s


def export_rknn(rknn, filename):
    ret = rknn.export_rknn(filename)
    if ret != 0:
        exit("Export rknn model to {filename} failed!")


def init_model(filename: str, target_platform: str, custom_string=None):
    rknn = RKNN(verbose=False)

    rknn.config(
        optimization_level=0,
        target_platform=target_platform,
        custom_string=custom_string,
    )
    if not Path(filename).is_file():
        exit(f"{filename} does not exist")

    ret = rknn.load_onnx(model=filename)
    if ret != 0:
        exit(f"Load model {filename} failed!")

    ret = rknn.build(do_quantization=False)
    if ret != 0:
        exit("Build model {filename} failed!")

    return rknn


class RKNNModel:
    def __init__(
        self,
        model: str,
        target_platform: str,
    ):
        meta = get_meta_data(model)
        print(meta)

        self.model = init_model(
            model,
            target_platform=target_platform,
            custom_string=meta,
        )

    def export_rknn(self, model):
        export_rknn(self.model, model)

    def release(self):
        self.model.release()


def main():
    args = get_parser().parse_args()
    print(vars(args))

    model = RKNNModel(
        model=args.in_model,
        target_platform=args.target_platform,
    )

    model.export_rknn(
        model=args.out_model,
    )

    model.release()


if __name__ == "__main__":
    main()