diff --git a/onnxruntime/python/tools/transformers/gpt2_beamsearch_helper.py b/onnxruntime/python/tools/transformers/gpt2_beamsearch_helper.py index a3fa8ec875..0001785a1c 100644 --- a/onnxruntime/python/tools/transformers/gpt2_beamsearch_helper.py +++ b/onnxruntime/python/tools/transformers/gpt2_beamsearch_helper.py @@ -610,6 +610,18 @@ class Gpt2BeamSearchHelper(Gpt2Helper): input_names.append("prev_step_scores") input_names.extend(past_names) + # add dynamic output axes + present_axes = {1: 'batch_size', 3: 'cur_seq_len'} + dynamic_axes["last_state"] = {0: 'batch_size', 1: 'beam_size'} + for i in range(num_layer): + dynamic_axes["present_" + str(i)] = present_axes + + dynamic_axes["output_selected_indices"] = {0: "batch_size", 1: "'beam_size_or_1'"} + dynamic_axes["output_log_probs"] = {0: "batch_size", 1: "'beam_size'"} + dynamic_axes["output_unfinished_sents"] = {0: "batch_size", 1: "'beam_size'"} + dynamic_axes["current_step_results"] = {0: "beam_size_or_1", 1: "total_seq_len"} + dynamic_axes["current_step_scores"] = {0: "beam_size_or_1", 1: "total_seq_len"} + logger.info( f"Shapes: input_ids={dummy_inputs.input_ids.shape} past={dummy_inputs.past[0].shape} output={outputs[0].shape} present={outputs[1][0].shape}" )