mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Perform forward pass using training graph with intermediate outputs
This commit is contained in:
parent
11b69f141e
commit
77cefcd6c2
7 changed files with 225 additions and 14 deletions
|
|
@ -346,6 +346,18 @@ Status TrainingSession::ConfigureForTraining(
|
|||
}
|
||||
}
|
||||
|
||||
#if 1
|
||||
// TODO: Do not merge this on master
|
||||
// Saving training model before optimizer nodes are added
|
||||
// This makes easier to manually edit MNIST model later
|
||||
if ((IsRootNode(config) || (config.pipeline_config.has_value() &&
|
||||
DistributedRunContext::GroupId(WorkerGroupType::ModelParallel) == 0)) &&
|
||||
config.model_with_training_graph_path.has_value()) {
|
||||
ORT_IGNORE_RETURN_VALUE(Save(
|
||||
config.model_with_training_graph_path.value(), SaveOption::NO_RELOAD));
|
||||
}
|
||||
#endif
|
||||
|
||||
// add optimizer or gradient accumulation
|
||||
if (config.optimizer_config.has_value()) {
|
||||
OptimizerGraphConfig opt_graph_config{};
|
||||
|
|
@ -412,12 +424,16 @@ Status TrainingSession::ConfigureForTraining(
|
|||
// conflict. It is user's responsibility to make sure different rank is passed in with different. Also, to avoid
|
||||
// writing conflict, only the ranks in first pipeline group write the partition file out.
|
||||
// model_with_training_graph_path value.
|
||||
#if 0
|
||||
// TODO: Do not merge this on master
|
||||
// This is being called above, before optimizers nodes are added
|
||||
if ((IsRootNode(config) || (config.pipeline_config.has_value() &&
|
||||
DistributedRunContext::GroupId(WorkerGroupType::ModelParallel) == 0)) &&
|
||||
config.model_with_training_graph_path.has_value()) {
|
||||
ORT_IGNORE_RETURN_VALUE(Save(
|
||||
config.model_with_training_graph_path.value(), SaveOption::NO_RELOAD));
|
||||
}
|
||||
#endif
|
||||
|
||||
// After pipeline partition, we need to return the inputs allowed in this partition.
|
||||
if (config.pipeline_config.has_value()) {
|
||||
|
|
|
|||
|
|
@ -21,6 +21,9 @@ using namespace onnxruntime::logging;
|
|||
using namespace onnxruntime::training;
|
||||
|
||||
struct TrainingParameters {
|
||||
std::string model_with_loss_function_path;
|
||||
std::string model_with_training_graph_path;
|
||||
|
||||
std::string loss_output_name;
|
||||
std::unordered_set<std::string> weights_to_train;
|
||||
std::unordered_set<std::string> weights_not_to_train;
|
||||
|
|
@ -84,6 +87,8 @@ TrainingConfigurationResult ConfigureSessionForTraining(
|
|||
}
|
||||
|
||||
training::TrainingSession::TrainingConfiguration config{};
|
||||
config.model_with_loss_function_path = parameters.model_with_loss_function_path;
|
||||
config.model_with_training_graph_path = parameters.model_with_training_graph_path;
|
||||
config.weight_names_to_train = parameters.weights_to_train;
|
||||
config.weight_names_to_not_train = parameters.weights_not_to_train;
|
||||
config.immutable_weights = parameters.immutable_weights;
|
||||
|
|
@ -190,6 +195,8 @@ void CopyMPIContextToTrainingParameters(TrainingParameters& parameters, const lo
|
|||
void addObjectMethodsForTraining(py::module& m) {
|
||||
py::class_<TrainingParameters> parameters(m, "TrainingParameters", R"pbdoc(Configuration information for training.)pbdoc");
|
||||
parameters.def(py::init())
|
||||
.def_readwrite("model_with_loss_function_path", &TrainingParameters::model_with_loss_function_path)
|
||||
.def_readwrite("model_with_training_graph_path", &TrainingParameters::model_with_training_graph_path)
|
||||
.def_readwrite("loss_output_name", &TrainingParameters::loss_output_name)
|
||||
.def_readwrite("immutable_weights", &TrainingParameters::immutable_weights)
|
||||
.def_readwrite("weights_not_to_train", &TrainingParameters::weights_not_to_train)
|
||||
|
|
|
|||
|
|
@ -38,7 +38,33 @@ class ORTModule(torch.nn.Module):
|
|||
# gradient_graph = ORTModule._build_gradient_graph(original_forward_graph)
|
||||
# self.forward_graph, self.backward_graph = ORTModule._split_forward_and_backward(gradient_graph)
|
||||
self._onnx_forward = original_forward_graph # TODO: hard-coding for MVP
|
||||
# import pdb; pdb.set_trace()
|
||||
self.forward_session = onnxruntime.InferenceSession(self._onnx_forward.SerializeToString())
|
||||
|
||||
|
||||
# TrainingParameters
|
||||
# ort_parameters = onnxruntime.TrainingParameters()
|
||||
# ort_parameters.loss_output_name = "loss"
|
||||
# ort_parameters.use_mixed_precision = False
|
||||
# ort_parameters.world_rank = 0
|
||||
# ort_parameters.world_size = 1
|
||||
# ort_parameters.gradient_accumulation_steps = 1
|
||||
# ort_parameters.allreduce_post_accumulation = False
|
||||
# ort_parameters.deepspeed_zero_stage = 0
|
||||
# ort_parameters.enable_grad_norm_clip = False
|
||||
# ort_parameters.set_gradients_as_graph_outputs = False
|
||||
# ort_parameters.use_invertible_layernorm_grad = False
|
||||
# ort_parameters.training_optimizer_name = "SGDOptimizer"
|
||||
# ort_parameters.lr_params_feed_name = "Learning_Rate"
|
||||
# ort_parameters.weights_to_train = trainable_params
|
||||
# ort_parameters.optimizer_attributes_map = optimizer_attributes_map
|
||||
# ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map
|
||||
|
||||
# # SessionOptions
|
||||
# session_options = onnxruntime.SessionOptions()
|
||||
# session_options.use_deterministic_compute = self.options.debug.deterministic_compute
|
||||
# self.forward_session = onnxruntime.TrainingSession(self._onnx_forward.SerializeToString(), ort_parameters, session_options)
|
||||
|
||||
self._save_onnx_graph(self._onnx_forward, 'forward_mnist.onnx')
|
||||
if not self._onnx_forward_initializers_desc:
|
||||
self._onnx_forward_initializers_desc = self._get_initializer_from_graph(self._onnx_forward)
|
||||
|
|
@ -60,9 +86,15 @@ class ORTModule(torch.nn.Module):
|
|||
# Note: A potential optimization would be to detect which of inputs and weights
|
||||
# require a gradient.
|
||||
# intermediates, outputs = self._run_forward_graph(inputs) # inputs, weights)
|
||||
# import pdb; pdb.set_trace()
|
||||
outputs = self._run_forward_graph(*input, **kwargs) # inputs, weights)
|
||||
# ctx.save_for_backward(*intermediates)
|
||||
outputs = [torch.from_numpy(out).requires_grad_(True) for out in outputs]
|
||||
|
||||
# TODO: Properly save intermediate tensors and remove them from model output
|
||||
ctx.save_for_backward(outputs[1])
|
||||
outputs = [outputs[0]]
|
||||
|
||||
# TODO: Properly support original module output format
|
||||
if len(outputs) == 1:
|
||||
return outputs[0]
|
||||
return tuple(outputs)
|
||||
|
|
@ -82,6 +114,7 @@ class ORTModule(torch.nn.Module):
|
|||
# Dictionary containing both inputs and initializers
|
||||
input_with_initializer = {}
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
# Inputs
|
||||
for idx, input_data in enumerate(self.forward_session.get_inputs()):
|
||||
input_with_initializer.update({input_data.name : input[idx].cpu().numpy()})
|
||||
|
|
@ -93,6 +126,7 @@ class ORTModule(torch.nn.Module):
|
|||
return input_with_initializer
|
||||
|
||||
def _run_forward_graph(self, data_with_initializer): #input, weights):
|
||||
# import pdb; pdb.set_trace()
|
||||
return self.forward_session.run(None, data_with_initializer)
|
||||
|
||||
def _run_backward_graph(self, grad_output, intermediates):
|
||||
|
|
@ -102,15 +136,18 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def _get_forward_graph(module, module_input):
|
||||
# TODO: Pytorch module must be exported to ONNX and splitted
|
||||
# Hard-coding with MNIST stub for MVP
|
||||
# Export torch.nn.Module to ONNX with initializers as input
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export(module, module_input, f, verbose=True,
|
||||
opset_version=ONNX_OPSET_VERSION,
|
||||
_retain_param_name=True,
|
||||
training=torch.onnx.TrainingMode.TRAINING,
|
||||
keep_initializers_as_inputs=True,
|
||||
export_params=True)
|
||||
return onnx.load_model_from_string(f.getvalue())
|
||||
# f = io.BytesIO()
|
||||
# torch.onnx.export(module, module_input, f, verbose=True,
|
||||
# opset_version=ONNX_OPSET_VERSION,
|
||||
# _retain_param_name=True,
|
||||
# training=torch.onnx.TrainingMode.TRAINING,
|
||||
# keep_initializers_as_inputs=True,
|
||||
# export_params=True)
|
||||
# return onnx.load_model_from_string(f.getvalue())
|
||||
return onnx.load('/home/thiagofc/mnist_onnx/mnist_with_training_forward_sliced.onnx')
|
||||
|
||||
def _get_initializer_from_graph(self, graph):
|
||||
# TODO: There is a tradefoo between memory footprint and total model export time
|
||||
|
|
|
|||
|
|
@ -621,6 +621,8 @@ class ORTTrainer(object):
|
|||
|
||||
# TrainingParameters
|
||||
ort_parameters = ort.TrainingParameters()
|
||||
ort_parameters.model_with_loss_function_path = '/home/thiagofc/mnist_onnx/mnist_with_loss.onnx'
|
||||
ort_parameters.model_with_training_graph_path = '/home/thiagofc/mnist_onnx/mnist_with_training.onnx'
|
||||
ort_parameters.loss_output_name = loss_name
|
||||
ort_parameters.use_mixed_precision = self.options.mixed_precision.enabled
|
||||
ort_parameters.world_rank = self.options.distributed.world_rank
|
||||
|
|
@ -629,7 +631,7 @@ class ORTTrainer(object):
|
|||
ort_parameters.allreduce_post_accumulation = self.options.distributed.allreduce_post_accumulation
|
||||
ort_parameters.deepspeed_zero_stage = self.options.distributed.deepspeed_zero_optimization.stage
|
||||
ort_parameters.enable_grad_norm_clip = self.options.utils.grad_norm_clip
|
||||
ort_parameters.set_gradients_as_graph_outputs = False
|
||||
ort_parameters.set_gradients_as_graph_outputs = True
|
||||
ort_parameters.use_invertible_layernorm_grad = self.options.utils.invertible_layer_norm_gradient
|
||||
ort_parameters.training_optimizer_name = self.optim_config.name
|
||||
ort_parameters.lr_params_feed_name = self.model_desc.learning_rate.name
|
||||
|
|
|
|||
|
|
@ -53,9 +53,10 @@ for iteration, (data, target) in enumerate(train_loader):
|
|||
|
||||
data = data.reshape(data.shape[0], -1)
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
print(f'Output from forward has shape {output[0].size()}: {output[0]}')
|
||||
loss = criterion(output, target)
|
||||
probability = model(data)
|
||||
print(f'Output from forward has shape {probability.size()}: {probability}')
|
||||
# import pdb; pdb.set_trace()
|
||||
loss = criterion(probability, target)
|
||||
# loss.backward()
|
||||
# optimizer.step()
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,148 @@
|
|||
# coding=utf8
|
||||
import sys
|
||||
import onnx
|
||||
from onnx import helper, shape_inference
|
||||
from onnx import TensorProto
|
||||
import numpy as np
|
||||
from onnx import numpy_helper
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
print("Please give model path...")
|
||||
exit(1)
|
||||
|
||||
input_model_name = sys.argv[1]
|
||||
output_forward_model_name = input_model_name[:-5] + '_forward_sliced.onnx'
|
||||
output_backward_model_name = input_model_name[:-5] + '_backward_sliced.onnx'
|
||||
|
||||
def add_model_input_from_initializer(model, initializer, docstring=None):
|
||||
new_input = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, initializer.dims, docstring)
|
||||
model.graph.input.append(new_input)
|
||||
|
||||
def add_model_input(model, name, data_type, dims, docstring=None):
|
||||
new_input = onnx.helper.make_tensor_value_info(name, data_type, dims, docstring)
|
||||
model.graph.input.append(new_input)
|
||||
|
||||
def add_output(model, name, data_type = None, docstring = None):
|
||||
new_output = model.graph.value_info.add()
|
||||
new_output.name = name
|
||||
if data_type:
|
||||
new_output.type.CopyFrom(data_type)
|
||||
if docstring:
|
||||
new_output.doc_string = docstring
|
||||
model.graph.output.append(new_output)
|
||||
|
||||
def find_model_input(model, input_name):
|
||||
for input in model.graph.input:
|
||||
if input.name == input_name:
|
||||
return input
|
||||
return None
|
||||
|
||||
def find_model_output(model, output_name):
|
||||
for output in model.graph.output:
|
||||
if output.name == output_name:
|
||||
return output
|
||||
return None
|
||||
|
||||
def find_initializer(model, name):
|
||||
for initializer in model.graph.initializer:
|
||||
if initializer.name == name:
|
||||
return initializer
|
||||
return None
|
||||
|
||||
def find_node(model, name):
|
||||
for node in model.graph.node:
|
||||
if node.name == name:
|
||||
return node
|
||||
return None
|
||||
|
||||
###############################################################################
|
||||
# FORWARD PASS GRAPH ##########################################################
|
||||
###############################################################################
|
||||
model = onnx.load(input_model_name)
|
||||
|
||||
# Remove model inputs
|
||||
# They are: label
|
||||
node = find_model_input(model, 'label')
|
||||
model.graph.input.remove(node)
|
||||
|
||||
# Remove model outputs
|
||||
# They are: loss
|
||||
node = find_model_output(model, 'loss')
|
||||
model.graph.output.remove(node)
|
||||
node = find_model_output(model, 'fc1.bias_grad')
|
||||
model.graph.output.remove(node)
|
||||
node = find_model_output(model, 'fc1.weight_grad')
|
||||
model.graph.output.remove(node)
|
||||
node = find_model_output(model, 'fc2.bias_grad')
|
||||
model.graph.output.remove(node)
|
||||
node = find_model_output(model, 'fc2.weight_grad')
|
||||
model.graph.output.remove(node)
|
||||
|
||||
# Add input with same name, type and shape as the initializers
|
||||
# They are: [fc1.bias, fc1.weight, fc2.bias, fc2.weight]
|
||||
node = find_initializer(model, 'fc1.bias')
|
||||
add_model_input_from_initializer(model, node, 'thiagofc: add fc1.bias as model input')
|
||||
node = find_initializer(model, 'fc1.weight')
|
||||
add_model_input_from_initializer(model, node, 'thiagofc: add fc1.weight as model input')
|
||||
node = find_initializer(model, 'fc2.bias')
|
||||
add_model_input_from_initializer(model, node, 'thiagofc: add fc2.bias as model input')
|
||||
node = find_initializer(model, 'fc2.weight')
|
||||
add_model_input_from_initializer(model, node, 'thiagofc: add fc2.weight as model input')
|
||||
|
||||
# Remove initializers from model
|
||||
# They are: [fc1.bias, fc1.weight, fc2.bias, fc2.weight]
|
||||
# TODO: Do this when we are able to distinguish inputs from initializers
|
||||
# model.graph.initializer.remove(node)
|
||||
# model.graph.initializer.remove(node)
|
||||
# model.graph.initializer.remove(node)
|
||||
# model.graph.initializer.remove(node)
|
||||
|
||||
# Remove backward-related initializers
|
||||
# They are: [loss_grad, ZeroConstant]
|
||||
node = find_initializer(model, 'loss_grad')
|
||||
model.graph.initializer.remove(node)
|
||||
node = find_initializer(model, 'ZeroConstant')
|
||||
model.graph.initializer.remove(node)
|
||||
|
||||
# Remove OPs
|
||||
# They are: [SoftmaxCrossEntropyLoss_3, SoftmaxCrossEntropyLoss_3_Grad/SoftmaxCrossEntropyLossGrad_0,
|
||||
# Gemm_2_Grad/ReduceSum_3, Gemm_2_Grad/Identity_4, Gemm_2_Grad/Gemm_2, Gemm_2_Grad/Gemm_1,
|
||||
# Relu_1_Grad/ReluGrad_0, Gemm_0_Grad/Gemm_1, Gemm_0_Grad/ReduceSum_2, Gemm_0_Grad/Identity_3]
|
||||
node = find_node(model, 'SoftmaxCrossEntropyLoss_3')
|
||||
model.graph.node.remove(node)
|
||||
node = find_node(model, 'SoftmaxCrossEntropyLoss_3_Grad/SoftmaxCrossEntropyLossGrad_0')
|
||||
model.graph.node.remove(node)
|
||||
node = find_node(model, 'Gemm_2_Grad/ReduceSum_3')
|
||||
model.graph.node.remove(node)
|
||||
node = find_node(model, 'Gemm_2_Grad/Identity_4')
|
||||
model.graph.node.remove(node)
|
||||
node = find_node(model, 'Gemm_2_Grad/Gemm_2')
|
||||
model.graph.node.remove(node)
|
||||
node = find_node(model, 'Gemm_2_Grad/Gemm_1')
|
||||
model.graph.node.remove(node)
|
||||
node = find_node(model, 'Relu_1_Grad/ReluGrad_0')
|
||||
model.graph.node.remove(node)
|
||||
node = find_node(model, 'Gemm_0_Grad/Gemm_1')
|
||||
model.graph.node.remove(node)
|
||||
node = find_node(model, 'Gemm_0_Grad/ReduceSum_2')
|
||||
model.graph.node.remove(node)
|
||||
node = find_node(model, 'Gemm_0_Grad/Identity_3')
|
||||
model.graph.node.remove(node)
|
||||
|
||||
# Add new outputs:
|
||||
# They are: 7
|
||||
add_output(model, '7', None, 'thiagofc: add 7 as model output')
|
||||
|
||||
with open(output_forward_model_name, "wb") as f:
|
||||
f.write(model.SerializeToString())
|
||||
|
||||
|
||||
###############################################################################
|
||||
# FORWARD PASS GRAPH ##########################################################
|
||||
###############################################################################
|
||||
model = onnx.load(input_model_name)
|
||||
|
||||
|
||||
|
||||
with open(output_backward_model_name, "wb") as f:
|
||||
f.write(model.SerializeToString())
|
||||
|
|
@ -8,7 +8,7 @@ import torch.nn.functional as F
|
|||
from torchvision import datasets, transforms
|
||||
|
||||
import onnxruntime
|
||||
from onnxruntime.experimental import ORTTrainer, ORTTrainerOptions, optim
|
||||
from onnxruntime.training import ORTTrainer, ORTTrainerOptions, optim
|
||||
|
||||
|
||||
# Pytorch model
|
||||
|
|
|
|||
Loading…
Reference in a new issue