Fix bug where the output names were sorted lexicographically (#7709)

This commit is contained in:
baijumeswani 2021-05-17 10:27:20 -07:00 committed by GitHub
parent 6c41ed597b
commit c873f5589d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 46 additions and 6 deletions

View file

@ -309,4 +309,7 @@ class GraphExecutionManager(ABC):
grad_builder_config.graph_transformer_config = self._get_graph_transformer_config()
grad_builder_config.loglevel = _logger.ortmodule_loglevel_to_onnxruntime_c_loglevel(self._loglevel)
self._graph_builder = C.OrtModuleGraphBuilder()
# It is assumed here that the order and names of the inputs and outputs are not modified by the backend in any way
# and are kept as they appear in the exported onnx model.
self._graph_builder.initialize(self._onnx_model.SerializeToString(), grad_builder_config)

View file

@ -99,7 +99,6 @@ class InferenceManager(GraphExecutionManager):
self._device))
return _io.unflatten_user_output(self._module_output_schema,
self._graph_info.user_output_names,
user_outputs)
def _build_graph(self):

View file

@ -232,7 +232,7 @@ class _TensorStub(object):
return True
def unflatten_user_output(output_schema, output_names, outputs):
def unflatten_user_output(output_schema, outputs):
"""Follows the schema to generate an output that is expected by the user"""
def _replace_stub_with_tensor_value(user_output, outputs, output_idx):
@ -264,11 +264,11 @@ def unflatten_user_output(output_schema, output_names, outputs):
return user_output
# Order the outputs according to the names so that the traversal order is consistent
outputs = [x for _, x in sorted(zip(output_names, outputs))]
# Replace every _TensorStub value in the schema with the torch.Tensor outputs calculated
output_schema_copy = copy.deepcopy(output_schema)
# It is expected that the outputs are ordered in the way defined in the exported onnx model
# which is the order in which the output schema was saved.
output_idx = [0]
user_output = _replace_stub_with_tensor_value(output_schema_copy, outputs, output_idx)
return user_output

View file

@ -184,7 +184,6 @@ class TrainingManager(GraphExecutionManager):
return tuple(results)
return _io.unflatten_user_output(self._module_output_schema,
self._graph_info.user_output_names,
_ORTModuleFunction.apply(
*_io._combine_input_buffers_initializers(
[param for name, param in self._flattened_module.named_parameters()

View file

@ -2523,3 +2523,42 @@ def test_unused_parameters(model, none_pt_params):
out_pt = model(x)
out_ort = ort_model(y)
_test_helpers.assert_values_are_close(out_ort, out_pt)
def test_output_order():
class OutputOrderNet(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(OutputOrderNet, self).__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.fc2 = torch.nn.Linear(input_size, hidden_size)
self.fc3 = torch.nn.Linear(input_size, hidden_size)
self.fc4 = torch.nn.Linear(input_size, hidden_size)
self.fc5 = torch.nn.Linear(input_size, hidden_size)
self.fc6 = torch.nn.Linear(input_size, hidden_size)
self.fc7 = torch.nn.Linear(input_size, hidden_size)
self.fc8 = torch.nn.Linear(input_size, hidden_size)
self.fc9 = torch.nn.Linear(input_size, hidden_size)
self.fc10 = torch.nn.Linear(input_size, hidden_size)
self.fc11 = torch.nn.Linear(input_size, hidden_size)
self.fc12 = torch.nn.Linear(input_size, hidden_size)
def forward(self, input1, input2, input3, input4, input5, input6, input7, input8, input9, input10, input11, input12):
return self.fc1(input1), self.fc2(input2), self.fc3(input3), \
self.fc4(input4), self.fc5(input5), self.fc6(input6), \
self.fc7(input7), self.fc8(input8), self.fc9(input9), \
self.fc10(input10), self.fc11(input11), self.fc12(input12)
device = 'cuda'
N, D_in, H, D_out = 64, 784, 500, 10
model = OutputOrderNet(D_in, H, D_out).to(device)
ort_model = ORTModule(copy.deepcopy(model))
x = [torch.randn(N, D_in, device=device) for _ in range(12)]
y = copy.deepcopy(x)
out_pt = model(*x)
out_ort = ort_model(*y)
assert len(out_pt) == len(out_ort)
for x, y in zip(out_pt, out_ort):
_test_helpers.assert_values_are_close(x, y)