export_onnx_180m_flash.py 11.2 KB
#!/usr/bin/env python3
# Copyright      2025  Xiaomi Corp.        (authors: Fangjun Kuang)

"""
<|en|>
<|pnc|>
<|noitn|>
<|nodiarize|>
<|notimestamp|>
"""

import os
from typing import Dict, Tuple

import nemo
import onnx
import torch
from nemo.collections.common.parts import NEG_INF
from onnxruntime.quantization import QuantType, quantize_dynamic

"""
NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED :
Could not find an implementation for Trilu(14) node with name '/Trilu'

See also https://github.com/microsoft/onnxruntime/issues/16189#issuecomment-1722219631

So we use fixed_form_attention_mask() to replace
the original form_attention_mask()
"""


def fixed_form_attention_mask(input_mask, diagonal=None):
    """
    Fixed: Build attention mask with optional masking of future tokens we forbid
    to attend to (e.g. as it is in Transformer decoder).

    Args:
        input_mask: binary mask of size B x L with 1s corresponding to valid
            tokens and 0s corresponding to padding tokens
        diagonal: diagonal where triangular future mask starts
            None -- do not mask anything
            0 -- regular translation or language modeling future masking
            1 -- query stream masking as in XLNet architecture
    Returns:
        attention_mask: mask of size B x 1 x L x L with 0s corresponding to
            tokens we plan to attend to and -10000 otherwise
    """

    if input_mask is None:
        return None
    attn_shape = (1, input_mask.shape[1], input_mask.shape[1])
    attn_mask = input_mask.to(dtype=bool).unsqueeze(1)
    if diagonal is not None:
        future_mask = torch.tril(
            torch.ones(
                attn_shape,
                dtype=torch.int64,  # it was torch.bool
                # but onnxruntime does not support torch.int32 or torch.bool
                # in torch.tril
                device=input_mask.device,
            ),
            diagonal,
        ).bool()
        attn_mask = attn_mask & future_mask
    attention_mask = (1 - attn_mask.to(torch.float)) * NEG_INF
    return attention_mask.unsqueeze(1)


nemo.collections.common.parts.form_attention_mask = fixed_form_attention_mask

from nemo.collections.asr.models import EncDecMultiTaskModel


def add_meta_data(filename: str, meta_data: Dict[str, str]):
    """Add meta data to an ONNX model. It is changed in-place.

    Args:
      filename:
        Filename of the ONNX model to be changed.
      meta_data:
        Key-value pairs.
    """
    model = onnx.load(filename)
    while len(model.metadata_props):
        model.metadata_props.pop()

    for key, value in meta_data.items():
        meta = model.metadata_props.add()
        meta.key = key
        meta.value = str(value)

    onnx.save(model, filename)


def lens_to_mask(lens, max_length):
    """
    Create a mask from a tensor of lengths.
    """
    batch_size = lens.shape[0]
    arange = torch.arange(max_length, device=lens.device)
    mask = arange.expand(batch_size, max_length) < lens.unsqueeze(1)
    return mask


