Port #4920 into the new pytorch frontend (#4965)

This commit is contained in:
Thiago Crepaldi 2020-09-01 19:00:49 -07:00 committed by GitHub
parent d30dd41c0e
commit f38f2d5b54
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -375,13 +375,41 @@ class ORTTrainer(object):
results = [session_run_results[o_desc.name] for o_desc in self.model_desc.outputs]
return results[0] if len (results) == 1 else results
def _combine_torch_model_with_loss_fn_and_wrap_input(self):
def _convert_torch_model_loss_fn_to_onnx(self, inputs, device):
# Dynamic axes
dynamic_axes = {}
for input in self.model_desc.inputs:
symbolic_axis = {}
for i, axis in enumerate(input.shape):
if isinstance(axis, str):
symbolic_axis[i] = axis
if len(symbolic_axis):
dynamic_axes[input.name] = symbolic_axis
for output in self.model_desc.outputs:
symbolic_axis = {}
for i, axis in enumerate(output.shape):
if isinstance(axis, str):
symbolic_axis[i] = axis
if len(symbolic_axis):
dynamic_axes[output.name] = symbolic_axis
if isinstance(inputs, torch.Tensor):
inputs = [inputs]
if isinstance(inputs, dict):
sample_inputs = [inputs[k.name_].to(device=device) for k in self.model_desc.inputs]
elif isinstance(inputs, (list, tuple)):
sample_inputs = [input.to(device=device) for i, input in enumerate(inputs) if i < len(self.model_desc.inputs)]
else:
raise RuntimeError("Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported.")
# PyTorch ONNX exporter does not match argument names
# This is an issue because the ONNX graph depends on all inputs to be specified
# Validate loss_fn
if self.loss_fn:
sig_loss = signature(self.loss_fn)
if len(sig_loss.parameters) != 2:
raise RuntimeError(
"loss function should take two arguments - predict and label.")
raise RuntimeError("loss function should take two arguments - predict and label.")
# Basic input names from model
input_names = [input.name for input in self.model_desc.inputs]
@ -416,39 +444,7 @@ class ORTTrainer(object):
preds = model_out
return self.loss_fn(preds, label), preds
return CombineTorchModelLossFnWrapInput(self._torch_model, self.loss_fn, input_names)
def _convert_torch_model_loss_fn_to_onnx(self, inputs, device):
# Dynamic axes
dynamic_axes = {}
for input in self.model_desc.inputs:
symbolic_axis = {}
for i, axis in enumerate(input.shape):
if isinstance(axis, str):
symbolic_axis[i] = axis
if len(symbolic_axis):
dynamic_axes[input.name] = symbolic_axis
for output in self.model_desc.outputs:
symbolic_axis = {}
for i, axis in enumerate(output.shape):
if isinstance(axis, str):
symbolic_axis[i] = axis
if len(symbolic_axis):
dynamic_axes[output.name] = symbolic_axis
if isinstance(inputs, torch.Tensor):
inputs = [inputs]
if isinstance(inputs, dict):
sample_inputs = [inputs[k.name_].to(device=device) for k in self.model_desc.inputs]
elif isinstance(inputs, (list, tuple)):
sample_inputs = [input.to(device=device) for i, input in enumerate(inputs) if i < len(self.model_desc.inputs)]
else:
raise RuntimeError("Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported.")
# PyTorch ONNX exporter does not match argument names
# This is an issue because the ONNX graph depends on all inputs to be specified
model = self._combine_torch_model_with_loss_fn_and_wrap_input()
model = CombineTorchModelLossFnWrapInput(self._torch_model, self.loss_fn, input_names)
# Do an inference to grab output types
model.eval()
@ -493,23 +489,16 @@ class ORTTrainer(object):
onnx_model = onnx.load_model_from_string(f.getvalue())
# Remove 'model.' prefix introduced by CombineTorchModelLossFn class
replace_name_dict = {}
for n in onnx_model.graph.initializer:
if n.name.startswith('model.'):
replace_name_dict[n.name] = n.name[len('model.'):]
n.name = replace_name_dict[n.name]
for n in onnx_model.graph.node:
for i, name in enumerate(n.input):
if name in replace_name_dict:
n.input[i] = replace_name_dict[name]
# ONNX model initializers may contain non-trainable registered buffers
# that are not part of PyTorch model named parameteres
named_parameters = model.model.named_parameters() if hasattr(model, 'model') else model.named_parameters()
assert set([n for n, t in named_parameters]).issubset(
set([n.name for n in onnx_model.graph.initializer])), \
"Initializer names do not match between PyTorch model and ONNX model, " \
"please report a bug to ONNX Runtime."
if isinstance(model, CombineTorchModelLossFnWrapInput):
replace_name_dict = {}
for n in onnx_model.graph.initializer:
if n.name.startswith('model.'):
replace_name_dict[n.name] = n.name[len('model.'):]
n.name = replace_name_dict[n.name]
for n in onnx_model.graph.node:
for i, name in enumerate(n.input):
if name in replace_name_dict:
n.input[i] = replace_name_dict[name]
return onnx_model