mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
fix data_ptr assertion error for past_sequence_length=0 in GPT-2 (#6284)
fix io binding crash for past_sequence_length=0
This commit is contained in:
parent
7fc827a8a1
commit
ac5ca2bbe0
1 changed files with 8 additions and 2 deletions
|
|
@ -407,8 +407,14 @@ class Gpt2Helper:
|
|||
if past is not None:
|
||||
for i, past_i in enumerate(past):
|
||||
assert past_i.is_contiguous()
|
||||
io_binding.bind_input(f'past_{i}', past_i.device.type, 0, float_type, list(past_i.size()),
|
||||
past_i.data_ptr())
|
||||
|
||||
data_ptr = past_i.data_ptr()
|
||||
if data_ptr == 0:
|
||||
# When past_sequence_length is 0, its data_ptr will be zero. IO Binding asserts that data_ptr shall not be zero.
|
||||
# Here we workaround and pass data pointer of input_ids. Actual data is not used for past so it does not matter.
|
||||
data_ptr = input_ids.data_ptr()
|
||||
|
||||
io_binding.bind_input(f'past_{i}', past_i.device.type, 0, float_type, list(past_i.size()), data_ptr)
|
||||
|
||||
if attention_mask is not None:
|
||||
assert attention_mask.is_contiguous()
|
||||
|
|
|
|||
Loading…
Reference in a new issue