rename past to past_key_values for GPT-2 (#6269)

rename past to past_key_values for transformers 4.*
This commit is contained in:
Tianlei Wu 2021-01-07 11:12:04 -08:00 committed by GitHub
parent 481a2cdf61
commit b80e8ce6a5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -41,7 +41,10 @@ class MyGPT2Model(GPT2Model):
super().__init__(config)
def forward(self, input_ids, position_ids, attention_mask, *past):
return super().forward(input_ids, position_ids=position_ids, attention_mask=attention_mask, past=past)
return super().forward(input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past)
class MyGPT2LMHeadModel(GPT2LMHeadModel):
@ -51,7 +54,10 @@ class MyGPT2LMHeadModel(GPT2LMHeadModel):
super().__init__(config)
def forward(self, input_ids, position_ids, attention_mask, *past):
return super().forward(input_ids, position_ids=position_ids, attention_mask=attention_mask, past=past)
return super().forward(input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past)
class MyGPT2LMHeadModel_NoPadding(GPT2LMHeadModel):
@ -63,7 +69,7 @@ class MyGPT2LMHeadModel_NoPadding(GPT2LMHeadModel):
super().__init__(config)
def forward(self, input_ids, *past):
return super().forward(input_ids, past=past)
return super().forward(input_ids, past_key_values=past)
# Maps model class name to a tuple of model class, name of first output and use padding or not
@ -568,7 +574,7 @@ class Gpt2Helper:
""" Build a path name for given model based on given attributes.
"""
model_name = model_name_or_path
if not re.match('^[\w_-]+$', model_name_or_path): # It is not a name, shall be a path
if not re.match(r'^[\w_-]+$', model_name_or_path): # It is not a name, shall be a path
assert os.path.isdir(model_name_or_path)
model_name = Path(model_name_or_path).parts[-1]