Resolve issue where a registered buffer was parsed incorrectly as a user input (#7617)

This commit is contained in:
baijumeswani 2021-05-10 19:04:27 -07:00 committed by GitHub
parent a684e9aa52
commit 08fbfe9607
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 54 additions and 9 deletions

View file

@ -91,19 +91,27 @@ def _combine_input_buffers_initializers(param_names, onnx_input_names, input_inf
if name in kwargs and kwargs[name] is not None:
# Only use keywords coming from user that are expected by ONNX model
inp = kwargs[name]
elif input_idx < len(non_none_inputs):
# Only use positionals coming from user that are expected by ONNX model
if name != input_info.names[input_idx]:
# When ONNX drops unused inputs, get correct index from user input
input_idx = input_info.names.index(name)
inp = non_none_inputs[input_idx]
elif input_idx >= len(non_none_inputs):
if inp is None:
try:
# Only use positionals coming from user that are expected by ONNX model
# if input_idx >= len(input_info.names), IndexError will be thrown
if name != input_info.names[input_idx]:
# When ONNX drops unused inputs, get correct index from user input
# if name is not in input_info.names, ValueError will be thrown
input_idx = input_info.names.index(name)
inp = non_none_inputs[input_idx]
except (IndexError, ValueError):
# ONNX input name is not present in input_info.names.
pass
if inp is None:
# Registered buffers are translated to user_input+initializer in ONNX
try:
inp = buffer_names_dict[name]
except KeyError:
raise KeyError(f'Registered buffer name {name} not found.')
# ONNX input name is not present in the registered buffer dict.
pass
if inp is not None:
if _PrimitiveType.is_primitive_type(inp):
@ -111,6 +119,7 @@ def _combine_input_buffers_initializers(param_names, onnx_input_names, input_inf
result.append(inp)
else:
raise RuntimeError(f'Input is present in ONNX graph but not provided: {name}.')
# Initializers
result.extend([param[1] for param in param_names])
return result
@ -255,7 +264,9 @@ def _parse_outputs_and_extract_names_and_dynamic_axes(module_output):
if output is None:
return
elif isinstance(output, torch.Tensor):
output_name = f'output{output_idx[0]}'
# Naming the outputs with a hyphen ensures that there can be no input with the same
# name, preventing collisions with other NodeArgs (for example an input to forward called output0)
output_name = f'output-{output_idx[0]}'
output_idx[0] += 1
output_names.append(output_name)
output_dynamic_axes[output_name] = {}

View file

@ -2328,3 +2328,37 @@ def test_changing_bool_input_re_exports_model(bool_arguments):
exported_model2 = ort_model._execution_manager(ort_model._is_training())._onnx_model
assert exported_model1 != exported_model2
def test_model_with_registered_buffer_and_dropped_parameters():
class ModelWithBufferAndDroppedParameter(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(ModelWithBufferAndDroppedParameter, self).__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
self.register_buffer("buffer", torch.ones(num_classes))
def forward(self, bool_argument, input1):
if bool_argument:
out = self.fc1(input1)
out = self.relu(out)
out = self.fc2(out)
out = out + self.buffer
else:
out = self.fc1(input1)
out = self.fc2(out)
out = self.relu(out)
out = out + self.buffer
return out
device = 'cuda'
N, D_in, H, D_out = 64, 784, 500, 10
model = ModelWithBufferAndDroppedParameter(D_in, H, D_out).to(device)
model = ORTModule(model)
bool_argument = torch.tensor(True)
x = torch.randn(N, D_in, device=device)
# Ensure that no exceptions are raised
out = model(bool_argument, x)