User dynamic axes in one step beam search output (#8092)

This commit is contained in:
Xiaoyu Liu 2021-06-23 01:41:32 -07:00 committed by GitHub
parent cccd61e3bc
commit 45ce239929
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -610,6 +610,18 @@ class Gpt2BeamSearchHelper(Gpt2Helper):
input_names.append("prev_step_scores")
input_names.extend(past_names)
# add dynamic output axes
present_axes = {1: 'batch_size', 3: 'cur_seq_len'}
dynamic_axes["last_state"] = {0: 'batch_size', 1: 'beam_size'}
for i in range(num_layer):
dynamic_axes["present_" + str(i)] = present_axes
dynamic_axes["output_selected_indices"] = {0: "batch_size", 1: "'beam_size_or_1'"}
dynamic_axes["output_log_probs"] = {0: "batch_size", 1: "'beam_size'"}
dynamic_axes["output_unfinished_sents"] = {0: "batch_size", 1: "'beam_size'"}
dynamic_axes["current_step_results"] = {0: "beam_size_or_1", 1: "total_seq_len"}
dynamic_axes["current_step_scores"] = {0: "beam_size_or_1", 1: "total_seq_len"}
logger.info(
f"Shapes: input_ids={dummy_inputs.input_ids.shape} past={dummy_inputs.past[0].shape} output={outputs[0].shape} present={outputs[1][0].shape}"
)