Fix user input order before ORTModule feed it to backend (#7456)

This commit is contained in:
Thiago Crepaldi 2021-04-28 14:33:40 -07:00 committed by GitHub
parent d68cedfa85
commit 3ee63beafa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 149 additions and 14 deletions

View file

@ -18,6 +18,7 @@ class _InputInfo(object):
dynamic_axes=None,
schema=None,
num_positionals=0,
num_positionals_non_none=0,
keyword_names=None):
self.names = names
self.shape = shape
@ -25,31 +26,33 @@ class _InputInfo(object):
self.dynamic_axes = dynamic_axes if dynamic_axes else {}
self.schema = schema if schema else []
self.num_positionals = num_positionals
self.num_positionals_non_none = num_positionals_non_none
self.keyword_names = keyword_names
def __repr__(self) -> str:
return f'''_InputInfo class:
\tNames: {self.names}
\tShape: {self.shape}
\tRequire gradient: {self.require_grad_names}
\tDynamic axes: {self.dynamic_axes}
\tSchema: {self.schema}
\t#Positionals: {self.num_positionals}
\tKeyword names: {self.keyword_names}'''
\tNames: {self.names}
\tShape: {self.shape}
\tRequire gradient: {self.require_grad_names}
\tDynamic axes: {self.dynamic_axes}
\tSchema: {self.schema}
\t#Positionals (total): {self.num_positionals}
\t#Positionals (non-None): {self.num_positionals_non_none}
\tKeyword names: {self.keyword_names}'''
def flatten(self, args, kwargs):
'''Flatten args and kwargs in a single tuple of tensors with strict ordering'''
ret = list(args)
for _, kwarg in kwargs.items():
ret.append(kwarg)
return tuple(ret)
ret += [kwargs[name] for name in self.names if name in kwargs]
return ret
def unflatten(self, flat_args):
'''Unflatten tuple of tensors into args and kwargs'''
args = tuple(flat_args[:self.num_positionals])
kwargs = {kwarg_name: arg for kwarg_name, arg in zip(self.keyword_names, flat_args[self.num_positionals:])}
kwargs = {name: arg for name, arg in zip(self.names[self.num_positionals_non_none:], flat_args[self.num_positionals:]) \
if name in self.keyword_names}
return args, kwargs
def _combine_input_buffers_initializers(param_names, onnx_input_names, input_info, buffer_names, inputs, kwargs):
@ -135,10 +138,10 @@ class _TensorStub(object):
return result
def __eq__(self, other):
if not isinstance(other, _TensorStub):
raise NotImplemented('_TensorStub must only be compared to another _TensorStub instance!')
elif not other:
if not other:
return False
elif not isinstance(other, _TensorStub):
raise NotImplemented('_TensorStub must only be compared to another _TensorStub instance!')
elif self.name != other.name:
return False
elif self.dtype != other.dtype:
@ -356,6 +359,7 @@ def parse_inputs_for_onnx_export(all_input_parameters, onnx_graph, inputs, kwarg
dynamic_axes=dynamic_axes,
schema=schema,
num_positionals=len(inputs),
num_positionals_non_none=len([i for i in inputs if i is not None]),
keyword_names=kwargs.keys())

View file

@ -2082,3 +2082,134 @@ def test_forward_call_default_input():
assert out.item() == 15.0
if model.training:
out.sum().backward()
def test_forward_call_kwargs_input_unexpected_order():
class OrderlyNet(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(OrderlyNet, self).__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
def forward(self, input1=None, input2=None):
assert input1.shape != input2.shape
input2 = torch.transpose(input2, 0, 1)
assert input1.shape == input2.shape
model_input = input1 + input2
out1 = self.fc1(model_input)
out1 = self.relu(out1)
out2 = self.fc2(out1)
return out1, out2
device = 'cuda'
N, D_in, H, D_out = 32, 784, 500, 10
model = OrderlyNet(D_in, H, D_out).to(device)
model = ORTModule(model)
input1 = torch.randn(N, D_in, device=device, requires_grad=False)
input2 = torch.randn(D_in, N, device=device, requires_grad=False)
# Make sure model runs without any exception
for i in range(2):
# Test both train and inference mode
if i % 2 == 0:
model.train()
else:
model.eval()
# Must work because forward() and dict order match
y1, y2 = model(**{'input1': input1, 'input2': input2})
assert y1 is not None
assert y2 is not None
if model.training:
loss = y1.sum() + y2.sum()
loss.backward()
# Must work even when forward() and dict order mismatch
y1, y2 = model(**{'input2': input2, 'input1': input1})
assert y1 is not None
assert y2 is not None
if model.training:
loss = y1.sum() + y2.sum()
loss.backward()
def test_forward_call_lots_None():
class NoneNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.zeros = torch.nn.Parameter(torch.zeros(1,1))
def forward(self, a, b, c, d, e, f, y=None, z=None):
assert a is not None
result = self.zeros.sum() + a
if b is not None:
result += b
if c is not None:
result += c
if d is not None:
result += d
if e is not None:
result += e
if f is not None:
result += f
if y is not None:
result += y
if z is not None:
result += z
return result
def run_step(expected, a, b, c, d, e, f, y, z):
# Force model (re)export to validate (un)flattening with new input
# This is needed because for a `forward(self, a, b)`, and
# input `forward(a,b)` or `forward(**{'a': a, 'b': b})`,
# ORTModule produces the same schema, thus not re-exporting
# the model when `forward(a,b)` is used after `forward(**{'a': a, 'b': b})`
# or vice-versa
model._execution_manager(model._is_training())._onnx_model = None
out = model(a,b,c,d,e,f,y,z)
assert out is not None
assert out.item() == expected
if model.training:
loss = out.sum()
loss.backward()
device = 'cuda'
model = NoneNet().to(device)
model = ORTModule(model)
a = torch.FloatTensor([1]).to(device)*1
b = torch.FloatTensor([1]).to(device)*10
c = torch.FloatTensor([1]).to(device)*100
d = torch.FloatTensor([1]).to(device)*1000
e = torch.FloatTensor([1]).to(device)*10000
f = torch.FloatTensor([1]).to(device)*100000
y = torch.FloatTensor([1]).to(device)*1000000
z = torch.FloatTensor([1]).to(device)*10000000
# Make sure model runs without any exception
for i in range(2):
# Test both train and inference mode
if i % 2 == 0:
model.train()
else:
model.eval()
run_step(a.item() + f.item(),
a, None, None, None, None, f, None, None, )
run_step(a.item() + f.item(),
**{'a': a, 'b': None, 'c': None, 'd': None, 'e': None, 'f': f, 'y': None, 'z': None})
run_step(a.item() + z.item(),
a, None, None, None, None, None, None, z)
run_step(a.item() + z.item(),
**{'a': a, 'b': None, 'c': None, 'd': None, 'e': None, 'f': None, 'y': None, 'z': z})
run_step(a.item() + c.item() + y.item(),
a, None, c, None, None, None, y, None)
run_step(a.item() + c.item() + y.item(),
**{'a': a, 'b': None, 'c': c, 'd': None, 'e': None, 'f': None, 'y': y, 'z': None})
run_step(a.item() + b.item() + c.item() + d.item() + e.item() + f.item() + y.item() + z.item(),
a, b, c, d, e, f, y, z)
run_step(a.item() + b.item() + c.item() + d.item() + e.item() + f.item() + y.item() + z.item(),
**{'a': a, 'b': b, 'c': c, 'd': d, 'e': e, 'f': f, 'y': y, 'z': z})