diff --git a/onnxruntime/python/tools/transformers/large_model_exporter.py b/onnxruntime/python/tools/transformers/large_model_exporter.py index 1601b1a203..9e8b284bf5 100644 --- a/onnxruntime/python/tools/transformers/large_model_exporter.py +++ b/onnxruntime/python/tools/transformers/large_model_exporter.py @@ -224,24 +224,35 @@ def fetch_onnx_inputs_outputs_name( if not num_of_past_key: num_of_past_key = model.config.num_hidden_layers - onnx_inp_names = ("input_ids", "attention_mask") + # filter out constant inputs + onnx_inp_names = tuple( + [torch_input_names[i] for i in range(len(torch_input_names)) if isinstance(onnx_inputs[i], torch.Tensor)] + ) + assert ( + "input_ids" in onnx_inp_names and "attention_mask" in onnx_inp_names + ), "input_ids and attention_mask must be existed in inputs" onnx_out_names = ("logits",) onnx_dynamic_axes = { "input_ids": {0: "batch_size", 1: "seq_len"}, "attention_mask": {0: "batch_size", 1: "seq_len"}, } + # add dyanmic dimensions for the unkonw inputs + for idx, name in enumerate(onnx_inp_names): + if name not in onnx_dynamic_axes: + unknown_dims = {i: f"{idx}__unknown_dims__{i}" for i in range(onnx_inputs[idx].dim())} + onnx_dynamic_axes[name] = unknown_dims if input_with_past: for i in range(num_of_past_key): - onnx_inp_names += (f"present_key.{i}",) - onnx_inp_names += (f"present_values.{i}",) + onnx_inp_names += (f"past_key_values.{i}.key",) + onnx_inp_names += (f"past_key_values.{i}.value",) onnx_dynamic_axes[onnx_inp_names[-1]] = kv_cache_axis onnx_dynamic_axes[onnx_inp_names[-2]] = kv_cache_axis if with_past or input_with_past: for i in range(num_of_past_key): - onnx_out_names += (f"past_key.{i}",) - onnx_out_names += (f"past_values.{i}",) + onnx_out_names += (f"present.{i}.key",) + onnx_out_names += (f"present.{i}.value",) onnx_dynamic_axes[onnx_out_names[-1]] = kv_cache_axis onnx_dynamic_axes[onnx_out_names[-2]] = kv_cache_axis