Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
lucaelin
2025-07-06 12:24:06 +0200
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2025-07-06 18:24:06 +0800
Commit
5ebb71909bcd97104f8cfaceba68547607c38342
5ebb7190
1 parent
d70b7895
fix(canary): use dynamo export, single input_ids and avoid 0/1 specialization (#2348)
隐藏空白字符变更
内嵌
并排对比
正在显示
2 个修改的文件
包含
21 行增加
和
22 行删除
scripts/nemo/canary/export_onnx_180m_flash.py
scripts/nemo/canary/test_180m_flash.py
scripts/nemo/canary/export_onnx_180m_flash.py
查看文件 @
5ebb719
...
...
@@ -197,12 +197,12 @@ def export_decoder(canary_model):
decoder
=
DecoderWrapper
(
canary_model
)
decoder_input_ids
=
torch
.
tensor
([[
1
,
0
]],
dtype
=
torch
.
int32
)
decoder_mems_list_0
=
torch
.
zeros
(
1
,
1
,
1024
)
decoder_mems_list_1
=
torch
.
zeros
(
1
,
1
,
1024
)
decoder_mems_list_2
=
torch
.
zeros
(
1
,
1
,
1024
)
decoder_mems_list_3
=
torch
.
zeros
(
1
,
1
,
1024
)
decoder_mems_list_4
=
torch
.
zeros
(
1
,
1
,
1024
)
decoder_mems_list_5
=
torch
.
zeros
(
1
,
1
,
1024
)
decoder_mems_list_0
=
torch
.
zeros
(
1
,
10
,
1024
)
decoder_mems_list_1
=
torch
.
zeros
(
1
,
10
,
1024
)
decoder_mems_list_2
=
torch
.
zeros
(
1
,
10
,
1024
)
decoder_mems_list_3
=
torch
.
zeros
(
1
,
10
,
1024
)
decoder_mems_list_4
=
torch
.
zeros
(
1
,
10
,
1024
)
decoder_mems_list_5
=
torch
.
zeros
(
1
,
10
,
1024
)
enc_states
=
torch
.
zeros
(
1
,
1000
,
1024
)
enc_mask
=
torch
.
ones
(
1
,
1000
)
.
bool
()
...
...
@@ -221,7 +221,9 @@ def export_decoder(canary_model):
enc_mask
,
),
"decoder.onnx"
,
opset_version
=
14
,
dynamo
=
True
,
opset_version
=
18
,
external_data
=
False
,
input_names
=
[
"decoder_input_ids"
,
"decoder_mems_list_0"
,
...
...
@@ -272,13 +274,11 @@ def main():
export_decoder
(
canary_model
)
for
m
in
[
"encoder"
,
"decoder"
]:
if
m
==
"encoder"
:
# we don't quantize the decoder with int8 since the accuracy drops
quantize_dynamic
(
model_input
=
f
"./{m}.onnx"
,
model_output
=
f
"./{m}.int8.onnx"
,
weight_type
=
QuantType
.
QUInt8
,
)
quantize_dynamic
(
model_input
=
f
"./{m}.onnx"
,
model_output
=
f
"./{m}.int8.onnx"
,
weight_type
=
QuantType
.
QUInt8
,
)
export_onnx_fp16
(
f
"{m}.onnx"
,
f
"{m}.fp16.onnx"
)
...
...
scripts/nemo/canary/test_180m_flash.py
查看文件 @
5ebb719
...
...
@@ -263,16 +263,15 @@ def main():
decoder_input_ids
.
append
(
token2id
[
"<|notimestamp|>"
])
decoder_input_ids
.
append
(
token2id
[
"<|nodiarize|>"
])
decoder_input_ids
.
append
(
0
)
decoder_mems_list
=
[
np
.
zeros
((
1
,
0
,
1024
),
dtype
=
np
.
float32
)
for
_
in
range
(
6
)]
logits
,
decoder_mems_list
=
model
.
run_decoder
(
np
.
array
([
decoder_input_ids
],
dtype
=
np
.
int32
),
decoder_mems_list
,
enc_states
,
enc_masks
,
)
for
pos
,
decoder_input_id
in
enumerate
(
decoder_input_ids
):
logits
,
decoder_mems_list
=
model
.
run_decoder
(
np
.
array
([[
decoder_input_id
,
pos
]],
dtype
=
np
.
int32
),
decoder_mems_list
,
enc_states
,
enc_masks
,
)
tokens
=
[
logits
.
argmax
()]
print
(
"decoder_input_ids"
,
decoder_input_ids
)
eos
=
token2id
[
"<|endoftext|>"
]
...
...
请
注册
或
登录
后发表评论