From b80e8ce6a5df73423f5421918e64db384f3a830f Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 7 Jan 2021 11:12:04 -0800 Subject: [PATCH] rename past to past_key_values for GPT-2 (#6269) rename past to past_key_values for transformers 4.* --- .../python/tools/transformers/gpt2_helper.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/onnxruntime/python/tools/transformers/gpt2_helper.py b/onnxruntime/python/tools/transformers/gpt2_helper.py index 9fbd58a386..9a2cdaf9ba 100644 --- a/onnxruntime/python/tools/transformers/gpt2_helper.py +++ b/onnxruntime/python/tools/transformers/gpt2_helper.py @@ -41,7 +41,10 @@ class MyGPT2Model(GPT2Model): super().__init__(config) def forward(self, input_ids, position_ids, attention_mask, *past): - return super().forward(input_ids, position_ids=position_ids, attention_mask=attention_mask, past=past) + return super().forward(input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past) class MyGPT2LMHeadModel(GPT2LMHeadModel): @@ -51,7 +54,10 @@ class MyGPT2LMHeadModel(GPT2LMHeadModel): super().__init__(config) def forward(self, input_ids, position_ids, attention_mask, *past): - return super().forward(input_ids, position_ids=position_ids, attention_mask=attention_mask, past=past) + return super().forward(input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past) class MyGPT2LMHeadModel_NoPadding(GPT2LMHeadModel): @@ -63,7 +69,7 @@ class MyGPT2LMHeadModel_NoPadding(GPT2LMHeadModel): super().__init__(config) def forward(self, input_ids, *past): - return super().forward(input_ids, past=past) + return super().forward(input_ids, past_key_values=past) # Maps model class name to a tuple of model class, name of first output and use padding or not @@ -568,7 +574,7 @@ class Gpt2Helper: """ Build a path name for given model based on given attributes. """ model_name = model_name_or_path - if not re.match('^[\w_-]+$', model_name_or_path): # It is not a name, shall be a path + if not re.match(r'^[\w_-]+$', model_name_or_path): # It is not a name, shall be a path assert os.path.isdir(model_name_or_path) model_name = Path(model_name_or_path).parts[-1]