Add working example for MNIST (MVP)

This commit is contained in:
Thiago Crepaldi 2020-10-19 15:12:13 -07:00
parent f1b5c25b2d
commit 3524fb04e8
3 changed files with 186 additions and 162 deletions

View file

@ -15,10 +15,10 @@ ONNX_OPSET_VERSION = 12
class ORTModule(torch.nn.Module):
def __init__(self, module):
print(f'ORTModule.__init__() was called')
assert isinstance(module, torch.nn.Module), "'module' mst be a torch.nn.Module"
super(ORTModule, self).__init__()
# User will interact with it (debugging, etc)
# User module is wrapped to use its initializers and save computed gradients
self._original_module = module
# Forward pass
@ -31,10 +31,11 @@ class ORTModule(torch.nn.Module):
# Backward pass
self._onnx_backward = None
self._backward_session = None
self._onnx_backward_initializers_desc = []
self._onnx_backward_inputs_desc = []
self._onnx_backward_outputs_desc = []
def forward(self, *input, **kwargs):
print(f'ORTModule.forward() was called')
if not self._onnx_forward:
original_forward_graph = ORTModule._get_forward_graph(self._original_module, *input, **kwargs)
gradient_graph = ORTModule._build_gradient_graph(original_forward_graph)
@ -45,118 +46,132 @@ class ORTModule(torch.nn.Module):
self._forward_session = onnxruntime.InferenceSession(self._onnx_forward.SerializeToString())
self._backward_session = onnxruntime.InferenceSession(self._onnx_backward.SerializeToString())
# TODO: debug only
self._save_onnx_graph(self._onnx_forward, 'ortmodule_forward_mnist.onnx')
self._save_onnx_graph(self._onnx_backward, 'ortmodule_backward_mnist.onnx')
# Forward I/O description
if not self._onnx_forward_initializers_desc:
self._onnx_forward_initializers_desc = self._get_initializer_from_graph(self._onnx_forward)
print(f'Forward initializers: {self._onnx_forward_initializers_desc}')
if not self._onnx_forward_inputs_desc:
self._onnx_forward_inputs_desc = self._get_input_from_graph(self._onnx_forward)
print(f'Forward inputs: {self._onnx_forward_inputs_desc}')
if not self._onnx_forward_outputs_desc:
self._onnx_forward_outputs_desc = self._get_output_from_graph(self._onnx_forward)
print(f'Forward outputs: {self._onnx_forward_outputs_desc}')
# TODO: debug only
print(f'Initializers: {self._onnx_forward_initializers_desc}')
print(f'Inputs: {self._onnx_forward_inputs_desc}')
print(f'Outpus: {self._onnx_forward_outputs_desc}')
# Backward I/O description
if not self._onnx_backward_initializers_desc:
self._onnx_backward_initializers_desc = self._get_initializer_from_graph(self._onnx_backward)
print(f'Backward initializers: {self._onnx_backward_initializers_desc}')
if not self._onnx_backward_inputs_desc:
self._onnx_backward_inputs_desc = self._get_input_from_graph(self._onnx_backward)
print(f'Backward inputs: {self._onnx_forward_inputs_desc}')
if not self._onnx_backward_outputs_desc:
self._onnx_backward_outputs_desc = self._get_output_from_graph(self._onnx_backward)
print(f'Backward outputs: {self._onnx_backward_outputs_desc}')
# Use a custom torch.autograd.Function to associate self.backward_graph as the
# gradient implementation for self.forward_graph.
class _ORTModuleFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, *input, **kwargs):
print(f'_ORTModuleFunction.forward() was called...')
# Note: A potential optimization would be to detect which of inputs and weights
# require a gradient.
# intermediates, outputs = self._run_forward_graph(inputs) # inputs, weights)
outputs = self._run_forward_graph(self._prepare_forward_input(*input, **kwargs)) # inputs, weights)
outputs = [torch.nn.Parameter(torch.from_numpy(out)) for out in outputs]
# TODO: Potential optimization is to detect which inputs and weights require gradients
input_with_initializer = self._prepare_forward_input_ort(*input, **kwargs)
outputs = self._run_forward_graph(input_with_initializer)
outputs = tuple(torch.from_numpy(out) for out in outputs)
# TODO: Properly save intermediate tensors and remove them from model output
ctx.save_for_backward([(input, kwargs), outputs[1]])
# outputs = [outputs[0]]
# TODO: Properly save dynamic number of intermediate tensors and remove them from model output
# Tensors that need to have gradients tracked can't be saved by `save_for_backward`
# saved_tensors ==> input1, fc2.weight, 7
ctx.save_for_backward(*[input[0], input[3], outputs[1]])
outputs = [outputs[0]]
# TODO: Properly support original module output format
if len(outputs) == 1:
return outputs[0]
return tuple(outputs)
return outputs
@staticmethod
def backward(ctx, *grad_output):
print(f'_ORTModuleFunction.backward() was called')
input_and_kwargs, intermediates = ctx.saved_tensors
# grad_inputs, grad_weights = self._run_backward_graph(
# grad_output, intermediates)
# return grad_inputs, grad_weights
# TODO: Properly restore dynamic number of intermediate tensors
# saved_tensors ==> input1, fc2.weight, 7
saved_tensors = ctx.saved_tensors
grad_weights = self._run_backward_graph(*[*saved_tensors, *grad_output])
grad_weights = [torch.from_numpy(grad) for grad in grad_weights]
# TODO: backward must return grad tensors in the same order forward does
# [input1_grad, fc1.weight_grad, fc1.bias_grad, fc2.weight_grad, fc2.bias_grad]
return tuple([torch.tensor([1.]), grad_weights[1], grad_weights[0], grad_weights[2], grad_weights[3]])
return _ORTModuleFunction.apply(*input, **kwargs)
return _ORTModuleFunction.apply(*self._prepare_forward_input_autograd(*input, **kwargs))
def _prepare_forward_input(self, *input, **kwargs):
# Dictionary containing both inputs and initializers
input_with_initializer = {}
def _prepare_forward_input_autograd(self, *input, **kwargs):
# List containing both user inputs and initializers, in this order
input_with_initializer = []
# Inputs
for idx, input_data in enumerate(self._forward_session.get_inputs()):
input_with_initializer.update({input_data.name: input[idx].cpu().numpy()})
input_with_initializer.append(input[idx])
# Initializers
for idx, param in enumerate(self._original_module.named_parameters()):
input_with_initializer.update({param[0]: param[1].detach().numpy()})
input_with_initializer.append(param[1])
# TODO: [input1, fc1.weight, fc1.bias, fc2.weight, fc2.bias]
return input_with_initializer
def _prepare_forward_input_ort(self, *inputs):
# Dictionary containing both inputs and initializers
input_with_initializer = {}
# TODO: [input1, fc1.weight, fc1.bias, fc2.weight, fc2.bias]
# Inputs
inputs_len = 0
for idx, input_data in enumerate(self._forward_session.get_inputs()):
inputs_len += 1
input_with_initializer.update({input_data.name: inputs[idx].cpu().numpy()})
# Initializers
for param in self._original_module.named_parameters():
input_with_initializer.update({param[0]: inputs[inputs_len].detach().numpy()})
inputs_len += 1
return input_with_initializer
def _prepare_backward_input(self, grad_output, intermediates, *inputs, **kwargs):
def _prepare_backward_input(self, *inputs, **kwargs):
# Dictionary containing initializers
input_with_initializer = {}
# User input
# TODO: How to determine which user input to feed to backward
for idx, input_data in enumerate(self._forward_session.get_inputs()):
input_with_initializer.update({input_data.name: inputs[idx].cpu().numpy()})
# for idx, input_data in enumerate(self._forward_session.get_inputs()):
# input_with_initializer.update({input_data.name: inputs[idx].cpu().numpy()})
input_with_initializer.update({'input1' : inputs[0].detach().numpy()})
# Initializers
# TODO: How to determine which initializer (subset) to be used
for idx, param in enumerate(self._original_module.named_parameters()):
if param[0] == 'fc2.weight':
input_with_initializer.update({param[0]: param[1].detach().numpy()})
# Grad output
# TODO: How to determine grad_output name?
input_with_initializer.update({'probability_grad': grad_output.detach().numpy()})
# for idx, param in enumerate(self._original_module.named_parameters()):
# input_with_initializer.update({param[0]: param[1].detach().numpy()})
input_with_initializer.update({'fc2.weight' : inputs[1].detach().numpy()})
# Intermediates
# TODO: How to determine intermediates name?
input_with_initializer.update({'7': intermediates.detach().numpy()})
input_with_initializer.update({'7': inputs[2].detach().numpy()})
# Grad output
# TODO: How to determine grad_output name?
input_with_initializer.update({'probability_grad': inputs[3].detach().numpy()})
return input_with_initializer
def _run_forward_graph(self, data_with_initializer): # input, weights):
print(f'_run_forward_graph was called...')
return self._forward_session.run(None, data_with_initializer)
def _run_backward_graph(self, grad_output, intermediates, *inputs, **kwargs):
# Use an InferenceSession to execute self.backward_graph.
# Return gradient tensors for inputs and weights.
print(f'_run_backward_graph was called...')
data = self._prepare_backward_input(grad_output, intermediates, *inputs, **kwargs)
def _run_backward_graph(self, *inputs, **kwargs):
data = self._prepare_backward_input(*inputs, **kwargs)
# TODO: Hack to guarantee output order from InferenceSession.run()
return self._backward_session.run(['fc1.bias_grad', 'fc1.weight_grad', 'fc2.weight_grad', 'fc2.bias_grad'], data)
# return self._backward_session.run(None, data)
@staticmethod
def _get_forward_graph(module, module_input):
print(f'_get_forward_graph was called...')
# TODO: Pytorch module must be exported to ONNX and splitted
# Hard-coding with MNIST stub for MVP
# Export torch.nn.Module to ONNX with initializers as input
# f = io.BytesIO()
# torch.onnx.export(module, module_input, f, verbose=True,
# opset_version=ONNX_OPSET_VERSION,
# _retain_param_name=True,
# training=torch.onnx.TrainingMode.TRAINING,
# keep_initializers_as_inputs=True,
# export_params=True)
# return onnx.load_model_from_string(f.getvalue())
return onnx.load('./model_with_training_forward_sliced.onnx')
def _get_initializer_from_graph(self, graph):
@ -200,6 +215,7 @@ class ORTModule(torch.nn.Module):
for elem in graph.graph.output:
for initializer in self._onnx_forward_initializers_desc:
if elem.name == initializer['name']:
# skip initializers
break
else:
name = elem.name
@ -236,7 +252,6 @@ class ORTModule(torch.nn.Module):
@staticmethod
def _build_gradient_graph(forward_graph):
print(f'_build_gradient_graph was called...')
# TODO: Invoke the C++ GradientBuilder implementation via pybind.
# Return an ONNX graph that contains the forward and backward nodes, which takes the
# following inputs:
@ -248,7 +263,6 @@ class ORTModule(torch.nn.Module):
@staticmethod
def _split_forward_and_backward(gradient_graph):
print(f'_split_forward_and_backward was called...')
# TODO: Split the result of _build_gradient_graph into two subgraphs:
# * A forward graph that takes module inputs and weights as input, and produces module
# outputs and (“stashed”) intermediate tensors as output.

