From 39ac95b2fc08ec4d75b6f9f59abe74de22c4907d Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Mon, 16 Nov 2020 09:17:27 +0000 Subject: [PATCH] add io binding --- .../module_gradient_graph_builder.cc | 52 +++--- .../framework/module_gradient_graph_builder.h | 7 +- .../orttraining/python/training/ortmodule.py | 174 ++++++++++++------ .../orttraining_test_ortmodule_iobinding.py | 161 ++++++++++++++++ run_ortmodule_mvp_mnist_iobinding.sh | 18 ++ 5 files changed, 324 insertions(+), 88 deletions(-) create mode 100644 orttraining/orttraining/test/python/orttraining_test_ortmodule_iobinding.py create mode 100644 run_ortmodule_mvp_mnist_iobinding.sh diff --git a/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc b/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc index 338580ab6e..55df8a0a25 100644 --- a/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc +++ b/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc @@ -14,8 +14,7 @@ namespace training { using namespace onnxruntime::common; -void GetInputAndOutputNames(const Node& node, - std::unordered_set& input_names, +void GetInputAndOutputNames(const Node& node, std::unordered_set& input_names, std::unordered_set& output_names) { std::for_each(node.InputDefs().begin(), node.InputDefs().end(), [&input_names](const NodeArg* node_arg) { input_names.insert(node_arg->Name()); }); @@ -63,7 +62,8 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream, split_graphs_info_.user_output_names.emplace_back(node_arg->Name()); } - split_graphs_info_.initializer_names_to_train.assign(config.initializer_names_to_train.begin(), config.initializer_names_to_train.end()); + split_graphs_info_.initializer_names_to_train.assign(config.initializer_names_to_train.begin(), + config.initializer_names_to_train.end()); // Register and apply transformers for pre-training. const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration graph_transformer_config{}; @@ -76,8 +76,9 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream, config.input_names_require_grad.begin(), config.input_names_require_grad.end(), std::inserter(x_node_arg_names, x_node_arg_names.begin())); auto add_transformers = [&](TransformerLevel level) { + std::unordered_map updated_weight_names{}; auto transformers_to_register = transformer_utils::GeneratePreTrainingTransformers( - level, x_node_arg_names, graph_transformer_config, *cpu_execution_provider, {}); + level, x_node_arg_names, graph_transformer_config, *cpu_execution_provider, updated_weight_names, {}); for (auto& entry : transformers_to_register) { graph_transformation_mgr.Register(std::move(entry), level); } @@ -101,13 +102,11 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream, GradientGraphConfiguration gradient_graph_config{}; gradient_graph_config.use_invertible_layernorm_grad = config.use_invertible_layernorm_grad; gradient_graph_config.set_gradients_as_graph_outputs = config.set_gradients_as_graph_outputs; - std::unordered_set y_node_arg_names(split_graphs_info_.user_output_names.begin(), split_graphs_info_.user_output_names.end()); - GradientGraphBuilder grad_graph_builder(&model_->MainGraph(), - y_node_arg_names, - x_node_arg_names, - "", // not support loss name for now. - gradient_graph_config, - *logger_); + std::unordered_set y_node_arg_names(split_graphs_info_.user_output_names.begin(), + split_graphs_info_.user_output_names.end()); + GradientGraphBuilder grad_graph_builder(&model_->MainGraph(), y_node_arg_names, x_node_arg_names, + "", // not support loss name for now. + gradient_graph_config, *logger_); ORT_RETURN_IF_ERROR(grad_graph_builder.Build()); // Fix inputs/outputs related to gradients. @@ -152,6 +151,7 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream, for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) { std::string initializer_gradient_name = initializer_name + "_grad"; if (output_names.find(initializer_gradient_name) != output_names.end()) { + split_graphs_info_.initializer_grad_names_to_train.emplace_back(initializer_gradient_name); output_args.emplace_back(gradient_graph.GetNodeArg(initializer_gradient_name)); } } @@ -188,17 +188,11 @@ std::string SerializeModel(const std::shared_ptr& model, con return model_str; } -std::string ModuleGradientGraphBuilder::GetGradientModel() const { - return SerializeModel(model_, "gradient"); -} +std::string ModuleGradientGraphBuilder::GetGradientModel() const { return SerializeModel(model_, "gradient"); } -std::string ModuleGradientGraphBuilder::GetForwardModel() const { - return SerializeModel(forward_model_, "forward"); -} +std::string ModuleGradientGraphBuilder::GetForwardModel() const { return SerializeModel(forward_model_, "forward"); } -std::string ModuleGradientGraphBuilder::GetBackwardModel() const { - return SerializeModel(backward_model_, "backward"); -} +std::string ModuleGradientGraphBuilder::GetBackwardModel() const { return SerializeModel(backward_model_, "backward"); } Status ModuleGradientGraphBuilder::Split() { // Get forward model, also collect some information for backward model generation. @@ -253,8 +247,8 @@ Status ModuleGradientGraphBuilder::Split() { // Add intermediate args to forward graph outputs. for (const auto& intermediate_arg_name : intermediate_arg_names) { // Ignore the user outputs. - if (std::find(split_graphs_info_.user_output_names.begin(), split_graphs_info_.user_output_names.end(), intermediate_arg_name) - == split_graphs_info_.user_output_names.end()) { + if (std::find(split_graphs_info_.user_output_names.begin(), split_graphs_info_.user_output_names.end(), + intermediate_arg_name) == split_graphs_info_.user_output_names.end()) { split_graphs_info_.intermediate_tensor_names.emplace_back(intermediate_arg_name); forward_output_args.emplace_back(forward_graph.GetNodeArg(intermediate_arg_name)); } @@ -264,7 +258,8 @@ Status ModuleGradientGraphBuilder::Split() { // Resolve the forward graph, keep the trainable initializers for now. Graph::ResolveOptions options; - std::unordered_set initializer_names_to_train_set(split_graphs_info_.initializer_names_to_train.begin(), split_graphs_info_.initializer_names_to_train.end()); + std::unordered_set initializer_names_to_train_set(split_graphs_info_.initializer_names_to_train.begin(), + split_graphs_info_.initializer_names_to_train.end()); options.initializer_names_to_preserve = &initializer_names_to_train_set; forward_graph.Resolve(options); @@ -292,15 +287,9 @@ Status ModuleGradientGraphBuilder::Split() { } } - // Grad of user outputs to backward graph inputs. - for (const auto& output_grad_name : split_graphs_info_.backward_output_grad_names) { - backward_input_args.emplace_back(backward_graph.GetNodeArg(output_grad_name)); - } - // Add initializer args to backward graph inputs if any node uses them. for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) { // Some initializers will be inputs for backward graph. - split_graphs_info_.initializer_grad_names_to_train.emplace_back(initializer_name + "_grad"); if (backward_input_names.find(initializer_name) != backward_input_names.end()) { split_graphs_info_.backward_intializer_names_as_input.emplace_back(initializer_name); backward_input_args.emplace_back(backward_graph.GetNodeArg(initializer_name)); @@ -315,6 +304,11 @@ Status ModuleGradientGraphBuilder::Split() { backward_input_args.emplace_back(intermediate_node_arg); } + // Grad of user outputs to backward graph inputs. + for (const auto& output_grad_name : split_graphs_info_.backward_output_grad_names) { + backward_input_args.emplace_back(backward_graph.GetNodeArg(output_grad_name)); + } + backward_graph.SetInputs(backward_input_args); // Exclude user outputs from the backward graph. diff --git a/orttraining/orttraining/core/framework/module_gradient_graph_builder.h b/orttraining/orttraining/core/framework/module_gradient_graph_builder.h index dbd5325869..33abc38e84 100644 --- a/orttraining/orttraining/core/framework/module_gradient_graph_builder.h +++ b/orttraining/orttraining/core/framework/module_gradient_graph_builder.h @@ -44,15 +44,12 @@ struct SplitGraphsInfo { class ModuleGradientGraphBuilder { public: - Status BuildAndSplit(std::istream& model_istream, - const ModuleGradientGraphBuilderConfiguration& config); + Status BuildAndSplit(std::istream& model_istream, const ModuleGradientGraphBuilderConfiguration& config); std::string GetGradientModel() const; std::string GetForwardModel() const; std::string GetBackwardModel() const; - SplitGraphsInfo GetSplitGraphsInfo() const { - return split_graphs_info_; - } + SplitGraphsInfo GetSplitGraphsInfo() const { return split_graphs_info_; } private: Status Split(); diff --git a/orttraining/orttraining/python/training/ortmodule.py b/orttraining/orttraining/python/training/ortmodule.py index cf4ffa6358..7b29d3af11 100644 --- a/orttraining/orttraining/python/training/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule.py @@ -6,6 +6,7 @@ import onnxruntime import os import torch import warnings +import numpy as np from inspect import signature from onnxruntime.capi import _pybind_state as C @@ -15,12 +16,44 @@ from . import _utils ONNX_OPSET_VERSION = 12 +def get_device_index(device): + if type(device) == str: + # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 + device = torch.device(device) + return 0 if device.index is None else device.index + + +def prepare_io_binding(io_binding, inputs, model, output_buffers, device): + idx = 0 + for value_info in model.graph.input: + io_binding.bind_input(value_info.name, inputs[idx].device.type, get_device_index(inputs[idx].device), + _utils.dtype_torch_to_numpy(inputs[idx].dtype), list(inputs[idx].size()), + inputs[idx].data_ptr()) + idx += 1 + + for value_info in model.graph.output: + name = value_info.name + output_tensor = output_buffers[name] + io_binding.bind_output(name, output_tensor.device.type, get_device_index(device), + _utils.dtype_torch_to_numpy(output_tensor.dtype), list(output_tensor.size()), + output_tensor.data_ptr()) + + +def value_info_to_buffer_tensor(value_info, device): + shape = [dim.dim_value for dim in value_info.type.tensor_type.shape.dim] + dtype = _utils.dtype_onnx_to_torch(value_info.type.tensor_type.elem_type) + return torch.zeros(shape, device=device, dtype=dtype) + + class ORTModule(torch.nn.Module): - def __init__(self, module): + def __init__(self, module, device="cpu", use_iobinding=False): assert isinstance(module, torch.nn.Module), "'module' mst be a torch.nn.Module" super(ORTModule, self).__init__() + self._device = device + self._use_iobinding = use_iobinding + # User module is wrapped to use its initializers and save computed gradients self._original_module = module self._onnx_training = None @@ -52,9 +85,10 @@ class ORTModule(torch.nn.Module): if not self._onnx_forward: self._onnx_training = ORTModule._get_forward_graph(self._original_module, *inputs, **kwargs) grad_builder_config = C.ModuleGradientGraphBuilderConfiguration() - self._onnx_gradient, self._onnx_forward, self._onnx_backward, self._onnx_graphs_info = ORTModule._build_fw_bw_grad_graphs(self._onnx_training, grad_builder_config) # TODO: PyTorch exporter bug: changes the initializer order - self._onnx_graphs_info.initializer_grad_names_to_train = [ p[0]+'_grad' for p in self._original_module.named_parameters()] + # Use the order in original module + initializer_names = [p[0] for p in self._original_module.named_parameters()] + self._onnx_gradient, self._onnx_forward, self._onnx_backward, self._onnx_graphs_info = ORTModule._build_fw_bw_grad_graphs(self._onnx_training, grad_builder_config, initializer_names) if self._save_onnx: onnx.save(self._onnx_training, self._save_onnx_prefix + '_full_training.onnx') @@ -65,9 +99,19 @@ class ORTModule(torch.nn.Module): # TODO: Consider moving this to the backend. We don't want to append '_grad' to get correct tensor names self._onnx_graphs_types = ORTModule._get_io_info_from_onnx_graph(self._onnx_forward, self._onnx_graphs_info) - # TODO: hard-coding to CPU only - self._forward_session = onnxruntime.InferenceSession(self._onnx_forward.SerializeToString(), providers=['CPUExecutionProvider']) - self._backward_session = onnxruntime.InferenceSession(self._onnx_backward.SerializeToString(), providers=['CPUExecutionProvider']) + execution_providers = ['CPUExecutionProvider'] if self._device == 'cpu' else ['CUDAExecutionProvider', 'CPUExecutionProvider'] + self._forward_session = onnxruntime.InferenceSession(self._onnx_forward.SerializeToString(), providers=execution_providers) + self._backward_session = onnxruntime.InferenceSession(self._onnx_backward.SerializeToString(), providers=execution_providers) + + if self._use_iobinding: + self._forward_io_binding = self._forward_session.io_binding() + self._forward_output_buffers = {} + for output in self._onnx_forward.graph.output: + self._forward_output_buffers[output.name] = value_info_to_buffer_tensor(output, self._device) + self._backward_io_binding = self._backward_session.io_binding() + self._backward_output_buffers = {} + for output in self._onnx_backward.graph.output: + self._backward_output_buffers[output.name] = value_info_to_buffer_tensor(output, self._device) # Use a custom torch.autograd.Function to associate self.backward_graph as the # gradient implementation for self.forward_graph. @@ -85,30 +129,44 @@ class ORTModule(torch.nn.Module): * Intermediate tensors ''' - # Convert input to dict of torch tensors - data_dict = self._convert_forward_input_list_to_dict(*inputs) + if not self._use_iobinding: + # Convert input to dict of torch tensors + data_dict = self._convert_forward_input_list_to_dict(*inputs) - # Convert dict of torch tensors to dict of numpy arrays (ORT BE requirement) - data_dict_numpy = self._convert_dict_torch_to_numpy(data_dict) + # Convert dict of torch tensors to dict of numpy arrays (ORT BE requirement) + data_dict_numpy = self._convert_dict_torch_to_numpy(data_dict) - # Feed forward - outputs, intermediate = self._run_forward_graph(data_dict_numpy) - outputs = tuple(torch.from_numpy(item) for item in outputs) + # Feed forward + outputs, intermediate = self._run_forward_graph(data_dict_numpy) + outputs = tuple(torch.from_numpy(item) for item in outputs) - # Save input, initializers and intermediate tensors to be used during backward - user_input = self._onnx_graphs_info.user_input_names - backward_user_input = self._onnx_graphs_info.backward_user_input_names - ctx_input = tuple(data_dict[name] for name in user_input if name in backward_user_input) - forward_initializer = self._onnx_graphs_info.initializer_names_to_train - backward_intializer = self._onnx_graphs_info.backward_intializer_names_as_input - ctx_initializer = tuple(data_dict[name] for name in forward_initializer if name in backward_intializer) - intermediate = tuple(torch.from_numpy(item) for item in intermediate) - ctx.save_for_backward(*[*ctx_input, *ctx_initializer, *intermediate]) + # Save input, initializers and intermediate tensors to be used during backward + user_input = self._onnx_graphs_info.user_input_names + backward_user_input = self._onnx_graphs_info.backward_user_input_names + ctx_input = tuple(data_dict[name] for name in user_input if name in backward_user_input) + forward_initializer = self._onnx_graphs_info.initializer_names_to_train + backward_intializer = self._onnx_graphs_info.backward_intializer_names_as_input + ctx_initializer = tuple(data_dict[name] for name in forward_initializer if name in backward_intializer) + intermediate = tuple(torch.from_numpy(item) for item in intermediate) + ctx.save_for_backward(*[*ctx_input, *ctx_initializer, *intermediate]) - # TODO: Support original module output (currently dict is not supported) - if len(outputs) == 1: - return outputs[0] - return outputs + # TODO: Support original module output (currently dict is not supported) + if len(outputs) == 1: + return outputs[0] + return outputs + + # Use IO binding. + prepare_io_binding(self._forward_io_binding, inputs, self._onnx_forward, self._forward_output_buffers, self._device) + self._forward_session.run_with_iobinding(self._forward_io_binding) + + forward_input_dict = self._convert_forward_input_list_to_dict(*inputs) + ctx_inputs = tuple(forward_input_dict[name] for name in self._onnx_graphs_info.backward_user_input_names) + ctx_initializers = tuple(forward_input_dict[name] for name in self._onnx_graphs_info.backward_intializer_names_as_input) + ctx_intermediates = tuple(self._forward_output_buffers[name] for name in self._onnx_graphs_info.intermediate_tensor_names) + ctx.save_for_backward(*[*ctx_inputs, *ctx_initializers, *ctx_intermediates]) + + outputs = tuple(self._forward_output_buffers[name] for name in self._onnx_graphs_info.user_output_names) + return outputs[0] if len(outputs) == 1 else outputs @staticmethod def backward(ctx, *grad_output): @@ -120,12 +178,23 @@ class ORTModule(torch.nn.Module): TODO: Input gradient is hard-coded to torch.tensor([1.]) ''' - saved_tensors = ctx.saved_tensors - grad_weights = self._run_backward_graph(*[*saved_tensors, *grad_output]) + if not self._use_iobinding: + saved_tensors = ctx.saved_tensors + grad_weights = self._run_backward_graph(*[*saved_tensors, *grad_output]) - result = [torch.tensor([1])]* len(self._onnx_graphs_info.user_input_names) - result += [torch.from_numpy(grad) for grad in grad_weights] - return tuple(result) + result = [torch.tensor([1])]* len(self._onnx_graphs_info.user_input_names) + result += [torch.from_numpy(grad) for grad in grad_weights] + return tuple(result) + + # Use IO binding. + grad_output_dict = dict(zip(self._onnx_graphs_info.user_output_grad_names, grad_output)) + backward_grad_output = tuple(grad_output_dict[name] for name in self._onnx_graphs_info.backward_output_grad_names) + prepare_io_binding(self._backward_io_binding, [*ctx.saved_tensors, *backward_grad_output], self._onnx_backward, self._backward_output_buffers, self._device) + self._backward_session.run_with_iobinding(self._backward_io_binding) + + results = [torch.tensor([1])] * len(self._onnx_graphs_info.user_input_names) + results += [self._backward_output_buffers[name] for name in self._onnx_graphs_info.initializer_grad_names_to_train] + return tuple(results) proc_inputs = [data for data in inputs if data is not None] return _ORTModuleFunction.apply(*self._convert_forward_input_to_list(*proc_inputs, **kwargs)) @@ -297,13 +366,16 @@ class ORTModule(torch.nn.Module): @staticmethod - def _build_fw_bw_grad_graphs(forward_graph, config): + def _build_fw_bw_grad_graphs(forward_graph, config, initializer_names=[]): '''Adds gradient nodes on top of an existing ONNX graph (with training flag)''' if not config.initializer_names_to_train: - initializer_names_to_train = [] - for initializer in forward_graph.graph.initializer: - initializer_names_to_train.append(initializer.name) - config.initializer_names_to_train = initializer_names_to_train + if not initializer_names: + initializer_names_to_train = [] + for initializer in forward_graph.graph.initializer: + initializer_names_to_train.append(initializer.name) + config.initializer_names_to_train = initializer_names_to_train + else: + config.initializer_names_to_train = initializer_names # TODO: Add support to input with grad required config.input_names_require_grad = [] @@ -323,27 +395,21 @@ class ORTModule(torch.nn.Module): @staticmethod def _get_io_info_from_onnx_graph(model, graphs_info): - type_map = {} - for name in graphs_info.user_input_names: - type_map[name] = None - for name in graphs_info.initializer_names_to_train: - type_map[name] = None - for name in graphs_info.user_output_names: - type_map[name] = None - for name in graphs_info.backward_user_input_names: - type_map[name] = None - for name in graphs_info.backward_intializer_names_as_input: - type_map[name] = None - for name in graphs_info.intermediate_tensor_names: - type_map[name] = None - for name in graphs_info.user_output_grad_names: - type_map[name] = None - for name in graphs_info.backward_output_grad_names: - type_map[name] = None + type_map = {key: None for key in [ + *graphs_info.user_input_names, + *graphs_info.initializer_names_to_train, + *graphs_info.initializer_grad_names_to_train, + *graphs_info.user_output_names, + *graphs_info.intermediate_tensor_names, + *graphs_info.user_output_grad_names + ]} for input in model.graph.input: if input.name in type_map and type_map[input.name] is None: type_map[input.name] = input.type + input_grad_name = input.name + '_grad' + if input_grad_name in type_map and type_map[input_grad_name] is None: + type_map[input_grad_name] = input.type for output in model.graph.output: if output.name in type_map and type_map[output.name] is None: @@ -352,4 +418,4 @@ class ORTModule(torch.nn.Module): if output_grad_name in type_map and type_map[output_grad_name] is None: type_map[output_grad_name] = output.type - return type_map \ No newline at end of file + return type_map diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_iobinding.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_iobinding.py new file mode 100644 index 0000000000..81ebdcf36a --- /dev/null +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_iobinding.py @@ -0,0 +1,161 @@ +import argparse +import logging +import torch +from torchvision import datasets, transforms + +import onnxruntime +from onnxruntime.training import ORTModule + + +class NeuralNet(torch.nn.Module): + def __init__(self, input_size, hidden_size, num_classes): + super(NeuralNet, self).__init__() + + self.fc1 = torch.nn.Linear(input_size, hidden_size) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(hidden_size, num_classes) + + def forward(self, input1): + out = self.fc1(input1) + out = self.relu(out) + out = self.fc2(out) + return out + + +def train(args, model, device, optimizer, loss_fn, train_loader, epoch): + model.train() + for iteration, (data, target) in enumerate(train_loader): + if iteration == args.train_steps: + break + data, target = data.to(device), target.to(device) + data = data.reshape(data.shape[0], -1) + + optimizer.zero_grad() + if args.pytorch_only: + probability = model(data) + else: + probability = model(data) + + if args.view_graphs: + import torchviz + pytorch_backward_graph = torchviz.make_dot(probability, params=dict(list(model.named_parameters()))) + pytorch_backward_graph.view() + + loss = loss_fn(probability, target) + loss.backward() + optimizer.step() + + # Stats + if iteration % args.log_interval == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, iteration * len(data), len(train_loader.dataset), + 100. * iteration / len(train_loader), loss)) + + +def test(args, model, device, loss_fn, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + data = data.reshape(data.shape[0], -1) + output = model(data) + + # Stats + test_loss += loss_fn(output, target, False).item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + test_loss /= len(test_loader.dataset) + print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), + 100. * correct / len(test_loader.dataset))) + +def my_loss(x, target, is_train=True): + if is_train: + return torch.nn.CrossEntropyLoss()(x, target) + else: + return torch.nn.CrossEntropyLoss(reduction='sum')(x, target) + +def main(): + # Training settings + parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser.add_argument('--train-steps', type=int, default=-1, metavar='N', + help='number of steps to train. Set -1 to run through whole dataset (default: -1)') + parser.add_argument('--lr', type=float, default=0.01, metavar='LR', + help='learning rate (default: 0.01)') + parser.add_argument('--batch-size', type=int, default=20, metavar='N', + help='input batch size for training (default: 20)') + parser.add_argument('--test-batch-size', type=int, default=20, metavar='N', + help='input batch size for testing (default: 20)') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--use_iobinding', action='store_true', default=False, + help='use IO binding') + parser.add_argument('--seed', type=int, default=42, metavar='S', + help='random seed (default: 42)') + parser.add_argument('--pytorch-only', action='store_true', default=False, + help='disables ONNX Runtime training') + parser.add_argument('--log-interval', type=int, default=100, metavar='N', + help='how many batches to wait before logging training status (default: 100)') + parser.add_argument('--view-graphs', action='store_true', default=False, + help='views forward and backward graphs') + parser.add_argument('--epochs', type=int, default=10, metavar='N', + help='number of epochs to train (default: 10)') + parser.add_argument('--log-level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], default='WARNING', + help='Log level (default: WARNING)') + + args = parser.parse_args() + + + # Common setup + torch.manual_seed(args.seed) + onnxruntime.set_seed(args.seed) + + # TODO: CUDA support is broken due to copying from PyTorch into ORT + if not args.no_cuda and torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + # device = 'cpu' + + ## Data loader + train_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=True, download=True, + transform=transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,))])), + batch_size=args.batch_size, + shuffle=True) + if args.test_batch_size > 0: + test_loader = torch.utils.data.DataLoader( + datasets.MNIST('./data', train=False, transform=transforms.Compose([ + transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])), + batch_size=args.test_batch_size, shuffle=True) + + # Model architecture + model = NeuralNet(input_size=784, hidden_size=500, num_classes=10).to(device) + if not args.pytorch_only: + print('Training MNIST on ORTModule....') + model = ORTModule(model, device, args.use_iobinding) + + # TODO: change it to False to stop saving ONNX models + model._save_onnx = True + model._save_onnx_prefix = 'MNIST' + + # Set log level + numeric_level = getattr(logging, args.log_level.upper(), None) + if not isinstance(numeric_level, int): + raise ValueError('Invalid log level: %s' % args.log_level) + logging.basicConfig(level=numeric_level) + else: + print('Training MNIST on vanilla PyTorch....') + optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) + + # Train loop + for epoch in range(1, args.epochs + 1): + train(args, model, device, optimizer, my_loss, train_loader, epoch) + if args.test_batch_size > 0: + test(args, model, device, my_loss, test_loader) + + +if __name__ == '__main__': + main() diff --git a/run_ortmodule_mvp_mnist_iobinding.sh b/run_ortmodule_mvp_mnist_iobinding.sh new file mode 100644 index 0000000000..45b25f9630 --- /dev/null +++ b/run_ortmodule_mvp_mnist_iobinding.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +cur_dir=$(basename `pwd`) + +if [[ ${cur_dir} != "RelWithDebInfo" ]] +then + echo "Going to build folder (aka build/Linux/RelWithDebInfo)" + cd build/Linux/RelWithDebInfo +fi + +echo "Exporting PYTHONPATH to use build dir as onnxruntime package" +export PYTHONPATH=$(pwd) + +echo "Copying PyTorch frontend source-code to build folder" +cp -Rf ../../../orttraining/orttraining/python/training/* ../../../build/Linux/RelWithDebInfo/onnxruntime/training/ + +echo "Running Flexible API (ORTModule)" +python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_iobinding.py --epochs 10 --log-interval 100 --log-level=DEBUG --use_iobinding