mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
Resolve issue where a registered buffer was parsed incorrectly as a user input (#7617)
This commit is contained in:
parent
a684e9aa52
commit
08fbfe9607
2 changed files with 54 additions and 9 deletions
|
|
@ -91,19 +91,27 @@ def _combine_input_buffers_initializers(param_names, onnx_input_names, input_inf
|
|||
if name in kwargs and kwargs[name] is not None:
|
||||
# Only use keywords coming from user that are expected by ONNX model
|
||||
inp = kwargs[name]
|
||||
elif input_idx < len(non_none_inputs):
|
||||
# Only use positionals coming from user that are expected by ONNX model
|
||||
if name != input_info.names[input_idx]:
|
||||
# When ONNX drops unused inputs, get correct index from user input
|
||||
input_idx = input_info.names.index(name)
|
||||
inp = non_none_inputs[input_idx]
|
||||
|
||||
elif input_idx >= len(non_none_inputs):
|
||||
if inp is None:
|
||||
try:
|
||||
# Only use positionals coming from user that are expected by ONNX model
|
||||
# if input_idx >= len(input_info.names), IndexError will be thrown
|
||||
if name != input_info.names[input_idx]:
|
||||
# When ONNX drops unused inputs, get correct index from user input
|
||||
# if name is not in input_info.names, ValueError will be thrown
|
||||
input_idx = input_info.names.index(name)
|
||||
inp = non_none_inputs[input_idx]
|
||||
except (IndexError, ValueError):
|
||||
# ONNX input name is not present in input_info.names.
|
||||
pass
|
||||
|
||||
if inp is None:
|
||||
# Registered buffers are translated to user_input+initializer in ONNX
|
||||
try:
|
||||
inp = buffer_names_dict[name]
|
||||
except KeyError:
|
||||
raise KeyError(f'Registered buffer name {name} not found.')
|
||||
# ONNX input name is not present in the registered buffer dict.
|
||||
pass
|
||||
|
||||
if inp is not None:
|
||||
if _PrimitiveType.is_primitive_type(inp):
|
||||
|
|
@ -111,6 +119,7 @@ def _combine_input_buffers_initializers(param_names, onnx_input_names, input_inf
|
|||
result.append(inp)
|
||||
else:
|
||||
raise RuntimeError(f'Input is present in ONNX graph but not provided: {name}.')
|
||||
|
||||
# Initializers
|
||||
result.extend([param[1] for param in param_names])
|
||||
return result
|
||||
|
|
@ -255,7 +264,9 @@ def _parse_outputs_and_extract_names_and_dynamic_axes(module_output):
|
|||
if output is None:
|
||||
return
|
||||
elif isinstance(output, torch.Tensor):
|
||||
output_name = f'output{output_idx[0]}'
|
||||
# Naming the outputs with a hyphen ensures that there can be no input with the same
|
||||
# name, preventing collisions with other NodeArgs (for example an input to forward called output0)
|
||||
output_name = f'output-{output_idx[0]}'
|
||||
output_idx[0] += 1
|
||||
output_names.append(output_name)
|
||||
output_dynamic_axes[output_name] = {}
|
||||
|
|
|
|||
|
|
@ -2328,3 +2328,37 @@ def test_changing_bool_input_re_exports_model(bool_arguments):
|
|||
exported_model2 = ort_model._execution_manager(ort_model._is_training())._onnx_model
|
||||
|
||||
assert exported_model1 != exported_model2
|
||||
|
||||
def test_model_with_registered_buffer_and_dropped_parameters():
|
||||
class ModelWithBufferAndDroppedParameter(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
super(ModelWithBufferAndDroppedParameter, self).__init__()
|
||||
|
||||
self.fc1 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
|
||||
self.register_buffer("buffer", torch.ones(num_classes))
|
||||
|
||||
def forward(self, bool_argument, input1):
|
||||
if bool_argument:
|
||||
out = self.fc1(input1)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
out = out + self.buffer
|
||||
else:
|
||||
out = self.fc1(input1)
|
||||
out = self.fc2(out)
|
||||
out = self.relu(out)
|
||||
out = out + self.buffer
|
||||
return out
|
||||
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
model = ModelWithBufferAndDroppedParameter(D_in, H, D_out).to(device)
|
||||
model = ORTModule(model)
|
||||
|
||||
bool_argument = torch.tensor(True)
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
|
||||
# Ensure that no exceptions are raised
|
||||
out = model(bool_argument, x)
|
||||
|
|
|
|||
Loading…
Reference in a new issue