View file

@ -1,14 +1,10 @@
import argparse
import torch
from torchvision import datasets, transforms
import torchviz
from onnxruntime import set_seed
import onnxruntime
from onnxruntime.training import ORTModule
import _test_commons
import _test_helpers
class NeuralNet(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
@ -24,112 +20,127 @@ class NeuralNet(torch.nn.Module):
out = self.fc2(out)
return out
def main():
#Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--pytorch-only', action='store_true', default=False,
help='disables ONNX Runtime training')
args = parser.parse_args()
# Model architecture
lr = 1e-4
batch_size=20
seed=42
torch.manual_seed(seed)
set_seed(seed)
model = NeuralNet(input_size=784, hidden_size=500, num_classes=10)
print('Training MNIST on ORTModule....')
if not args.pytorch_only:
model = ORTModule(model)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
# Data loader
train_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=True, download=True,
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=batch_size,
shuffle=True)
#TrainingLoop
loss = float('inf')
def train(args, model, device, optimizer, loss_fn, train_loader, epoch):
model.train()
for iteration, (data, target) in enumerate(train_loader):
if iteration == 1:
print(f'Final loss is {loss}')
if iteration == args.train_steps:
break
data, target = data.to(device), target.to(device)
data = data.reshape(data.shape[0], -1)
optimizer.zero_grad()
if args.pytorch_only:
print("Using PyTorch-only API")
probability = model(data)
else:
probability = model(data)
if args.view_graphs:
import torchviz
pytorch_backward_graph = torchviz.make_dot(probability, params=dict(list(model.named_parameters())))
print(f'probability.grad_fn={probability.grad_fn}')
print(f'probability.grad_fn.next_functions={probability.grad_fn.next_functions}')
# pytorch_backward_graph.view()
probability.retain_grad()
else:
print("Using ONNX Runtime Flexible API")
probability, intermediates = model(data)
probability.requires_grad_(True)
pytorch_backward_graph.view()
print(f'Output from forward has shape {probability.size()}')
loss = criterion(probability, target)
loss = loss_fn(probability, target)
loss.backward()
print(f'***** probability.grad[0]={probability.grad[0]}')
if args.pytorch_only:
print(f'***** (PYTORCH) fc1.bias_grad[0] BEFORE {model.fc1.bias.data[0].item()}')
print(f'***** (PYTORCH) fc1.weight_grad[0][0] BEFORE {model.fc1.weight.data[0][0].item()}')
print(f'***** (PYTORCH) fc2.bias_grad[0] BEFORE {model.fc2.bias.data[0].item()}')
print(f'***** (PYTORCH) fc2.weight_grad[0][0] BEFORE {model.fc2.weight.data[0][0].item()}')
else:
# import pdb; pdb.set_trace()
# Fake backward call to test backprop graph
# TODO: The model output *order* is changing from ONNX export to ONNX export
fc1_bias_grad, fc1_weight_grad, fc2_weight_grad, fc2_bias_grad = model._run_backward_graph(probability.grad, intermediates, data)
fc1_bias_grad = torch.from_numpy(fc1_bias_grad).requires_grad_(True)
fc2_bias_grad = torch.from_numpy(fc2_bias_grad).requires_grad_(True)
fc1_weight_grad = torch.from_numpy(fc1_weight_grad).requires_grad_(True)
fc2_weight_grad = torch.from_numpy(fc2_weight_grad).requires_grad_(True)
fc1_bias_grad.retain_grad()
fc1_weight_grad.retain_grad()
fc2_bias_grad.retain_grad()
fc2_weight_grad.retain_grad()
print(f'***** (ONNX Runtime) fc1_bias_grad[0] BEFORE {model._original_module.fc1.bias.data[0].item()}')
print(f'***** (ONNX Runtime) fc1_weight_grad[0][0] BEFORE {model._original_module.fc1.weight.data[0][0].item()}')
print(f'***** (ONNX Runtime) fc2_bias_grad[0] BEFORE {model._original_module.fc2.bias.data[0].item()}')
print(f'***** (ONNX Runtime) fc2_weight_grad[0][0] BEFORE {model._original_module.fc2.weight.data[0][0].item()}')
print(f'***** (ONNX Runtime) fc1_bias_grad[0] AFTER {fc1_bias_grad[0].item()}')
print(f'***** (ONNX Runtime) fc1_weight_grad[0][0] AFTER {fc1_weight_grad[0][0]}')
print(f'***** (ONNX Runtime) fc2_bias_grad[0] AFTER {fc2_bias_grad[0].item()}')
print(f'***** (ONNX Runtime) fc2_weight_grad[0][0] AFTER {fc2_weight_grad[0][0].item()}')
model._original_module.fc1.bias.data = fc1_bias_grad.data
model._original_module.fc1.weight.data = fc1_weight_grad.data
model._original_module.fc2.bias.data = fc2_bias_grad.data
model._original_module.fc2.weight.data = fc2_weight_grad.data
print(f'Output from backaward has the following shapes after update:')
print(f'fc1_bias_grad={fc1_bias_grad.size()}')
print(f'fc2_bias_grad={fc2_bias_grad.size()}')
print(f'fc1_weight_grad={fc1_weight_grad.size()}')
print(f'fc2_weight_grad={fc2_weight_grad.size()}')
optimizer.step()
if args.pytorch_only:
print(f'***** (PYTORCH) fc1.bias_grad[0] AFTER {model.fc1.bias.data[0].item()}')
print(f'***** (PYTORCH) fc1.weight_grad[0][0] AFTER {model.fc1.weight.data[0][0].item()}')
print(f'***** (PYTORCH) fc2.bias_grad[0] AFTER {model.fc2.bias.data[0].item()}')
print(f'***** (PYTORCH) fc2.weight_grad[0][0] AFTER {model.fc2.weight.data[0][0].item()}')
if iteration == 0:
print(f'Initial loss is {loss}')
print('Tah dah!')
# Stats
if iteration % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, iteration * len(data), len(train_loader.dataset),
100. * iteration / len(train_loader), loss))
def test(args, model, device, loss_fn, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
data = data.reshape(data.shape[0], -1)
output = model(data)
# Stats
test_loss += loss_fn(output, target, False).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
def my_loss(x, target, is_train=True):
if is_train:
return torch.nn.CrossEntropyLoss()(x, target)
else:
return torch.nn.CrossEntropyLoss(reduction='sum')(x, target)
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--train-steps', type=int, default=-1, metavar='N',
help='number of steps to train. Set -1 to run through whole dataset (default: -1)')
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
help='learning rate (default: 0.001)')
parser.add_argument('--batch-size', type=int, default=20, metavar='N',
help='input batch size for training (default: 20)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
parser.add_argument('--pytorch-only', action='store_true', default=False,
help='disables ONNX Runtime training')
parser.add_argument('--log-interval', type=int, default=100, metavar='N',
help='how many batches to wait before logging training status (default: 100)')
parser.add_argument('--view-graphs', action='store_true', default=False,
help='views forward and backward graphs')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
args = parser.parse_args()
# Common setup
torch.manual_seed(args.seed)
onnxruntime.set_seed(args.seed)
# TODO: CUDA support is broken due to copying from PyTorch into ORT
# if not args.no_cuda and torch.cuda.is_available():
# device = "cuda"
# else:
# device = "cpu"
device = 'cpu'
## Data loader
train_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=True, download=True,
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=args.batch_size,
shuffle=True)
if args.test_batch_size > 0:
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=False, transform=transforms.Compose([
transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=args.test_batch_size, shuffle=True)
# Model architecture
model = NeuralNet(input_size=784, hidden_size=500, num_classes=10).to(device)
if not args.pytorch_only:
print('Training MNIST on ORTModule....')
model = ORTModule(model)
else:
print('Training MNIST on vanilla PyTorch....')
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
# Train loop
for epoch in range(1, args.epochs + 1):
train(args, model, device, optimizer, my_loss, train_loader, epoch)
if args.test_batch_size > 0:
test(args, model, device, my_loss, test_loader)
if __name__ == '__main__':
main()

View file

@ -122,7 +122,6 @@ def main():
train(args, model, device, train_loader, optimizer, epoch)
if args.test_batch_size > 0:
test(model, device, test_loader)
optimizer.step()
# Save model
if args.save_path: