From eebc2cccce62e5893053fcf061f832168c6b2e55 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Wed, 2 Sep 2020 08:57:44 -0700 Subject: [PATCH] Fix fetches when eval_step's input is a subset of train_step's input (#4966) This PR also includes MNIST sample using the new forntend as a sample --- .../python/experimental/orttrainer.py | 5 +- samples/python/mnist/mnist_training.py | 142 ++++++++++++++++++ 2 files changed, 145 insertions(+), 2 deletions(-) create mode 100644 samples/python/mnist/mnist_training.py diff --git a/orttraining/orttraining/python/experimental/orttrainer.py b/orttraining/orttraining/python/experimental/orttrainer.py index e25e7aec69..1a5267779d 100644 --- a/orttraining/orttraining/python/experimental/orttrainer.py +++ b/orttraining/orttraining/python/experimental/orttrainer.py @@ -24,7 +24,7 @@ class TrainStepInfo(object): Args: optimizer_config (optim._OptimizerConfig): reference to optimizer config all_finite (bool, default is True): flag that indicates whether all gradients are still finite after last step - fetches (list of str, default is []): list of output names to fetch from train_step/eval_step + fetches (list of str, default is []): list of output names to fetch from train_step/eval_step. Set it to [] to reset normal behavior. optimization_step (int): indicates the number of optimizations performed. Used for learning rate scheduling step (int): indicates current training step. Used for gradient accumulation @@ -678,7 +678,8 @@ class ORTTrainer(object): input += (loss_scale, ) extra_inputs += 1 - assert len(self.model_desc.inputs) + extra_inputs == len(input) + # Only assert length of input when fetches is not used + assert self._train_step_info.fetches or len(self.model_desc.inputs) + extra_inputs == len(input) return input def _resolve_symbolic_dimensions(self, inputs, inputs_desc, outputs_desc): diff --git a/samples/python/mnist/mnist_training.py b/samples/python/mnist/mnist_training.py new file mode 100644 index 0000000000..c3c4c86963 --- /dev/null +++ b/samples/python/mnist/mnist_training.py @@ -0,0 +1,142 @@ +# This code is from https://github.com/pytorch/examples/blob/master/mnist/main.py +# with modification to do training using onnxruntime as backend on cuda device. + +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import datasets, transforms + +import onnxruntime +from onnxruntime.experimental import ORTTrainer, ORTTrainerOptions, optim + + +# Pytorch model +class NeuralNet(nn.Module): + def __init__(self, input_size, hidden_size, num_classes): + super(NeuralNet, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(hidden_size, num_classes) + + def forward(self, input1): + out = self.fc1(input1) + out = self.relu(out) + out = self.fc2(out) + return out + + +# ONNX Runtime training +def mnist_model_description(): + return {'inputs': [('input1', ['batch', 784]), + ('label', ['batch'])], + 'outputs': [('loss', [], True), + ('probability', ['batch', 10])]} + + +def my_loss(x, target): + return F.nll_loss(F.log_softmax(x, dim=1), target) + + +# Helpers +def train_with_trainer(log_interval, trainer, device, train_loader, epoch): + for batch_idx, (data, target) in enumerate(train_loader): + # Fetch data + data, target = data.to(device), target.to(device) + data = data.reshape(data.shape[0], -1) + + # Train step + loss, _ = trainer.train_step(data, target) + + # Stats + if batch_idx % log_interval == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, batch_idx * len(data), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), loss)) + + +def test_with_trainer(trainer, device, test_loader): + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + # Fetch data + data, target = data.to(device), target.to(device) + data = data.reshape(data.shape[0], -1) + + # Eval step + # Using fetches around without eval_step to not pass 'target' as input + trainer._train_step_info.fetches = ['probability'] + output = F.log_softmax(trainer.eval_step(data), dim=1) + trainer._train_step_info.fetches = [] + + # Stats + test_loss += F.nll_loss(output, target, reduction='sum').item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + test_loss /= len(test_loader.dataset) + print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), + 100. * correct / len(test_loader.dataset))) + + +def main(): + # Training settings + parser = argparse.ArgumentParser(description='MNIST Example') + parser.add_argument('--batch-size', type=int, default=64, metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', + help='input batch size for testing (default: 1000)') + parser.add_argument('--epochs', type=int, default=10, metavar='N', + help='number of epochs to train (default: 10)') + parser.add_argument('--lr', type=float, default=0.01, metavar='LR', + help='learning rate (default: 0.01)') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--log-interval', type=int, default=10, metavar='N', + help='how many batches to wait before logging training status') + + # Basic setup + args = parser.parse_args() + if not args.no_cuda and torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + torch.manual_seed(args.seed) + onnxruntime.set_seed(args.seed) + + # Data loader + train_loader = torch.utils.data.DataLoader( + datasets.MNIST('./data', train=True, download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ])), + batch_size=args.batch_size, shuffle=True) + test_loader = torch.utils.data.DataLoader( + datasets.MNIST('./data', train=False, transform=transforms.Compose([ + transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])), + batch_size=args.test_batch_size, shuffle=True) + + # Modeling + model = NeuralNet(784, 500, 10) + model_desc = mnist_model_description() + optim_config = optim.SGDConfig(lr=args.lr) + opts = ORTTrainerOptions({'device': {'id': device}}) + trainer = ORTTrainer(model, + model_desc, + optim_config, + loss_fn=my_loss, + options=opts) + + # Train loop + for epoch in range(1, args.epochs + 1): + train_with_trainer(args.log_interval, trainer, + device, train_loader, epoch) + test_with_trainer(trainer, device, test_loader) + + +if __name__ == '__main__': + main()