Fangjun Kuang
Committed by GitHub

Add meta data to NeMo canary ONNX models (#2351)

@@ -62,22 +62,7 @@ jobs: @@ -62,22 +62,7 @@ jobs:
62 d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8 62 d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8
63 mkdir -p $d 63 mkdir -p $d
64 cp encoder.int8.onnx $d 64 cp encoder.int8.onnx $d
65 - cp decoder.fp16.onnx $d  
66 - cp tokens.txt $d  
67 -  
68 - mkdir $d/test_wavs  
69 - cp de.wav $d/test_wavs  
70 - cp en.wav $d/test_wavs  
71 -  
72 - tar cjfv $d.tar.bz2 $d  
73 -  
74 - - name: Collect files (fp16)  
75 - shell: bash  
76 - run: |  
77 - d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-fp16  
78 - mkdir -p $d  
79 - cp encoder.fp16.onnx $d  
80 - cp decoder.fp16.onnx $d 65 + cp decoder.int8.onnx $d
81 cp tokens.txt $d 66 cp tokens.txt $d
82 67
83 mkdir $d/test_wavs 68 mkdir $d/test_wavs
@@ -101,7 +86,6 @@ jobs: @@ -101,7 +86,6 @@ jobs:
101 models=( 86 models=(
102 sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr 87 sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr
103 sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8 88 sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8
104 - sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-fp16  
105 ) 89 )
106 90
107 for m in ${models[@]}; do 91 for m in ${models[@]}; do
1 #!/usr/bin/env python3 1 #!/usr/bin/env python3
2 # Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) 2 # Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
3 3
  4 +"""
  5 +<|en|>
  6 +<|pnc|>
  7 +<|noitn|>
  8 +<|nodiarize|>
  9 +<|notimestamp|>
  10 +"""
  11 +
4 import os 12 import os
5 -from typing import Tuple 13 +from typing import Dict, Tuple
6 14
7 import nemo 15 import nemo
8 -import onnxmltools 16 +import onnx
9 import torch 17 import torch
10 from nemo.collections.common.parts import NEG_INF 18 from nemo.collections.common.parts import NEG_INF
11 -from onnxmltools.utils.float16_converter import convert_float_to_float16  
12 from onnxruntime.quantization import QuantType, quantize_dynamic 19 from onnxruntime.quantization import QuantType, quantize_dynamic
13 20
14 """ 21 """
@@ -64,10 +71,25 @@ nemo.collections.common.parts.form_attention_mask = fixed_form_attention_mask @@ -64,10 +71,25 @@ nemo.collections.common.parts.form_attention_mask = fixed_form_attention_mask
64 from nemo.collections.asr.models import EncDecMultiTaskModel 71 from nemo.collections.asr.models import EncDecMultiTaskModel
65 72
66 73
67 -def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path):  
68 - onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path)  
69 - onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True)  
70 - onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) 74 +def add_meta_data(filename: str, meta_data: Dict[str, str]):
  75 + """Add meta data to an ONNX model. It is changed in-place.
  76 +
  77 + Args:
  78 + filename:
  79 + Filename of the ONNX model to be changed.
  80 + meta_data:
  81 + Key-value pairs.
  82 + """
  83 + model = onnx.load(filename)
  84 + while len(model.metadata_props):
  85 + model.metadata_props.pop()
  86 +
  87 + for key, value in meta_data.items():
  88 + meta = model.metadata_props.add()
  89 + meta.key = key
  90 + meta.value = str(value)
  91 +
  92 + onnx.save(model, filename)
