Fangjun Kuang
Committed by GitHub

Reduce whisper decoder file size with onnx export (#328)

@@ -200,10 +200,25 @@ class TextDecoderTensorCache(nn.Module): @@ -200,10 +200,25 @@ class TextDecoderTensorCache(nn.Module):
200 200
201 x = self.textDecoder.ln(x) 201 x = self.textDecoder.ln(x)
202 202
203 - logits = (  
204 - x  
205 - @ torch.transpose(self.textDecoder.token_embedding.weight.to(x.dtype), 0, 1)  
206 - ).float() 203 + if False:
  204 + # x.shape (1, 3, 384)
  205 + # weight.shape (51684, 384)
  206 +
  207 + logits = (
  208 + x
  209 + @ torch.transpose(
  210 + self.textDecoder.token_embedding.weight.to(x.dtype), 0, 1
  211 + )
  212 + ).float()
  213 + else:
  214 + logits = (
  215 + torch.matmul(
  216 + self.textDecoder.token_embedding.weight.to(x.dtype),
  217 + x.permute(0, 2, 1),
  218 + )
  219 + .permute(0, 2, 1)
  220 + .float()
  221 + )
207 222
208 return logits, n_layer_self_k_cache, n_layer_self_v_cache 223 return logits, n_layer_self_k_cache, n_layer_self_v_cache
209 224
@@ -246,6 +261,19 @@ def main(): @@ -246,6 +261,19 @@ def main():
246 opset_version = 13 261 opset_version = 13
247 262
248 model = whisper.load_model(name) 263 model = whisper.load_model(name)
  264 + print(
  265 + f"number of model parameters: {name}",
  266 + sum(p.numel() for p in model.parameters()),
  267 + )
  268 + print(
  269 + f"number of encoder parameters: {name}",
  270 + sum(p.numel() for p in model.encoder.parameters()),
  271 + )
  272 + print(
  273 + f"number of decoder parameters: {name}",
  274 + sum(p.numel() for p in model.decoder.parameters()),
  275 + )
  276 +
249 convert_tokens(name=name, model=model) 277 convert_tokens(name=name, model=model)
250 278
251 # write tokens 279 # write tokens
@@ -419,7 +447,7 @@ def main(): @@ -419,7 +447,7 @@ def main():
419 }, 447 },
420 ) 448 )
421 449
422 - if 'large' in args.model: 450 + if "large" in args.model:
423 # it causes errors for large models, so skip it. 451 # it causes errors for large models, so skip it.
424 return 452 return
425 # Generate int8 quantization models 453 # Generate int8 quantization models