Fangjun Kuang
Committed by GitHub

Reduce whisper decoder file size with onnx export (#328)

... ... @@ -200,10 +200,25 @@ class TextDecoderTensorCache(nn.Module):
x = self.textDecoder.ln(x)
logits = (
x
@ torch.transpose(self.textDecoder.token_embedding.weight.to(x.dtype), 0, 1)
).float()
if False:
# x.shape (1, 3, 384)
# weight.shape (51684, 384)
logits = (
x
@ torch.transpose(
self.textDecoder.token_embedding.weight.to(x.dtype), 0, 1
)
).float()
else:
logits = (
torch.matmul(
self.textDecoder.token_embedding.weight.to(x.dtype),
x.permute(0, 2, 1),
)
.permute(0, 2, 1)
.float()
)
return logits, n_layer_self_k_cache, n_layer_self_v_cache
... ... @@ -246,6 +261,19 @@ def main():
opset_version = 13
model = whisper.load_model(name)
print(
f"number of model parameters: {name}",
sum(p.numel() for p in model.parameters()),
)
print(
f"number of encoder parameters: {name}",
sum(p.numel() for p in model.encoder.parameters()),
)
print(
f"number of decoder parameters: {name}",
sum(p.numel() for p in model.decoder.parameters()),
)
convert_tokens(name=name, model=model)
# write tokens
... ... @@ -419,7 +447,7 @@ def main():
},
)
if 'large' in args.model:
if "large" in args.model:
# it causes errors for large models, so skip it.
return
# Generate int8 quantization models
... ...