diff --git a/orttraining/orttraining/python/training/__init__.py b/orttraining/orttraining/python/training/__init__.py index bd1bd73fc9..179e4c9f13 100644 --- a/orttraining/orttraining/python/training/__init__.py +++ b/orttraining/orttraining/python/training/__init__.py @@ -8,3 +8,5 @@ from onnxruntime.capi.training.training_session import TrainingSession from .orttrainer_options import ORTTrainerOptions from .orttrainer import ORTTrainer, TrainStepInfo from . import amp, checkpoint, optim, model_desc_validation + +from .ortmodule import ORTModule \ No newline at end of file diff --git a/orttraining/orttraining/python/training/ortmodule.py b/orttraining/orttraining/python/training/ortmodule.py new file mode 100644 index 0000000000..99843a3c2e --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule.py @@ -0,0 +1,209 @@ +import copy +import io +import onnx +import onnxruntime +import os +import torch +import warnings + +from . import _utils + + +ONNX_OPSET_VERSION = 12 + + +class ORTModule(torch.nn.Module): + + def __init__(self, module): + print(f'ORTModule.__init__() was called') + super(ORTModule, self).__init__() + + # User will interact with it (debugging, etc) + self._original_module = module + + # Forward pass + self._onnx_forward = None + self._onnx_forward_initializers_desc = [] + self._onnx_forward_inputs_desc = [] + self._onnx_forward_outputs_desc = [] + + # Backward pass + self._onnx_backward = None + + def forward(self, *input, **kwargs): + print(f'ORTModule.forward() was called') + + if not self._onnx_forward: + original_forward_graph = ORTModule._get_forward_graph(self._original_module, *input, **kwargs) + # gradient_graph = ORTModule._build_gradient_graph(original_forward_graph) + # self.forward_graph, self.backward_graph = ORTModule._split_forward_and_backward(gradient_graph) + self._onnx_forward = original_forward_graph # TODO: hard-coding for MVP + self.forward_session = onnxruntime.InferenceSession(self._onnx_forward.SerializeToString()) + self._save_onnx_graph(self._onnx_forward, 'forward_mnist.onnx') + if not self._onnx_forward_initializers_desc: + self._onnx_forward_initializers_desc = self._get_initializer_from_graph(self._onnx_forward) + if not self._onnx_forward_inputs_desc: + self._onnx_forward_inputs_desc = self._get_input_from_graph(self._onnx_forward) + if not self._onnx_forward_outputs_desc: + self._onnx_forward_outputs_desc = self._get_output_from_graph(self._onnx_forward) + + print(f'Initializers: {self._onnx_forward_initializers_desc}') + print(f'Inputs: {self._onnx_forward_inputs_desc}') + print(f'Outpus: {self._onnx_forward_outputs_desc}') + + # Use a custom torch.autograd.Function to associate self.backward_graph as the + # gradient implementation for self.forward_graph. + class _ORTModuleFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, *input, **kwargs): + print(f'_ORTModuleFunction.forward() was called...') + # Note: A potential optimization would be to detect which of inputs and weights + # require a gradient. + # intermediates, outputs = self._run_forward_graph(inputs) # inputs, weights) + outputs = self._run_forward_graph(*input, **kwargs) # inputs, weights) + # ctx.save_for_backward(*intermediates) + outputs = [torch.from_numpy(out).requires_grad_(True) for out in outputs] + if len(outputs) == 1: + return outputs[0] + return tuple(outputs) + + @staticmethod + def backward(ctx, grad_output): + print(f'_ORTModuleFunction.backward() was called') + ... + # intermediates = ctx.saved_tensors + # grad_inputs, grad_weights = self._run_backward_graph( + # grad_output, intermediates) + # return grad_inputs, grad_weights + + return _ORTModuleFunction.apply(self._prepare_model_input(*input, **kwargs)) + + def _prepare_model_input(self, *input, **kwargs): + # Dictionary containing both inputs and initializers + input_with_initializer = {} + + # Inputs + for idx, input_data in enumerate(self.forward_session.get_inputs()): + input_with_initializer.update({input_data.name : input[idx].cpu().numpy()}) + + # Initializers + for idx, param in enumerate(self._original_module.named_parameters()): + input_with_initializer.update({param[0] : param[1].detach().numpy()}) + + return input_with_initializer + + def _run_forward_graph(self, data_with_initializer): #input, weights): + return self.forward_session.run(None, data_with_initializer) + + def _run_backward_graph(self, grad_output, intermediates): + # Use an InferenceSession to execute self.backward_graph. + # Return gradient tensors for inputs and weights. + ... + + @staticmethod + def _get_forward_graph(module, module_input): + # Export torch.nn.Module to ONNX with initializers as input + f = io.BytesIO() + torch.onnx.export(module, module_input, f, verbose=True, + opset_version=ONNX_OPSET_VERSION, + _retain_param_name=True, + training=torch.onnx.TrainingMode.TRAINING, + keep_initializers_as_inputs=True, + export_params=True) + return onnx.load_model_from_string(f.getvalue()) + + def _get_initializer_from_graph(self, graph): + # TODO: There is a tradefoo between memory footprint and total model export time + # Ideally we want to export the model using torch.onnx.export(.., export_params=False, keep_initializers_as_inputs=True) + # to obtain an ONNX model with minimal size and initializers as input. + # However, this results in (guessing) assuming only initializer's name end with '.weight' and '.bias'. + # Otherwise, it is not possible to separate input from initializer after the model is exported + # Options are: + # 1) If memory footprint is more important, we can export ONNX twice, varying keep_initializers_as_inputs flag + # ONNX model is small (400 bytes vs 1.6MB for MNIST), but export takes twice the time + # 2) If total export time is more important, we can export ONNX once, using export_params=True + # ONNX model is bigger, but export takes half the time + + # As performance is not the main goal in this first deliverable, using approach 2) for simplicity + initializers = [] + for initializer in graph.graph.initializer: + name = initializer.name + # TODO: Dynamic shape is not being handled yet + shape = initializer.dims + dtype = _utils.dtype_onnx_to_torch(initializer.data_type) + initializers.append({'name' : name, 'shape' : shape, 'dtype' : dtype}) + return initializers + + def _get_input_from_graph(self, graph): + inputs = [] + for elem in graph.graph.input: + for initializer in self._onnx_forward_initializers_desc: + if elem.name == initializer['name']: + break + else: + name = elem.name + # TODO: Dynamic shape is not being handled yet + shape = [dim.dim_value for dim in elem.type.tensor_type.shape.dim] + dtype = _utils.dtype_onnx_to_torch(elem.type.tensor_type.elem_type) + inputs.append({'name' : name, 'shape' : shape, 'dtype' : dtype}) + return inputs + + def _get_output_from_graph(self, graph): + outputs = [] + for elem in graph.graph.output: + for initializer in self._onnx_forward_initializers_desc: + if elem.name == initializer['name']: + break + else: + name = elem.name + # TODO: Dynamic shape is not being handled yet + shape = [dim.dim_value for dim in elem.type.tensor_type.shape.dim] + dtype = _utils.dtype_onnx_to_torch(elem.type.tensor_type.elem_type) + outputs.append({'name' : name, 'shape' : shape, 'dtype' : dtype}) + return outputs + + @staticmethod + def _save_onnx_graph(onnx_graph, path): + r"""Persists ONNX model into :py:attr:`path` + + The model will be saved as a Google Protocol Buffers (aka protobuf) file as per ONNX standard. + The graph includes full information, including inference and training metadata. + + Args: + onnx_graph (onnx.ModelProto): Either forward or backward graph + path (str): Full path, including filename, to save the ONNX model in the filesystem + + Raises: + ValueError: raised when `path` is not valid path + """ + + assert isinstance(path, str), "'path' must be a valid path string" + dir_name = os.path.dirname(path) + file_name = os.path.basename(path) + if (dir_name and not os.path.exists(dir_name)) or not file_name: + warnings.warn("'path' is not valid or does not exist") + return + + with open(path, "wb") as f: + f.write(onnx_graph.SerializeToString()) + + @staticmethod + def _build_gradient_graph(forward_graph): + # Invoke the C++ GradientBuilder implementation via pybind. + # Return an ONNX graph that contains the forward and backward nodes, which takes the + # following inputs: + # * Module inputs + # * Module weights + # * Gradients with respect to the module outputs + # …and produces gradients with respect to the module inputs and weights. + ... + + @staticmethod + def _split_forward_and_backward(gradient_graph): + # Split the result of _build_gradient_graph into two subgraphs: + # * A forward graph that takes module inputs and weights as input, and produces module + # outputs and (“stashed”) intermediate tensors as output. + # * A backward graph that takes intermediate tensors, module weights, and gradients + # respect to the module outputs as inputs, and produces gradients with respect to the + # module inputs and weights. + return (None, None) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py new file mode 100644 index 0000000000..0f3554c342 --- /dev/null +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py @@ -0,0 +1,64 @@ +import torch +from torchvision import datasets, transforms + +from onnxruntime import set_seed +from onnxruntime.training import ORTModule + +import _test_commons +import _test_helpers + + +class NeuralNet(torch.nn.Module): + def __init__(self, input_size, hidden_size, num_classes): + super(NeuralNet, self).__init__() + + self.fc1 = torch.nn.Linear(input_size, hidden_size) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(hidden_size, num_classes) + + def forward(self, input1): + out = self.fc1(input1) + out = self.relu(out) + out = self.fc2(out) + return out + +# Model architecture +lr = 1e-4 +batch_size=20 +seed=42 + +torch.manual_seed(seed) +set_seed(seed) + + +model = NeuralNet(input_size=784, hidden_size=500, num_classes=10) +model = ORTModule(model) +criterion = torch.nn.CrossEntropyLoss() +optimizer = torch.optim.SGD(model.parameters(), lr=lr) + +# Data loader +train_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=True, download=True, + transform=transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,))])), + batch_size=batch_size, + shuffle=True) + +# Training Loop +print('Training MNIST on ORTModule....') +loss = float('inf') +for iteration, (data, target) in enumerate(train_loader): + if iteration == 1: + print(f'Final loss is {loss}') + break + + data = data.reshape(data.shape[0], -1) + optimizer.zero_grad() + output = model(data) + print(f'Output from forward has shape {output[0].size()}: {output[0]}') + loss = criterion(output, target) + # loss.backward() + # optimizer.step() + + if iteration == 0: + print(f'Initial loss is {loss}') +print('Tah dah!') \ No newline at end of file