diff --git a/onnxruntime/python/tools/transformers/gpt2_helper.py b/onnxruntime/python/tools/transformers/gpt2_helper.py index 9fbd58a386..9a2cdaf9ba 100644 --- a/onnxruntime/python/tools/transformers/gpt2_helper.py +++ b/onnxruntime/python/tools/transformers/gpt2_helper.py @@ -41,7 +41,10 @@ class MyGPT2Model(GPT2Model): super().__init__(config) def forward(self, input_ids, position_ids, attention_mask, *past): - return super().forward(input_ids, position_ids=position_ids, attention_mask=attention_mask, past=past) + return super().forward(input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past) class MyGPT2LMHeadModel(GPT2LMHeadModel): @@ -51,7 +54,10 @@ class MyGPT2LMHeadModel(GPT2LMHeadModel): super().__init__(config) def forward(self, input_ids, position_ids, attention_mask, *past): - return super().forward(input_ids, position_ids=position_ids, attention_mask=attention_mask, past=past) + return super().forward(input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past) class MyGPT2LMHeadModel_NoPadding(GPT2LMHeadModel): @@ -63,7 +69,7 @@ class MyGPT2LMHeadModel_NoPadding(GPT2LMHeadModel): super().__init__(config) def forward(self, input_ids, *past): - return super().forward(input_ids, past=past) + return super().forward(input_ids, past_key_values=past) # Maps model class name to a tuple of model class, name of first output and use padding or not @@ -568,7 +574,7 @@ class Gpt2Helper: """ Build a path name for given model based on given attributes. """ model_name = model_name_or_path - if not re.match('^[\w_-]+$', model_name_or_path): # It is not a name, shall be a path + if not re.match(r'^[\w_-]+$', model_name_or_path): # It is not a name, shall be a path assert os.path.isdir(model_name_or_path) model_name = Path(model_name_or_path).parts[-1]