export_onnx.py 2.3 KB
#!/usr/bin/env python3
# Copyright    2025  Xiaomi Corp.        (authors: Fangjun Kuang)

import onnx
import onnxmltools
import torch
from onnxmltools.utils.float16_converter import convert_float_to_float16
from onnxruntime.quantization import QuantType, quantize_dynamic

from unet import UNet


def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path):
    onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path)
    onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True)
    onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)


def add_meta_data(filename, prefix):
    meta_data = {
        "model_type": "spleeter",
        "sample_rate": 41000,
        "version": 1,
        "model_url": "https://github.com/deezer/spleeter",
        "stems": 2,
        "comment": prefix,
        "model_name": "2stems.tar.gz",
    }
    model = onnx.load(filename)

    print(model.metadata_props)

    while len(model.metadata_props):
        model.metadata_props.pop()

    for key, value in meta_data.items():
        meta = model.metadata_props.add()
        meta.key = key
        meta.value = str(value)
    print("--------------------")

    print(model.metadata_props)

    onnx.save(model, filename)


def export(model, prefix):
    num_splits = 1
    x = torch.rand(2, num_splits, 512, 1024, dtype=torch.float32)

    filename = f"./2stems/{prefix}.onnx"
    torch.onnx.export(
        model,
        x,
        filename,
        input_names=["x"],
        output_names=["y"],
        dynamic_axes={
            "x": {1: "num_splits"},
        },
        opset_version=13,
    )

    add_meta_data(filename, prefix)

    filename_int8 = f"./2stems/{prefix}.int8.onnx"
    quantize_dynamic(
        model_input=filename,
        model_output=filename_int8,
        weight_type=QuantType.QUInt8,
    )

    filename_fp16 = f"./2stems/{prefix}.fp16.onnx"
    export_onnx_fp16(filename, filename_fp16)


@torch.no_grad()
def main():
    vocals = UNet()
    state_dict = torch.load("./2stems/vocals.pt", map_location="cpu")
    vocals.load_state_dict(state_dict)
    vocals.eval()

    accompaniment = UNet()
    state_dict = torch.load("./2stems/accompaniment.pt", map_location="cpu")
    accompaniment.load_state_dict(state_dict)
    accompaniment.eval()

    export(vocals, "vocals")
    export(accompaniment, "accompaniment")


if __name__ == "__main__":
    main()