Committed by
GitHub
Add meta data to NeMo canary ONNX models (#2351)
正在显示
4 个修改的文件
包含
87 行增加
和
68 行删除
| @@ -62,22 +62,7 @@ jobs: | @@ -62,22 +62,7 @@ jobs: | ||
| 62 | d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8 | 62 | d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8 |
| 63 | mkdir -p $d | 63 | mkdir -p $d |
| 64 | cp encoder.int8.onnx $d | 64 | cp encoder.int8.onnx $d |
| 65 | - cp decoder.fp16.onnx $d | ||
| 66 | - cp tokens.txt $d | ||
| 67 | - | ||
| 68 | - mkdir $d/test_wavs | ||
| 69 | - cp de.wav $d/test_wavs | ||
| 70 | - cp en.wav $d/test_wavs | ||
| 71 | - | ||
| 72 | - tar cjfv $d.tar.bz2 $d | ||
| 73 | - | ||
| 74 | - - name: Collect files (fp16) | ||
| 75 | - shell: bash | ||
| 76 | - run: | | ||
| 77 | - d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-fp16 | ||
| 78 | - mkdir -p $d | ||
| 79 | - cp encoder.fp16.onnx $d | ||
| 80 | - cp decoder.fp16.onnx $d | 65 | + cp decoder.int8.onnx $d |
| 81 | cp tokens.txt $d | 66 | cp tokens.txt $d |
| 82 | 67 | ||
| 83 | mkdir $d/test_wavs | 68 | mkdir $d/test_wavs |
| @@ -101,7 +86,6 @@ jobs: | @@ -101,7 +86,6 @@ jobs: | ||
| 101 | models=( | 86 | models=( |
| 102 | sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr | 87 | sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr |
| 103 | sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8 | 88 | sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8 |
| 104 | - sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-fp16 | ||
| 105 | ) | 89 | ) |
| 106 | 90 | ||
| 107 | for m in ${models[@]}; do | 91 | for m in ${models[@]}; do |
| 1 | #!/usr/bin/env python3 | 1 | #!/usr/bin/env python3 |
| 2 | # Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) | 2 | # Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) |
| 3 | 3 | ||
| 4 | +""" | ||
| 5 | +<|en|> | ||
| 6 | +<|pnc|> | ||
| 7 | +<|noitn|> | ||
| 8 | +<|nodiarize|> | ||
| 9 | +<|notimestamp|> | ||
| 10 | +""" | ||
| 11 | + | ||
| 4 | import os | 12 | import os |
| 5 | -from typing import Tuple | 13 | +from typing import Dict, Tuple |
| 6 | 14 | ||
| 7 | import nemo | 15 | import nemo |
| 8 | -import onnxmltools | 16 | +import onnx |
| 9 | import torch | 17 | import torch |
| 10 | from nemo.collections.common.parts import NEG_INF | 18 | from nemo.collections.common.parts import NEG_INF |
| 11 | -from onnxmltools.utils.float16_converter import convert_float_to_float16 | ||
| 12 | from onnxruntime.quantization import QuantType, quantize_dynamic | 19 | from onnxruntime.quantization import QuantType, quantize_dynamic |
| 13 | 20 | ||
| 14 | """ | 21 | """ |
| @@ -64,10 +71,25 @@ nemo.collections.common.parts.form_attention_mask = fixed_form_attention_mask | @@ -64,10 +71,25 @@ nemo.collections.common.parts.form_attention_mask = fixed_form_attention_mask | ||
| 64 | from nemo.collections.asr.models import EncDecMultiTaskModel | 71 | from nemo.collections.asr.models import EncDecMultiTaskModel |
| 65 | 72 | ||
| 66 | 73 | ||
| 67 | -def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path): | ||
| 68 | - onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path) | ||
| 69 | - onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True) | ||
| 70 | - onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) | 74 | +def add_meta_data(filename: str, meta_data: Dict[str, str]): |
| 75 | + """Add meta data to an ONNX model. It is changed in-place. | ||
| 76 | + | ||
| 77 | + Args: | ||
| 78 | + filename: | ||
| 79 | + Filename of the ONNX model to be changed. | ||
| 80 | + meta_data: | ||
| 81 | + Key-value pairs. | ||
| 82 | + """ | ||
| 83 | + model = onnx.load(filename) | ||
| 84 | + while len(model.metadata_props): | ||
| 85 | + model.metadata_props.pop() | ||
| 86 | + | ||
| 87 | + for key, value in meta_data.items(): | ||
| 88 | + meta = model.metadata_props.add() | ||
| 89 | + meta.key = key | ||
| 90 | + meta.value = str(value) | ||
| 91 | + | ||
| 92 | + onnx.save(model, filename) | ||
| 71 | 93 | ||
| 72 | 94 | ||
| 73 | def lens_to_mask(lens, max_length): | 95 | def lens_to_mask(lens, max_length): |
| @@ -222,7 +244,7 @@ def export_decoder(canary_model): | @@ -222,7 +244,7 @@ def export_decoder(canary_model): | ||
| 222 | ), | 244 | ), |
| 223 | "decoder.onnx", | 245 | "decoder.onnx", |
| 224 | dynamo=True, | 246 | dynamo=True, |
| 225 | - opset_version=18, | 247 | + opset_version=14, |
| 226 | external_data=False, | 248 | external_data=False, |
| 227 | input_names=[ | 249 | input_names=[ |
| 228 | "decoder_input_ids", | 250 | "decoder_input_ids", |
| @@ -269,6 +291,29 @@ def export_tokens(canary_model): | @@ -269,6 +291,29 @@ def export_tokens(canary_model): | ||
| 269 | @torch.no_grad() | 291 | @torch.no_grad() |
| 270 | def main(): | 292 | def main(): |
| 271 | canary_model = EncDecMultiTaskModel.from_pretrained("nvidia/canary-180m-flash") | 293 | canary_model = EncDecMultiTaskModel.from_pretrained("nvidia/canary-180m-flash") |
| 294 | + canary_model.eval() | ||
| 295 | + | ||
| 296 | + preprocessor = canary_model.cfg["preprocessor"] | ||
| 297 | + sample_rate = preprocessor["sample_rate"] | ||
| 298 | + normalize_type = preprocessor["normalize"] | ||
| 299 | + window_size = preprocessor["window_size"] # ms | ||
| 300 | + window_stride = preprocessor["window_stride"] # ms | ||
| 301 | + window = preprocessor["window"] | ||
| 302 | + features = preprocessor["features"] | ||
| 303 | + n_fft = preprocessor["n_fft"] | ||
| 304 | + vocab_size = canary_model.tokenizer.vocab_size # 5248 | ||
| 305 | + | ||
| 306 | + subsampling_factor = canary_model.cfg["encoder"]["subsampling_factor"] | ||
| 307 | + | ||
| 308 | + assert sample_rate == 16000, sample_rate | ||
| 309 | + assert normalize_type == "per_feature", normalize_type | ||
| 310 | + assert window_size == 0.025, window_size | ||
| 311 | + assert window_stride == 0.01, window_stride | ||
| 312 | + assert window == "hann", window | ||
| 313 | + assert features == 128, features | ||
| 314 | + assert n_fft == 512, n_fft | ||
| 315 | + assert subsampling_factor == 8, subsampling_factor | ||
| 316 | + | ||
| 272 | export_tokens(canary_model) | 317 | export_tokens(canary_model) |
| 273 | export_encoder(canary_model) | 318 | export_encoder(canary_model) |
| 274 | export_decoder(canary_model) | 319 | export_decoder(canary_model) |
| @@ -280,7 +325,32 @@ def main(): | @@ -280,7 +325,32 @@ def main(): | ||
| 280 | weight_type=QuantType.QUInt8, | 325 | weight_type=QuantType.QUInt8, |
| 281 | ) | 326 | ) |
| 282 | 327 | ||
| 283 | - export_onnx_fp16(f"{m}.onnx", f"{m}.fp16.onnx") | 328 | + meta_data = { |
| 329 | + "vocab_size": vocab_size, | ||
| 330 | + "normalize_type": normalize_type, | ||
| 331 | + "subsampling_factor": subsampling_factor, | ||
| 332 | + "model_type": "EncDecMultiTaskModel", | ||
| 333 | + "version": "1", | ||
| 334 | + "model_author": "NeMo", | ||
| 335 | + "url": "https://huggingface.co/nvidia/canary-180m-flash", | ||
| 336 | + "feat_dim": features, | ||
| 337 | + } | ||
| 338 | + | ||
| 339 | + add_meta_data("encoder.onnx", meta_data) | ||
| 340 | + add_meta_data("encoder.int8.onnx", meta_data) | ||
| 341 | + | ||
| 342 | + """ | ||
| 343 | + To fix the following error with onnxruntime 1.17.1 and 1.16.3: | ||
| 344 | + | ||
| 345 | + onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 :FAIL : Load model from ./decoder.int8.onnx failed:/Users/runner/work/1/s/onnxruntime/core/graph/model.cc:150 onnxruntime::Model::Model(onnx::ModelProto &&, const onnxruntime::PathString &, const onnxruntime::IOnnxRuntimeOpSchemaRegistryList *, const logging::Logger &, const onnxruntime::ModelOptions &) | ||
| 346 | + Unsupported model IR version: 10, max supported IR version: 9 | ||
| 347 | + """ | ||
| 348 | + for filename in ["./decoder.onnx", "./decoder.int8.onnx"]: | ||
| 349 | + model = onnx.load(filename) | ||
| 350 | + print("old", model.ir_version) | ||
| 351 | + model.ir_version = 9 | ||
| 352 | + print("new", model.ir_version) | ||
| 353 | + onnx.save(model, filename) | ||
| 284 | 354 | ||
| 285 | os.system("ls -lh *.onnx") | 355 | os.system("ls -lh *.onnx") |
| 286 | 356 |
| @@ -19,8 +19,8 @@ pip install \ | @@ -19,8 +19,8 @@ pip install \ | ||
| 19 | kaldi-native-fbank \ | 19 | kaldi-native-fbank \ |
| 20 | librosa \ | 20 | librosa \ |
| 21 | onnx==1.17.0 \ | 21 | onnx==1.17.0 \ |
| 22 | - onnxmltools \ | ||
| 23 | onnxruntime==1.17.1 \ | 22 | onnxruntime==1.17.1 \ |
| 23 | + onnxscript \ | ||
| 24 | soundfile | 24 | soundfile |
| 25 | 25 | ||
| 26 | python3 ./export_onnx_180m_flash.py | 26 | python3 ./export_onnx_180m_flash.py |
| @@ -66,7 +66,7 @@ log "-----int8------" | @@ -66,7 +66,7 @@ log "-----int8------" | ||
| 66 | 66 | ||
| 67 | python3 ./test_180m_flash.py \ | 67 | python3 ./test_180m_flash.py \ |
| 68 | --encoder ./encoder.int8.onnx \ | 68 | --encoder ./encoder.int8.onnx \ |
| 69 | - --decoder ./decoder.fp16.onnx \ | 69 | + --decoder ./decoder.int8.onnx \ |
| 70 | --source-lang en \ | 70 | --source-lang en \ |
| 71 | --target-lang en \ | 71 | --target-lang en \ |
| 72 | --tokens ./tokens.txt \ | 72 | --tokens ./tokens.txt \ |
| @@ -74,7 +74,7 @@ python3 ./test_180m_flash.py \ | @@ -74,7 +74,7 @@ python3 ./test_180m_flash.py \ | ||
| 74 | 74 | ||
| 75 | python3 ./test_180m_flash.py \ | 75 | python3 ./test_180m_flash.py \ |
| 76 | --encoder ./encoder.int8.onnx \ | 76 | --encoder ./encoder.int8.onnx \ |
| 77 | - --decoder ./decoder.fp16.onnx \ | 77 | + --decoder ./decoder.int8.onnx \ |
| 78 | --source-lang en \ | 78 | --source-lang en \ |
| 79 | --target-lang de \ | 79 | --target-lang de \ |
| 80 | --tokens ./tokens.txt \ | 80 | --tokens ./tokens.txt \ |
| @@ -82,7 +82,7 @@ python3 ./test_180m_flash.py \ | @@ -82,7 +82,7 @@ python3 ./test_180m_flash.py \ | ||
| 82 | 82 | ||
| 83 | python3 ./test_180m_flash.py \ | 83 | python3 ./test_180m_flash.py \ |
| 84 | --encoder ./encoder.int8.onnx \ | 84 | --encoder ./encoder.int8.onnx \ |
| 85 | - --decoder ./decoder.fp16.onnx \ | 85 | + --decoder ./decoder.int8.onnx \ |
| 86 | --source-lang de \ | 86 | --source-lang de \ |
| 87 | --target-lang de \ | 87 | --target-lang de \ |
| 88 | --tokens ./tokens.txt \ | 88 | --tokens ./tokens.txt \ |
| @@ -90,41 +90,7 @@ python3 ./test_180m_flash.py \ | @@ -90,41 +90,7 @@ python3 ./test_180m_flash.py \ | ||
| 90 | 90 | ||
| 91 | python3 ./test_180m_flash.py \ | 91 | python3 ./test_180m_flash.py \ |
| 92 | --encoder ./encoder.int8.onnx \ | 92 | --encoder ./encoder.int8.onnx \ |
| 93 | - --decoder ./decoder.fp16.onnx \ | ||
| 94 | - --source-lang de \ | ||
| 95 | - --target-lang en \ | ||
| 96 | - --tokens ./tokens.txt \ | ||
| 97 | - --wav ./de.wav | ||
| 98 | - | ||
| 99 | -log "-----fp16------" | ||
| 100 | - | ||
| 101 | -python3 ./test_180m_flash.py \ | ||
| 102 | - --encoder ./encoder.fp16.onnx \ | ||
| 103 | - --decoder ./decoder.fp16.onnx \ | ||
| 104 | - --source-lang en \ | ||
| 105 | - --target-lang en \ | ||
| 106 | - --tokens ./tokens.txt \ | ||
| 107 | - --wav ./en.wav | ||
| 108 | - | ||
| 109 | -python3 ./test_180m_flash.py \ | ||
| 110 | - --encoder ./encoder.fp16.onnx \ | ||
| 111 | - --decoder ./decoder.fp16.onnx \ | ||
| 112 | - --source-lang en \ | ||
| 113 | - --target-lang de \ | ||
| 114 | - --tokens ./tokens.txt \ | ||
| 115 | - --wav ./en.wav | ||
| 116 | - | ||
| 117 | -python3 ./test_180m_flash.py \ | ||
| 118 | - --encoder ./encoder.fp16.onnx \ | ||
| 119 | - --decoder ./decoder.fp16.onnx \ | ||
| 120 | - --source-lang de \ | ||
| 121 | - --target-lang de \ | ||
| 122 | - --tokens ./tokens.txt \ | ||
| 123 | - --wav ./de.wav | ||
| 124 | - | ||
| 125 | -python3 ./test_180m_flash.py \ | ||
| 126 | - --encoder ./encoder.fp16.onnx \ | ||
| 127 | - --decoder ./decoder.fp16.onnx \ | 93 | + --decoder ./decoder.int8.onnx \ |
| 128 | --source-lang de \ | 94 | --source-lang de \ |
| 129 | --target-lang en \ | 95 | --target-lang en \ |
| 130 | --tokens ./tokens.txt \ | 96 | --tokens ./tokens.txt \ |
| @@ -79,8 +79,7 @@ class OnnxModel: | @@ -79,8 +79,7 @@ class OnnxModel: | ||
| 79 | ) | 79 | ) |
| 80 | 80 | ||
| 81 | meta = self.encoder.get_modelmeta().custom_metadata_map | 81 | meta = self.encoder.get_modelmeta().custom_metadata_map |
| 82 | - # self.normalize_type = meta["normalize_type"] | ||
| 83 | - self.normalize_type = "per_feature" | 82 | + self.normalize_type = meta["normalize_type"] |
| 84 | print(meta) | 83 | print(meta) |
| 85 | 84 | ||
| 86 | def init_decoder(self, decoder): | 85 | def init_decoder(self, decoder): |
| @@ -267,7 +266,7 @@ def main(): | @@ -267,7 +266,7 @@ def main(): | ||
| 267 | 266 | ||
| 268 | for pos, decoder_input_id in enumerate(decoder_input_ids): | 267 | for pos, decoder_input_id in enumerate(decoder_input_ids): |
| 269 | logits, decoder_mems_list = model.run_decoder( | 268 | logits, decoder_mems_list = model.run_decoder( |
| 270 | - np.array([[decoder_input_id,pos]], dtype=np.int32), | 269 | + np.array([[decoder_input_id, pos]], dtype=np.int32), |
| 271 | decoder_mems_list, | 270 | decoder_mems_list, |
| 272 | enc_states, | 271 | enc_states, |
| 273 | enc_masks, | 272 | enc_masks, |
-
请 注册 或 登录 后发表评论