mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
User dynamic axes in one step beam search output (#8092)
This commit is contained in:
parent
cccd61e3bc
commit
45ce239929
1 changed files with 12 additions and 0 deletions
|
|
@ -610,6 +610,18 @@ class Gpt2BeamSearchHelper(Gpt2Helper):
|
|||
input_names.append("prev_step_scores")
|
||||
input_names.extend(past_names)
|
||||
|
||||
# add dynamic output axes
|
||||
present_axes = {1: 'batch_size', 3: 'cur_seq_len'}
|
||||
dynamic_axes["last_state"] = {0: 'batch_size', 1: 'beam_size'}
|
||||
for i in range(num_layer):
|
||||
dynamic_axes["present_" + str(i)] = present_axes
|
||||
|
||||
dynamic_axes["output_selected_indices"] = {0: "batch_size", 1: "'beam_size_or_1'"}
|
||||
dynamic_axes["output_log_probs"] = {0: "batch_size", 1: "'beam_size'"}
|
||||
dynamic_axes["output_unfinished_sents"] = {0: "batch_size", 1: "'beam_size'"}
|
||||
dynamic_axes["current_step_results"] = {0: "beam_size_or_1", 1: "total_seq_len"}
|
||||
dynamic_axes["current_step_scores"] = {0: "beam_size_or_1", 1: "total_seq_len"}
|
||||
|
||||
logger.info(
|
||||
f"Shapes: input_ids={dummy_inputs.input_ids.shape} past={dummy_inputs.past[0].shape} output={outputs[0].shape} present={outputs[1][0].shape}"
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue