From ac5ca2bbe0fbfeecd50b1fdc0ca884831fd9d629 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 7 Jan 2021 23:43:50 -0800 Subject: [PATCH] fix data_ptr assertion error for past_sequence_length=0 in GPT-2 (#6284) fix io binding crash for past_sequence_length=0 --- onnxruntime/python/tools/transformers/gpt2_helper.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/transformers/gpt2_helper.py b/onnxruntime/python/tools/transformers/gpt2_helper.py index 9a2cdaf9ba..5490c7267a 100644 --- a/onnxruntime/python/tools/transformers/gpt2_helper.py +++ b/onnxruntime/python/tools/transformers/gpt2_helper.py @@ -407,8 +407,14 @@ class Gpt2Helper: if past is not None: for i, past_i in enumerate(past): assert past_i.is_contiguous() - io_binding.bind_input(f'past_{i}', past_i.device.type, 0, float_type, list(past_i.size()), - past_i.data_ptr()) + + data_ptr = past_i.data_ptr() + if data_ptr == 0: + # When past_sequence_length is 0, its data_ptr will be zero. IO Binding asserts that data_ptr shall not be zero. + # Here we workaround and pass data pointer of input_ids. Actual data is not used for past so it does not matter. + data_ptr = input_ids.data_ptr() + + io_binding.bind_input(f'past_{i}', past_i.device.type, 0, float_type, list(past_i.size()), data_ptr) if attention_mask is not None: assert attention_mask.is_contiguous()