diff --git a/orttraining/orttraining/core/session/training_session.cc b/orttraining/orttraining/core/session/training_session.cc index e2828c8051..609a08976f 100644 --- a/orttraining/orttraining/core/session/training_session.cc +++ b/orttraining/orttraining/core/session/training_session.cc @@ -346,6 +346,18 @@ Status TrainingSession::ConfigureForTraining( } } + #if 1 + // TODO: Do not merge this on master + // Saving training model before optimizer nodes are added + // This makes easier to manually edit MNIST model later + if ((IsRootNode(config) || (config.pipeline_config.has_value() && + DistributedRunContext::GroupId(WorkerGroupType::ModelParallel) == 0)) && + config.model_with_training_graph_path.has_value()) { + ORT_IGNORE_RETURN_VALUE(Save( + config.model_with_training_graph_path.value(), SaveOption::NO_RELOAD)); + } + #endif + // add optimizer or gradient accumulation if (config.optimizer_config.has_value()) { OptimizerGraphConfig opt_graph_config{}; @@ -412,12 +424,16 @@ Status TrainingSession::ConfigureForTraining( // conflict. It is user's responsibility to make sure different rank is passed in with different. Also, to avoid // writing conflict, only the ranks in first pipeline group write the partition file out. // model_with_training_graph_path value. + #if 0 + // TODO: Do not merge this on master + // This is being called above, before optimizers nodes are added if ((IsRootNode(config) || (config.pipeline_config.has_value() && DistributedRunContext::GroupId(WorkerGroupType::ModelParallel) == 0)) && config.model_with_training_graph_path.has_value()) { ORT_IGNORE_RETURN_VALUE(Save( config.model_with_training_graph_path.value(), SaveOption::NO_RELOAD)); } + #endif // After pipeline partition, we need to return the inputs allowed in this partition. if (config.pipeline_config.has_value()) { diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 416dae5d82..84a6db113f 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -21,6 +21,9 @@ using namespace onnxruntime::logging; using namespace onnxruntime::training; struct TrainingParameters { + std::string model_with_loss_function_path; + std::string model_with_training_graph_path; + std::string loss_output_name; std::unordered_set weights_to_train; std::unordered_set weights_not_to_train; @@ -84,6 +87,8 @@ TrainingConfigurationResult ConfigureSessionForTraining( } training::TrainingSession::TrainingConfiguration config{}; + config.model_with_loss_function_path = parameters.model_with_loss_function_path; + config.model_with_training_graph_path = parameters.model_with_training_graph_path; config.weight_names_to_train = parameters.weights_to_train; config.weight_names_to_not_train = parameters.weights_not_to_train; config.immutable_weights = parameters.immutable_weights; @@ -190,6 +195,8 @@ void CopyMPIContextToTrainingParameters(TrainingParameters& parameters, const lo void addObjectMethodsForTraining(py::module& m) { py::class_ parameters(m, "TrainingParameters", R"pbdoc(Configuration information for training.)pbdoc"); parameters.def(py::init()) + .def_readwrite("model_with_loss_function_path", &TrainingParameters::model_with_loss_function_path) + .def_readwrite("model_with_training_graph_path", &TrainingParameters::model_with_training_graph_path) .def_readwrite("loss_output_name", &TrainingParameters::loss_output_name) .def_readwrite("immutable_weights", &TrainingParameters::immutable_weights) .def_readwrite("weights_not_to_train", &TrainingParameters::weights_not_to_train) diff --git a/orttraining/orttraining/python/training/ortmodule.py b/orttraining/orttraining/python/training/ortmodule.py index 99843a3c2e..4695a6afc8 100644 --- a/orttraining/orttraining/python/training/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule.py @@ -38,7 +38,33 @@ class ORTModule(torch.nn.Module): # gradient_graph = ORTModule._build_gradient_graph(original_forward_graph) # self.forward_graph, self.backward_graph = ORTModule._split_forward_and_backward(gradient_graph) self._onnx_forward = original_forward_graph # TODO: hard-coding for MVP + # import pdb; pdb.set_trace() self.forward_session = onnxruntime.InferenceSession(self._onnx_forward.SerializeToString()) + + + # TrainingParameters + # ort_parameters = onnxruntime.TrainingParameters() + # ort_parameters.loss_output_name = "loss" + # ort_parameters.use_mixed_precision = False + # ort_parameters.world_rank = 0 + # ort_parameters.world_size = 1 + # ort_parameters.gradient_accumulation_steps = 1 + # ort_parameters.allreduce_post_accumulation = False + # ort_parameters.deepspeed_zero_stage = 0 + # ort_parameters.enable_grad_norm_clip = False + # ort_parameters.set_gradients_as_graph_outputs = False + # ort_parameters.use_invertible_layernorm_grad = False + # ort_parameters.training_optimizer_name = "SGDOptimizer" + # ort_parameters.lr_params_feed_name = "Learning_Rate" + # ort_parameters.weights_to_train = trainable_params + # ort_parameters.optimizer_attributes_map = optimizer_attributes_map + # ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map + + # # SessionOptions + # session_options = onnxruntime.SessionOptions() + # session_options.use_deterministic_compute = self.options.debug.deterministic_compute + # self.forward_session = onnxruntime.TrainingSession(self._onnx_forward.SerializeToString(), ort_parameters, session_options) + self._save_onnx_graph(self._onnx_forward, 'forward_mnist.onnx') if not self._onnx_forward_initializers_desc: self._onnx_forward_initializers_desc = self._get_initializer_from_graph(self._onnx_forward) @@ -60,9 +86,15 @@ class ORTModule(torch.nn.Module): # Note: A potential optimization would be to detect which of inputs and weights # require a gradient. # intermediates, outputs = self._run_forward_graph(inputs) # inputs, weights) + # import pdb; pdb.set_trace() outputs = self._run_forward_graph(*input, **kwargs) # inputs, weights) - # ctx.save_for_backward(*intermediates) outputs = [torch.from_numpy(out).requires_grad_(True) for out in outputs] + + # TODO: Properly save intermediate tensors and remove them from model output + ctx.save_for_backward(outputs[1]) + outputs = [outputs[0]] + + # TODO: Properly support original module output format if len(outputs) == 1: return outputs[0] return tuple(outputs) @@ -82,6 +114,7 @@ class ORTModule(torch.nn.Module): # Dictionary containing both inputs and initializers input_with_initializer = {} + # import pdb; pdb.set_trace() # Inputs for idx, input_data in enumerate(self.forward_session.get_inputs()): input_with_initializer.update({input_data.name : input[idx].cpu().numpy()}) @@ -93,6 +126,7 @@ class ORTModule(torch.nn.Module): return input_with_initializer def _run_forward_graph(self, data_with_initializer): #input, weights): + # import pdb; pdb.set_trace() return self.forward_session.run(None, data_with_initializer) def _run_backward_graph(self, grad_output, intermediates): @@ -102,15 +136,18 @@ class ORTModule(torch.nn.Module): @staticmethod def _get_forward_graph(module, module_input): + # TODO: Pytorch module must be exported to ONNX and splitted + # Hard-coding with MNIST stub for MVP # Export torch.nn.Module to ONNX with initializers as input - f = io.BytesIO() - torch.onnx.export(module, module_input, f, verbose=True, - opset_version=ONNX_OPSET_VERSION, - _retain_param_name=True, - training=torch.onnx.TrainingMode.TRAINING, - keep_initializers_as_inputs=True, - export_params=True) - return onnx.load_model_from_string(f.getvalue()) + # f = io.BytesIO() + # torch.onnx.export(module, module_input, f, verbose=True, + # opset_version=ONNX_OPSET_VERSION, + # _retain_param_name=True, + # training=torch.onnx.TrainingMode.TRAINING, + # keep_initializers_as_inputs=True, + # export_params=True) + # return onnx.load_model_from_string(f.getvalue()) + return onnx.load('/home/thiagofc/mnist_onnx/mnist_with_training_forward_sliced.onnx') def _get_initializer_from_graph(self, graph): # TODO: There is a tradefoo between memory footprint and total model export time diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py index fc86cb48e0..cae26436cc 100644 --- a/orttraining/orttraining/python/training/orttrainer.py +++ b/orttraining/orttraining/python/training/orttrainer.py @@ -621,6 +621,8 @@ class ORTTrainer(object): # TrainingParameters ort_parameters = ort.TrainingParameters() + ort_parameters.model_with_loss_function_path = '/home/thiagofc/mnist_onnx/mnist_with_loss.onnx' + ort_parameters.model_with_training_graph_path = '/home/thiagofc/mnist_onnx/mnist_with_training.onnx' ort_parameters.loss_output_name = loss_name ort_parameters.use_mixed_precision = self.options.mixed_precision.enabled ort_parameters.world_rank = self.options.distributed.world_rank @@ -629,7 +631,7 @@ class ORTTrainer(object): ort_parameters.allreduce_post_accumulation = self.options.distributed.allreduce_post_accumulation ort_parameters.deepspeed_zero_stage = self.options.distributed.deepspeed_zero_optimization.stage ort_parameters.enable_grad_norm_clip = self.options.utils.grad_norm_clip - ort_parameters.set_gradients_as_graph_outputs = False + ort_parameters.set_gradients_as_graph_outputs = True ort_parameters.use_invertible_layernorm_grad = self.options.utils.invertible_layer_norm_gradient ort_parameters.training_optimizer_name = self.optim_config.name ort_parameters.lr_params_feed_name = self.model_desc.learning_rate.name diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py index 0f3554c342..96f62f3828 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py @@ -53,9 +53,10 @@ for iteration, (data, target) in enumerate(train_loader): data = data.reshape(data.shape[0], -1) optimizer.zero_grad() - output = model(data) - print(f'Output from forward has shape {output[0].size()}: {output[0]}') - loss = criterion(output, target) + probability = model(data) + print(f'Output from forward has shape {probability.size()}: {probability}') + # import pdb; pdb.set_trace() + loss = criterion(probability, target) # loss.backward() # optimizer.step() diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_basic_transform_model.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_basic_transform_model.py new file mode 100644 index 0000000000..7fbc6a2d3e --- /dev/null +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_basic_transform_model.py @@ -0,0 +1,148 @@ +# coding=utf8 +import sys +import onnx +from onnx import helper, shape_inference +from onnx import TensorProto +import numpy as np +from onnx import numpy_helper + +if len(sys.argv) < 2: + print("Please give model path...") + exit(1) + +input_model_name = sys.argv[1] +output_forward_model_name = input_model_name[:-5] + '_forward_sliced.onnx' +output_backward_model_name = input_model_name[:-5] + '_backward_sliced.onnx' + +def add_model_input_from_initializer(model, initializer, docstring=None): + new_input = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, initializer.dims, docstring) + model.graph.input.append(new_input) + +def add_model_input(model, name, data_type, dims, docstring=None): + new_input = onnx.helper.make_tensor_value_info(name, data_type, dims, docstring) + model.graph.input.append(new_input) + +def add_output(model, name, data_type = None, docstring = None): + new_output = model.graph.value_info.add() + new_output.name = name + if data_type: + new_output.type.CopyFrom(data_type) + if docstring: + new_output.doc_string = docstring + model.graph.output.append(new_output) + +def find_model_input(model, input_name): + for input in model.graph.input: + if input.name == input_name: + return input + return None + +def find_model_output(model, output_name): + for output in model.graph.output: + if output.name == output_name: + return output + return None + +def find_initializer(model, name): + for initializer in model.graph.initializer: + if initializer.name == name: + return initializer + return None + +def find_node(model, name): + for node in model.graph.node: + if node.name == name: + return node + return None + +############################################################################### +# FORWARD PASS GRAPH ########################################################## +############################################################################### +model = onnx.load(input_model_name) + +# Remove model inputs +# They are: label +node = find_model_input(model, 'label') +model.graph.input.remove(node) + +# Remove model outputs +# They are: loss +node = find_model_output(model, 'loss') +model.graph.output.remove(node) +node = find_model_output(model, 'fc1.bias_grad') +model.graph.output.remove(node) +node = find_model_output(model, 'fc1.weight_grad') +model.graph.output.remove(node) +node = find_model_output(model, 'fc2.bias_grad') +model.graph.output.remove(node) +node = find_model_output(model, 'fc2.weight_grad') +model.graph.output.remove(node) + +# Add input with same name, type and shape as the initializers +# They are: [fc1.bias, fc1.weight, fc2.bias, fc2.weight] +node = find_initializer(model, 'fc1.bias') +add_model_input_from_initializer(model, node, 'thiagofc: add fc1.bias as model input') +node = find_initializer(model, 'fc1.weight') +add_model_input_from_initializer(model, node, 'thiagofc: add fc1.weight as model input') +node = find_initializer(model, 'fc2.bias') +add_model_input_from_initializer(model, node, 'thiagofc: add fc2.bias as model input') +node = find_initializer(model, 'fc2.weight') +add_model_input_from_initializer(model, node, 'thiagofc: add fc2.weight as model input') + +# Remove initializers from model +# They are: [fc1.bias, fc1.weight, fc2.bias, fc2.weight] +# TODO: Do this when we are able to distinguish inputs from initializers +# model.graph.initializer.remove(node) +# model.graph.initializer.remove(node) +# model.graph.initializer.remove(node) +# model.graph.initializer.remove(node) + +# Remove backward-related initializers +# They are: [loss_grad, ZeroConstant] +node = find_initializer(model, 'loss_grad') +model.graph.initializer.remove(node) +node = find_initializer(model, 'ZeroConstant') +model.graph.initializer.remove(node) + +# Remove OPs +# They are: [SoftmaxCrossEntropyLoss_3, SoftmaxCrossEntropyLoss_3_Grad/SoftmaxCrossEntropyLossGrad_0, +# Gemm_2_Grad/ReduceSum_3, Gemm_2_Grad/Identity_4, Gemm_2_Grad/Gemm_2, Gemm_2_Grad/Gemm_1, +# Relu_1_Grad/ReluGrad_0, Gemm_0_Grad/Gemm_1, Gemm_0_Grad/ReduceSum_2, Gemm_0_Grad/Identity_3] +node = find_node(model, 'SoftmaxCrossEntropyLoss_3') +model.graph.node.remove(node) +node = find_node(model, 'SoftmaxCrossEntropyLoss_3_Grad/SoftmaxCrossEntropyLossGrad_0') +model.graph.node.remove(node) +node = find_node(model, 'Gemm_2_Grad/ReduceSum_3') +model.graph.node.remove(node) +node = find_node(model, 'Gemm_2_Grad/Identity_4') +model.graph.node.remove(node) +node = find_node(model, 'Gemm_2_Grad/Gemm_2') +model.graph.node.remove(node) +node = find_node(model, 'Gemm_2_Grad/Gemm_1') +model.graph.node.remove(node) +node = find_node(model, 'Relu_1_Grad/ReluGrad_0') +model.graph.node.remove(node) +node = find_node(model, 'Gemm_0_Grad/Gemm_1') +model.graph.node.remove(node) +node = find_node(model, 'Gemm_0_Grad/ReduceSum_2') +model.graph.node.remove(node) +node = find_node(model, 'Gemm_0_Grad/Identity_3') +model.graph.node.remove(node) + +# Add new outputs: +# They are: 7 +add_output(model, '7', None, 'thiagofc: add 7 as model output') + +with open(output_forward_model_name, "wb") as f: + f.write(model.SerializeToString()) + + +############################################################################### +# FORWARD PASS GRAPH ########################################################## +############################################################################### +model = onnx.load(input_model_name) + + + +with open(output_backward_model_name, "wb") as f: + f.write(model.SerializeToString()) diff --git a/samples/python/mnist/mnist_training.py b/samples/python/mnist/mnist_training.py index c3c4c86963..3d23987e60 100644 --- a/samples/python/mnist/mnist_training.py +++ b/samples/python/mnist/mnist_training.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from torchvision import datasets, transforms import onnxruntime -from onnxruntime.experimental import ORTTrainer, ORTTrainerOptions, optim +from onnxruntime.training import ORTTrainer, ORTTrainerOptions, optim # Pytorch model