71 93
72 94
73 def lens_to_mask(lens, max_length): 95 def lens_to_mask(lens, max_length):
@@ -222,7 +244,7 @@ def export_decoder(canary_model): @@ -222,7 +244,7 @@ def export_decoder(canary_model):
222 ), 244 ),
223 "decoder.onnx", 245 "decoder.onnx",
224 dynamo=True, 246 dynamo=True,
225 - opset_version=18, 247 + opset_version=14,
226 external_data=False, 248 external_data=False,
227 input_names=[ 249 input_names=[
228 "decoder_input_ids", 250 "decoder_input_ids",
@@ -269,6 +291,29 @@ def export_tokens(canary_model): @@ -269,6 +291,29 @@ def export_tokens(canary_model):
269 @torch.no_grad() 291 @torch.no_grad()
270 def main(): 292 def main():
271 canary_model = EncDecMultiTaskModel.from_pretrained("nvidia/canary-180m-flash") 293 canary_model = EncDecMultiTaskModel.from_pretrained("nvidia/canary-180m-flash")
  294 + canary_model.eval()
  295 +
  296 + preprocessor = canary_model.cfg["preprocessor"]
  297 + sample_rate = preprocessor["sample_rate"]
  298 + normalize_type = preprocessor["normalize"]
  299 + window_size = preprocessor["window_size"] # ms
  300 + window_stride = preprocessor["window_stride"] # ms
  301 + window = preprocessor["window"]
  302 + features = preprocessor["features"]
  303 + n_fft = preprocessor["n_fft"]
  304 + vocab_size = canary_model.tokenizer.vocab_size # 5248
  305 +
  306 + subsampling_factor = canary_model.cfg["encoder"]["subsampling_factor"]
  307 +
  308 + assert sample_rate == 16000, sample_rate
  309 + assert normalize_type == "per_feature", normalize_type
  310 + assert window_size == 0.025, window_size
  311 + assert window_stride == 0.01, window_stride
  312 + assert window == "hann", window
  313 + assert features == 128, features
  314 + assert n_fft == 512, n_fft
  315 + assert subsampling_factor == 8, subsampling_factor
  316 +
272 export_tokens(canary_model) 317 export_tokens(canary_model)
273 export_encoder(canary_model) 318 export_encoder(canary_model)
274 export_decoder(canary_model) 319 export_decoder(canary_model)
@@ -280,7 +325,32 @@ def main(): @@ -280,7 +325,32 @@ def main():
280 weight_type=QuantType.QUInt8, 325 weight_type=QuantType.QUInt8,
281 ) 326 )
282 327
283 - export_onnx_fp16(f"{m}.onnx", f"{m}.fp16.onnx") 328 + meta_data = {
  329 + "vocab_size": vocab_size,
  330 + "normalize_type": normalize_type,
  331 + "subsampling_factor": subsampling_factor,
  332 + "model_type": "EncDecMultiTaskModel",
  333 + "version": "1",
  334 + "model_author": "NeMo",
  335 + "url": "https://huggingface.co/nvidia/canary-180m-flash",
  336 + "feat_dim": features,
  337 + }
  338 +
  339 + add_meta_data("encoder.onnx", meta_data)
  340 + add_meta_data("encoder.int8.onnx", meta_data)
  341 +
  342 + """
  343 + To fix the following error with onnxruntime 1.17.1 and 1.16.3:
  344 +
  345 + onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 :FAIL : Load model from ./decoder.int8.onnx failed:/Users/runner/work/1/s/onnxruntime/core/graph/model.cc:150 onnxruntime::Model::Model(onnx::ModelProto &&, const onnxruntime::PathString &, const onnxruntime::IOnnxRuntimeOpSchemaRegistryList *, const logging::Logger &, const onnxruntime::ModelOptions &)
  346 + Unsupported model IR version: 10, max supported IR version: 9
  347 + """
  348 + for filename in ["./decoder.onnx", "./decoder.int8.onnx"]:
  349 + model = onnx.load(filename)
  350 + print("old", model.ir_version)
  351 + model.ir_version = 9
  352 + print("new", model.ir_version)
  353 + onnx.save(model, filename)
284 354
285 os.system("ls -lh *.onnx") 355 os.system("ls -lh *.onnx")
286 356
@@ -19,8 +19,8 @@ pip install \ @@ -19,8 +19,8 @@ pip install \
19 kaldi-native-fbank \ 19 kaldi-native-fbank \
20 librosa \ 20 librosa \
21 onnx==1.17.0 \ 21 onnx==1.17.0 \
22 - onnxmltools \  
23 onnxruntime==1.17.1 \ 22 onnxruntime==1.17.1 \
  23 + onnxscript \
