mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
rename past to past_key_values for GPT-2 (#6269)
rename past to past_key_values for transformers 4.*
This commit is contained in:
parent
481a2cdf61
commit
b80e8ce6a5
1 changed files with 10 additions and 4 deletions
|
|
@ -41,7 +41,10 @@ class MyGPT2Model(GPT2Model):
|
|||
super().__init__(config)
|
||||
|
||||
def forward(self, input_ids, position_ids, attention_mask, *past):
|
||||
return super().forward(input_ids, position_ids=position_ids, attention_mask=attention_mask, past=past)
|
||||
return super().forward(input_ids,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past)
|
||||
|
||||
|
||||
class MyGPT2LMHeadModel(GPT2LMHeadModel):
|
||||
|
|
@ -51,7 +54,10 @@ class MyGPT2LMHeadModel(GPT2LMHeadModel):
|
|||
super().__init__(config)
|
||||
|
||||
def forward(self, input_ids, position_ids, attention_mask, *past):
|
||||
return super().forward(input_ids, position_ids=position_ids, attention_mask=attention_mask, past=past)
|
||||
return super().forward(input_ids,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past)
|
||||
|
||||
|
||||
class MyGPT2LMHeadModel_NoPadding(GPT2LMHeadModel):
|
||||
|
|
@ -63,7 +69,7 @@ class MyGPT2LMHeadModel_NoPadding(GPT2LMHeadModel):
|
|||
super().__init__(config)
|
||||
|
||||
def forward(self, input_ids, *past):
|
||||
return super().forward(input_ids, past=past)
|
||||
return super().forward(input_ids, past_key_values=past)
|
||||
|
||||
|
||||
# Maps model class name to a tuple of model class, name of first output and use padding or not
|
||||
|
|
@ -568,7 +574,7 @@ class Gpt2Helper:
|
|||
""" Build a path name for given model based on given attributes.
|
||||
"""
|
||||
model_name = model_name_or_path
|
||||
if not re.match('^[\w_-]+$', model_name_or_path): # It is not a name, shall be a path
|
||||
if not re.match(r'^[\w_-]+$', model_name_or_path): # It is not a name, shall be a path
|
||||
assert os.path.isdir(model_name_or_path)
|
||||
model_name = Path(model_name_or_path).parts[-1]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue