fix data_ptr assertion error for past_sequence_length=0 in GPT-2 (#6284)

fix io binding crash for past_sequence_length=0
This commit is contained in:
Tianlei Wu 2021-01-07 23:43:50 -08:00 committed by GitHub
parent 7fc827a8a1
commit ac5ca2bbe0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -407,8 +407,14 @@ class Gpt2Helper:
if past is not None:
for i, past_i in enumerate(past):
assert past_i.is_contiguous()
io_binding.bind_input(f'past_{i}', past_i.device.type, 0, float_type, list(past_i.size()),
past_i.data_ptr())
data_ptr = past_i.data_ptr()
if data_ptr == 0:
# When past_sequence_length is 0, its data_ptr will be zero. IO Binding asserts that data_ptr shall not be zero.
# Here we workaround and pass data pointer of input_ids. Actual data is not used for past so it does not matter.
data_ptr = input_ids.data_ptr()
io_binding.bind_input(f'past_{i}', past_i.device.type, 0, float_type, list(past_i.size()), data_ptr)
if attention_mask is not None:
assert attention_mask.is_contiguous()