Add support for gradient clipping, AdamWOptimizer and tensorseq as inputs (#11697)

This commit is contained in:
Baiju Meswani 2022-06-22 10:27:58 -07:00 committed by GitHub
parent f14f0e19ec
commit fac8dae9df
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
28 changed files with 749 additions and 681 deletions

View file

@ -226,7 +226,7 @@ static Status BatchOrCopyMLValue(const SessionState& session_state,
}
++source_iter;
++target_iter;
} //while
} // while
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported OrtValue type to copy between device.");
}
@ -322,9 +322,9 @@ static common::Status CalculateStaticCopyInfoForFetches(const SessionState& sess
// If for some reason using just the device from the allocation plan isn't enough, the following
// would use the NodeInfo from the node producing the output
//
//std::vector<SessionState::NodeInfo> node_info_vec;
//auto status = session_state.GetOutputNodeInfo(output_name, node_info_vec);
//if (status.IsOK()) {
// std::vector<SessionState::NodeInfo> node_info_vec;
// auto status = session_state.GetOutputNodeInfo(output_name, node_info_vec);
// if (status.IsOK()) {
// const auto& node_info = node_info_vec.front(); // only one entry as only one node can produce a given output
// copy_info[idx].source_device = *node_info.device;
//} else {
@ -428,6 +428,11 @@ static void FinalizeFeedFetchCopyInfo(FeedsFetchesManager& feeds_fetches_manager
const auto& feed = feeds[i];
if (feed.IsTensor()) {
feed_locations[i] = feed.Get<Tensor>().Location().device;
} else if (feed.IsTensorSequence()) {
const auto& tensor_seq = feed.Get<TensorSeq>();
if (tensor_seq.Size() != std::size_t{0}) {
feed_locations[i] = tensor_seq.Get(0).Location().device;
}
} else if (feed.IsSparseTensor()) {
#if !defined(DISABLE_SPARSE_TENSORS)
feed_locations[i] = feed.Get<SparseTensor>().Location().device;

Binary file not shown.

View file

@ -2,13 +2,8 @@
# Licensed under the MIT License.
# __init__.py
from .building_blocks import Block
from .model import Model, TrainingModel
from .checkpoint_utils import save_checkpoint
from . import loss, optim
from .building_blocks import Block
from .checkpoint_utils import save_checkpoint
from .model import Model, TrainingModel
from .model_accessor import onnx_model
import onnx
_producer_name = "onnxblock offline tooling"
_opset_import = onnx.helper.make_opsetid("com.microsoft", 1)

View file

@ -1,146 +0,0 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# _building_blocks.py
import copy
import onnx
import onnxruntime.training.onnxblock as onnxblock
import onnxruntime.training.onnxblock.model_accessor as accessor
import onnxruntime.training.onnxblock._graph_utils as graph_utils
class Sub(onnxblock.Model):
"""Adds Sub node to an onnx model."""
def __init__(self):
super(Sub, self).__init__()
def build(self, sub_input_name1, sub_input_name2):
# get the model to manipulate
onnx_model = accessor.global_accessor.model
# create the graph node for sub
sub_node_input_names = [sub_input_name1, sub_input_name2]
sub_node_output_name = graph_utils.generate_random_graph_name("sub_output")
sub_node_output_names = [sub_node_output_name]
sub_node = onnx.helper.make_node(
"Sub",
sub_node_input_names,
sub_node_output_names,
name=graph_utils.generate_random_graph_name("Sub"),
)
onnx_model.graph.node.append(sub_node)
# create the graph output for sub
graph_output = copy.deepcopy(
graph_utils.get_output_from_output_name(onnx_model, sub_input_name1)
)
graph_output.name = sub_node_output_name
del onnx_model.graph.output[:]
onnx_model.graph.output.append(graph_output)
return sub_node_output_name
class Pow(onnxblock.Model):
"""Adds Pow node to the onnx model."""
def __init__(self, exponent):
super(Pow, self).__init__()
self._exponent = exponent
def build(self, pow_input_name):
# get the model to manipulate
onnx_model = accessor.global_accessor.model
# create the graph initializer for the exponent
pow_node_exponent_name = graph_utils.generate_random_graph_name("pow_exponent")
onnx_model.graph.initializer.append(
onnx.helper.make_tensor(
pow_node_exponent_name, onnx.TensorProto.FLOAT, [1], [self._exponent]
)
)
# create the graph node for pow
pow_node_input_names = [pow_input_name, pow_node_exponent_name]
pow_node_output_name = graph_utils.generate_random_graph_name("pow_output")
pow_node_output_names = [pow_node_output_name]
pow_node = onnx.helper.make_node(
"Pow",
pow_node_input_names,
pow_node_output_names,
name=graph_utils.generate_random_graph_name("Pow"),
)
onnx_model.graph.node.append(pow_node)
# create the graph output for pow
graph_output = copy.deepcopy(
graph_utils.get_output_from_output_name(onnx_model, pow_input_name)
)
graph_output.name = pow_node_output_name
del onnx_model.graph.output[:]
onnx_model.graph.output.append(graph_output)
return pow_node_output_name
class _Reduce(onnxblock.Model):
"""Base class for the reduce blocks."""
def __init__(self):
super(_Reduce, self).__init__()
def _reduce(self, reduce_input_name, reduction_op):
# get the model to manipulate
onnx_model = accessor.global_accessor.model
# create the graph node for reduce
reduce_node_input_names = [reduce_input_name]
reduce_node_output_name = graph_utils.generate_random_graph_name(
"reduce_output"
)
reduce_node_output_names = [reduce_node_output_name]
reduce_node = onnx.helper.make_node(
reduction_op,
reduce_node_input_names,
reduce_node_output_names,
name=graph_utils.generate_random_graph_name(reduction_op),
)
onnx_model.graph.node.append(reduce_node)
# create the graph output for reduce
reduce_input = copy.deepcopy(
graph_utils.get_output_from_output_name(onnx_model, reduce_input_name)
)
output_rank = len(reduce_input.type.tensor_type.shape.dim)
graph_outputs = [
onnx.helper.make_tensor_value_info(
reduce_node_output_name, onnx.TensorProto.FLOAT, [1] * output_rank
)
]
del onnx_model.graph.output[:]
onnx_model.graph.output.extend(graph_outputs)
return reduce_node_output_name
class ReduceMean(_Reduce):
"""Adds ReduceMean node to the onnx model."""
def __init__(self):
super(ReduceMean, self).__init__()
def build(self, reduce_input_name):
return super()._reduce(reduce_input_name, "ReduceMean")
class ReduceSum(_Reduce):
"""Adds ReduceSum node to the onnx model."""
def __init__(self):
super(ReduceSum, self).__init__()
def build(self, reduce_input_name):
return super(ReduceSum, self)._reduce(reduce_input_name, "ReduceSum")

View file

@ -3,9 +3,10 @@
# _graph_utils.py
import copy
import onnx
import random
import onnx
from onnxruntime.capi._pybind_state import GradientGraphBuilder
@ -21,44 +22,52 @@ def get_output_from_output_name(onnx_model, output_name):
def get_random_number():
"""Return a random number in the range 0, 1000."""
"""Return a random number in the range 0, 100000."""
return random.randint(0, 1000)
return random.randint(0, 100000)
def generate_random_graph_name(token):
"""Return a string that can be used in the graph as a graph attribute name."""
return f"{get_random_number()}.{token}"
return f"onnx::{token}::{get_random_number()}"
def build_gradient_graph(
accessor, user_args_requiring_grad, user_args_not_requiring_grad, output_names
):
def _reorder_outputs(model, user_output_names, args_requiring_gradient_names):
graph_outputs = {output.name: output for output in model.graph.output}
ordered_graph_outputs = [graph_outputs[name] for name in user_output_names]
for arg in args_requiring_gradient_names:
gradient_name = f"{arg}_grad"
ordered_graph_outputs.append(graph_outputs[gradient_name])
del model.graph.output[:]
model.graph.output.extend(ordered_graph_outputs)
def build_gradient_graph(accessor, user_args_requiring_grad, user_args_not_requiring_grad, output_names):
"""Builds the gradient graph on top of the given input forward only graph."""
model = accessor.model
# Collect names of parameters that need gradients computed
all_args_requiring_gradient = set()
all_args_requiring_gradient = []
# Move all trainable and non trainable initializers to graph inputs.
# This allows training to pass in the parameters from outside the graph
# so as to share the parameters across multiple sessions.
graph_inputs = model.graph.input
initializers = []
for initializer in model.graph.initializer:
if not initializer.name[0].isdigit():
if not initializer.name.startswith("onnx::"):
# Move only those initializers as inputs that are not local
# to the onnx model. i.e. initializers that are model parameters.
# These are tpically those initializers without any number prefixed
# These are tpically those initializers without any onnx:: prefixed
# to their names.
graph_inputs.append(
onnx.helper.make_tensor_value_info(
initializer.name, initializer.data_type, initializer.dims
)
onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, initializer.dims)
)
if initializer.name not in user_args_not_requiring_grad:
all_args_requiring_gradient.add(initializer.name)
all_args_requiring_gradient.append(initializer.name)
else:
# All other initializers stay where they were.
initializers.append(initializer)
@ -74,7 +83,7 @@ def build_gradient_graph(
# args_requiring_grad. So, add these arguments to set of arguments
# whose gradient should be built.
for argument_name in user_args_requiring_grad:
all_args_requiring_gradient.add(argument_name)
all_args_requiring_gradient.append(argument_name)
# Assumption is that the first graph output is the loss output
if isinstance(output_names, str):
@ -82,13 +91,16 @@ def build_gradient_graph(
builder = GradientGraphBuilder(
model.SerializeToString(),
set(output_names),
all_args_requiring_gradient,
set(all_args_requiring_gradient),
output_names[0],
)
builder.build()
gradient_model = onnx.load_from_string(builder.get_model())
# copy the gradient graph into the user's graph
# Reorder gradient outputs for the gradient model based on the all_args_requiring_gradient order
_reorder_outputs(gradient_model, output_names, all_args_requiring_gradient)
# copy the gradient model into the user's model
model.CopyFrom(gradient_model)
return all_args_requiring_gradient
@ -117,9 +129,7 @@ def build_gradient_accumulation_graph(grad_model, all_args_requiring_gradient_na
"""
# TODO: Avoid hard coded input/output strings
gradient_output_names = {
f"{arg_name}_grad" for arg_name in all_args_requiring_gradient_names
}
gradient_output_names = {f"{arg_name}_grad" for arg_name in all_args_requiring_gradient_names}
graph_inputs = grad_model.graph.input
graph_nodes = grad_model.graph.node
@ -139,12 +149,8 @@ def build_gradient_accumulation_graph(grad_model, all_args_requiring_gradient_na
# gradient accumulation node inputs and output names
grad_name = graph_output.name
grad_accumulation_buffer_name = (
f"{grad_name}.{gradient_accumulation_name}.{gradient_buffer_name_suffix}"
)
grad_accumulation_output_name = (
f"{grad_name}.{gradient_accumulation_name}.{gradient_output_name_suffix}"
)
grad_accumulation_buffer_name = f"{grad_name}.{gradient_accumulation_name}.{gradient_buffer_name_suffix}"
grad_accumulation_output_name = f"{grad_name}.{gradient_accumulation_name}.{gradient_output_name_suffix}"
# Gradient accumulation node
acc_node = onnx.helper.make_node(
@ -167,9 +173,7 @@ def build_gradient_accumulation_graph(grad_model, all_args_requiring_gradient_na
grad_accumulation_output.name = grad_accumulation_output_name
graph_outputs.append(grad_accumulation_output)
lazy_reset_grad_input = onnx.helper.make_tensor_value_info(
lazy_reset_grad_input_name, onnx.TensorProto.BOOL, [1]
)
lazy_reset_grad_input = onnx.helper.make_tensor_value_info(lazy_reset_grad_input_name, onnx.TensorProto.BOOL, [1])
graph_inputs.append(lazy_reset_grad_input)
del grad_model.graph.output[:]
@ -183,18 +187,18 @@ def get_model_parameters(model, args_not_requiring_gradient):
non_trainable_params = []
for initializer in model.graph.initializer:
# All model parameters should have their names not begin with
# a digit. So, check to see if the initializer's first char
# is a digit. If not, it is either a trainable, or a non
# `onnx::`. So, check to see if the initializer begins with
# `onnx::`. If not, it is either a trainable, or a non
# trainable parameter.
# Note that this assumption can be made because the export logic
# does not change the names of the original model parameters.
# and the original model parameters don't have their names begin
# with a digit.
# `onnx::`.
# On the other hand, const initializers are generated by export
# logic and have a digit prefix.
# logic and have a `onnx::` prefix.
# TODO: validate this assumption. If assumption is not valid,
# the alternative is to enforce the user to provide the parameter names.
if not initializer.name[0].isdigit():
if not initializer.name.startswith("onnx::"):
if initializer.name in args_not_requiring_gradient:
non_trainable_params.append(initializer)
else:
@ -215,9 +219,7 @@ def build_graph_outputs(model, output_names):
if isinstance(output_names, str):
output_names = [output_names]
name_value_info_mapping = {
value_info.name: value_info for value_info in model.graph.value_info
}
name_value_info_mapping = {value_info.name: value_info for value_info in model.graph.value_info}
name_graph_output_mapping = {output.name: output for output in model.graph.output}
# collect all new graph outputs (i.e. graph outputs that are not
@ -229,9 +231,7 @@ def build_graph_outputs(model, output_names):
elif output_name in name_value_info_mapping:
graph_outputs.append(name_value_info_mapping[output_name])
else:
raise LookupError(
f"The provided name {output_name} is not a graph value info or a graph output."
)
raise LookupError(f"The provided name {output_name} is not a graph value info or a graph output.")
# Clear all existing graph outputs
del model.graph.output[:]

View file

@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# _building_blocks.py
# building_blocks.py
from abc import ABC, abstractmethod
@ -38,7 +38,7 @@ class Block(ABC):
class _BinaryOp(Block):
def __init__(self, op_name):
super(_BinaryOp, self).__init__()
super().__init__()
self._op_name = op_name
def build(self, input_name1, input_name2):
@ -68,28 +68,35 @@ class Add(_BinaryOp):
"""Adds Add node to an onnx model."""
def __init__(self):
super(Add, self).__init__("Add")
super().__init__("Add")
class Sub(_BinaryOp):
"""Adds Sub node to an onnx model."""
def __init__(self):
super(Sub, self).__init__("Sub")
super().__init__("Sub")
class Mul(_BinaryOp):
"""Adds Mul node to an onnx model."""
def __init__(self):
super(Mul, self).__init__("Mul")
super().__init__("Mul")
class Div(_BinaryOp):
"""Adds Div node to an onnx model."""
def __init__(self):
super().__init__("Div")
class Pow(Block):
"""Adds Pow node to the onnx model."""
def __init__(self, exponent):
super(Pow, self).__init__()
super().__init__()
self._exponent = exponent
@ -119,8 +126,10 @@ class Pow(Block):
class _UnaryOp(Block):
"""Base class for all nodes that take in a single argument."""
def __init__(self, op_name):
super(_UnaryOp, self).__init__()
super().__init__()
self._op_name = op_name
def build(self, input_name):
@ -150,29 +159,35 @@ class ReduceMean(_UnaryOp):
"""Adds ReduceMean node to the onnx model."""
def __init__(self):
super(ReduceMean, self).__init__("ReduceMean")
super().__init__("ReduceMean")
class ReduceSum(_UnaryOp):
"""Adds ReduceSum node to the onnx model."""
def __init__(self):
super(ReduceSum, self).__init__("ReduceSum")
super().__init__("ReduceSum")
class Sigmoid(_UnaryOp):
"""Adds Sigmoid node to the onnx model."""
def __init__(self):
super(Sigmoid, self).__init__("Sigmoid")
super().__init__("Sigmoid")
class Log(_UnaryOp):
"""Adds Log node to the onnx model."""
def __init__(self):
super(Log, self).__init__("Log")
super().__init__("Log")
class Neg(_UnaryOp):
"""Adds Neg node to the onnx model."""
def __init__(self):
super(Neg, self).__init__("Neg")
super().__init__("Neg")
class Constant(Block):
@ -192,3 +207,109 @@ class Constant(Block):
onnx.helper.make_tensor(initializer_name, onnx.TensorProto.FLOAT, [1], [self._value])
)
return initializer_name
class SequenceConstruct(Block):
"""Adds SequenceConstruct node to the onnx model."""
def __init__(self):
super().__init__()
def build(self, *sequence_input_names):
# get the model to manipulate
onnx_model = accessor.global_accessor.model
# create the graph node for this sequence construct node
sc_node_input_names = list(sequence_input_names)
sc_node_output_name = graph_utils.generate_random_graph_name("sequenceconstruct_output")
sc_node_output_names = [sc_node_output_name]
sc_node = onnx.helper.make_node(
"SequenceConstruct",
sc_node_input_names,
sc_node_output_names,
graph_utils.generate_random_graph_name("SequenceConstruct"),
)
onnx_model.graph.node.append(sc_node)
return sc_node_output_name
class ReduceAllL2(Block):
"""Adds ReduceAllL2 node to the onnx model.
ReduceAllL2 is a part of the com.microsoft domain and might not be accessible outside this domain.
"""
def __init__(self):
super().__init__()
def build(self, *reduce_node_input_names):
# get the model to manipulate
onnx_model = accessor.global_accessor.model
# create the graph node for this reducealll2 node
reduce_node_input_names = list(reduce_node_input_names)
reduce_node_output_name = graph_utils.generate_random_graph_name("reducealll2_output")
reduce_node_output_names = [reduce_node_output_name]
reduce_node = onnx.helper.make_node(
"ReduceAllL2",
reduce_node_input_names,
reduce_node_output_names,
graph_utils.generate_random_graph_name("ReduceAllL2"),
domain="com.microsoft",
)
onnx_model.graph.node.append(reduce_node)
# TODO: register shape inference with onnx
onnx_model.graph.value_info.append(
onnx.helper.make_tensor_value_info(reduce_node_output_name, onnx.TensorProto.FLOAT, [1])
)
return reduce_node_output_name
class Clip(Block):
"""Adds Clip node to the onnx model."""
def __init__(self, clip_min=None, clip_max=None):
super().__init__()
self._min = clip_min
self._max = clip_max
def build(self, clip_input_name):
# get the model to manipulate
onnx_model = accessor.global_accessor.model
# create the graph initializer for the clip min
clip_node_min_name = ""
if self._min is not None:
clip_node_min_name = graph_utils.generate_random_graph_name("clip_min")
onnx_model.graph.initializer.append(
onnx.helper.make_tensor(clip_node_min_name, onnx.TensorProto.FLOAT, [1], [self._min])
)
# create the graph initializer for the clip max
clip_node_max_name = ""
if self._max is not None:
clip_node_max_name = graph_utils.generate_random_graph_name("clip_max")
onnx_model.graph.initializer.append(
onnx.helper.make_tensor(clip_node_max_name, onnx.TensorProto.FLOAT, [1], [self._max])
)
# create the graph node for this clip node
clip_node_input_names = [
clip_input_name,
clip_node_min_name,
clip_node_max_name,
]
clip_node_output_name = graph_utils.generate_random_graph_name("clip_output")
clip_node_output_names = [clip_node_output_name]
clip_node = onnx.helper.make_node(
"Clip",
clip_node_input_names,
clip_node_output_names,
graph_utils.generate_random_graph_name("Clip"),
)
onnx_model.graph.node.append(clip_node)
return clip_node_output_name

View file

@ -16,6 +16,4 @@ def save_checkpoint(parameters, path_to_checkpoint):
trainable_params, non_trainable_params = parameters
trainable_params = [param.SerializeToString() for param in trainable_params]
non_trainable_params = [param.SerializeToString() for param in non_trainable_params]
_internal_save_checkpoint(
trainable_params, non_trainable_params, path_to_checkpoint
)
_internal_save_checkpoint(trainable_params, non_trainable_params, path_to_checkpoint)

View file

@ -3,11 +3,12 @@
# loss.py
import copy
import onnx
import onnxruntime.training.onnxblock.model_accessor as accessor
import onnxruntime.training.onnxblock.building_blocks as building_blocks
import onnxruntime.training.onnxblock._graph_utils as graph_utils
import onnxruntime.training.onnxblock.building_blocks as building_blocks
import onnxruntime.training.onnxblock.model_accessor as accessor
class MSELoss(building_blocks.Block):
@ -25,11 +26,7 @@ class MSELoss(building_blocks.Block):
if reduction != "mean" and reduction != "sum":
raise RuntimeError(f"Reduction {reduction} not supported.")
self._reduce = (
building_blocks.ReduceMean()
if reduction == "mean"
else building_blocks.ReduceSum()
)
self._reduce = building_blocks.ReduceMean() if reduction == "mean" else building_blocks.ReduceSum()
self._sub = building_blocks.Sub()
self._square = building_blocks.Pow(2.0)
@ -53,9 +50,7 @@ class MSELoss(building_blocks.Block):
# create a new graph input. this is the target input needed to compare
# the graph output against to calculate loss.
# TODO: Move input creation outside of the blocks.
target_input = copy.deepcopy(
graph_utils.get_output_from_output_name(onnx_model, loss_input_name)
)
target_input = copy.deepcopy(graph_utils.get_output_from_output_name(onnx_model, loss_input_name))
target_input.name = target_name
onnx_model.graph.input.append(target_input)
@ -106,15 +101,11 @@ class CrossEntropyLoss(building_blocks.Block):
weight_name = graph_utils.generate_random_graph_name("celoss.weight")
if self._weight is not None:
onnx_model.graph.initializer.append(
onnx.numpy_helper.from_array(self._weight, weight_name)
)
onnx_model.graph.initializer.append(onnx.numpy_helper.from_array(self._weight, weight_name))
# create a new graph input. this is the labels input needed to compare
# the graph output against to calculate loss.
labels_input = copy.deepcopy(
graph_utils.get_output_from_output_name(onnx_model, scores_input_name)
)
labels_input = copy.deepcopy(graph_utils.get_output_from_output_name(onnx_model, scores_input_name))
labels_input.name = labels_name
labels_input.type.tensor_type.elem_type = onnx.TensorProto.INT32
# if the predictions are (num_examples x num_classes)
@ -163,11 +154,7 @@ class BCEWithLogitsLoss(building_blocks.Block):
raise RuntimeError(f"Reduction {reduction} not supported.")
self._weight = weight
self._reduce = (
building_blocks.ReduceMean()
if reduction == "mean"
else building_blocks.ReduceSum()
)
self._reduce = building_blocks.ReduceMean() if reduction == "mean" else building_blocks.ReduceSum()
self._pos_weight = pos_weight
self._sigmoid = building_blocks.Sigmoid()
@ -198,38 +185,24 @@ class BCEWithLogitsLoss(building_blocks.Block):
# create the graph initializers for pos_weight, weight, and the sub operands ([1])
pos_weight_name = graph_utils.generate_random_graph_name("bceloss.pos_weight")
if self._pos_weight is not None:
onnx_model.graph.initializer.append(
onnx.numpy_helper.from_array(self._pos_weight, pos_weight_name)
)
onnx_model.graph.initializer.append(onnx.numpy_helper.from_array(self._pos_weight, pos_weight_name))
weight_name = graph_utils.generate_random_graph_name("bceloss.weight")
if self._weight is not None:
onnx_model.graph.initializer.append(
onnx.numpy_helper.from_array(self._weight, weight_name)
)
onnx_model.graph.initializer.append(onnx.numpy_helper.from_array(self._weight, weight_name))
sub_ones_operand_name1 = graph_utils.generate_random_graph_name(
"bceloss.sub_ones"
)
sub_ones_operand_name1 = graph_utils.generate_random_graph_name("bceloss.sub_ones")
onnx_model.graph.initializer.append(
onnx.helper.make_tensor(
sub_ones_operand_name1, onnx.TensorProto.FLOAT, [1], [1.0]
)
)
sub_ones_operand_name2 = graph_utils.generate_random_graph_name(
"bceloss.sub_ones"
onnx.helper.make_tensor(sub_ones_operand_name1, onnx.TensorProto.FLOAT, [1], [1.0])
)
sub_ones_operand_name2 = graph_utils.generate_random_graph_name("bceloss.sub_ones")
onnx_model.graph.initializer.append(
onnx.helper.make_tensor(
sub_ones_operand_name2, onnx.TensorProto.FLOAT, [1], [1.0]
)
onnx.helper.make_tensor(sub_ones_operand_name2, onnx.TensorProto.FLOAT, [1], [1.0])
)
# create a new graph input. this is the target input needed to compare
# the graph output against to calculate loss.
target_input = copy.deepcopy(
graph_utils.get_output_from_output_name(onnx_model, loss_input_name)
)
target_input = copy.deepcopy(graph_utils.get_output_from_output_name(onnx_model, loss_input_name))
target_input.name = target_name
onnx_model.graph.input.append(target_input)

View file

@ -3,10 +3,12 @@
# model.py
from abc import abstractmethod
import onnx
import onnxruntime.training.onnxblock.model_accessor as accessor
import onnxruntime.training.onnxblock.building_blocks as building_blocks
import onnxruntime.training.onnxblock._graph_utils as graph_utils
import onnxruntime.training.onnxblock.building_blocks as building_blocks
import onnxruntime.training.onnxblock.model_accessor as accessor
class Model(building_blocks.Block):
@ -32,9 +34,7 @@ class Model(building_blocks.Block):
output = self.build(*args, **kwargs)
# Perform shape inference
model_with_shapes = onnx.shape_inference.infer_shapes(
accessor.global_accessor.model
)
model_with_shapes = onnx.shape_inference.infer_shapes(accessor.global_accessor.model)
accessor.global_accessor.model.CopyFrom(model_with_shapes)
# Build the graph outputs
@ -85,10 +85,7 @@ class TrainingModel(building_blocks.Block):
is built, an exception will be raised.
"""
if self._parameters is None:
raise RuntimeError(
"Please build the training model first before trying to "
"retrieve the parameters."
)
raise RuntimeError("Please build the training model first before trying to " "retrieve the parameters.")
return self._parameters
@ -108,9 +105,7 @@ class TrainingModel(building_blocks.Block):
output = self.build(*args, **kwargs)
# Perform shape inference
model_with_shapes = onnx.shape_inference.infer_shapes(
accessor.global_accessor.model
)
model_with_shapes = onnx.shape_inference.infer_shapes(accessor.global_accessor.model)
accessor.global_accessor.model.CopyFrom(model_with_shapes)
# Build the graph outputs
@ -122,6 +117,7 @@ class TrainingModel(building_blocks.Block):
accessor.global_accessor.model, self._arg_not_requiring_grad
)
# build the gradient graph
all_args_requiring_gradient_names = graph_utils.build_gradient_graph(
accessor.global_accessor,
self._arg_requiring_grad,
@ -130,9 +126,7 @@ class TrainingModel(building_blocks.Block):
)
# add gradient accumulation nodes
graph_utils.build_gradient_accumulation_graph(
accessor.global_accessor.model, all_args_requiring_gradient_names
)
graph_utils.build_gradient_accumulation_graph(accessor.global_accessor.model, all_args_requiring_gradient_names)
# validate and check the model
onnx.checker.check_model(accessor.global_accessor.model, True)

View file

@ -1,16 +1,41 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# global_model.py
# model_accessor.py
from contextlib import contextmanager
import onnx
class ModelAccessor:
"""This class stores the onnx model that is manipulated by the onnx blocks."""
def __init__(self, model):
self.model = model
self.eval_model = None
self._model = model
self._eval_model = None
@property
def model(self):
"""ModelAccessor property that gets the modified model."""
if self._model is None:
raise RuntimeError(
"The onnx model was not set. Please use the context manager onnxblock.onnx_model to create the model."
)
return self._model
@property
def eval_model(self):
"""ModelAccessor property that gets the eval model."""
if self._eval_model is None:
raise RuntimeError("The eval onnx model was not set.")
return self._eval_model
@eval_model.setter
def eval_model(self, value):
"""ModelAccessor property that sets the eval model."""
self._eval_model = value
# This variable resides in the global namespace.
@ -26,6 +51,14 @@ def onnx_model(model=None):
Manages the construction and destruction of the global model.
"""
global global_accessor
if global_accessor is not None:
raise RuntimeError("Base onnx model already exists. Cannot create multiple ModelAccessors.")
# If the user did not provide a model, then assume that they want to build from scratch.
# It is the duty of the caller to fill the model however they deam fit.
if model is None:
model = onnx.ModelProto()
global_accessor = ModelAccessor(model)
try:
yield global_accessor

View file

@ -2,4 +2,4 @@
# Licensed under the MIT License.
# __init__.py
from .optim import AdamW
from .optim import AdamW, ClipGradNorm

View file

@ -2,160 +2,216 @@
# Licensed under the MIT License.
# optim.py
import copy
import onnx
import onnxruntime.training.onnxblock as onnxblock
import onnxruntime.training.onnxblock._graph_utils as graph_utils
import onnxruntime.training.onnxblock.building_blocks as building_blocks
import onnxruntime.training.onnxblock.model as model
import onnxruntime.training.onnxblock.model_accessor as accessor
# TODO: Find a better place for these constants
_PRODUCER_NAME = "onnxblock offline tooling"
_OPSET_IMPORTS = [onnx.helper.make_opsetid("com.microsoft", 1), onnx.helper.make_opsetid("", 14)]
class AdamW(onnxblock.Model):
"""Builds AdamW optimizer onnxblock for the given training model."""
def __init__(
self, bias_correction=True, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0
):
super(AdamW, self).__init__()
class AdamWOptimizer(building_blocks.Block):
"""Adds an AdamWOptimizer node to the onnx model."""
def __init__(self, bias_correction=True, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0):
super().__init__()
self._bias_correction = bias_correction
self._betas = betas
self._eps = eps
self._weight_decay = weight_decay
self._max_norm_clip = 1.0
def build( # pylint: disable=too-many-arguments
self,
learning_rate_name,
step_name,
parameter_sequence_name,
gradient_sequence_name,
first_order_moment_sequence_name,
second_order_moment_sequence_name,
):
"""Adds the AdamWOptimizer node to the model."""
# get the model to manipulate
onnx_model = accessor.global_accessor.model
# define the node attributes
node_attributes = {
"alpha": self._betas[0], # beta1
"beta": self._betas[1], # beta2
"epsilon": self._eps, # epsilon
"weight_decay": self._weight_decay, # weight decay
"correct_bias": 1 if self._bias_correction else 0, # bias_correction
"adam_mode": 1, # adam mode (1 for hf/transformers/AdamW)
}
# add the adamw node to the onnx model
adamw_input_names = [
learning_rate_name, # learning rate
step_name, # training step
parameter_sequence_name, # param to be updated
gradient_sequence_name, # gradient of the param to be used for update
first_order_moment_sequence_name, # first order moment for this param
second_order_moment_sequence_name, # second order moment for this param
]
adamw_output_name = graph_utils.generate_random_graph_name("adamw.updated_flag")
adamw_output_names = [adamw_output_name]
adamw_node = onnx.helper.make_node(
"AdamWOptimizer",
adamw_input_names,
adamw_output_names,
name=graph_utils.generate_random_graph_name("AdamWOptimizer"),
domain="com.microsoft",
**node_attributes,
)
onnx_model.graph.node.append(adamw_node)
return adamw_output_name
class ClipGradNorm(building_blocks.Block):
"""Builds a gradient clipping by norm sub graph for the onnx model.
Creates a block that performs gradient clipping by l2 norm for the calculated
gradient.
Args:
max_norm: float indicating the max norm of the gradients.
Returns:
Returns a string of the output names of the gradients after clipping.
"""
def __init__(self, max_norm):
super().__init__()
self._max_norm = max_norm
self._reduce = building_blocks.ReduceAllL2()
self._add = building_blocks.Add()
self._div = building_blocks.Div()
self._mul = building_blocks.Mul()
self._clip = building_blocks.Clip(clip_max=1.0)
def build(self, *gradient_names):
"""Adds a clip grad norm sub graph to the onnx model."""
# get the model to manipulate
onnx_model = accessor.global_accessor.model
# add the necessary graph initializers
add_node_eps_name = graph_utils.generate_random_graph_name("add_eps")
onnx_model.graph.initializer.append(
onnx.helper.make_tensor(add_node_eps_name, onnx.TensorProto.FLOAT, [1], [1e-6])
)
max_norm_name = graph_utils.generate_random_graph_name("max_norm")
onnx_model.graph.initializer.append(
onnx.helper.make_tensor(max_norm_name, onnx.TensorProto.FLOAT, [1], [self._max_norm])
)
# perform gradient clipping
total_norm_name = self._reduce(*gradient_names)
adjusted_total_norm_name = self._add(total_norm_name, add_node_eps_name)
clip_coef_name = self._clip(self._div(max_norm_name, adjusted_total_norm_name))
return [self._mul(grad_name, clip_coef_name) for grad_name in gradient_names]
class AdamW(model.Model):
"""Builds AdamW optimizer onnxblock for the given training parameters.
Creates a block that updates the model parameters based on the calculated
gradient following the AdamW algorithm.
Args:
bias_correction: bool indicating whether to perform bias correction.
betas: AdamW decay rate hyperparameters.
eps: term added to the denominator for computing the moments.
weight_decay: AdamW weight decay
clip_grad (optional): an instance of the ClipGradNorm. If not provided,
gradient clipping will not be done.
Returns:
Returns a string of the output names from this optimizer node.
"""
def __init__(
self, bias_correction=True, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, clip_grad=None
): # pylint: disable=too-many-arguments
super().__init__()
self._adamw = AdamWOptimizer(
bias_correction=bias_correction,
betas=betas,
eps=eps,
weight_decay=weight_decay,
)
self._clip_grad = clip_grad
self._sc = building_blocks.SequenceConstruct()
def build(self, parameters):
"""Returns an AdamW optimizer model based on the input training model."""
"""Returns an AdamW optimizer model based on the input parameters."""
# get the model to manipulate and update its namespace
onnx_model = accessor.global_accessor.model
onnx_model.graph.name = "AdamW Optimizer Model"
onnx_model.producer_name = _PRODUCER_NAME
onnx_model.opset_import.extend(_OPSET_IMPORTS)
onnx_model.ir_version = onnx.IR_VERSION
# TODO: Avoid hard coded input/output strings
learning_rate_name = "learning_rate"
step_name = "step"
gradient_output_suffix = "_grad.accumulation.out"
first_order_moment_suffix = "exp_avg"
second_order_moment_fuffix = "exp_avg_sq"
output_name_suffix = "out"
params_name = "params"
first_order_moments_name = "first_order_moments"
second_order_moments_name = "second_order_moments"
gradient_suffix = "_grad"
trainable_parameters, _ = parameters
graph_nodes = []
graph_inputs = [
onnx.helper.make_tensor_value_info(
learning_rate_name, onnx.TensorProto.FLOAT, [1]
),
onnx.helper.make_tensor_value_info(step_name, onnx.TensorProto.INT64, [1]),
]
graph_outputs = []
# Iterate over all training graph outputs that are gradient outputs
for idx, param in enumerate(trainable_parameters):
param_name = param.name
grad_name = f"{param_name}{gradient_output_suffix}"
first_order_moment_name = f"{param_name}.{first_order_moment_suffix}"
second_order_moment_name = f"{param_name}.{second_order_moment_fuffix}"
# prepare node (and graph) inputs and outputs
node_input_names = [
learning_rate_name, # learning rate
step_name, # training step (used for beta correction)
param_name, # param to be updated
grad_name, # gradient of the param to be used for update
first_order_moment_name, # first order moment for this param
second_order_moment_name, # second order moment for this param
# create the graph inputs for the lr, step, params, grads, moments
onnx_model.graph.input.extend(
[
onnx.helper.make_tensor_value_info(learning_rate_name, onnx.TensorProto.FLOAT, [1]),
onnx.helper.make_tensor_value_info(step_name, onnx.TensorProto.INT64, [1]),
]
param_tensor_value_info = onnx.helper.make_tensor_value_info(
param_name, param.data_type, param.dims
)
grad_tensor_value_info = onnx.helper.make_tensor_value_info(
grad_name, param.data_type, param.dims
)
first_order_moment_tensor_value_info = onnx.helper.make_tensor_value_info(
first_order_moment_name, param.data_type, param.dims
)
second_order_moment_tensor_value_info = onnx.helper.make_tensor_value_info(
second_order_moment_name, param.data_type, param.dims
)
node_inputs = [
param_tensor_value_info,
grad_tensor_value_info,
first_order_moment_tensor_value_info,
second_order_moment_tensor_value_info,
]
graph_inputs.extend(node_inputs)
step_output_name = f"{param_name}.{step_name}.{output_name_suffix}"
param_output_name = f"{param_name}.{output_name_suffix}"
first_order_moment_output_name = (
f"{first_order_moment_name}.{output_name_suffix}"
)
second_order_moment_output_name = (
f"{second_order_moment_name}.{output_name_suffix}"
)
param_output_tensor_value_info = onnx.helper.make_tensor_value_info(
param_output_name, param.data_type, param.dims
)
first_order_moment_output_tensor_value_info = (
onnx.helper.make_tensor_value_info(
first_order_moment_output_name, param.data_type, param.dims
)
)
second_order_moment_output_tensor_value_info = (
onnx.helper.make_tensor_value_info(
second_order_moment_output_name, param.data_type, param.dims
)
)
node_output_names = [
step_output_name, # step out
first_order_moment_output_name, # first order moment output
second_order_moment_output_name, # second order moment output
param_output_name, # updated weights
]
node_outputs = [
onnx.helper.make_tensor_value_info(
step_output_name, onnx.TensorProto.INT64, [1]
),
first_order_moment_output_tensor_value_info,
second_order_moment_output_tensor_value_info,
param_output_tensor_value_info,
]
graph_outputs.extend(node_outputs)
# AdamOptimizer node attributes
node_attributes = {
"alpha": self._betas[0], # beta1
"beta": self._betas[1], # beta2
"lambda": self._weight_decay, # weight decay
"epsilon": self._eps, # epsilon
"do_bias_correction": 1
if self._bias_correction
else 0, # bias_correction
"weight_decay_mode": 1, # weight decay mode
"max_norm_clip": self._max_norm_clip, # used for gradient scaling
}
# make the node
optimizer_node = onnx.helper.make_node(
"AdamOptimizer",
node_input_names,
node_output_names,
name=f"AdamOptimizer{idx}",
domain="com.microsoft",
**node_attributes,
)
graph_nodes.append(optimizer_node)
# make the graph and the model
graph = onnx.helper.make_graph(
graph_nodes, "Optimizer Graph", graph_inputs, graph_outputs
)
model = onnx.helper.make_model(
graph,
producer_name=onnxblock._producer_name,
opset_imports=[onnxblock._opset_import],
)
accessor.global_accessor.model = model
# Prepare the tensor sequence inputs for params and moments
for input_name in [params_name, first_order_moments_name, second_order_moments_name]:
onnx_model.graph.input.append(
onnx.helper.make_tensor_sequence_value_info(input_name, trainable_parameters[0].data_type, None)
)
return [output.name for output in graph_outputs]
# TODO: Make the grads as a tensor sequence input after implementing clip grad
# normalization implementation which takes in a tensor sequence.
grad_names = []
for param in trainable_parameters:
grad_names.append(f"{param.name}{gradient_suffix}")
onnx_model.graph.input.append(
onnx.helper.make_tensor_value_info(grad_names[-1], param.data_type, param.dims)
)
# Clip the gradients if needed
if self._clip_grad is not None:
grad_names = self._clip_grad(*grad_names)
# Run multi tensor AdamWOptimizer
updated_flag_name = self._adamw(
learning_rate_name,
step_name,
params_name,
self._sc(*grad_names),
first_order_moments_name,
second_order_moments_name,
)
# Create the graph outputs
onnx_model.graph.output.append(
onnx.helper.make_tensor_value_info(updated_flag_name, onnx.TensorProto.INT64, [1])
)
return updated_flag_name

View file

@ -421,8 +421,12 @@ def test_bcewithlogits_loss_training_graph_execution():
assert np.allclose(ort_grad, _to_numpy(pt_param.grad))
@pytest.mark.parametrize("graph", [SimpleTrainingModelWithMSELoss, SimpleTrainingModelWithCrossEntropyLoss])
def test_adamw_optimizer_composition(graph):
@pytest.mark.parametrize(
"graph",
[SimpleTrainingModelWithMSELoss, SimpleTrainingModelWithCrossEntropyLoss, SimpleTrainingModelWithBCEWithLogitsLoss],
)
@pytest.mark.parametrize("grad_clipping", [None, onnxblock.optim.ClipGradNorm(2.5)])
def test_adamw_optimizer_composition(graph, grad_clipping):
# Given
device = "cuda"
N, D_in, H, D_out = 64, 784, 500, 10
@ -433,13 +437,15 @@ def test_adamw_optimizer_composition(graph):
with onnxblock.onnx_model(onnx_model):
_ = simple_model(onnx_model.graph.output[0].name)
optimizer = onnxblock.optim.AdamW()
optimizer = onnxblock.optim.AdamW(clip_grad=grad_clipping)
with onnxblock.onnx_model() as accessor:
_ = optimizer(simple_model.parameters())
optimizer_model = accessor.model
assert optimizer_model
# TODO: Add a test for correctness when creation of ortvalues of
# tensor seq is possible on cuda
def test_adamw_optimizer_execution():
# Given
device = "cuda"
@ -455,14 +461,12 @@ def test_adamw_optimizer_execution():
optimizer = onnxblock.optim.AdamW()
with onnxblock.onnx_model() as accessor:
_ = optimizer(simple_model.parameters())
output_name = optimizer(simple_model.parameters())
optimizer_model = accessor.model
learning_rate = 0.001
step = 1
ort_output_names = []
for name, _ in pt_model.named_parameters():
ort_output_names.append(f"{name}.out")
ort_output_names = [output_name]
def mse_loss(prediction, target):
loss = torch.nn.MSELoss()
@ -478,12 +482,15 @@ def test_adamw_optimizer_execution():
ort_inputs = {
"learning_rate": np.full(1, learning_rate, dtype=np.float32),
"step": np.full(1, step, dtype=np.int64),
"params": [],
"first_order_moments": [],
"second_order_moments": [],
}
for name, param in pt_model.named_parameters():
ort_inputs[name] = _to_numpy(copy.deepcopy(param))
ort_inputs[f"{name}_grad.accumulation.out"] = _to_numpy(copy.deepcopy(param.grad))
ort_inputs[f"{name}.exp_avg"] = _to_numpy(torch.zeros_like(param))
ort_inputs[f"{name}.exp_avg_sq"] = _to_numpy(torch.zeros_like(param))
ort_inputs["params"].append(_to_numpy(copy.deepcopy(param)))
ort_inputs[f"{name}_grad"] = _to_numpy(copy.deepcopy(param.grad))
ort_inputs["first_order_moments"].append(_to_numpy(torch.zeros_like(param)))
ort_inputs["second_order_moments"].append(_to_numpy(torch.zeros_like(param)))
# Then no error occurs when executing the model
ort_session = onnxruntime.InferenceSession(onnx_fo.name, providers=C.get_available_providers())
@ -642,3 +649,64 @@ def test_weighted_average_model_composition(model_type):
weighted_model = WeightedAvg(random.random(), random.random())
with onnxblock.onnx_model(onnx_model):
_ = weighted_model(onnx_model.graph.output[0].name, onnx_model.graph.output[1].name)
def test_grad_clipping_execution():
# Given
device = "cuda"
N, D_in, H, D_out = 64, 784, 500, 10
pt_model, _ = _get_models(device, N, D_in, H, D_out)
x = torch.randn(N, D_in, device=device)
target = torch.randn(N, D_out, device=device)
# Prepare the onnx model with only grad clipping
onnx_model = onnx.ModelProto()
onnx_model.graph.name = "AdamW Optimizer Model"
onnx_model.producer_name = "grad clipping test"
onnx_model.opset_import.extend(onnxblock.optim.optim._OPSET_IMPORTS)
onnx_model.ir_version = onnx.IR_VERSION
class GradClippingModel(onnxblock.Model):
def __init__(self, max_norm):
self._grad_clip = onnxblock.optim.ClipGradNorm(max_norm)
def build(self, *grad_names):
return self._grad_clip(*grad_names)
grad_names = []
for name, param in pt_model.named_parameters():
grad_names.append(f"{name}_grad")
onnx_model.graph.input.append(
onnx.helper.make_tensor_value_info(grad_names[-1], onnx.TensorProto.FLOAT, param.shape)
)
grad_clip = GradClippingModel(2.5)
with onnxblock.onnx_model(onnx_model):
ort_output_names = grad_clip(*grad_names)
def mse_loss(prediction, target):
loss = torch.nn.MSELoss()
return loss(prediction, target)
# When
with tempfile.NamedTemporaryFile(suffix=".onnx") as onnx_fo:
onnx.save(onnx_model, onnx_fo.name)
loss = mse_loss(pt_model(x), target)
loss.backward()
ort_inputs = {}
for name, param in pt_model.named_parameters():
ort_inputs[f"{name}_grad"] = _to_numpy(copy.deepcopy(param.grad))
torch.nn.utils.clip_grad_norm_(pt_model.parameters(), 2.5)
# Then no error occurs when executing the model
ort_session = onnxruntime.InferenceSession(onnx_fo.name, providers=C.get_available_providers())
ort_outs = ort_session.run(ort_output_names, ort_inputs)
# assert all the gradients are close
for ort_grad, pt_param in zip(ort_outs, pt_model.parameters()):
assert np.allclose(ort_grad, _to_numpy(pt_param.grad))

View file

@ -19,7 +19,11 @@
#include "core/platform/path_lib.h"
#include "orttraining/core/framework/checkpoint_common.h"
#include "orttraining/training_api/include/interfaces.h"
#include "orttraining/training_api/include/module.h"
#include "orttraining/training_api/include/optimizer.h"
#include "orttraining/training_api/include/checkpoint_property.h"
#include "orttraining/training_api/include/checkpoint.h"
#include "orttraining/training_api/include/lr_scheduler.h"
#include "test/test_environment.h"
#include "test/util/include/asserts.h"
@ -27,6 +31,7 @@
#include "test/util/include/test/test_environment.h"
#include "orttraining/test/training_api/common/synthetic_data_loader.h"
#include "orttraining/test/training_api/core/data_utils.h"
#include "default_providers.h"
using onnxruntime::test::TemporaryDirectory;
using namespace onnxruntime::training::api;
@ -159,7 +164,9 @@ TEST(CheckpointApiTest, SaveOnnxModelAsCheckpoint_ThenLoad_CPU) {
* Save Optimizer states into ORT checkpoint files,
* Then load it into ORT, compare with the initial optimizer states values.
*/
TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad_CPU) {
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad_CUDA) {
/// Phase 1 - Test Preparison
/// Prepare the data and dest folder for saving checkpoint.
/// Also cooked the data for test result comparison.
@ -201,15 +208,17 @@ TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad_CPU) {
onnxruntime::SessionOptions session_option;
std::unique_ptr<Environment> env;
ORT_THROW_IF_ERROR(Environment::Create(nullptr, env));
std::vector<std::shared_ptr<IExecutionProvider>> cuda_provider{onnxruntime::test::DefaultCudaExecutionProvider()};
auto model = std::make_unique<Module>(model_uri, named_parameters, session_option,
*env);
auto optimizer = Optimizer(optim_uri, model->NamedParameters(), session_option, *env);
*env, cuda_provider);
auto optimizer = std::make_unique<Optimizer>(optim_uri, model->NamedParameters(), session_option,
*env, cuda_provider);
/// Phase 2 - Run Optimizer.GetStateDict and call save checkpoint APIs.
/// And check the result checkpoint files.
CheckpointState checkpoint_state;
ORT_ENFORCE(optimizer.GetStateDict(checkpoint_state.optimizer_checkpoint_state).IsOK());
ORT_ENFORCE(optimizer->GetStateDict(checkpoint_state.optimizer_checkpoint_state).IsOK());
// Remove the tempoprary directory if it already exists.
auto ckpt_test_root_dir = ORT_TSTR("checkpointing_api_test_dir");
@ -286,6 +295,8 @@ TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad_CPU) {
}
}
#endif
/**
* Create PropertyBag with sets of properties,
* Save properties into ORT checkpoint files,

View file

@ -20,6 +20,25 @@ void OrtValueToVec(const OrtValue& val, std::vector<T>& output) {
output.assign(val_ptr, val_ptr + num_elem);
}
template <typename T>
void CudaOrtValueToCpuVec(const OrtValue& val, std::vector<T>& output,
std::shared_ptr<IExecutionProvider> cuda_provider,
std::shared_ptr<IExecutionProvider> cpu_provider) {
const Tensor& src_tensor = val.Get<Tensor>();
auto allocator = cpu_provider->GetAllocator(0, OrtMemTypeDefault);
ORT_ENFORCE(allocator, "Cpu allocator is a nullptr.");
auto dst_tensor = std::make_unique<Tensor>(src_tensor.DataType(), src_tensor.Shape(), allocator);
auto data_transfer = cuda_provider->GetDataTransfer();
ORT_ENFORCE(data_transfer, "Cuda data transfer is a nullptr.");
ORT_THROW_IF_ERROR(data_transfer->CopyTensor(src_tensor, *dst_tensor));
const T* val_ptr = dst_tensor->template Data<T>();
output.assign(val_ptr, val_ptr + src_tensor.Shape().Size());
}
} // namespace test
} // namespace training
} // namespace onnxruntime

View file

@ -11,8 +11,13 @@
#include "core/common/path_utils.h"
#include "core/framework/tensorprotoutils.h"
#include "orttraining/training_api/include/utils.h"
#include "orttraining/training_api/include/interfaces.h"
#include "orttraining/training_api/include/module.h"
#include "orttraining/training_api/include/optimizer.h"
#include "orttraining/training_api/include/checkpoint_property.h"
#include "orttraining/training_api/include/checkpoint.h"
#include "orttraining/training_api/include/lr_scheduler.h"
#include "orttraining/test/training_api/core/data_utils.h"
#include "default_providers.h"
using json = nlohmann::json;
using namespace onnxruntime::training;
@ -53,7 +58,7 @@ TEST(TrainingApiTest, ModuleTrainStep) {
std::unique_ptr<Environment> env;
ORT_THROW_IF_ERROR(Environment::Create(nullptr, env));
auto model = std::make_unique<Module>(model_uri, state.module_checkpoint_state.named_parameters, session_option,
*env);
*env, std::vector<std::shared_ptr<IExecutionProvider>>());
OrtValue input, target;
GenerateRandomInput(std::array<int64_t, 2>{2, 784}, input);
@ -95,6 +100,8 @@ TEST(TrainingApiTest, ModuleTrainStep) {
}
}
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(TrainingApiTest, OptimStep) {
auto model_uri = MODEL_FOLDER "gradient_graph.onnx";
auto optim_uri = MODEL_FOLDER "adamw.onnx";
@ -105,10 +112,14 @@ TEST(TrainingApiTest, OptimStep) {
onnxruntime::SessionOptions session_option;
std::unique_ptr<Environment> env;
std::vector<std::shared_ptr<IExecutionProvider>> providers{onnxruntime::test::DefaultCudaExecutionProvider()};
std::shared_ptr<IExecutionProvider> cuda_provider = providers.front();
std::shared_ptr<IExecutionProvider> cpu_provider = onnxruntime::test::DefaultCpuExecutionProvider();
ORT_THROW_IF_ERROR(Environment::Create(nullptr, env));
auto model = std::make_unique<Module>(model_uri, state.module_checkpoint_state.named_parameters, session_option,
*env);
auto optim = std::make_unique<Optimizer>(optim_uri, model->NamedParameters(), session_option, *env);
*env, providers);
auto optim = std::make_unique<Optimizer>(optim_uri, model->NamedParameters(), session_option,
*env, providers);
OrtValue input, target;
GenerateRandomInput(std::array<int64_t, 2>{2, 784}, input);
@ -125,8 +136,11 @@ TEST(TrainingApiTest, OptimStep) {
optimizer_states.group_named_optimizer_states["group0"]->param_named_optimizer_states.at(param_name);
OrtValue& moment_1 = param_state.momentum_named_states.at("momentum0");
std::vector<float> param_vec_before_optimizer_step;
CudaOrtValueToCpuVec(model->NamedParameters().at(param_name)->Data(), param_vec_before_optimizer_step,
cuda_provider, cpu_provider);
std::vector<float> moment_1_vec;
OrtValueToVec(moment_1, moment_1_vec);
CudaOrtValueToCpuVec(moment_1, moment_1_vec, cuda_provider, cpu_provider);
for (size_t i = 0; i < moment_1_vec.size(); i++) {
ORT_ENFORCE(moment_1_vec[i] == 0.0f);
}
@ -136,14 +150,25 @@ TEST(TrainingApiTest, OptimStep) {
std::vector<OrtValue>& inputs = *it;
std::vector<OrtValue> fetches;
ORT_ENFORCE(model->TrainStep(inputs, fetches).IsOK());
std::vector<float> grads;
CudaOrtValueToCpuVec(model->NamedParameters().at(param_name)->Gradient(), grads,
cuda_provider, cpu_provider);
ORT_ENFORCE(optim->Step().IsOK());
// get optim state and check if it is updated
OrtValueToVec(moment_1, moment_1_vec);
CudaOrtValueToCpuVec(moment_1, moment_1_vec, cuda_provider, cpu_provider);
for (size_t i = 0; i < moment_1_vec.size(); i++) {
if (moment_1_vec[i] != 0.0f) {
// moment was updated
break;
if (grads[i] != 0.0f) {
ORT_ENFORCE(moment_1_vec[i] != 0.0f);
}
}
std::vector<float> param_vec_after_optimizer_step;
CudaOrtValueToCpuVec(model->NamedParameters().at(param_name)->Data(), param_vec_after_optimizer_step,
cuda_provider, cpu_provider);
for (size_t i = 0; i < param_vec_after_optimizer_step.size(); ++i) {
if (grads[i] != 0.0f && moment_1_vec[i] != 0.0f) {
ORT_ENFORCE(param_vec_after_optimizer_step[i] != param_vec_before_optimizer_step[i]);
}
}
}
@ -167,9 +192,11 @@ void TestLRSchduler(const std::string& test_file_name, float initial_lr, int64_t
onnxruntime::SessionOptions session_option;
std::unique_ptr<Environment> env;
ORT_THROW_IF_ERROR(Environment::Create(nullptr, env));
const std::vector<std::shared_ptr<IExecutionProvider>> providers{onnxruntime::test::DefaultCudaExecutionProvider()};
auto model = std::make_unique<Module>(model_uri, state.module_checkpoint_state.named_parameters, session_option,
*env);
auto optim = std::make_shared<Optimizer>(optim_uri, model->NamedParameters(), session_option, *env);
*env, providers);
auto optim = std::make_shared<Optimizer>(optim_uri, model->NamedParameters(), session_option,
*env, providers);
OrtValue input, target;
GenerateRandomInput(std::array<int64_t, 2>{2, 784}, input);
@ -259,6 +286,8 @@ TEST(TrainingApiTest, LinearLRScheduler_WarmUp200Step_ResumeFromCheckpoint_Test)
TestLRSchduler("warmup_linear_scheduler_warmupstep-200_restored.json", initial_lr, total_step_count, 200);
}
#endif
} // namespace
} // namespace test
} // namespace training

View file

@ -1,172 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
/**
* @brief Temporary classes bridging between public C interfaces and internal ones.
* Currently implementation is not covering all targeted public training APIs, more changes are expected to add.
*
* TODO: For longer-term, we should re-arange following interfaces following the ways of exiting APIs' definitions and exposures
* in include/onnxruntime/core/session/onnxruntime_c_api.h and include/onnxruntime/core/session/onnxruntime_cxx_api.h.
*
* This is an intermediate header file for our training api internal development usage.
*
*/
#pragma once
#include <onnxruntime_cxx_api.h>
#include "core/session/inference_session.h"
#include "core/session/environment.h"
#include "orttraining/training_api/include/module.h"
#include "orttraining/training_api/include/optimizer.h"
#include "orttraining/training_api/include/lr_scheduler.h"
#include "orttraining/training_api/include/checkpoint_property.h"
#include "orttraining/training_api/include/checkpoint.h"
using onnxruntime::training::api::LinearLRScheduler;
using onnxruntime::training::api::Module;
using onnxruntime::training::api::ModuleCheckpointState;
using onnxruntime::training::api::Optimizer;
using onnxruntime::training::api::OptimizerCheckpointState;
using onnxruntime::training::api::Parameter;
namespace Ort {
namespace {
void ToOrtValue(const std::vector<Ort::Value>& ort_value_list, std::vector<OrtValue>& ortvalue_list) {
size_t input_len = ort_value_list.size();
ortvalue_list.clear();
ortvalue_list.reserve(input_len);
const Ort::Value* ort_value_inputs_ptr = ort_value_list.data();
auto ortvalue_inputs_ptr = reinterpret_cast<const OrtValue**>(const_cast<Ort::Value*>(ort_value_inputs_ptr));
for (size_t i = 0; i < input_len; ++i) {
auto& ort_value = *reinterpret_cast<const ::OrtValue*>(ortvalue_inputs_ptr[i]);
ortvalue_list.push_back(ort_value);
}
}
void FromOrtValue(std::vector<OrtValue>& ortvalue_list, std::vector<Ort::Value>& ort_value_list) {
// Clean the output.
ort_value_list.clear();
size_t output_names_len = ortvalue_list.size();
for (size_t i = 0; i < output_names_len; ++i)
ort_value_list.emplace_back(nullptr);
Ort::Value* ort_value_outputs_ptr = ort_value_list.data();
auto ortvalue_outputs_ptr = reinterpret_cast<OrtValue**>(ort_value_outputs_ptr);
for (size_t i = 0; i != output_names_len; ++i) {
OrtValue& value = ortvalue_list[i];
// Ort::Value will release the pointer once it goes out of scope.
ortvalue_outputs_ptr[i] = new OrtValue(value);
}
}
} // namespace
struct OrtModule {
public:
OrtModule(
OrtEnv* env, OrtSessionOptions* session_options,
const std::string& train_model_path_or_bytes,
std::unordered_map<std::string, std::shared_ptr<onnxruntime::training::api::Parameter>>& parameters,
const std::optional<std::string>& eval_model_path_or_bytes = std::nullopt) {
module_ = std::make_unique<onnxruntime::training::api::Module>(
train_model_path_or_bytes,
parameters,
*reinterpret_cast<::onnxruntime::SessionOptions*>(session_options),
*reinterpret_cast<::onnxruntime::Environment*>(env),
eval_model_path_or_bytes);
}
std::unordered_map<std::string, std::shared_ptr<onnxruntime::training::api::Parameter>> NamedParameters() const {
return module_->NamedParameters();
}
bool ResetGrad() {
return module_->ResetGrad().IsOK();
}
bool TrainStep(const std::vector<Ort::Value>& inputs, std::vector<Ort::Value>& outputs) {
std::vector<OrtValue> feeds;
ToOrtValue(inputs, feeds);
std::vector<OrtValue> fetches;
if (!module_->TrainStep(feeds, fetches).IsOK()) {
return false;
}
// Clean the output.
outputs.clear();
FromOrtValue(fetches, outputs);
return true;
}
bool EvalStep(const std::vector<Ort::Value>& inputs, std::vector<Ort::Value>& outputs) {
std::vector<OrtValue> feeds;
ToOrtValue(inputs, feeds);
std::vector<OrtValue> fetches;
if (!module_->EvalStep(feeds, fetches).IsOK()) {
return false;
}
// Clean the output.
outputs.clear();
FromOrtValue(fetches, outputs);
return true;
}
bool GetStateDict(onnxruntime::training::api::ModuleCheckpointState& module_checkpoint_states) {
return module_->GetStateDict(module_checkpoint_states).IsOK();
}
private:
std::unique_ptr<onnxruntime::training::api::Module> module_;
};
struct OrtOptimizer {
friend struct OrtLinearLRScheduler;
OrtOptimizer(
OrtEnv* env, OrtSessionOptions* session_options,
const std::string& optim_path_or_bytes,
const std::unordered_map<std::string, std::shared_ptr<onnxruntime::training::api::Parameter>>& parameters) {
optimizer_ = std::make_shared<onnxruntime::training::api::Optimizer>(
optim_path_or_bytes,
parameters,
*reinterpret_cast<::onnxruntime::SessionOptions*>(session_options),
*reinterpret_cast<::onnxruntime::Environment*>(env));
}
bool Step() {
return optimizer_->Step().IsOK();
}
bool GetStateDict(onnxruntime::training::api::OptimizerCheckpointState& optimizer_checkpoint_states) {
return optimizer_->GetStateDict(optimizer_checkpoint_states).IsOK();
}
private:
std::shared_ptr<onnxruntime::training::api::Optimizer> optimizer_;
};
struct OrtLinearLRScheduler {
OrtLinearLRScheduler(OrtOptimizer& optimizer, int64_t warmup_step_count, int64_t total_step_count)
: optim_(optimizer.optimizer_) {
linear_lr_scheduler_ = std::make_unique<LinearLRScheduler>(optim_, warmup_step_count, total_step_count);
}
bool Step() {
return linear_lr_scheduler_->Step().IsOK();
}
private:
std::unique_ptr<onnxruntime::training::api::LinearLRScheduler> linear_lr_scheduler_;
std::shared_ptr<onnxruntime::training::api::Optimizer> optim_;
};
} // namespace Ort

View file

@ -61,6 +61,7 @@ struct Module {
const std::unordered_map<std::string, std::shared_ptr<Parameter>>& named_parameters,
const onnxruntime::SessionOptions& session_options,
const Environment& env,
const std::vector<std::shared_ptr<IExecutionProvider>>& providers,
const std::optional<std::string>& eval_model_path_or_bytes = std::nullopt);
// Return the trainable/nontrainable parameters

View file

@ -59,7 +59,8 @@ struct Optimizer {
Optimizer(const std::string& optim_path_or_bytes,
const std::unordered_map<std::string, std::shared_ptr<Parameter>>& named_parameters,
const onnxruntime::SessionOptions& session_options,
const Environment& env);
const Environment& env,
const std::vector<std::shared_ptr<IExecutionProvider>>& providers);
Status Step();

View file

@ -12,16 +12,23 @@ namespace training {
namespace api {
using namespace common;
struct ModelIdentifiers {
const std::string train_model;
const std::optional<std::string> eval_model, optim_model;
ModelIdentifiers(const std::string& train_model_uri,
const std::optional<std::string>& eval_model_uri,
const std::optional<std::string>& optim_model_uri)
: train_model(train_model_uri), eval_model(eval_model_uri), optim_model(optim_model_uri) {}
};
// Wrapper on top of module and optimizer classes and is the only class exposed via capis
class TrainingSession {
public:
TrainingSession(const Environment& session_env,
const SessionOptions& session_options,
const std::unordered_map<std::string, std::shared_ptr<Parameter>>& parameters);
Status Initialize(const std::string& train_model_uri,
const std::optional<std::string>& eval_model_uri,
const std::optional<std::string>& optim_model_uri);
const std::vector<std::shared_ptr<IExecutionProvider>>& providers,
const std::unordered_map<std::string, std::shared_ptr<Parameter>>& parameters,
const ModelIdentifiers& model_identifiers);
size_t GetTrainModeOutputCount() const noexcept;
@ -44,8 +51,6 @@ class TrainingSession {
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TrainingSession);
const Environment& environment_;
SessionOptions session_options_;
const std::unordered_map<std::string, std::shared_ptr<Parameter>> named_parameters_;
std::unique_ptr<Module> module_;
std::unique_ptr<Optimizer> optimizer_;

View file

@ -55,10 +55,14 @@ Module::Module(const std::string& train_model_path_or_bytes,
const std::unordered_map<std::string, std::shared_ptr<Parameter>>& named_parameters,
const onnxruntime::SessionOptions& session_options,
const Environment& env,
const std::optional<std::string>& eval_model_path_or_bytes) : named_parameters_{named_parameters} {
// Create session for training model
const std::vector<std::shared_ptr<IExecutionProvider>>& providers,
const std::optional<std::string>& eval_model_path_or_bytes)
: named_parameters_{named_parameters} {
train_sess_ = std::make_unique<onnxruntime::InferenceSession>(session_options, env);
ORT_THROW_IF_ERROR(train_sess_->Load(train_model_path_or_bytes));
for (const auto& provider : providers) {
ORT_THROW_IF_ERROR(train_sess_->RegisterExecutionProvider(provider));
}
ORT_THROW_IF_ERROR(train_sess_->Initialize());
// Extract model input and output names
@ -123,7 +127,6 @@ Module::Module(const std::string& train_model_path_or_bytes,
auto target_tensor = std::make_unique<Tensor>(param_data_tensor.DataType(), param_data_tensor.Shape(), target_allocator);
ORT_THROW_IF_ERROR(train_sess_state.GetDataTransferMgr().CopyTensor(param_data_tensor, *target_tensor.get()));
auto ml_tensor_type = DataTypeImpl::GetType<Tensor>();
// TODO test the original buffer is released.
param_data.Init(target_tensor.release(), ml_tensor_type, ml_tensor_type->GetDeleteFunc());
}

View file

@ -9,6 +9,20 @@
#include "orttraining/training_api/include/training_session.h"
#include "core/session/abi_session_options_impl.h"
namespace {
std::vector<std::shared_ptr<onnxruntime::IExecutionProvider>> CreateProviders(
const std::vector<std::shared_ptr<onnxruntime::IExecutionProviderFactory>>& provider_factories) {
std::vector<std::shared_ptr<onnxruntime::IExecutionProvider>> execution_providers;
for (const auto& factory : provider_factories) {
execution_providers.emplace_back(std::move(factory->CreateProvider()));
}
return execution_providers;
}
} // namespace
ORT_API_STATUS_IMPL(OrtApis::CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options,
_Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path,
_In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path,
@ -20,16 +34,18 @@ ORT_API_STATUS_IMPL(OrtApis::CreateTrainingSession, _In_ const OrtEnv* env, _In_
*out = nullptr;
ORT_TRY {
using ProvidersType = std::vector<std::shared_ptr<onnxruntime::IExecutionProvider>>;
train_sess = std::make_unique<onnxruntime::training::api::TrainingSession>(
env->GetEnvironment(),
options == nullptr ? onnxruntime::SessionOptions() : options->value,
chkpt_state->module_checkpoint_state.named_parameters);
ORT_API_RETURN_IF_STATUS_NOT_OK(train_sess->Initialize(train_model_path,
eval_model_path ? std::optional<std::string>{eval_model_path}
: std::nullopt,
optimizer_model_path ? std::optional<std::string>{optimizer_model_path}
: std::nullopt));
options == nullptr ? ProvidersType() : CreateProviders(options->provider_factories),
chkpt_state->module_checkpoint_state.named_parameters,
onnxruntime::training::api::ModelIdentifiers{
train_model_path,
eval_model_path ? std::optional<std::string>{eval_model_path}
: std::nullopt,
optimizer_model_path ? std::optional<std::string>{optimizer_model_path}
: std::nullopt});
*out = reinterpret_cast<OrtTrainingSession*>(train_sess.release());
}

View file

@ -20,9 +20,14 @@ const std::string GROUP_ZERO_NAME = "group0";
// TODO: don't hard code the state names, should get the state names according to the optimizer types.
// TODO: Conolidate with frontend tooling
const std::vector<std::string> MOMENT_SUFFIXES{".exp_avg", ".exp_avg_sq"};
const std::vector<std::string> MOMENT_STATE_NAMES{"momentum0", "momentum1"};
constexpr char LearningRateName[] = "learning_rate";
constexpr char StepName[] = "step";
constexpr char ParamsName[] = "params";
constexpr char FirstOrderMomentsName[] = "first_order_moments";
constexpr char SecondOrderMomentsName[] = "second_order_moments";
} // namespace
Status Optimizer::GenerateMomentumNamedStates() {
@ -49,35 +54,70 @@ Status Optimizer::ConstructInputs() {
if (optimizer_type_ == OptimizerType::AdamW) {
auto& param_named_optimizer_states = optimizer_state_.param_named_optimizer_states;
std::string param_name;
std::vector<std::string> param_names, grad_names, moment1_names, moment2_names, user_inputs;
for (size_t i = 2; i < input_names_.size(); i++) {
std::string& name = input_names_[i];
auto it = named_parameters_.find(name);
if (it != named_parameters_.end()) { // is param
param_names.push_back(name);
inputs_.push_back(it->second->Data());
} else if (utils::GetParamNameFromGradient(name, param_name)) {
grad_names.emplace_back(name);
// assert param_name is valid.
auto it = named_parameters_.find(param_name);
ORT_ENFORCE(it != named_parameters_.end(), "Unknown param: ", param_name, " for field: ", name);
inputs_.push_back(it->second->Gradient());
} else if (utils::GetParamNameFromSuffix(name, MOMENT_SUFFIXES[0], param_name)) {
moment1_names.push_back(name);
auto it = named_parameters_.find(param_name);
ORT_ENFORCE(it != named_parameters_.end(), "Unknown param: ", param_name, " for field: ", name);
inputs_.push_back(param_named_optimizer_states.at(param_name).momentum_named_states.at(MOMENT_STATE_NAMES[0]));
} else if (utils::GetParamNameFromSuffix(name, MOMENT_SUFFIXES[1], param_name)) {
moment2_names.push_back(name);
auto it = named_parameters_.find(param_name);
ORT_ENFORCE(it != named_parameters_.end(), "Unknown param: ", param_name, " for field: ", name);
inputs_.push_back(param_named_optimizer_states.at(param_name).momentum_named_states.at(MOMENT_STATE_NAMES[1]));
std::vector<Tensor> params, first_order_moments, second_order_moments;
// TODO: Change to tensor seq implementation once clip grad norm op
// that accepts tensor seq as input for gradients is complete.
std::vector<OrtValue> grads;
// Input names 0-4 are reserved for lr, step, params, first order moments, second order moments
// input names 5 onwards are all the gradient names.
// Collect all the inputs based on the gradient names order.
for (size_t i = 5; i < input_names_.size(); i++) {
std::string param_name;
if (utils::GetParamNameFromGradient(input_names_[i], param_name)) {
const auto named_parameter_it = named_parameters_.find(param_name);
ORT_ENFORCE(named_parameter_it != named_parameters_.end(),
"Unknown param: ", param_name, " for field: ", input_names_[i]);
// Collect the gradients as ortvalues
grads.push_back(named_parameter_it->second->Gradient());
// Collect parameters and prepare for tensorseq creation
auto* param_tensor = named_parameter_it->second->Data().GetMutable<Tensor>();
params.emplace_back(
Tensor(param_tensor->DataType(), param_tensor->Shape(),
param_tensor->MutableDataRaw(), param_tensor->Location()));
// Collect first order moments and prepare for tensorseq creation
auto* first_order_moment_tensor = param_named_optimizer_states.at(param_name)
.momentum_named_states.at(MOMENT_STATE_NAMES[0])
.GetMutable<Tensor>();
first_order_moments.emplace_back(
Tensor(first_order_moment_tensor->DataType(), first_order_moment_tensor->Shape(),
first_order_moment_tensor->MutableDataRaw(), first_order_moment_tensor->Location()));
// Collect second order moments and prepare for tensorseq creation
auto* second_order_moment_tensor = param_named_optimizer_states.at(param_name)
.momentum_named_states.at(MOMENT_STATE_NAMES[1])
.GetMutable<Tensor>();
second_order_moments.emplace_back(
Tensor(second_order_moment_tensor->DataType(), second_order_moment_tensor->Shape(),
second_order_moment_tensor->MutableDataRaw(), second_order_moment_tensor->Location()));
} else {
ORT_ENFORCE("This is an invalid graph. Optimizer graph contains unknown user input:", name);
ORT_ENFORCE(
false, "This is an invalid graph. Optimizer graph contains unknown user input:", input_names_[i]);
}
ORT_ENFORCE(inputs_.back().IsAllocated() && inputs_.back().IsTensor(), "Uninitialized tensor data for ", name);
}
const auto tensorseq_inserter = [](auto& tensors, auto* inputs) {
ORT_ENFORCE(!tensors.empty(), "Tensors cannot be empty while building a tensor sequence.");
auto tensor_seq = std::make_unique<TensorSeq>(tensors.front().DataType());
tensor_seq->SetElements(std::move(tensors));
inputs->emplace_back(
OrtValue(tensor_seq.release(), DataTypeImpl::GetType<TensorSeq>(),
DataTypeImpl::GetType<TensorSeq>()->GetDeleteFunc()));
};
// Add the params and moments as tensorseq ortvalues to inputs
tensorseq_inserter(params, &inputs_);
tensorseq_inserter(first_order_moments, &inputs_);
tensorseq_inserter(second_order_moments, &inputs_);
// Add the gradients as ortvalues to inputs
inputs_.insert(inputs_.end(),
std::make_move_iterator(grads.begin()),
std::make_move_iterator(grads.end()));
}
// Add other optimizer reordering logic here
return Status::OK();
@ -86,17 +126,26 @@ Status Optimizer::ConstructInputs() {
Optimizer::Optimizer(const std::string& optim_path_or_bytes,
const std::unordered_map<std::string, std::shared_ptr<Parameter>>& named_parameters,
const onnxruntime::SessionOptions& session_options,
const Environment& env) : named_parameters_(named_parameters) {
const Environment& env,
const std::vector<std::shared_ptr<IExecutionProvider>>& providers)
: named_parameters_(named_parameters) {
optim_sess_ = std::move(std::make_unique<InferenceSession>(session_options, env));
for (const auto& execution_provider : providers) {
ORT_THROW_IF_ERROR(optim_sess_->RegisterExecutionProvider(execution_provider));
}
ORT_THROW_IF_ERROR(optim_sess_->Load(optim_path_or_bytes));
ORT_THROW_IF_ERROR(optim_sess_->Initialize());
utils::GetGraphInputOutputNames(optim_sess_, input_names_, output_names_);
ORT_ENFORCE(input_names_[0] == "learning_rate"); // TODO: make this better
ORT_ENFORCE(input_names_[1] == "step"); // TODO: make this better
ORT_ENFORCE(input_names_[0] == LearningRateName); // TODO: make this better
ORT_ENFORCE(input_names_[1] == StepName); // TODO: make this better
ORT_ENFORCE(input_names_[2] == ParamsName); // TODO: make this better
if (optimizer_type_ == OptimizerType::AdamW) {
ORT_ENFORCE(input_names_[3] == FirstOrderMomentsName); // TODO: make this better
ORT_ENFORCE(input_names_[4] == SecondOrderMomentsName); // TODO: make this better
ORT_THROW_IF_ERROR(GenerateMomentumNamedStates());
} else {
ORT_THROW("Unsupported optimizer type");
@ -107,7 +156,10 @@ Optimizer::Optimizer(const std::string& optim_path_or_bytes,
Status Optimizer::Step() {
OrtValue learning_rate_input, step_input;
utils::WrapInOrtValue<float>(optimizer_state_.learning_rate, &learning_rate_input);
utils::WrapInOrtValue<int64_t>(optimizer_state_.step, &step_input);
// Use step count + 1 before running optimizer step.
// This is necessary since bias correction uses the step
// as a power. Using power of 0 is wrong.
utils::WrapInOrtValue<int64_t>(optimizer_state_.step + 1, &step_input);
std::vector<OrtValue> feeds({learning_rate_input, step_input});
feeds.insert(feeds.end(), inputs_.begin(), inputs_.end());
@ -116,8 +168,9 @@ Status Optimizer::Step() {
ORT_THROW_IF_ERROR(status);
// extract step output and update
// TODO: need to remove hardcoding
optimizer_state_.step = utils::GetValue<int64_t>(outputs[0]);
if (utils::GetValue<int64_t>(outputs[0]) == 1LL) {
optimizer_state_.step++;
}
return Status::OK();
}

View file

@ -9,23 +9,17 @@ namespace api {
TrainingSession::TrainingSession(const Environment& session_env,
const SessionOptions& session_options,
const std::unordered_map<std::string, std::shared_ptr<Parameter>>& parameters)
: environment_(session_env),
session_options_{session_options},
named_parameters_{parameters} {}
Status TrainingSession::Initialize(const std::string& train_model_uri, const std::optional<std::string>& eval_model_uri,
const std::optional<std::string>& optim_model_uri) {
module_ = std::move(std::make_unique<Module>(train_model_uri, named_parameters_, session_options_,
environment_, eval_model_uri));
if (optim_model_uri.has_value()) {
optimizer_ = std::move(std::make_unique<Optimizer>(optim_model_uri.value(), named_parameters_,
session_options_, environment_));
}
return Status::OK();
}
const std::vector<std::shared_ptr<IExecutionProvider>>& providers,
const std::unordered_map<std::string, std::shared_ptr<Parameter>>& parameters,
const ModelIdentifiers& model_identifiers)
: named_parameters_{parameters},
module_{std::make_unique<Module>(model_identifiers.train_model, named_parameters_,
session_options, session_env, providers, model_identifiers.eval_model)},
optimizer_{model_identifiers.optim_model.has_value()
? std::make_unique<Optimizer>(
model_identifiers.optim_model.value(), named_parameters_,
session_options, session_env, providers)
: std::unique_ptr<Optimizer>()} {}
size_t TrainingSession::GetTrainModeOutputCount() const noexcept {
return module_->GetTrainModeOutputCount();

View file

@ -63,12 +63,22 @@ Status OrtValueLike(const SessionState& sess_state, const OrtValue& input_val, O
auto element_type = param_tensor.DataType();
auto p_tensor = std::make_unique<Tensor>(element_type, shape, allocator);
// TODO: handle CUDA memset
if (tensor_location.device.Type() == OrtDevice::CPU ||
tensor_location.mem_type == OrtMemTypeCPUInput ||
tensor_location.mem_type == OrtMemTypeCPUOutput) {
memset(p_tensor->MutableDataRaw(), 0, p_tensor->SizeInBytes());
} else if (tensor_location.device.Type() == OrtDevice::GPU) {
// Use a tensor on cpu and copy it over to the desired device using
// the data transfer manager.
AllocatorPtr cpu_allocator = sess_state.GetAllocator(OrtDevice());
auto p_cpu_tensor = std::make_unique<Tensor>(element_type, shape, cpu_allocator);
memset(p_cpu_tensor->MutableDataRaw(), 0, p_cpu_tensor->SizeInBytes());
// No need to free the cpu buffer, tensor destructor takes care of it using the cpu_allocator
ORT_THROW_IF_ERROR(sess_state.GetDataTransferMgr().CopyTensor(*p_cpu_tensor, *p_tensor));
} else {
ORT_THROW("Cannot create tensor on device ", tensor_location.device.Type());
}
output_val.Init(p_tensor.release(),
DataTypeImpl::GetType<Tensor>(),
DataTypeImpl::GetType<Tensor>()->GetDeleteFunc());

View file

@ -17,7 +17,7 @@ extend_skip_glob = [
convention = "google"
[tool.pylint.'MESSAGES CONTROL']
disable = ["format", "line-too-long", "import-error", "no-name-in-module"]
disable = ["format", "line-too-long", "import-error", "no-name-in-module", "fixme", "too-few-public-methods"]
[tool.pyright]
exclude = ["onnxruntime/core/flatbuffers/*"]

View file

@ -374,7 +374,7 @@ requirements_file = "requirements.txt"
local_version = None
enable_training = parse_arg_remove_boolean(sys.argv, "--enable_training")
enable_training_on_device = parse_arg_remove_boolean(sys.argv, '--enable_training_on_device')
enable_training_on_device = parse_arg_remove_boolean(sys.argv, "--enable_training_on_device")
disable_auditwheel_repair = parse_arg_remove_boolean(sys.argv, "--disable_auditwheel_repair")
default_training_package_device = parse_arg_remove_boolean(sys.argv, "--default_training_package_device")
@ -421,7 +421,9 @@ if enable_training:
]
)
if enable_training_on_device:
packages.append('onnxruntime.training.onnxblock')
packages.append("onnxruntime.training.onnxblock")
packages.append("onnxruntime.training.onnxblock.loss")
packages.append("onnxruntime.training.onnxblock.optim")
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.aten_op_executor"] = ["*.cc"]
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils"] = ["*.cc"]
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator"] = ["*.cc"]

View file

@ -179,8 +179,7 @@ def parse_arguments():
parser.add_argument(
"--enable_training_torch_interop", action="store_true", help="Enable training kernels interop with torch."
)
parser.add_argument(
"--enable_training_on_device", action='store_true', help="Enable on device training in ORT.")
parser.add_argument("--enable_training_on_device", action="store_true", help="Enable on device training in ORT.")
parser.add_argument("--disable_nccl", action="store_true", help="Disable Nccl.")
parser.add_argument("--mpi_home", help="Path to MPI installation dir")
parser.add_argument("--nccl_home", help="Path to NCCL installation dir")