Committed by
GitHub
fix(canary): use dynamo export, single input_ids and avoid 0/1 specialization (#2348)
正在显示
2 个修改的文件
包含
21 行增加
和
22 行删除
| @@ -197,12 +197,12 @@ def export_decoder(canary_model): | @@ -197,12 +197,12 @@ def export_decoder(canary_model): | ||
| 197 | decoder = DecoderWrapper(canary_model) | 197 | decoder = DecoderWrapper(canary_model) |
| 198 | decoder_input_ids = torch.tensor([[1, 0]], dtype=torch.int32) | 198 | decoder_input_ids = torch.tensor([[1, 0]], dtype=torch.int32) |
| 199 | 199 | ||
| 200 | - decoder_mems_list_0 = torch.zeros(1, 1, 1024) | ||
| 201 | - decoder_mems_list_1 = torch.zeros(1, 1, 1024) | ||
| 202 | - decoder_mems_list_2 = torch.zeros(1, 1, 1024) | ||
| 203 | - decoder_mems_list_3 = torch.zeros(1, 1, 1024) | ||
| 204 | - decoder_mems_list_4 = torch.zeros(1, 1, 1024) | ||
| 205 | - decoder_mems_list_5 = torch.zeros(1, 1, 1024) | 200 | + decoder_mems_list_0 = torch.zeros(1, 10, 1024) |
| 201 | + decoder_mems_list_1 = torch.zeros(1, 10, 1024) | ||
| 202 | + decoder_mems_list_2 = torch.zeros(1, 10, 1024) | ||
| 203 | + decoder_mems_list_3 = torch.zeros(1, 10, 1024) | ||
| 204 | + decoder_mems_list_4 = torch.zeros(1, 10, 1024) | ||
| 205 | + decoder_mems_list_5 = torch.zeros(1, 10, 1024) | ||
| 206 | 206 | ||
| 207 | enc_states = torch.zeros(1, 1000, 1024) | 207 | enc_states = torch.zeros(1, 1000, 1024) |
| 208 | enc_mask = torch.ones(1, 1000).bool() | 208 | enc_mask = torch.ones(1, 1000).bool() |
| @@ -221,7 +221,9 @@ def export_decoder(canary_model): | @@ -221,7 +221,9 @@ def export_decoder(canary_model): | ||
| 221 | enc_mask, | 221 | enc_mask, |
| 222 | ), | 222 | ), |
| 223 | "decoder.onnx", | 223 | "decoder.onnx", |
| 224 | - opset_version=14, | 224 | + dynamo=True, |
| 225 | + opset_version=18, | ||
| 226 | + external_data=False, | ||
| 225 | input_names=[ | 227 | input_names=[ |
| 226 | "decoder_input_ids", | 228 | "decoder_input_ids", |
| 227 | "decoder_mems_list_0", | 229 | "decoder_mems_list_0", |
| @@ -272,13 +274,11 @@ def main(): | @@ -272,13 +274,11 @@ def main(): | ||
| 272 | export_decoder(canary_model) | 274 | export_decoder(canary_model) |
| 273 | 275 | ||
| 274 | for m in ["encoder", "decoder"]: | 276 | for m in ["encoder", "decoder"]: |
| 275 | - if m == "encoder": | ||
| 276 | - # we don't quantize the decoder with int8 since the accuracy drops | ||
| 277 | - quantize_dynamic( | ||
| 278 | - model_input=f"./{m}.onnx", | ||
| 279 | - model_output=f"./{m}.int8.onnx", | ||
| 280 | - weight_type=QuantType.QUInt8, | ||
| 281 | - ) | 277 | + quantize_dynamic( |
| 278 | + model_input=f"./{m}.onnx", | ||
| 279 | + model_output=f"./{m}.int8.onnx", | ||
| 280 | + weight_type=QuantType.QUInt8, | ||
| 281 | + ) | ||
| 282 | 282 | ||
| 283 | export_onnx_fp16(f"{m}.onnx", f"{m}.fp16.onnx") | 283 | export_onnx_fp16(f"{m}.onnx", f"{m}.fp16.onnx") |
| 284 | 284 |
| @@ -263,16 +263,15 @@ def main(): | @@ -263,16 +263,15 @@ def main(): | ||
| 263 | decoder_input_ids.append(token2id["<|notimestamp|>"]) | 263 | decoder_input_ids.append(token2id["<|notimestamp|>"]) |
| 264 | decoder_input_ids.append(token2id["<|nodiarize|>"]) | 264 | decoder_input_ids.append(token2id["<|nodiarize|>"]) |
| 265 | 265 | ||
| 266 | - decoder_input_ids.append(0) | ||
| 267 | - | ||
| 268 | decoder_mems_list = [np.zeros((1, 0, 1024), dtype=np.float32) for _ in range(6)] | 266 | decoder_mems_list = [np.zeros((1, 0, 1024), dtype=np.float32) for _ in range(6)] |
| 269 | 267 | ||
| 270 | - logits, decoder_mems_list = model.run_decoder( | ||
| 271 | - np.array([decoder_input_ids], dtype=np.int32), | ||
| 272 | - decoder_mems_list, | ||
| 273 | - enc_states, | ||
| 274 | - enc_masks, | ||
| 275 | - ) | 268 | + for pos, decoder_input_id in enumerate(decoder_input_ids): |
| 269 | + logits, decoder_mems_list = model.run_decoder( | ||
| 270 | + np.array([[decoder_input_id,pos]], dtype=np.int32), | ||
| 271 | + decoder_mems_list, | ||
| 272 | + enc_states, | ||
| 273 | + enc_masks, | ||
| 274 | + ) | ||
| 276 | tokens = [logits.argmax()] | 275 | tokens = [logits.argmax()] |
| 277 | print("decoder_input_ids", decoder_input_ids) | 276 | print("decoder_input_ids", decoder_input_ids) |
| 278 | eos = token2id["<|endoftext|>"] | 277 | eos = token2id["<|endoftext|>"] |
-
请 注册 或 登录 后发表评论