add_meta_data.py 2.0 KB
#!/usr/bin/env python3
# Copyright    2025  Xiaomi Corp.        (authors: Fangjun Kuang)

"""
NodeArg(name='mix', type='tensor(float)', shape=[1, 257, 1, 2])
NodeArg(name='conv_cache', type='tensor(float)', shape=[2, 1, 16, 16, 33])
NodeArg(name='tra_cache', type='tensor(float)', shape=[2, 3, 1, 1, 16])
NodeArg(name='inter_cache', type='tensor(float)', shape=[2, 1, 33, 16])
-----
NodeArg(name='enh', type='tensor(float)', shape=[1, 257, 1, 2])
NodeArg(name='conv_cache_out', type='tensor(float)', shape=[2, 1, 16, 16, 33])
NodeArg(name='tra_cache_out', type='tensor(float)', shape=[2, 3, 1, 1, 16])
NodeArg(name='inter_cache_out', type='tensor(float)', shape=[2, 1, 33, 16])
"""

import onnx
import onnxruntime as ort


def show(filename):
    session_opts = ort.SessionOptions()
    session_opts.log_severity_level = 3
    sess = ort.InferenceSession(filename, session_opts)
    for i in sess.get_inputs():
        print(i)

    print("-----")

    for i in sess.get_outputs():
        print(i)


def main():
    filename = "./gtcrn_simple.onnx"
    show(filename)
    model = onnx.load(filename)

    meta_data = {
        "model_type": "gtcrn",
        "comment": "gtcrn_simple",
        "version": 1,
        "sample_rate": 16000,
        "model_url": "https://github.com/Xiaobin-Rong/gtcrn/blob/main/stream/onnx_models/gtcrn_simple.onnx",
        "maintainer": "k2-fsa",
        "comment2": "Please see also https://github.com/Xiaobin-Rong/gtcrn",
        "conv_cache_shape": "2,1,16,16,33",
        "tra_cache_shape": "2,3,1,1,16",
        "inter_cache_shape": "2,1,33,16",
        "n_fft": 512,
        "hop_length": 256,
        "window_length": 512,
        "window_type": "hann_sqrt",
    }

    print(model.metadata_props)

    while len(model.metadata_props):
        model.metadata_props.pop()

    for key, value in meta_data.items():
        meta = model.metadata_props.add()
        meta.key = key
        meta.value = str(value)
    print("--------------------")

    print(model.metadata_props)

    onnx.save(model, filename)


if __name__ == "__main__":
    main()