Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2025-07-07 00:12:20 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2025-07-07 00:12:20 +0800
Commit
fce481c125f7343770241f51f05bf687f6aa5849
fce481c1
1 parent
25f9cec0
Add meta data to NeMo canary ONNX models (#2351)
隐藏空白字符变更
内嵌
并排对比
正在显示
4 个修改的文件
包含
87 行增加
和
68 行删除
.github/workflows/export-nemo-canary-180m-flash.yaml
scripts/nemo/canary/export_onnx_180m_flash.py
scripts/nemo/canary/run_180m_flash.sh
scripts/nemo/canary/test_180m_flash.py
.github/workflows/export-nemo-canary-180m-flash.yaml
查看文件 @
fce481c
...
...
@@ -62,22 +62,7 @@ jobs:
d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8
mkdir -p $d
cp encoder.int8.onnx $d
cp decoder.fp16.onnx $d
cp tokens.txt $d
mkdir $d/test_wavs
cp de.wav $d/test_wavs
cp en.wav $d/test_wavs
tar cjfv $d.tar.bz2 $d
-
name
:
Collect files (fp16)
shell
:
bash
run
:
|
d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-fp16
mkdir -p $d
cp encoder.fp16.onnx $d
cp decoder.fp16.onnx $d
cp decoder.int8.onnx $d
cp tokens.txt $d
mkdir $d/test_wavs
...
...
@@ -101,7 +86,6 @@ jobs:
models=(
sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr
sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8
sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-fp16
)
for m in ${models[@]}; do
...
...
scripts/nemo/canary/export_onnx_180m_flash.py
查看文件 @
fce481c
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
"""
<|en|>
<|pnc|>
<|noitn|>
<|nodiarize|>
<|notimestamp|>
"""
import
os
from
typing
import
Tuple
from
typing
import
Dict
,
Tuple
import
nemo
import
onnx
mltools
import
onnx
import
torch
from
nemo.collections.common.parts
import
NEG_INF
from
onnxmltools.utils.float16_converter
import
convert_float_to_float16
from
onnxruntime.quantization
import
QuantType
,
quantize_dynamic
"""
...
...
@@ -64,10 +71,25 @@ nemo.collections.common.parts.form_attention_mask = fixed_form_attention_mask
from
nemo.collections.asr.models
import
EncDecMultiTaskModel
def
export_onnx_fp16
(
onnx_fp32_path
,
onnx_fp16_path
):
onnx_fp32_model
=
onnxmltools
.
utils
.
load_model
(
onnx_fp32_path
)
onnx_fp16_model
=
convert_float_to_float16
(
onnx_fp32_model
,
keep_io_types
=
True
)
onnxmltools
.
utils
.
save_model
(
onnx_fp16_model
,
onnx_fp16_path
)
def
add_meta_data
(
filename
:
str
,
meta_data
:
Dict
[
str
,
str
]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model
=
onnx
.
load
(
filename
)
while
len
(
model
.
metadata_props
):
model
.
metadata_props
.
pop
()
for
key
,
value
in
meta_data
.
items
():
meta
=
model
.
metadata_props
.
add
()
meta
.
key
=
key
meta
.
value
=
str
(
value
)
onnx
.
save
(
model
,
filename
)
def
lens_to_mask
(
lens
,
max_length
):
...
...
@@ -222,7 +244,7 @@ def export_decoder(canary_model):
),
"decoder.onnx"
,
dynamo
=
True
,
opset_version
=
1
8
,
opset_version
=
1
4
,
external_data
=
False
,
input_names
=
[
"decoder_input_ids"
,
...
...
@@ -269,6 +291,29 @@ def export_tokens(canary_model):
@torch.no_grad
()
def
main
():
canary_model
=
EncDecMultiTaskModel
.
from_pretrained
(
"nvidia/canary-180m-flash"
)
canary_model
.
eval
()
preprocessor
=
canary_model
.
cfg
[
"preprocessor"
]
sample_rate
=
preprocessor
[
"sample_rate"
]
normalize_type
=
preprocessor
[
"normalize"
]
window_size
=
preprocessor
[
"window_size"
]
# ms
window_stride
=
preprocessor
[
"window_stride"
]
# ms
window
=
preprocessor
[
"window"
]
features
=
preprocessor
[
"features"
]
n_fft
=
preprocessor
[
"n_fft"
]
vocab_size
=
canary_model
.
tokenizer
.
vocab_size
# 5248
subsampling_factor
=
canary_model
.
cfg
[
"encoder"
][
"subsampling_factor"
]
assert
sample_rate
==
16000
,
sample_rate
assert
normalize_type
==
"per_feature"
,
normalize_type
assert
window_size
==
0.025
,
window_size
assert
window_stride
==
0.01
,
window_stride
assert
window
==
"hann"
,
window
assert
features
==
128
,
features
assert
n_fft
==
512
,
n_fft
assert
subsampling_factor
==
8
,
subsampling_factor
export_tokens
(
canary_model
)
export_encoder
(
canary_model
)
export_decoder
(
canary_model
)
...
...
@@ -280,7 +325,32 @@ def main():
weight_type
=
QuantType
.
QUInt8
,
)
export_onnx_fp16
(
f
"{m}.onnx"
,
f
"{m}.fp16.onnx"
)
meta_data
=
{
"vocab_size"
:
vocab_size
,
"normalize_type"
:
normalize_type
,
"subsampling_factor"
:
subsampling_factor
,
"model_type"
:
"EncDecMultiTaskModel"
,
"version"
:
"1"
,
"model_author"
:
"NeMo"
,
"url"
:
"https://huggingface.co/nvidia/canary-180m-flash"
,
"feat_dim"
:
features
,
}
add_meta_data
(
"encoder.onnx"
,
meta_data
)
add_meta_data
(
"encoder.int8.onnx"
,
meta_data
)
"""
To fix the following error with onnxruntime 1.17.1 and 1.16.3:
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 :FAIL : Load model from ./decoder.int8.onnx failed:/Users/runner/work/1/s/onnxruntime/core/graph/model.cc:150 onnxruntime::Model::Model(onnx::ModelProto &&, const onnxruntime::PathString &, const onnxruntime::IOnnxRuntimeOpSchemaRegistryList *, const logging::Logger &, const onnxruntime::ModelOptions &)
Unsupported model IR version: 10, max supported IR version: 9
"""
for
filename
in
[
"./decoder.onnx"
,
"./decoder.int8.onnx"
]:
model
=
onnx
.
load
(
filename
)
print
(
"old"
,
model
.
ir_version
)
model
.
ir_version
=
9
print
(
"new"
,
model
.
ir_version
)
onnx
.
save
(
model
,
filename
)
os
.
system
(
"ls -lh *.onnx"
)
...
...
scripts/nemo/canary/run_180m_flash.sh
查看文件 @
fce481c
...
...
@@ -19,8 +19,8 @@ pip install \
kaldi-native-fbank
\
librosa
\
onnx
==
1.17.0
\
onnxmltools
\
onnxruntime
==
1.17.1
\
onnxscript
\
soundfile
python3 ./export_onnx_180m_flash.py
...
...
@@ -66,7 +66,7 @@ log "-----int8------"
python3 ./test_180m_flash.py
\
--encoder ./encoder.int8.onnx
\
--decoder ./decoder.
fp16
.onnx
\
--decoder ./decoder.
int8
.onnx
\
--source-lang en
\
--target-lang en
\
--tokens ./tokens.txt
\
...
...
@@ -74,7 +74,7 @@ python3 ./test_180m_flash.py \
python3 ./test_180m_flash.py
\
--encoder ./encoder.int8.onnx
\
--decoder ./decoder.
fp16
.onnx
\
--decoder ./decoder.
int8
.onnx
\
--source-lang en
\
--target-lang de
\
--tokens ./tokens.txt
\
...
...
@@ -82,7 +82,7 @@ python3 ./test_180m_flash.py \
python3 ./test_180m_flash.py
\
--encoder ./encoder.int8.onnx
\
--decoder ./decoder.
fp16
.onnx
\
--decoder ./decoder.
int8
.onnx
\
--source-lang de
\
--target-lang de
\
--tokens ./tokens.txt
\
...
...
@@ -90,41 +90,7 @@ python3 ./test_180m_flash.py \
python3 ./test_180m_flash.py
\
--encoder ./encoder.int8.onnx
\
--decoder ./decoder.fp16.onnx
\
--source-lang de
\
--target-lang en
\
--tokens ./tokens.txt
\
--wav ./de.wav
log
"-----fp16------"
python3 ./test_180m_flash.py
\
--encoder ./encoder.fp16.onnx
\
--decoder ./decoder.fp16.onnx
\
--source-lang en
\
--target-lang en
\
--tokens ./tokens.txt
\
--wav ./en.wav
python3 ./test_180m_flash.py
\
--encoder ./encoder.fp16.onnx
\
--decoder ./decoder.fp16.onnx
\
--source-lang en
\
--target-lang de
\
--tokens ./tokens.txt
\
--wav ./en.wav
python3 ./test_180m_flash.py
\
--encoder ./encoder.fp16.onnx
\
--decoder ./decoder.fp16.onnx
\
--source-lang de
\
--target-lang de
\
--tokens ./tokens.txt
\
--wav ./de.wav
python3 ./test_180m_flash.py
\
--encoder ./encoder.fp16.onnx
\
--decoder ./decoder.fp16.onnx
\
--decoder ./decoder.int8.onnx
\
--source-lang de
\
--target-lang en
\
--tokens ./tokens.txt
\
...
...
scripts/nemo/canary/test_180m_flash.py
查看文件 @
fce481c
...
...
@@ -79,8 +79,7 @@ class OnnxModel:
)
meta
=
self
.
encoder
.
get_modelmeta
()
.
custom_metadata_map
# self.normalize_type = meta["normalize_type"]
self
.
normalize_type
=
"per_feature"
self
.
normalize_type
=
meta
[
"normalize_type"
]
print
(
meta
)
def
init_decoder
(
self
,
decoder
):
...
...
@@ -267,7 +266,7 @@ def main():
for
pos
,
decoder_input_id
in
enumerate
(
decoder_input_ids
):
logits
,
decoder_mems_list
=
model
.
run_decoder
(
np
.
array
([[
decoder_input_id
,
pos
]],
dtype
=
np
.
int32
),
np
.
array
([[
decoder_input_id
,
pos
]],
dtype
=
np
.
int32
),
decoder_mems_list
,
enc_states
,
enc_masks
,
...
...
请
注册
或
登录
后发表评论