mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
more inputs support for LLM exporter (#19005)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
07d3aed3aa
commit
9876cc7c4f
1 changed files with 16 additions and 5 deletions
|
|
@ -224,24 +224,35 @@ def fetch_onnx_inputs_outputs_name(
|
|||
if not num_of_past_key:
|
||||
num_of_past_key = model.config.num_hidden_layers
|
||||
|
||||
onnx_inp_names = ("input_ids", "attention_mask")
|
||||
# filter out constant inputs
|
||||
onnx_inp_names = tuple(
|
||||
[torch_input_names[i] for i in range(len(torch_input_names)) if isinstance(onnx_inputs[i], torch.Tensor)]
|
||||
)
|
||||
assert (
|
||||
"input_ids" in onnx_inp_names and "attention_mask" in onnx_inp_names
|
||||
), "input_ids and attention_mask must be existed in inputs"
|
||||
onnx_out_names = ("logits",)
|
||||
onnx_dynamic_axes = {
|
||||
"input_ids": {0: "batch_size", 1: "seq_len"},
|
||||
"attention_mask": {0: "batch_size", 1: "seq_len"},
|
||||
}
|
||||
# add dyanmic dimensions for the unkonw inputs
|
||||
for idx, name in enumerate(onnx_inp_names):
|
||||
if name not in onnx_dynamic_axes:
|
||||
unknown_dims = {i: f"{idx}__unknown_dims__{i}" for i in range(onnx_inputs[idx].dim())}
|
||||
onnx_dynamic_axes[name] = unknown_dims
|
||||
if input_with_past:
|
||||
for i in range(num_of_past_key):
|
||||
onnx_inp_names += (f"present_key.{i}",)
|
||||
onnx_inp_names += (f"present_values.{i}",)
|
||||
onnx_inp_names += (f"past_key_values.{i}.key",)
|
||||
onnx_inp_names += (f"past_key_values.{i}.value",)
|
||||
|
||||
onnx_dynamic_axes[onnx_inp_names[-1]] = kv_cache_axis
|
||||
onnx_dynamic_axes[onnx_inp_names[-2]] = kv_cache_axis
|
||||
|
||||
if with_past or input_with_past:
|
||||
for i in range(num_of_past_key):
|
||||
onnx_out_names += (f"past_key.{i}",)
|
||||
onnx_out_names += (f"past_values.{i}",)
|
||||
onnx_out_names += (f"present.{i}.key",)
|
||||
onnx_out_names += (f"present.{i}.value",)
|
||||
onnx_dynamic_axes[onnx_out_names[-1]] = kv_cache_axis
|
||||
onnx_dynamic_axes[onnx_out_names[-2]] = kv_cache_axis
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue