diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 6624d46df3..439d780683 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -309,4 +309,7 @@ class GraphExecutionManager(ABC): grad_builder_config.graph_transformer_config = self._get_graph_transformer_config() grad_builder_config.loglevel = _logger.ortmodule_loglevel_to_onnxruntime_c_loglevel(self._loglevel) self._graph_builder = C.OrtModuleGraphBuilder() + + # It is assumed here that the order and names of the inputs and outputs are not modified by the backend in any way + # and are kept as they appear in the exported onnx model. self._graph_builder.initialize(self._onnx_model.SerializeToString(), grad_builder_config) diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index 3ca6514248..260dacf55b 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -99,7 +99,6 @@ class InferenceManager(GraphExecutionManager): self._device)) return _io.unflatten_user_output(self._module_output_schema, - self._graph_info.user_output_names, user_outputs) def _build_graph(self): diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index f50a931be2..1b88c3996d 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -232,7 +232,7 @@ class _TensorStub(object): return True -def unflatten_user_output(output_schema, output_names, outputs): +def unflatten_user_output(output_schema, outputs): """Follows the schema to generate an output that is expected by the user""" def _replace_stub_with_tensor_value(user_output, outputs, output_idx): @@ -264,11 +264,11 @@ def unflatten_user_output(output_schema, output_names, outputs): return user_output - # Order the outputs according to the names so that the traversal order is consistent - outputs = [x for _, x in sorted(zip(output_names, outputs))] - # Replace every _TensorStub value in the schema with the torch.Tensor outputs calculated output_schema_copy = copy.deepcopy(output_schema) + + # It is expected that the outputs are ordered in the way defined in the exported onnx model + # which is the order in which the output schema was saved. output_idx = [0] user_output = _replace_stub_with_tensor_value(output_schema_copy, outputs, output_idx) return user_output diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 848340b8fa..f6e618d2f8 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -184,7 +184,6 @@ class TrainingManager(GraphExecutionManager): return tuple(results) return _io.unflatten_user_output(self._module_output_schema, - self._graph_info.user_output_names, _ORTModuleFunction.apply( *_io._combine_input_buffers_initializers( [param for name, param in self._flattened_module.named_parameters() diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index f2de522115..3cdd36afc0 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -2523,3 +2523,42 @@ def test_unused_parameters(model, none_pt_params): out_pt = model(x) out_ort = ort_model(y) _test_helpers.assert_values_are_close(out_ort, out_pt) + +def test_output_order(): + class OutputOrderNet(torch.nn.Module): + def __init__(self, input_size, hidden_size, num_classes): + super(OutputOrderNet, self).__init__() + + self.fc1 = torch.nn.Linear(input_size, hidden_size) + self.fc2 = torch.nn.Linear(input_size, hidden_size) + self.fc3 = torch.nn.Linear(input_size, hidden_size) + self.fc4 = torch.nn.Linear(input_size, hidden_size) + self.fc5 = torch.nn.Linear(input_size, hidden_size) + self.fc6 = torch.nn.Linear(input_size, hidden_size) + self.fc7 = torch.nn.Linear(input_size, hidden_size) + self.fc8 = torch.nn.Linear(input_size, hidden_size) + self.fc9 = torch.nn.Linear(input_size, hidden_size) + self.fc10 = torch.nn.Linear(input_size, hidden_size) + self.fc11 = torch.nn.Linear(input_size, hidden_size) + self.fc12 = torch.nn.Linear(input_size, hidden_size) + + def forward(self, input1, input2, input3, input4, input5, input6, input7, input8, input9, input10, input11, input12): + return self.fc1(input1), self.fc2(input2), self.fc3(input3), \ + self.fc4(input4), self.fc5(input5), self.fc6(input6), \ + self.fc7(input7), self.fc8(input8), self.fc9(input9), \ + self.fc10(input10), self.fc11(input11), self.fc12(input12) + + device = 'cuda' + N, D_in, H, D_out = 64, 784, 500, 10 + model = OutputOrderNet(D_in, H, D_out).to(device) + ort_model = ORTModule(copy.deepcopy(model)) + + x = [torch.randn(N, D_in, device=device) for _ in range(12)] + y = copy.deepcopy(x) + + out_pt = model(*x) + out_ort = ort_model(*y) + + assert len(out_pt) == len(out_ort) + for x, y in zip(out_pt, out_ort): + _test_helpers.assert_values_are_close(x, y)