mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
parent
d30dd41c0e
commit
f38f2d5b54
1 changed files with 42 additions and 53 deletions
|
|
@ -375,13 +375,41 @@ class ORTTrainer(object):
|
|||
results = [session_run_results[o_desc.name] for o_desc in self.model_desc.outputs]
|
||||
return results[0] if len (results) == 1 else results
|
||||
|
||||
def _combine_torch_model_with_loss_fn_and_wrap_input(self):
|
||||
def _convert_torch_model_loss_fn_to_onnx(self, inputs, device):
|
||||
# Dynamic axes
|
||||
dynamic_axes = {}
|
||||
for input in self.model_desc.inputs:
|
||||
symbolic_axis = {}
|
||||
for i, axis in enumerate(input.shape):
|
||||
if isinstance(axis, str):
|
||||
symbolic_axis[i] = axis
|
||||
if len(symbolic_axis):
|
||||
dynamic_axes[input.name] = symbolic_axis
|
||||
for output in self.model_desc.outputs:
|
||||
symbolic_axis = {}
|
||||
for i, axis in enumerate(output.shape):
|
||||
if isinstance(axis, str):
|
||||
symbolic_axis[i] = axis
|
||||
if len(symbolic_axis):
|
||||
dynamic_axes[output.name] = symbolic_axis
|
||||
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
inputs = [inputs]
|
||||
if isinstance(inputs, dict):
|
||||
sample_inputs = [inputs[k.name_].to(device=device) for k in self.model_desc.inputs]
|
||||
elif isinstance(inputs, (list, tuple)):
|
||||
sample_inputs = [input.to(device=device) for i, input in enumerate(inputs) if i < len(self.model_desc.inputs)]
|
||||
else:
|
||||
raise RuntimeError("Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported.")
|
||||
|
||||
# PyTorch ONNX exporter does not match argument names
|
||||
# This is an issue because the ONNX graph depends on all inputs to be specified
|
||||
|
||||
# Validate loss_fn
|
||||
if self.loss_fn:
|
||||
sig_loss = signature(self.loss_fn)
|
||||
if len(sig_loss.parameters) != 2:
|
||||
raise RuntimeError(
|
||||
"loss function should take two arguments - predict and label.")
|
||||
raise RuntimeError("loss function should take two arguments - predict and label.")
|
||||
|
||||
# Basic input names from model
|
||||
input_names = [input.name for input in self.model_desc.inputs]
|
||||
|
|
@ -416,39 +444,7 @@ class ORTTrainer(object):
|
|||
preds = model_out
|
||||
return self.loss_fn(preds, label), preds
|
||||
|
||||
return CombineTorchModelLossFnWrapInput(self._torch_model, self.loss_fn, input_names)
|
||||
|
||||
|
||||
def _convert_torch_model_loss_fn_to_onnx(self, inputs, device):
|
||||
# Dynamic axes
|
||||
dynamic_axes = {}
|
||||
for input in self.model_desc.inputs:
|
||||
symbolic_axis = {}
|
||||
for i, axis in enumerate(input.shape):
|
||||
if isinstance(axis, str):
|
||||
symbolic_axis[i] = axis
|
||||
if len(symbolic_axis):
|
||||
dynamic_axes[input.name] = symbolic_axis
|
||||
for output in self.model_desc.outputs:
|
||||
symbolic_axis = {}
|
||||
for i, axis in enumerate(output.shape):
|
||||
if isinstance(axis, str):
|
||||
symbolic_axis[i] = axis
|
||||
if len(symbolic_axis):
|
||||
dynamic_axes[output.name] = symbolic_axis
|
||||
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
inputs = [inputs]
|
||||
if isinstance(inputs, dict):
|
||||
sample_inputs = [inputs[k.name_].to(device=device) for k in self.model_desc.inputs]
|
||||
elif isinstance(inputs, (list, tuple)):
|
||||
sample_inputs = [input.to(device=device) for i, input in enumerate(inputs) if i < len(self.model_desc.inputs)]
|
||||
else:
|
||||
raise RuntimeError("Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported.")
|
||||
|
||||
# PyTorch ONNX exporter does not match argument names
|
||||
# This is an issue because the ONNX graph depends on all inputs to be specified
|
||||
model = self._combine_torch_model_with_loss_fn_and_wrap_input()
|
||||
model = CombineTorchModelLossFnWrapInput(self._torch_model, self.loss_fn, input_names)
|
||||
|
||||
# Do an inference to grab output types
|
||||
model.eval()
|
||||
|
|
@ -493,23 +489,16 @@ class ORTTrainer(object):
|
|||
onnx_model = onnx.load_model_from_string(f.getvalue())
|
||||
|
||||
# Remove 'model.' prefix introduced by CombineTorchModelLossFn class
|
||||
replace_name_dict = {}
|
||||
for n in onnx_model.graph.initializer:
|
||||
if n.name.startswith('model.'):
|
||||
replace_name_dict[n.name] = n.name[len('model.'):]
|
||||
n.name = replace_name_dict[n.name]
|
||||
for n in onnx_model.graph.node:
|
||||
for i, name in enumerate(n.input):
|
||||
if name in replace_name_dict:
|
||||
n.input[i] = replace_name_dict[name]
|
||||
|
||||
# ONNX model initializers may contain non-trainable registered buffers
|
||||
# that are not part of PyTorch model named parameteres
|
||||
named_parameters = model.model.named_parameters() if hasattr(model, 'model') else model.named_parameters()
|
||||
assert set([n for n, t in named_parameters]).issubset(
|
||||
set([n.name for n in onnx_model.graph.initializer])), \
|
||||
"Initializer names do not match between PyTorch model and ONNX model, " \
|
||||
"please report a bug to ONNX Runtime."
|
||||
if isinstance(model, CombineTorchModelLossFnWrapInput):
|
||||
replace_name_dict = {}
|
||||
for n in onnx_model.graph.initializer:
|
||||
if n.name.startswith('model.'):
|
||||
replace_name_dict[n.name] = n.name[len('model.'):]
|
||||
n.name = replace_name_dict[n.name]
|
||||
for n in onnx_model.graph.node:
|
||||
for i, name in enumerate(n.input):
|
||||
if name in replace_name_dict:
|
||||
n.input[i] = replace_name_dict[name]
|
||||
|
||||
return onnx_model
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue