more inputs support for LLM exporter (#19005)

### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
wejoncy 2024-01-17 15:46:19 +08:00 committed by GitHub
parent 07d3aed3aa
commit 9876cc7c4f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -224,24 +224,35 @@ def fetch_onnx_inputs_outputs_name(
if not num_of_past_key:
num_of_past_key = model.config.num_hidden_layers
onnx_inp_names = ("input_ids", "attention_mask")
# filter out constant inputs
onnx_inp_names = tuple(
[torch_input_names[i] for i in range(len(torch_input_names)) if isinstance(onnx_inputs[i], torch.Tensor)]
)
assert (
"input_ids" in onnx_inp_names and "attention_mask" in onnx_inp_names
), "input_ids and attention_mask must be existed in inputs"
onnx_out_names = ("logits",)
onnx_dynamic_axes = {
"input_ids": {0: "batch_size", 1: "seq_len"},
"attention_mask": {0: "batch_size", 1: "seq_len"},
}
# add dyanmic dimensions for the unkonw inputs
for idx, name in enumerate(onnx_inp_names):
if name not in onnx_dynamic_axes:
unknown_dims = {i: f"{idx}__unknown_dims__{i}" for i in range(onnx_inputs[idx].dim())}
onnx_dynamic_axes[name] = unknown_dims
if input_with_past:
for i in range(num_of_past_key):
onnx_inp_names += (f"present_key.{i}",)
onnx_inp_names += (f"present_values.{i}",)
onnx_inp_names += (f"past_key_values.{i}.key",)
onnx_inp_names += (f"past_key_values.{i}.value",)
onnx_dynamic_axes[onnx_inp_names[-1]] = kv_cache_axis
onnx_dynamic_axes[onnx_inp_names[-2]] = kv_cache_axis
if with_past or input_with_past:
for i in range(num_of_past_key):
onnx_out_names += (f"past_key.{i}",)
onnx_out_names += (f"past_values.{i}",)
onnx_out_names += (f"present.{i}.key",)
onnx_out_names += (f"present.{i}.value",)
onnx_dynamic_axes[onnx_out_names[-1]] = kv_cache_axis
onnx_dynamic_axes[onnx_out_names[-2]] = kv_cache_axis