export-onnx.py 3.1 KB
#!/usr/bin/env python3
# Copyright    2024  Xiaomi Corp.        (authors: Fangjun Kuang)

import argparse
from typing import Dict

import nemo.collections.asr as nemo_asr
import onnx
import torch


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model",
        type=str,
        required=True,
        choices=[
            "speakerverification_speakernet",
            "titanet_large",
            "titanet_small",
            "ecapa_tdnn",
        ],
    )
    return parser.parse_args()


def add_meta_data(filename: str, meta_data: Dict[str, str]):
    """Add meta data to an ONNX model. It is changed in-place.

    Args:
      filename:
        Filename of the ONNX model to be changed.
      meta_data:
        Key-value pairs.
    """
    model = onnx.load(filename)
    for key, value in meta_data.items():
        meta = model.metadata_props.add()
        meta.key = key
        meta.value = str(value)

    onnx.save(model, filename)


@torch.no_grad()
def main():
    args = get_args()
    speaker_model_config = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(
        model_name=args.model, return_config=True
    )
    preprocessor_config = speaker_model_config["preprocessor"]

    print(args.model)
    print(speaker_model_config)
    print(preprocessor_config)

    assert preprocessor_config["n_fft"] == 512, preprocessor_config

    assert (
        preprocessor_config["_target_"]
        == "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor"
    ), preprocessor_config

    assert preprocessor_config["frame_splicing"] == 1, preprocessor_config

    speaker_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(
        model_name=args.model
    )
    speaker_model.eval()
    filename = f"nemo_en_{args.model}.onnx"
    speaker_model.export(filename)

    print(f"Adding metadata to {filename}")

    comment = "This model is from NeMo."
    url = {
        "titanet_large": "https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/titanet_large",
        "titanet_small": "https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/titanet_small",
        "speakerverification_speakernet": "https://ngc.nvidia.com/catalog/models/nvidia:nemo:speakerverification_speakernet",
        "ecapa_tdnn": "https://ngc.nvidia.com/catalog/models/nvidia:nemo:ecapa_tdnn",
    }[args.model]

    language = "English"

    meta_data = {
        "framework": "nemo",
        "language": language,
        "url": url,
        "comment": comment,
        "sample_rate": preprocessor_config["sample_rate"],
        "output_dim": speaker_model_config["decoder"]["emb_sizes"],
        "feature_normalize_type": preprocessor_config["normalize"],
        "window_size_ms": int(float(preprocessor_config["window_size"]) * 1000),
        "window_stride_ms": int(float(preprocessor_config["window_stride"]) * 1000),
        "window_type": preprocessor_config["window"],  # e.g., hann
        "feat_dim": preprocessor_config["features"],
    }
    print(meta_data)
    add_meta_data(filename=filename, meta_data=meta_data)


if __name__ == "__main__":
    main()