Fangjun Kuang
Committed by GitHub

Add meta data to NeMo canary ONNX models (#2351)

... ... @@ -62,22 +62,7 @@ jobs:
d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8
mkdir -p $d
cp encoder.int8.onnx $d
cp decoder.fp16.onnx $d
cp tokens.txt $d
mkdir $d/test_wavs
cp de.wav $d/test_wavs
cp en.wav $d/test_wavs
tar cjfv $d.tar.bz2 $d
- name: Collect files (fp16)
shell: bash
run: |
d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-fp16
mkdir -p $d
cp encoder.fp16.onnx $d
cp decoder.fp16.onnx $d
cp decoder.int8.onnx $d
cp tokens.txt $d
mkdir $d/test_wavs
... ... @@ -101,7 +86,6 @@ jobs:
models=(
sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr
sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8
sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-fp16
)
for m in ${models[@]}; do
... ...
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
"""
<|en|>
<|pnc|>
<|noitn|>
<|nodiarize|>
<|notimestamp|>
"""
import os
from typing import Tuple
from typing import Dict, Tuple
import nemo
import onnxmltools
import onnx
import torch
from nemo.collections.common.parts import NEG_INF
from onnxmltools.utils.float16_converter import convert_float_to_float16
from onnxruntime.quantization import QuantType, quantize_dynamic
"""
... ... @@ -64,10 +71,25 @@ nemo.collections.common.parts.form_attention_mask = fixed_form_attention_mask
from nemo.collections.asr.models import EncDecMultiTaskModel
def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path):
onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path)
onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True)
onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
while len(model.metadata_props):
model.metadata_props.pop()
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
def lens_to_mask(lens, max_length):
... ... @@ -222,7 +244,7 @@ def export_decoder(canary_model):
),
"decoder.onnx",
dynamo=True,
opset_version=18,
opset_version=14,
external_data=False,
input_names=[
"decoder_input_ids",
... ... @@ -269,6 +291,29 @@ def export_tokens(canary_model):
@torch.no_grad()
def main():
canary_model = EncDecMultiTaskModel.from_pretrained("nvidia/canary-180m-flash")
canary_model.eval()
preprocessor = canary_model.cfg["preprocessor"]
sample_rate = preprocessor["sample_rate"]
normalize_type = preprocessor["normalize"]
window_size = preprocessor["window_size"] # ms
window_stride = preprocessor["window_stride"] # ms
window = preprocessor["window"]
features = preprocessor["features"]
n_fft = preprocessor["n_fft"]
vocab_size = canary_model.tokenizer.vocab_size # 5248
subsampling_factor = canary_model.cfg["encoder"]["subsampling_factor"]
assert sample_rate == 16000, sample_rate
assert normalize_type == "per_feature", normalize_type
assert window_size == 0.025, window_size
assert window_stride == 0.01, window_stride
assert window == "hann", window
assert features == 128, features
assert n_fft == 512, n_fft
assert subsampling_factor == 8, subsampling_factor
export_tokens(canary_model)
export_encoder(canary_model)
export_decoder(canary_model)
... ... @@ -280,7 +325,32 @@ def main():
weight_type=QuantType.QUInt8,
)
export_onnx_fp16(f"{m}.onnx", f"{m}.fp16.onnx")
meta_data = {
"vocab_size": vocab_size,
"normalize_type": normalize_type,
"subsampling_factor": subsampling_factor,
"model_type": "EncDecMultiTaskModel",
"version": "1",
"model_author": "NeMo",
"url": "https://huggingface.co/nvidia/canary-180m-flash",
"feat_dim": features,
}
add_meta_data("encoder.onnx", meta_data)
add_meta_data("encoder.int8.onnx", meta_data)
"""
To fix the following error with onnxruntime 1.17.1 and 1.16.3:
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 :FAIL : Load model from ./decoder.int8.onnx failed:/Users/runner/work/1/s/onnxruntime/core/graph/model.cc:150 onnxruntime::Model::Model(onnx::ModelProto &&, const onnxruntime::PathString &, const onnxruntime::IOnnxRuntimeOpSchemaRegistryList *, const logging::Logger &, const onnxruntime::ModelOptions &)
Unsupported model IR version: 10, max supported IR version: 9
"""
for filename in ["./decoder.onnx", "./decoder.int8.onnx"]:
model = onnx.load(filename)
print("old", model.ir_version)
model.ir_version = 9
print("new", model.ir_version)
onnx.save(model, filename)
os.system("ls -lh *.onnx")
... ...
... ... @@ -19,8 +19,8 @@ pip install \
kaldi-native-fbank \
librosa \
onnx==1.17.0 \
onnxmltools \
onnxruntime==1.17.1 \
onnxscript \
soundfile
python3 ./export_onnx_180m_flash.py
... ... @@ -66,7 +66,7 @@ log "-----int8------"
python3 ./test_180m_flash.py \
--encoder ./encoder.int8.onnx \
--decoder ./decoder.fp16.onnx \
--decoder ./decoder.int8.onnx \
--source-lang en \
--target-lang en \
--tokens ./tokens.txt \
... ... @@ -74,7 +74,7 @@ python3 ./test_180m_flash.py \
python3 ./test_180m_flash.py \
--encoder ./encoder.int8.onnx \
--decoder ./decoder.fp16.onnx \
--decoder ./decoder.int8.onnx \
--source-lang en \
--target-lang de \
--tokens ./tokens.txt \
... ... @@ -82,7 +82,7 @@ python3 ./test_180m_flash.py \
python3 ./test_180m_flash.py \
--encoder ./encoder.int8.onnx \
--decoder ./decoder.fp16.onnx \
--decoder ./decoder.int8.onnx \
--source-lang de \
--target-lang de \
--tokens ./tokens.txt \
... ... @@ -90,41 +90,7 @@ python3 ./test_180m_flash.py \
python3 ./test_180m_flash.py \
--encoder ./encoder.int8.onnx \
--decoder ./decoder.fp16.onnx \
--source-lang de \
--target-lang en \
--tokens ./tokens.txt \
--wav ./de.wav
log "-----fp16------"
python3 ./test_180m_flash.py \
--encoder ./encoder.fp16.onnx \
--decoder ./decoder.fp16.onnx \
--source-lang en \
--target-lang en \
--tokens ./tokens.txt \
--wav ./en.wav
python3 ./test_180m_flash.py \
--encoder ./encoder.fp16.onnx \
--decoder ./decoder.fp16.onnx \
--source-lang en \
--target-lang de \
--tokens ./tokens.txt \
--wav ./en.wav
python3 ./test_180m_flash.py \
--encoder ./encoder.fp16.onnx \
--decoder ./decoder.fp16.onnx \
--source-lang de \
--target-lang de \
--tokens ./tokens.txt \
--wav ./de.wav
python3 ./test_180m_flash.py \
--encoder ./encoder.fp16.onnx \
--decoder ./decoder.fp16.onnx \
--decoder ./decoder.int8.onnx \
--source-lang de \
--target-lang en \
--tokens ./tokens.txt \
... ...
... ... @@ -79,8 +79,7 @@ class OnnxModel:
)
meta = self.encoder.get_modelmeta().custom_metadata_map
# self.normalize_type = meta["normalize_type"]
self.normalize_type = "per_feature"
self.normalize_type = meta["normalize_type"]
print(meta)
def init_decoder(self, decoder):
... ... @@ -267,7 +266,7 @@ def main():
for pos, decoder_input_id in enumerate(decoder_input_ids):
logits, decoder_mems_list = model.run_decoder(
np.array([[decoder_input_id,pos]], dtype=np.int32),
np.array([[decoder_input_id, pos]], dtype=np.int32),
decoder_mems_list,
enc_states,
enc_masks,
... ...