mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-19 21:32:23 +00:00
Fix user input order before ORTModule feed it to backend (#7456)
This commit is contained in:
parent
d68cedfa85
commit
3ee63beafa
2 changed files with 149 additions and 14 deletions
|
|
@ -18,6 +18,7 @@ class _InputInfo(object):
|
|||
dynamic_axes=None,
|
||||
schema=None,
|
||||
num_positionals=0,
|
||||
num_positionals_non_none=0,
|
||||
keyword_names=None):
|
||||
self.names = names
|
||||
self.shape = shape
|
||||
|
|
@ -25,31 +26,33 @@ class _InputInfo(object):
|
|||
self.dynamic_axes = dynamic_axes if dynamic_axes else {}
|
||||
self.schema = schema if schema else []
|
||||
self.num_positionals = num_positionals
|
||||
self.num_positionals_non_none = num_positionals_non_none
|
||||
self.keyword_names = keyword_names
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'''_InputInfo class:
|
||||
\tNames: {self.names}
|
||||
\tShape: {self.shape}
|
||||
\tRequire gradient: {self.require_grad_names}
|
||||
\tDynamic axes: {self.dynamic_axes}
|
||||
\tSchema: {self.schema}
|
||||
\t#Positionals: {self.num_positionals}
|
||||
\tKeyword names: {self.keyword_names}'''
|
||||
\tNames: {self.names}
|
||||
\tShape: {self.shape}
|
||||
\tRequire gradient: {self.require_grad_names}
|
||||
\tDynamic axes: {self.dynamic_axes}
|
||||
\tSchema: {self.schema}
|
||||
\t#Positionals (total): {self.num_positionals}
|
||||
\t#Positionals (non-None): {self.num_positionals_non_none}
|
||||
\tKeyword names: {self.keyword_names}'''
|
||||
|
||||
def flatten(self, args, kwargs):
|
||||
'''Flatten args and kwargs in a single tuple of tensors with strict ordering'''
|
||||
|
||||
ret = list(args)
|
||||
for _, kwarg in kwargs.items():
|
||||
ret.append(kwarg)
|
||||
return tuple(ret)
|
||||
ret += [kwargs[name] for name in self.names if name in kwargs]
|
||||
return ret
|
||||
|
||||
def unflatten(self, flat_args):
|
||||
'''Unflatten tuple of tensors into args and kwargs'''
|
||||
|
||||
args = tuple(flat_args[:self.num_positionals])
|
||||
kwargs = {kwarg_name: arg for kwarg_name, arg in zip(self.keyword_names, flat_args[self.num_positionals:])}
|
||||
kwargs = {name: arg for name, arg in zip(self.names[self.num_positionals_non_none:], flat_args[self.num_positionals:]) \
|
||||
if name in self.keyword_names}
|
||||
return args, kwargs
|
||||
|
||||
def _combine_input_buffers_initializers(param_names, onnx_input_names, input_info, buffer_names, inputs, kwargs):
|
||||
|
|
@ -135,10 +138,10 @@ class _TensorStub(object):
|
|||
return result
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, _TensorStub):
|
||||
raise NotImplemented('_TensorStub must only be compared to another _TensorStub instance!')
|
||||
elif not other:
|
||||
if not other:
|
||||
return False
|
||||
elif not isinstance(other, _TensorStub):
|
||||
raise NotImplemented('_TensorStub must only be compared to another _TensorStub instance!')
|
||||
elif self.name != other.name:
|
||||
return False
|
||||
elif self.dtype != other.dtype:
|
||||
|
|
@ -356,6 +359,7 @@ def parse_inputs_for_onnx_export(all_input_parameters, onnx_graph, inputs, kwarg
|
|||
dynamic_axes=dynamic_axes,
|
||||
schema=schema,
|
||||
num_positionals=len(inputs),
|
||||
num_positionals_non_none=len([i for i in inputs if i is not None]),
|
||||
keyword_names=kwargs.keys())
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2082,3 +2082,134 @@ def test_forward_call_default_input():
|
|||
assert out.item() == 15.0
|
||||
if model.training:
|
||||
out.sum().backward()
|
||||
|
||||
|
||||
def test_forward_call_kwargs_input_unexpected_order():
|
||||
class OrderlyNet(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
super(OrderlyNet, self).__init__()
|
||||
self.fc1 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
|
||||
|
||||
def forward(self, input1=None, input2=None):
|
||||
assert input1.shape != input2.shape
|
||||
input2 = torch.transpose(input2, 0, 1)
|
||||
assert input1.shape == input2.shape
|
||||
|
||||
model_input = input1 + input2
|
||||
out1 = self.fc1(model_input)
|
||||
out1 = self.relu(out1)
|
||||
out2 = self.fc2(out1)
|
||||
return out1, out2
|
||||
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 32, 784, 500, 10
|
||||
model = OrderlyNet(D_in, H, D_out).to(device)
|
||||
model = ORTModule(model)
|
||||
|
||||
input1 = torch.randn(N, D_in, device=device, requires_grad=False)
|
||||
input2 = torch.randn(D_in, N, device=device, requires_grad=False)
|
||||
|
||||
# Make sure model runs without any exception
|
||||
for i in range(2):
|
||||
# Test both train and inference mode
|
||||
if i % 2 == 0:
|
||||
model.train()
|
||||
else:
|
||||
model.eval()
|
||||
|
||||
# Must work because forward() and dict order match
|
||||
y1, y2 = model(**{'input1': input1, 'input2': input2})
|
||||
assert y1 is not None
|
||||
assert y2 is not None
|
||||
if model.training:
|
||||
loss = y1.sum() + y2.sum()
|
||||
loss.backward()
|
||||
|
||||
# Must work even when forward() and dict order mismatch
|
||||
y1, y2 = model(**{'input2': input2, 'input1': input1})
|
||||
assert y1 is not None
|
||||
assert y2 is not None
|
||||
if model.training:
|
||||
loss = y1.sum() + y2.sum()
|
||||
loss.backward()
|
||||
|
||||
|
||||
def test_forward_call_lots_None():
|
||||
class NoneNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.zeros = torch.nn.Parameter(torch.zeros(1,1))
|
||||
|
||||
def forward(self, a, b, c, d, e, f, y=None, z=None):
|
||||
assert a is not None
|
||||
result = self.zeros.sum() + a
|
||||
if b is not None:
|
||||
result += b
|
||||
if c is not None:
|
||||
result += c
|
||||
if d is not None:
|
||||
result += d
|
||||
if e is not None:
|
||||
result += e
|
||||
if f is not None:
|
||||
result += f
|
||||
if y is not None:
|
||||
result += y
|
||||
if z is not None:
|
||||
result += z
|
||||
return result
|
||||
|
||||
def run_step(expected, a, b, c, d, e, f, y, z):
|
||||
# Force model (re)export to validate (un)flattening with new input
|
||||
# This is needed because for a `forward(self, a, b)`, and
|
||||
# input `forward(a,b)` or `forward(**{'a': a, 'b': b})`,
|
||||
# ORTModule produces the same schema, thus not re-exporting
|
||||
# the model when `forward(a,b)` is used after `forward(**{'a': a, 'b': b})`
|
||||
# or vice-versa
|
||||
model._execution_manager(model._is_training())._onnx_model = None
|
||||
out = model(a,b,c,d,e,f,y,z)
|
||||
assert out is not None
|
||||
assert out.item() == expected
|
||||
if model.training:
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
|
||||
device = 'cuda'
|
||||
model = NoneNet().to(device)
|
||||
model = ORTModule(model)
|
||||
|
||||
a = torch.FloatTensor([1]).to(device)*1
|
||||
b = torch.FloatTensor([1]).to(device)*10
|
||||
c = torch.FloatTensor([1]).to(device)*100
|
||||
d = torch.FloatTensor([1]).to(device)*1000
|
||||
e = torch.FloatTensor([1]).to(device)*10000
|
||||
f = torch.FloatTensor([1]).to(device)*100000
|
||||
y = torch.FloatTensor([1]).to(device)*1000000
|
||||
z = torch.FloatTensor([1]).to(device)*10000000
|
||||
|
||||
# Make sure model runs without any exception
|
||||
for i in range(2):
|
||||
# Test both train and inference mode
|
||||
if i % 2 == 0:
|
||||
model.train()
|
||||
else:
|
||||
model.eval()
|
||||
|
||||
run_step(a.item() + f.item(),
|
||||
a, None, None, None, None, f, None, None, )
|
||||
run_step(a.item() + f.item(),
|
||||
**{'a': a, 'b': None, 'c': None, 'd': None, 'e': None, 'f': f, 'y': None, 'z': None})
|
||||
run_step(a.item() + z.item(),
|
||||
a, None, None, None, None, None, None, z)
|
||||
run_step(a.item() + z.item(),
|
||||
**{'a': a, 'b': None, 'c': None, 'd': None, 'e': None, 'f': None, 'y': None, 'z': z})
|
||||
run_step(a.item() + c.item() + y.item(),
|
||||
a, None, c, None, None, None, y, None)
|
||||
run_step(a.item() + c.item() + y.item(),
|
||||
**{'a': a, 'b': None, 'c': c, 'd': None, 'e': None, 'f': None, 'y': y, 'z': None})
|
||||
run_step(a.item() + b.item() + c.item() + d.item() + e.item() + f.item() + y.item() + z.item(),
|
||||
a, b, c, d, e, f, y, z)
|
||||
run_step(a.item() + b.item() + c.item() + d.item() + e.item() + f.item() + y.item() + z.item(),
|
||||
**{'a': a, 'b': b, 'c': c, 'd': d, 'e': e, 'f': f, 'y': y, 'z': z})
|
||||
|
|
|
|||
Loading…
Reference in a new issue