lucaelin
Committed by GitHub

fix(canary): use dynamo export, single input_ids and avoid 0/1 specialization (#2348)

@@ -197,12 +197,12 @@ def export_decoder(canary_model): @@ -197,12 +197,12 @@ def export_decoder(canary_model):
197 decoder = DecoderWrapper(canary_model) 197 decoder = DecoderWrapper(canary_model)
198 decoder_input_ids = torch.tensor([[1, 0]], dtype=torch.int32) 198 decoder_input_ids = torch.tensor([[1, 0]], dtype=torch.int32)
199 199
200 - decoder_mems_list_0 = torch.zeros(1, 1, 1024)  
201 - decoder_mems_list_1 = torch.zeros(1, 1, 1024)  
202 - decoder_mems_list_2 = torch.zeros(1, 1, 1024)  
203 - decoder_mems_list_3 = torch.zeros(1, 1, 1024)  
204 - decoder_mems_list_4 = torch.zeros(1, 1, 1024)  
205 - decoder_mems_list_5 = torch.zeros(1, 1, 1024) 200 + decoder_mems_list_0 = torch.zeros(1, 10, 1024)
  201 + decoder_mems_list_1 = torch.zeros(1, 10, 1024)
  202 + decoder_mems_list_2 = torch.zeros(1, 10, 1024)
  203 + decoder_mems_list_3 = torch.zeros(1, 10, 1024)
  204 + decoder_mems_list_4 = torch.zeros(1, 10, 1024)
  205 + decoder_mems_list_5 = torch.zeros(1, 10, 1024)
206 206
207 enc_states = torch.zeros(1, 1000, 1024) 207 enc_states = torch.zeros(1, 1000, 1024)
208 enc_mask = torch.ones(1, 1000).bool() 208 enc_mask = torch.ones(1, 1000).bool()
@@ -221,7 +221,9 @@ def export_decoder(canary_model): @@ -221,7 +221,9 @@ def export_decoder(canary_model):
221 enc_mask, 221 enc_mask,
222 ), 222 ),
223 "decoder.onnx", 223 "decoder.onnx",
224 - opset_version=14, 224 + dynamo=True,
  225 + opset_version=18,
  226 + external_data=False,
225 input_names=[ 227 input_names=[
226 "decoder_input_ids", 228 "decoder_input_ids",
227 "decoder_mems_list_0", 229 "decoder_mems_list_0",
@@ -272,13 +274,11 @@ def main(): @@ -272,13 +274,11 @@ def main():
272 export_decoder(canary_model) 274 export_decoder(canary_model)
273 275
274 for m in ["encoder", "decoder"]: 276 for m in ["encoder", "decoder"]:
275 - if m == "encoder":  
276 - # we don't quantize the decoder with int8 since the accuracy drops  
277 - quantize_dynamic(  
278 - model_input=f"./{m}.onnx",  
279 - model_output=f"./{m}.int8.onnx",  
280 - weight_type=QuantType.QUInt8,  
281 - ) 277 + quantize_dynamic(
  278 + model_input=f"./{m}.onnx",
  279 + model_output=f"./{m}.int8.onnx",
  280 + weight_type=QuantType.QUInt8,
  281 + )
282 282
283 export_onnx_fp16(f"{m}.onnx", f"{m}.fp16.onnx") 283 export_onnx_fp16(f"{m}.onnx", f"{m}.fp16.onnx")
284 284
@@ -263,16 +263,15 @@ def main(): @@ -263,16 +263,15 @@ def main():
263 decoder_input_ids.append(token2id["<|notimestamp|>"]) 263 decoder_input_ids.append(token2id["<|notimestamp|>"])
264 decoder_input_ids.append(token2id["<|nodiarize|>"]) 264 decoder_input_ids.append(token2id["<|nodiarize|>"])
265 265
266 - decoder_input_ids.append(0)  
267 -  
268 decoder_mems_list = [np.zeros((1, 0, 1024), dtype=np.float32) for _ in range(6)] 266 decoder_mems_list = [np.zeros((1, 0, 1024), dtype=np.float32) for _ in range(6)]
269 267
270 - logits, decoder_mems_list = model.run_decoder(  
271 - np.array([decoder_input_ids], dtype=np.int32),  
272 - decoder_mems_list,  
273 - enc_states,  
274 - enc_masks,  
275 - ) 268 + for pos, decoder_input_id in enumerate(decoder_input_ids):
  269 + logits, decoder_mems_list = model.run_decoder(
  270 + np.array([[decoder_input_id,pos]], dtype=np.int32),
  271 + decoder_mems_list,
  272 + enc_states,
  273 + enc_masks,
  274 + )
276 tokens = [logits.argmax()] 275 tokens = [logits.argmax()]
277 print("decoder_input_ids", decoder_input_ids) 276 print("decoder_input_ids", decoder_input_ids)
278 eos = token2id["<|endoftext|>"] 277 eos = token2id["<|endoftext|>"]