add_meta_data.py 3.2 KB
#!/usr/bin/env python3
# Copyright      2023  Xiaomi Corp.        (authors: Fangjun Kuang)

"""
This script adds meta data to a model so that it can be used in sherpa-onnx.

Usage:
./add_meta_data.py --model ./voxceleb_resnet34.onnx  --language English
"""

import argparse
from pathlib import Path
from typing import Dict

import onnx
import onnxruntime


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model",
        type=str,
        required=True,
        help="Path to the input onnx model. Example value: model.onnx",
    )

    parser.add_argument(
        "--language",
        type=str,
        required=True,
        help="""Supported language of the input model.
        Example value: Chinese, English.
        """,
    )

    parser.add_argument(
        "--url",
        type=str,
        default="https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md",
        help="Where the model is downloaded",
    )

    parser.add_argument(
        "--comment",
        type=str,
        default="no comment",
        help="Comment about the model",
    )

    parser.add_argument(
        "--sample-rate",
        type=int,
        default=16000,
        help="Sample rate expected by the model",
    )

    return parser.parse_args()


def add_meta_data(filename: str, meta_data: Dict[str, str]):
    """Add meta data to an ONNX model. It is changed in-place.

    Args:
      filename:
        Filename of the ONNX model to be changed.
      meta_data:
        Key-value pairs.
    """
    model = onnx.load(filename)
    for key, value in meta_data.items():
        meta = model.metadata_props.add()
        meta.key = key
        meta.value = str(value)

    onnx.save(model, filename)


def get_output_dim(filename) -> int:
    filename = str(filename)
    session_opts = onnxruntime.SessionOptions()
    session_opts.log_severity_level = 3  # error level
    sess = onnxruntime.InferenceSession(filename, session_opts)

    for i in sess.get_inputs():
        print(i)

    print("----------")

    for o in sess.get_outputs():
        print(o)

    print("----------")

    assert len(sess.get_inputs()) == 1
    assert len(sess.get_outputs()) == 1

    i = sess.get_inputs()[0]
    o = sess.get_outputs()[0]

    assert i.shape[:2] == ["B", "T"], i.shape
    assert o.shape[0] == "B"

    assert i.shape[2] == 80, i.shape

    return o.shape[1]


def main():
    args = get_args()
    model = Path(args.model)
    language = args.language
    url = args.url
    comment = args.comment
    sample_rate = args.sample_rate

    if not model.is_file():
        raise ValueError(f"{model} does not exist")

    assert len(language) > 0, len(language)
    assert len(url) > 0, len(url)

    output_dim = get_output_dim(model)

    # all models from wespeaker expect input samples in the range
    # [-32768, 32767]
    normalize_samples = 0

    meta_data = {
        "framework": "wespeaker",
        "language": language,
        "url": url,
        "comment": comment,
        "sample_rate": sample_rate,
        "output_dim": output_dim,
        "normalize_samples": normalize_samples,
    }
    print(meta_data)
    add_meta_data(filename=str(model), meta_data=meta_data)


if __name__ == "__main__":
    main()