diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 7767544747..b4e1ce866b 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -18,6 +18,7 @@ class _InputInfo(object): dynamic_axes=None, schema=None, num_positionals=0, + num_positionals_non_none=0, keyword_names=None): self.names = names self.shape = shape @@ -25,31 +26,33 @@ class _InputInfo(object): self.dynamic_axes = dynamic_axes if dynamic_axes else {} self.schema = schema if schema else [] self.num_positionals = num_positionals + self.num_positionals_non_none = num_positionals_non_none self.keyword_names = keyword_names def __repr__(self) -> str: return f'''_InputInfo class: - \tNames: {self.names} - \tShape: {self.shape} - \tRequire gradient: {self.require_grad_names} - \tDynamic axes: {self.dynamic_axes} - \tSchema: {self.schema} - \t#Positionals: {self.num_positionals} - \tKeyword names: {self.keyword_names}''' + \tNames: {self.names} + \tShape: {self.shape} + \tRequire gradient: {self.require_grad_names} + \tDynamic axes: {self.dynamic_axes} + \tSchema: {self.schema} + \t#Positionals (total): {self.num_positionals} + \t#Positionals (non-None): {self.num_positionals_non_none} + \tKeyword names: {self.keyword_names}''' def flatten(self, args, kwargs): '''Flatten args and kwargs in a single tuple of tensors with strict ordering''' ret = list(args) - for _, kwarg in kwargs.items(): - ret.append(kwarg) - return tuple(ret) + ret += [kwargs[name] for name in self.names if name in kwargs] + return ret def unflatten(self, flat_args): '''Unflatten tuple of tensors into args and kwargs''' args = tuple(flat_args[:self.num_positionals]) - kwargs = {kwarg_name: arg for kwarg_name, arg in zip(self.keyword_names, flat_args[self.num_positionals:])} + kwargs = {name: arg for name, arg in zip(self.names[self.num_positionals_non_none:], flat_args[self.num_positionals:]) \ + if name in self.keyword_names} return args, kwargs def _combine_input_buffers_initializers(param_names, onnx_input_names, input_info, buffer_names, inputs, kwargs): @@ -135,10 +138,10 @@ class _TensorStub(object): return result def __eq__(self, other): - if not isinstance(other, _TensorStub): - raise NotImplemented('_TensorStub must only be compared to another _TensorStub instance!') - elif not other: + if not other: return False + elif not isinstance(other, _TensorStub): + raise NotImplemented('_TensorStub must only be compared to another _TensorStub instance!') elif self.name != other.name: return False elif self.dtype != other.dtype: @@ -356,6 +359,7 @@ def parse_inputs_for_onnx_export(all_input_parameters, onnx_graph, inputs, kwarg dynamic_axes=dynamic_axes, schema=schema, num_positionals=len(inputs), + num_positionals_non_none=len([i for i in inputs if i is not None]), keyword_names=kwargs.keys()) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 99ac87c7f0..f3ca5b8216 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -2082,3 +2082,134 @@ def test_forward_call_default_input(): assert out.item() == 15.0 if model.training: out.sum().backward() + + +def test_forward_call_kwargs_input_unexpected_order(): + class OrderlyNet(torch.nn.Module): + def __init__(self, input_size, hidden_size, num_classes): + super(OrderlyNet, self).__init__() + self.fc1 = torch.nn.Linear(input_size, hidden_size) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(hidden_size, num_classes) + + def forward(self, input1=None, input2=None): + assert input1.shape != input2.shape + input2 = torch.transpose(input2, 0, 1) + assert input1.shape == input2.shape + + model_input = input1 + input2 + out1 = self.fc1(model_input) + out1 = self.relu(out1) + out2 = self.fc2(out1) + return out1, out2 + + device = 'cuda' + N, D_in, H, D_out = 32, 784, 500, 10 + model = OrderlyNet(D_in, H, D_out).to(device) + model = ORTModule(model) + + input1 = torch.randn(N, D_in, device=device, requires_grad=False) + input2 = torch.randn(D_in, N, device=device, requires_grad=False) + + # Make sure model runs without any exception + for i in range(2): + # Test both train and inference mode + if i % 2 == 0: + model.train() + else: + model.eval() + + # Must work because forward() and dict order match + y1, y2 = model(**{'input1': input1, 'input2': input2}) + assert y1 is not None + assert y2 is not None + if model.training: + loss = y1.sum() + y2.sum() + loss.backward() + + # Must work even when forward() and dict order mismatch + y1, y2 = model(**{'input2': input2, 'input1': input1}) + assert y1 is not None + assert y2 is not None + if model.training: + loss = y1.sum() + y2.sum() + loss.backward() + + +def test_forward_call_lots_None(): + class NoneNet(torch.nn.Module): + def __init__(self): + super().__init__() + self.zeros = torch.nn.Parameter(torch.zeros(1,1)) + + def forward(self, a, b, c, d, e, f, y=None, z=None): + assert a is not None + result = self.zeros.sum() + a + if b is not None: + result += b + if c is not None: + result += c + if d is not None: + result += d + if e is not None: + result += e + if f is not None: + result += f + if y is not None: + result += y + if z is not None: + result += z + return result + + def run_step(expected, a, b, c, d, e, f, y, z): + # Force model (re)export to validate (un)flattening with new input + # This is needed because for a `forward(self, a, b)`, and + # input `forward(a,b)` or `forward(**{'a': a, 'b': b})`, + # ORTModule produces the same schema, thus not re-exporting + # the model when `forward(a,b)` is used after `forward(**{'a': a, 'b': b})` + # or vice-versa + model._execution_manager(model._is_training())._onnx_model = None + out = model(a,b,c,d,e,f,y,z) + assert out is not None + assert out.item() == expected + if model.training: + loss = out.sum() + loss.backward() + + device = 'cuda' + model = NoneNet().to(device) + model = ORTModule(model) + + a = torch.FloatTensor([1]).to(device)*1 + b = torch.FloatTensor([1]).to(device)*10 + c = torch.FloatTensor([1]).to(device)*100 + d = torch.FloatTensor([1]).to(device)*1000 + e = torch.FloatTensor([1]).to(device)*10000 + f = torch.FloatTensor([1]).to(device)*100000 + y = torch.FloatTensor([1]).to(device)*1000000 + z = torch.FloatTensor([1]).to(device)*10000000 + + # Make sure model runs without any exception + for i in range(2): + # Test both train and inference mode + if i % 2 == 0: + model.train() + else: + model.eval() + + run_step(a.item() + f.item(), + a, None, None, None, None, f, None, None, ) + run_step(a.item() + f.item(), + **{'a': a, 'b': None, 'c': None, 'd': None, 'e': None, 'f': f, 'y': None, 'z': None}) + run_step(a.item() + z.item(), + a, None, None, None, None, None, None, z) + run_step(a.item() + z.item(), + **{'a': a, 'b': None, 'c': None, 'd': None, 'e': None, 'f': None, 'y': None, 'z': z}) + run_step(a.item() + c.item() + y.item(), + a, None, c, None, None, None, y, None) + run_step(a.item() + c.item() + y.item(), + **{'a': a, 'b': None, 'c': c, 'd': None, 'e': None, 'f': None, 'y': y, 'z': None}) + run_step(a.item() + b.item() + c.item() + d.item() + e.item() + f.item() + y.item() + z.item(), + a, b, c, d, e, f, y, z) + run_step(a.item() + b.item() + c.item() + d.item() + e.item() + f.item() + y.item() + z.item(), + **{'a': a, 'b': b, 'c': c, 'd': d, 'e': e, 'f': f, 'y': y, 'z': z})