class EncoderWrapper(torch.nn.Module):
    def __init__(self, m):
        super().__init__()
        self.encoder = m.encoder
        self.encoder_decoder_proj = m.encoder_decoder_proj

    def forward(
        self, x: torch.Tensor, x_len: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
          x: (N, T, C)
          x_len: (N,)
        Returns:
          - enc_states: (N, T, C)
          - encoded_len: (N,)
          - enc_mask: (N, T)
        """
        x = x.permute(0, 2, 1)
        # x: (N, C, T)
        encoded, encoded_len = self.encoder(audio_signal=x, length=x_len)

        enc_states = encoded.permute(0, 2, 1)

        enc_states = self.encoder_decoder_proj(enc_states)

        enc_mask = lens_to_mask(encoded_len, enc_states.shape[1])

        return enc_states, encoded_len, enc_mask


class DecoderWrapper(torch.nn.Module):
    def __init__(self, m):
        super().__init__()
        self.decoder = m.transf_decoder
        self.log_softmax = m.log_softmax

        # We use only greedy search, so there is no need to compute log_softmax
        self.log_softmax.mlp.log_softmax = False

    def forward(
        self,
        decoder_input_ids: torch.Tensor,
        decoder_mems_list_0: torch.Tensor,
        decoder_mems_list_1: torch.Tensor,
        decoder_mems_list_2: torch.Tensor,
        decoder_mems_list_3: torch.Tensor,
        decoder_mems_list_4: torch.Tensor,
        decoder_mems_list_5: torch.Tensor,
        enc_states: torch.Tensor,
        enc_mask: torch.Tensor,
    ):
        """
        Args:
          decoder_input_ids: (N, num_tokens), torch.int32
          decoder_mems_list_i: (N, num_tokens, 1024)
          enc_states: (N, T, 1024)
          enc_mask: (N, T)
        Returns:
          - logits: (N, 1, vocab_size)
          - decoder_mems_list_i: (N, num_tokens_2, 1024)
        """
        pos = decoder_input_ids[0][-1].item()
        decoder_input_ids = decoder_input_ids[:, :-1]

        decoder_hidden_states = self.decoder.embedding.forward(
            decoder_input_ids, start_pos=pos
        )
        decoder_input_mask = torch.ones_like(decoder_input_ids).float()

        decoder_mems_list = self.decoder.decoder.forward(
            decoder_hidden_states,
            decoder_input_mask,
            enc_states,
            enc_mask,
            [
                decoder_mems_list_0,
                decoder_mems_list_1,
                decoder_mems_list_2,
                decoder_mems_list_3,
                decoder_mems_list_4,
                decoder_mems_list_5,
            ],
            return_mems=True,
        )
        logits = self.log_softmax(hidden_states=decoder_mems_list[-1][:, -1:])

        return logits, decoder_mems_list


def export_encoder(canary_model):
    encoder = EncoderWrapper(canary_model)
    x = torch.rand(1, 4000, 128)
    x_lens = torch.tensor([x.shape[1]], dtype=torch.int64)

    encoder_filename = "encoder.onnx"
    torch.onnx.export(
        encoder,
        (x, x_lens),
        encoder_filename,
        input_names=["x", "x_len"],
        output_names=["enc_states", "enc_len", "enc_mask"],
        opset_version=14,
        dynamic_axes={
            "x": {0: "N", 1: "T"},
            "x_len": {0: "N"},
            "enc_states": {0: "N", 1: "T"},
            "enc_len": {0: "N"},
            "enc_mask": {0: "N", 1: "T"},
        },
    )


def export_decoder(canary_model):
    decoder = DecoderWrapper(canary_model)
    decoder_input_ids = torch.tensor([[1, 0]], dtype=torch.int32)

    decoder_mems_list_0 = torch.zeros(1, 10, 1024)
    decoder_mems_list_1 = torch.zeros(1, 10, 1024)
    decoder_mems_list_2 = torch.zeros(1, 10, 1024)
    decoder_mems_list_3 = torch.zeros(1, 10, 1024)
    decoder_mems_list_4 = torch.zeros(1, 10, 1024)
    decoder_mems_list_5 = torch.zeros(1, 10, 1024)

    enc_states = torch.zeros(1, 1000, 1024)
    enc_mask = torch.ones(1, 1000).bool()

    torch.onnx.export(
        decoder,
        (
            decoder_input_ids,
            decoder_mems_list_0,
            decoder_mems_list_1,
            decoder_mems_list_2,
            decoder_mems_list_3,
            decoder_mems_list_4,
            decoder_mems_list_5,
            enc_states,
            enc_mask,
        ),
        "decoder.onnx",
        dynamo=True,
        opset_version=14,
        external_data=False,
        input_names=[
            "decoder_input_ids",
            "decoder_mems_list_0",
            "decoder_mems_list_1",
            "decoder_mems_list_2",
            "decoder_mems_list_3",
            "decoder_mems_list_4",
            "decoder_mems_list_5",
            "enc_states",
            "enc_mask",
        ],
        output_names=[
            "logits",
            "next_decoder_mem_list_0",
            "next_decoder_mem_list_1",
            "next_decoder_mem_list_2",
            "next_decoder_mem_list_3",
            "next_decoder_mem_list_4",
            "next_decoder_mem_list_5",
        ],
        dynamic_axes={
            "decoder_input_ids": {1: "num_tokens"},
            "decoder_mems_list_0": {1: "num_tokens"},
            "decoder_mems_list_1": {1: "num_tokens"},
            "decoder_mems_list_2": {1: "num_tokens"},
            "decoder_mems_list_3": {1: "num_tokens"},
            "decoder_mems_list_4": {1: "num_tokens"},
            "decoder_mems_list_5": {1: "num_tokens"},
            "enc_states": {1: "T"},
            "enc_mask": {1: "T"},
        },
    )


def export_tokens(canary_model):
    underline = "▁"
    with open("./tokens.txt", "w", encoding="utf-8") as f:
        for i in range(canary_model.tokenizer.vocab_size):
            s = canary_model.tokenizer.ids_to_text([i])

            if s[0] == " ":
                s = underline + s[1:]

            f.write(f"{s} {i}\n")
        print("Saved to tokens.txt")


@torch.no_grad()
def main():
    canary_model = EncDecMultiTaskModel.from_pretrained("nvidia/canary-180m-flash")
    canary_model.eval()

    preprocessor = canary_model.cfg["preprocessor"]
    sample_rate = preprocessor["sample_rate"]
    normalize_type = preprocessor["normalize"]
    window_size = preprocessor["window_size"]  # ms
    window_stride = preprocessor["window_stride"]  # ms
    window = preprocessor["window"]
    features = preprocessor["features"]
    n_fft = preprocessor["n_fft"]
    vocab_size = canary_model.tokenizer.vocab_size  # 5248

    subsampling_factor = canary_model.cfg["encoder"]["subsampling_factor"]

    assert sample_rate == 16000, sample_rate
    assert normalize_type == "per_feature", normalize_type
    assert window_size == 0.025, window_size
    assert window_stride == 0.01, window_stride
    assert window == "hann", window
    assert features == 128, features
    assert n_fft == 512, n_fft
    assert subsampling_factor == 8, subsampling_factor

    export_tokens(canary_model)
    export_encoder(canary_model)
    export_decoder(canary_model)

    for m in ["encoder", "decoder"]:
        quantize_dynamic(
            model_input=f"./{m}.onnx",
            model_output=f"./{m}.int8.onnx",
            weight_type=QuantType.QUInt8,
        )

    meta_data = {
        "vocab_size": vocab_size,
        "normalize_type": normalize_type,
        "subsampling_factor": subsampling_factor,
        "model_type": "EncDecMultiTaskModel",
        "version": "1",
        "model_author": "NeMo",
        "url": "https://huggingface.co/nvidia/canary-180m-flash",
        "feat_dim": features,
    }

    add_meta_data("encoder.onnx", meta_data)
    add_meta_data("encoder.int8.onnx", meta_data)

    """
    To fix the following error with onnxruntime 1.17.1 and 1.16.3:

    onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 :FAIL : Load model from ./decoder.int8.onnx failed:/Users/runner/work/1/s/onnxruntime/core/graph/model.cc:150 onnxruntime::Model::Model(onnx::ModelProto &&, const onnxruntime::PathString &, const onnxruntime::IOnnxRuntimeOpSchemaRegistryList *, const logging::Logger &, const onnxruntime::ModelOptions &)
    Unsupported model IR version: 10, max supported IR version: 9
    """
    for filename in ["./decoder.onnx", "./decoder.int8.onnx"]:
        model = onnx.load(filename)
        print("old", model.ir_version)
        model.ir_version = 9
        print("new", model.ir_version)
        onnx.save(model, filename)

    os.system("ls -lh *.onnx")


if __name__ == "__main__":
    main()