mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-29 03:30:52 +00:00
Forward pass using InferenceSession on exported ONNX
Although forward pass works, this has the limitation of not working for backward pass due to the lack of intermediate tensors needed for gradient. Next step is to export a training graph and split it manually
This commit is contained in:
parent
a8d549e181
commit
11b69f141e
3 changed files with 275 additions and 0 deletions
|
|
@ -8,3 +8,5 @@ from onnxruntime.capi.training.training_session import TrainingSession
|
|||
from .orttrainer_options import ORTTrainerOptions
|
||||
from .orttrainer import ORTTrainer, TrainStepInfo
|
||||
from . import amp, checkpoint, optim, model_desc_validation
|
||||
|
||||
from .ortmodule import ORTModule
|
||||
209
orttraining/orttraining/python/training/ortmodule.py
Normal file
209
orttraining/orttraining/python/training/ortmodule.py
Normal file
|
|
@ -0,0 +1,209 @@
|
|||
import copy
|
||||
import io
|
||||
import onnx
|
||||
import onnxruntime
|
||||
import os
|
||||
import torch
|
||||
import warnings
|
||||
|
||||
from . import _utils
|
||||
|
||||
|
||||
ONNX_OPSET_VERSION = 12
|
||||
|
||||
|
||||
class ORTModule(torch.nn.Module):
|
||||
|
||||
def __init__(self, module):
|
||||
print(f'ORTModule.__init__() was called')
|
||||
super(ORTModule, self).__init__()
|
||||
|
||||
# User will interact with it (debugging, etc)
|
||||
self._original_module = module
|
||||
|
||||
# Forward pass
|
||||
self._onnx_forward = None
|
||||
self._onnx_forward_initializers_desc = []
|
||||
self._onnx_forward_inputs_desc = []
|
||||
self._onnx_forward_outputs_desc = []
|
||||
|
||||
# Backward pass
|
||||
self._onnx_backward = None
|
||||
|
||||
def forward(self, *input, **kwargs):
|
||||
print(f'ORTModule.forward() was called')
|
||||
|
||||
if not self._onnx_forward:
|
||||
original_forward_graph = ORTModule._get_forward_graph(self._original_module, *input, **kwargs)
|
||||
# gradient_graph = ORTModule._build_gradient_graph(original_forward_graph)
|
||||
# self.forward_graph, self.backward_graph = ORTModule._split_forward_and_backward(gradient_graph)
|
||||
self._onnx_forward = original_forward_graph # TODO: hard-coding for MVP
|
||||
self.forward_session = onnxruntime.InferenceSession(self._onnx_forward.SerializeToString())
|
||||
self._save_onnx_graph(self._onnx_forward, 'forward_mnist.onnx')
|
||||
if not self._onnx_forward_initializers_desc:
|
||||
self._onnx_forward_initializers_desc = self._get_initializer_from_graph(self._onnx_forward)
|
||||
if not self._onnx_forward_inputs_desc:
|
||||
self._onnx_forward_inputs_desc = self._get_input_from_graph(self._onnx_forward)
|
||||
if not self._onnx_forward_outputs_desc:
|
||||
self._onnx_forward_outputs_desc = self._get_output_from_graph(self._onnx_forward)
|
||||
|
||||
print(f'Initializers: {self._onnx_forward_initializers_desc}')
|
||||
print(f'Inputs: {self._onnx_forward_inputs_desc}')
|
||||
print(f'Outpus: {self._onnx_forward_outputs_desc}')
|
||||
|
||||
# Use a custom torch.autograd.Function to associate self.backward_graph as the
|
||||
# gradient implementation for self.forward_graph.
|
||||
class _ORTModuleFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, *input, **kwargs):
|
||||
print(f'_ORTModuleFunction.forward() was called...')
|
||||
# Note: A potential optimization would be to detect which of inputs and weights
|
||||
# require a gradient.
|
||||
# intermediates, outputs = self._run_forward_graph(inputs) # inputs, weights)
|
||||
outputs = self._run_forward_graph(*input, **kwargs) # inputs, weights)
|
||||
# ctx.save_for_backward(*intermediates)
|
||||
outputs = [torch.from_numpy(out).requires_grad_(True) for out in outputs]
|
||||
if len(outputs) == 1:
|
||||
return outputs[0]
|
||||
return tuple(outputs)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
print(f'_ORTModuleFunction.backward() was called')
|
||||
...
|
||||
# intermediates = ctx.saved_tensors
|
||||
# grad_inputs, grad_weights = self._run_backward_graph(
|
||||
# grad_output, intermediates)
|
||||
# return grad_inputs, grad_weights
|
||||
|
||||
return _ORTModuleFunction.apply(self._prepare_model_input(*input, **kwargs))
|
||||
|
||||
def _prepare_model_input(self, *input, **kwargs):
|
||||
# Dictionary containing both inputs and initializers
|
||||
input_with_initializer = {}
|
||||
|
||||
# Inputs
|
||||
for idx, input_data in enumerate(self.forward_session.get_inputs()):
|
||||
input_with_initializer.update({input_data.name : input[idx].cpu().numpy()})
|
||||
|
||||
# Initializers
|
||||
for idx, param in enumerate(self._original_module.named_parameters()):
|
||||
input_with_initializer.update({param[0] : param[1].detach().numpy()})
|
||||
|
||||
return input_with_initializer
|
||||
|
||||
def _run_forward_graph(self, data_with_initializer): #input, weights):
|
||||
return self.forward_session.run(None, data_with_initializer)
|
||||
|
||||
def _run_backward_graph(self, grad_output, intermediates):
|
||||
# Use an InferenceSession to execute self.backward_graph.
|
||||
# Return gradient tensors for inputs and weights.
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def _get_forward_graph(module, module_input):
|
||||
# Export torch.nn.Module to ONNX with initializers as input
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export(module, module_input, f, verbose=True,
|
||||
opset_version=ONNX_OPSET_VERSION,
|
||||
_retain_param_name=True,
|
||||
training=torch.onnx.TrainingMode.TRAINING,
|
||||
keep_initializers_as_inputs=True,
|
||||
export_params=True)
|
||||
return onnx.load_model_from_string(f.getvalue())
|
||||
|
||||
def _get_initializer_from_graph(self, graph):
|
||||
# TODO: There is a tradefoo between memory footprint and total model export time
|
||||
# Ideally we want to export the model using torch.onnx.export(.., export_params=False, keep_initializers_as_inputs=True)
|
||||
# to obtain an ONNX model with minimal size and initializers as input.
|
||||
# However, this results in (guessing) assuming only initializer's name end with '.weight' and '.bias'.
|
||||
# Otherwise, it is not possible to separate input from initializer after the model is exported
|
||||
# Options are:
|
||||
# 1) If memory footprint is more important, we can export ONNX twice, varying keep_initializers_as_inputs flag
|
||||
# ONNX model is small (400 bytes vs 1.6MB for MNIST), but export takes twice the time
|
||||
# 2) If total export time is more important, we can export ONNX once, using export_params=True
|
||||
# ONNX model is bigger, but export takes half the time
|
||||
|
||||
# As performance is not the main goal in this first deliverable, using approach 2) for simplicity
|
||||
initializers = []
|
||||
for initializer in graph.graph.initializer:
|
||||
name = initializer.name
|
||||
# TODO: Dynamic shape is not being handled yet
|
||||
shape = initializer.dims
|
||||
dtype = _utils.dtype_onnx_to_torch(initializer.data_type)
|
||||
initializers.append({'name' : name, 'shape' : shape, 'dtype' : dtype})
|
||||
return initializers
|
||||
|
||||
def _get_input_from_graph(self, graph):
|
||||
inputs = []
|
||||
for elem in graph.graph.input:
|
||||
for initializer in self._onnx_forward_initializers_desc:
|
||||
if elem.name == initializer['name']:
|
||||
break
|
||||
else:
|
||||
name = elem.name
|
||||
# TODO: Dynamic shape is not being handled yet
|
||||
shape = [dim.dim_value for dim in elem.type.tensor_type.shape.dim]
|
||||
dtype = _utils.dtype_onnx_to_torch(elem.type.tensor_type.elem_type)
|
||||
inputs.append({'name' : name, 'shape' : shape, 'dtype' : dtype})
|
||||
return inputs
|
||||
|
||||
def _get_output_from_graph(self, graph):
|
||||
outputs = []
|
||||
for elem in graph.graph.output:
|
||||
for initializer in self._onnx_forward_initializers_desc:
|
||||
if elem.name == initializer['name']:
|
||||
break
|
||||
else:
|
||||
name = elem.name
|
||||
# TODO: Dynamic shape is not being handled yet
|
||||
shape = [dim.dim_value for dim in elem.type.tensor_type.shape.dim]
|
||||
dtype = _utils.dtype_onnx_to_torch(elem.type.tensor_type.elem_type)
|
||||
outputs.append({'name' : name, 'shape' : shape, 'dtype' : dtype})
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def _save_onnx_graph(onnx_graph, path):
|
||||
r"""Persists ONNX model into :py:attr:`path`
|
||||
|
||||
The model will be saved as a Google Protocol Buffers (aka protobuf) file as per ONNX standard.
|
||||
The graph includes full information, including inference and training metadata.
|
||||
|
||||
Args:
|
||||
onnx_graph (onnx.ModelProto): Either forward or backward graph
|
||||
path (str): Full path, including filename, to save the ONNX model in the filesystem
|
||||
|
||||
Raises:
|
||||
ValueError: raised when `path` is not valid path
|
||||
"""
|
||||
|
||||
assert isinstance(path, str), "'path' must be a valid path string"
|
||||
dir_name = os.path.dirname(path)
|
||||
file_name = os.path.basename(path)
|
||||
if (dir_name and not os.path.exists(dir_name)) or not file_name:
|
||||
warnings.warn("'path' is not valid or does not exist")
|
||||
return
|
||||
|
||||
with open(path, "wb") as f:
|
||||
f.write(onnx_graph.SerializeToString())
|
||||
|
||||
@staticmethod
|
||||
def _build_gradient_graph(forward_graph):
|
||||
# Invoke the C++ GradientBuilder implementation via pybind.
|
||||
# Return an ONNX graph that contains the forward and backward nodes, which takes the
|
||||
# following inputs:
|
||||
# * Module inputs
|
||||
# * Module weights
|
||||
# * Gradients with respect to the module outputs
|
||||
# …and produces gradients with respect to the module inputs and weights.
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def _split_forward_and_backward(gradient_graph):
|
||||
# Split the result of _build_gradient_graph into two subgraphs:
|
||||
# * A forward graph that takes module inputs and weights as input, and produces module
|
||||
# outputs and (“stashed”) intermediate tensors as output.
|
||||
# * A backward graph that takes intermediate tensors, module weights, and gradients
|
||||
# respect to the module outputs as inputs, and produces gradients with respect to the
|
||||
# module inputs and weights.
|
||||
return (None, None)
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
import torch
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
from onnxruntime import set_seed
|
||||
from onnxruntime.training import ORTModule
|
||||
|
||||
import _test_commons
|
||||
import _test_helpers
|
||||
|
||||
|
||||
class NeuralNet(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
super(NeuralNet, self).__init__()
|
||||
|
||||
self.fc1 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
|
||||
|
||||
def forward(self, input1):
|
||||
out = self.fc1(input1)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
return out
|
||||
|
||||
# Model architecture
|
||||
lr = 1e-4
|
||||
batch_size=20
|
||||
seed=42
|
||||
|
||||
torch.manual_seed(seed)
|
||||
set_seed(seed)
|
||||
|
||||
|
||||
model = NeuralNet(input_size=784, hidden_size=500, num_classes=10)
|
||||
model = ORTModule(model)
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
|
||||
|
||||
# Data loader
|
||||
train_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=True, download=True,
|
||||
transform=transforms.Compose([transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))])),
|
||||
batch_size=batch_size,
|
||||
shuffle=True)
|
||||
|
||||
# Training Loop
|
||||
print('Training MNIST on ORTModule....')
|
||||
loss = float('inf')
|
||||
for iteration, (data, target) in enumerate(train_loader):
|
||||
if iteration == 1:
|
||||
print(f'Final loss is {loss}')
|
||||
break
|
||||
|
||||
data = data.reshape(data.shape[0], -1)
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
print(f'Output from forward has shape {output[0].size()}: {output[0]}')
|
||||
loss = criterion(output, target)
|
||||
# loss.backward()
|
||||
# optimizer.step()
|
||||
|
||||
if iteration == 0:
|
||||
print(f'Initial loss is {loss}')
|
||||
print('Tah dah!')
|
||||
Loading…
Reference in a new issue