mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
Add support for gradient clipping, AdamWOptimizer and tensorseq as inputs (#11697)
This commit is contained in:
parent
f14f0e19ec
commit
fac8dae9df
28 changed files with 749 additions and 681 deletions
|
|
@ -226,7 +226,7 @@ static Status BatchOrCopyMLValue(const SessionState& session_state,
|
|||
}
|
||||
++source_iter;
|
||||
++target_iter;
|
||||
} //while
|
||||
} // while
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported OrtValue type to copy between device.");
|
||||
}
|
||||
|
|
@ -322,9 +322,9 @@ static common::Status CalculateStaticCopyInfoForFetches(const SessionState& sess
|
|||
// If for some reason using just the device from the allocation plan isn't enough, the following
|
||||
// would use the NodeInfo from the node producing the output
|
||||
//
|
||||
//std::vector<SessionState::NodeInfo> node_info_vec;
|
||||
//auto status = session_state.GetOutputNodeInfo(output_name, node_info_vec);
|
||||
//if (status.IsOK()) {
|
||||
// std::vector<SessionState::NodeInfo> node_info_vec;
|
||||
// auto status = session_state.GetOutputNodeInfo(output_name, node_info_vec);
|
||||
// if (status.IsOK()) {
|
||||
// const auto& node_info = node_info_vec.front(); // only one entry as only one node can produce a given output
|
||||
// copy_info[idx].source_device = *node_info.device;
|
||||
//} else {
|
||||
|
|
@ -428,6 +428,11 @@ static void FinalizeFeedFetchCopyInfo(FeedsFetchesManager& feeds_fetches_manager
|
|||
const auto& feed = feeds[i];
|
||||
if (feed.IsTensor()) {
|
||||
feed_locations[i] = feed.Get<Tensor>().Location().device;
|
||||
} else if (feed.IsTensorSequence()) {
|
||||
const auto& tensor_seq = feed.Get<TensorSeq>();
|
||||
if (tensor_seq.Size() != std::size_t{0}) {
|
||||
feed_locations[i] = tensor_seq.Get(0).Location().device;
|
||||
}
|
||||
} else if (feed.IsSparseTensor()) {
|
||||
#if !defined(DISABLE_SPARSE_TENSORS)
|
||||
feed_locations[i] = feed.Get<SparseTensor>().Location().device;
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/training_api/adamw.onnx
vendored
BIN
onnxruntime/test/testdata/training_api/adamw.onnx
vendored
Binary file not shown.
|
|
@ -2,13 +2,8 @@
|
|||
# Licensed under the MIT License.
|
||||
# __init__.py
|
||||
|
||||
from .building_blocks import Block
|
||||
from .model import Model, TrainingModel
|
||||
from .checkpoint_utils import save_checkpoint
|
||||
from . import loss, optim
|
||||
from .building_blocks import Block
|
||||
from .checkpoint_utils import save_checkpoint
|
||||
from .model import Model, TrainingModel
|
||||
from .model_accessor import onnx_model
|
||||
|
||||
import onnx
|
||||
|
||||
_producer_name = "onnxblock offline tooling"
|
||||
_opset_import = onnx.helper.make_opsetid("com.microsoft", 1)
|
||||
|
|
|
|||
|
|
@ -1,146 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# _building_blocks.py
|
||||
|
||||
import copy
|
||||
import onnx
|
||||
|
||||
import onnxruntime.training.onnxblock as onnxblock
|
||||
import onnxruntime.training.onnxblock.model_accessor as accessor
|
||||
import onnxruntime.training.onnxblock._graph_utils as graph_utils
|
||||
|
||||
|
||||
class Sub(onnxblock.Model):
|
||||
"""Adds Sub node to an onnx model."""
|
||||
|
||||
def __init__(self):
|
||||
super(Sub, self).__init__()
|
||||
|
||||
def build(self, sub_input_name1, sub_input_name2):
|
||||
# get the model to manipulate
|
||||
onnx_model = accessor.global_accessor.model
|
||||
|
||||
# create the graph node for sub
|
||||
sub_node_input_names = [sub_input_name1, sub_input_name2]
|
||||
sub_node_output_name = graph_utils.generate_random_graph_name("sub_output")
|
||||
sub_node_output_names = [sub_node_output_name]
|
||||
sub_node = onnx.helper.make_node(
|
||||
"Sub",
|
||||
sub_node_input_names,
|
||||
sub_node_output_names,
|
||||
name=graph_utils.generate_random_graph_name("Sub"),
|
||||
)
|
||||
onnx_model.graph.node.append(sub_node)
|
||||
|
||||
# create the graph output for sub
|
||||
graph_output = copy.deepcopy(
|
||||
graph_utils.get_output_from_output_name(onnx_model, sub_input_name1)
|
||||
)
|
||||
graph_output.name = sub_node_output_name
|
||||
del onnx_model.graph.output[:]
|
||||
onnx_model.graph.output.append(graph_output)
|
||||
|
||||
return sub_node_output_name
|
||||
|
||||
|
||||
class Pow(onnxblock.Model):
|
||||
"""Adds Pow node to the onnx model."""
|
||||
|
||||
def __init__(self, exponent):
|
||||
super(Pow, self).__init__()
|
||||
|
||||
self._exponent = exponent
|
||||
|
||||
def build(self, pow_input_name):
|
||||
# get the model to manipulate
|
||||
onnx_model = accessor.global_accessor.model
|
||||
|
||||
# create the graph initializer for the exponent
|
||||
pow_node_exponent_name = graph_utils.generate_random_graph_name("pow_exponent")
|
||||
onnx_model.graph.initializer.append(
|
||||
onnx.helper.make_tensor(
|
||||
pow_node_exponent_name, onnx.TensorProto.FLOAT, [1], [self._exponent]
|
||||
)
|
||||
)
|
||||
|
||||
# create the graph node for pow
|
||||
pow_node_input_names = [pow_input_name, pow_node_exponent_name]
|
||||
pow_node_output_name = graph_utils.generate_random_graph_name("pow_output")
|
||||
pow_node_output_names = [pow_node_output_name]
|
||||
pow_node = onnx.helper.make_node(
|
||||
"Pow",
|
||||
pow_node_input_names,
|
||||
pow_node_output_names,
|
||||
name=graph_utils.generate_random_graph_name("Pow"),
|
||||
)
|
||||
onnx_model.graph.node.append(pow_node)
|
||||
|
||||
# create the graph output for pow
|
||||
graph_output = copy.deepcopy(
|
||||
graph_utils.get_output_from_output_name(onnx_model, pow_input_name)
|
||||
)
|
||||
graph_output.name = pow_node_output_name
|
||||
del onnx_model.graph.output[:]
|
||||
onnx_model.graph.output.append(graph_output)
|
||||
|
||||
return pow_node_output_name
|
||||
|
||||
|
||||
class _Reduce(onnxblock.Model):
|
||||
"""Base class for the reduce blocks."""
|
||||
|
||||
def __init__(self):
|
||||
super(_Reduce, self).__init__()
|
||||
|
||||
def _reduce(self, reduce_input_name, reduction_op):
|
||||
# get the model to manipulate
|
||||
onnx_model = accessor.global_accessor.model
|
||||
|
||||
# create the graph node for reduce
|
||||
reduce_node_input_names = [reduce_input_name]
|
||||
reduce_node_output_name = graph_utils.generate_random_graph_name(
|
||||
"reduce_output"
|
||||
)
|
||||
reduce_node_output_names = [reduce_node_output_name]
|
||||
reduce_node = onnx.helper.make_node(
|
||||
reduction_op,
|
||||
reduce_node_input_names,
|
||||
reduce_node_output_names,
|
||||
name=graph_utils.generate_random_graph_name(reduction_op),
|
||||
)
|
||||
onnx_model.graph.node.append(reduce_node)
|
||||
|
||||
# create the graph output for reduce
|
||||
reduce_input = copy.deepcopy(
|
||||
graph_utils.get_output_from_output_name(onnx_model, reduce_input_name)
|
||||
)
|
||||
output_rank = len(reduce_input.type.tensor_type.shape.dim)
|
||||
graph_outputs = [
|
||||
onnx.helper.make_tensor_value_info(
|
||||
reduce_node_output_name, onnx.TensorProto.FLOAT, [1] * output_rank
|
||||
)
|
||||
]
|
||||
del onnx_model.graph.output[:]
|
||||
onnx_model.graph.output.extend(graph_outputs)
|
||||
|
||||
return reduce_node_output_name
|
||||
|
||||
|
||||
class ReduceMean(_Reduce):
|
||||
"""Adds ReduceMean node to the onnx model."""
|
||||
|
||||
def __init__(self):
|
||||
super(ReduceMean, self).__init__()
|
||||
|
||||
def build(self, reduce_input_name):
|
||||
return super()._reduce(reduce_input_name, "ReduceMean")
|
||||
|
||||
|
||||
class ReduceSum(_Reduce):
|
||||
"""Adds ReduceSum node to the onnx model."""
|
||||
|
||||
def __init__(self):
|
||||
super(ReduceSum, self).__init__()
|
||||
|
||||
def build(self, reduce_input_name):
|
||||
return super(ReduceSum, self)._reduce(reduce_input_name, "ReduceSum")
|
||||
|
|
@ -3,9 +3,10 @@
|
|||
# _graph_utils.py
|
||||
|
||||
import copy
|
||||
import onnx
|
||||
import random
|
||||
|
||||
import onnx
|
||||
|
||||
from onnxruntime.capi._pybind_state import GradientGraphBuilder
|
||||
|
||||
|
||||
|
|
@ -21,44 +22,52 @@ def get_output_from_output_name(onnx_model, output_name):
|
|||
|
||||
|
||||
def get_random_number():
|
||||
"""Return a random number in the range 0, 1000."""
|
||||
"""Return a random number in the range 0, 100000."""
|
||||
|
||||
return random.randint(0, 1000)
|
||||
return random.randint(0, 100000)
|
||||
|
||||
|
||||
def generate_random_graph_name(token):
|
||||
"""Return a string that can be used in the graph as a graph attribute name."""
|
||||
|
||||
return f"{get_random_number()}.{token}"
|
||||
return f"onnx::{token}::{get_random_number()}"
|
||||
|
||||
|
||||
def build_gradient_graph(
|
||||
accessor, user_args_requiring_grad, user_args_not_requiring_grad, output_names
|
||||
):
|
||||
def _reorder_outputs(model, user_output_names, args_requiring_gradient_names):
|
||||
graph_outputs = {output.name: output for output in model.graph.output}
|
||||
ordered_graph_outputs = [graph_outputs[name] for name in user_output_names]
|
||||
|
||||
for arg in args_requiring_gradient_names:
|
||||
gradient_name = f"{arg}_grad"
|
||||
ordered_graph_outputs.append(graph_outputs[gradient_name])
|
||||
|
||||
del model.graph.output[:]
|
||||
model.graph.output.extend(ordered_graph_outputs)
|
||||
|
||||
|
||||
def build_gradient_graph(accessor, user_args_requiring_grad, user_args_not_requiring_grad, output_names):
|
||||
"""Builds the gradient graph on top of the given input forward only graph."""
|
||||
|
||||
model = accessor.model
|
||||
|
||||
# Collect names of parameters that need gradients computed
|
||||
all_args_requiring_gradient = set()
|
||||
all_args_requiring_gradient = []
|
||||
# Move all trainable and non trainable initializers to graph inputs.
|
||||
# This allows training to pass in the parameters from outside the graph
|
||||
# so as to share the parameters across multiple sessions.
|
||||
graph_inputs = model.graph.input
|
||||
initializers = []
|
||||
for initializer in model.graph.initializer:
|
||||
if not initializer.name[0].isdigit():
|
||||
if not initializer.name.startswith("onnx::"):
|
||||
# Move only those initializers as inputs that are not local
|
||||
# to the onnx model. i.e. initializers that are model parameters.
|
||||
# These are tpically those initializers without any number prefixed
|
||||
# These are tpically those initializers without any onnx:: prefixed
|
||||
# to their names.
|
||||
graph_inputs.append(
|
||||
onnx.helper.make_tensor_value_info(
|
||||
initializer.name, initializer.data_type, initializer.dims
|
||||
)
|
||||
onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, initializer.dims)
|
||||
)
|
||||
if initializer.name not in user_args_not_requiring_grad:
|
||||
all_args_requiring_gradient.add(initializer.name)
|
||||
all_args_requiring_gradient.append(initializer.name)
|
||||
else:
|
||||
# All other initializers stay where they were.
|
||||
initializers.append(initializer)
|
||||
|
|
@ -74,7 +83,7 @@ def build_gradient_graph(
|
|||
# args_requiring_grad. So, add these arguments to set of arguments
|
||||
# whose gradient should be built.
|
||||
for argument_name in user_args_requiring_grad:
|
||||
all_args_requiring_gradient.add(argument_name)
|
||||
all_args_requiring_gradient.append(argument_name)
|
||||
|
||||
# Assumption is that the first graph output is the loss output
|
||||
if isinstance(output_names, str):
|
||||
|
|
@ -82,13 +91,16 @@ def build_gradient_graph(
|
|||
builder = GradientGraphBuilder(
|
||||
model.SerializeToString(),
|
||||
set(output_names),
|
||||
all_args_requiring_gradient,
|
||||
set(all_args_requiring_gradient),
|
||||
output_names[0],
|
||||
)
|
||||
builder.build()
|
||||
gradient_model = onnx.load_from_string(builder.get_model())
|
||||
|
||||
# copy the gradient graph into the user's graph
|
||||
# Reorder gradient outputs for the gradient model based on the all_args_requiring_gradient order
|
||||
_reorder_outputs(gradient_model, output_names, all_args_requiring_gradient)
|
||||
|
||||
# copy the gradient model into the user's model
|
||||
model.CopyFrom(gradient_model)
|
||||
|
||||
return all_args_requiring_gradient
|
||||
|
|
@ -117,9 +129,7 @@ def build_gradient_accumulation_graph(grad_model, all_args_requiring_gradient_na
|
|||
"""
|
||||
|
||||
# TODO: Avoid hard coded input/output strings
|
||||
gradient_output_names = {
|
||||
f"{arg_name}_grad" for arg_name in all_args_requiring_gradient_names
|
||||
}
|
||||
gradient_output_names = {f"{arg_name}_grad" for arg_name in all_args_requiring_gradient_names}
|
||||
|
||||
graph_inputs = grad_model.graph.input
|
||||
graph_nodes = grad_model.graph.node
|
||||
|
|
@ -139,12 +149,8 @@ def build_gradient_accumulation_graph(grad_model, all_args_requiring_gradient_na
|
|||
|
||||
# gradient accumulation node inputs and output names
|
||||
grad_name = graph_output.name
|
||||
grad_accumulation_buffer_name = (
|
||||
f"{grad_name}.{gradient_accumulation_name}.{gradient_buffer_name_suffix}"
|
||||
)
|
||||
grad_accumulation_output_name = (
|
||||
f"{grad_name}.{gradient_accumulation_name}.{gradient_output_name_suffix}"
|
||||
)
|
||||
grad_accumulation_buffer_name = f"{grad_name}.{gradient_accumulation_name}.{gradient_buffer_name_suffix}"
|
||||
grad_accumulation_output_name = f"{grad_name}.{gradient_accumulation_name}.{gradient_output_name_suffix}"
|
||||
|
||||
# Gradient accumulation node
|
||||
acc_node = onnx.helper.make_node(
|
||||
|
|
@ -167,9 +173,7 @@ def build_gradient_accumulation_graph(grad_model, all_args_requiring_gradient_na
|
|||
grad_accumulation_output.name = grad_accumulation_output_name
|
||||
graph_outputs.append(grad_accumulation_output)
|
||||
|
||||
lazy_reset_grad_input = onnx.helper.make_tensor_value_info(
|
||||
lazy_reset_grad_input_name, onnx.TensorProto.BOOL, [1]
|
||||
)
|
||||
lazy_reset_grad_input = onnx.helper.make_tensor_value_info(lazy_reset_grad_input_name, onnx.TensorProto.BOOL, [1])
|
||||
graph_inputs.append(lazy_reset_grad_input)
|
||||
|
||||
del grad_model.graph.output[:]
|
||||
|
|
@ -183,18 +187,18 @@ def get_model_parameters(model, args_not_requiring_gradient):
|
|||
non_trainable_params = []
|
||||
for initializer in model.graph.initializer:
|
||||
# All model parameters should have their names not begin with
|
||||
# a digit. So, check to see if the initializer's first char
|
||||
# is a digit. If not, it is either a trainable, or a non
|
||||
# `onnx::`. So, check to see if the initializer begins with
|
||||
# `onnx::`. If not, it is either a trainable, or a non
|
||||
# trainable parameter.
|
||||
# Note that this assumption can be made because the export logic
|
||||
# does not change the names of the original model parameters.
|
||||
# and the original model parameters don't have their names begin
|
||||
# with a digit.
|
||||
# `onnx::`.
|
||||
# On the other hand, const initializers are generated by export
|
||||
# logic and have a digit prefix.
|
||||
# logic and have a `onnx::` prefix.
|
||||
# TODO: validate this assumption. If assumption is not valid,
|
||||
# the alternative is to enforce the user to provide the parameter names.
|
||||
if not initializer.name[0].isdigit():
|
||||
if not initializer.name.startswith("onnx::"):
|
||||
if initializer.name in args_not_requiring_gradient:
|
||||
non_trainable_params.append(initializer)
|
||||
else:
|
||||
|
|
@ -215,9 +219,7 @@ def build_graph_outputs(model, output_names):
|
|||
if isinstance(output_names, str):
|
||||
output_names = [output_names]
|
||||
|
||||
name_value_info_mapping = {
|
||||
value_info.name: value_info for value_info in model.graph.value_info
|
||||
}
|
||||
name_value_info_mapping = {value_info.name: value_info for value_info in model.graph.value_info}
|
||||
name_graph_output_mapping = {output.name: output for output in model.graph.output}
|
||||
|
||||
# collect all new graph outputs (i.e. graph outputs that are not
|
||||
|
|
@ -229,9 +231,7 @@ def build_graph_outputs(model, output_names):
|
|||
elif output_name in name_value_info_mapping:
|
||||
graph_outputs.append(name_value_info_mapping[output_name])
|
||||
else:
|
||||
raise LookupError(
|
||||
f"The provided name {output_name} is not a graph value info or a graph output."
|
||||
)
|
||||
raise LookupError(f"The provided name {output_name} is not a graph value info or a graph output.")
|
||||
|
||||
# Clear all existing graph outputs
|
||||
del model.graph.output[:]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# _building_blocks.py
|
||||
# building_blocks.py
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
|
@ -38,7 +38,7 @@ class Block(ABC):
|
|||
|
||||
class _BinaryOp(Block):
|
||||
def __init__(self, op_name):
|
||||
super(_BinaryOp, self).__init__()
|
||||
super().__init__()
|
||||
self._op_name = op_name
|
||||
|
||||
def build(self, input_name1, input_name2):
|
||||
|
|
@ -68,28 +68,35 @@ class Add(_BinaryOp):
|
|||
"""Adds Add node to an onnx model."""
|
||||
|
||||
def __init__(self):
|
||||
super(Add, self).__init__("Add")
|
||||
super().__init__("Add")
|
||||
|
||||
|
||||
class Sub(_BinaryOp):
|
||||
"""Adds Sub node to an onnx model."""
|
||||
|
||||
def __init__(self):
|
||||
super(Sub, self).__init__("Sub")
|
||||
super().__init__("Sub")
|
||||
|
||||
|
||||
class Mul(_BinaryOp):
|
||||
"""Adds Mul node to an onnx model."""
|
||||
|
||||
def __init__(self):
|
||||
super(Mul, self).__init__("Mul")
|
||||
super().__init__("Mul")
|
||||
|
||||
|
||||
class Div(_BinaryOp):
|
||||
"""Adds Div node to an onnx model."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("Div")
|
||||
|
||||
|
||||
class Pow(Block):
|
||||
"""Adds Pow node to the onnx model."""
|
||||
|
||||
def __init__(self, exponent):
|
||||
super(Pow, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
self._exponent = exponent
|
||||
|
||||
|
|
@ -119,8 +126,10 @@ class Pow(Block):
|
|||
|
||||
|
||||
class _UnaryOp(Block):
|
||||
"""Base class for all nodes that take in a single argument."""
|
||||
|
||||
def __init__(self, op_name):
|
||||
super(_UnaryOp, self).__init__()
|
||||
super().__init__()
|
||||
self._op_name = op_name
|
||||
|
||||
def build(self, input_name):
|
||||
|
|
@ -150,29 +159,35 @@ class ReduceMean(_UnaryOp):
|
|||
"""Adds ReduceMean node to the onnx model."""
|
||||
|
||||
def __init__(self):
|
||||
super(ReduceMean, self).__init__("ReduceMean")
|
||||
super().__init__("ReduceMean")
|
||||
|
||||
|
||||
class ReduceSum(_UnaryOp):
|
||||
"""Adds ReduceSum node to the onnx model."""
|
||||
|
||||
def __init__(self):
|
||||
super(ReduceSum, self).__init__("ReduceSum")
|
||||
super().__init__("ReduceSum")
|
||||
|
||||
|
||||
class Sigmoid(_UnaryOp):
|
||||
"""Adds Sigmoid node to the onnx model."""
|
||||
|
||||
def __init__(self):
|
||||
super(Sigmoid, self).__init__("Sigmoid")
|
||||
super().__init__("Sigmoid")
|
||||
|
||||
|
||||
class Log(_UnaryOp):
|
||||
"""Adds Log node to the onnx model."""
|
||||
|
||||
def __init__(self):
|
||||
super(Log, self).__init__("Log")
|
||||
super().__init__("Log")
|
||||
|
||||
|
||||
class Neg(_UnaryOp):
|
||||
"""Adds Neg node to the onnx model."""
|
||||
|
||||
def __init__(self):
|
||||
super(Neg, self).__init__("Neg")
|
||||
super().__init__("Neg")
|
||||
|
||||
|
||||
class Constant(Block):
|
||||
|
|
@ -192,3 +207,109 @@ class Constant(Block):
|
|||
onnx.helper.make_tensor(initializer_name, onnx.TensorProto.FLOAT, [1], [self._value])
|
||||
)
|
||||
return initializer_name
|
||||
|
||||
|
||||
class SequenceConstruct(Block):
|
||||
"""Adds SequenceConstruct node to the onnx model."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def build(self, *sequence_input_names):
|
||||
# get the model to manipulate
|
||||
onnx_model = accessor.global_accessor.model
|
||||
|
||||
# create the graph node for this sequence construct node
|
||||
sc_node_input_names = list(sequence_input_names)
|
||||
sc_node_output_name = graph_utils.generate_random_graph_name("sequenceconstruct_output")
|
||||
sc_node_output_names = [sc_node_output_name]
|
||||
sc_node = onnx.helper.make_node(
|
||||
"SequenceConstruct",
|
||||
sc_node_input_names,
|
||||
sc_node_output_names,
|
||||
graph_utils.generate_random_graph_name("SequenceConstruct"),
|
||||
)
|
||||
onnx_model.graph.node.append(sc_node)
|
||||
|
||||
return sc_node_output_name
|
||||
|
||||
|
||||
class ReduceAllL2(Block):
|
||||
"""Adds ReduceAllL2 node to the onnx model.
|
||||
|
||||
ReduceAllL2 is a part of the com.microsoft domain and might not be accessible outside this domain.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def build(self, *reduce_node_input_names):
|
||||
# get the model to manipulate
|
||||
onnx_model = accessor.global_accessor.model
|
||||
|
||||
# create the graph node for this reducealll2 node
|
||||
reduce_node_input_names = list(reduce_node_input_names)
|
||||
reduce_node_output_name = graph_utils.generate_random_graph_name("reducealll2_output")
|
||||
reduce_node_output_names = [reduce_node_output_name]
|
||||
reduce_node = onnx.helper.make_node(
|
||||
"ReduceAllL2",
|
||||
reduce_node_input_names,
|
||||
reduce_node_output_names,
|
||||
graph_utils.generate_random_graph_name("ReduceAllL2"),
|
||||
domain="com.microsoft",
|
||||
)
|
||||
onnx_model.graph.node.append(reduce_node)
|
||||
# TODO: register shape inference with onnx
|
||||
onnx_model.graph.value_info.append(
|
||||
onnx.helper.make_tensor_value_info(reduce_node_output_name, onnx.TensorProto.FLOAT, [1])
|
||||
)
|
||||
|
||||
return reduce_node_output_name
|
||||
|
||||
|
||||
class Clip(Block):
|
||||
"""Adds Clip node to the onnx model."""
|
||||
|
||||
def __init__(self, clip_min=None, clip_max=None):
|
||||
super().__init__()
|
||||
|
||||
self._min = clip_min
|
||||
self._max = clip_max
|
||||
|
||||
def build(self, clip_input_name):
|
||||
# get the model to manipulate
|
||||
onnx_model = accessor.global_accessor.model
|
||||
|
||||
# create the graph initializer for the clip min
|
||||
clip_node_min_name = ""
|
||||
if self._min is not None:
|
||||
clip_node_min_name = graph_utils.generate_random_graph_name("clip_min")
|
||||
onnx_model.graph.initializer.append(
|
||||
onnx.helper.make_tensor(clip_node_min_name, onnx.TensorProto.FLOAT, [1], [self._min])
|
||||
)
|
||||
|
||||
# create the graph initializer for the clip max
|
||||
clip_node_max_name = ""
|
||||
if self._max is not None:
|
||||
clip_node_max_name = graph_utils.generate_random_graph_name("clip_max")
|
||||
onnx_model.graph.initializer.append(
|
||||
onnx.helper.make_tensor(clip_node_max_name, onnx.TensorProto.FLOAT, [1], [self._max])
|
||||
)
|
||||
|
||||
# create the graph node for this clip node
|
||||
clip_node_input_names = [
|
||||
clip_input_name,
|
||||
clip_node_min_name,
|
||||
clip_node_max_name,
|
||||
]
|
||||
clip_node_output_name = graph_utils.generate_random_graph_name("clip_output")
|
||||
clip_node_output_names = [clip_node_output_name]
|
||||
clip_node = onnx.helper.make_node(
|
||||
"Clip",
|
||||
clip_node_input_names,
|
||||
clip_node_output_names,
|
||||
graph_utils.generate_random_graph_name("Clip"),
|
||||
)
|
||||
onnx_model.graph.node.append(clip_node)
|
||||
|
||||
return clip_node_output_name
|
||||
|
|
|
|||
|
|
@ -16,6 +16,4 @@ def save_checkpoint(parameters, path_to_checkpoint):
|
|||
trainable_params, non_trainable_params = parameters
|
||||
trainable_params = [param.SerializeToString() for param in trainable_params]
|
||||
non_trainable_params = [param.SerializeToString() for param in non_trainable_params]
|
||||
_internal_save_checkpoint(
|
||||
trainable_params, non_trainable_params, path_to_checkpoint
|
||||
)
|
||||
_internal_save_checkpoint(trainable_params, non_trainable_params, path_to_checkpoint)
|
||||
|
|
|
|||
|
|
@ -3,11 +3,12 @@
|
|||
# loss.py
|
||||
|
||||
import copy
|
||||
|
||||
import onnx
|
||||
|
||||
import onnxruntime.training.onnxblock.model_accessor as accessor
|
||||
import onnxruntime.training.onnxblock.building_blocks as building_blocks
|
||||
import onnxruntime.training.onnxblock._graph_utils as graph_utils
|
||||
import onnxruntime.training.onnxblock.building_blocks as building_blocks
|
||||
import onnxruntime.training.onnxblock.model_accessor as accessor
|
||||
|
||||
|
||||
class MSELoss(building_blocks.Block):
|
||||
|
|
@ -25,11 +26,7 @@ class MSELoss(building_blocks.Block):
|
|||
if reduction != "mean" and reduction != "sum":
|
||||
raise RuntimeError(f"Reduction {reduction} not supported.")
|
||||
|
||||
self._reduce = (
|
||||
building_blocks.ReduceMean()
|
||||
if reduction == "mean"
|
||||
else building_blocks.ReduceSum()
|
||||
)
|
||||
self._reduce = building_blocks.ReduceMean() if reduction == "mean" else building_blocks.ReduceSum()
|
||||
self._sub = building_blocks.Sub()
|
||||
self._square = building_blocks.Pow(2.0)
|
||||
|
||||
|
|
@ -53,9 +50,7 @@ class MSELoss(building_blocks.Block):
|
|||
# create a new graph input. this is the target input needed to compare
|
||||
# the graph output against to calculate loss.
|
||||
# TODO: Move input creation outside of the blocks.
|
||||
target_input = copy.deepcopy(
|
||||
graph_utils.get_output_from_output_name(onnx_model, loss_input_name)
|
||||
)
|
||||
target_input = copy.deepcopy(graph_utils.get_output_from_output_name(onnx_model, loss_input_name))
|
||||
target_input.name = target_name
|
||||
onnx_model.graph.input.append(target_input)
|
||||
|
||||
|
|
@ -106,15 +101,11 @@ class CrossEntropyLoss(building_blocks.Block):
|
|||
|
||||
weight_name = graph_utils.generate_random_graph_name("celoss.weight")
|
||||
if self._weight is not None:
|
||||
onnx_model.graph.initializer.append(
|
||||
onnx.numpy_helper.from_array(self._weight, weight_name)
|
||||
)
|
||||
onnx_model.graph.initializer.append(onnx.numpy_helper.from_array(self._weight, weight_name))
|
||||
|
||||
# create a new graph input. this is the labels input needed to compare
|
||||
# the graph output against to calculate loss.
|
||||
labels_input = copy.deepcopy(
|
||||
graph_utils.get_output_from_output_name(onnx_model, scores_input_name)
|
||||
)
|
||||
labels_input = copy.deepcopy(graph_utils.get_output_from_output_name(onnx_model, scores_input_name))
|
||||
labels_input.name = labels_name
|
||||
labels_input.type.tensor_type.elem_type = onnx.TensorProto.INT32
|
||||
# if the predictions are (num_examples x num_classes)
|
||||
|
|
@ -163,11 +154,7 @@ class BCEWithLogitsLoss(building_blocks.Block):
|
|||
raise RuntimeError(f"Reduction {reduction} not supported.")
|
||||
|
||||
self._weight = weight
|
||||
self._reduce = (
|
||||
building_blocks.ReduceMean()
|
||||
if reduction == "mean"
|
||||
else building_blocks.ReduceSum()
|
||||
)
|
||||
self._reduce = building_blocks.ReduceMean() if reduction == "mean" else building_blocks.ReduceSum()
|
||||
self._pos_weight = pos_weight
|
||||
|
||||
self._sigmoid = building_blocks.Sigmoid()
|
||||
|
|
@ -198,38 +185,24 @@ class BCEWithLogitsLoss(building_blocks.Block):
|
|||
# create the graph initializers for pos_weight, weight, and the sub operands ([1])
|
||||
pos_weight_name = graph_utils.generate_random_graph_name("bceloss.pos_weight")
|
||||
if self._pos_weight is not None:
|
||||
onnx_model.graph.initializer.append(
|
||||
onnx.numpy_helper.from_array(self._pos_weight, pos_weight_name)
|
||||
)
|
||||
onnx_model.graph.initializer.append(onnx.numpy_helper.from_array(self._pos_weight, pos_weight_name))
|
||||
|
||||
weight_name = graph_utils.generate_random_graph_name("bceloss.weight")
|
||||
if self._weight is not None:
|
||||
onnx_model.graph.initializer.append(
|
||||
onnx.numpy_helper.from_array(self._weight, weight_name)
|
||||
)
|
||||
onnx_model.graph.initializer.append(onnx.numpy_helper.from_array(self._weight, weight_name))
|
||||
|
||||
sub_ones_operand_name1 = graph_utils.generate_random_graph_name(
|
||||
"bceloss.sub_ones"
|
||||
)
|
||||
sub_ones_operand_name1 = graph_utils.generate_random_graph_name("bceloss.sub_ones")
|
||||
onnx_model.graph.initializer.append(
|
||||
onnx.helper.make_tensor(
|
||||
sub_ones_operand_name1, onnx.TensorProto.FLOAT, [1], [1.0]
|
||||
)
|
||||
)
|
||||
sub_ones_operand_name2 = graph_utils.generate_random_graph_name(
|
||||
"bceloss.sub_ones"
|
||||
onnx.helper.make_tensor(sub_ones_operand_name1, onnx.TensorProto.FLOAT, [1], [1.0])
|
||||
)
|
||||
sub_ones_operand_name2 = graph_utils.generate_random_graph_name("bceloss.sub_ones")
|
||||
onnx_model.graph.initializer.append(
|
||||
onnx.helper.make_tensor(
|
||||
sub_ones_operand_name2, onnx.TensorProto.FLOAT, [1], [1.0]
|
||||
)
|
||||
onnx.helper.make_tensor(sub_ones_operand_name2, onnx.TensorProto.FLOAT, [1], [1.0])
|
||||
)
|
||||
|
||||
# create a new graph input. this is the target input needed to compare
|
||||
# the graph output against to calculate loss.
|
||||
target_input = copy.deepcopy(
|
||||
graph_utils.get_output_from_output_name(onnx_model, loss_input_name)
|
||||
)
|
||||
target_input = copy.deepcopy(graph_utils.get_output_from_output_name(onnx_model, loss_input_name))
|
||||
target_input.name = target_name
|
||||
onnx_model.graph.input.append(target_input)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,10 +3,12 @@
|
|||
# model.py
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
import onnx
|
||||
import onnxruntime.training.onnxblock.model_accessor as accessor
|
||||
import onnxruntime.training.onnxblock.building_blocks as building_blocks
|
||||
|
||||
import onnxruntime.training.onnxblock._graph_utils as graph_utils
|
||||
import onnxruntime.training.onnxblock.building_blocks as building_blocks
|
||||
import onnxruntime.training.onnxblock.model_accessor as accessor
|
||||
|
||||
|
||||
class Model(building_blocks.Block):
|
||||
|
|
@ -32,9 +34,7 @@ class Model(building_blocks.Block):
|
|||
output = self.build(*args, **kwargs)
|
||||
|
||||
# Perform shape inference
|
||||
model_with_shapes = onnx.shape_inference.infer_shapes(
|
||||
accessor.global_accessor.model
|
||||
)
|
||||
model_with_shapes = onnx.shape_inference.infer_shapes(accessor.global_accessor.model)
|
||||
accessor.global_accessor.model.CopyFrom(model_with_shapes)
|
||||
|
||||
# Build the graph outputs
|
||||
|
|
@ -85,10 +85,7 @@ class TrainingModel(building_blocks.Block):
|
|||
is built, an exception will be raised.
|
||||
"""
|
||||
if self._parameters is None:
|
||||
raise RuntimeError(
|
||||
"Please build the training model first before trying to "
|
||||
"retrieve the parameters."
|
||||
)
|
||||
raise RuntimeError("Please build the training model first before trying to " "retrieve the parameters.")
|
||||
|
||||
return self._parameters
|
||||
|
||||
|
|
@ -108,9 +105,7 @@ class TrainingModel(building_blocks.Block):
|
|||
output = self.build(*args, **kwargs)
|
||||
|
||||
# Perform shape inference
|
||||
model_with_shapes = onnx.shape_inference.infer_shapes(
|
||||
accessor.global_accessor.model
|
||||
)
|
||||
model_with_shapes = onnx.shape_inference.infer_shapes(accessor.global_accessor.model)
|
||||
accessor.global_accessor.model.CopyFrom(model_with_shapes)
|
||||
|
||||
# Build the graph outputs
|
||||
|
|
@ -122,6 +117,7 @@ class TrainingModel(building_blocks.Block):
|
|||
accessor.global_accessor.model, self._arg_not_requiring_grad
|
||||
)
|
||||
|
||||
# build the gradient graph
|
||||
all_args_requiring_gradient_names = graph_utils.build_gradient_graph(
|
||||
accessor.global_accessor,
|
||||
self._arg_requiring_grad,
|
||||
|
|
@ -130,9 +126,7 @@ class TrainingModel(building_blocks.Block):
|
|||
)
|
||||
|
||||
# add gradient accumulation nodes
|
||||
graph_utils.build_gradient_accumulation_graph(
|
||||
accessor.global_accessor.model, all_args_requiring_gradient_names
|
||||
)
|
||||
graph_utils.build_gradient_accumulation_graph(accessor.global_accessor.model, all_args_requiring_gradient_names)
|
||||
|
||||
# validate and check the model
|
||||
onnx.checker.check_model(accessor.global_accessor.model, True)
|
||||
|
|
|
|||
|
|
@ -1,16 +1,41 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# global_model.py
|
||||
# model_accessor.py
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
import onnx
|
||||
|
||||
|
||||
class ModelAccessor:
|
||||
"""This class stores the onnx model that is manipulated by the onnx blocks."""
|
||||
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
self.eval_model = None
|
||||
self._model = model
|
||||
self._eval_model = None
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""ModelAccessor property that gets the modified model."""
|
||||
|
||||
if self._model is None:
|
||||
raise RuntimeError(
|
||||
"The onnx model was not set. Please use the context manager onnxblock.onnx_model to create the model."
|
||||
)
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def eval_model(self):
|
||||
"""ModelAccessor property that gets the eval model."""
|
||||
|
||||
if self._eval_model is None:
|
||||
raise RuntimeError("The eval onnx model was not set.")
|
||||
return self._eval_model
|
||||
|
||||
@eval_model.setter
|
||||
def eval_model(self, value):
|
||||
"""ModelAccessor property that sets the eval model."""
|
||||
self._eval_model = value
|
||||
|
||||
|
||||
# This variable resides in the global namespace.
|
||||
|
|
@ -26,6 +51,14 @@ def onnx_model(model=None):
|
|||
Manages the construction and destruction of the global model.
|
||||
"""
|
||||
global global_accessor
|
||||
if global_accessor is not None:
|
||||
raise RuntimeError("Base onnx model already exists. Cannot create multiple ModelAccessors.")
|
||||
|
||||
# If the user did not provide a model, then assume that they want to build from scratch.
|
||||
# It is the duty of the caller to fill the model however they deam fit.
|
||||
if model is None:
|
||||
model = onnx.ModelProto()
|
||||
|
||||
global_accessor = ModelAccessor(model)
|
||||
try:
|
||||
yield global_accessor
|
||||
|
|
|
|||
|
|
@ -2,4 +2,4 @@
|
|||
# Licensed under the MIT License.
|
||||
# __init__.py
|
||||
|
||||
from .optim import AdamW
|
||||
from .optim import AdamW, ClipGradNorm
|
||||
|
|
|
|||
|
|
@ -2,160 +2,216 @@
|
|||
# Licensed under the MIT License.
|
||||
# optim.py
|
||||
|
||||
import copy
|
||||
import onnx
|
||||
|
||||
import onnxruntime.training.onnxblock as onnxblock
|
||||
import onnxruntime.training.onnxblock._graph_utils as graph_utils
|
||||
import onnxruntime.training.onnxblock.building_blocks as building_blocks
|
||||
import onnxruntime.training.onnxblock.model as model
|
||||
import onnxruntime.training.onnxblock.model_accessor as accessor
|
||||
|
||||
# TODO: Find a better place for these constants
|
||||
_PRODUCER_NAME = "onnxblock offline tooling"
|
||||
_OPSET_IMPORTS = [onnx.helper.make_opsetid("com.microsoft", 1), onnx.helper.make_opsetid("", 14)]
|
||||
|
||||
class AdamW(onnxblock.Model):
|
||||
"""Builds AdamW optimizer onnxblock for the given training model."""
|
||||
|
||||
def __init__(
|
||||
self, bias_correction=True, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0
|
||||
):
|
||||
super(AdamW, self).__init__()
|
||||
class AdamWOptimizer(building_blocks.Block):
|
||||
"""Adds an AdamWOptimizer node to the onnx model."""
|
||||
|
||||
def __init__(self, bias_correction=True, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0):
|
||||
super().__init__()
|
||||
|
||||
self._bias_correction = bias_correction
|
||||
self._betas = betas
|
||||
self._eps = eps
|
||||
self._weight_decay = weight_decay
|
||||
self._max_norm_clip = 1.0
|
||||
|
||||
def build( # pylint: disable=too-many-arguments
|
||||
self,
|
||||
learning_rate_name,
|
||||
step_name,
|
||||
parameter_sequence_name,
|
||||
gradient_sequence_name,
|
||||
first_order_moment_sequence_name,
|
||||
second_order_moment_sequence_name,
|
||||
):
|
||||
"""Adds the AdamWOptimizer node to the model."""
|
||||
|
||||
# get the model to manipulate
|
||||
onnx_model = accessor.global_accessor.model
|
||||
|
||||
# define the node attributes
|
||||
node_attributes = {
|
||||
"alpha": self._betas[0], # beta1
|
||||
"beta": self._betas[1], # beta2
|
||||
"epsilon": self._eps, # epsilon
|
||||
"weight_decay": self._weight_decay, # weight decay
|
||||
"correct_bias": 1 if self._bias_correction else 0, # bias_correction
|
||||
"adam_mode": 1, # adam mode (1 for hf/transformers/AdamW)
|
||||
}
|
||||
|
||||
# add the adamw node to the onnx model
|
||||
adamw_input_names = [
|
||||
learning_rate_name, # learning rate
|
||||
step_name, # training step
|
||||
parameter_sequence_name, # param to be updated
|
||||
gradient_sequence_name, # gradient of the param to be used for update
|
||||
first_order_moment_sequence_name, # first order moment for this param
|
||||
second_order_moment_sequence_name, # second order moment for this param
|
||||
]
|
||||
adamw_output_name = graph_utils.generate_random_graph_name("adamw.updated_flag")
|
||||
adamw_output_names = [adamw_output_name]
|
||||
adamw_node = onnx.helper.make_node(
|
||||
"AdamWOptimizer",
|
||||
adamw_input_names,
|
||||
adamw_output_names,
|
||||
name=graph_utils.generate_random_graph_name("AdamWOptimizer"),
|
||||
domain="com.microsoft",
|
||||
**node_attributes,
|
||||
)
|
||||
onnx_model.graph.node.append(adamw_node)
|
||||
|
||||
return adamw_output_name
|
||||
|
||||
|
||||
class ClipGradNorm(building_blocks.Block):
|
||||
"""Builds a gradient clipping by norm sub graph for the onnx model.
|
||||
|
||||
Creates a block that performs gradient clipping by l2 norm for the calculated
|
||||
gradient.
|
||||
|
||||
Args:
|
||||
max_norm: float indicating the max norm of the gradients.
|
||||
|
||||
Returns:
|
||||
Returns a string of the output names of the gradients after clipping.
|
||||
"""
|
||||
|
||||
def __init__(self, max_norm):
|
||||
super().__init__()
|
||||
|
||||
self._max_norm = max_norm
|
||||
|
||||
self._reduce = building_blocks.ReduceAllL2()
|
||||
self._add = building_blocks.Add()
|
||||
self._div = building_blocks.Div()
|
||||
self._mul = building_blocks.Mul()
|
||||
self._clip = building_blocks.Clip(clip_max=1.0)
|
||||
|
||||
def build(self, *gradient_names):
|
||||
"""Adds a clip grad norm sub graph to the onnx model."""
|
||||
|
||||
# get the model to manipulate
|
||||
onnx_model = accessor.global_accessor.model
|
||||
|
||||
# add the necessary graph initializers
|
||||
add_node_eps_name = graph_utils.generate_random_graph_name("add_eps")
|
||||
onnx_model.graph.initializer.append(
|
||||
onnx.helper.make_tensor(add_node_eps_name, onnx.TensorProto.FLOAT, [1], [1e-6])
|
||||
)
|
||||
max_norm_name = graph_utils.generate_random_graph_name("max_norm")
|
||||
onnx_model.graph.initializer.append(
|
||||
onnx.helper.make_tensor(max_norm_name, onnx.TensorProto.FLOAT, [1], [self._max_norm])
|
||||
)
|
||||
|
||||
# perform gradient clipping
|
||||
total_norm_name = self._reduce(*gradient_names)
|
||||
adjusted_total_norm_name = self._add(total_norm_name, add_node_eps_name)
|
||||
clip_coef_name = self._clip(self._div(max_norm_name, adjusted_total_norm_name))
|
||||
return [self._mul(grad_name, clip_coef_name) for grad_name in gradient_names]
|
||||
|
||||
|
||||
class AdamW(model.Model):
|
||||
"""Builds AdamW optimizer onnxblock for the given training parameters.
|
||||
|
||||
Creates a block that updates the model parameters based on the calculated
|
||||
gradient following the AdamW algorithm.
|
||||
|
||||
Args:
|
||||
bias_correction: bool indicating whether to perform bias correction.
|
||||
betas: AdamW decay rate hyperparameters.
|
||||
eps: term added to the denominator for computing the moments.
|
||||
weight_decay: AdamW weight decay
|
||||
clip_grad (optional): an instance of the ClipGradNorm. If not provided,
|
||||
gradient clipping will not be done.
|
||||
|
||||
Returns:
|
||||
Returns a string of the output names from this optimizer node.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, bias_correction=True, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, clip_grad=None
|
||||
): # pylint: disable=too-many-arguments
|
||||
super().__init__()
|
||||
|
||||
self._adamw = AdamWOptimizer(
|
||||
bias_correction=bias_correction,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
self._clip_grad = clip_grad
|
||||
self._sc = building_blocks.SequenceConstruct()
|
||||
|
||||
def build(self, parameters):
|
||||
"""Returns an AdamW optimizer model based on the input training model."""
|
||||
"""Returns an AdamW optimizer model based on the input parameters."""
|
||||
|
||||
# get the model to manipulate and update its namespace
|
||||
onnx_model = accessor.global_accessor.model
|
||||
onnx_model.graph.name = "AdamW Optimizer Model"
|
||||
onnx_model.producer_name = _PRODUCER_NAME
|
||||
onnx_model.opset_import.extend(_OPSET_IMPORTS)
|
||||
onnx_model.ir_version = onnx.IR_VERSION
|
||||
|
||||
# TODO: Avoid hard coded input/output strings
|
||||
learning_rate_name = "learning_rate"
|
||||
step_name = "step"
|
||||
gradient_output_suffix = "_grad.accumulation.out"
|
||||
first_order_moment_suffix = "exp_avg"
|
||||
second_order_moment_fuffix = "exp_avg_sq"
|
||||
output_name_suffix = "out"
|
||||
params_name = "params"
|
||||
first_order_moments_name = "first_order_moments"
|
||||
second_order_moments_name = "second_order_moments"
|
||||
gradient_suffix = "_grad"
|
||||
|
||||
trainable_parameters, _ = parameters
|
||||
|
||||
graph_nodes = []
|
||||
graph_inputs = [
|
||||
onnx.helper.make_tensor_value_info(
|
||||
learning_rate_name, onnx.TensorProto.FLOAT, [1]
|
||||
),
|
||||
onnx.helper.make_tensor_value_info(step_name, onnx.TensorProto.INT64, [1]),
|
||||
]
|
||||
graph_outputs = []
|
||||
|
||||
# Iterate over all training graph outputs that are gradient outputs
|
||||
for idx, param in enumerate(trainable_parameters):
|
||||
|
||||
param_name = param.name
|
||||
grad_name = f"{param_name}{gradient_output_suffix}"
|
||||
first_order_moment_name = f"{param_name}.{first_order_moment_suffix}"
|
||||
second_order_moment_name = f"{param_name}.{second_order_moment_fuffix}"
|
||||
# prepare node (and graph) inputs and outputs
|
||||
node_input_names = [
|
||||
learning_rate_name, # learning rate
|
||||
step_name, # training step (used for beta correction)
|
||||
param_name, # param to be updated
|
||||
grad_name, # gradient of the param to be used for update
|
||||
first_order_moment_name, # first order moment for this param
|
||||
second_order_moment_name, # second order moment for this param
|
||||
# create the graph inputs for the lr, step, params, grads, moments
|
||||
onnx_model.graph.input.extend(
|
||||
[
|
||||
onnx.helper.make_tensor_value_info(learning_rate_name, onnx.TensorProto.FLOAT, [1]),
|
||||
onnx.helper.make_tensor_value_info(step_name, onnx.TensorProto.INT64, [1]),
|
||||
]
|
||||
|
||||
param_tensor_value_info = onnx.helper.make_tensor_value_info(
|
||||
param_name, param.data_type, param.dims
|
||||
)
|
||||
grad_tensor_value_info = onnx.helper.make_tensor_value_info(
|
||||
grad_name, param.data_type, param.dims
|
||||
)
|
||||
first_order_moment_tensor_value_info = onnx.helper.make_tensor_value_info(
|
||||
first_order_moment_name, param.data_type, param.dims
|
||||
)
|
||||
second_order_moment_tensor_value_info = onnx.helper.make_tensor_value_info(
|
||||
second_order_moment_name, param.data_type, param.dims
|
||||
)
|
||||
node_inputs = [
|
||||
param_tensor_value_info,
|
||||
grad_tensor_value_info,
|
||||
first_order_moment_tensor_value_info,
|
||||
second_order_moment_tensor_value_info,
|
||||
]
|
||||
graph_inputs.extend(node_inputs)
|
||||
|
||||
step_output_name = f"{param_name}.{step_name}.{output_name_suffix}"
|
||||
param_output_name = f"{param_name}.{output_name_suffix}"
|
||||
first_order_moment_output_name = (
|
||||
f"{first_order_moment_name}.{output_name_suffix}"
|
||||
)
|
||||
second_order_moment_output_name = (
|
||||
f"{second_order_moment_name}.{output_name_suffix}"
|
||||
)
|
||||
|
||||
param_output_tensor_value_info = onnx.helper.make_tensor_value_info(
|
||||
param_output_name, param.data_type, param.dims
|
||||
)
|
||||
first_order_moment_output_tensor_value_info = (
|
||||
onnx.helper.make_tensor_value_info(
|
||||
first_order_moment_output_name, param.data_type, param.dims
|
||||
)
|
||||
)
|
||||
second_order_moment_output_tensor_value_info = (
|
||||
onnx.helper.make_tensor_value_info(
|
||||
second_order_moment_output_name, param.data_type, param.dims
|
||||
)
|
||||
)
|
||||
|
||||
node_output_names = [
|
||||
step_output_name, # step out
|
||||
first_order_moment_output_name, # first order moment output
|
||||
second_order_moment_output_name, # second order moment output
|
||||
param_output_name, # updated weights
|
||||
]
|
||||
|
||||
node_outputs = [
|
||||
onnx.helper.make_tensor_value_info(
|
||||
step_output_name, onnx.TensorProto.INT64, [1]
|
||||
),
|
||||
first_order_moment_output_tensor_value_info,
|
||||
second_order_moment_output_tensor_value_info,
|
||||
param_output_tensor_value_info,
|
||||
]
|
||||
graph_outputs.extend(node_outputs)
|
||||
|
||||
# AdamOptimizer node attributes
|
||||
node_attributes = {
|
||||
"alpha": self._betas[0], # beta1
|
||||
"beta": self._betas[1], # beta2
|
||||
"lambda": self._weight_decay, # weight decay
|
||||
"epsilon": self._eps, # epsilon
|
||||
"do_bias_correction": 1
|
||||
if self._bias_correction
|
||||
else 0, # bias_correction
|
||||
"weight_decay_mode": 1, # weight decay mode
|
||||
"max_norm_clip": self._max_norm_clip, # used for gradient scaling
|
||||
}
|
||||
|
||||
# make the node
|
||||
optimizer_node = onnx.helper.make_node(
|
||||
"AdamOptimizer",
|
||||
node_input_names,
|
||||
node_output_names,
|
||||
name=f"AdamOptimizer{idx}",
|
||||
domain="com.microsoft",
|
||||
**node_attributes,
|
||||
)
|
||||
|
||||
graph_nodes.append(optimizer_node)
|
||||
|
||||
# make the graph and the model
|
||||
graph = onnx.helper.make_graph(
|
||||
graph_nodes, "Optimizer Graph", graph_inputs, graph_outputs
|
||||
)
|
||||
model = onnx.helper.make_model(
|
||||
graph,
|
||||
producer_name=onnxblock._producer_name,
|
||||
opset_imports=[onnxblock._opset_import],
|
||||
)
|
||||
|
||||
accessor.global_accessor.model = model
|
||||
# Prepare the tensor sequence inputs for params and moments
|
||||
for input_name in [params_name, first_order_moments_name, second_order_moments_name]:
|
||||
onnx_model.graph.input.append(
|
||||
onnx.helper.make_tensor_sequence_value_info(input_name, trainable_parameters[0].data_type, None)
|
||||
)
|
||||
|
||||
return [output.name for output in graph_outputs]
|
||||
# TODO: Make the grads as a tensor sequence input after implementing clip grad
|
||||
# normalization implementation which takes in a tensor sequence.
|
||||
grad_names = []
|
||||
for param in trainable_parameters:
|
||||
grad_names.append(f"{param.name}{gradient_suffix}")
|
||||
onnx_model.graph.input.append(
|
||||
onnx.helper.make_tensor_value_info(grad_names[-1], param.data_type, param.dims)
|
||||
)
|
||||
|
||||
# Clip the gradients if needed
|
||||
if self._clip_grad is not None:
|
||||
grad_names = self._clip_grad(*grad_names)
|
||||
|
||||
# Run multi tensor AdamWOptimizer
|
||||
updated_flag_name = self._adamw(
|
||||
learning_rate_name,
|
||||
step_name,
|
||||
params_name,
|
||||
self._sc(*grad_names),
|
||||
first_order_moments_name,
|
||||
second_order_moments_name,
|
||||
)
|
||||
|
||||
# Create the graph outputs
|
||||
onnx_model.graph.output.append(
|
||||
onnx.helper.make_tensor_value_info(updated_flag_name, onnx.TensorProto.INT64, [1])
|
||||
)
|
||||
|
||||
return updated_flag_name
|
||||
|
|
|
|||
|
|
@ -421,8 +421,12 @@ def test_bcewithlogits_loss_training_graph_execution():
|
|||
assert np.allclose(ort_grad, _to_numpy(pt_param.grad))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("graph", [SimpleTrainingModelWithMSELoss, SimpleTrainingModelWithCrossEntropyLoss])
|
||||
def test_adamw_optimizer_composition(graph):
|
||||
@pytest.mark.parametrize(
|
||||
"graph",
|
||||
[SimpleTrainingModelWithMSELoss, SimpleTrainingModelWithCrossEntropyLoss, SimpleTrainingModelWithBCEWithLogitsLoss],
|
||||
)
|
||||
@pytest.mark.parametrize("grad_clipping", [None, onnxblock.optim.ClipGradNorm(2.5)])
|
||||
def test_adamw_optimizer_composition(graph, grad_clipping):
|
||||
# Given
|
||||
device = "cuda"
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
|
|
@ -433,13 +437,15 @@ def test_adamw_optimizer_composition(graph):
|
|||
with onnxblock.onnx_model(onnx_model):
|
||||
_ = simple_model(onnx_model.graph.output[0].name)
|
||||
|
||||
optimizer = onnxblock.optim.AdamW()
|
||||
optimizer = onnxblock.optim.AdamW(clip_grad=grad_clipping)
|
||||
with onnxblock.onnx_model() as accessor:
|
||||
_ = optimizer(simple_model.parameters())
|
||||
optimizer_model = accessor.model
|
||||
assert optimizer_model
|
||||
|
||||
|
||||
# TODO: Add a test for correctness when creation of ortvalues of
|
||||
# tensor seq is possible on cuda
|
||||
def test_adamw_optimizer_execution():
|
||||
# Given
|
||||
device = "cuda"
|
||||
|
|
@ -455,14 +461,12 @@ def test_adamw_optimizer_execution():
|
|||
|
||||
optimizer = onnxblock.optim.AdamW()
|
||||
with onnxblock.onnx_model() as accessor:
|
||||
_ = optimizer(simple_model.parameters())
|
||||
output_name = optimizer(simple_model.parameters())
|
||||
optimizer_model = accessor.model
|
||||
|
||||
learning_rate = 0.001
|
||||
step = 1
|
||||
ort_output_names = []
|
||||
for name, _ in pt_model.named_parameters():
|
||||
ort_output_names.append(f"{name}.out")
|
||||
ort_output_names = [output_name]
|
||||
|
||||
def mse_loss(prediction, target):
|
||||
loss = torch.nn.MSELoss()
|
||||
|
|
@ -478,12 +482,15 @@ def test_adamw_optimizer_execution():
|
|||
ort_inputs = {
|
||||
"learning_rate": np.full(1, learning_rate, dtype=np.float32),
|
||||
"step": np.full(1, step, dtype=np.int64),
|
||||
"params": [],
|
||||
"first_order_moments": [],
|
||||
"second_order_moments": [],
|
||||
}
|
||||
for name, param in pt_model.named_parameters():
|
||||
ort_inputs[name] = _to_numpy(copy.deepcopy(param))
|
||||
ort_inputs[f"{name}_grad.accumulation.out"] = _to_numpy(copy.deepcopy(param.grad))
|
||||
ort_inputs[f"{name}.exp_avg"] = _to_numpy(torch.zeros_like(param))
|
||||
ort_inputs[f"{name}.exp_avg_sq"] = _to_numpy(torch.zeros_like(param))
|
||||
ort_inputs["params"].append(_to_numpy(copy.deepcopy(param)))
|
||||
ort_inputs[f"{name}_grad"] = _to_numpy(copy.deepcopy(param.grad))
|
||||
ort_inputs["first_order_moments"].append(_to_numpy(torch.zeros_like(param)))
|
||||
ort_inputs["second_order_moments"].append(_to_numpy(torch.zeros_like(param)))
|
||||
|
||||
# Then no error occurs when executing the model
|
||||
ort_session = onnxruntime.InferenceSession(onnx_fo.name, providers=C.get_available_providers())
|
||||
|
|
@ -642,3 +649,64 @@ def test_weighted_average_model_composition(model_type):
|
|||
weighted_model = WeightedAvg(random.random(), random.random())
|
||||
with onnxblock.onnx_model(onnx_model):
|
||||
_ = weighted_model(onnx_model.graph.output[0].name, onnx_model.graph.output[1].name)
|
||||
|
||||
|
||||
def test_grad_clipping_execution():
|
||||
# Given
|
||||
device = "cuda"
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
pt_model, _ = _get_models(device, N, D_in, H, D_out)
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
target = torch.randn(N, D_out, device=device)
|
||||
|
||||
# Prepare the onnx model with only grad clipping
|
||||
onnx_model = onnx.ModelProto()
|
||||
onnx_model.graph.name = "AdamW Optimizer Model"
|
||||
onnx_model.producer_name = "grad clipping test"
|
||||
onnx_model.opset_import.extend(onnxblock.optim.optim._OPSET_IMPORTS)
|
||||
onnx_model.ir_version = onnx.IR_VERSION
|
||||
|
||||
class GradClippingModel(onnxblock.Model):
|
||||
def __init__(self, max_norm):
|
||||
self._grad_clip = onnxblock.optim.ClipGradNorm(max_norm)
|
||||
|
||||
def build(self, *grad_names):
|
||||
return self._grad_clip(*grad_names)
|
||||
|
||||
grad_names = []
|
||||
for name, param in pt_model.named_parameters():
|
||||
grad_names.append(f"{name}_grad")
|
||||
|
||||
onnx_model.graph.input.append(
|
||||
onnx.helper.make_tensor_value_info(grad_names[-1], onnx.TensorProto.FLOAT, param.shape)
|
||||
)
|
||||
|
||||
grad_clip = GradClippingModel(2.5)
|
||||
|
||||
with onnxblock.onnx_model(onnx_model):
|
||||
ort_output_names = grad_clip(*grad_names)
|
||||
|
||||
def mse_loss(prediction, target):
|
||||
loss = torch.nn.MSELoss()
|
||||
return loss(prediction, target)
|
||||
|
||||
# When
|
||||
with tempfile.NamedTemporaryFile(suffix=".onnx") as onnx_fo:
|
||||
onnx.save(onnx_model, onnx_fo.name)
|
||||
|
||||
loss = mse_loss(pt_model(x), target)
|
||||
loss.backward()
|
||||
|
||||
ort_inputs = {}
|
||||
for name, param in pt_model.named_parameters():
|
||||
ort_inputs[f"{name}_grad"] = _to_numpy(copy.deepcopy(param.grad))
|
||||
|
||||
torch.nn.utils.clip_grad_norm_(pt_model.parameters(), 2.5)
|
||||
|
||||
# Then no error occurs when executing the model
|
||||
ort_session = onnxruntime.InferenceSession(onnx_fo.name, providers=C.get_available_providers())
|
||||
ort_outs = ort_session.run(ort_output_names, ort_inputs)
|
||||
|
||||
# assert all the gradients are close
|
||||
for ort_grad, pt_param in zip(ort_outs, pt_model.parameters()):
|
||||
assert np.allclose(ort_grad, _to_numpy(pt_param.grad))
|
||||
|
|
|
|||
|
|
@ -19,7 +19,11 @@
|
|||
#include "core/platform/path_lib.h"
|
||||
|
||||
#include "orttraining/core/framework/checkpoint_common.h"
|
||||
#include "orttraining/training_api/include/interfaces.h"
|
||||
#include "orttraining/training_api/include/module.h"
|
||||
#include "orttraining/training_api/include/optimizer.h"
|
||||
#include "orttraining/training_api/include/checkpoint_property.h"
|
||||
#include "orttraining/training_api/include/checkpoint.h"
|
||||
#include "orttraining/training_api/include/lr_scheduler.h"
|
||||
|
||||
#include "test/test_environment.h"
|
||||
#include "test/util/include/asserts.h"
|
||||
|
|
@ -27,6 +31,7 @@
|
|||
#include "test/util/include/test/test_environment.h"
|
||||
#include "orttraining/test/training_api/common/synthetic_data_loader.h"
|
||||
#include "orttraining/test/training_api/core/data_utils.h"
|
||||
#include "default_providers.h"
|
||||
|
||||
using onnxruntime::test::TemporaryDirectory;
|
||||
using namespace onnxruntime::training::api;
|
||||
|
|
@ -159,7 +164,9 @@ TEST(CheckpointApiTest, SaveOnnxModelAsCheckpoint_ThenLoad_CPU) {
|
|||
* Save Optimizer states into ORT checkpoint files,
|
||||
* Then load it into ORT, compare with the initial optimizer states values.
|
||||
*/
|
||||
TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad_CPU) {
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
|
||||
TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad_CUDA) {
|
||||
/// Phase 1 - Test Preparison
|
||||
/// Prepare the data and dest folder for saving checkpoint.
|
||||
/// Also cooked the data for test result comparison.
|
||||
|
|
@ -201,15 +208,17 @@ TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad_CPU) {
|
|||
onnxruntime::SessionOptions session_option;
|
||||
std::unique_ptr<Environment> env;
|
||||
ORT_THROW_IF_ERROR(Environment::Create(nullptr, env));
|
||||
std::vector<std::shared_ptr<IExecutionProvider>> cuda_provider{onnxruntime::test::DefaultCudaExecutionProvider()};
|
||||
auto model = std::make_unique<Module>(model_uri, named_parameters, session_option,
|
||||
*env);
|
||||
auto optimizer = Optimizer(optim_uri, model->NamedParameters(), session_option, *env);
|
||||
*env, cuda_provider);
|
||||
auto optimizer = std::make_unique<Optimizer>(optim_uri, model->NamedParameters(), session_option,
|
||||
*env, cuda_provider);
|
||||
|
||||
/// Phase 2 - Run Optimizer.GetStateDict and call save checkpoint APIs.
|
||||
/// And check the result checkpoint files.
|
||||
|
||||
CheckpointState checkpoint_state;
|
||||
ORT_ENFORCE(optimizer.GetStateDict(checkpoint_state.optimizer_checkpoint_state).IsOK());
|
||||
ORT_ENFORCE(optimizer->GetStateDict(checkpoint_state.optimizer_checkpoint_state).IsOK());
|
||||
|
||||
// Remove the tempoprary directory if it already exists.
|
||||
auto ckpt_test_root_dir = ORT_TSTR("checkpointing_api_test_dir");
|
||||
|
|
@ -286,6 +295,8 @@ TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad_CPU) {
|
|||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Create PropertyBag with sets of properties,
|
||||
* Save properties into ORT checkpoint files,
|
||||
|
|
|
|||
|
|
@ -20,6 +20,25 @@ void OrtValueToVec(const OrtValue& val, std::vector<T>& output) {
|
|||
output.assign(val_ptr, val_ptr + num_elem);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CudaOrtValueToCpuVec(const OrtValue& val, std::vector<T>& output,
|
||||
std::shared_ptr<IExecutionProvider> cuda_provider,
|
||||
std::shared_ptr<IExecutionProvider> cpu_provider) {
|
||||
const Tensor& src_tensor = val.Get<Tensor>();
|
||||
|
||||
auto allocator = cpu_provider->GetAllocator(0, OrtMemTypeDefault);
|
||||
ORT_ENFORCE(allocator, "Cpu allocator is a nullptr.");
|
||||
auto dst_tensor = std::make_unique<Tensor>(src_tensor.DataType(), src_tensor.Shape(), allocator);
|
||||
|
||||
auto data_transfer = cuda_provider->GetDataTransfer();
|
||||
ORT_ENFORCE(data_transfer, "Cuda data transfer is a nullptr.");
|
||||
|
||||
ORT_THROW_IF_ERROR(data_transfer->CopyTensor(src_tensor, *dst_tensor));
|
||||
|
||||
const T* val_ptr = dst_tensor->template Data<T>();
|
||||
output.assign(val_ptr, val_ptr + src_tensor.Shape().Size());
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -11,8 +11,13 @@
|
|||
#include "core/common/path_utils.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "orttraining/training_api/include/utils.h"
|
||||
#include "orttraining/training_api/include/interfaces.h"
|
||||
#include "orttraining/training_api/include/module.h"
|
||||
#include "orttraining/training_api/include/optimizer.h"
|
||||
#include "orttraining/training_api/include/checkpoint_property.h"
|
||||
#include "orttraining/training_api/include/checkpoint.h"
|
||||
#include "orttraining/training_api/include/lr_scheduler.h"
|
||||
#include "orttraining/test/training_api/core/data_utils.h"
|
||||
#include "default_providers.h"
|
||||
|
||||
using json = nlohmann::json;
|
||||
using namespace onnxruntime::training;
|
||||
|
|
@ -53,7 +58,7 @@ TEST(TrainingApiTest, ModuleTrainStep) {
|
|||
std::unique_ptr<Environment> env;
|
||||
ORT_THROW_IF_ERROR(Environment::Create(nullptr, env));
|
||||
auto model = std::make_unique<Module>(model_uri, state.module_checkpoint_state.named_parameters, session_option,
|
||||
*env);
|
||||
*env, std::vector<std::shared_ptr<IExecutionProvider>>());
|
||||
|
||||
OrtValue input, target;
|
||||
GenerateRandomInput(std::array<int64_t, 2>{2, 784}, input);
|
||||
|
|
@ -95,6 +100,8 @@ TEST(TrainingApiTest, ModuleTrainStep) {
|
|||
}
|
||||
}
|
||||
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
|
||||
TEST(TrainingApiTest, OptimStep) {
|
||||
auto model_uri = MODEL_FOLDER "gradient_graph.onnx";
|
||||
auto optim_uri = MODEL_FOLDER "adamw.onnx";
|
||||
|
|
@ -105,10 +112,14 @@ TEST(TrainingApiTest, OptimStep) {
|
|||
|
||||
onnxruntime::SessionOptions session_option;
|
||||
std::unique_ptr<Environment> env;
|
||||
std::vector<std::shared_ptr<IExecutionProvider>> providers{onnxruntime::test::DefaultCudaExecutionProvider()};
|
||||
std::shared_ptr<IExecutionProvider> cuda_provider = providers.front();
|
||||
std::shared_ptr<IExecutionProvider> cpu_provider = onnxruntime::test::DefaultCpuExecutionProvider();
|
||||
ORT_THROW_IF_ERROR(Environment::Create(nullptr, env));
|
||||
auto model = std::make_unique<Module>(model_uri, state.module_checkpoint_state.named_parameters, session_option,
|
||||
*env);
|
||||
auto optim = std::make_unique<Optimizer>(optim_uri, model->NamedParameters(), session_option, *env);
|
||||
*env, providers);
|
||||
auto optim = std::make_unique<Optimizer>(optim_uri, model->NamedParameters(), session_option,
|
||||
*env, providers);
|
||||
|
||||
OrtValue input, target;
|
||||
GenerateRandomInput(std::array<int64_t, 2>{2, 784}, input);
|
||||
|
|
@ -125,8 +136,11 @@ TEST(TrainingApiTest, OptimStep) {
|
|||
optimizer_states.group_named_optimizer_states["group0"]->param_named_optimizer_states.at(param_name);
|
||||
OrtValue& moment_1 = param_state.momentum_named_states.at("momentum0");
|
||||
|
||||
std::vector<float> param_vec_before_optimizer_step;
|
||||
CudaOrtValueToCpuVec(model->NamedParameters().at(param_name)->Data(), param_vec_before_optimizer_step,
|
||||
cuda_provider, cpu_provider);
|
||||
std::vector<float> moment_1_vec;
|
||||
OrtValueToVec(moment_1, moment_1_vec);
|
||||
CudaOrtValueToCpuVec(moment_1, moment_1_vec, cuda_provider, cpu_provider);
|
||||
for (size_t i = 0; i < moment_1_vec.size(); i++) {
|
||||
ORT_ENFORCE(moment_1_vec[i] == 0.0f);
|
||||
}
|
||||
|
|
@ -136,14 +150,25 @@ TEST(TrainingApiTest, OptimStep) {
|
|||
std::vector<OrtValue>& inputs = *it;
|
||||
std::vector<OrtValue> fetches;
|
||||
ORT_ENFORCE(model->TrainStep(inputs, fetches).IsOK());
|
||||
std::vector<float> grads;
|
||||
CudaOrtValueToCpuVec(model->NamedParameters().at(param_name)->Gradient(), grads,
|
||||
cuda_provider, cpu_provider);
|
||||
ORT_ENFORCE(optim->Step().IsOK());
|
||||
|
||||
// get optim state and check if it is updated
|
||||
OrtValueToVec(moment_1, moment_1_vec);
|
||||
CudaOrtValueToCpuVec(moment_1, moment_1_vec, cuda_provider, cpu_provider);
|
||||
for (size_t i = 0; i < moment_1_vec.size(); i++) {
|
||||
if (moment_1_vec[i] != 0.0f) {
|
||||
// moment was updated
|
||||
break;
|
||||
if (grads[i] != 0.0f) {
|
||||
ORT_ENFORCE(moment_1_vec[i] != 0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<float> param_vec_after_optimizer_step;
|
||||
CudaOrtValueToCpuVec(model->NamedParameters().at(param_name)->Data(), param_vec_after_optimizer_step,
|
||||
cuda_provider, cpu_provider);
|
||||
for (size_t i = 0; i < param_vec_after_optimizer_step.size(); ++i) {
|
||||
if (grads[i] != 0.0f && moment_1_vec[i] != 0.0f) {
|
||||
ORT_ENFORCE(param_vec_after_optimizer_step[i] != param_vec_before_optimizer_step[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -167,9 +192,11 @@ void TestLRSchduler(const std::string& test_file_name, float initial_lr, int64_t
|
|||
onnxruntime::SessionOptions session_option;
|
||||
std::unique_ptr<Environment> env;
|
||||
ORT_THROW_IF_ERROR(Environment::Create(nullptr, env));
|
||||
const std::vector<std::shared_ptr<IExecutionProvider>> providers{onnxruntime::test::DefaultCudaExecutionProvider()};
|
||||
auto model = std::make_unique<Module>(model_uri, state.module_checkpoint_state.named_parameters, session_option,
|
||||
*env);
|
||||
auto optim = std::make_shared<Optimizer>(optim_uri, model->NamedParameters(), session_option, *env);
|
||||
*env, providers);
|
||||
auto optim = std::make_shared<Optimizer>(optim_uri, model->NamedParameters(), session_option,
|
||||
*env, providers);
|
||||
|
||||
OrtValue input, target;
|
||||
GenerateRandomInput(std::array<int64_t, 2>{2, 784}, input);
|
||||
|
|
@ -259,6 +286,8 @@ TEST(TrainingApiTest, LinearLRScheduler_WarmUp200Step_ResumeFromCheckpoint_Test)
|
|||
TestLRSchduler("warmup_linear_scheduler_warmupstep-200_restored.json", initial_lr, total_step_count, 200);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace
|
||||
} // namespace test
|
||||
} // namespace training
|
||||
|
|
|
|||
|
|
@ -1,172 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
/**
|
||||
* @brief Temporary classes bridging between public C interfaces and internal ones.
|
||||
* Currently implementation is not covering all targeted public training APIs, more changes are expected to add.
|
||||
*
|
||||
* TODO: For longer-term, we should re-arange following interfaces following the ways of exiting APIs' definitions and exposures
|
||||
* in include/onnxruntime/core/session/onnxruntime_c_api.h and include/onnxruntime/core/session/onnxruntime_cxx_api.h.
|
||||
*
|
||||
* This is an intermediate header file for our training api internal development usage.
|
||||
*
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <onnxruntime_cxx_api.h>
|
||||
|
||||
#include "core/session/inference_session.h"
|
||||
#include "core/session/environment.h"
|
||||
|
||||
#include "orttraining/training_api/include/module.h"
|
||||
#include "orttraining/training_api/include/optimizer.h"
|
||||
#include "orttraining/training_api/include/lr_scheduler.h"
|
||||
#include "orttraining/training_api/include/checkpoint_property.h"
|
||||
#include "orttraining/training_api/include/checkpoint.h"
|
||||
|
||||
using onnxruntime::training::api::LinearLRScheduler;
|
||||
using onnxruntime::training::api::Module;
|
||||
using onnxruntime::training::api::ModuleCheckpointState;
|
||||
using onnxruntime::training::api::Optimizer;
|
||||
using onnxruntime::training::api::OptimizerCheckpointState;
|
||||
using onnxruntime::training::api::Parameter;
|
||||
|
||||
namespace Ort {
|
||||
|
||||
namespace {
|
||||
|
||||
void ToOrtValue(const std::vector<Ort::Value>& ort_value_list, std::vector<OrtValue>& ortvalue_list) {
|
||||
size_t input_len = ort_value_list.size();
|
||||
ortvalue_list.clear();
|
||||
ortvalue_list.reserve(input_len);
|
||||
const Ort::Value* ort_value_inputs_ptr = ort_value_list.data();
|
||||
auto ortvalue_inputs_ptr = reinterpret_cast<const OrtValue**>(const_cast<Ort::Value*>(ort_value_inputs_ptr));
|
||||
for (size_t i = 0; i < input_len; ++i) {
|
||||
auto& ort_value = *reinterpret_cast<const ::OrtValue*>(ortvalue_inputs_ptr[i]);
|
||||
ortvalue_list.push_back(ort_value);
|
||||
}
|
||||
}
|
||||
|
||||
void FromOrtValue(std::vector<OrtValue>& ortvalue_list, std::vector<Ort::Value>& ort_value_list) {
|
||||
// Clean the output.
|
||||
ort_value_list.clear();
|
||||
size_t output_names_len = ortvalue_list.size();
|
||||
for (size_t i = 0; i < output_names_len; ++i)
|
||||
ort_value_list.emplace_back(nullptr);
|
||||
|
||||
Ort::Value* ort_value_outputs_ptr = ort_value_list.data();
|
||||
auto ortvalue_outputs_ptr = reinterpret_cast<OrtValue**>(ort_value_outputs_ptr);
|
||||
for (size_t i = 0; i != output_names_len; ++i) {
|
||||
OrtValue& value = ortvalue_list[i];
|
||||
// Ort::Value will release the pointer once it goes out of scope.
|
||||
ortvalue_outputs_ptr[i] = new OrtValue(value);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
struct OrtModule {
|
||||
public:
|
||||
OrtModule(
|
||||
OrtEnv* env, OrtSessionOptions* session_options,
|
||||
const std::string& train_model_path_or_bytes,
|
||||
std::unordered_map<std::string, std::shared_ptr<onnxruntime::training::api::Parameter>>& parameters,
|
||||
const std::optional<std::string>& eval_model_path_or_bytes = std::nullopt) {
|
||||
module_ = std::make_unique<onnxruntime::training::api::Module>(
|
||||
train_model_path_or_bytes,
|
||||
parameters,
|
||||
*reinterpret_cast<::onnxruntime::SessionOptions*>(session_options),
|
||||
*reinterpret_cast<::onnxruntime::Environment*>(env),
|
||||
eval_model_path_or_bytes);
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, std::shared_ptr<onnxruntime::training::api::Parameter>> NamedParameters() const {
|
||||
return module_->NamedParameters();
|
||||
}
|
||||
|
||||
bool ResetGrad() {
|
||||
return module_->ResetGrad().IsOK();
|
||||
}
|
||||
|
||||
bool TrainStep(const std::vector<Ort::Value>& inputs, std::vector<Ort::Value>& outputs) {
|
||||
std::vector<OrtValue> feeds;
|
||||
ToOrtValue(inputs, feeds);
|
||||
|
||||
std::vector<OrtValue> fetches;
|
||||
if (!module_->TrainStep(feeds, fetches).IsOK()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Clean the output.
|
||||
outputs.clear();
|
||||
FromOrtValue(fetches, outputs);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool EvalStep(const std::vector<Ort::Value>& inputs, std::vector<Ort::Value>& outputs) {
|
||||
std::vector<OrtValue> feeds;
|
||||
ToOrtValue(inputs, feeds);
|
||||
|
||||
std::vector<OrtValue> fetches;
|
||||
if (!module_->EvalStep(feeds, fetches).IsOK()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Clean the output.
|
||||
outputs.clear();
|
||||
FromOrtValue(fetches, outputs);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GetStateDict(onnxruntime::training::api::ModuleCheckpointState& module_checkpoint_states) {
|
||||
return module_->GetStateDict(module_checkpoint_states).IsOK();
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<onnxruntime::training::api::Module> module_;
|
||||
};
|
||||
|
||||
struct OrtOptimizer {
|
||||
friend struct OrtLinearLRScheduler;
|
||||
|
||||
OrtOptimizer(
|
||||
OrtEnv* env, OrtSessionOptions* session_options,
|
||||
const std::string& optim_path_or_bytes,
|
||||
const std::unordered_map<std::string, std::shared_ptr<onnxruntime::training::api::Parameter>>& parameters) {
|
||||
optimizer_ = std::make_shared<onnxruntime::training::api::Optimizer>(
|
||||
optim_path_or_bytes,
|
||||
parameters,
|
||||
*reinterpret_cast<::onnxruntime::SessionOptions*>(session_options),
|
||||
*reinterpret_cast<::onnxruntime::Environment*>(env));
|
||||
}
|
||||
|
||||
bool Step() {
|
||||
return optimizer_->Step().IsOK();
|
||||
}
|
||||
|
||||
bool GetStateDict(onnxruntime::training::api::OptimizerCheckpointState& optimizer_checkpoint_states) {
|
||||
return optimizer_->GetStateDict(optimizer_checkpoint_states).IsOK();
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<onnxruntime::training::api::Optimizer> optimizer_;
|
||||
};
|
||||
|
||||
struct OrtLinearLRScheduler {
|
||||
OrtLinearLRScheduler(OrtOptimizer& optimizer, int64_t warmup_step_count, int64_t total_step_count)
|
||||
: optim_(optimizer.optimizer_) {
|
||||
linear_lr_scheduler_ = std::make_unique<LinearLRScheduler>(optim_, warmup_step_count, total_step_count);
|
||||
}
|
||||
|
||||
bool Step() {
|
||||
return linear_lr_scheduler_->Step().IsOK();
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<onnxruntime::training::api::LinearLRScheduler> linear_lr_scheduler_;
|
||||
std::shared_ptr<onnxruntime::training::api::Optimizer> optim_;
|
||||
};
|
||||
|
||||
} // namespace Ort
|
||||
|
|
@ -61,6 +61,7 @@ struct Module {
|
|||
const std::unordered_map<std::string, std::shared_ptr<Parameter>>& named_parameters,
|
||||
const onnxruntime::SessionOptions& session_options,
|
||||
const Environment& env,
|
||||
const std::vector<std::shared_ptr<IExecutionProvider>>& providers,
|
||||
const std::optional<std::string>& eval_model_path_or_bytes = std::nullopt);
|
||||
|
||||
// Return the trainable/nontrainable parameters
|
||||
|
|
|
|||
|
|
@ -59,7 +59,8 @@ struct Optimizer {
|
|||
Optimizer(const std::string& optim_path_or_bytes,
|
||||
const std::unordered_map<std::string, std::shared_ptr<Parameter>>& named_parameters,
|
||||
const onnxruntime::SessionOptions& session_options,
|
||||
const Environment& env);
|
||||
const Environment& env,
|
||||
const std::vector<std::shared_ptr<IExecutionProvider>>& providers);
|
||||
|
||||
Status Step();
|
||||
|
||||
|
|
|
|||
|
|
@ -12,16 +12,23 @@ namespace training {
|
|||
namespace api {
|
||||
using namespace common;
|
||||
|
||||
struct ModelIdentifiers {
|
||||
const std::string train_model;
|
||||
const std::optional<std::string> eval_model, optim_model;
|
||||
ModelIdentifiers(const std::string& train_model_uri,
|
||||
const std::optional<std::string>& eval_model_uri,
|
||||
const std::optional<std::string>& optim_model_uri)
|
||||
: train_model(train_model_uri), eval_model(eval_model_uri), optim_model(optim_model_uri) {}
|
||||
};
|
||||
|
||||
// Wrapper on top of module and optimizer classes and is the only class exposed via capis
|
||||
class TrainingSession {
|
||||
public:
|
||||
TrainingSession(const Environment& session_env,
|
||||
const SessionOptions& session_options,
|
||||
const std::unordered_map<std::string, std::shared_ptr<Parameter>>& parameters);
|
||||
|
||||
Status Initialize(const std::string& train_model_uri,
|
||||
const std::optional<std::string>& eval_model_uri,
|
||||
const std::optional<std::string>& optim_model_uri);
|
||||
const std::vector<std::shared_ptr<IExecutionProvider>>& providers,
|
||||
const std::unordered_map<std::string, std::shared_ptr<Parameter>>& parameters,
|
||||
const ModelIdentifiers& model_identifiers);
|
||||
|
||||
size_t GetTrainModeOutputCount() const noexcept;
|
||||
|
||||
|
|
@ -44,8 +51,6 @@ class TrainingSession {
|
|||
private:
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TrainingSession);
|
||||
|
||||
const Environment& environment_;
|
||||
SessionOptions session_options_;
|
||||
const std::unordered_map<std::string, std::shared_ptr<Parameter>> named_parameters_;
|
||||
std::unique_ptr<Module> module_;
|
||||
std::unique_ptr<Optimizer> optimizer_;
|
||||
|
|
|
|||
|
|
@ -55,10 +55,14 @@ Module::Module(const std::string& train_model_path_or_bytes,
|
|||
const std::unordered_map<std::string, std::shared_ptr<Parameter>>& named_parameters,
|
||||
const onnxruntime::SessionOptions& session_options,
|
||||
const Environment& env,
|
||||
const std::optional<std::string>& eval_model_path_or_bytes) : named_parameters_{named_parameters} {
|
||||
// Create session for training model
|
||||
const std::vector<std::shared_ptr<IExecutionProvider>>& providers,
|
||||
const std::optional<std::string>& eval_model_path_or_bytes)
|
||||
: named_parameters_{named_parameters} {
|
||||
train_sess_ = std::make_unique<onnxruntime::InferenceSession>(session_options, env);
|
||||
ORT_THROW_IF_ERROR(train_sess_->Load(train_model_path_or_bytes));
|
||||
for (const auto& provider : providers) {
|
||||
ORT_THROW_IF_ERROR(train_sess_->RegisterExecutionProvider(provider));
|
||||
}
|
||||
ORT_THROW_IF_ERROR(train_sess_->Initialize());
|
||||
|
||||
// Extract model input and output names
|
||||
|
|
@ -123,7 +127,6 @@ Module::Module(const std::string& train_model_path_or_bytes,
|
|||
auto target_tensor = std::make_unique<Tensor>(param_data_tensor.DataType(), param_data_tensor.Shape(), target_allocator);
|
||||
ORT_THROW_IF_ERROR(train_sess_state.GetDataTransferMgr().CopyTensor(param_data_tensor, *target_tensor.get()));
|
||||
auto ml_tensor_type = DataTypeImpl::GetType<Tensor>();
|
||||
// TODO test the original buffer is released.
|
||||
param_data.Init(target_tensor.release(), ml_tensor_type, ml_tensor_type->GetDeleteFunc());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,20 @@
|
|||
#include "orttraining/training_api/include/training_session.h"
|
||||
#include "core/session/abi_session_options_impl.h"
|
||||
|
||||
namespace {
|
||||
|
||||
std::vector<std::shared_ptr<onnxruntime::IExecutionProvider>> CreateProviders(
|
||||
const std::vector<std::shared_ptr<onnxruntime::IExecutionProviderFactory>>& provider_factories) {
|
||||
std::vector<std::shared_ptr<onnxruntime::IExecutionProvider>> execution_providers;
|
||||
for (const auto& factory : provider_factories) {
|
||||
execution_providers.emplace_back(std::move(factory->CreateProvider()));
|
||||
}
|
||||
|
||||
return execution_providers;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options,
|
||||
_Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path,
|
||||
_In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path,
|
||||
|
|
@ -20,16 +34,18 @@ ORT_API_STATUS_IMPL(OrtApis::CreateTrainingSession, _In_ const OrtEnv* env, _In_
|
|||
*out = nullptr;
|
||||
|
||||
ORT_TRY {
|
||||
using ProvidersType = std::vector<std::shared_ptr<onnxruntime::IExecutionProvider>>;
|
||||
train_sess = std::make_unique<onnxruntime::training::api::TrainingSession>(
|
||||
env->GetEnvironment(),
|
||||
options == nullptr ? onnxruntime::SessionOptions() : options->value,
|
||||
chkpt_state->module_checkpoint_state.named_parameters);
|
||||
|
||||
ORT_API_RETURN_IF_STATUS_NOT_OK(train_sess->Initialize(train_model_path,
|
||||
eval_model_path ? std::optional<std::string>{eval_model_path}
|
||||
: std::nullopt,
|
||||
optimizer_model_path ? std::optional<std::string>{optimizer_model_path}
|
||||
: std::nullopt));
|
||||
options == nullptr ? ProvidersType() : CreateProviders(options->provider_factories),
|
||||
chkpt_state->module_checkpoint_state.named_parameters,
|
||||
onnxruntime::training::api::ModelIdentifiers{
|
||||
train_model_path,
|
||||
eval_model_path ? std::optional<std::string>{eval_model_path}
|
||||
: std::nullopt,
|
||||
optimizer_model_path ? std::optional<std::string>{optimizer_model_path}
|
||||
: std::nullopt});
|
||||
|
||||
*out = reinterpret_cast<OrtTrainingSession*>(train_sess.release());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,9 +20,14 @@ const std::string GROUP_ZERO_NAME = "group0";
|
|||
|
||||
// TODO: don't hard code the state names, should get the state names according to the optimizer types.
|
||||
// TODO: Conolidate with frontend tooling
|
||||
const std::vector<std::string> MOMENT_SUFFIXES{".exp_avg", ".exp_avg_sq"};
|
||||
const std::vector<std::string> MOMENT_STATE_NAMES{"momentum0", "momentum1"};
|
||||
|
||||
constexpr char LearningRateName[] = "learning_rate";
|
||||
constexpr char StepName[] = "step";
|
||||
constexpr char ParamsName[] = "params";
|
||||
constexpr char FirstOrderMomentsName[] = "first_order_moments";
|
||||
constexpr char SecondOrderMomentsName[] = "second_order_moments";
|
||||
|
||||
} // namespace
|
||||
|
||||
Status Optimizer::GenerateMomentumNamedStates() {
|
||||
|
|
@ -49,35 +54,70 @@ Status Optimizer::ConstructInputs() {
|
|||
if (optimizer_type_ == OptimizerType::AdamW) {
|
||||
auto& param_named_optimizer_states = optimizer_state_.param_named_optimizer_states;
|
||||
|
||||
std::string param_name;
|
||||
std::vector<std::string> param_names, grad_names, moment1_names, moment2_names, user_inputs;
|
||||
for (size_t i = 2; i < input_names_.size(); i++) {
|
||||
std::string& name = input_names_[i];
|
||||
auto it = named_parameters_.find(name);
|
||||
if (it != named_parameters_.end()) { // is param
|
||||
param_names.push_back(name);
|
||||
inputs_.push_back(it->second->Data());
|
||||
} else if (utils::GetParamNameFromGradient(name, param_name)) {
|
||||
grad_names.emplace_back(name);
|
||||
// assert param_name is valid.
|
||||
auto it = named_parameters_.find(param_name);
|
||||
ORT_ENFORCE(it != named_parameters_.end(), "Unknown param: ", param_name, " for field: ", name);
|
||||
inputs_.push_back(it->second->Gradient());
|
||||
} else if (utils::GetParamNameFromSuffix(name, MOMENT_SUFFIXES[0], param_name)) {
|
||||
moment1_names.push_back(name);
|
||||
auto it = named_parameters_.find(param_name);
|
||||
ORT_ENFORCE(it != named_parameters_.end(), "Unknown param: ", param_name, " for field: ", name);
|
||||
inputs_.push_back(param_named_optimizer_states.at(param_name).momentum_named_states.at(MOMENT_STATE_NAMES[0]));
|
||||
} else if (utils::GetParamNameFromSuffix(name, MOMENT_SUFFIXES[1], param_name)) {
|
||||
moment2_names.push_back(name);
|
||||
auto it = named_parameters_.find(param_name);
|
||||
ORT_ENFORCE(it != named_parameters_.end(), "Unknown param: ", param_name, " for field: ", name);
|
||||
inputs_.push_back(param_named_optimizer_states.at(param_name).momentum_named_states.at(MOMENT_STATE_NAMES[1]));
|
||||
std::vector<Tensor> params, first_order_moments, second_order_moments;
|
||||
// TODO: Change to tensor seq implementation once clip grad norm op
|
||||
// that accepts tensor seq as input for gradients is complete.
|
||||
std::vector<OrtValue> grads;
|
||||
|
||||
// Input names 0-4 are reserved for lr, step, params, first order moments, second order moments
|
||||
// input names 5 onwards are all the gradient names.
|
||||
// Collect all the inputs based on the gradient names order.
|
||||
for (size_t i = 5; i < input_names_.size(); i++) {
|
||||
std::string param_name;
|
||||
if (utils::GetParamNameFromGradient(input_names_[i], param_name)) {
|
||||
const auto named_parameter_it = named_parameters_.find(param_name);
|
||||
ORT_ENFORCE(named_parameter_it != named_parameters_.end(),
|
||||
"Unknown param: ", param_name, " for field: ", input_names_[i]);
|
||||
|
||||
// Collect the gradients as ortvalues
|
||||
grads.push_back(named_parameter_it->second->Gradient());
|
||||
|
||||
// Collect parameters and prepare for tensorseq creation
|
||||
auto* param_tensor = named_parameter_it->second->Data().GetMutable<Tensor>();
|
||||
params.emplace_back(
|
||||
Tensor(param_tensor->DataType(), param_tensor->Shape(),
|
||||
param_tensor->MutableDataRaw(), param_tensor->Location()));
|
||||
|
||||
// Collect first order moments and prepare for tensorseq creation
|
||||
auto* first_order_moment_tensor = param_named_optimizer_states.at(param_name)
|
||||
.momentum_named_states.at(MOMENT_STATE_NAMES[0])
|
||||
.GetMutable<Tensor>();
|
||||
first_order_moments.emplace_back(
|
||||
Tensor(first_order_moment_tensor->DataType(), first_order_moment_tensor->Shape(),
|
||||
first_order_moment_tensor->MutableDataRaw(), first_order_moment_tensor->Location()));
|
||||
|
||||
// Collect second order moments and prepare for tensorseq creation
|
||||
auto* second_order_moment_tensor = param_named_optimizer_states.at(param_name)
|
||||
.momentum_named_states.at(MOMENT_STATE_NAMES[1])
|
||||
.GetMutable<Tensor>();
|
||||
second_order_moments.emplace_back(
|
||||
Tensor(second_order_moment_tensor->DataType(), second_order_moment_tensor->Shape(),
|
||||
second_order_moment_tensor->MutableDataRaw(), second_order_moment_tensor->Location()));
|
||||
} else {
|
||||
ORT_ENFORCE("This is an invalid graph. Optimizer graph contains unknown user input:", name);
|
||||
ORT_ENFORCE(
|
||||
false, "This is an invalid graph. Optimizer graph contains unknown user input:", input_names_[i]);
|
||||
}
|
||||
ORT_ENFORCE(inputs_.back().IsAllocated() && inputs_.back().IsTensor(), "Uninitialized tensor data for ", name);
|
||||
}
|
||||
|
||||
const auto tensorseq_inserter = [](auto& tensors, auto* inputs) {
|
||||
ORT_ENFORCE(!tensors.empty(), "Tensors cannot be empty while building a tensor sequence.");
|
||||
|
||||
auto tensor_seq = std::make_unique<TensorSeq>(tensors.front().DataType());
|
||||
tensor_seq->SetElements(std::move(tensors));
|
||||
inputs->emplace_back(
|
||||
OrtValue(tensor_seq.release(), DataTypeImpl::GetType<TensorSeq>(),
|
||||
DataTypeImpl::GetType<TensorSeq>()->GetDeleteFunc()));
|
||||
};
|
||||
|
||||
// Add the params and moments as tensorseq ortvalues to inputs
|
||||
tensorseq_inserter(params, &inputs_);
|
||||
tensorseq_inserter(first_order_moments, &inputs_);
|
||||
tensorseq_inserter(second_order_moments, &inputs_);
|
||||
|
||||
// Add the gradients as ortvalues to inputs
|
||||
inputs_.insert(inputs_.end(),
|
||||
std::make_move_iterator(grads.begin()),
|
||||
std::make_move_iterator(grads.end()));
|
||||
}
|
||||
// Add other optimizer reordering logic here
|
||||
return Status::OK();
|
||||
|
|
@ -86,17 +126,26 @@ Status Optimizer::ConstructInputs() {
|
|||
Optimizer::Optimizer(const std::string& optim_path_or_bytes,
|
||||
const std::unordered_map<std::string, std::shared_ptr<Parameter>>& named_parameters,
|
||||
const onnxruntime::SessionOptions& session_options,
|
||||
const Environment& env) : named_parameters_(named_parameters) {
|
||||
const Environment& env,
|
||||
const std::vector<std::shared_ptr<IExecutionProvider>>& providers)
|
||||
: named_parameters_(named_parameters) {
|
||||
optim_sess_ = std::move(std::make_unique<InferenceSession>(session_options, env));
|
||||
for (const auto& execution_provider : providers) {
|
||||
ORT_THROW_IF_ERROR(optim_sess_->RegisterExecutionProvider(execution_provider));
|
||||
}
|
||||
|
||||
ORT_THROW_IF_ERROR(optim_sess_->Load(optim_path_or_bytes));
|
||||
ORT_THROW_IF_ERROR(optim_sess_->Initialize());
|
||||
|
||||
utils::GetGraphInputOutputNames(optim_sess_, input_names_, output_names_);
|
||||
ORT_ENFORCE(input_names_[0] == "learning_rate"); // TODO: make this better
|
||||
ORT_ENFORCE(input_names_[1] == "step"); // TODO: make this better
|
||||
ORT_ENFORCE(input_names_[0] == LearningRateName); // TODO: make this better
|
||||
ORT_ENFORCE(input_names_[1] == StepName); // TODO: make this better
|
||||
ORT_ENFORCE(input_names_[2] == ParamsName); // TODO: make this better
|
||||
|
||||
if (optimizer_type_ == OptimizerType::AdamW) {
|
||||
ORT_ENFORCE(input_names_[3] == FirstOrderMomentsName); // TODO: make this better
|
||||
ORT_ENFORCE(input_names_[4] == SecondOrderMomentsName); // TODO: make this better
|
||||
|
||||
ORT_THROW_IF_ERROR(GenerateMomentumNamedStates());
|
||||
} else {
|
||||
ORT_THROW("Unsupported optimizer type");
|
||||
|
|
@ -107,7 +156,10 @@ Optimizer::Optimizer(const std::string& optim_path_or_bytes,
|
|||
Status Optimizer::Step() {
|
||||
OrtValue learning_rate_input, step_input;
|
||||
utils::WrapInOrtValue<float>(optimizer_state_.learning_rate, &learning_rate_input);
|
||||
utils::WrapInOrtValue<int64_t>(optimizer_state_.step, &step_input);
|
||||
// Use step count + 1 before running optimizer step.
|
||||
// This is necessary since bias correction uses the step
|
||||
// as a power. Using power of 0 is wrong.
|
||||
utils::WrapInOrtValue<int64_t>(optimizer_state_.step + 1, &step_input);
|
||||
std::vector<OrtValue> feeds({learning_rate_input, step_input});
|
||||
feeds.insert(feeds.end(), inputs_.begin(), inputs_.end());
|
||||
|
||||
|
|
@ -116,8 +168,9 @@ Status Optimizer::Step() {
|
|||
ORT_THROW_IF_ERROR(status);
|
||||
|
||||
// extract step output and update
|
||||
// TODO: need to remove hardcoding
|
||||
optimizer_state_.step = utils::GetValue<int64_t>(outputs[0]);
|
||||
if (utils::GetValue<int64_t>(outputs[0]) == 1LL) {
|
||||
optimizer_state_.step++;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,23 +9,17 @@ namespace api {
|
|||
|
||||
TrainingSession::TrainingSession(const Environment& session_env,
|
||||
const SessionOptions& session_options,
|
||||
const std::unordered_map<std::string, std::shared_ptr<Parameter>>& parameters)
|
||||
: environment_(session_env),
|
||||
session_options_{session_options},
|
||||
named_parameters_{parameters} {}
|
||||
|
||||
Status TrainingSession::Initialize(const std::string& train_model_uri, const std::optional<std::string>& eval_model_uri,
|
||||
const std::optional<std::string>& optim_model_uri) {
|
||||
module_ = std::move(std::make_unique<Module>(train_model_uri, named_parameters_, session_options_,
|
||||
environment_, eval_model_uri));
|
||||
|
||||
if (optim_model_uri.has_value()) {
|
||||
optimizer_ = std::move(std::make_unique<Optimizer>(optim_model_uri.value(), named_parameters_,
|
||||
session_options_, environment_));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
const std::vector<std::shared_ptr<IExecutionProvider>>& providers,
|
||||
const std::unordered_map<std::string, std::shared_ptr<Parameter>>& parameters,
|
||||
const ModelIdentifiers& model_identifiers)
|
||||
: named_parameters_{parameters},
|
||||
module_{std::make_unique<Module>(model_identifiers.train_model, named_parameters_,
|
||||
session_options, session_env, providers, model_identifiers.eval_model)},
|
||||
optimizer_{model_identifiers.optim_model.has_value()
|
||||
? std::make_unique<Optimizer>(
|
||||
model_identifiers.optim_model.value(), named_parameters_,
|
||||
session_options, session_env, providers)
|
||||
: std::unique_ptr<Optimizer>()} {}
|
||||
|
||||
size_t TrainingSession::GetTrainModeOutputCount() const noexcept {
|
||||
return module_->GetTrainModeOutputCount();
|
||||
|
|
|
|||
|
|
@ -63,12 +63,22 @@ Status OrtValueLike(const SessionState& sess_state, const OrtValue& input_val, O
|
|||
auto element_type = param_tensor.DataType();
|
||||
auto p_tensor = std::make_unique<Tensor>(element_type, shape, allocator);
|
||||
|
||||
// TODO: handle CUDA memset
|
||||
if (tensor_location.device.Type() == OrtDevice::CPU ||
|
||||
tensor_location.mem_type == OrtMemTypeCPUInput ||
|
||||
tensor_location.mem_type == OrtMemTypeCPUOutput) {
|
||||
memset(p_tensor->MutableDataRaw(), 0, p_tensor->SizeInBytes());
|
||||
} else if (tensor_location.device.Type() == OrtDevice::GPU) {
|
||||
// Use a tensor on cpu and copy it over to the desired device using
|
||||
// the data transfer manager.
|
||||
AllocatorPtr cpu_allocator = sess_state.GetAllocator(OrtDevice());
|
||||
auto p_cpu_tensor = std::make_unique<Tensor>(element_type, shape, cpu_allocator);
|
||||
memset(p_cpu_tensor->MutableDataRaw(), 0, p_cpu_tensor->SizeInBytes());
|
||||
// No need to free the cpu buffer, tensor destructor takes care of it using the cpu_allocator
|
||||
ORT_THROW_IF_ERROR(sess_state.GetDataTransferMgr().CopyTensor(*p_cpu_tensor, *p_tensor));
|
||||
} else {
|
||||
ORT_THROW("Cannot create tensor on device ", tensor_location.device.Type());
|
||||
}
|
||||
|
||||
output_val.Init(p_tensor.release(),
|
||||
DataTypeImpl::GetType<Tensor>(),
|
||||
DataTypeImpl::GetType<Tensor>()->GetDeleteFunc());
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ extend_skip_glob = [
|
|||
convention = "google"
|
||||
|
||||
[tool.pylint.'MESSAGES CONTROL']
|
||||
disable = ["format", "line-too-long", "import-error", "no-name-in-module"]
|
||||
disable = ["format", "line-too-long", "import-error", "no-name-in-module", "fixme", "too-few-public-methods"]
|
||||
|
||||
[tool.pyright]
|
||||
exclude = ["onnxruntime/core/flatbuffers/*"]
|
||||
|
|
|
|||
6
setup.py
6
setup.py
|
|
@ -374,7 +374,7 @@ requirements_file = "requirements.txt"
|
|||
|
||||
local_version = None
|
||||
enable_training = parse_arg_remove_boolean(sys.argv, "--enable_training")
|
||||
enable_training_on_device = parse_arg_remove_boolean(sys.argv, '--enable_training_on_device')
|
||||
enable_training_on_device = parse_arg_remove_boolean(sys.argv, "--enable_training_on_device")
|
||||
disable_auditwheel_repair = parse_arg_remove_boolean(sys.argv, "--disable_auditwheel_repair")
|
||||
default_training_package_device = parse_arg_remove_boolean(sys.argv, "--default_training_package_device")
|
||||
|
||||
|
|
@ -421,7 +421,9 @@ if enable_training:
|
|||
]
|
||||
)
|
||||
if enable_training_on_device:
|
||||
packages.append('onnxruntime.training.onnxblock')
|
||||
packages.append("onnxruntime.training.onnxblock")
|
||||
packages.append("onnxruntime.training.onnxblock.loss")
|
||||
packages.append("onnxruntime.training.onnxblock.optim")
|
||||
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.aten_op_executor"] = ["*.cc"]
|
||||
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils"] = ["*.cc"]
|
||||
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator"] = ["*.cc"]
|
||||
|
|
|
|||
|
|
@ -179,8 +179,7 @@ def parse_arguments():
|
|||
parser.add_argument(
|
||||
"--enable_training_torch_interop", action="store_true", help="Enable training kernels interop with torch."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_training_on_device", action='store_true', help="Enable on device training in ORT.")
|
||||
parser.add_argument("--enable_training_on_device", action="store_true", help="Enable on device training in ORT.")
|
||||
parser.add_argument("--disable_nccl", action="store_true", help="Disable Nccl.")
|
||||
parser.add_argument("--mpi_home", help="Path to MPI installation dir")
|
||||
parser.add_argument("--nccl_home", help="Path to NCCL installation dir")
|
||||
|
|
|
|||
Loading…
Reference in a new issue