Forward pass using InferenceSession on exported ONNX

Although forward pass works, this has the limitation of not working for
backward pass due to the lack of intermediate tensors needed for
gradient.

Next step is to export a training graph and split it manually
This commit is contained in:
Thiago Crepaldi 2020-10-05 09:32:43 -07:00
parent a8d549e181
commit 11b69f141e
3 changed files with 275 additions and 0 deletions

View file

@ -8,3 +8,5 @@ from onnxruntime.capi.training.training_session import TrainingSession
from .orttrainer_options import ORTTrainerOptions
from .orttrainer import ORTTrainer, TrainStepInfo
from . import amp, checkpoint, optim, model_desc_validation
from .ortmodule import ORTModule

View file

@ -0,0 +1,209 @@
import copy
import io
import onnx
import onnxruntime
import os
import torch
import warnings
from . import _utils
ONNX_OPSET_VERSION = 12
class ORTModule(torch.nn.Module):
def __init__(self, module):
print(f'ORTModule.__init__() was called')
super(ORTModule, self).__init__()
# User will interact with it (debugging, etc)
self._original_module = module
# Forward pass
self._onnx_forward = None
self._onnx_forward_initializers_desc = []
self._onnx_forward_inputs_desc = []
self._onnx_forward_outputs_desc = []
# Backward pass
self._onnx_backward = None
def forward(self, *input, **kwargs):
print(f'ORTModule.forward() was called')
if not self._onnx_forward:
original_forward_graph = ORTModule._get_forward_graph(self._original_module, *input, **kwargs)
# gradient_graph = ORTModule._build_gradient_graph(original_forward_graph)
# self.forward_graph, self.backward_graph = ORTModule._split_forward_and_backward(gradient_graph)
self._onnx_forward = original_forward_graph # TODO: hard-coding for MVP
self.forward_session = onnxruntime.InferenceSession(self._onnx_forward.SerializeToString())
self._save_onnx_graph(self._onnx_forward, 'forward_mnist.onnx')
if not self._onnx_forward_initializers_desc:
self._onnx_forward_initializers_desc = self._get_initializer_from_graph(self._onnx_forward)
if not self._onnx_forward_inputs_desc:
self._onnx_forward_inputs_desc = self._get_input_from_graph(self._onnx_forward)
if not self._onnx_forward_outputs_desc:
self._onnx_forward_outputs_desc = self._get_output_from_graph(self._onnx_forward)
print(f'Initializers: {self._onnx_forward_initializers_desc}')
print(f'Inputs: {self._onnx_forward_inputs_desc}')
print(f'Outpus: {self._onnx_forward_outputs_desc}')
# Use a custom torch.autograd.Function to associate self.backward_graph as the
# gradient implementation for self.forward_graph.
class _ORTModuleFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, *input, **kwargs):
print(f'_ORTModuleFunction.forward() was called...')
# Note: A potential optimization would be to detect which of inputs and weights
# require a gradient.
# intermediates, outputs = self._run_forward_graph(inputs) # inputs, weights)
outputs = self._run_forward_graph(*input, **kwargs) # inputs, weights)
# ctx.save_for_backward(*intermediates)
outputs = [torch.from_numpy(out).requires_grad_(True) for out in outputs]
if len(outputs) == 1:
return outputs[0]
return tuple(outputs)
@staticmethod
def backward(ctx, grad_output):
print(f'_ORTModuleFunction.backward() was called')
...
# intermediates = ctx.saved_tensors
# grad_inputs, grad_weights = self._run_backward_graph(
# grad_output, intermediates)
# return grad_inputs, grad_weights
return _ORTModuleFunction.apply(self._prepare_model_input(*input, **kwargs))
def _prepare_model_input(self, *input, **kwargs):
# Dictionary containing both inputs and initializers
input_with_initializer = {}
# Inputs
for idx, input_data in enumerate(self.forward_session.get_inputs()):
input_with_initializer.update({input_data.name : input[idx].cpu().numpy()})
# Initializers
for idx, param in enumerate(self._original_module.named_parameters()):
input_with_initializer.update({param[0] : param[1].detach().numpy()})
return input_with_initializer
def _run_forward_graph(self, data_with_initializer): #input, weights):
return self.forward_session.run(None, data_with_initializer)
def _run_backward_graph(self, grad_output, intermediates):
# Use an InferenceSession to execute self.backward_graph.
# Return gradient tensors for inputs and weights.
...
@staticmethod
def _get_forward_graph(module, module_input):
# Export torch.nn.Module to ONNX with initializers as input
f = io.BytesIO()
torch.onnx.export(module, module_input, f, verbose=True,
opset_version=ONNX_OPSET_VERSION,
_retain_param_name=True,
training=torch.onnx.TrainingMode.TRAINING,
keep_initializers_as_inputs=True,
export_params=True)
return onnx.load_model_from_string(f.getvalue())
def _get_initializer_from_graph(self, graph):
# TODO: There is a tradefoo between memory footprint and total model export time
# Ideally we want to export the model using torch.onnx.export(.., export_params=False, keep_initializers_as_inputs=True)
# to obtain an ONNX model with minimal size and initializers as input.
# However, this results in (guessing) assuming only initializer's name end with '.weight' and '.bias'.
# Otherwise, it is not possible to separate input from initializer after the model is exported
# Options are:
# 1) If memory footprint is more important, we can export ONNX twice, varying keep_initializers_as_inputs flag
# ONNX model is small (400 bytes vs 1.6MB for MNIST), but export takes twice the time
# 2) If total export time is more important, we can export ONNX once, using export_params=True
# ONNX model is bigger, but export takes half the time
# As performance is not the main goal in this first deliverable, using approach 2) for simplicity
initializers = []
for initializer in graph.graph.initializer:
name = initializer.name
# TODO: Dynamic shape is not being handled yet
shape = initializer.dims
dtype = _utils.dtype_onnx_to_torch(initializer.data_type)
initializers.append({'name' : name, 'shape' : shape, 'dtype' : dtype})
return initializers
def _get_input_from_graph(self, graph):
inputs = []
for elem in graph.graph.input:
for initializer in self._onnx_forward_initializers_desc:
if elem.name == initializer['name']:
break
else:
name = elem.name
# TODO: Dynamic shape is not being handled yet
shape = [dim.dim_value for dim in elem.type.tensor_type.shape.dim]
dtype = _utils.dtype_onnx_to_torch(elem.type.tensor_type.elem_type)
inputs.append({'name' : name, 'shape' : shape, 'dtype' : dtype})
return inputs
def _get_output_from_graph(self, graph):
outputs = []
for elem in graph.graph.output:
for initializer in self._onnx_forward_initializers_desc:
if elem.name == initializer['name']:
break
else:
name = elem.name
# TODO: Dynamic shape is not being handled yet
shape = [dim.dim_value for dim in elem.type.tensor_type.shape.dim]
dtype = _utils.dtype_onnx_to_torch(elem.type.tensor_type.elem_type)
outputs.append({'name' : name, 'shape' : shape, 'dtype' : dtype})
return outputs
@staticmethod
def _save_onnx_graph(onnx_graph, path):
r"""Persists ONNX model into :py:attr:`path`
The model will be saved as a Google Protocol Buffers (aka protobuf) file as per ONNX standard.
The graph includes full information, including inference and training metadata.
Args:
onnx_graph (onnx.ModelProto): Either forward or backward graph
path (str): Full path, including filename, to save the ONNX model in the filesystem
Raises:
ValueError: raised when `path` is not valid path
"""
assert isinstance(path, str), "'path' must be a valid path string"
dir_name = os.path.dirname(path)
file_name = os.path.basename(path)
if (dir_name and not os.path.exists(dir_name)) or not file_name:
warnings.warn("'path' is not valid or does not exist")
return
with open(path, "wb") as f:
f.write(onnx_graph.SerializeToString())
@staticmethod
def _build_gradient_graph(forward_graph):
# Invoke the C++ GradientBuilder implementation via pybind.
# Return an ONNX graph that contains the forward and backward nodes, which takes the
# following inputs:
# * Module inputs
# * Module weights
# * Gradients with respect to the module outputs
# …and produces gradients with respect to the module inputs and weights.
...
@staticmethod
def _split_forward_and_backward(gradient_graph):
# Split the result of _build_gradient_graph into two subgraphs:
# * A forward graph that takes module inputs and weights as input, and produces module
# outputs and (“stashed”) intermediate tensors as output.
# * A backward graph that takes intermediate tensors, module weights, and gradients
# respect to the module outputs as inputs, and produces gradients with respect to the
# module inputs and weights.
return (None, None)

View file

@ -0,0 +1,64 @@
import torch
from torchvision import datasets, transforms
from onnxruntime import set_seed
from onnxruntime.training import ORTModule
import _test_commons
import _test_helpers
class NeuralNet(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(NeuralNet, self).__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
def forward(self, input1):
out = self.fc1(input1)
out = self.relu(out)
out = self.fc2(out)
return out
# Model architecture
lr = 1e-4
batch_size=20
seed=42
torch.manual_seed(seed)
set_seed(seed)
model = NeuralNet(input_size=784, hidden_size=500, num_classes=10)
model = ORTModule(model)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
# Data loader
train_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=True, download=True,
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=batch_size,
shuffle=True)
#TrainingLoop
print('Training MNIST on ORTModule....')
loss = float('inf')
for iteration, (data, target) in enumerate(train_loader):
if iteration == 1:
print(f'Final loss is {loss}')
break
data = data.reshape(data.shape[0], -1)
optimizer.zero_grad()
output = model(data)
print(f'Output from forward has shape {output[0].size()}: {output[0]}')
loss = criterion(output, target)
# loss.backward()
# optimizer.step()
if iteration == 0:
print(f'Initial loss is {loss}')
print('Tah dah!')