Perform forward pass using training graph with intermediate outputs

This commit is contained in:
Thiago Crepaldi 2020-10-07 12:13:29 -07:00
parent 11b69f141e
commit 77cefcd6c2
7 changed files with 225 additions and 14 deletions

View file

@ -346,6 +346,18 @@ Status TrainingSession::ConfigureForTraining(
}
}
#if 1
// TODO: Do not merge this on master
// Saving training model before optimizer nodes are added
// This makes easier to manually edit MNIST model later
if ((IsRootNode(config) || (config.pipeline_config.has_value() &&
DistributedRunContext::GroupId(WorkerGroupType::ModelParallel) == 0)) &&
config.model_with_training_graph_path.has_value()) {
ORT_IGNORE_RETURN_VALUE(Save(
config.model_with_training_graph_path.value(), SaveOption::NO_RELOAD));
}
#endif
// add optimizer or gradient accumulation
if (config.optimizer_config.has_value()) {
OptimizerGraphConfig opt_graph_config{};
@ -412,12 +424,16 @@ Status TrainingSession::ConfigureForTraining(
// conflict. It is user's responsibility to make sure different rank is passed in with different. Also, to avoid
// writing conflict, only the ranks in first pipeline group write the partition file out.
// model_with_training_graph_path value.
#if 0
// TODO: Do not merge this on master
// This is being called above, before optimizers nodes are added
if ((IsRootNode(config) || (config.pipeline_config.has_value() &&
DistributedRunContext::GroupId(WorkerGroupType::ModelParallel) == 0)) &&
config.model_with_training_graph_path.has_value()) {
ORT_IGNORE_RETURN_VALUE(Save(
config.model_with_training_graph_path.value(), SaveOption::NO_RELOAD));
}
#endif
// After pipeline partition, we need to return the inputs allowed in this partition.
if (config.pipeline_config.has_value()) {

View file

@ -21,6 +21,9 @@ using namespace onnxruntime::logging;
using namespace onnxruntime::training;
struct TrainingParameters {
std::string model_with_loss_function_path;
std::string model_with_training_graph_path;
std::string loss_output_name;
std::unordered_set<std::string> weights_to_train;
std::unordered_set<std::string> weights_not_to_train;
@ -84,6 +87,8 @@ TrainingConfigurationResult ConfigureSessionForTraining(
}
training::TrainingSession::TrainingConfiguration config{};
config.model_with_loss_function_path = parameters.model_with_loss_function_path;
config.model_with_training_graph_path = parameters.model_with_training_graph_path;
config.weight_names_to_train = parameters.weights_to_train;
config.weight_names_to_not_train = parameters.weights_not_to_train;
config.immutable_weights = parameters.immutable_weights;
@ -190,6 +195,8 @@ void CopyMPIContextToTrainingParameters(TrainingParameters& parameters, const lo
void addObjectMethodsForTraining(py::module& m) {
py::class_<TrainingParameters> parameters(m, "TrainingParameters", R"pbdoc(Configuration information for training.)pbdoc");
parameters.def(py::init())
.def_readwrite("model_with_loss_function_path", &TrainingParameters::model_with_loss_function_path)
.def_readwrite("model_with_training_graph_path", &TrainingParameters::model_with_training_graph_path)
.def_readwrite("loss_output_name", &TrainingParameters::loss_output_name)
.def_readwrite("immutable_weights", &TrainingParameters::immutable_weights)
.def_readwrite("weights_not_to_train", &TrainingParameters::weights_not_to_train)

View file

@ -38,7 +38,33 @@ class ORTModule(torch.nn.Module):
# gradient_graph = ORTModule._build_gradient_graph(original_forward_graph)
# self.forward_graph, self.backward_graph = ORTModule._split_forward_and_backward(gradient_graph)
self._onnx_forward = original_forward_graph # TODO: hard-coding for MVP
# import pdb; pdb.set_trace()
self.forward_session = onnxruntime.InferenceSession(self._onnx_forward.SerializeToString())
# TrainingParameters
# ort_parameters = onnxruntime.TrainingParameters()
# ort_parameters.loss_output_name = "loss"
# ort_parameters.use_mixed_precision = False
# ort_parameters.world_rank = 0
# ort_parameters.world_size = 1
# ort_parameters.gradient_accumulation_steps = 1
# ort_parameters.allreduce_post_accumulation = False
# ort_parameters.deepspeed_zero_stage = 0
# ort_parameters.enable_grad_norm_clip = False
# ort_parameters.set_gradients_as_graph_outputs = False
# ort_parameters.use_invertible_layernorm_grad = False
# ort_parameters.training_optimizer_name = "SGDOptimizer"
# ort_parameters.lr_params_feed_name = "Learning_Rate"
# ort_parameters.weights_to_train = trainable_params
# ort_parameters.optimizer_attributes_map = optimizer_attributes_map
# ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map
# # SessionOptions
# session_options = onnxruntime.SessionOptions()
# session_options.use_deterministic_compute = self.options.debug.deterministic_compute
# self.forward_session = onnxruntime.TrainingSession(self._onnx_forward.SerializeToString(), ort_parameters, session_options)
self._save_onnx_graph(self._onnx_forward, 'forward_mnist.onnx')
if not self._onnx_forward_initializers_desc:
self._onnx_forward_initializers_desc = self._get_initializer_from_graph(self._onnx_forward)
@ -60,9 +86,15 @@ class ORTModule(torch.nn.Module):
# Note: A potential optimization would be to detect which of inputs and weights
# require a gradient.
# intermediates, outputs = self._run_forward_graph(inputs) # inputs, weights)
# import pdb; pdb.set_trace()
outputs = self._run_forward_graph(*input, **kwargs) # inputs, weights)
# ctx.save_for_backward(*intermediates)
outputs = [torch.from_numpy(out).requires_grad_(True) for out in outputs]
# TODO: Properly save intermediate tensors and remove them from model output
ctx.save_for_backward(outputs[1])
outputs = [outputs[0]]
# TODO: Properly support original module output format
if len(outputs) == 1:
return outputs[0]
return tuple(outputs)
@ -82,6 +114,7 @@ class ORTModule(torch.nn.Module):
# Dictionary containing both inputs and initializers
input_with_initializer = {}
# import pdb; pdb.set_trace()
# Inputs
for idx, input_data in enumerate(self.forward_session.get_inputs()):
input_with_initializer.update({input_data.name : input[idx].cpu().numpy()})
@ -93,6 +126,7 @@ class ORTModule(torch.nn.Module):
return input_with_initializer
def _run_forward_graph(self, data_with_initializer): #input, weights):
# import pdb; pdb.set_trace()
return self.forward_session.run(None, data_with_initializer)
def _run_backward_graph(self, grad_output, intermediates):
@ -102,15 +136,18 @@ class ORTModule(torch.nn.Module):
@staticmethod
def _get_forward_graph(module, module_input):
# TODO: Pytorch module must be exported to ONNX and splitted
# Hard-coding with MNIST stub for MVP
# Export torch.nn.Module to ONNX with initializers as input
f = io.BytesIO()
torch.onnx.export(module, module_input, f, verbose=True,
opset_version=ONNX_OPSET_VERSION,
_retain_param_name=True,
training=torch.onnx.TrainingMode.TRAINING,
keep_initializers_as_inputs=True,
export_params=True)
return onnx.load_model_from_string(f.getvalue())
# f = io.BytesIO()
# torch.onnx.export(module, module_input, f, verbose=True,
# opset_version=ONNX_OPSET_VERSION,
# _retain_param_name=True,
# training=torch.onnx.TrainingMode.TRAINING,
# keep_initializers_as_inputs=True,
# export_params=True)
# return onnx.load_model_from_string(f.getvalue())
return onnx.load('/home/thiagofc/mnist_onnx/mnist_with_training_forward_sliced.onnx')
def _get_initializer_from_graph(self, graph):
# TODO: There is a tradefoo between memory footprint and total model export time

View file

@ -621,6 +621,8 @@ class ORTTrainer(object):
# TrainingParameters
ort_parameters = ort.TrainingParameters()
ort_parameters.model_with_loss_function_path = '/home/thiagofc/mnist_onnx/mnist_with_loss.onnx'
ort_parameters.model_with_training_graph_path = '/home/thiagofc/mnist_onnx/mnist_with_training.onnx'
ort_parameters.loss_output_name = loss_name
ort_parameters.use_mixed_precision = self.options.mixed_precision.enabled
ort_parameters.world_rank = self.options.distributed.world_rank
@ -629,7 +631,7 @@ class ORTTrainer(object):
ort_parameters.allreduce_post_accumulation = self.options.distributed.allreduce_post_accumulation
ort_parameters.deepspeed_zero_stage = self.options.distributed.deepspeed_zero_optimization.stage
ort_parameters.enable_grad_norm_clip = self.options.utils.grad_norm_clip
ort_parameters.set_gradients_as_graph_outputs = False
ort_parameters.set_gradients_as_graph_outputs = True
ort_parameters.use_invertible_layernorm_grad = self.options.utils.invertible_layer_norm_gradient
ort_parameters.training_optimizer_name = self.optim_config.name
ort_parameters.lr_params_feed_name = self.model_desc.learning_rate.name

View file

@ -53,9 +53,10 @@ for iteration, (data, target) in enumerate(train_loader):
data = data.reshape(data.shape[0], -1)
optimizer.zero_grad()
output = model(data)
print(f'Output from forward has shape {output[0].size()}: {output[0]}')
loss = criterion(output, target)
probability = model(data)
print(f'Output from forward has shape {probability.size()}: {probability}')
# import pdb; pdb.set_trace()
loss = criterion(probability, target)
# loss.backward()
# optimizer.step()

View file

@ -0,0 +1,148 @@
# coding=utf8
import sys
import onnx
from onnx import helper, shape_inference
from onnx import TensorProto
import numpy as np
from onnx import numpy_helper
if len(sys.argv) < 2:
print("Please give model path...")
exit(1)
input_model_name = sys.argv[1]
output_forward_model_name = input_model_name[:-5] + '_forward_sliced.onnx'
output_backward_model_name = input_model_name[:-5] + '_backward_sliced.onnx'
def add_model_input_from_initializer(model, initializer, docstring=None):
new_input = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, initializer.dims, docstring)
model.graph.input.append(new_input)
def add_model_input(model, name, data_type, dims, docstring=None):
new_input = onnx.helper.make_tensor_value_info(name, data_type, dims, docstring)
model.graph.input.append(new_input)
def add_output(model, name, data_type = None, docstring = None):
new_output = model.graph.value_info.add()
new_output.name = name
if data_type:
new_output.type.CopyFrom(data_type)
if docstring:
new_output.doc_string = docstring
model.graph.output.append(new_output)
def find_model_input(model, input_name):
for input in model.graph.input:
if input.name == input_name:
return input
return None
def find_model_output(model, output_name):
for output in model.graph.output:
if output.name == output_name:
return output
return None
def find_initializer(model, name):
for initializer in model.graph.initializer:
if initializer.name == name:
return initializer
return None
def find_node(model, name):
for node in model.graph.node:
if node.name == name:
return node
return None
###############################################################################
# FORWARD PASS GRAPH ##########################################################
###############################################################################
model = onnx.load(input_model_name)
# Remove model inputs
# They are: label
node = find_model_input(model, 'label')
model.graph.input.remove(node)
# Remove model outputs
# They are: loss
node = find_model_output(model, 'loss')
model.graph.output.remove(node)
node = find_model_output(model, 'fc1.bias_grad')
model.graph.output.remove(node)
node = find_model_output(model, 'fc1.weight_grad')
model.graph.output.remove(node)
node = find_model_output(model, 'fc2.bias_grad')
model.graph.output.remove(node)
node = find_model_output(model, 'fc2.weight_grad')
model.graph.output.remove(node)
# Add input with same name, type and shape as the initializers
# They are: [fc1.bias, fc1.weight, fc2.bias, fc2.weight]
node = find_initializer(model, 'fc1.bias')
add_model_input_from_initializer(model, node, 'thiagofc: add fc1.bias as model input')
node = find_initializer(model, 'fc1.weight')
add_model_input_from_initializer(model, node, 'thiagofc: add fc1.weight as model input')
node = find_initializer(model, 'fc2.bias')
add_model_input_from_initializer(model, node, 'thiagofc: add fc2.bias as model input')
node = find_initializer(model, 'fc2.weight')
add_model_input_from_initializer(model, node, 'thiagofc: add fc2.weight as model input')
# Remove initializers from model
# They are: [fc1.bias, fc1.weight, fc2.bias, fc2.weight]
# TODO: Do this when we are able to distinguish inputs from initializers
# model.graph.initializer.remove(node)
# model.graph.initializer.remove(node)
# model.graph.initializer.remove(node)
# model.graph.initializer.remove(node)
# Remove backward-related initializers
# They are: [loss_grad, ZeroConstant]
node = find_initializer(model, 'loss_grad')
model.graph.initializer.remove(node)
node = find_initializer(model, 'ZeroConstant')
model.graph.initializer.remove(node)
# Remove OPs
# They are: [SoftmaxCrossEntropyLoss_3, SoftmaxCrossEntropyLoss_3_Grad/SoftmaxCrossEntropyLossGrad_0,
# Gemm_2_Grad/ReduceSum_3, Gemm_2_Grad/Identity_4, Gemm_2_Grad/Gemm_2, Gemm_2_Grad/Gemm_1,
# Relu_1_Grad/ReluGrad_0, Gemm_0_Grad/Gemm_1, Gemm_0_Grad/ReduceSum_2, Gemm_0_Grad/Identity_3]
node = find_node(model, 'SoftmaxCrossEntropyLoss_3')
model.graph.node.remove(node)
node = find_node(model, 'SoftmaxCrossEntropyLoss_3_Grad/SoftmaxCrossEntropyLossGrad_0')
model.graph.node.remove(node)
node = find_node(model, 'Gemm_2_Grad/ReduceSum_3')
model.graph.node.remove(node)
node = find_node(model, 'Gemm_2_Grad/Identity_4')
model.graph.node.remove(node)
node = find_node(model, 'Gemm_2_Grad/Gemm_2')
model.graph.node.remove(node)
node = find_node(model, 'Gemm_2_Grad/Gemm_1')
model.graph.node.remove(node)
node = find_node(model, 'Relu_1_Grad/ReluGrad_0')
model.graph.node.remove(node)
node = find_node(model, 'Gemm_0_Grad/Gemm_1')
model.graph.node.remove(node)
node = find_node(model, 'Gemm_0_Grad/ReduceSum_2')
model.graph.node.remove(node)
node = find_node(model, 'Gemm_0_Grad/Identity_3')
model.graph.node.remove(node)
# Add new outputs:
# They are: 7
add_output(model, '7', None, 'thiagofc: add 7 as model output')
with open(output_forward_model_name, "wb") as f:
f.write(model.SerializeToString())
###############################################################################
# FORWARD PASS GRAPH ##########################################################
###############################################################################
model = onnx.load(input_model_name)
with open(output_backward_model_name, "wb") as f:
f.write(model.SerializeToString())

View file

@ -8,7 +8,7 @@ import torch.nn.functional as F
from torchvision import datasets, transforms
import onnxruntime
from onnxruntime.experimental import ORTTrainer, ORTTrainerOptions, optim
from onnxruntime.training import ORTTrainer, ORTTrainerOptions, optim
# Pytorch model