24 soundfile 24 soundfile
25 25
26 python3 ./export_onnx_180m_flash.py 26 python3 ./export_onnx_180m_flash.py
@@ -66,7 +66,7 @@ log "-----int8------" @@ -66,7 +66,7 @@ log "-----int8------"
66 66
67 python3 ./test_180m_flash.py \ 67 python3 ./test_180m_flash.py \
68 --encoder ./encoder.int8.onnx \ 68 --encoder ./encoder.int8.onnx \
69 - --decoder ./decoder.fp16.onnx \ 69 + --decoder ./decoder.int8.onnx \
70 --source-lang en \ 70 --source-lang en \
71 --target-lang en \ 71 --target-lang en \
72 --tokens ./tokens.txt \ 72 --tokens ./tokens.txt \
@@ -74,7 +74,7 @@ python3 ./test_180m_flash.py \ @@ -74,7 +74,7 @@ python3 ./test_180m_flash.py \
74 74
75 python3 ./test_180m_flash.py \ 75 python3 ./test_180m_flash.py \
76 --encoder ./encoder.int8.onnx \ 76 --encoder ./encoder.int8.onnx \
77 - --decoder ./decoder.fp16.onnx \ 77 + --decoder ./decoder.int8.onnx \
78 --source-lang en \ 78 --source-lang en \
79 --target-lang de \ 79 --target-lang de \
80 --tokens ./tokens.txt \ 80 --tokens ./tokens.txt \
@@ -82,7 +82,7 @@ python3 ./test_180m_flash.py \ @@ -82,7 +82,7 @@ python3 ./test_180m_flash.py \
82 82
83 python3 ./test_180m_flash.py \ 83 python3 ./test_180m_flash.py \
84 --encoder ./encoder.int8.onnx \ 84 --encoder ./encoder.int8.onnx \
85 - --decoder ./decoder.fp16.onnx \ 85 + --decoder ./decoder.int8.onnx \
86 --source-lang de \ 86 --source-lang de \
87 --target-lang de \ 87 --target-lang de \
88 --tokens ./tokens.txt \ 88 --tokens ./tokens.txt \
@@ -90,41 +90,7 @@ python3 ./test_180m_flash.py \ @@ -90,41 +90,7 @@ python3 ./test_180m_flash.py \
90 90
91 python3 ./test_180m_flash.py \ 91 python3 ./test_180m_flash.py \
92 --encoder ./encoder.int8.onnx \ 92 --encoder ./encoder.int8.onnx \
93 - --decoder ./decoder.fp16.onnx \  
94 - --source-lang de \  
95 - --target-lang en \  
96 - --tokens ./tokens.txt \  
97 - --wav ./de.wav  
98 -  
99 -log "-----fp16------"  
100 -  
101 -python3 ./test_180m_flash.py \  
102 - --encoder ./encoder.fp16.onnx \  
103 - --decoder ./decoder.fp16.onnx \  
104 - --source-lang en \  
105 - --target-lang en \  
106 - --tokens ./tokens.txt \  
107 - --wav ./en.wav  
108 -  
109 -python3 ./test_180m_flash.py \  
110 - --encoder ./encoder.fp16.onnx \  
111 - --decoder ./decoder.fp16.onnx \  
112 - --source-lang en \  
113 - --target-lang de \  
114 - --tokens ./tokens.txt \  
115 - --wav ./en.wav  
116 -  
117 -python3 ./test_180m_flash.py \  
118 - --encoder ./encoder.fp16.onnx \  
119 - --decoder ./decoder.fp16.onnx \  
120 - --source-lang de \  
121 - --target-lang de \  
122 - --tokens ./tokens.txt \  
123 - --wav ./de.wav  
124 -  
125 -python3 ./test_180m_flash.py \  
126 - --encoder ./encoder.fp16.onnx \  
127 - --decoder ./decoder.fp16.onnx \ 93 + --decoder ./decoder.int8.onnx \
128 --source-lang de \ 94 --source-lang de \
129 --target-lang en \ 95 --target-lang en \
130 --tokens ./tokens.txt \ 96 --tokens ./tokens.txt \
@@ -79,8 +79,7 @@ class OnnxModel: @@ -79,8 +79,7 @@ class OnnxModel:
79 ) 79 )
80 80
81 meta = self.encoder.get_modelmeta().custom_metadata_map 81 meta = self.encoder.get_modelmeta().custom_metadata_map
82 - # self.normalize_type = meta["normalize_type"]  
83 - self.normalize_type = "per_feature" 82 + self.normalize_type = meta["normalize_type"]
84 print(meta) 83 print(meta)
85 84
86 def init_decoder(self, decoder): 85 def init_decoder(self, decoder):
@@ -267,7 +266,7 @@ def main(): @@ -267,7 +266,7 @@ def main():
267 266
268 for pos, decoder_input_id in enumerate(decoder_input_ids): 267 for pos, decoder_input_id in enumerate(decoder_input_ids):
269 logits, decoder_mems_list = model.run_decoder( 268 logits, decoder_mems_list = model.run_decoder(
270 - np.array([[decoder_input_id,pos]], dtype=np.int32), 269 + np.array([[decoder_input_id, pos]], dtype=np.int32),
271 decoder_mems_list, 270 decoder_mems_list,
272 enc_states, 271 enc_states,
273 enc_masks, 272 enc_masks,