diff --git a/orttraining/orttraining/python/training/ortmodule.py b/orttraining/orttraining/python/training/ortmodule.py index 0ccc8a54f3..493c413777 100644 --- a/orttraining/orttraining/python/training/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule.py @@ -15,10 +15,10 @@ ONNX_OPSET_VERSION = 12 class ORTModule(torch.nn.Module): def __init__(self, module): - print(f'ORTModule.__init__() was called') assert isinstance(module, torch.nn.Module), "'module' mst be a torch.nn.Module" super(ORTModule, self).__init__() - # User will interact with it (debugging, etc) + + # User module is wrapped to use its initializers and save computed gradients self._original_module = module # Forward pass @@ -31,10 +31,11 @@ class ORTModule(torch.nn.Module): # Backward pass self._onnx_backward = None self._backward_session = None + self._onnx_backward_initializers_desc = [] + self._onnx_backward_inputs_desc = [] + self._onnx_backward_outputs_desc = [] def forward(self, *input, **kwargs): - print(f'ORTModule.forward() was called') - if not self._onnx_forward: original_forward_graph = ORTModule._get_forward_graph(self._original_module, *input, **kwargs) gradient_graph = ORTModule._build_gradient_graph(original_forward_graph) @@ -45,118 +46,132 @@ class ORTModule(torch.nn.Module): self._forward_session = onnxruntime.InferenceSession(self._onnx_forward.SerializeToString()) self._backward_session = onnxruntime.InferenceSession(self._onnx_backward.SerializeToString()) - # TODO: debug only - self._save_onnx_graph(self._onnx_forward, 'ortmodule_forward_mnist.onnx') - self._save_onnx_graph(self._onnx_backward, 'ortmodule_backward_mnist.onnx') - + # Forward I/O description if not self._onnx_forward_initializers_desc: self._onnx_forward_initializers_desc = self._get_initializer_from_graph(self._onnx_forward) + print(f'Forward initializers: {self._onnx_forward_initializers_desc}') if not self._onnx_forward_inputs_desc: self._onnx_forward_inputs_desc = self._get_input_from_graph(self._onnx_forward) + print(f'Forward inputs: {self._onnx_forward_inputs_desc}') if not self._onnx_forward_outputs_desc: self._onnx_forward_outputs_desc = self._get_output_from_graph(self._onnx_forward) + print(f'Forward outputs: {self._onnx_forward_outputs_desc}') - # TODO: debug only - print(f'Initializers: {self._onnx_forward_initializers_desc}') - print(f'Inputs: {self._onnx_forward_inputs_desc}') - print(f'Outpus: {self._onnx_forward_outputs_desc}') + # Backward I/O description + if not self._onnx_backward_initializers_desc: + self._onnx_backward_initializers_desc = self._get_initializer_from_graph(self._onnx_backward) + print(f'Backward initializers: {self._onnx_backward_initializers_desc}') + if not self._onnx_backward_inputs_desc: + self._onnx_backward_inputs_desc = self._get_input_from_graph(self._onnx_backward) + print(f'Backward inputs: {self._onnx_forward_inputs_desc}') + if not self._onnx_backward_outputs_desc: + self._onnx_backward_outputs_desc = self._get_output_from_graph(self._onnx_backward) + print(f'Backward outputs: {self._onnx_backward_outputs_desc}') # Use a custom torch.autograd.Function to associate self.backward_graph as the # gradient implementation for self.forward_graph. class _ORTModuleFunction(torch.autograd.Function): @staticmethod def forward(ctx, *input, **kwargs): - print(f'_ORTModuleFunction.forward() was called...') - # Note: A potential optimization would be to detect which of inputs and weights - # require a gradient. - # intermediates, outputs = self._run_forward_graph(inputs) # inputs, weights) - outputs = self._run_forward_graph(self._prepare_forward_input(*input, **kwargs)) # inputs, weights) - outputs = [torch.nn.Parameter(torch.from_numpy(out)) for out in outputs] + # TODO: Potential optimization is to detect which inputs and weights require gradients + input_with_initializer = self._prepare_forward_input_ort(*input, **kwargs) + outputs = self._run_forward_graph(input_with_initializer) + outputs = tuple(torch.from_numpy(out) for out in outputs) - # TODO: Properly save intermediate tensors and remove them from model output - ctx.save_for_backward([(input, kwargs), outputs[1]]) - # outputs = [outputs[0]] + # TODO: Properly save dynamic number of intermediate tensors and remove them from model output + # Tensors that need to have gradients tracked can't be saved by `save_for_backward` + # saved_tensors ==> input1, fc2.weight, 7 + ctx.save_for_backward(*[input[0], input[3], outputs[1]]) + outputs = [outputs[0]] # TODO: Properly support original module output format if len(outputs) == 1: return outputs[0] - return tuple(outputs) + return outputs @staticmethod def backward(ctx, *grad_output): - print(f'_ORTModuleFunction.backward() was called') - input_and_kwargs, intermediates = ctx.saved_tensors - # grad_inputs, grad_weights = self._run_backward_graph( - # grad_output, intermediates) - # return grad_inputs, grad_weights + # TODO: Properly restore dynamic number of intermediate tensors + # saved_tensors ==> input1, fc2.weight, 7 + saved_tensors = ctx.saved_tensors + grad_weights = self._run_backward_graph(*[*saved_tensors, *grad_output]) + grad_weights = [torch.from_numpy(grad) for grad in grad_weights] + # TODO: backward must return grad tensors in the same order forward does + # [input1_grad, fc1.weight_grad, fc1.bias_grad, fc2.weight_grad, fc2.bias_grad] + return tuple([torch.tensor([1.]), grad_weights[1], grad_weights[0], grad_weights[2], grad_weights[3]]) - return _ORTModuleFunction.apply(*input, **kwargs) + return _ORTModuleFunction.apply(*self._prepare_forward_input_autograd(*input, **kwargs)) - def _prepare_forward_input(self, *input, **kwargs): - # Dictionary containing both inputs and initializers - input_with_initializer = {} + def _prepare_forward_input_autograd(self, *input, **kwargs): + # List containing both user inputs and initializers, in this order + input_with_initializer = [] # Inputs for idx, input_data in enumerate(self._forward_session.get_inputs()): - input_with_initializer.update({input_data.name: input[idx].cpu().numpy()}) + input_with_initializer.append(input[idx]) # Initializers for idx, param in enumerate(self._original_module.named_parameters()): - input_with_initializer.update({param[0]: param[1].detach().numpy()}) + input_with_initializer.append(param[1]) + + # TODO: [input1, fc1.weight, fc1.bias, fc2.weight, fc2.bias] + return input_with_initializer + + def _prepare_forward_input_ort(self, *inputs): + # Dictionary containing both inputs and initializers + input_with_initializer = {} + + # TODO: [input1, fc1.weight, fc1.bias, fc2.weight, fc2.bias] + # Inputs + inputs_len = 0 + for idx, input_data in enumerate(self._forward_session.get_inputs()): + inputs_len += 1 + input_with_initializer.update({input_data.name: inputs[idx].cpu().numpy()}) + + # Initializers + for param in self._original_module.named_parameters(): + input_with_initializer.update({param[0]: inputs[inputs_len].detach().numpy()}) + inputs_len += 1 return input_with_initializer - def _prepare_backward_input(self, grad_output, intermediates, *inputs, **kwargs): + def _prepare_backward_input(self, *inputs, **kwargs): # Dictionary containing initializers input_with_initializer = {} # User input # TODO: How to determine which user input to feed to backward - for idx, input_data in enumerate(self._forward_session.get_inputs()): - input_with_initializer.update({input_data.name: inputs[idx].cpu().numpy()}) + # for idx, input_data in enumerate(self._forward_session.get_inputs()): + # input_with_initializer.update({input_data.name: inputs[idx].cpu().numpy()}) + input_with_initializer.update({'input1' : inputs[0].detach().numpy()}) # Initializers # TODO: How to determine which initializer (subset) to be used - for idx, param in enumerate(self._original_module.named_parameters()): - if param[0] == 'fc2.weight': - input_with_initializer.update({param[0]: param[1].detach().numpy()}) - - # Grad output - # TODO: How to determine grad_output name? - input_with_initializer.update({'probability_grad': grad_output.detach().numpy()}) + # for idx, param in enumerate(self._original_module.named_parameters()): + # input_with_initializer.update({param[0]: param[1].detach().numpy()}) + input_with_initializer.update({'fc2.weight' : inputs[1].detach().numpy()}) # Intermediates # TODO: How to determine intermediates name? - input_with_initializer.update({'7': intermediates.detach().numpy()}) + input_with_initializer.update({'7': inputs[2].detach().numpy()}) + + # Grad output + # TODO: How to determine grad_output name? + input_with_initializer.update({'probability_grad': inputs[3].detach().numpy()}) return input_with_initializer def _run_forward_graph(self, data_with_initializer): # input, weights): - print(f'_run_forward_graph was called...') return self._forward_session.run(None, data_with_initializer) - def _run_backward_graph(self, grad_output, intermediates, *inputs, **kwargs): - # Use an InferenceSession to execute self.backward_graph. - # Return gradient tensors for inputs and weights. - print(f'_run_backward_graph was called...') - data = self._prepare_backward_input(grad_output, intermediates, *inputs, **kwargs) + def _run_backward_graph(self, *inputs, **kwargs): + data = self._prepare_backward_input(*inputs, **kwargs) # TODO: Hack to guarantee output order from InferenceSession.run() return self._backward_session.run(['fc1.bias_grad', 'fc1.weight_grad', 'fc2.weight_grad', 'fc2.bias_grad'], data) - # return self._backward_session.run(None, data) @staticmethod def _get_forward_graph(module, module_input): - print(f'_get_forward_graph was called...') # TODO: Pytorch module must be exported to ONNX and splitted # Hard-coding with MNIST stub for MVP - # Export torch.nn.Module to ONNX with initializers as input - # f = io.BytesIO() - # torch.onnx.export(module, module_input, f, verbose=True, - # opset_version=ONNX_OPSET_VERSION, - # _retain_param_name=True, - # training=torch.onnx.TrainingMode.TRAINING, - # keep_initializers_as_inputs=True, - # export_params=True) - # return onnx.load_model_from_string(f.getvalue()) return onnx.load('./model_with_training_forward_sliced.onnx') def _get_initializer_from_graph(self, graph): @@ -200,6 +215,7 @@ class ORTModule(torch.nn.Module): for elem in graph.graph.output: for initializer in self._onnx_forward_initializers_desc: if elem.name == initializer['name']: + # skip initializers break else: name = elem.name @@ -236,7 +252,6 @@ class ORTModule(torch.nn.Module): @staticmethod def _build_gradient_graph(forward_graph): - print(f'_build_gradient_graph was called...') # TODO: Invoke the C++ GradientBuilder implementation via pybind. # Return an ONNX graph that contains the forward and backward nodes, which takes the # following inputs: @@ -248,7 +263,6 @@ class ORTModule(torch.nn.Module): @staticmethod def _split_forward_and_backward(gradient_graph): - print(f'_split_forward_and_backward was called...') # TODO: Split the result of _build_gradient_graph into two subgraphs: # * A forward graph that takes module inputs and weights as input, and produces module # outputs and (“stashed”) intermediate tensors as output. diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py index 725b5e2d9b..6af177b80a 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py @@ -1,14 +1,10 @@ import argparse import torch from torchvision import datasets, transforms -import torchviz -from onnxruntime import set_seed +import onnxruntime from onnxruntime.training import ORTModule -import _test_commons -import _test_helpers - class NeuralNet(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): @@ -24,112 +20,127 @@ class NeuralNet(torch.nn.Module): out = self.fc2(out) return out -def main(): - #Training settings - parser = argparse.ArgumentParser(description='PyTorch MNIST Example') - parser.add_argument('--pytorch-only', action='store_true', default=False, - help='disables ONNX Runtime training') - args = parser.parse_args() - # Model architecture - lr = 1e-4 - batch_size=20 - seed=42 - - torch.manual_seed(seed) - set_seed(seed) - - - model = NeuralNet(input_size=784, hidden_size=500, num_classes=10) - print('Training MNIST on ORTModule....') - if not args.pytorch_only: - model = ORTModule(model) - - criterion = torch.nn.CrossEntropyLoss() - optimizer = torch.optim.SGD(model.parameters(), lr=lr) - - # Data loader - train_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=True, download=True, - transform=transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,))])), - batch_size=batch_size, - shuffle=True) - # Training Loop - loss = float('inf') +def train(args, model, device, optimizer, loss_fn, train_loader, epoch): + model.train() for iteration, (data, target) in enumerate(train_loader): - if iteration == 1: - print(f'Final loss is {loss}') + if iteration == args.train_steps: break - + data, target = data.to(device), target.to(device) data = data.reshape(data.shape[0], -1) + optimizer.zero_grad() if args.pytorch_only: - print("Using PyTorch-only API") + probability = model(data) + else: probability = model(data) + if args.view_graphs: + import torchviz pytorch_backward_graph = torchviz.make_dot(probability, params=dict(list(model.named_parameters()))) - print(f'probability.grad_fn={probability.grad_fn}') - print(f'probability.grad_fn.next_functions={probability.grad_fn.next_functions}') - # pytorch_backward_graph.view() - probability.retain_grad() - else: - print("Using ONNX Runtime Flexible API") - probability, intermediates = model(data) - probability.requires_grad_(True) + pytorch_backward_graph.view() - print(f'Output from forward has shape {probability.size()}') - loss = criterion(probability, target) + loss = loss_fn(probability, target) loss.backward() - print(f'***** probability.grad[0]={probability.grad[0]}') - - if args.pytorch_only: - print(f'***** (PYTORCH) fc1.bias_grad[0] BEFORE {model.fc1.bias.data[0].item()}') - print(f'***** (PYTORCH) fc1.weight_grad[0][0] BEFORE {model.fc1.weight.data[0][0].item()}') - print(f'***** (PYTORCH) fc2.bias_grad[0] BEFORE {model.fc2.bias.data[0].item()}') - print(f'***** (PYTORCH) fc2.weight_grad[0][0] BEFORE {model.fc2.weight.data[0][0].item()}') - else: - # import pdb; pdb.set_trace() - # Fake backward call to test backprop graph - # TODO: The model output *order* is changing from ONNX export to ONNX export - fc1_bias_grad, fc1_weight_grad, fc2_weight_grad, fc2_bias_grad = model._run_backward_graph(probability.grad, intermediates, data) - fc1_bias_grad = torch.from_numpy(fc1_bias_grad).requires_grad_(True) - fc2_bias_grad = torch.from_numpy(fc2_bias_grad).requires_grad_(True) - fc1_weight_grad = torch.from_numpy(fc1_weight_grad).requires_grad_(True) - fc2_weight_grad = torch.from_numpy(fc2_weight_grad).requires_grad_(True) - fc1_bias_grad.retain_grad() - fc1_weight_grad.retain_grad() - fc2_bias_grad.retain_grad() - fc2_weight_grad.retain_grad() - - print(f'***** (ONNX Runtime) fc1_bias_grad[0] BEFORE {model._original_module.fc1.bias.data[0].item()}') - print(f'***** (ONNX Runtime) fc1_weight_grad[0][0] BEFORE {model._original_module.fc1.weight.data[0][0].item()}') - print(f'***** (ONNX Runtime) fc2_bias_grad[0] BEFORE {model._original_module.fc2.bias.data[0].item()}') - print(f'***** (ONNX Runtime) fc2_weight_grad[0][0] BEFORE {model._original_module.fc2.weight.data[0][0].item()}') - print(f'***** (ONNX Runtime) fc1_bias_grad[0] AFTER {fc1_bias_grad[0].item()}') - print(f'***** (ONNX Runtime) fc1_weight_grad[0][0] AFTER {fc1_weight_grad[0][0]}') - print(f'***** (ONNX Runtime) fc2_bias_grad[0] AFTER {fc2_bias_grad[0].item()}') - print(f'***** (ONNX Runtime) fc2_weight_grad[0][0] AFTER {fc2_weight_grad[0][0].item()}') - model._original_module.fc1.bias.data = fc1_bias_grad.data - model._original_module.fc1.weight.data = fc1_weight_grad.data - model._original_module.fc2.bias.data = fc2_bias_grad.data - model._original_module.fc2.weight.data = fc2_weight_grad.data - - print(f'Output from backaward has the following shapes after update:') - print(f'fc1_bias_grad={fc1_bias_grad.size()}') - print(f'fc2_bias_grad={fc2_bias_grad.size()}') - print(f'fc1_weight_grad={fc1_weight_grad.size()}') - print(f'fc2_weight_grad={fc2_weight_grad.size()}') - optimizer.step() - if args.pytorch_only: - print(f'***** (PYTORCH) fc1.bias_grad[0] AFTER {model.fc1.bias.data[0].item()}') - print(f'***** (PYTORCH) fc1.weight_grad[0][0] AFTER {model.fc1.weight.data[0][0].item()}') - print(f'***** (PYTORCH) fc2.bias_grad[0] AFTER {model.fc2.bias.data[0].item()}') - print(f'***** (PYTORCH) fc2.weight_grad[0][0] AFTER {model.fc2.weight.data[0][0].item()}') - if iteration == 0: - print(f'Initial loss is {loss}') - print('Tah dah!') + # Stats + if iteration % args.log_interval == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, iteration * len(data), len(train_loader.dataset), + 100. * iteration / len(train_loader), loss)) + + +def test(args, model, device, loss_fn, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + data = data.reshape(data.shape[0], -1) + output = model(data) + + # Stats + test_loss += loss_fn(output, target, False).item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + test_loss /= len(test_loader.dataset) + print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), + 100. * correct / len(test_loader.dataset))) + +def my_loss(x, target, is_train=True): + if is_train: + return torch.nn.CrossEntropyLoss()(x, target) + else: + return torch.nn.CrossEntropyLoss(reduction='sum')(x, target) + +def main(): + # Training settings + parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser.add_argument('--train-steps', type=int, default=-1, metavar='N', + help='number of steps to train. Set -1 to run through whole dataset (default: -1)') + parser.add_argument('--lr', type=float, default=0.001, metavar='LR', + help='learning rate (default: 0.001)') + parser.add_argument('--batch-size', type=int, default=20, metavar='N', + help='input batch size for training (default: 20)') + parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', + help='input batch size for testing (default: 1000)') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--seed', type=int, default=42, metavar='S', + help='random seed (default: 42)') + parser.add_argument('--pytorch-only', action='store_true', default=False, + help='disables ONNX Runtime training') + parser.add_argument('--log-interval', type=int, default=100, metavar='N', + help='how many batches to wait before logging training status (default: 100)') + parser.add_argument('--view-graphs', action='store_true', default=False, + help='views forward and backward graphs') + parser.add_argument('--epochs', type=int, default=10, metavar='N', + help='number of epochs to train (default: 10)') + args = parser.parse_args() + + + # Common setup + torch.manual_seed(args.seed) + onnxruntime.set_seed(args.seed) + + # TODO: CUDA support is broken due to copying from PyTorch into ORT + # if not args.no_cuda and torch.cuda.is_available(): + # device = "cuda" + # else: + # device = "cpu" + device = 'cpu' + + ## Data loader + train_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=True, download=True, + transform=transforms.Compose([transforms.ToTensor(), + + transforms.Normalize((0.1307,), (0.3081,))])), + batch_size=args.batch_size, + shuffle=True) + if args.test_batch_size > 0: + test_loader = torch.utils.data.DataLoader( + datasets.MNIST('./data', train=False, transform=transforms.Compose([ + transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])), + batch_size=args.test_batch_size, shuffle=True) + + # Model architecture + model = NeuralNet(input_size=784, hidden_size=500, num_classes=10).to(device) + if not args.pytorch_only: + print('Training MNIST on ORTModule....') + model = ORTModule(model) + else: + print('Training MNIST on vanilla PyTorch....') + optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) + + # Train loop + for epoch in range(1, args.epochs + 1): + train(args, model, device, optimizer, my_loss, train_loader, epoch) + if args.test_batch_size > 0: + test(args, model, device, my_loss, test_loader) + if __name__ == '__main__': main() diff --git a/samples/python/mnist/pytorch_mnist.py b/samples/python/mnist/pytorch_mnist.py index b1bc599531..f6cdb8be3f 100644 --- a/samples/python/mnist/pytorch_mnist.py +++ b/samples/python/mnist/pytorch_mnist.py @@ -122,7 +122,6 @@ def main(): train(args, model, device, train_loader, optimizer, epoch) if args.test_batch_size > 0: test(model, device, test_loader) - optimizer.step() # Save model if args.save_path: