Committed by
GitHub
Reduce whisper decoder file size with onnx export (#328)
正在显示
1 个修改的文件
包含
33 行增加
和
5 行删除
| @@ -200,10 +200,25 @@ class TextDecoderTensorCache(nn.Module): | @@ -200,10 +200,25 @@ class TextDecoderTensorCache(nn.Module): | ||
| 200 | 200 | ||
| 201 | x = self.textDecoder.ln(x) | 201 | x = self.textDecoder.ln(x) |
| 202 | 202 | ||
| 203 | - logits = ( | ||
| 204 | - x | ||
| 205 | - @ torch.transpose(self.textDecoder.token_embedding.weight.to(x.dtype), 0, 1) | ||
| 206 | - ).float() | 203 | + if False: |
| 204 | + # x.shape (1, 3, 384) | ||
| 205 | + # weight.shape (51684, 384) | ||
| 206 | + | ||
| 207 | + logits = ( | ||
| 208 | + x | ||
| 209 | + @ torch.transpose( | ||
| 210 | + self.textDecoder.token_embedding.weight.to(x.dtype), 0, 1 | ||
| 211 | + ) | ||
| 212 | + ).float() | ||
| 213 | + else: | ||
| 214 | + logits = ( | ||
| 215 | + torch.matmul( | ||
| 216 | + self.textDecoder.token_embedding.weight.to(x.dtype), | ||
| 217 | + x.permute(0, 2, 1), | ||
| 218 | + ) | ||
| 219 | + .permute(0, 2, 1) | ||
| 220 | + .float() | ||
| 221 | + ) | ||
| 207 | 222 | ||
| 208 | return logits, n_layer_self_k_cache, n_layer_self_v_cache | 223 | return logits, n_layer_self_k_cache, n_layer_self_v_cache |
| 209 | 224 | ||
| @@ -246,6 +261,19 @@ def main(): | @@ -246,6 +261,19 @@ def main(): | ||
| 246 | opset_version = 13 | 261 | opset_version = 13 |
| 247 | 262 | ||
| 248 | model = whisper.load_model(name) | 263 | model = whisper.load_model(name) |
| 264 | + print( | ||
| 265 | + f"number of model parameters: {name}", | ||
| 266 | + sum(p.numel() for p in model.parameters()), | ||
| 267 | + ) | ||
| 268 | + print( | ||
| 269 | + f"number of encoder parameters: {name}", | ||
| 270 | + sum(p.numel() for p in model.encoder.parameters()), | ||
| 271 | + ) | ||
| 272 | + print( | ||
| 273 | + f"number of decoder parameters: {name}", | ||
| 274 | + sum(p.numel() for p in model.decoder.parameters()), | ||
| 275 | + ) | ||
| 276 | + | ||
| 249 | convert_tokens(name=name, model=model) | 277 | convert_tokens(name=name, model=model) |
| 250 | 278 | ||
| 251 | # write tokens | 279 | # write tokens |
| @@ -419,7 +447,7 @@ def main(): | @@ -419,7 +447,7 @@ def main(): | ||
| 419 | }, | 447 | }, |
| 420 | ) | 448 | ) |
| 421 | 449 | ||
| 422 | - if 'large' in args.model: | 450 | + if "large" in args.model: |
| 423 | # it causes errors for large models, so skip it. | 451 | # it causes errors for large models, so skip it. |
| 424 | return | 452 | return |
| 425 | # Generate int8 quantization models | 453 | # Generate int8 quantization models |
-
请 注册 或 登录 后发表评论