Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2023-09-20 19:33:26 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2023-09-20 19:33:26 +0800
Commit
f5c060dd617d2741a70d0477929503bc40b82e5d
f5c060dd
1 parent
6afa9c85
Reduce whisper decoder file size with onnx export (#328)
隐藏空白字符变更
内嵌
并排对比
正在显示
1 个修改的文件
包含
33 行增加
和
5 行删除
scripts/whisper/export-onnx.py
scripts/whisper/export-onnx.py
查看文件 @
f5c060d
...
...
@@ -200,10 +200,25 @@ class TextDecoderTensorCache(nn.Module):
x
=
self
.
textDecoder
.
ln
(
x
)
logits
=
(
x
@
torch
.
transpose
(
self
.
textDecoder
.
token_embedding
.
weight
.
to
(
x
.
dtype
),
0
,
1
)
)
.
float
()
if
False
:
# x.shape (1, 3, 384)
# weight.shape (51684, 384)
logits
=
(
x
@
torch
.
transpose
(
self
.
textDecoder
.
token_embedding
.
weight
.
to
(
x
.
dtype
),
0
,
1
)
)
.
float
()
else
:
logits
=
(
torch
.
matmul
(
self
.
textDecoder
.
token_embedding
.
weight
.
to
(
x
.
dtype
),
x
.
permute
(
0
,
2
,
1
),
)
.
permute
(
0
,
2
,
1
)
.
float
()
)
return
logits
,
n_layer_self_k_cache
,
n_layer_self_v_cache
...
...
@@ -246,6 +261,19 @@ def main():
opset_version
=
13
model
=
whisper
.
load_model
(
name
)
print
(
f
"number of model parameters: {name}"
,
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()),
)
print
(
f
"number of encoder parameters: {name}"
,
sum
(
p
.
numel
()
for
p
in
model
.
encoder
.
parameters
()),
)
print
(
f
"number of decoder parameters: {name}"
,
sum
(
p
.
numel
()
for
p
in
model
.
decoder
.
parameters
()),
)
convert_tokens
(
name
=
name
,
model
=
model
)
# write tokens
...
...
@@ -419,7 +447,7 @@ def main():
},
)
if
'large'
in
args
.
model
:
if
"large"
in
args
.
model
:
# it causes errors for large models, so skip it.
return
# Generate int8 quantization models
...
...
请
注册
或
登录
后发表评论