mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
Fix bug where the output names were sorted lexicographically (#7709)
This commit is contained in:
parent
6c41ed597b
commit
c873f5589d
5 changed files with 46 additions and 6 deletions
|
|
@ -309,4 +309,7 @@ class GraphExecutionManager(ABC):
|
|||
grad_builder_config.graph_transformer_config = self._get_graph_transformer_config()
|
||||
grad_builder_config.loglevel = _logger.ortmodule_loglevel_to_onnxruntime_c_loglevel(self._loglevel)
|
||||
self._graph_builder = C.OrtModuleGraphBuilder()
|
||||
|
||||
# It is assumed here that the order and names of the inputs and outputs are not modified by the backend in any way
|
||||
# and are kept as they appear in the exported onnx model.
|
||||
self._graph_builder.initialize(self._onnx_model.SerializeToString(), grad_builder_config)
|
||||
|
|
|
|||
|
|
@ -99,7 +99,6 @@ class InferenceManager(GraphExecutionManager):
|
|||
self._device))
|
||||
|
||||
return _io.unflatten_user_output(self._module_output_schema,
|
||||
self._graph_info.user_output_names,
|
||||
user_outputs)
|
||||
|
||||
def _build_graph(self):
|
||||
|
|
|
|||
|
|
@ -232,7 +232,7 @@ class _TensorStub(object):
|
|||
return True
|
||||
|
||||
|
||||
def unflatten_user_output(output_schema, output_names, outputs):
|
||||
def unflatten_user_output(output_schema, outputs):
|
||||
"""Follows the schema to generate an output that is expected by the user"""
|
||||
|
||||
def _replace_stub_with_tensor_value(user_output, outputs, output_idx):
|
||||
|
|
@ -264,11 +264,11 @@ def unflatten_user_output(output_schema, output_names, outputs):
|
|||
|
||||
return user_output
|
||||
|
||||
# Order the outputs according to the names so that the traversal order is consistent
|
||||
outputs = [x for _, x in sorted(zip(output_names, outputs))]
|
||||
|
||||
# Replace every _TensorStub value in the schema with the torch.Tensor outputs calculated
|
||||
output_schema_copy = copy.deepcopy(output_schema)
|
||||
|
||||
# It is expected that the outputs are ordered in the way defined in the exported onnx model
|
||||
# which is the order in which the output schema was saved.
|
||||
output_idx = [0]
|
||||
user_output = _replace_stub_with_tensor_value(output_schema_copy, outputs, output_idx)
|
||||
return user_output
|
||||
|
|
|
|||
|
|
@ -184,7 +184,6 @@ class TrainingManager(GraphExecutionManager):
|
|||
return tuple(results)
|
||||
|
||||
return _io.unflatten_user_output(self._module_output_schema,
|
||||
self._graph_info.user_output_names,
|
||||
_ORTModuleFunction.apply(
|
||||
*_io._combine_input_buffers_initializers(
|
||||
[param for name, param in self._flattened_module.named_parameters()
|
||||
|
|
|
|||
|
|
@ -2523,3 +2523,42 @@ def test_unused_parameters(model, none_pt_params):
|
|||
out_pt = model(x)
|
||||
out_ort = ort_model(y)
|
||||
_test_helpers.assert_values_are_close(out_ort, out_pt)
|
||||
|
||||
def test_output_order():
|
||||
class OutputOrderNet(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
super(OutputOrderNet, self).__init__()
|
||||
|
||||
self.fc1 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.fc2 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.fc3 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.fc4 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.fc5 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.fc6 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.fc7 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.fc8 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.fc9 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.fc10 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.fc11 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.fc12 = torch.nn.Linear(input_size, hidden_size)
|
||||
|
||||
def forward(self, input1, input2, input3, input4, input5, input6, input7, input8, input9, input10, input11, input12):
|
||||
return self.fc1(input1), self.fc2(input2), self.fc3(input3), \
|
||||
self.fc4(input4), self.fc5(input5), self.fc6(input6), \
|
||||
self.fc7(input7), self.fc8(input8), self.fc9(input9), \
|
||||
self.fc10(input10), self.fc11(input11), self.fc12(input12)
|
||||
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
model = OutputOrderNet(D_in, H, D_out).to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(model))
|
||||
|
||||
x = [torch.randn(N, D_in, device=device) for _ in range(12)]
|
||||
y = copy.deepcopy(x)
|
||||
|
||||
out_pt = model(*x)
|
||||
out_ort = ort_model(*y)
|
||||
|
||||
assert len(out_pt) == len(out_ort)
|
||||
for x, y in zip(out_pt, out_ort):
|
||||
_test_helpers.assert_values_are_close(x, y)
|
||||
|
|
|
|||
Loading…
Reference in a new issue