Fix initializer name only when wrapper is applied (#4920)

* Fix initializer name only when wrapper is applied

* fix inspect import
This commit is contained in:
Bowen Bao 2020-09-04 12:08:07 -07:00 committed by GitHub
parent d792af776d
commit 6dd4af3936
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -199,6 +199,45 @@ def dtype_torch_to_numpy(torch_dtype):
else:
raise Exception("Torch type to numpy type mapping unavailable for: " + str(torch_dtype))
class model_loss_cls(torch.nn.Module):
def __init__(self, model, loss_fn):
super(model_loss_cls, self).__init__()
self.model_ = model
self.loss_fn_ = loss_fn
def forward(self, *inputs):
# here we assume input can be unpacked into input and label
input, label = inputs[:-1], inputs[-1]
preds = self.model_(*input)
return self.loss_fn_(preds, label), preds
class WrapModel(torch.nn.Module):
def __init__(self, model, loss_fn, input_names):
super(WrapModel, self).__init__()
self.model_ = model
self.loss_fn_ = loss_fn
self.input_names_ = input_names
def forward(self, *inputs):
import inspect
# *inputs is given by torch trace. It is in the order of input_names.
# model_ takes input in a order (which can be obtained via inspect.signature(model.forward)) different than input_names.
sig = inspect.signature(self.model_.forward)
ordered_list_keys = list(sig.parameters.keys())
input_dict = {}
for key in sig.parameters.keys():
if key in self.input_names_:
input_dict[key] = inputs[self.input_names_.index(key)]
model_out = self.model_(**input_dict)
if self.loss_fn_ is None:
return model_out
label = inputs[-1]
preds = model_out
return self.loss_fn_(preds, label), preds
def wrap_for_input_match(model, loss_fn, input_names):
import inspect
sig = inspect.signature(model.forward)
@ -211,18 +250,6 @@ def wrap_for_input_match(model, loss_fn, input_names):
# label shall be the second input to loss_fn.
ordered_list_keys = [*ordered_list_keys, list(sig_loss.parameters.keys())[1]]
class model_loss_cls(torch.nn.Module):
def __init__(self, model, loss_fn):
super(model_loss_cls, self).__init__()
self.model_ = model
self.loss_fn_ = loss_fn
def forward(self, *inputs):
# here we assume input can be unpacked into input and label
input, label = inputs[:-1], inputs[-1]
preds = self.model_(*input)
return self.loss_fn_(preds, label), preds
# name match is needed only when input_names are a subset
# of expected inputs (inputs to model and loss_fn combined).
if len(input_names) > len(ordered_list_keys):
@ -248,32 +275,6 @@ def wrap_for_input_match(model, loss_fn, input_names):
if match:
return model_loss_cls(model, loss_fn) if loss_fn else model
class WrapModel(torch.nn.Module):
def __init__(self, model, loss_fn, input_names):
super(WrapModel, self).__init__()
self.model_ = model
self.loss_fn_ = loss_fn
self.input_names_ = input_names
def forward(self, *inputs):
# *inputs is given by torch trace. It is in the order of input_names.
# model_ takes input in a order (which can be obtained via inspect.signature(model.forward)) different than input_names.
sig = inspect.signature(self.model_.forward)
ordered_list_keys = list(sig.parameters.keys())
input_dict = {}
for key in sig.parameters.keys():
if key in self.input_names_:
input_dict[key] = inputs[self.input_names_.index(key)]
model_out = self.model_(**input_dict)
if self.loss_fn_ is None:
return model_out
label = inputs[-1]
preds = model_out
return self.loss_fn_(preds, label), preds
model = WrapModel(model, loss_fn, input_names)
return model
@ -362,23 +363,16 @@ def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, op
onnx_model = onnx.load_model_from_string(f.getvalue())
# Remove 'model_.' prefix introduced by model wrapper for initializers.
replace_name_dict = {}
for n in onnx_model.graph.initializer:
if n.name.startswith('model_.'):
replace_name_dict[n.name] = n.name[len('model_.'):]
n.name = replace_name_dict[n.name]
for n in onnx_model.graph.node:
for i, name in enumerate(n.input):
if name in replace_name_dict:
n.input[i] = replace_name_dict[name]
# onnx model initializer may contain non-trainable registered buffers that are not part
# of pytorch model named parameteres.
named_parameters = model.model_.named_parameters() if hasattr(model, 'model_') else model.named_parameters()
assert set([n for n, t in named_parameters]).issubset(
set([n.name for n in onnx_model.graph.initializer])), \
"Initializer names do not match between PyTorch model and ONNX model, " \
"please report a bug to ONNX Runtime."
if isinstance(model, WrapModel) or isinstance(model, model_loss_cls):
replace_name_dict = {}
for n in onnx_model.graph.initializer:
if n.name.startswith('model_.'):
replace_name_dict[n.name] = n.name[len('model_.'):]
n.name = replace_name_dict[n.name]
for n in onnx_model.graph.node:
for i, name in enumerate(n.input):
if name in replace_name_dict:
n.input[i] = replace_name_dict[name]
return onnx_model