diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 7c7b78f25c..894e7524be 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -226,7 +226,7 @@ static Status BatchOrCopyMLValue(const SessionState& session_state, } ++source_iter; ++target_iter; - } //while + } // while } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported OrtValue type to copy between device."); } @@ -322,9 +322,9 @@ static common::Status CalculateStaticCopyInfoForFetches(const SessionState& sess // If for some reason using just the device from the allocation plan isn't enough, the following // would use the NodeInfo from the node producing the output // - //std::vector node_info_vec; - //auto status = session_state.GetOutputNodeInfo(output_name, node_info_vec); - //if (status.IsOK()) { + // std::vector node_info_vec; + // auto status = session_state.GetOutputNodeInfo(output_name, node_info_vec); + // if (status.IsOK()) { // const auto& node_info = node_info_vec.front(); // only one entry as only one node can produce a given output // copy_info[idx].source_device = *node_info.device; //} else { @@ -428,6 +428,11 @@ static void FinalizeFeedFetchCopyInfo(FeedsFetchesManager& feeds_fetches_manager const auto& feed = feeds[i]; if (feed.IsTensor()) { feed_locations[i] = feed.Get().Location().device; + } else if (feed.IsTensorSequence()) { + const auto& tensor_seq = feed.Get(); + if (tensor_seq.Size() != std::size_t{0}) { + feed_locations[i] = tensor_seq.Get(0).Location().device; + } } else if (feed.IsSparseTensor()) { #if !defined(DISABLE_SPARSE_TENSORS) feed_locations[i] = feed.Get().Location().device; diff --git a/onnxruntime/test/testdata/training_api/adamw.onnx b/onnxruntime/test/testdata/training_api/adamw.onnx index 20581ee1f7..6eb4e83a76 100644 Binary files a/onnxruntime/test/testdata/training_api/adamw.onnx and b/onnxruntime/test/testdata/training_api/adamw.onnx differ diff --git a/orttraining/orttraining/python/training/onnxblock/__init__.py b/orttraining/orttraining/python/training/onnxblock/__init__.py index e90f2ff533..96baa2bcab 100644 --- a/orttraining/orttraining/python/training/onnxblock/__init__.py +++ b/orttraining/orttraining/python/training/onnxblock/__init__.py @@ -2,13 +2,8 @@ # Licensed under the MIT License. # __init__.py -from .building_blocks import Block -from .model import Model, TrainingModel -from .checkpoint_utils import save_checkpoint from . import loss, optim +from .building_blocks import Block +from .checkpoint_utils import save_checkpoint +from .model import Model, TrainingModel from .model_accessor import onnx_model - -import onnx - -_producer_name = "onnxblock offline tooling" -_opset_import = onnx.helper.make_opsetid("com.microsoft", 1) diff --git a/orttraining/orttraining/python/training/onnxblock/_building_blocks.py b/orttraining/orttraining/python/training/onnxblock/_building_blocks.py deleted file mode 100644 index 10a8f04c57..0000000000 --- a/orttraining/orttraining/python/training/onnxblock/_building_blocks.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# _building_blocks.py - -import copy -import onnx - -import onnxruntime.training.onnxblock as onnxblock -import onnxruntime.training.onnxblock.model_accessor as accessor -import onnxruntime.training.onnxblock._graph_utils as graph_utils - - -class Sub(onnxblock.Model): - """Adds Sub node to an onnx model.""" - - def __init__(self): - super(Sub, self).__init__() - - def build(self, sub_input_name1, sub_input_name2): - # get the model to manipulate - onnx_model = accessor.global_accessor.model - - # create the graph node for sub - sub_node_input_names = [sub_input_name1, sub_input_name2] - sub_node_output_name = graph_utils.generate_random_graph_name("sub_output") - sub_node_output_names = [sub_node_output_name] - sub_node = onnx.helper.make_node( - "Sub", - sub_node_input_names, - sub_node_output_names, - name=graph_utils.generate_random_graph_name("Sub"), - ) - onnx_model.graph.node.append(sub_node) - - # create the graph output for sub - graph_output = copy.deepcopy( - graph_utils.get_output_from_output_name(onnx_model, sub_input_name1) - ) - graph_output.name = sub_node_output_name - del onnx_model.graph.output[:] - onnx_model.graph.output.append(graph_output) - - return sub_node_output_name - - -class Pow(onnxblock.Model): - """Adds Pow node to the onnx model.""" - - def __init__(self, exponent): - super(Pow, self).__init__() - - self._exponent = exponent - - def build(self, pow_input_name): - # get the model to manipulate - onnx_model = accessor.global_accessor.model - - # create the graph initializer for the exponent - pow_node_exponent_name = graph_utils.generate_random_graph_name("pow_exponent") - onnx_model.graph.initializer.append( - onnx.helper.make_tensor( - pow_node_exponent_name, onnx.TensorProto.FLOAT, [1], [self._exponent] - ) - ) - - # create the graph node for pow - pow_node_input_names = [pow_input_name, pow_node_exponent_name] - pow_node_output_name = graph_utils.generate_random_graph_name("pow_output") - pow_node_output_names = [pow_node_output_name] - pow_node = onnx.helper.make_node( - "Pow", - pow_node_input_names, - pow_node_output_names, - name=graph_utils.generate_random_graph_name("Pow"), - ) - onnx_model.graph.node.append(pow_node) - - # create the graph output for pow - graph_output = copy.deepcopy( - graph_utils.get_output_from_output_name(onnx_model, pow_input_name) - ) - graph_output.name = pow_node_output_name - del onnx_model.graph.output[:] - onnx_model.graph.output.append(graph_output) - - return pow_node_output_name - - -class _Reduce(onnxblock.Model): - """Base class for the reduce blocks.""" - - def __init__(self): - super(_Reduce, self).__init__() - - def _reduce(self, reduce_input_name, reduction_op): - # get the model to manipulate - onnx_model = accessor.global_accessor.model - - # create the graph node for reduce - reduce_node_input_names = [reduce_input_name] - reduce_node_output_name = graph_utils.generate_random_graph_name( - "reduce_output" - ) - reduce_node_output_names = [reduce_node_output_name] - reduce_node = onnx.helper.make_node( - reduction_op, - reduce_node_input_names, - reduce_node_output_names, - name=graph_utils.generate_random_graph_name(reduction_op), - ) - onnx_model.graph.node.append(reduce_node) - - # create the graph output for reduce - reduce_input = copy.deepcopy( - graph_utils.get_output_from_output_name(onnx_model, reduce_input_name) - ) - output_rank = len(reduce_input.type.tensor_type.shape.dim) - graph_outputs = [ - onnx.helper.make_tensor_value_info( - reduce_node_output_name, onnx.TensorProto.FLOAT, [1] * output_rank - ) - ] - del onnx_model.graph.output[:] - onnx_model.graph.output.extend(graph_outputs) - - return reduce_node_output_name - - -class ReduceMean(_Reduce): - """Adds ReduceMean node to the onnx model.""" - - def __init__(self): - super(ReduceMean, self).__init__() - - def build(self, reduce_input_name): - return super()._reduce(reduce_input_name, "ReduceMean") - - -class ReduceSum(_Reduce): - """Adds ReduceSum node to the onnx model.""" - - def __init__(self): - super(ReduceSum, self).__init__() - - def build(self, reduce_input_name): - return super(ReduceSum, self)._reduce(reduce_input_name, "ReduceSum") diff --git a/orttraining/orttraining/python/training/onnxblock/_graph_utils.py b/orttraining/orttraining/python/training/onnxblock/_graph_utils.py index aa814f5b92..f5291a69c3 100644 --- a/orttraining/orttraining/python/training/onnxblock/_graph_utils.py +++ b/orttraining/orttraining/python/training/onnxblock/_graph_utils.py @@ -3,9 +3,10 @@ # _graph_utils.py import copy -import onnx import random +import onnx + from onnxruntime.capi._pybind_state import GradientGraphBuilder @@ -21,44 +22,52 @@ def get_output_from_output_name(onnx_model, output_name): def get_random_number(): - """Return a random number in the range 0, 1000.""" + """Return a random number in the range 0, 100000.""" - return random.randint(0, 1000) + return random.randint(0, 100000) def generate_random_graph_name(token): """Return a string that can be used in the graph as a graph attribute name.""" - return f"{get_random_number()}.{token}" + return f"onnx::{token}::{get_random_number()}" -def build_gradient_graph( - accessor, user_args_requiring_grad, user_args_not_requiring_grad, output_names -): +def _reorder_outputs(model, user_output_names, args_requiring_gradient_names): + graph_outputs = {output.name: output for output in model.graph.output} + ordered_graph_outputs = [graph_outputs[name] for name in user_output_names] + + for arg in args_requiring_gradient_names: + gradient_name = f"{arg}_grad" + ordered_graph_outputs.append(graph_outputs[gradient_name]) + + del model.graph.output[:] + model.graph.output.extend(ordered_graph_outputs) + + +def build_gradient_graph(accessor, user_args_requiring_grad, user_args_not_requiring_grad, output_names): """Builds the gradient graph on top of the given input forward only graph.""" model = accessor.model # Collect names of parameters that need gradients computed - all_args_requiring_gradient = set() + all_args_requiring_gradient = [] # Move all trainable and non trainable initializers to graph inputs. # This allows training to pass in the parameters from outside the graph # so as to share the parameters across multiple sessions. graph_inputs = model.graph.input initializers = [] for initializer in model.graph.initializer: - if not initializer.name[0].isdigit(): + if not initializer.name.startswith("onnx::"): # Move only those initializers as inputs that are not local # to the onnx model. i.e. initializers that are model parameters. - # These are tpically those initializers without any number prefixed + # These are tpically those initializers without any onnx:: prefixed # to their names. graph_inputs.append( - onnx.helper.make_tensor_value_info( - initializer.name, initializer.data_type, initializer.dims - ) + onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, initializer.dims) ) if initializer.name not in user_args_not_requiring_grad: - all_args_requiring_gradient.add(initializer.name) + all_args_requiring_gradient.append(initializer.name) else: # All other initializers stay where they were. initializers.append(initializer) @@ -74,7 +83,7 @@ def build_gradient_graph( # args_requiring_grad. So, add these arguments to set of arguments # whose gradient should be built. for argument_name in user_args_requiring_grad: - all_args_requiring_gradient.add(argument_name) + all_args_requiring_gradient.append(argument_name) # Assumption is that the first graph output is the loss output if isinstance(output_names, str): @@ -82,13 +91,16 @@ def build_gradient_graph( builder = GradientGraphBuilder( model.SerializeToString(), set(output_names), - all_args_requiring_gradient, + set(all_args_requiring_gradient), output_names[0], ) builder.build() gradient_model = onnx.load_from_string(builder.get_model()) - # copy the gradient graph into the user's graph + # Reorder gradient outputs for the gradient model based on the all_args_requiring_gradient order + _reorder_outputs(gradient_model, output_names, all_args_requiring_gradient) + + # copy the gradient model into the user's model model.CopyFrom(gradient_model) return all_args_requiring_gradient @@ -117,9 +129,7 @@ def build_gradient_accumulation_graph(grad_model, all_args_requiring_gradient_na """ # TODO: Avoid hard coded input/output strings - gradient_output_names = { - f"{arg_name}_grad" for arg_name in all_args_requiring_gradient_names - } + gradient_output_names = {f"{arg_name}_grad" for arg_name in all_args_requiring_gradient_names} graph_inputs = grad_model.graph.input graph_nodes = grad_model.graph.node @@ -139,12 +149,8 @@ def build_gradient_accumulation_graph(grad_model, all_args_requiring_gradient_na # gradient accumulation node inputs and output names grad_name = graph_output.name - grad_accumulation_buffer_name = ( - f"{grad_name}.{gradient_accumulation_name}.{gradient_buffer_name_suffix}" - ) - grad_accumulation_output_name = ( - f"{grad_name}.{gradient_accumulation_name}.{gradient_output_name_suffix}" - ) + grad_accumulation_buffer_name = f"{grad_name}.{gradient_accumulation_name}.{gradient_buffer_name_suffix}" + grad_accumulation_output_name = f"{grad_name}.{gradient_accumulation_name}.{gradient_output_name_suffix}" # Gradient accumulation node acc_node = onnx.helper.make_node( @@ -167,9 +173,7 @@ def build_gradient_accumulation_graph(grad_model, all_args_requiring_gradient_na grad_accumulation_output.name = grad_accumulation_output_name graph_outputs.append(grad_accumulation_output) - lazy_reset_grad_input = onnx.helper.make_tensor_value_info( - lazy_reset_grad_input_name, onnx.TensorProto.BOOL, [1] - ) + lazy_reset_grad_input = onnx.helper.make_tensor_value_info(lazy_reset_grad_input_name, onnx.TensorProto.BOOL, [1]) graph_inputs.append(lazy_reset_grad_input) del grad_model.graph.output[:] @@ -183,18 +187,18 @@ def get_model_parameters(model, args_not_requiring_gradient): non_trainable_params = [] for initializer in model.graph.initializer: # All model parameters should have their names not begin with - # a digit. So, check to see if the initializer's first char - # is a digit. If not, it is either a trainable, or a non + # `onnx::`. So, check to see if the initializer begins with + # `onnx::`. If not, it is either a trainable, or a non # trainable parameter. # Note that this assumption can be made because the export logic # does not change the names of the original model parameters. # and the original model parameters don't have their names begin - # with a digit. + # `onnx::`. # On the other hand, const initializers are generated by export - # logic and have a digit prefix. + # logic and have a `onnx::` prefix. # TODO: validate this assumption. If assumption is not valid, # the alternative is to enforce the user to provide the parameter names. - if not initializer.name[0].isdigit(): + if not initializer.name.startswith("onnx::"): if initializer.name in args_not_requiring_gradient: non_trainable_params.append(initializer) else: @@ -215,9 +219,7 @@ def build_graph_outputs(model, output_names): if isinstance(output_names, str): output_names = [output_names] - name_value_info_mapping = { - value_info.name: value_info for value_info in model.graph.value_info - } + name_value_info_mapping = {value_info.name: value_info for value_info in model.graph.value_info} name_graph_output_mapping = {output.name: output for output in model.graph.output} # collect all new graph outputs (i.e. graph outputs that are not @@ -229,9 +231,7 @@ def build_graph_outputs(model, output_names): elif output_name in name_value_info_mapping: graph_outputs.append(name_value_info_mapping[output_name]) else: - raise LookupError( - f"The provided name {output_name} is not a graph value info or a graph output." - ) + raise LookupError(f"The provided name {output_name} is not a graph value info or a graph output.") # Clear all existing graph outputs del model.graph.output[:] diff --git a/orttraining/orttraining/python/training/onnxblock/building_blocks.py b/orttraining/orttraining/python/training/onnxblock/building_blocks.py index f48a519dac..1fd23f2289 100644 --- a/orttraining/orttraining/python/training/onnxblock/building_blocks.py +++ b/orttraining/orttraining/python/training/onnxblock/building_blocks.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -# _building_blocks.py +# building_blocks.py from abc import ABC, abstractmethod @@ -38,7 +38,7 @@ class Block(ABC): class _BinaryOp(Block): def __init__(self, op_name): - super(_BinaryOp, self).__init__() + super().__init__() self._op_name = op_name def build(self, input_name1, input_name2): @@ -68,28 +68,35 @@ class Add(_BinaryOp): """Adds Add node to an onnx model.""" def __init__(self): - super(Add, self).__init__("Add") + super().__init__("Add") class Sub(_BinaryOp): """Adds Sub node to an onnx model.""" def __init__(self): - super(Sub, self).__init__("Sub") + super().__init__("Sub") class Mul(_BinaryOp): """Adds Mul node to an onnx model.""" def __init__(self): - super(Mul, self).__init__("Mul") + super().__init__("Mul") + + +class Div(_BinaryOp): + """Adds Div node to an onnx model.""" + + def __init__(self): + super().__init__("Div") class Pow(Block): """Adds Pow node to the onnx model.""" def __init__(self, exponent): - super(Pow, self).__init__() + super().__init__() self._exponent = exponent @@ -119,8 +126,10 @@ class Pow(Block): class _UnaryOp(Block): + """Base class for all nodes that take in a single argument.""" + def __init__(self, op_name): - super(_UnaryOp, self).__init__() + super().__init__() self._op_name = op_name def build(self, input_name): @@ -150,29 +159,35 @@ class ReduceMean(_UnaryOp): """Adds ReduceMean node to the onnx model.""" def __init__(self): - super(ReduceMean, self).__init__("ReduceMean") + super().__init__("ReduceMean") class ReduceSum(_UnaryOp): """Adds ReduceSum node to the onnx model.""" def __init__(self): - super(ReduceSum, self).__init__("ReduceSum") + super().__init__("ReduceSum") class Sigmoid(_UnaryOp): + """Adds Sigmoid node to the onnx model.""" + def __init__(self): - super(Sigmoid, self).__init__("Sigmoid") + super().__init__("Sigmoid") class Log(_UnaryOp): + """Adds Log node to the onnx model.""" + def __init__(self): - super(Log, self).__init__("Log") + super().__init__("Log") class Neg(_UnaryOp): + """Adds Neg node to the onnx model.""" + def __init__(self): - super(Neg, self).__init__("Neg") + super().__init__("Neg") class Constant(Block): @@ -192,3 +207,109 @@ class Constant(Block): onnx.helper.make_tensor(initializer_name, onnx.TensorProto.FLOAT, [1], [self._value]) ) return initializer_name + + +class SequenceConstruct(Block): + """Adds SequenceConstruct node to the onnx model.""" + + def __init__(self): + super().__init__() + + def build(self, *sequence_input_names): + # get the model to manipulate + onnx_model = accessor.global_accessor.model + + # create the graph node for this sequence construct node + sc_node_input_names = list(sequence_input_names) + sc_node_output_name = graph_utils.generate_random_graph_name("sequenceconstruct_output") + sc_node_output_names = [sc_node_output_name] + sc_node = onnx.helper.make_node( + "SequenceConstruct", + sc_node_input_names, + sc_node_output_names, + graph_utils.generate_random_graph_name("SequenceConstruct"), + ) + onnx_model.graph.node.append(sc_node) + + return sc_node_output_name + + +class ReduceAllL2(Block): + """Adds ReduceAllL2 node to the onnx model. + + ReduceAllL2 is a part of the com.microsoft domain and might not be accessible outside this domain. + """ + + def __init__(self): + super().__init__() + + def build(self, *reduce_node_input_names): + # get the model to manipulate + onnx_model = accessor.global_accessor.model + + # create the graph node for this reducealll2 node + reduce_node_input_names = list(reduce_node_input_names) + reduce_node_output_name = graph_utils.generate_random_graph_name("reducealll2_output") + reduce_node_output_names = [reduce_node_output_name] + reduce_node = onnx.helper.make_node( + "ReduceAllL2", + reduce_node_input_names, + reduce_node_output_names, + graph_utils.generate_random_graph_name("ReduceAllL2"), + domain="com.microsoft", + ) + onnx_model.graph.node.append(reduce_node) + # TODO: register shape inference with onnx + onnx_model.graph.value_info.append( + onnx.helper.make_tensor_value_info(reduce_node_output_name, onnx.TensorProto.FLOAT, [1]) + ) + + return reduce_node_output_name + + +class Clip(Block): + """Adds Clip node to the onnx model.""" + + def __init__(self, clip_min=None, clip_max=None): + super().__init__() + + self._min = clip_min + self._max = clip_max + + def build(self, clip_input_name): + # get the model to manipulate + onnx_model = accessor.global_accessor.model + + # create the graph initializer for the clip min + clip_node_min_name = "" + if self._min is not None: + clip_node_min_name = graph_utils.generate_random_graph_name("clip_min") + onnx_model.graph.initializer.append( + onnx.helper.make_tensor(clip_node_min_name, onnx.TensorProto.FLOAT, [1], [self._min]) + ) + + # create the graph initializer for the clip max + clip_node_max_name = "" + if self._max is not None: + clip_node_max_name = graph_utils.generate_random_graph_name("clip_max") + onnx_model.graph.initializer.append( + onnx.helper.make_tensor(clip_node_max_name, onnx.TensorProto.FLOAT, [1], [self._max]) + ) + + # create the graph node for this clip node + clip_node_input_names = [ + clip_input_name, + clip_node_min_name, + clip_node_max_name, + ] + clip_node_output_name = graph_utils.generate_random_graph_name("clip_output") + clip_node_output_names = [clip_node_output_name] + clip_node = onnx.helper.make_node( + "Clip", + clip_node_input_names, + clip_node_output_names, + graph_utils.generate_random_graph_name("Clip"), + ) + onnx_model.graph.node.append(clip_node) + + return clip_node_output_name diff --git a/orttraining/orttraining/python/training/onnxblock/checkpoint_utils.py b/orttraining/orttraining/python/training/onnxblock/checkpoint_utils.py index 411a8e266c..b635708bc8 100644 --- a/orttraining/orttraining/python/training/onnxblock/checkpoint_utils.py +++ b/orttraining/orttraining/python/training/onnxblock/checkpoint_utils.py @@ -16,6 +16,4 @@ def save_checkpoint(parameters, path_to_checkpoint): trainable_params, non_trainable_params = parameters trainable_params = [param.SerializeToString() for param in trainable_params] non_trainable_params = [param.SerializeToString() for param in non_trainable_params] - _internal_save_checkpoint( - trainable_params, non_trainable_params, path_to_checkpoint - ) + _internal_save_checkpoint(trainable_params, non_trainable_params, path_to_checkpoint) diff --git a/orttraining/orttraining/python/training/onnxblock/loss/loss.py b/orttraining/orttraining/python/training/onnxblock/loss/loss.py index a144f56fc9..17b74194a1 100644 --- a/orttraining/orttraining/python/training/onnxblock/loss/loss.py +++ b/orttraining/orttraining/python/training/onnxblock/loss/loss.py @@ -3,11 +3,12 @@ # loss.py import copy + import onnx -import onnxruntime.training.onnxblock.model_accessor as accessor -import onnxruntime.training.onnxblock.building_blocks as building_blocks import onnxruntime.training.onnxblock._graph_utils as graph_utils +import onnxruntime.training.onnxblock.building_blocks as building_blocks +import onnxruntime.training.onnxblock.model_accessor as accessor class MSELoss(building_blocks.Block): @@ -25,11 +26,7 @@ class MSELoss(building_blocks.Block): if reduction != "mean" and reduction != "sum": raise RuntimeError(f"Reduction {reduction} not supported.") - self._reduce = ( - building_blocks.ReduceMean() - if reduction == "mean" - else building_blocks.ReduceSum() - ) + self._reduce = building_blocks.ReduceMean() if reduction == "mean" else building_blocks.ReduceSum() self._sub = building_blocks.Sub() self._square = building_blocks.Pow(2.0) @@ -53,9 +50,7 @@ class MSELoss(building_blocks.Block): # create a new graph input. this is the target input needed to compare # the graph output against to calculate loss. # TODO: Move input creation outside of the blocks. - target_input = copy.deepcopy( - graph_utils.get_output_from_output_name(onnx_model, loss_input_name) - ) + target_input = copy.deepcopy(graph_utils.get_output_from_output_name(onnx_model, loss_input_name)) target_input.name = target_name onnx_model.graph.input.append(target_input) @@ -106,15 +101,11 @@ class CrossEntropyLoss(building_blocks.Block): weight_name = graph_utils.generate_random_graph_name("celoss.weight") if self._weight is not None: - onnx_model.graph.initializer.append( - onnx.numpy_helper.from_array(self._weight, weight_name) - ) + onnx_model.graph.initializer.append(onnx.numpy_helper.from_array(self._weight, weight_name)) # create a new graph input. this is the labels input needed to compare # the graph output against to calculate loss. - labels_input = copy.deepcopy( - graph_utils.get_output_from_output_name(onnx_model, scores_input_name) - ) + labels_input = copy.deepcopy(graph_utils.get_output_from_output_name(onnx_model, scores_input_name)) labels_input.name = labels_name labels_input.type.tensor_type.elem_type = onnx.TensorProto.INT32 # if the predictions are (num_examples x num_classes) @@ -163,11 +154,7 @@ class BCEWithLogitsLoss(building_blocks.Block): raise RuntimeError(f"Reduction {reduction} not supported.") self._weight = weight - self._reduce = ( - building_blocks.ReduceMean() - if reduction == "mean" - else building_blocks.ReduceSum() - ) + self._reduce = building_blocks.ReduceMean() if reduction == "mean" else building_blocks.ReduceSum() self._pos_weight = pos_weight self._sigmoid = building_blocks.Sigmoid() @@ -198,38 +185,24 @@ class BCEWithLogitsLoss(building_blocks.Block): # create the graph initializers for pos_weight, weight, and the sub operands ([1]) pos_weight_name = graph_utils.generate_random_graph_name("bceloss.pos_weight") if self._pos_weight is not None: - onnx_model.graph.initializer.append( - onnx.numpy_helper.from_array(self._pos_weight, pos_weight_name) - ) + onnx_model.graph.initializer.append(onnx.numpy_helper.from_array(self._pos_weight, pos_weight_name)) weight_name = graph_utils.generate_random_graph_name("bceloss.weight") if self._weight is not None: - onnx_model.graph.initializer.append( - onnx.numpy_helper.from_array(self._weight, weight_name) - ) + onnx_model.graph.initializer.append(onnx.numpy_helper.from_array(self._weight, weight_name)) - sub_ones_operand_name1 = graph_utils.generate_random_graph_name( - "bceloss.sub_ones" - ) + sub_ones_operand_name1 = graph_utils.generate_random_graph_name("bceloss.sub_ones") onnx_model.graph.initializer.append( - onnx.helper.make_tensor( - sub_ones_operand_name1, onnx.TensorProto.FLOAT, [1], [1.0] - ) - ) - sub_ones_operand_name2 = graph_utils.generate_random_graph_name( - "bceloss.sub_ones" + onnx.helper.make_tensor(sub_ones_operand_name1, onnx.TensorProto.FLOAT, [1], [1.0]) ) + sub_ones_operand_name2 = graph_utils.generate_random_graph_name("bceloss.sub_ones") onnx_model.graph.initializer.append( - onnx.helper.make_tensor( - sub_ones_operand_name2, onnx.TensorProto.FLOAT, [1], [1.0] - ) + onnx.helper.make_tensor(sub_ones_operand_name2, onnx.TensorProto.FLOAT, [1], [1.0]) ) # create a new graph input. this is the target input needed to compare # the graph output against to calculate loss. - target_input = copy.deepcopy( - graph_utils.get_output_from_output_name(onnx_model, loss_input_name) - ) + target_input = copy.deepcopy(graph_utils.get_output_from_output_name(onnx_model, loss_input_name)) target_input.name = target_name onnx_model.graph.input.append(target_input) diff --git a/orttraining/orttraining/python/training/onnxblock/model.py b/orttraining/orttraining/python/training/onnxblock/model.py index fc292e5d77..5502e3b361 100644 --- a/orttraining/orttraining/python/training/onnxblock/model.py +++ b/orttraining/orttraining/python/training/onnxblock/model.py @@ -3,10 +3,12 @@ # model.py from abc import abstractmethod + import onnx -import onnxruntime.training.onnxblock.model_accessor as accessor -import onnxruntime.training.onnxblock.building_blocks as building_blocks + import onnxruntime.training.onnxblock._graph_utils as graph_utils +import onnxruntime.training.onnxblock.building_blocks as building_blocks +import onnxruntime.training.onnxblock.model_accessor as accessor class Model(building_blocks.Block): @@ -32,9 +34,7 @@ class Model(building_blocks.Block): output = self.build(*args, **kwargs) # Perform shape inference - model_with_shapes = onnx.shape_inference.infer_shapes( - accessor.global_accessor.model - ) + model_with_shapes = onnx.shape_inference.infer_shapes(accessor.global_accessor.model) accessor.global_accessor.model.CopyFrom(model_with_shapes) # Build the graph outputs @@ -85,10 +85,7 @@ class TrainingModel(building_blocks.Block): is built, an exception will be raised. """ if self._parameters is None: - raise RuntimeError( - "Please build the training model first before trying to " - "retrieve the parameters." - ) + raise RuntimeError("Please build the training model first before trying to " "retrieve the parameters.") return self._parameters @@ -108,9 +105,7 @@ class TrainingModel(building_blocks.Block): output = self.build(*args, **kwargs) # Perform shape inference - model_with_shapes = onnx.shape_inference.infer_shapes( - accessor.global_accessor.model - ) + model_with_shapes = onnx.shape_inference.infer_shapes(accessor.global_accessor.model) accessor.global_accessor.model.CopyFrom(model_with_shapes) # Build the graph outputs @@ -122,6 +117,7 @@ class TrainingModel(building_blocks.Block): accessor.global_accessor.model, self._arg_not_requiring_grad ) + # build the gradient graph all_args_requiring_gradient_names = graph_utils.build_gradient_graph( accessor.global_accessor, self._arg_requiring_grad, @@ -130,9 +126,7 @@ class TrainingModel(building_blocks.Block): ) # add gradient accumulation nodes - graph_utils.build_gradient_accumulation_graph( - accessor.global_accessor.model, all_args_requiring_gradient_names - ) + graph_utils.build_gradient_accumulation_graph(accessor.global_accessor.model, all_args_requiring_gradient_names) # validate and check the model onnx.checker.check_model(accessor.global_accessor.model, True) diff --git a/orttraining/orttraining/python/training/onnxblock/model_accessor.py b/orttraining/orttraining/python/training/onnxblock/model_accessor.py index 324eae441d..162acd765d 100644 --- a/orttraining/orttraining/python/training/onnxblock/model_accessor.py +++ b/orttraining/orttraining/python/training/onnxblock/model_accessor.py @@ -1,16 +1,41 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -# global_model.py +# model_accessor.py from contextlib import contextmanager +import onnx + class ModelAccessor: """This class stores the onnx model that is manipulated by the onnx blocks.""" def __init__(self, model): - self.model = model - self.eval_model = None + self._model = model + self._eval_model = None + + @property + def model(self): + """ModelAccessor property that gets the modified model.""" + + if self._model is None: + raise RuntimeError( + "The onnx model was not set. Please use the context manager onnxblock.onnx_model to create the model." + ) + return self._model + + @property + def eval_model(self): + """ModelAccessor property that gets the eval model.""" + + if self._eval_model is None: + raise RuntimeError("The eval onnx model was not set.") + return self._eval_model + + @eval_model.setter + def eval_model(self, value): + """ModelAccessor property that sets the eval model.""" + self._eval_model = value # This variable resides in the global namespace. @@ -26,6 +51,14 @@ def onnx_model(model=None): Manages the construction and destruction of the global model. """ global global_accessor + if global_accessor is not None: + raise RuntimeError("Base onnx model already exists. Cannot create multiple ModelAccessors.") + + # If the user did not provide a model, then assume that they want to build from scratch. + # It is the duty of the caller to fill the model however they deam fit. + if model is None: + model = onnx.ModelProto() + global_accessor = ModelAccessor(model) try: yield global_accessor diff --git a/orttraining/orttraining/python/training/onnxblock/optim/__init__.py b/orttraining/orttraining/python/training/onnxblock/optim/__init__.py index 58a2a41f5d..c45ab01c26 100644 --- a/orttraining/orttraining/python/training/onnxblock/optim/__init__.py +++ b/orttraining/orttraining/python/training/onnxblock/optim/__init__.py @@ -2,4 +2,4 @@ # Licensed under the MIT License. # __init__.py -from .optim import AdamW +from .optim import AdamW, ClipGradNorm diff --git a/orttraining/orttraining/python/training/onnxblock/optim/optim.py b/orttraining/orttraining/python/training/onnxblock/optim/optim.py index 480c1ace67..4db29d75ab 100644 --- a/orttraining/orttraining/python/training/onnxblock/optim/optim.py +++ b/orttraining/orttraining/python/training/onnxblock/optim/optim.py @@ -2,160 +2,216 @@ # Licensed under the MIT License. # optim.py -import copy import onnx -import onnxruntime.training.onnxblock as onnxblock +import onnxruntime.training.onnxblock._graph_utils as graph_utils +import onnxruntime.training.onnxblock.building_blocks as building_blocks +import onnxruntime.training.onnxblock.model as model import onnxruntime.training.onnxblock.model_accessor as accessor +# TODO: Find a better place for these constants +_PRODUCER_NAME = "onnxblock offline tooling" +_OPSET_IMPORTS = [onnx.helper.make_opsetid("com.microsoft", 1), onnx.helper.make_opsetid("", 14)] -class AdamW(onnxblock.Model): - """Builds AdamW optimizer onnxblock for the given training model.""" - def __init__( - self, bias_correction=True, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0 - ): - super(AdamW, self).__init__() +class AdamWOptimizer(building_blocks.Block): + """Adds an AdamWOptimizer node to the onnx model.""" + + def __init__(self, bias_correction=True, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0): + super().__init__() + self._bias_correction = bias_correction self._betas = betas self._eps = eps self._weight_decay = weight_decay - self._max_norm_clip = 1.0 + + def build( # pylint: disable=too-many-arguments + self, + learning_rate_name, + step_name, + parameter_sequence_name, + gradient_sequence_name, + first_order_moment_sequence_name, + second_order_moment_sequence_name, + ): + """Adds the AdamWOptimizer node to the model.""" + + # get the model to manipulate + onnx_model = accessor.global_accessor.model + + # define the node attributes + node_attributes = { + "alpha": self._betas[0], # beta1 + "beta": self._betas[1], # beta2 + "epsilon": self._eps, # epsilon + "weight_decay": self._weight_decay, # weight decay + "correct_bias": 1 if self._bias_correction else 0, # bias_correction + "adam_mode": 1, # adam mode (1 for hf/transformers/AdamW) + } + + # add the adamw node to the onnx model + adamw_input_names = [ + learning_rate_name, # learning rate + step_name, # training step + parameter_sequence_name, # param to be updated + gradient_sequence_name, # gradient of the param to be used for update + first_order_moment_sequence_name, # first order moment for this param + second_order_moment_sequence_name, # second order moment for this param + ] + adamw_output_name = graph_utils.generate_random_graph_name("adamw.updated_flag") + adamw_output_names = [adamw_output_name] + adamw_node = onnx.helper.make_node( + "AdamWOptimizer", + adamw_input_names, + adamw_output_names, + name=graph_utils.generate_random_graph_name("AdamWOptimizer"), + domain="com.microsoft", + **node_attributes, + ) + onnx_model.graph.node.append(adamw_node) + + return adamw_output_name + + +class ClipGradNorm(building_blocks.Block): + """Builds a gradient clipping by norm sub graph for the onnx model. + + Creates a block that performs gradient clipping by l2 norm for the calculated + gradient. + + Args: + max_norm: float indicating the max norm of the gradients. + + Returns: + Returns a string of the output names of the gradients after clipping. + """ + + def __init__(self, max_norm): + super().__init__() + + self._max_norm = max_norm + + self._reduce = building_blocks.ReduceAllL2() + self._add = building_blocks.Add() + self._div = building_blocks.Div() + self._mul = building_blocks.Mul() + self._clip = building_blocks.Clip(clip_max=1.0) + + def build(self, *gradient_names): + """Adds a clip grad norm sub graph to the onnx model.""" + + # get the model to manipulate + onnx_model = accessor.global_accessor.model + + # add the necessary graph initializers + add_node_eps_name = graph_utils.generate_random_graph_name("add_eps") + onnx_model.graph.initializer.append( + onnx.helper.make_tensor(add_node_eps_name, onnx.TensorProto.FLOAT, [1], [1e-6]) + ) + max_norm_name = graph_utils.generate_random_graph_name("max_norm") + onnx_model.graph.initializer.append( + onnx.helper.make_tensor(max_norm_name, onnx.TensorProto.FLOAT, [1], [self._max_norm]) + ) + + # perform gradient clipping + total_norm_name = self._reduce(*gradient_names) + adjusted_total_norm_name = self._add(total_norm_name, add_node_eps_name) + clip_coef_name = self._clip(self._div(max_norm_name, adjusted_total_norm_name)) + return [self._mul(grad_name, clip_coef_name) for grad_name in gradient_names] + + +class AdamW(model.Model): + """Builds AdamW optimizer onnxblock for the given training parameters. + + Creates a block that updates the model parameters based on the calculated + gradient following the AdamW algorithm. + + Args: + bias_correction: bool indicating whether to perform bias correction. + betas: AdamW decay rate hyperparameters. + eps: term added to the denominator for computing the moments. + weight_decay: AdamW weight decay + clip_grad (optional): an instance of the ClipGradNorm. If not provided, + gradient clipping will not be done. + + Returns: + Returns a string of the output names from this optimizer node. + """ + + def __init__( + self, bias_correction=True, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, clip_grad=None + ): # pylint: disable=too-many-arguments + super().__init__() + + self._adamw = AdamWOptimizer( + bias_correction=bias_correction, + betas=betas, + eps=eps, + weight_decay=weight_decay, + ) + self._clip_grad = clip_grad + self._sc = building_blocks.SequenceConstruct() def build(self, parameters): - """Returns an AdamW optimizer model based on the input training model.""" + """Returns an AdamW optimizer model based on the input parameters.""" + + # get the model to manipulate and update its namespace + onnx_model = accessor.global_accessor.model + onnx_model.graph.name = "AdamW Optimizer Model" + onnx_model.producer_name = _PRODUCER_NAME + onnx_model.opset_import.extend(_OPSET_IMPORTS) + onnx_model.ir_version = onnx.IR_VERSION # TODO: Avoid hard coded input/output strings learning_rate_name = "learning_rate" step_name = "step" - gradient_output_suffix = "_grad.accumulation.out" - first_order_moment_suffix = "exp_avg" - second_order_moment_fuffix = "exp_avg_sq" - output_name_suffix = "out" + params_name = "params" + first_order_moments_name = "first_order_moments" + second_order_moments_name = "second_order_moments" + gradient_suffix = "_grad" trainable_parameters, _ = parameters - graph_nodes = [] - graph_inputs = [ - onnx.helper.make_tensor_value_info( - learning_rate_name, onnx.TensorProto.FLOAT, [1] - ), - onnx.helper.make_tensor_value_info(step_name, onnx.TensorProto.INT64, [1]), - ] - graph_outputs = [] - - # Iterate over all training graph outputs that are gradient outputs - for idx, param in enumerate(trainable_parameters): - - param_name = param.name - grad_name = f"{param_name}{gradient_output_suffix}" - first_order_moment_name = f"{param_name}.{first_order_moment_suffix}" - second_order_moment_name = f"{param_name}.{second_order_moment_fuffix}" - # prepare node (and graph) inputs and outputs - node_input_names = [ - learning_rate_name, # learning rate - step_name, # training step (used for beta correction) - param_name, # param to be updated - grad_name, # gradient of the param to be used for update - first_order_moment_name, # first order moment for this param - second_order_moment_name, # second order moment for this param + # create the graph inputs for the lr, step, params, grads, moments + onnx_model.graph.input.extend( + [ + onnx.helper.make_tensor_value_info(learning_rate_name, onnx.TensorProto.FLOAT, [1]), + onnx.helper.make_tensor_value_info(step_name, onnx.TensorProto.INT64, [1]), ] - - param_tensor_value_info = onnx.helper.make_tensor_value_info( - param_name, param.data_type, param.dims - ) - grad_tensor_value_info = onnx.helper.make_tensor_value_info( - grad_name, param.data_type, param.dims - ) - first_order_moment_tensor_value_info = onnx.helper.make_tensor_value_info( - first_order_moment_name, param.data_type, param.dims - ) - second_order_moment_tensor_value_info = onnx.helper.make_tensor_value_info( - second_order_moment_name, param.data_type, param.dims - ) - node_inputs = [ - param_tensor_value_info, - grad_tensor_value_info, - first_order_moment_tensor_value_info, - second_order_moment_tensor_value_info, - ] - graph_inputs.extend(node_inputs) - - step_output_name = f"{param_name}.{step_name}.{output_name_suffix}" - param_output_name = f"{param_name}.{output_name_suffix}" - first_order_moment_output_name = ( - f"{first_order_moment_name}.{output_name_suffix}" - ) - second_order_moment_output_name = ( - f"{second_order_moment_name}.{output_name_suffix}" - ) - - param_output_tensor_value_info = onnx.helper.make_tensor_value_info( - param_output_name, param.data_type, param.dims - ) - first_order_moment_output_tensor_value_info = ( - onnx.helper.make_tensor_value_info( - first_order_moment_output_name, param.data_type, param.dims - ) - ) - second_order_moment_output_tensor_value_info = ( - onnx.helper.make_tensor_value_info( - second_order_moment_output_name, param.data_type, param.dims - ) - ) - - node_output_names = [ - step_output_name, # step out - first_order_moment_output_name, # first order moment output - second_order_moment_output_name, # second order moment output - param_output_name, # updated weights - ] - - node_outputs = [ - onnx.helper.make_tensor_value_info( - step_output_name, onnx.TensorProto.INT64, [1] - ), - first_order_moment_output_tensor_value_info, - second_order_moment_output_tensor_value_info, - param_output_tensor_value_info, - ] - graph_outputs.extend(node_outputs) - - # AdamOptimizer node attributes - node_attributes = { - "alpha": self._betas[0], # beta1 - "beta": self._betas[1], # beta2 - "lambda": self._weight_decay, # weight decay - "epsilon": self._eps, # epsilon - "do_bias_correction": 1 - if self._bias_correction - else 0, # bias_correction - "weight_decay_mode": 1, # weight decay mode - "max_norm_clip": self._max_norm_clip, # used for gradient scaling - } - - # make the node - optimizer_node = onnx.helper.make_node( - "AdamOptimizer", - node_input_names, - node_output_names, - name=f"AdamOptimizer{idx}", - domain="com.microsoft", - **node_attributes, - ) - - graph_nodes.append(optimizer_node) - - # make the graph and the model - graph = onnx.helper.make_graph( - graph_nodes, "Optimizer Graph", graph_inputs, graph_outputs - ) - model = onnx.helper.make_model( - graph, - producer_name=onnxblock._producer_name, - opset_imports=[onnxblock._opset_import], ) - accessor.global_accessor.model = model + # Prepare the tensor sequence inputs for params and moments + for input_name in [params_name, first_order_moments_name, second_order_moments_name]: + onnx_model.graph.input.append( + onnx.helper.make_tensor_sequence_value_info(input_name, trainable_parameters[0].data_type, None) + ) - return [output.name for output in graph_outputs] + # TODO: Make the grads as a tensor sequence input after implementing clip grad + # normalization implementation which takes in a tensor sequence. + grad_names = [] + for param in trainable_parameters: + grad_names.append(f"{param.name}{gradient_suffix}") + onnx_model.graph.input.append( + onnx.helper.make_tensor_value_info(grad_names[-1], param.data_type, param.dims) + ) + + # Clip the gradients if needed + if self._clip_grad is not None: + grad_names = self._clip_grad(*grad_names) + + # Run multi tensor AdamWOptimizer + updated_flag_name = self._adamw( + learning_rate_name, + step_name, + params_name, + self._sc(*grad_names), + first_order_moments_name, + second_order_moments_name, + ) + + # Create the graph outputs + onnx_model.graph.output.append( + onnx.helper.make_tensor_value_info(updated_flag_name, onnx.TensorProto.INT64, [1]) + ) + + return updated_flag_name diff --git a/orttraining/orttraining/test/python/orttraining_test_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_onnxblock.py index 609be3a777..4a8d6981f1 100644 --- a/orttraining/orttraining/test/python/orttraining_test_onnxblock.py +++ b/orttraining/orttraining/test/python/orttraining_test_onnxblock.py @@ -421,8 +421,12 @@ def test_bcewithlogits_loss_training_graph_execution(): assert np.allclose(ort_grad, _to_numpy(pt_param.grad)) -@pytest.mark.parametrize("graph", [SimpleTrainingModelWithMSELoss, SimpleTrainingModelWithCrossEntropyLoss]) -def test_adamw_optimizer_composition(graph): +@pytest.mark.parametrize( + "graph", + [SimpleTrainingModelWithMSELoss, SimpleTrainingModelWithCrossEntropyLoss, SimpleTrainingModelWithBCEWithLogitsLoss], +) +@pytest.mark.parametrize("grad_clipping", [None, onnxblock.optim.ClipGradNorm(2.5)]) +def test_adamw_optimizer_composition(graph, grad_clipping): # Given device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 @@ -433,13 +437,15 @@ def test_adamw_optimizer_composition(graph): with onnxblock.onnx_model(onnx_model): _ = simple_model(onnx_model.graph.output[0].name) - optimizer = onnxblock.optim.AdamW() + optimizer = onnxblock.optim.AdamW(clip_grad=grad_clipping) with onnxblock.onnx_model() as accessor: _ = optimizer(simple_model.parameters()) optimizer_model = accessor.model assert optimizer_model +# TODO: Add a test for correctness when creation of ortvalues of +# tensor seq is possible on cuda def test_adamw_optimizer_execution(): # Given device = "cuda" @@ -455,14 +461,12 @@ def test_adamw_optimizer_execution(): optimizer = onnxblock.optim.AdamW() with onnxblock.onnx_model() as accessor: - _ = optimizer(simple_model.parameters()) + output_name = optimizer(simple_model.parameters()) optimizer_model = accessor.model learning_rate = 0.001 step = 1 - ort_output_names = [] - for name, _ in pt_model.named_parameters(): - ort_output_names.append(f"{name}.out") + ort_output_names = [output_name] def mse_loss(prediction, target): loss = torch.nn.MSELoss() @@ -478,12 +482,15 @@ def test_adamw_optimizer_execution(): ort_inputs = { "learning_rate": np.full(1, learning_rate, dtype=np.float32), "step": np.full(1, step, dtype=np.int64), + "params": [], + "first_order_moments": [], + "second_order_moments": [], } for name, param in pt_model.named_parameters(): - ort_inputs[name] = _to_numpy(copy.deepcopy(param)) - ort_inputs[f"{name}_grad.accumulation.out"] = _to_numpy(copy.deepcopy(param.grad)) - ort_inputs[f"{name}.exp_avg"] = _to_numpy(torch.zeros_like(param)) - ort_inputs[f"{name}.exp_avg_sq"] = _to_numpy(torch.zeros_like(param)) + ort_inputs["params"].append(_to_numpy(copy.deepcopy(param))) + ort_inputs[f"{name}_grad"] = _to_numpy(copy.deepcopy(param.grad)) + ort_inputs["first_order_moments"].append(_to_numpy(torch.zeros_like(param))) + ort_inputs["second_order_moments"].append(_to_numpy(torch.zeros_like(param))) # Then no error occurs when executing the model ort_session = onnxruntime.InferenceSession(onnx_fo.name, providers=C.get_available_providers()) @@ -642,3 +649,64 @@ def test_weighted_average_model_composition(model_type): weighted_model = WeightedAvg(random.random(), random.random()) with onnxblock.onnx_model(onnx_model): _ = weighted_model(onnx_model.graph.output[0].name, onnx_model.graph.output[1].name) + + +def test_grad_clipping_execution(): + # Given + device = "cuda" + N, D_in, H, D_out = 64, 784, 500, 10 + pt_model, _ = _get_models(device, N, D_in, H, D_out) + x = torch.randn(N, D_in, device=device) + target = torch.randn(N, D_out, device=device) + + # Prepare the onnx model with only grad clipping + onnx_model = onnx.ModelProto() + onnx_model.graph.name = "AdamW Optimizer Model" + onnx_model.producer_name = "grad clipping test" + onnx_model.opset_import.extend(onnxblock.optim.optim._OPSET_IMPORTS) + onnx_model.ir_version = onnx.IR_VERSION + + class GradClippingModel(onnxblock.Model): + def __init__(self, max_norm): + self._grad_clip = onnxblock.optim.ClipGradNorm(max_norm) + + def build(self, *grad_names): + return self._grad_clip(*grad_names) + + grad_names = [] + for name, param in pt_model.named_parameters(): + grad_names.append(f"{name}_grad") + + onnx_model.graph.input.append( + onnx.helper.make_tensor_value_info(grad_names[-1], onnx.TensorProto.FLOAT, param.shape) + ) + + grad_clip = GradClippingModel(2.5) + + with onnxblock.onnx_model(onnx_model): + ort_output_names = grad_clip(*grad_names) + + def mse_loss(prediction, target): + loss = torch.nn.MSELoss() + return loss(prediction, target) + + # When + with tempfile.NamedTemporaryFile(suffix=".onnx") as onnx_fo: + onnx.save(onnx_model, onnx_fo.name) + + loss = mse_loss(pt_model(x), target) + loss.backward() + + ort_inputs = {} + for name, param in pt_model.named_parameters(): + ort_inputs[f"{name}_grad"] = _to_numpy(copy.deepcopy(param.grad)) + + torch.nn.utils.clip_grad_norm_(pt_model.parameters(), 2.5) + + # Then no error occurs when executing the model + ort_session = onnxruntime.InferenceSession(onnx_fo.name, providers=C.get_available_providers()) + ort_outs = ort_session.run(ort_output_names, ort_inputs) + + # assert all the gradients are close + for ort_grad, pt_param in zip(ort_outs, pt_model.parameters()): + assert np.allclose(ort_grad, _to_numpy(pt_param.grad)) diff --git a/orttraining/orttraining/test/training_api/core/checkpoint_test.cc b/orttraining/orttraining/test/training_api/core/checkpoint_test.cc index 9fc5edc57f..df38f21817 100644 --- a/orttraining/orttraining/test/training_api/core/checkpoint_test.cc +++ b/orttraining/orttraining/test/training_api/core/checkpoint_test.cc @@ -19,7 +19,11 @@ #include "core/platform/path_lib.h" #include "orttraining/core/framework/checkpoint_common.h" -#include "orttraining/training_api/include/interfaces.h" +#include "orttraining/training_api/include/module.h" +#include "orttraining/training_api/include/optimizer.h" +#include "orttraining/training_api/include/checkpoint_property.h" +#include "orttraining/training_api/include/checkpoint.h" +#include "orttraining/training_api/include/lr_scheduler.h" #include "test/test_environment.h" #include "test/util/include/asserts.h" @@ -27,6 +31,7 @@ #include "test/util/include/test/test_environment.h" #include "orttraining/test/training_api/common/synthetic_data_loader.h" #include "orttraining/test/training_api/core/data_utils.h" +#include "default_providers.h" using onnxruntime::test::TemporaryDirectory; using namespace onnxruntime::training::api; @@ -159,7 +164,9 @@ TEST(CheckpointApiTest, SaveOnnxModelAsCheckpoint_ThenLoad_CPU) { * Save Optimizer states into ORT checkpoint files, * Then load it into ORT, compare with the initial optimizer states values. */ -TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad_CPU) { +#if defined(USE_CUDA) || defined(USE_ROCM) + +TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad_CUDA) { /// Phase 1 - Test Preparison /// Prepare the data and dest folder for saving checkpoint. /// Also cooked the data for test result comparison. @@ -201,15 +208,17 @@ TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad_CPU) { onnxruntime::SessionOptions session_option; std::unique_ptr env; ORT_THROW_IF_ERROR(Environment::Create(nullptr, env)); + std::vector> cuda_provider{onnxruntime::test::DefaultCudaExecutionProvider()}; auto model = std::make_unique(model_uri, named_parameters, session_option, - *env); - auto optimizer = Optimizer(optim_uri, model->NamedParameters(), session_option, *env); + *env, cuda_provider); + auto optimizer = std::make_unique(optim_uri, model->NamedParameters(), session_option, + *env, cuda_provider); /// Phase 2 - Run Optimizer.GetStateDict and call save checkpoint APIs. /// And check the result checkpoint files. CheckpointState checkpoint_state; - ORT_ENFORCE(optimizer.GetStateDict(checkpoint_state.optimizer_checkpoint_state).IsOK()); + ORT_ENFORCE(optimizer->GetStateDict(checkpoint_state.optimizer_checkpoint_state).IsOK()); // Remove the tempoprary directory if it already exists. auto ckpt_test_root_dir = ORT_TSTR("checkpointing_api_test_dir"); @@ -286,6 +295,8 @@ TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad_CPU) { } } +#endif + /** * Create PropertyBag with sets of properties, * Save properties into ORT checkpoint files, diff --git a/orttraining/orttraining/test/training_api/core/data_utils.h b/orttraining/orttraining/test/training_api/core/data_utils.h index 5090a5116d..f114720e71 100644 --- a/orttraining/orttraining/test/training_api/core/data_utils.h +++ b/orttraining/orttraining/test/training_api/core/data_utils.h @@ -20,6 +20,25 @@ void OrtValueToVec(const OrtValue& val, std::vector& output) { output.assign(val_ptr, val_ptr + num_elem); } +template +void CudaOrtValueToCpuVec(const OrtValue& val, std::vector& output, + std::shared_ptr cuda_provider, + std::shared_ptr cpu_provider) { + const Tensor& src_tensor = val.Get(); + + auto allocator = cpu_provider->GetAllocator(0, OrtMemTypeDefault); + ORT_ENFORCE(allocator, "Cpu allocator is a nullptr."); + auto dst_tensor = std::make_unique(src_tensor.DataType(), src_tensor.Shape(), allocator); + + auto data_transfer = cuda_provider->GetDataTransfer(); + ORT_ENFORCE(data_transfer, "Cuda data transfer is a nullptr."); + + ORT_THROW_IF_ERROR(data_transfer->CopyTensor(src_tensor, *dst_tensor)); + + const T* val_ptr = dst_tensor->template Data(); + output.assign(val_ptr, val_ptr + src_tensor.Shape().Size()); +} + } // namespace test } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/test/training_api/core/training_api_tests.cc b/orttraining/orttraining/test/training_api/core/training_api_tests.cc index e8507289b6..3621fcc0a6 100644 --- a/orttraining/orttraining/test/training_api/core/training_api_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_api_tests.cc @@ -11,8 +11,13 @@ #include "core/common/path_utils.h" #include "core/framework/tensorprotoutils.h" #include "orttraining/training_api/include/utils.h" -#include "orttraining/training_api/include/interfaces.h" +#include "orttraining/training_api/include/module.h" +#include "orttraining/training_api/include/optimizer.h" +#include "orttraining/training_api/include/checkpoint_property.h" +#include "orttraining/training_api/include/checkpoint.h" +#include "orttraining/training_api/include/lr_scheduler.h" #include "orttraining/test/training_api/core/data_utils.h" +#include "default_providers.h" using json = nlohmann::json; using namespace onnxruntime::training; @@ -53,7 +58,7 @@ TEST(TrainingApiTest, ModuleTrainStep) { std::unique_ptr env; ORT_THROW_IF_ERROR(Environment::Create(nullptr, env)); auto model = std::make_unique(model_uri, state.module_checkpoint_state.named_parameters, session_option, - *env); + *env, std::vector>()); OrtValue input, target; GenerateRandomInput(std::array{2, 784}, input); @@ -95,6 +100,8 @@ TEST(TrainingApiTest, ModuleTrainStep) { } } +#if defined(USE_CUDA) || defined(USE_ROCM) + TEST(TrainingApiTest, OptimStep) { auto model_uri = MODEL_FOLDER "gradient_graph.onnx"; auto optim_uri = MODEL_FOLDER "adamw.onnx"; @@ -105,10 +112,14 @@ TEST(TrainingApiTest, OptimStep) { onnxruntime::SessionOptions session_option; std::unique_ptr env; + std::vector> providers{onnxruntime::test::DefaultCudaExecutionProvider()}; + std::shared_ptr cuda_provider = providers.front(); + std::shared_ptr cpu_provider = onnxruntime::test::DefaultCpuExecutionProvider(); ORT_THROW_IF_ERROR(Environment::Create(nullptr, env)); auto model = std::make_unique(model_uri, state.module_checkpoint_state.named_parameters, session_option, - *env); - auto optim = std::make_unique(optim_uri, model->NamedParameters(), session_option, *env); + *env, providers); + auto optim = std::make_unique(optim_uri, model->NamedParameters(), session_option, + *env, providers); OrtValue input, target; GenerateRandomInput(std::array{2, 784}, input); @@ -125,8 +136,11 @@ TEST(TrainingApiTest, OptimStep) { optimizer_states.group_named_optimizer_states["group0"]->param_named_optimizer_states.at(param_name); OrtValue& moment_1 = param_state.momentum_named_states.at("momentum0"); + std::vector param_vec_before_optimizer_step; + CudaOrtValueToCpuVec(model->NamedParameters().at(param_name)->Data(), param_vec_before_optimizer_step, + cuda_provider, cpu_provider); std::vector moment_1_vec; - OrtValueToVec(moment_1, moment_1_vec); + CudaOrtValueToCpuVec(moment_1, moment_1_vec, cuda_provider, cpu_provider); for (size_t i = 0; i < moment_1_vec.size(); i++) { ORT_ENFORCE(moment_1_vec[i] == 0.0f); } @@ -136,14 +150,25 @@ TEST(TrainingApiTest, OptimStep) { std::vector& inputs = *it; std::vector fetches; ORT_ENFORCE(model->TrainStep(inputs, fetches).IsOK()); + std::vector grads; + CudaOrtValueToCpuVec(model->NamedParameters().at(param_name)->Gradient(), grads, + cuda_provider, cpu_provider); ORT_ENFORCE(optim->Step().IsOK()); // get optim state and check if it is updated - OrtValueToVec(moment_1, moment_1_vec); + CudaOrtValueToCpuVec(moment_1, moment_1_vec, cuda_provider, cpu_provider); for (size_t i = 0; i < moment_1_vec.size(); i++) { - if (moment_1_vec[i] != 0.0f) { - // moment was updated - break; + if (grads[i] != 0.0f) { + ORT_ENFORCE(moment_1_vec[i] != 0.0f); + } + } + + std::vector param_vec_after_optimizer_step; + CudaOrtValueToCpuVec(model->NamedParameters().at(param_name)->Data(), param_vec_after_optimizer_step, + cuda_provider, cpu_provider); + for (size_t i = 0; i < param_vec_after_optimizer_step.size(); ++i) { + if (grads[i] != 0.0f && moment_1_vec[i] != 0.0f) { + ORT_ENFORCE(param_vec_after_optimizer_step[i] != param_vec_before_optimizer_step[i]); } } } @@ -167,9 +192,11 @@ void TestLRSchduler(const std::string& test_file_name, float initial_lr, int64_t onnxruntime::SessionOptions session_option; std::unique_ptr env; ORT_THROW_IF_ERROR(Environment::Create(nullptr, env)); + const std::vector> providers{onnxruntime::test::DefaultCudaExecutionProvider()}; auto model = std::make_unique(model_uri, state.module_checkpoint_state.named_parameters, session_option, - *env); - auto optim = std::make_shared(optim_uri, model->NamedParameters(), session_option, *env); + *env, providers); + auto optim = std::make_shared(optim_uri, model->NamedParameters(), session_option, + *env, providers); OrtValue input, target; GenerateRandomInput(std::array{2, 784}, input); @@ -259,6 +286,8 @@ TEST(TrainingApiTest, LinearLRScheduler_WarmUp200Step_ResumeFromCheckpoint_Test) TestLRSchduler("warmup_linear_scheduler_warmupstep-200_restored.json", initial_lr, total_step_count, 200); } +#endif + } // namespace } // namespace test } // namespace training diff --git a/orttraining/orttraining/training_api/include/interfaces.h b/orttraining/orttraining/training_api/include/interfaces.h deleted file mode 100644 index 1418d96b0d..0000000000 --- a/orttraining/orttraining/training_api/include/interfaces.h +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -/** - * @brief Temporary classes bridging between public C interfaces and internal ones. - * Currently implementation is not covering all targeted public training APIs, more changes are expected to add. - * - * TODO: For longer-term, we should re-arange following interfaces following the ways of exiting APIs' definitions and exposures - * in include/onnxruntime/core/session/onnxruntime_c_api.h and include/onnxruntime/core/session/onnxruntime_cxx_api.h. - * - * This is an intermediate header file for our training api internal development usage. - * - */ - -#pragma once - -#include - -#include "core/session/inference_session.h" -#include "core/session/environment.h" - -#include "orttraining/training_api/include/module.h" -#include "orttraining/training_api/include/optimizer.h" -#include "orttraining/training_api/include/lr_scheduler.h" -#include "orttraining/training_api/include/checkpoint_property.h" -#include "orttraining/training_api/include/checkpoint.h" - -using onnxruntime::training::api::LinearLRScheduler; -using onnxruntime::training::api::Module; -using onnxruntime::training::api::ModuleCheckpointState; -using onnxruntime::training::api::Optimizer; -using onnxruntime::training::api::OptimizerCheckpointState; -using onnxruntime::training::api::Parameter; - -namespace Ort { - -namespace { - -void ToOrtValue(const std::vector& ort_value_list, std::vector& ortvalue_list) { - size_t input_len = ort_value_list.size(); - ortvalue_list.clear(); - ortvalue_list.reserve(input_len); - const Ort::Value* ort_value_inputs_ptr = ort_value_list.data(); - auto ortvalue_inputs_ptr = reinterpret_cast(const_cast(ort_value_inputs_ptr)); - for (size_t i = 0; i < input_len; ++i) { - auto& ort_value = *reinterpret_cast(ortvalue_inputs_ptr[i]); - ortvalue_list.push_back(ort_value); - } -} - -void FromOrtValue(std::vector& ortvalue_list, std::vector& ort_value_list) { - // Clean the output. - ort_value_list.clear(); - size_t output_names_len = ortvalue_list.size(); - for (size_t i = 0; i < output_names_len; ++i) - ort_value_list.emplace_back(nullptr); - - Ort::Value* ort_value_outputs_ptr = ort_value_list.data(); - auto ortvalue_outputs_ptr = reinterpret_cast(ort_value_outputs_ptr); - for (size_t i = 0; i != output_names_len; ++i) { - OrtValue& value = ortvalue_list[i]; - // Ort::Value will release the pointer once it goes out of scope. - ortvalue_outputs_ptr[i] = new OrtValue(value); - } -} - -} // namespace - -struct OrtModule { - public: - OrtModule( - OrtEnv* env, OrtSessionOptions* session_options, - const std::string& train_model_path_or_bytes, - std::unordered_map>& parameters, - const std::optional& eval_model_path_or_bytes = std::nullopt) { - module_ = std::make_unique( - train_model_path_or_bytes, - parameters, - *reinterpret_cast<::onnxruntime::SessionOptions*>(session_options), - *reinterpret_cast<::onnxruntime::Environment*>(env), - eval_model_path_or_bytes); - } - - std::unordered_map> NamedParameters() const { - return module_->NamedParameters(); - } - - bool ResetGrad() { - return module_->ResetGrad().IsOK(); - } - - bool TrainStep(const std::vector& inputs, std::vector& outputs) { - std::vector feeds; - ToOrtValue(inputs, feeds); - - std::vector fetches; - if (!module_->TrainStep(feeds, fetches).IsOK()) { - return false; - } - - // Clean the output. - outputs.clear(); - FromOrtValue(fetches, outputs); - - return true; - } - - bool EvalStep(const std::vector& inputs, std::vector& outputs) { - std::vector feeds; - ToOrtValue(inputs, feeds); - - std::vector fetches; - if (!module_->EvalStep(feeds, fetches).IsOK()) { - return false; - } - - // Clean the output. - outputs.clear(); - FromOrtValue(fetches, outputs); - return true; - } - - bool GetStateDict(onnxruntime::training::api::ModuleCheckpointState& module_checkpoint_states) { - return module_->GetStateDict(module_checkpoint_states).IsOK(); - } - - private: - std::unique_ptr module_; -}; - -struct OrtOptimizer { - friend struct OrtLinearLRScheduler; - - OrtOptimizer( - OrtEnv* env, OrtSessionOptions* session_options, - const std::string& optim_path_or_bytes, - const std::unordered_map>& parameters) { - optimizer_ = std::make_shared( - optim_path_or_bytes, - parameters, - *reinterpret_cast<::onnxruntime::SessionOptions*>(session_options), - *reinterpret_cast<::onnxruntime::Environment*>(env)); - } - - bool Step() { - return optimizer_->Step().IsOK(); - } - - bool GetStateDict(onnxruntime::training::api::OptimizerCheckpointState& optimizer_checkpoint_states) { - return optimizer_->GetStateDict(optimizer_checkpoint_states).IsOK(); - } - - private: - std::shared_ptr optimizer_; -}; - -struct OrtLinearLRScheduler { - OrtLinearLRScheduler(OrtOptimizer& optimizer, int64_t warmup_step_count, int64_t total_step_count) - : optim_(optimizer.optimizer_) { - linear_lr_scheduler_ = std::make_unique(optim_, warmup_step_count, total_step_count); - } - - bool Step() { - return linear_lr_scheduler_->Step().IsOK(); - } - - private: - std::unique_ptr linear_lr_scheduler_; - std::shared_ptr optim_; -}; - -} // namespace Ort diff --git a/orttraining/orttraining/training_api/include/module.h b/orttraining/orttraining/training_api/include/module.h index b6be7714c6..dabc372034 100644 --- a/orttraining/orttraining/training_api/include/module.h +++ b/orttraining/orttraining/training_api/include/module.h @@ -61,6 +61,7 @@ struct Module { const std::unordered_map>& named_parameters, const onnxruntime::SessionOptions& session_options, const Environment& env, + const std::vector>& providers, const std::optional& eval_model_path_or_bytes = std::nullopt); // Return the trainable/nontrainable parameters diff --git a/orttraining/orttraining/training_api/include/optimizer.h b/orttraining/orttraining/training_api/include/optimizer.h index 56ae0a2ccc..cb8a20f602 100644 --- a/orttraining/orttraining/training_api/include/optimizer.h +++ b/orttraining/orttraining/training_api/include/optimizer.h @@ -59,7 +59,8 @@ struct Optimizer { Optimizer(const std::string& optim_path_or_bytes, const std::unordered_map>& named_parameters, const onnxruntime::SessionOptions& session_options, - const Environment& env); + const Environment& env, + const std::vector>& providers); Status Step(); diff --git a/orttraining/orttraining/training_api/include/training_session.h b/orttraining/orttraining/training_api/include/training_session.h index 36a66baa4b..263355dbe6 100644 --- a/orttraining/orttraining/training_api/include/training_session.h +++ b/orttraining/orttraining/training_api/include/training_session.h @@ -12,16 +12,23 @@ namespace training { namespace api { using namespace common; +struct ModelIdentifiers { + const std::string train_model; + const std::optional eval_model, optim_model; + ModelIdentifiers(const std::string& train_model_uri, + const std::optional& eval_model_uri, + const std::optional& optim_model_uri) + : train_model(train_model_uri), eval_model(eval_model_uri), optim_model(optim_model_uri) {} +}; + // Wrapper on top of module and optimizer classes and is the only class exposed via capis class TrainingSession { public: TrainingSession(const Environment& session_env, const SessionOptions& session_options, - const std::unordered_map>& parameters); - - Status Initialize(const std::string& train_model_uri, - const std::optional& eval_model_uri, - const std::optional& optim_model_uri); + const std::vector>& providers, + const std::unordered_map>& parameters, + const ModelIdentifiers& model_identifiers); size_t GetTrainModeOutputCount() const noexcept; @@ -44,8 +51,6 @@ class TrainingSession { private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TrainingSession); - const Environment& environment_; - SessionOptions session_options_; const std::unordered_map> named_parameters_; std::unique_ptr module_; std::unique_ptr optimizer_; diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index 8084c71a41..e362e1118c 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -55,10 +55,14 @@ Module::Module(const std::string& train_model_path_or_bytes, const std::unordered_map>& named_parameters, const onnxruntime::SessionOptions& session_options, const Environment& env, - const std::optional& eval_model_path_or_bytes) : named_parameters_{named_parameters} { - // Create session for training model + const std::vector>& providers, + const std::optional& eval_model_path_or_bytes) + : named_parameters_{named_parameters} { train_sess_ = std::make_unique(session_options, env); ORT_THROW_IF_ERROR(train_sess_->Load(train_model_path_or_bytes)); + for (const auto& provider : providers) { + ORT_THROW_IF_ERROR(train_sess_->RegisterExecutionProvider(provider)); + } ORT_THROW_IF_ERROR(train_sess_->Initialize()); // Extract model input and output names @@ -123,7 +127,6 @@ Module::Module(const std::string& train_model_path_or_bytes, auto target_tensor = std::make_unique(param_data_tensor.DataType(), param_data_tensor.Shape(), target_allocator); ORT_THROW_IF_ERROR(train_sess_state.GetDataTransferMgr().CopyTensor(param_data_tensor, *target_tensor.get())); auto ml_tensor_type = DataTypeImpl::GetType(); - // TODO test the original buffer is released. param_data.Init(target_tensor.release(), ml_tensor_type, ml_tensor_type->GetDeleteFunc()); } diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc index 155b9aa27f..730ad208c4 100644 --- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc @@ -9,6 +9,20 @@ #include "orttraining/training_api/include/training_session.h" #include "core/session/abi_session_options_impl.h" +namespace { + +std::vector> CreateProviders( + const std::vector>& provider_factories) { + std::vector> execution_providers; + for (const auto& factory : provider_factories) { + execution_providers.emplace_back(std::move(factory->CreateProvider())); + } + + return execution_providers; +} + +} // namespace + ORT_API_STATUS_IMPL(OrtApis::CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path, _In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path, @@ -20,16 +34,18 @@ ORT_API_STATUS_IMPL(OrtApis::CreateTrainingSession, _In_ const OrtEnv* env, _In_ *out = nullptr; ORT_TRY { + using ProvidersType = std::vector>; train_sess = std::make_unique( env->GetEnvironment(), options == nullptr ? onnxruntime::SessionOptions() : options->value, - chkpt_state->module_checkpoint_state.named_parameters); - - ORT_API_RETURN_IF_STATUS_NOT_OK(train_sess->Initialize(train_model_path, - eval_model_path ? std::optional{eval_model_path} - : std::nullopt, - optimizer_model_path ? std::optional{optimizer_model_path} - : std::nullopt)); + options == nullptr ? ProvidersType() : CreateProviders(options->provider_factories), + chkpt_state->module_checkpoint_state.named_parameters, + onnxruntime::training::api::ModelIdentifiers{ + train_model_path, + eval_model_path ? std::optional{eval_model_path} + : std::nullopt, + optimizer_model_path ? std::optional{optimizer_model_path} + : std::nullopt}); *out = reinterpret_cast(train_sess.release()); } diff --git a/orttraining/orttraining/training_api/optimizer.cc b/orttraining/orttraining/training_api/optimizer.cc index e2e988b5c4..39f1c78617 100644 --- a/orttraining/orttraining/training_api/optimizer.cc +++ b/orttraining/orttraining/training_api/optimizer.cc @@ -20,9 +20,14 @@ const std::string GROUP_ZERO_NAME = "group0"; // TODO: don't hard code the state names, should get the state names according to the optimizer types. // TODO: Conolidate with frontend tooling -const std::vector MOMENT_SUFFIXES{".exp_avg", ".exp_avg_sq"}; const std::vector MOMENT_STATE_NAMES{"momentum0", "momentum1"}; +constexpr char LearningRateName[] = "learning_rate"; +constexpr char StepName[] = "step"; +constexpr char ParamsName[] = "params"; +constexpr char FirstOrderMomentsName[] = "first_order_moments"; +constexpr char SecondOrderMomentsName[] = "second_order_moments"; + } // namespace Status Optimizer::GenerateMomentumNamedStates() { @@ -49,35 +54,70 @@ Status Optimizer::ConstructInputs() { if (optimizer_type_ == OptimizerType::AdamW) { auto& param_named_optimizer_states = optimizer_state_.param_named_optimizer_states; - std::string param_name; - std::vector param_names, grad_names, moment1_names, moment2_names, user_inputs; - for (size_t i = 2; i < input_names_.size(); i++) { - std::string& name = input_names_[i]; - auto it = named_parameters_.find(name); - if (it != named_parameters_.end()) { // is param - param_names.push_back(name); - inputs_.push_back(it->second->Data()); - } else if (utils::GetParamNameFromGradient(name, param_name)) { - grad_names.emplace_back(name); - // assert param_name is valid. - auto it = named_parameters_.find(param_name); - ORT_ENFORCE(it != named_parameters_.end(), "Unknown param: ", param_name, " for field: ", name); - inputs_.push_back(it->second->Gradient()); - } else if (utils::GetParamNameFromSuffix(name, MOMENT_SUFFIXES[0], param_name)) { - moment1_names.push_back(name); - auto it = named_parameters_.find(param_name); - ORT_ENFORCE(it != named_parameters_.end(), "Unknown param: ", param_name, " for field: ", name); - inputs_.push_back(param_named_optimizer_states.at(param_name).momentum_named_states.at(MOMENT_STATE_NAMES[0])); - } else if (utils::GetParamNameFromSuffix(name, MOMENT_SUFFIXES[1], param_name)) { - moment2_names.push_back(name); - auto it = named_parameters_.find(param_name); - ORT_ENFORCE(it != named_parameters_.end(), "Unknown param: ", param_name, " for field: ", name); - inputs_.push_back(param_named_optimizer_states.at(param_name).momentum_named_states.at(MOMENT_STATE_NAMES[1])); + std::vector params, first_order_moments, second_order_moments; + // TODO: Change to tensor seq implementation once clip grad norm op + // that accepts tensor seq as input for gradients is complete. + std::vector grads; + + // Input names 0-4 are reserved for lr, step, params, first order moments, second order moments + // input names 5 onwards are all the gradient names. + // Collect all the inputs based on the gradient names order. + for (size_t i = 5; i < input_names_.size(); i++) { + std::string param_name; + if (utils::GetParamNameFromGradient(input_names_[i], param_name)) { + const auto named_parameter_it = named_parameters_.find(param_name); + ORT_ENFORCE(named_parameter_it != named_parameters_.end(), + "Unknown param: ", param_name, " for field: ", input_names_[i]); + + // Collect the gradients as ortvalues + grads.push_back(named_parameter_it->second->Gradient()); + + // Collect parameters and prepare for tensorseq creation + auto* param_tensor = named_parameter_it->second->Data().GetMutable(); + params.emplace_back( + Tensor(param_tensor->DataType(), param_tensor->Shape(), + param_tensor->MutableDataRaw(), param_tensor->Location())); + + // Collect first order moments and prepare for tensorseq creation + auto* first_order_moment_tensor = param_named_optimizer_states.at(param_name) + .momentum_named_states.at(MOMENT_STATE_NAMES[0]) + .GetMutable(); + first_order_moments.emplace_back( + Tensor(first_order_moment_tensor->DataType(), first_order_moment_tensor->Shape(), + first_order_moment_tensor->MutableDataRaw(), first_order_moment_tensor->Location())); + + // Collect second order moments and prepare for tensorseq creation + auto* second_order_moment_tensor = param_named_optimizer_states.at(param_name) + .momentum_named_states.at(MOMENT_STATE_NAMES[1]) + .GetMutable(); + second_order_moments.emplace_back( + Tensor(second_order_moment_tensor->DataType(), second_order_moment_tensor->Shape(), + second_order_moment_tensor->MutableDataRaw(), second_order_moment_tensor->Location())); } else { - ORT_ENFORCE("This is an invalid graph. Optimizer graph contains unknown user input:", name); + ORT_ENFORCE( + false, "This is an invalid graph. Optimizer graph contains unknown user input:", input_names_[i]); } - ORT_ENFORCE(inputs_.back().IsAllocated() && inputs_.back().IsTensor(), "Uninitialized tensor data for ", name); } + + const auto tensorseq_inserter = [](auto& tensors, auto* inputs) { + ORT_ENFORCE(!tensors.empty(), "Tensors cannot be empty while building a tensor sequence."); + + auto tensor_seq = std::make_unique(tensors.front().DataType()); + tensor_seq->SetElements(std::move(tensors)); + inputs->emplace_back( + OrtValue(tensor_seq.release(), DataTypeImpl::GetType(), + DataTypeImpl::GetType()->GetDeleteFunc())); + }; + + // Add the params and moments as tensorseq ortvalues to inputs + tensorseq_inserter(params, &inputs_); + tensorseq_inserter(first_order_moments, &inputs_); + tensorseq_inserter(second_order_moments, &inputs_); + + // Add the gradients as ortvalues to inputs + inputs_.insert(inputs_.end(), + std::make_move_iterator(grads.begin()), + std::make_move_iterator(grads.end())); } // Add other optimizer reordering logic here return Status::OK(); @@ -86,17 +126,26 @@ Status Optimizer::ConstructInputs() { Optimizer::Optimizer(const std::string& optim_path_or_bytes, const std::unordered_map>& named_parameters, const onnxruntime::SessionOptions& session_options, - const Environment& env) : named_parameters_(named_parameters) { + const Environment& env, + const std::vector>& providers) + : named_parameters_(named_parameters) { optim_sess_ = std::move(std::make_unique(session_options, env)); + for (const auto& execution_provider : providers) { + ORT_THROW_IF_ERROR(optim_sess_->RegisterExecutionProvider(execution_provider)); + } ORT_THROW_IF_ERROR(optim_sess_->Load(optim_path_or_bytes)); ORT_THROW_IF_ERROR(optim_sess_->Initialize()); utils::GetGraphInputOutputNames(optim_sess_, input_names_, output_names_); - ORT_ENFORCE(input_names_[0] == "learning_rate"); // TODO: make this better - ORT_ENFORCE(input_names_[1] == "step"); // TODO: make this better + ORT_ENFORCE(input_names_[0] == LearningRateName); // TODO: make this better + ORT_ENFORCE(input_names_[1] == StepName); // TODO: make this better + ORT_ENFORCE(input_names_[2] == ParamsName); // TODO: make this better if (optimizer_type_ == OptimizerType::AdamW) { + ORT_ENFORCE(input_names_[3] == FirstOrderMomentsName); // TODO: make this better + ORT_ENFORCE(input_names_[4] == SecondOrderMomentsName); // TODO: make this better + ORT_THROW_IF_ERROR(GenerateMomentumNamedStates()); } else { ORT_THROW("Unsupported optimizer type"); @@ -107,7 +156,10 @@ Optimizer::Optimizer(const std::string& optim_path_or_bytes, Status Optimizer::Step() { OrtValue learning_rate_input, step_input; utils::WrapInOrtValue(optimizer_state_.learning_rate, &learning_rate_input); - utils::WrapInOrtValue(optimizer_state_.step, &step_input); + // Use step count + 1 before running optimizer step. + // This is necessary since bias correction uses the step + // as a power. Using power of 0 is wrong. + utils::WrapInOrtValue(optimizer_state_.step + 1, &step_input); std::vector feeds({learning_rate_input, step_input}); feeds.insert(feeds.end(), inputs_.begin(), inputs_.end()); @@ -116,8 +168,9 @@ Status Optimizer::Step() { ORT_THROW_IF_ERROR(status); // extract step output and update - // TODO: need to remove hardcoding - optimizer_state_.step = utils::GetValue(outputs[0]); + if (utils::GetValue(outputs[0]) == 1LL) { + optimizer_state_.step++; + } return Status::OK(); } diff --git a/orttraining/orttraining/training_api/training_session.cc b/orttraining/orttraining/training_api/training_session.cc index 020a708c29..ad565c21d0 100644 --- a/orttraining/orttraining/training_api/training_session.cc +++ b/orttraining/orttraining/training_api/training_session.cc @@ -9,23 +9,17 @@ namespace api { TrainingSession::TrainingSession(const Environment& session_env, const SessionOptions& session_options, - const std::unordered_map>& parameters) - : environment_(session_env), - session_options_{session_options}, - named_parameters_{parameters} {} - -Status TrainingSession::Initialize(const std::string& train_model_uri, const std::optional& eval_model_uri, - const std::optional& optim_model_uri) { - module_ = std::move(std::make_unique(train_model_uri, named_parameters_, session_options_, - environment_, eval_model_uri)); - - if (optim_model_uri.has_value()) { - optimizer_ = std::move(std::make_unique(optim_model_uri.value(), named_parameters_, - session_options_, environment_)); - } - - return Status::OK(); -} + const std::vector>& providers, + const std::unordered_map>& parameters, + const ModelIdentifiers& model_identifiers) + : named_parameters_{parameters}, + module_{std::make_unique(model_identifiers.train_model, named_parameters_, + session_options, session_env, providers, model_identifiers.eval_model)}, + optimizer_{model_identifiers.optim_model.has_value() + ? std::make_unique( + model_identifiers.optim_model.value(), named_parameters_, + session_options, session_env, providers) + : std::unique_ptr()} {} size_t TrainingSession::GetTrainModeOutputCount() const noexcept { return module_->GetTrainModeOutputCount(); diff --git a/orttraining/orttraining/training_api/utils.cc b/orttraining/orttraining/training_api/utils.cc index d7c71826b4..1169f15fe6 100644 --- a/orttraining/orttraining/training_api/utils.cc +++ b/orttraining/orttraining/training_api/utils.cc @@ -63,12 +63,22 @@ Status OrtValueLike(const SessionState& sess_state, const OrtValue& input_val, O auto element_type = param_tensor.DataType(); auto p_tensor = std::make_unique(element_type, shape, allocator); - // TODO: handle CUDA memset if (tensor_location.device.Type() == OrtDevice::CPU || tensor_location.mem_type == OrtMemTypeCPUInput || tensor_location.mem_type == OrtMemTypeCPUOutput) { memset(p_tensor->MutableDataRaw(), 0, p_tensor->SizeInBytes()); + } else if (tensor_location.device.Type() == OrtDevice::GPU) { + // Use a tensor on cpu and copy it over to the desired device using + // the data transfer manager. + AllocatorPtr cpu_allocator = sess_state.GetAllocator(OrtDevice()); + auto p_cpu_tensor = std::make_unique(element_type, shape, cpu_allocator); + memset(p_cpu_tensor->MutableDataRaw(), 0, p_cpu_tensor->SizeInBytes()); + // No need to free the cpu buffer, tensor destructor takes care of it using the cpu_allocator + ORT_THROW_IF_ERROR(sess_state.GetDataTransferMgr().CopyTensor(*p_cpu_tensor, *p_tensor)); + } else { + ORT_THROW("Cannot create tensor on device ", tensor_location.device.Type()); } + output_val.Init(p_tensor.release(), DataTypeImpl::GetType(), DataTypeImpl::GetType()->GetDeleteFunc()); diff --git a/pyproject.toml b/pyproject.toml index 91e3c13440..b79a02d5ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ extend_skip_glob = [ convention = "google" [tool.pylint.'MESSAGES CONTROL'] -disable = ["format", "line-too-long", "import-error", "no-name-in-module"] +disable = ["format", "line-too-long", "import-error", "no-name-in-module", "fixme", "too-few-public-methods"] [tool.pyright] exclude = ["onnxruntime/core/flatbuffers/*"] diff --git a/setup.py b/setup.py index 49aadbb5ca..86ecd25938 100644 --- a/setup.py +++ b/setup.py @@ -374,7 +374,7 @@ requirements_file = "requirements.txt" local_version = None enable_training = parse_arg_remove_boolean(sys.argv, "--enable_training") -enable_training_on_device = parse_arg_remove_boolean(sys.argv, '--enable_training_on_device') +enable_training_on_device = parse_arg_remove_boolean(sys.argv, "--enable_training_on_device") disable_auditwheel_repair = parse_arg_remove_boolean(sys.argv, "--disable_auditwheel_repair") default_training_package_device = parse_arg_remove_boolean(sys.argv, "--default_training_package_device") @@ -421,7 +421,9 @@ if enable_training: ] ) if enable_training_on_device: - packages.append('onnxruntime.training.onnxblock') + packages.append("onnxruntime.training.onnxblock") + packages.append("onnxruntime.training.onnxblock.loss") + packages.append("onnxruntime.training.onnxblock.optim") package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.aten_op_executor"] = ["*.cc"] package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils"] = ["*.cc"] package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator"] = ["*.cc"] diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 90fc942ed5..e7ece8dc6d 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -179,8 +179,7 @@ def parse_arguments(): parser.add_argument( "--enable_training_torch_interop", action="store_true", help="Enable training kernels interop with torch." ) - parser.add_argument( - "--enable_training_on_device", action='store_true', help="Enable on device training in ORT.") + parser.add_argument("--enable_training_on_device", action="store_true", help="Enable on device training in ORT.") parser.add_argument("--disable_nccl", action="store_true", help="Disable Nccl.") parser.add_argument("--mpi_home", help="Path to MPI installation dir") parser.add_argument("--nccl_home", help="Path to NCCL installation dir")