mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Fix initializer name only when wrapper is applied (#4920)
* Fix initializer name only when wrapper is applied * fix inspect import
This commit is contained in:
parent
d792af776d
commit
6dd4af3936
1 changed files with 49 additions and 55 deletions
|
|
@ -199,6 +199,45 @@ def dtype_torch_to_numpy(torch_dtype):
|
|||
else:
|
||||
raise Exception("Torch type to numpy type mapping unavailable for: " + str(torch_dtype))
|
||||
|
||||
class model_loss_cls(torch.nn.Module):
|
||||
def __init__(self, model, loss_fn):
|
||||
super(model_loss_cls, self).__init__()
|
||||
self.model_ = model
|
||||
self.loss_fn_ = loss_fn
|
||||
|
||||
def forward(self, *inputs):
|
||||
# here we assume input can be unpacked into input and label
|
||||
input, label = inputs[:-1], inputs[-1]
|
||||
preds = self.model_(*input)
|
||||
return self.loss_fn_(preds, label), preds
|
||||
|
||||
class WrapModel(torch.nn.Module):
|
||||
def __init__(self, model, loss_fn, input_names):
|
||||
super(WrapModel, self).__init__()
|
||||
self.model_ = model
|
||||
self.loss_fn_ = loss_fn
|
||||
self.input_names_ = input_names
|
||||
|
||||
def forward(self, *inputs):
|
||||
import inspect
|
||||
# *inputs is given by torch trace. It is in the order of input_names.
|
||||
# model_ takes input in a order (which can be obtained via inspect.signature(model.forward)) different than input_names.
|
||||
sig = inspect.signature(self.model_.forward)
|
||||
ordered_list_keys = list(sig.parameters.keys())
|
||||
|
||||
input_dict = {}
|
||||
for key in sig.parameters.keys():
|
||||
if key in self.input_names_:
|
||||
input_dict[key] = inputs[self.input_names_.index(key)]
|
||||
|
||||
model_out = self.model_(**input_dict)
|
||||
if self.loss_fn_ is None:
|
||||
return model_out
|
||||
|
||||
label = inputs[-1]
|
||||
preds = model_out
|
||||
return self.loss_fn_(preds, label), preds
|
||||
|
||||
def wrap_for_input_match(model, loss_fn, input_names):
|
||||
import inspect
|
||||
sig = inspect.signature(model.forward)
|
||||
|
|
@ -211,18 +250,6 @@ def wrap_for_input_match(model, loss_fn, input_names):
|
|||
# label shall be the second input to loss_fn.
|
||||
ordered_list_keys = [*ordered_list_keys, list(sig_loss.parameters.keys())[1]]
|
||||
|
||||
class model_loss_cls(torch.nn.Module):
|
||||
def __init__(self, model, loss_fn):
|
||||
super(model_loss_cls, self).__init__()
|
||||
self.model_ = model
|
||||
self.loss_fn_ = loss_fn
|
||||
|
||||
def forward(self, *inputs):
|
||||
# here we assume input can be unpacked into input and label
|
||||
input, label = inputs[:-1], inputs[-1]
|
||||
preds = self.model_(*input)
|
||||
return self.loss_fn_(preds, label), preds
|
||||
|
||||
# name match is needed only when input_names are a subset
|
||||
# of expected inputs (inputs to model and loss_fn combined).
|
||||
if len(input_names) > len(ordered_list_keys):
|
||||
|
|
@ -248,32 +275,6 @@ def wrap_for_input_match(model, loss_fn, input_names):
|
|||
if match:
|
||||
return model_loss_cls(model, loss_fn) if loss_fn else model
|
||||
|
||||
class WrapModel(torch.nn.Module):
|
||||
def __init__(self, model, loss_fn, input_names):
|
||||
super(WrapModel, self).__init__()
|
||||
self.model_ = model
|
||||
self.loss_fn_ = loss_fn
|
||||
self.input_names_ = input_names
|
||||
|
||||
def forward(self, *inputs):
|
||||
# *inputs is given by torch trace. It is in the order of input_names.
|
||||
# model_ takes input in a order (which can be obtained via inspect.signature(model.forward)) different than input_names.
|
||||
sig = inspect.signature(self.model_.forward)
|
||||
ordered_list_keys = list(sig.parameters.keys())
|
||||
|
||||
input_dict = {}
|
||||
for key in sig.parameters.keys():
|
||||
if key in self.input_names_:
|
||||
input_dict[key] = inputs[self.input_names_.index(key)]
|
||||
|
||||
model_out = self.model_(**input_dict)
|
||||
if self.loss_fn_ is None:
|
||||
return model_out
|
||||
|
||||
label = inputs[-1]
|
||||
preds = model_out
|
||||
return self.loss_fn_(preds, label), preds
|
||||
|
||||
model = WrapModel(model, loss_fn, input_names)
|
||||
|
||||
return model
|
||||
|
|
@ -362,23 +363,16 @@ def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, op
|
|||
onnx_model = onnx.load_model_from_string(f.getvalue())
|
||||
|
||||
# Remove 'model_.' prefix introduced by model wrapper for initializers.
|
||||
replace_name_dict = {}
|
||||
for n in onnx_model.graph.initializer:
|
||||
if n.name.startswith('model_.'):
|
||||
replace_name_dict[n.name] = n.name[len('model_.'):]
|
||||
n.name = replace_name_dict[n.name]
|
||||
for n in onnx_model.graph.node:
|
||||
for i, name in enumerate(n.input):
|
||||
if name in replace_name_dict:
|
||||
n.input[i] = replace_name_dict[name]
|
||||
|
||||
# onnx model initializer may contain non-trainable registered buffers that are not part
|
||||
# of pytorch model named parameteres.
|
||||
named_parameters = model.model_.named_parameters() if hasattr(model, 'model_') else model.named_parameters()
|
||||
assert set([n for n, t in named_parameters]).issubset(
|
||||
set([n.name for n in onnx_model.graph.initializer])), \
|
||||
"Initializer names do not match between PyTorch model and ONNX model, " \
|
||||
"please report a bug to ONNX Runtime."
|
||||
if isinstance(model, WrapModel) or isinstance(model, model_loss_cls):
|
||||
replace_name_dict = {}
|
||||
for n in onnx_model.graph.initializer:
|
||||
if n.name.startswith('model_.'):
|
||||
replace_name_dict[n.name] = n.name[len('model_.'):]
|
||||
n.name = replace_name_dict[n.name]
|
||||
for n in onnx_model.graph.node:
|
||||
for i, name in enumerate(n.input):
|
||||
if name in replace_name_dict:
|
||||
n.input[i] = replace_name_dict[name]
|
||||
|
||||
return onnx_model
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue