diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 7c02082937..b673130ebb 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -91,19 +91,27 @@ def _combine_input_buffers_initializers(param_names, onnx_input_names, input_inf if name in kwargs and kwargs[name] is not None: # Only use keywords coming from user that are expected by ONNX model inp = kwargs[name] - elif input_idx < len(non_none_inputs): - # Only use positionals coming from user that are expected by ONNX model - if name != input_info.names[input_idx]: - # When ONNX drops unused inputs, get correct index from user input - input_idx = input_info.names.index(name) - inp = non_none_inputs[input_idx] - elif input_idx >= len(non_none_inputs): + if inp is None: + try: + # Only use positionals coming from user that are expected by ONNX model + # if input_idx >= len(input_info.names), IndexError will be thrown + if name != input_info.names[input_idx]: + # When ONNX drops unused inputs, get correct index from user input + # if name is not in input_info.names, ValueError will be thrown + input_idx = input_info.names.index(name) + inp = non_none_inputs[input_idx] + except (IndexError, ValueError): + # ONNX input name is not present in input_info.names. + pass + + if inp is None: # Registered buffers are translated to user_input+initializer in ONNX try: inp = buffer_names_dict[name] except KeyError: - raise KeyError(f'Registered buffer name {name} not found.') + # ONNX input name is not present in the registered buffer dict. + pass if inp is not None: if _PrimitiveType.is_primitive_type(inp): @@ -111,6 +119,7 @@ def _combine_input_buffers_initializers(param_names, onnx_input_names, input_inf result.append(inp) else: raise RuntimeError(f'Input is present in ONNX graph but not provided: {name}.') + # Initializers result.extend([param[1] for param in param_names]) return result @@ -255,7 +264,9 @@ def _parse_outputs_and_extract_names_and_dynamic_axes(module_output): if output is None: return elif isinstance(output, torch.Tensor): - output_name = f'output{output_idx[0]}' + # Naming the outputs with a hyphen ensures that there can be no input with the same + # name, preventing collisions with other NodeArgs (for example an input to forward called output0) + output_name = f'output-{output_idx[0]}' output_idx[0] += 1 output_names.append(output_name) output_dynamic_axes[output_name] = {} diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 5d393094e2..df26250ed0 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -2328,3 +2328,37 @@ def test_changing_bool_input_re_exports_model(bool_arguments): exported_model2 = ort_model._execution_manager(ort_model._is_training())._onnx_model assert exported_model1 != exported_model2 + +def test_model_with_registered_buffer_and_dropped_parameters(): + class ModelWithBufferAndDroppedParameter(torch.nn.Module): + def __init__(self, input_size, hidden_size, num_classes): + super(ModelWithBufferAndDroppedParameter, self).__init__() + + self.fc1 = torch.nn.Linear(input_size, hidden_size) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(hidden_size, num_classes) + self.register_buffer("buffer", torch.ones(num_classes)) + + def forward(self, bool_argument, input1): + if bool_argument: + out = self.fc1(input1) + out = self.relu(out) + out = self.fc2(out) + out = out + self.buffer + else: + out = self.fc1(input1) + out = self.fc2(out) + out = self.relu(out) + out = out + self.buffer + return out + + device = 'cuda' + N, D_in, H, D_out = 64, 784, 500, 10 + model = ModelWithBufferAndDroppedParameter(D_in, H, D_out).to(device) + model = ORTModule(model) + + bool_argument = torch.tensor(True) + x = torch.randn(N, D_in, device=device) + + # Ensure that no exceptions are raised + out = model(bool_argument, x)