mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Add working example for MNIST (MVP)
This commit is contained in:
parent
f1b5c25b2d
commit
3524fb04e8
3 changed files with 186 additions and 162 deletions
|
|
@ -15,10 +15,10 @@ ONNX_OPSET_VERSION = 12
|
|||
class ORTModule(torch.nn.Module):
|
||||
|
||||
def __init__(self, module):
|
||||
print(f'ORTModule.__init__() was called')
|
||||
assert isinstance(module, torch.nn.Module), "'module' mst be a torch.nn.Module"
|
||||
super(ORTModule, self).__init__()
|
||||
# User will interact with it (debugging, etc)
|
||||
|
||||
# User module is wrapped to use its initializers and save computed gradients
|
||||
self._original_module = module
|
||||
|
||||
# Forward pass
|
||||
|
|
@ -31,10 +31,11 @@ class ORTModule(torch.nn.Module):
|
|||
# Backward pass
|
||||
self._onnx_backward = None
|
||||
self._backward_session = None
|
||||
self._onnx_backward_initializers_desc = []
|
||||
self._onnx_backward_inputs_desc = []
|
||||
self._onnx_backward_outputs_desc = []
|
||||
|
||||
def forward(self, *input, **kwargs):
|
||||
print(f'ORTModule.forward() was called')
|
||||
|
||||
if not self._onnx_forward:
|
||||
original_forward_graph = ORTModule._get_forward_graph(self._original_module, *input, **kwargs)
|
||||
gradient_graph = ORTModule._build_gradient_graph(original_forward_graph)
|
||||
|
|
@ -45,118 +46,132 @@ class ORTModule(torch.nn.Module):
|
|||
self._forward_session = onnxruntime.InferenceSession(self._onnx_forward.SerializeToString())
|
||||
self._backward_session = onnxruntime.InferenceSession(self._onnx_backward.SerializeToString())
|
||||
|
||||
# TODO: debug only
|
||||
self._save_onnx_graph(self._onnx_forward, 'ortmodule_forward_mnist.onnx')
|
||||
self._save_onnx_graph(self._onnx_backward, 'ortmodule_backward_mnist.onnx')
|
||||
|
||||
# Forward I/O description
|
||||
if not self._onnx_forward_initializers_desc:
|
||||
self._onnx_forward_initializers_desc = self._get_initializer_from_graph(self._onnx_forward)
|
||||
print(f'Forward initializers: {self._onnx_forward_initializers_desc}')
|
||||
if not self._onnx_forward_inputs_desc:
|
||||
self._onnx_forward_inputs_desc = self._get_input_from_graph(self._onnx_forward)
|
||||
print(f'Forward inputs: {self._onnx_forward_inputs_desc}')
|
||||
if not self._onnx_forward_outputs_desc:
|
||||
self._onnx_forward_outputs_desc = self._get_output_from_graph(self._onnx_forward)
|
||||
print(f'Forward outputs: {self._onnx_forward_outputs_desc}')
|
||||
|
||||
# TODO: debug only
|
||||
print(f'Initializers: {self._onnx_forward_initializers_desc}')
|
||||
print(f'Inputs: {self._onnx_forward_inputs_desc}')
|
||||
print(f'Outpus: {self._onnx_forward_outputs_desc}')
|
||||
# Backward I/O description
|
||||
if not self._onnx_backward_initializers_desc:
|
||||
self._onnx_backward_initializers_desc = self._get_initializer_from_graph(self._onnx_backward)
|
||||
print(f'Backward initializers: {self._onnx_backward_initializers_desc}')
|
||||
if not self._onnx_backward_inputs_desc:
|
||||
self._onnx_backward_inputs_desc = self._get_input_from_graph(self._onnx_backward)
|
||||
print(f'Backward inputs: {self._onnx_forward_inputs_desc}')
|
||||
if not self._onnx_backward_outputs_desc:
|
||||
self._onnx_backward_outputs_desc = self._get_output_from_graph(self._onnx_backward)
|
||||
print(f'Backward outputs: {self._onnx_backward_outputs_desc}')
|
||||
|
||||
# Use a custom torch.autograd.Function to associate self.backward_graph as the
|
||||
# gradient implementation for self.forward_graph.
|
||||
class _ORTModuleFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, *input, **kwargs):
|
||||
print(f'_ORTModuleFunction.forward() was called...')
|
||||
# Note: A potential optimization would be to detect which of inputs and weights
|
||||
# require a gradient.
|
||||
# intermediates, outputs = self._run_forward_graph(inputs) # inputs, weights)
|
||||
outputs = self._run_forward_graph(self._prepare_forward_input(*input, **kwargs)) # inputs, weights)
|
||||
outputs = [torch.nn.Parameter(torch.from_numpy(out)) for out in outputs]
|
||||
# TODO: Potential optimization is to detect which inputs and weights require gradients
|
||||
input_with_initializer = self._prepare_forward_input_ort(*input, **kwargs)
|
||||
outputs = self._run_forward_graph(input_with_initializer)
|
||||
outputs = tuple(torch.from_numpy(out) for out in outputs)
|
||||
|
||||
# TODO: Properly save intermediate tensors and remove them from model output
|
||||
ctx.save_for_backward([(input, kwargs), outputs[1]])
|
||||
# outputs = [outputs[0]]
|
||||
# TODO: Properly save dynamic number of intermediate tensors and remove them from model output
|
||||
# Tensors that need to have gradients tracked can't be saved by `save_for_backward`
|
||||
# saved_tensors ==> input1, fc2.weight, 7
|
||||
ctx.save_for_backward(*[input[0], input[3], outputs[1]])
|
||||
outputs = [outputs[0]]
|
||||
|
||||
# TODO: Properly support original module output format
|
||||
if len(outputs) == 1:
|
||||
return outputs[0]
|
||||
return tuple(outputs)
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_output):
|
||||
print(f'_ORTModuleFunction.backward() was called')
|
||||
input_and_kwargs, intermediates = ctx.saved_tensors
|
||||
# grad_inputs, grad_weights = self._run_backward_graph(
|
||||
# grad_output, intermediates)
|
||||
# return grad_inputs, grad_weights
|
||||
# TODO: Properly restore dynamic number of intermediate tensors
|
||||
# saved_tensors ==> input1, fc2.weight, 7
|
||||
saved_tensors = ctx.saved_tensors
|
||||
grad_weights = self._run_backward_graph(*[*saved_tensors, *grad_output])
|
||||
grad_weights = [torch.from_numpy(grad) for grad in grad_weights]
|
||||
# TODO: backward must return grad tensors in the same order forward does
|
||||
# [input1_grad, fc1.weight_grad, fc1.bias_grad, fc2.weight_grad, fc2.bias_grad]
|
||||
return tuple([torch.tensor([1.]), grad_weights[1], grad_weights[0], grad_weights[2], grad_weights[3]])
|
||||
|
||||
return _ORTModuleFunction.apply(*input, **kwargs)
|
||||
return _ORTModuleFunction.apply(*self._prepare_forward_input_autograd(*input, **kwargs))
|
||||
|
||||
def _prepare_forward_input(self, *input, **kwargs):
|
||||
# Dictionary containing both inputs and initializers
|
||||
input_with_initializer = {}
|
||||
def _prepare_forward_input_autograd(self, *input, **kwargs):
|
||||
# List containing both user inputs and initializers, in this order
|
||||
input_with_initializer = []
|
||||
|
||||
# Inputs
|
||||
for idx, input_data in enumerate(self._forward_session.get_inputs()):
|
||||
input_with_initializer.update({input_data.name: input[idx].cpu().numpy()})
|
||||
input_with_initializer.append(input[idx])
|
||||
|
||||
# Initializers
|
||||
for idx, param in enumerate(self._original_module.named_parameters()):
|
||||
input_with_initializer.update({param[0]: param[1].detach().numpy()})
|
||||
input_with_initializer.append(param[1])
|
||||
|
||||
# TODO: [input1, fc1.weight, fc1.bias, fc2.weight, fc2.bias]
|
||||
return input_with_initializer
|
||||
|
||||
def _prepare_forward_input_ort(self, *inputs):
|
||||
# Dictionary containing both inputs and initializers
|
||||
input_with_initializer = {}
|
||||
|
||||
# TODO: [input1, fc1.weight, fc1.bias, fc2.weight, fc2.bias]
|
||||
# Inputs
|
||||
inputs_len = 0
|
||||
for idx, input_data in enumerate(self._forward_session.get_inputs()):
|
||||
inputs_len += 1
|
||||
input_with_initializer.update({input_data.name: inputs[idx].cpu().numpy()})
|
||||
|
||||
# Initializers
|
||||
for param in self._original_module.named_parameters():
|
||||
input_with_initializer.update({param[0]: inputs[inputs_len].detach().numpy()})
|
||||
inputs_len += 1
|
||||
|
||||
return input_with_initializer
|
||||
|
||||
def _prepare_backward_input(self, grad_output, intermediates, *inputs, **kwargs):
|
||||
def _prepare_backward_input(self, *inputs, **kwargs):
|
||||
# Dictionary containing initializers
|
||||
input_with_initializer = {}
|
||||
|
||||
# User input
|
||||
# TODO: How to determine which user input to feed to backward
|
||||
for idx, input_data in enumerate(self._forward_session.get_inputs()):
|
||||
input_with_initializer.update({input_data.name: inputs[idx].cpu().numpy()})
|
||||
# for idx, input_data in enumerate(self._forward_session.get_inputs()):
|
||||
# input_with_initializer.update({input_data.name: inputs[idx].cpu().numpy()})
|
||||
input_with_initializer.update({'input1' : inputs[0].detach().numpy()})
|
||||
|
||||
# Initializers
|
||||
# TODO: How to determine which initializer (subset) to be used
|
||||
for idx, param in enumerate(self._original_module.named_parameters()):
|
||||
if param[0] == 'fc2.weight':
|
||||
input_with_initializer.update({param[0]: param[1].detach().numpy()})
|
||||
|
||||
# Grad output
|
||||
# TODO: How to determine grad_output name?
|
||||
input_with_initializer.update({'probability_grad': grad_output.detach().numpy()})
|
||||
# for idx, param in enumerate(self._original_module.named_parameters()):
|
||||
# input_with_initializer.update({param[0]: param[1].detach().numpy()})
|
||||
input_with_initializer.update({'fc2.weight' : inputs[1].detach().numpy()})
|
||||
|
||||
# Intermediates
|
||||
# TODO: How to determine intermediates name?
|
||||
input_with_initializer.update({'7': intermediates.detach().numpy()})
|
||||
input_with_initializer.update({'7': inputs[2].detach().numpy()})
|
||||
|
||||
# Grad output
|
||||
# TODO: How to determine grad_output name?
|
||||
input_with_initializer.update({'probability_grad': inputs[3].detach().numpy()})
|
||||
return input_with_initializer
|
||||
|
||||
def _run_forward_graph(self, data_with_initializer): # input, weights):
|
||||
print(f'_run_forward_graph was called...')
|
||||
return self._forward_session.run(None, data_with_initializer)
|
||||
|
||||
def _run_backward_graph(self, grad_output, intermediates, *inputs, **kwargs):
|
||||
# Use an InferenceSession to execute self.backward_graph.
|
||||
# Return gradient tensors for inputs and weights.
|
||||
print(f'_run_backward_graph was called...')
|
||||
data = self._prepare_backward_input(grad_output, intermediates, *inputs, **kwargs)
|
||||
def _run_backward_graph(self, *inputs, **kwargs):
|
||||
data = self._prepare_backward_input(*inputs, **kwargs)
|
||||
# TODO: Hack to guarantee output order from InferenceSession.run()
|
||||
return self._backward_session.run(['fc1.bias_grad', 'fc1.weight_grad', 'fc2.weight_grad', 'fc2.bias_grad'], data)
|
||||
# return self._backward_session.run(None, data)
|
||||
|
||||
@staticmethod
|
||||
def _get_forward_graph(module, module_input):
|
||||
print(f'_get_forward_graph was called...')
|
||||
# TODO: Pytorch module must be exported to ONNX and splitted
|
||||
# Hard-coding with MNIST stub for MVP
|
||||
# Export torch.nn.Module to ONNX with initializers as input
|
||||
# f = io.BytesIO()
|
||||
# torch.onnx.export(module, module_input, f, verbose=True,
|
||||
# opset_version=ONNX_OPSET_VERSION,
|
||||
# _retain_param_name=True,
|
||||
# training=torch.onnx.TrainingMode.TRAINING,
|
||||
# keep_initializers_as_inputs=True,
|
||||
# export_params=True)
|
||||
# return onnx.load_model_from_string(f.getvalue())
|
||||
return onnx.load('./model_with_training_forward_sliced.onnx')
|
||||
|
||||
def _get_initializer_from_graph(self, graph):
|
||||
|
|
@ -200,6 +215,7 @@ class ORTModule(torch.nn.Module):
|
|||
for elem in graph.graph.output:
|
||||
for initializer in self._onnx_forward_initializers_desc:
|
||||
if elem.name == initializer['name']:
|
||||
# skip initializers
|
||||
break
|
||||
else:
|
||||
name = elem.name
|
||||
|
|
@ -236,7 +252,6 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def _build_gradient_graph(forward_graph):
|
||||
print(f'_build_gradient_graph was called...')
|
||||
# TODO: Invoke the C++ GradientBuilder implementation via pybind.
|
||||
# Return an ONNX graph that contains the forward and backward nodes, which takes the
|
||||
# following inputs:
|
||||
|
|
@ -248,7 +263,6 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def _split_forward_and_backward(gradient_graph):
|
||||
print(f'_split_forward_and_backward was called...')
|
||||
# TODO: Split the result of _build_gradient_graph into two subgraphs:
|
||||
# * A forward graph that takes module inputs and weights as input, and produces module
|
||||
# outputs and (“stashed”) intermediate tensors as output.
|
||||
|
|
|
|||
|
|
@ -1,14 +1,10 @@
|
|||
import argparse
|
||||
import torch
|
||||
from torchvision import datasets, transforms
|
||||
import torchviz
|
||||
|
||||
from onnxruntime import set_seed
|
||||
import onnxruntime
|
||||
from onnxruntime.training import ORTModule
|
||||
|
||||
import _test_commons
|
||||
import _test_helpers
|
||||
|
||||
|
||||
class NeuralNet(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
|
|
@ -24,112 +20,127 @@ class NeuralNet(torch.nn.Module):
|
|||
out = self.fc2(out)
|
||||
return out
|
||||
|
||||
def main():
|
||||
#Training settings
|
||||
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
|
||||
parser.add_argument('--pytorch-only', action='store_true', default=False,
|
||||
help='disables ONNX Runtime training')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Model architecture
|
||||
lr = 1e-4
|
||||
batch_size=20
|
||||
seed=42
|
||||
|
||||
torch.manual_seed(seed)
|
||||
set_seed(seed)
|
||||
|
||||
|
||||
model = NeuralNet(input_size=784, hidden_size=500, num_classes=10)
|
||||
print('Training MNIST on ORTModule....')
|
||||
if not args.pytorch_only:
|
||||
model = ORTModule(model)
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
|
||||
|
||||
# Data loader
|
||||
train_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=True, download=True,
|
||||
transform=transforms.Compose([transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))])),
|
||||
batch_size=batch_size,
|
||||
shuffle=True)
|
||||
# Training Loop
|
||||
loss = float('inf')
|
||||
def train(args, model, device, optimizer, loss_fn, train_loader, epoch):
|
||||
model.train()
|
||||
for iteration, (data, target) in enumerate(train_loader):
|
||||
if iteration == 1:
|
||||
print(f'Final loss is {loss}')
|
||||
if iteration == args.train_steps:
|
||||
break
|
||||
|
||||
data, target = data.to(device), target.to(device)
|
||||
data = data.reshape(data.shape[0], -1)
|
||||
|
||||
optimizer.zero_grad()
|
||||
if args.pytorch_only:
|
||||
print("Using PyTorch-only API")
|
||||
probability = model(data)
|
||||
else:
|
||||
probability = model(data)
|
||||
|
||||
if args.view_graphs:
|
||||
import torchviz
|
||||
pytorch_backward_graph = torchviz.make_dot(probability, params=dict(list(model.named_parameters())))
|
||||
print(f'probability.grad_fn={probability.grad_fn}')
|
||||
print(f'probability.grad_fn.next_functions={probability.grad_fn.next_functions}')
|
||||
# pytorch_backward_graph.view()
|
||||
probability.retain_grad()
|
||||
else:
|
||||
print("Using ONNX Runtime Flexible API")
|
||||
probability, intermediates = model(data)
|
||||
probability.requires_grad_(True)
|
||||
pytorch_backward_graph.view()
|
||||
|
||||
print(f'Output from forward has shape {probability.size()}')
|
||||
loss = criterion(probability, target)
|
||||
loss = loss_fn(probability, target)
|
||||
loss.backward()
|
||||
print(f'***** probability.grad[0]={probability.grad[0]}')
|
||||
|
||||
if args.pytorch_only:
|
||||
print(f'***** (PYTORCH) fc1.bias_grad[0] BEFORE {model.fc1.bias.data[0].item()}')
|
||||
print(f'***** (PYTORCH) fc1.weight_grad[0][0] BEFORE {model.fc1.weight.data[0][0].item()}')
|
||||
print(f'***** (PYTORCH) fc2.bias_grad[0] BEFORE {model.fc2.bias.data[0].item()}')
|
||||
print(f'***** (PYTORCH) fc2.weight_grad[0][0] BEFORE {model.fc2.weight.data[0][0].item()}')
|
||||
else:
|
||||
# import pdb; pdb.set_trace()
|
||||
# Fake backward call to test backprop graph
|
||||
# TODO: The model output *order* is changing from ONNX export to ONNX export
|
||||
fc1_bias_grad, fc1_weight_grad, fc2_weight_grad, fc2_bias_grad = model._run_backward_graph(probability.grad, intermediates, data)
|
||||
fc1_bias_grad = torch.from_numpy(fc1_bias_grad).requires_grad_(True)
|
||||
fc2_bias_grad = torch.from_numpy(fc2_bias_grad).requires_grad_(True)
|
||||
fc1_weight_grad = torch.from_numpy(fc1_weight_grad).requires_grad_(True)
|
||||
fc2_weight_grad = torch.from_numpy(fc2_weight_grad).requires_grad_(True)
|
||||
fc1_bias_grad.retain_grad()
|
||||
fc1_weight_grad.retain_grad()
|
||||
fc2_bias_grad.retain_grad()
|
||||
fc2_weight_grad.retain_grad()
|
||||
|
||||
print(f'***** (ONNX Runtime) fc1_bias_grad[0] BEFORE {model._original_module.fc1.bias.data[0].item()}')
|
||||
print(f'***** (ONNX Runtime) fc1_weight_grad[0][0] BEFORE {model._original_module.fc1.weight.data[0][0].item()}')
|
||||
print(f'***** (ONNX Runtime) fc2_bias_grad[0] BEFORE {model._original_module.fc2.bias.data[0].item()}')
|
||||
print(f'***** (ONNX Runtime) fc2_weight_grad[0][0] BEFORE {model._original_module.fc2.weight.data[0][0].item()}')
|
||||
print(f'***** (ONNX Runtime) fc1_bias_grad[0] AFTER {fc1_bias_grad[0].item()}')
|
||||
print(f'***** (ONNX Runtime) fc1_weight_grad[0][0] AFTER {fc1_weight_grad[0][0]}')
|
||||
print(f'***** (ONNX Runtime) fc2_bias_grad[0] AFTER {fc2_bias_grad[0].item()}')
|
||||
print(f'***** (ONNX Runtime) fc2_weight_grad[0][0] AFTER {fc2_weight_grad[0][0].item()}')
|
||||
model._original_module.fc1.bias.data = fc1_bias_grad.data
|
||||
model._original_module.fc1.weight.data = fc1_weight_grad.data
|
||||
model._original_module.fc2.bias.data = fc2_bias_grad.data
|
||||
model._original_module.fc2.weight.data = fc2_weight_grad.data
|
||||
|
||||
print(f'Output from backaward has the following shapes after update:')
|
||||
print(f'fc1_bias_grad={fc1_bias_grad.size()}')
|
||||
print(f'fc2_bias_grad={fc2_bias_grad.size()}')
|
||||
print(f'fc1_weight_grad={fc1_weight_grad.size()}')
|
||||
print(f'fc2_weight_grad={fc2_weight_grad.size()}')
|
||||
|
||||
optimizer.step()
|
||||
if args.pytorch_only:
|
||||
print(f'***** (PYTORCH) fc1.bias_grad[0] AFTER {model.fc1.bias.data[0].item()}')
|
||||
print(f'***** (PYTORCH) fc1.weight_grad[0][0] AFTER {model.fc1.weight.data[0][0].item()}')
|
||||
print(f'***** (PYTORCH) fc2.bias_grad[0] AFTER {model.fc2.bias.data[0].item()}')
|
||||
print(f'***** (PYTORCH) fc2.weight_grad[0][0] AFTER {model.fc2.weight.data[0][0].item()}')
|
||||
|
||||
if iteration == 0:
|
||||
print(f'Initial loss is {loss}')
|
||||
print('Tah dah!')
|
||||
# Stats
|
||||
if iteration % args.log_interval == 0:
|
||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
||||
epoch, iteration * len(data), len(train_loader.dataset),
|
||||
100. * iteration / len(train_loader), loss))
|
||||
|
||||
|
||||
def test(args, model, device, loss_fn, test_loader):
|
||||
model.eval()
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for data, target in test_loader:
|
||||
data, target = data.to(device), target.to(device)
|
||||
data = data.reshape(data.shape[0], -1)
|
||||
output = model(data)
|
||||
|
||||
# Stats
|
||||
test_loss += loss_fn(output, target, False).item()
|
||||
pred = output.argmax(dim=1, keepdim=True)
|
||||
correct += pred.eq(target.view_as(pred)).sum().item()
|
||||
test_loss /= len(test_loader.dataset)
|
||||
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
|
||||
test_loss, correct, len(test_loader.dataset),
|
||||
100. * correct / len(test_loader.dataset)))
|
||||
|
||||
def my_loss(x, target, is_train=True):
|
||||
if is_train:
|
||||
return torch.nn.CrossEntropyLoss()(x, target)
|
||||
else:
|
||||
return torch.nn.CrossEntropyLoss(reduction='sum')(x, target)
|
||||
|
||||
def main():
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
|
||||
parser.add_argument('--train-steps', type=int, default=-1, metavar='N',
|
||||
help='number of steps to train. Set -1 to run through whole dataset (default: -1)')
|
||||
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
|
||||
help='learning rate (default: 0.001)')
|
||||
parser.add_argument('--batch-size', type=int, default=20, metavar='N',
|
||||
help='input batch size for training (default: 20)')
|
||||
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
|
||||
help='input batch size for testing (default: 1000)')
|
||||
parser.add_argument('--no-cuda', action='store_true', default=False,
|
||||
help='disables CUDA training')
|
||||
parser.add_argument('--seed', type=int, default=42, metavar='S',
|
||||
help='random seed (default: 42)')
|
||||
parser.add_argument('--pytorch-only', action='store_true', default=False,
|
||||
help='disables ONNX Runtime training')
|
||||
parser.add_argument('--log-interval', type=int, default=100, metavar='N',
|
||||
help='how many batches to wait before logging training status (default: 100)')
|
||||
parser.add_argument('--view-graphs', action='store_true', default=False,
|
||||
help='views forward and backward graphs')
|
||||
parser.add_argument('--epochs', type=int, default=10, metavar='N',
|
||||
help='number of epochs to train (default: 10)')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Common setup
|
||||
torch.manual_seed(args.seed)
|
||||
onnxruntime.set_seed(args.seed)
|
||||
|
||||
# TODO: CUDA support is broken due to copying from PyTorch into ORT
|
||||
# if not args.no_cuda and torch.cuda.is_available():
|
||||
# device = "cuda"
|
||||
# else:
|
||||
# device = "cpu"
|
||||
device = 'cpu'
|
||||
|
||||
## Data loader
|
||||
train_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=True, download=True,
|
||||
transform=transforms.Compose([transforms.ToTensor(),
|
||||
|
||||
transforms.Normalize((0.1307,), (0.3081,))])),
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True)
|
||||
if args.test_batch_size > 0:
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST('./data', train=False, transform=transforms.Compose([
|
||||
transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])),
|
||||
batch_size=args.test_batch_size, shuffle=True)
|
||||
|
||||
# Model architecture
|
||||
model = NeuralNet(input_size=784, hidden_size=500, num_classes=10).to(device)
|
||||
if not args.pytorch_only:
|
||||
print('Training MNIST on ORTModule....')
|
||||
model = ORTModule(model)
|
||||
else:
|
||||
print('Training MNIST on vanilla PyTorch....')
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
|
||||
|
||||
# Train loop
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
train(args, model, device, optimizer, my_loss, train_loader, epoch)
|
||||
if args.test_batch_size > 0:
|
||||
test(args, model, device, my_loss, test_loader)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -122,7 +122,6 @@ def main():
|
|||
train(args, model, device, train_loader, optimizer, epoch)
|
||||
if args.test_batch_size > 0:
|
||||
test(model, device, test_loader)
|
||||
optimizer.step()
|
||||
|
||||
# Save model
|
||||
if args.save_path:
|
||||
|
|
|
|||
Loading…
Reference in a new issue