diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 685ae59fc6..1005423035 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -61,7 +61,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float, LayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double_float, LayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double_double, LayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float, LayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear); @@ -125,7 +125,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/layer_norm.cc b/onnxruntime/contrib_ops/cuda/layer_norm.cc index bd6d14eef0..de89bd424d 100644 --- a/onnxruntime/contrib_ops/cuda/layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/layer_norm.cc @@ -24,7 +24,7 @@ namespace cuda { LayerNorm); REGISTER_KERNEL_TYPED(float, float) -REGISTER_KERNEL_TYPED(double, float) +REGISTER_KERNEL_TYPED(double, double) REGISTER_KERNEL_TYPED(MLFloat16, float) template diff --git a/onnxruntime/contrib_ops/cuda/layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/layer_norm_impl.cu index 747a2ff70e..4251c21033 100644 --- a/onnxruntime/contrib_ops/cuda/layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/layer_norm_impl.cu @@ -376,7 +376,7 @@ void HostApplyLayerNorm( LAYERNORM_LINEAR_IMPL(float, float) LAYERNORM_LINEAR_IMPL(half, float) -LAYERNORM_LINEAR_IMPL(double, float) +LAYERNORM_LINEAR_IMPL(double, double) //LAYERNORM_LINEAR_IMPL(half, half) } // namespace cuda diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.cc b/orttraining/orttraining/core/framework/gradient_graph_builder.cc index 0b32ec0857..fe44ae78b6 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.cc +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.cc @@ -6,6 +6,7 @@ #include "core/graph/schema_registry.h" #include "orttraining/core/framework/gradient_graph_builder.h" #include "orttraining/core/graph/gradient_builder_registry.h" +#include "orttraining/core/graph/gradient_config.h" #include "orttraining/core/optimizer/insert_output_rewriter.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/rule_based_graph_transformer.h" @@ -22,9 +23,11 @@ GradientGraphBuilder::GradientGraphBuilder(Graph* graph, const unordered_set& y_node_arg_names, const unordered_set& x_node_arg_names, string loss_node_arg_name, + const GradientGraphConfiguration& gradient_graph_config, const bool set_gradient_as_graph_output) : graph_(graph), loss_node_arg_name_(loss_node_arg_name), + gradient_graph_config_(gradient_graph_config), set_gradient_as_graph_output_(set_gradient_as_graph_output) { auto rule_based_graph_transformer = onnxruntime::make_unique("pre_training_rule_based_graph_transformer"); @@ -187,7 +190,7 @@ Status GradientGraphBuilder::Build() { } } - GradientDef node_defs = GetGradientForOp(node, output_args_need_grad, input_args_need_grad); + GradientDef node_defs = GetGradientForOp(gradient_graph_config_, node, output_args_need_grad, input_args_need_grad); // updates arg name if gradient accumulation is needed for (auto& op_def : node_defs) { diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.h b/orttraining/orttraining/core/framework/gradient_graph_builder.h index 2c40ab86e8..403d543613 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.h +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.h @@ -58,6 +58,7 @@ class GradientGraphBuilder { const std::unordered_set& y_node_arg_names, const std::unordered_set& x_node_arg_names, std::string loss_node_arg_name, + const GradientGraphConfiguration& gradient_graph_config, const bool set_gradient_as_graph_output = false); Status Build(); @@ -73,6 +74,8 @@ class GradientGraphBuilder { std::string loss_node_arg_name_; + const GradientGraphConfiguration& gradient_graph_config_; + onnxruntime::GraphTransformerManager graph_transformation_mgr_{5}; // key: ArgDef for the gradient after accumulation diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 9c6cbf72fc..a823d46ba9 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -956,11 +956,19 @@ IMPLEMENT_GRADIENT_BUILDER(GetFastGeluGradient) { } IMPLEMENT_GRADIENT_BUILDER(GetLayerNormalizationGradient) { - return std::vector{ - NodeDef(OpDef{"LayerNormalizationGrad", kMSDomain, 1}, - {GO(0), I(0), I(1), O(1), O(2)}, - {GI(0), GI(1), GI(2)}, - {SrcNodeAttributes()})}; + if (GetGradientGraphConfiguration().use_invertible_layernorm_grad) { + return std::vector{ + NodeDef(OpDef{"InvertibleLayerNormalizationGrad", kMSDomain, 1}, + {GO(0), O(0), I(1), I(2), O(2)}, + {GI(0), GI(1), GI(2)}, + {SrcNodeAttributes()})}; + } else { + return std::vector{ + NodeDef(OpDef{"LayerNormalizationGrad", kMSDomain, 1}, + {GO(0), I(0), I(1), O(1), O(2)}, + {GI(0), GI(1), GI(2)}, + {SrcNodeAttributes()})}; + } } IMPLEMENT_GRADIENT_BUILDER(GetBatchNormalizationGradient) { diff --git a/orttraining/orttraining/core/graph/gradient_builder_base.h b/orttraining/orttraining/core/graph/gradient_builder_base.h index 0f5c79eb8e..90f5ed5124 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_base.h +++ b/orttraining/orttraining/core/graph/gradient_builder_base.h @@ -7,6 +7,7 @@ #include #include "core/graph/graph.h" #include "orttraining/core/graph/graph_augmenter.h" +#include "orttraining/core/graph/gradient_config.h" #include "onnx/defs/attr_proto_util.h" namespace onnxruntime { @@ -27,10 +28,11 @@ typedef std::vector GradientDef; class GradientBuilderBase { public: GradientBuilderBase( + const GradientGraphConfiguration& gradient_graph_config, const Node* node, const std::unordered_set& gradient_inputs, const std::unordered_set& gradient_outputs) - : node_(node), gradient_inputs_(gradient_inputs), gradient_outputs_(gradient_outputs) { + : gradient_graph_config_(gradient_graph_config), node_(node), gradient_inputs_(gradient_inputs), gradient_outputs_(gradient_outputs) { unique_node_prefix_ = CreateUniqueNodePrefix(); } @@ -54,6 +56,10 @@ class GradientBuilderBase { protected: virtual GradientDef GetGradientDefsImpl() const = 0; + const GradientGraphConfiguration& GetGradientGraphConfiguration() const { + return gradient_graph_config_; + } + // i-th input of forward op ArgDef I(const size_t i) const { ORT_ENFORCE(i < node_->InputDefs().size()); @@ -185,6 +191,7 @@ class GradientBuilderBase { return unique_prefix.str(); } + const GradientGraphConfiguration& gradient_graph_config_; const Node* node_; std::string unique_node_prefix_; diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index 54149f7bf3..94b4e1e096 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -3,11 +3,13 @@ #include "orttraining/core/graph/gradient_builder_registry.h" #include "orttraining/core/graph/gradient_builder.h" +#include "orttraining/core/graph/gradient_config.h" namespace onnxruntime { namespace training { -GradientDef GetGradientForOp(const Node* node, +GradientDef GetGradientForOp(const GradientGraphConfiguration& gradient_graph_config, + const Node* node, const std::unordered_set& output_args_need_grad, const std::unordered_set& input_args_need_grad) { @@ -16,6 +18,7 @@ GradientDef GetGradientForOp(const Node* node, // less than 9 is not supported and for Slice we have Slice-1, Slice-10 and Slice-11. auto gradient_builder = GradientBuilderRegistry::GetInstance().MakeUnique(node->OpType(), + gradient_graph_config, node, output_args_need_grad, input_args_need_grad); diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.h b/orttraining/orttraining/core/graph/gradient_builder_registry.h index 9568aff6bd..7acec7e7f4 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.h +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.h @@ -12,6 +12,7 @@ namespace onnxruntime { namespace training { typedef GenericRegistry&, // gradient_inputs const std::unordered_set&> // gradient_outputs @@ -31,7 +32,8 @@ class GradientBuilderRegistry : public GradientRegistryType { ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GradientBuilderRegistry); }; -GradientDef GetGradientForOp(const Node* node, +GradientDef GetGradientForOp(const GradientGraphConfiguration& gradient_graph_config, + const Node* node, const std::unordered_set& output_args_need_grad, const std::unordered_set& input_args_need_grad); diff --git a/orttraining/orttraining/core/graph/gradient_config.h b/orttraining/orttraining/core/graph/gradient_config.h new file mode 100644 index 0000000000..bee896965e --- /dev/null +++ b/orttraining/orttraining/core/graph/gradient_config.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +namespace training { + +struct GradientGraphConfiguration { + // Layernorm gradient can be computed based on either input or output of layernorm. + // That is to say, either input or output needs to be stashed for layernorm gradient. + // To save memory, ideally, only one(input vs output) should be stashed rather than both. + // By default, the input based algorithm is used. This flag is to enable the output based algorithm. + bool use_invertible_layernorm_grad{false}; +}; + +} // namespace training +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 17c816268c..23d99cf071 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -1432,6 +1432,32 @@ Example 4: {"tensor(float)"}, "Constrain mean and inv_std_var to float tensors."); + ONNX_CONTRIB_OPERATOR_SCHEMA(InvertibleLayerNormalizationGrad) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL) + .SetDoc("LayerNormalizationGrad") + .Attr("axis", + "The first normalization dimension: normalization will be performed along dimensions axis : rank(inputs).", + AttributeProto::INT, static_cast(-1)) + .AllowUncheckedAttributes() + .Input(0, "Y_grad", "The gradient tensor from output.", "T") + .Input(1, "Y", "Output data tensor from the forward path", "T") + .Input(2, "scale", "Scale tensor.", "T") + .Input(3, "bias", "Bias tensor.", "T") + .Input(4, "inv_std_var", "inverse std variance of X.", "U") + .Output(0, "X_grad", "Gradient of the input.", "T") + .Output(1, "scale_grad", "Gradient of the scale.", "T") + .Output(2, "bias_grad", "Gradient of the bias.", "T") + .TypeConstraint( + "T", + {"tensor(float16)", "tensor(float)", "tensor(double)"}, + "Constrain input and output types (except mean and inv_std_var) to float tensors.") + .TypeConstraint( + "U", + {"tensor(float)"}, + "Constrain mean and inv_std_var to float tensors."); + ONNX_CONTRIB_OPERATOR_SCHEMA(BatchNormalizationGrad) .SetDomain(kMSDomain) .SinceVersion(1) diff --git a/orttraining/orttraining/core/session/training_session.cc b/orttraining/orttraining/core/session/training_session.cc index 6b650eed47..1898639f12 100644 --- a/orttraining/orttraining/core/session/training_session.cc +++ b/orttraining/orttraining/core/session/training_session.cc @@ -219,7 +219,7 @@ Status TrainingSession::ConfigureForTraining( } ORT_RETURN_IF_ERROR(BuildGradientGraph( - weight_names_to_train, loss_name, config.set_gradients_as_graph_outputs)); + weight_names_to_train, loss_name, config.gradient_graph_config, config.set_gradients_as_graph_outputs)); // transform for mixed precision std::unordered_map fp32_weight_name_to_fp16_node_arg{}; @@ -425,12 +425,14 @@ static Status ConfigureLossFunctionInternal( static Status BuildGradientGraphInternal(Graph& graph, const std::string& loss_function_output_name, const std::unordered_set& node_arg_names_to_train, + const GradientGraphConfiguration& gradient_graph_config, const bool set_gradient_as_graph_output = false) { // Compute the gradient graph def. GradientGraphBuilder grad_graph_builder(&graph, {loss_function_output_name}, node_arg_names_to_train, loss_function_output_name, + gradient_graph_config, set_gradient_as_graph_output); return grad_graph_builder.Build(); } @@ -638,13 +640,16 @@ Status TrainingSession::EnableMixedPrecision(const std::unordered_set& weights_to_train, const std::string& loss_function_output_name, + const GradientGraphConfiguration& gradient_graph_config, const bool set_gradient_as_graph_output) { // Fill weights_to_train_ according to weights_to_train weights_to_train_ = weights_to_train; + gradient_graph_config_ = gradient_graph_config; ORT_RETURN_IF_ERROR(BuildGradientGraphInternal(model_->MainGraph(), loss_function_output_name, weights_to_train_, + gradient_graph_config_, set_gradient_as_graph_output)); return DoPostLoadProcessing(*model_); @@ -762,6 +767,7 @@ Status TrainingSession::Save(const PathString& model_uri, TrainingSession::SaveO ORT_RETURN_IF_ERROR(BuildGradientGraphInternal(new_model->MainGraph(), actual_loss_name, weights_to_train_, + gradient_graph_config_, false)); OptimizerOutputKeyMap opt_graph_outputs; diff --git a/orttraining/orttraining/core/session/training_session.h b/orttraining/orttraining/core/session/training_session.h index 9ede468f50..5453b54273 100644 --- a/orttraining/orttraining/core/session/training_session.h +++ b/orttraining/orttraining/core/session/training_session.h @@ -10,6 +10,7 @@ #include "orttraining/core/graph/loss_function_registry.h" #include "orttraining/core/graph/optimizer_graph_output_key.h" #include "orttraining/core/graph/optimizer_config.h" +#include "orttraining/core/graph/gradient_config.h" namespace onnxruntime { namespace training { @@ -42,6 +43,9 @@ class TrainingSession : public InferenceSession { // The immutable weights specification. ImmutableWeights immutable_weights; + // Gradient graph configuration + GradientGraphConfiguration gradient_graph_config{}; + // Whether to set the gradients as graph outputs. bool set_gradients_as_graph_outputs{false}; @@ -409,6 +413,7 @@ class TrainingSession : public InferenceSession { */ common::Status BuildGradientGraph(const std::unordered_set& weights_to_train, const std::string& loss_function_output_name, + const GradientGraphConfiguration& gradient_graph_config, const bool set_gradient_as_graph_output = false); common::Status BuildAccumulationNode(const std::unordered_set& weights_to_train); @@ -469,6 +474,8 @@ class TrainingSession : public InferenceSession { std::unordered_set dropout_eval_feeds_; OptimizerGraphConfig opt_graph_config_; std::unordered_map opt_configs_; + + GradientGraphConfiguration gradient_graph_config_; }; } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/models/bert/main.cc b/orttraining/orttraining/models/bert/main.cc index effa43b7ab..89067c9269 100644 --- a/orttraining/orttraining/models/bert/main.cc +++ b/orttraining/orttraining/models/bert/main.cc @@ -165,7 +165,9 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet ("enable_grad_norm_clip", "Specify whether to enable gradient clipping for optimizers.", cxxopts::value()->default_value("true")) ("enable_gelu_approximation", "Specify whether to enable GELU approximation.", - cxxopts::value()->default_value("true")); + cxxopts::value()->default_value("true")) + ("use_invertible_layernorm_grad", "Specify whether to use invertible laynorm(dropping the input activation)", + cxxopts::value()->default_value("false")); options .add_options("ORT configuration") ("ort_log_severity", "ORT minimum logging severity (see onnxruntime::logging::Severity values)", @@ -458,12 +460,15 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet "Log severity must be in the range [", static_cast(logging::Severity::kVERBOSE), ", ", static_cast(logging::Severity::kFATAL), "]."); ort_params.vlog_level = flags["ort_vlog_level"].as(); + + params.use_invertible_layernorm_grad = flags["use_invertible_layernorm_grad"].as(); } catch (const exception& e) { const std::string msg = "Failed to parse the command line arguments"; cerr << msg << ": " << e.what() << "\n" << options.help() << "\n"; return Status(ONNXRUNTIME, INVALID_ARGUMENT, msg); } + return Status::OK(); } diff --git a/orttraining/orttraining/models/runner/training_runner.cc b/orttraining/orttraining/models/runner/training_runner.cc index 226dbe2d27..e94dc86977 100644 --- a/orttraining/orttraining/models/runner/training_runner.cc +++ b/orttraining/orttraining/models/runner/training_runner.cc @@ -91,6 +91,7 @@ Status TrainingRunner::Initialize() { config.weight_names_to_not_train = params_.weights_not_to_train; config.immutable_weights = params_.immutable_weights; + config.gradient_graph_config.use_invertible_layernorm_grad = params_.use_invertible_layernorm_grad; config.set_gradients_as_graph_outputs = false; config.gradient_accumulation_steps = params_.gradient_accumulation_steps; diff --git a/orttraining/orttraining/models/runner/training_runner.h b/orttraining/orttraining/models/runner/training_runner.h index 015d21b470..4b42a56168 100644 --- a/orttraining/orttraining/models/runner/training_runner.h +++ b/orttraining/orttraining/models/runner/training_runner.h @@ -169,6 +169,9 @@ class TrainingRunner { // Enable GELU approximation bool enable_gelu_approximation = false; + + // Use invertible layernorm grad + bool use_invertible_layernorm_grad = false; }; TrainingRunner(Parameters params, const Environment& env); diff --git a/orttraining/orttraining/python/ort_trainer.py b/orttraining/orttraining/python/ort_trainer.py index c58580902b..33e00252e9 100644 --- a/orttraining/orttraining/python/ort_trainer.py +++ b/orttraining/orttraining/python/ort_trainer.py @@ -376,7 +376,8 @@ def create_ort_training_session_with_optimizer(model, device, training_optimizer deepspeed_zero_stage=0, enable_grad_norm_clip=True, frozen_weights=[], opset_version=DEFAULT_OPSET_VERSION, - use_deterministic_compute=False): + use_deterministic_compute=False, + use_invertible_layernorm_grad=False): output_name = model.graph.output[0].name ort_parameters = ort.TrainingParameters() ort_parameters.loss_output_name = output_name @@ -384,11 +385,11 @@ def create_ort_training_session_with_optimizer(model, device, training_optimizer ort_parameters.world_rank = world_rank ort_parameters.world_size = world_size ort_parameters.gradient_accumulation_steps = gradient_accumulation_steps - ort_parameters.use_mixed_precision = use_mixed_precision ort_parameters.allreduce_post_accumulation = allreduce_post_accumulation ort_parameters.deepspeed_zero_stage = deepspeed_zero_stage ort_parameters.enable_grad_norm_clip = enable_grad_norm_clip ort_parameters.set_gradients_as_graph_outputs = False + ort_parameters.use_invertible_layernorm_grad = use_invertible_layernorm_grad output_types = {} for output in model.graph.output: @@ -530,7 +531,8 @@ class ORTTrainer(): world_rank=0, world_size=1, use_mixed_precision=False, allreduce_post_accumulation=False, global_step=0, get_lr_this_step=None, loss_scaler=None, deepspeed_zero_stage=0, enable_grad_norm_clip=True, frozen_weights=[], _opset_version=DEFAULT_OPSET_VERSION, - _enable_internal_postprocess=True, _extra_postprocess=None, _use_deterministic_compute=False): + _enable_internal_postprocess=True, _extra_postprocess=None, _use_deterministic_compute=False, + use_invertible_layernorm_grad=False): super(ORTTrainer, self).__init__() """ Initialize ORTTrainer. @@ -599,6 +601,8 @@ class ORTTrainer(): Defaults to True _extra_postprocess: a callable to postprocess the ONNX model that is converted from PyTorch. Defaults to None + use_invertible_layernorm_grad: use invertible layernorm grad + Defaults to False """ warnings.warn('DISCLAIMER: This is an early version of an experimental training API and it is subject to change. DO NOT create production applications with it') self.is_train = True @@ -651,6 +655,7 @@ class ORTTrainer(): self.state_dict_ = None self._enable_internal_postprocess = _enable_internal_postprocess self._use_deterministic_compute = _use_deterministic_compute + self.use_invertible_layernorm_grad = use_invertible_layernorm_grad # use this special string to workaround a corner case that external loss_scale is passed into train_step as kwargs. # see prepare_input_and_fetches for more details. @@ -679,7 +684,8 @@ class ORTTrainer(): deepspeed_zero_stage=self.deepspeed_zero_stage_, enable_grad_norm_clip=self.enable_grad_norm_clip_, frozen_weights=self.frozen_weights_, opset_version=self.opset_version_, - use_deterministic_compute=self._use_deterministic_compute) + use_deterministic_compute=self._use_deterministic_compute, + use_invertible_layernorm_grad=self.use_invertible_layernorm_grad) self.loss_scale_input_name = self.session.loss_scale_input_name diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index a9d52c7a80..415e1c2c0e 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -47,6 +47,7 @@ struct TrainingParameters { int deepspeed_zero_stage = 0; bool enable_grad_norm_clip = true; bool set_gradients_as_graph_outputs = false; + bool use_invertible_layernorm_grad = false; }; struct TrainingConfigurationResult { @@ -144,6 +145,8 @@ TrainingConfigurationResult ConfigureSessionForTraining( config.optimizer_config = opt; } + config.gradient_graph_config.use_invertible_layernorm_grad = parameters.use_invertible_layernorm_grad; + training::TrainingSession::TrainingConfigurationResult config_result{}; OrtPybindThrowIfError(sess->ConfigureForTraining(config, config_result)); @@ -177,7 +180,8 @@ void addObjectMethodsForTraining(py::module& m) { .def_readwrite("gradient_accumulation_steps", &TrainingParameters::gradient_accumulation_steps) .def_readwrite("deepspeed_zero_stage", &TrainingParameters::deepspeed_zero_stage) .def_readwrite("enable_grad_norm_clip", &TrainingParameters::enable_grad_norm_clip) - .def_readwrite("set_gradients_as_graph_outputs", &TrainingParameters::set_gradients_as_graph_outputs); + .def_readwrite("set_gradients_as_graph_outputs", &TrainingParameters::set_gradients_as_graph_outputs) + .def_readwrite("use_invertible_layernorm_grad", &TrainingParameters::use_invertible_layernorm_grad); py::class_ config_result(m, "TrainingConfigurationResult", "pbdoc(Configuration result for training.)pbdoc"); config_result.def(py::init()) diff --git a/orttraining/orttraining/test/gradient/gradient_checker.cc b/orttraining/orttraining/test/gradient/gradient_checker.cc index e923b81926..92892bff21 100644 --- a/orttraining/orttraining/test/gradient/gradient_checker.cc +++ b/orttraining/orttraining/test/gradient/gradient_checker.cc @@ -18,6 +18,7 @@ limitations under the License. #include "gradient_checker.h" #include "gradient_op_test_utils.h" #include "orttraining/core/framework/gradient_graph_builder.h" +#include "orttraining/core/graph/gradient_config.h" #include "test/util/include/test_random_seed.h" #include @@ -305,10 +306,12 @@ inline Status GradientChecker::InitOpTesterWithGradGraph( } } + training::GradientGraphConfiguration gradient_graph_config; training::GradientGraphBuilder grad_graph_builder(&graph, dy_values, weights_to_train, "", + gradient_graph_config, true); Status status = grad_graph_builder.Build(); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); diff --git a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc index c8d546dba3..048efc6078 100644 --- a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc +++ b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc @@ -5,6 +5,7 @@ #include "core/session/inference_session.h" #include "orttraining/core/session/training_session.h" #include "orttraining/core/framework/gradient_graph_builder.h" +#include "orttraining/core/graph/gradient_config.h" #include "default_providers.h" namespace onnxruntime { @@ -69,10 +70,12 @@ void GradientOpTester::Run( } } + training::GradientGraphConfiguration gradient_graph_config; training::GradientGraphBuilder grad_graph_builder(&graph, dy_values, weights_to_train, "", + gradient_graph_config, true); status = grad_graph_builder.Build(); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); diff --git a/orttraining/orttraining/test/training_ops/cuda/layer_norm_test.cc b/orttraining/orttraining/test/training_ops/cuda/layer_norm_test.cc index 8fb8f543e7..cbd23c43f6 100644 --- a/orttraining/orttraining/test/training_ops/cuda/layer_norm_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/layer_norm_test.cc @@ -97,5 +97,143 @@ TEST(CudaKernelTest, LayerNormGrad_LargeSizeTensor) { TestLayerNormGrad(X_dims, -1, 5e-3); } +static void TestInvertibleLayerNormGrad( + const std::vector& x_dims, + int64_t axis = -1, + double error_tolerance = 1e-4, + bool test_fp16=false) { + const std::vector& n_x_m_dims = x_dims; + std::vector n_dims, m_dims; + ASSERT_TRUE(SplitDims(n_x_m_dims, axis, n_dims, m_dims).IsOK()); + + const auto N = std::accumulate(n_dims.begin(), n_dims.end(), static_cast(1), std::multiplies<>{}); + const auto M = std::accumulate(m_dims.begin(), m_dims.end(), static_cast(1), std::multiplies<>{}); + + CompareOpTester test{"InvertibleLayerNormalizationGrad", 1, kMSDomain}; + + test.AddAttribute("axis", axis); + + RandomValueGenerator random{}; + const auto Y_grad_data = random.Uniform(n_x_m_dims, k_random_data_min, k_random_data_max); + const auto X_data = random.Uniform(n_x_m_dims, k_random_data_min, k_random_data_max); + const auto scale_data = random.Uniform(m_dims, k_random_data_min, k_random_data_max); + const auto bias_data = random.Uniform(m_dims, k_random_data_min, k_random_data_max); + + // these inputs are dependent on X_data + std::vector mean_data(N); // mean(X) + std::vector inv_std_var_data(N); // 1 / sqrt(mean(X^2) - mean(X)^2 + epsilon) + std::vector Y_data(N*M); + { + using ConstEigenArrayMap = Eigen::Map>; + using EigenArrayMap = Eigen::Map>; + + ConstEigenArrayMap X{X_data.data(), M, N}; + + for (int i = 0; i < N; ++i) { + mean_data[i] = X.col(i).mean(); + inv_std_var_data[i] = X.col(i).square().mean() - mean_data[i] * mean_data[i]; + } + + // Compute Y = ((x - mean) * (inv_var) * scale + bias + EigenArrayMap Y(Y_data.data(), M, N); + + using EigenVectorArrayMap = Eigen::Map>; + using ConstEigenVectorArrayMap = Eigen::Map>; + ConstEigenVectorArrayMap mean(mean_data.data(), N); + EigenVectorArrayMap inv_std_var(inv_std_var_data.data(), N); + inv_std_var = (inv_std_var + k_epsilon_default).sqrt().inverse(); + + Y = (X.rowwise() - mean.transpose()).rowwise() * inv_std_var.transpose(); + + ConstEigenVectorArrayMap scale(scale_data.data(), M); + ConstEigenVectorArrayMap bias(bias_data.data(), M); + Y = (Y.colwise() * scale).colwise() + bias; + } + + if (test_fp16) { + std::vector Y_grad_data_half(Y_grad_data.size()); + std::vector Y_data_half(Y_data.size()); + std::vector scale_data_half(scale_data.size()); + std::vector bias_data_half(bias_data.size()); + ConvertFloatToMLFloat16(Y_grad_data.data(),Y_grad_data_half.data(), int(Y_grad_data.size())); + ConvertFloatToMLFloat16(Y_data.data(),Y_data_half.data(), int(Y_data.size())); + ConvertFloatToMLFloat16(scale_data.data(), scale_data_half.data(), int(scale_data.size())); + ConvertFloatToMLFloat16(bias_data.data(), bias_data_half.data(), int(bias_data.size())); + + test.AddInput("Y_grad", n_x_m_dims, Y_grad_data_half); + test.AddInput("Y", n_x_m_dims, Y_data_half); + test.AddInput("scale", m_dims, scale_data_half, true); + test.AddInput("bias", m_dims, bias_data_half); + + const auto X_grad_data = FillZeros(n_x_m_dims); + const auto scale_grad_data = FillZeros(m_dims); + const auto bias_grad_data = FillZeros(m_dims); + test.AddOutput("X_grad", n_x_m_dims, X_grad_data); + test.AddOutput("scale_grad_data", m_dims, scale_grad_data); + test.AddOutput("bias_grad_data", m_dims, bias_grad_data); + } else { + test.AddInput("Y_grad", n_x_m_dims, Y_grad_data); + test.AddInput("Y", n_x_m_dims, Y_data); + test.AddInput("scale", m_dims, scale_data, true); + test.AddInput("bias", m_dims, bias_data); + + const auto X_grad_data = FillZeros(n_x_m_dims); + const auto scale_grad_data = FillZeros(m_dims); + const auto bias_grad_data = FillZeros(m_dims); + test.AddOutput("X_grad", n_x_m_dims, X_grad_data); + test.AddOutput("scale_grad_data", m_dims, scale_grad_data); + test.AddOutput("bias_grad_data", m_dims, bias_grad_data); + } + test.AddInput("inv_std_var", n_dims, inv_std_var_data); + + if (test_fp16) { + test.CompareWithCPU(kCudaExecutionProvider, error_tolerance, error_tolerance); + } else { + test.CompareWithCPU(kCudaExecutionProvider, error_tolerance); + } +} + +TEST(CudaKernelTest, InvertibleLayerNormGrad_SmallSizeTensor) { + const std::vector X_dims{4, 20, 128}; + TestInvertibleLayerNormGrad(X_dims); +} + +TEST(CudaKernelTest, InvertibleLayerNormGrad_SmallSizeTensor_IntermediateAxis) { + const std::vector X_dims{4, 20, 16, 8}; + const int64_t axis = -2; + TestInvertibleLayerNormGrad(X_dims, axis); +} + +TEST(CudaKernelTest, InvertibleLayerNormGrad_MidSizeTensor) { + const std::vector X_dims{8, 80, 768}; + TestInvertibleLayerNormGrad(X_dims); +} + +TEST(CudaKernelTest, InvertibleLayerNormGrad_LargeSizeTensor) { + const std::vector X_dims{16, 512, 1024}; + TestInvertibleLayerNormGrad(X_dims, -1, 5e-3); +} + +TEST(CudaKernelTest, InvertibleLayerNormGrad_SmallSizeTensor_FP16) { + const std::vector X_dims{4, 20, 128}; + TestInvertibleLayerNormGrad(X_dims, -1, 2e-3, true); +} + +TEST(CudaKernelTest, InvertibleLayerNormGrad_SmallSizeTensor_IntermediateAxis_FP16) { + const std::vector X_dims{4, 20, 16, 8}; + const int64_t axis = -2; + TestInvertibleLayerNormGrad(X_dims, axis, 2e-3, true); +} + +TEST(CudaKernelTest, InvertibleLayerNormGrad_MidSizeTensor_FP16) { + const std::vector X_dims{8, 80, 768}; + TestInvertibleLayerNormGrad(X_dims, -1, 2e-3, true); +} + +TEST(CudaKernelTest, InvertibleLayerNormGrad_LargeSizeTensor_FP16) { + const std::vector X_dims{16, 512, 1024}; + TestInvertibleLayerNormGrad(X_dims, -1, 2e-3, true); +} + } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc index 1ae886545c..a302ad0f52 100644 --- a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc @@ -77,6 +77,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GistB class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GistBinarizeDecoder); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, LayerNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, LayerNormalizationGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, InvertibleLayerNormalizationGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, InvertibleLayerNormalizationGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SliceGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FastGeluGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BiasGeluGrad_dX); @@ -155,6 +157,8 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/cpu/nn/layer_norm.cc b/orttraining/orttraining/training_ops/cpu/nn/layer_norm.cc index fc2128b05b..e55fe5e411 100644 --- a/orttraining/orttraining/training_ops/cpu/nn/layer_norm.cc +++ b/orttraining/orttraining/training_ops/cpu/nn/layer_norm.cc @@ -20,7 +20,16 @@ namespace contrib { kCpuExecutionProvider, \ KernelDefBuilder() \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - LayerNormGrad); + LayerNormGrad); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + InvertibleLayerNormalizationGrad, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + InvertibleLayerNormGrad); REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(double) @@ -99,5 +108,76 @@ Status LayerNormGrad::Compute(OpKernelContext* op_kernel_context) const { return Status::OK(); } +template +InvertibleLayerNormGrad::InvertibleLayerNormGrad(const OpKernelInfo& op_kernel_info) + : OpKernel{op_kernel_info} { + ORT_ENFORCE(op_kernel_info.GetAttr("axis", &axis_).IsOK()); +} + +template +Status InvertibleLayerNormGrad::Compute(OpKernelContext* op_kernel_context) const { + const Tensor* Y_grad = op_kernel_context->Input(0); + const Tensor* Y = op_kernel_context->Input(1); + const Tensor* scale = op_kernel_context->Input(2); + const Tensor* bias = op_kernel_context->Input(3); + const Tensor* inv_std_var = op_kernel_context->Input(4); + + const auto& Y_shape = Y_grad->Shape(); + const auto& X_shape = Y_shape; + const auto axis = HandleNegativeAxis(axis_, X_shape.NumDimensions()); + const auto N = X_shape.SizeToDimension(axis); + const auto M = X_shape.SizeFromDimension(axis); + ORT_ENFORCE(M != 1); + const auto& scale_shape = scale->Shape(); + + Tensor* X_grad = op_kernel_context->Output(0, X_shape); + Tensor* scale_grad = op_kernel_context->Output(1, scale_shape); + Tensor* bias_grad = op_kernel_context->Output(2, scale_shape); + + // Note: Eigen has column-major storage order by default + ConstEigenArrayMap Y_grad_arr{Y_grad->Data(), M, N}; + ConstEigenArrayMap Y_arr{Y->Data(), M, N}; + ConstEigenVectorArrayMap scale_vec{scale->Data(), M}; + ConstEigenVectorArrayMap bias_vec{bias->Data(), M}; + ConstEigenVectorArrayMap inv_std_var_vec{inv_std_var->Data(), N}; + + EigenArrayMap X_grad_arr{X_grad->MutableData(), M, N}; + EigenVectorArrayMap scale_grad_vec{scale_grad->MutableData(), M}; + EigenVectorArrayMap bias_grad_vec{bias_grad->MutableData(), M}; + + using Array = Eigen::ArrayXX; + using RowVector = Eigen::Array; + + // A, B, C are calculated as below: + // A = Y_grad * (X - mean(X)) * inv_std_var + // B = Y_grad * scale * inv_std_var + // C = Y_grad * scale * inv_std_var * (X - mean(X)) * inv_std_var + + // A, B, and C are M x N + Array X_mean_difference_over_std_var = (Y_arr.colwise() - bias_vec).colwise() / scale_vec; + Array A = Y_grad_arr * X_mean_difference_over_std_var; + Array B = (Y_grad_arr.colwise() * scale_vec).rowwise() * inv_std_var_vec.cast().transpose(); + Array C = B * X_mean_difference_over_std_var; + + // mean_B = mean(Y_grad * scale * inv_std_var) + RowVector mean_B = B.colwise().mean(); // 1 x N + + // mean_C = mean(Y_grad * scale * inv_std_var * (X - mean(X)) * inv_std_var) + RowVector mean_C = C.colwise().mean(); // 1 x N + + // X_grad = Y_grad * scale * inv_std_var - mean_B - (X - mean(X)) * inv_std_var * mean_C + // = B - mean_B - (X - mean(X)) * inv_std_var * mean_c + X_grad_arr = B.rowwise() - mean_B - X_mean_difference_over_std_var.rowwise() * mean_C; + + // bias_grad = sum(Y_grad) + bias_grad_vec = Y_grad_arr.rowwise().sum(); + + // scale_grad = sum(Y_grad * (X - mean(X)) * inv_std_var) + // = sum(A) + scale_grad_vec = A.rowwise().sum(); + + return Status::OK(); +} + } // namespace contrib } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/nn/layer_norm.h b/orttraining/orttraining/training_ops/cpu/nn/layer_norm.h index 14024c6068..4cd6207c26 100644 --- a/orttraining/orttraining/training_ops/cpu/nn/layer_norm.h +++ b/orttraining/orttraining/training_ops/cpu/nn/layer_norm.h @@ -18,5 +18,15 @@ class LayerNormGrad final : public OpKernel { int64_t axis_; }; +template +class InvertibleLayerNormGrad final : public OpKernel { + public: + InvertibleLayerNormGrad(const OpKernelInfo& op_kernel_info); + Status Compute(OpKernelContext* op_kernel_context) const override; + + private: + int64_t axis_; +}; + } // namespace contrib } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index e648e6fb14..e3a3e87662 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -89,8 +89,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_MLFloat16, ReduceAllL2); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16, ReduceAllL2); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float, LayerNormalizationGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_float, LayerNormalizationGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_double, LayerNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_float, LayerNormalizationGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float, InvertibleLayerNormalizationGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_double, InvertibleLayerNormalizationGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_float, InvertibleLayerNormalizationGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SliceGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GatherElementsGrad); @@ -198,8 +201,11 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/cuda/nn/layer_norm.cc b/orttraining/orttraining/training_ops/cuda/nn/layer_norm.cc index f2e3b0e026..d7f8acfc8c 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/layer_norm.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/layer_norm.cc @@ -19,9 +19,19 @@ namespace cuda { KernelDefBuilder() \ .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ .TypeConstraint("U", DataTypeImpl::GetTensorType()), \ - LayerNormGrad); + LayerNormGrad); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + InvertibleLayerNormalizationGrad, \ + kMSDomain, \ + 1, \ + T##_##U, \ + kCudaExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("U", DataTypeImpl::GetTensorType()), \ + InvertibleLayerNormGrad); REGISTER_GRADIENT_KERNEL_TYPED(float, float) -REGISTER_GRADIENT_KERNEL_TYPED(double, float) +REGISTER_GRADIENT_KERNEL_TYPED(double, double) REGISTER_GRADIENT_KERNEL_TYPED(MLFloat16, float) template @@ -65,7 +75,58 @@ Status LayerNormGrad::ComputeInternal(OpKernelContext* p_op_kernel_context auto part_grad_gamma = GetScratchBuffer(part_size * n2); auto part_grad_beta = GetScratchBuffer(part_size * n2); - HostLayerNormGradient(GetDeviceProp(), Y_grad_data, mean_data, inv_std_var_data, X_data, n1, n2, scale_data, X_grad_data, scale_grad_data, bias_grad_data, + HostLayerNormGradient(GetDeviceProp(), Y_grad_data, X_data, reinterpret_cast(NULL), + scale_data, reinterpret_cast(NULL), mean_data, inv_std_var_data, n1, n2, + X_grad_data, scale_grad_data, bias_grad_data, + part_grad_gamma.get(), part_grad_beta.get(), part_size); + return Status::OK(); +} + +template +InvertibleLayerNormGrad::InvertibleLayerNormGrad(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) { + ORT_ENFORCE(op_kernel_info.GetAttr("axis", &axis_).IsOK()); +} + +template +Status InvertibleLayerNormGrad::ComputeInternal(OpKernelContext* p_op_kernel_context) const { + typedef typename ToCudaType::MappedType CudaT; + typedef typename ToCudaType::MappedType CudaU; + // Inputs + const Tensor* Y_grad = p_op_kernel_context->Input(0); + const Tensor* Y = p_op_kernel_context->Input(1); + const Tensor* scale = p_op_kernel_context->Input(2); + const Tensor* bias = p_op_kernel_context->Input(3); + const Tensor* inv_std_var = p_op_kernel_context->Input(4); + + auto Y_grad_data = reinterpret_cast(Y_grad->template Data()); + auto Y_data = reinterpret_cast(Y->template Data()); + auto scale_data = reinterpret_cast(scale->template Data()); + auto bias_data = reinterpret_cast(bias->template Data()); + auto inv_std_var_data = reinterpret_cast(inv_std_var->template Data()); + + const TensorShape& y_shape = Y->Shape(); + const TensorShape& x_shape = y_shape; + const int64_t axis = HandleNegativeAxis(axis_, x_shape.NumDimensions()); + auto n1 = x_shape.SizeToDimension(axis); + auto n2 = x_shape.SizeFromDimension(axis); + ORT_ENFORCE(n2 != 1, "n2 should not be 1"); + + // Outputs + Tensor* X_grad = p_op_kernel_context->Output(0, x_shape); + auto X_grad_data = reinterpret_cast(X_grad->template MutableData()); + + Tensor* scale_grad = p_op_kernel_context->Output(1, scale->Shape()); + Tensor* bias_grad = p_op_kernel_context->Output(2, scale->Shape()); + auto scale_grad_data = reinterpret_cast(scale_grad->template MutableData()); + auto bias_grad_data = reinterpret_cast(bias_grad->template MutableData()); + + const int part_size = 16; + auto part_grad_gamma = GetScratchBuffer(part_size * n2); + auto part_grad_beta = GetScratchBuffer(part_size * n2); + + HostLayerNormGradient(GetDeviceProp(), Y_grad_data, reinterpret_cast(NULL), Y_data, + scale_data, bias_data, reinterpret_cast(NULL), inv_std_var_data, n1, n2, + X_grad_data, scale_grad_data, bias_grad_data, part_grad_gamma.get(), part_grad_beta.get(), part_size); return Status::OK(); } diff --git a/orttraining/orttraining/training_ops/cuda/nn/layer_norm.h b/orttraining/orttraining/training_ops/cuda/nn/layer_norm.h index fd21b09ba4..ab092ed12a 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/layer_norm.h +++ b/orttraining/orttraining/training_ops/cuda/nn/layer_norm.h @@ -25,5 +25,15 @@ class LayerNormGrad final : public CudaKernel { int64_t axis_; }; +template +class InvertibleLayerNormGrad final : public CudaKernel { + public: + InvertibleLayerNormGrad(const OpKernelInfo& op_kernel_info); + Status ComputeInternal(OpKernelContext* ctx) const override; + + private: + int64_t axis_; +}; + } // namespace cuda } // namespace onnxruntime \ No newline at end of file diff --git a/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu b/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu index 7f85c5676f..11753e080b 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu @@ -372,10 +372,10 @@ void HostApplyLayerNorm( LAYERNORM_LINEAR_IMPL(float, float) LAYERNORM_LINEAR_IMPL(half, float) -LAYERNORM_LINEAR_IMPL(double, float) +LAYERNORM_LINEAR_IMPL(double, double) //LAYERNORM_LINEAR_IMPL(half, half) -template +template __device__ void cuLoadWriteStridedInputs( const int i1_block, const int thr_load_row_off, @@ -385,24 +385,34 @@ __device__ void cuLoadWriteStridedInputs( U* warp_buf1, U* warp_buf2, const T* input, + const T* output, const T* dout, const int i1_end, const int n2, + const T* __restrict__ gamma, + const T* __restrict__ beta, const U* __restrict__ mean, const U* __restrict__ invvar) { int i1 = i1_block + thr_load_row_off; if (i1 < i1_end) { - U curr_mean = mean[i1]; - U curr_invvar = invvar[i1]; + U curr_mean = use_mean ? mean[i1] : U(0); + U curr_invvar = use_mean ? invvar[i1] : U(0); for (int k = 0; k < blockDim.y; ++k) { int i2 = i2_off + k; int load_idx = i1 * n2 + i2; int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; if (i2 < n2) { - U curr_input = static_cast(input[load_idx]); U curr_dout = static_cast(dout[load_idx]); warp_buf1[write_idx] = curr_dout; - warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar; + if (use_mean) { + U curr_input = static_cast(input[load_idx]); + warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar; + } else { + U curr_gamma = static_cast(gamma[i2]); + U curr_beta = static_cast(beta[i2]); + U curr_output = static_cast(output[load_idx]); + warp_buf2[write_idx] = curr_dout * (curr_output - curr_beta) / curr_gamma; + } } else { warp_buf1[write_idx] = U(0); warp_buf2[write_idx] = U(0); @@ -417,7 +427,7 @@ __device__ void cuLoadWriteStridedInputs( } } -template +template __device__ void cuLoadAddStridedInputs( const int i1_block, const int thr_load_row_off, @@ -427,37 +437,50 @@ __device__ void cuLoadAddStridedInputs( U* warp_buf1, U* warp_buf2, const T* input, + const T* output, const T* dout, const int i1_end, const int n2, + const T* __restrict__ gamma, + const T* __restrict__ beta, const U* __restrict__ mean, const U* __restrict__ invvar) { int i1 = i1_block + thr_load_row_off; if (i1 < i1_end) { - U curr_mean = mean[i1]; - U curr_invvar = invvar[i1]; + U curr_mean = use_mean ? mean[i1] : U(0); + U curr_invvar = use_mean ? invvar[i1] : U(0); for (int k = 0; k < blockDim.y; ++k) { int i2 = i2_off + k; int load_idx = i1 * n2 + i2; int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; if (i2 < n2) { - U curr_input = static_cast(input[load_idx]); U curr_dout = static_cast(dout[load_idx]); warp_buf1[write_idx] += curr_dout; - warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; + if (use_mean) { + U curr_input = static_cast(input[load_idx]); + warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; + } else { + U curr_gamma = static_cast(gamma[i2]); + U curr_beta = static_cast(beta[i2]); + U curr_output = static_cast(output[load_idx]); + warp_buf2[write_idx] += curr_dout * (curr_output - curr_beta) / curr_gamma; + } } } } } -template +template __global__ void cuComputePartGradGammaBeta( const T* __restrict__ dout, const T* __restrict__ input, - const int n1, - const int n2, + const T* __restrict__ output, + const T* __restrict__ gamma, + const T* __restrict__ beta, const U* __restrict__ mean, const U* __restrict__ invvar, + const int n1, + const int n2, U* part_grad_gamma, U* part_grad_beta) { const int numsegs_n1 = (n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y); @@ -475,9 +498,9 @@ __global__ void cuComputePartGradGammaBeta( U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; // compute partial sums from strided inputs // do this to increase number of loads in flight - cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off, row_stride, warp_buf1, warp_buf2, input, dout, i1_end, n2, mean, invvar); + cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off, row_stride, warp_buf1, warp_buf2, input, output, dout, i1_end, n2, gamma, beta, mean, invvar); for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end; i1_block += blockDim.y * blockDim.y) { - cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off, row_stride, warp_buf1, warp_buf2, input, dout, i1_end, n2, mean, invvar); + cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off, row_stride, warp_buf1, warp_buf2, input, output, dout, i1_end, n2, gamma, beta, mean, invvar); } __syncthreads(); // inter-warp reductions @@ -566,22 +589,25 @@ __global__ void cuComputeGradGammaBeta( } } -template +template __global__ void cuComputeGradInput( const T* __restrict__ dout, const T* __restrict__ input, - const int n1, - const int n2, + const T* __restrict__ output, + const T* gamma, + const T* beta, const U* __restrict__ mean, const U* __restrict__ invvar, - const T* gamma, + const int n1, + const int n2, T* grad_input) { for (int i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { U sum_loss1 = U(0); U sum_loss2 = U(0); - const U c_mean = mean[i1]; + const U c_mean = use_mean ? mean[i1] : U(0); const U c_invvar = invvar[i1]; - const T* k_input = input + i1 * n2; + const T* k_input = use_mean ? input + i1 * n2 : nullptr; + const T* k_output = use_mean ? nullptr: output + i1 * n2; const T* k_dout = dout + i1 * n2; const int numx = blockDim.x * blockDim.y; const int thrx = threadIdx.x + threadIdx.y * blockDim.x; @@ -589,33 +615,53 @@ __global__ void cuComputeGradInput( int l = 4 * thrx; for (; l + 3 < n2; l += 4 * numx) { for (int k = 0; k < 4; ++k) { - const U c_h = static_cast(k_input[l + k]); const U c_loss = static_cast(k_dout[l + k]); sum_loss1 += c_loss * U(gamma[l + k]); - sum_loss2 += c_loss * U(gamma[l + k]) * (c_h - c_mean) * c_invvar; + if (use_mean) { + const U c_h = static_cast(k_input[l + k]); + sum_loss2 += c_loss * U(gamma[l + k]) * (c_h - c_mean) * c_invvar; + } else { + const U c_output = static_cast(k_output[l + k]); + sum_loss2 += c_loss * (c_output - U(beta[l + k])); + } } } for (; l < n2; ++l) { - const U c_h = static_cast(k_input[l]); const U c_loss = static_cast(k_dout[l]); sum_loss1 += c_loss * U(gamma[l]); - sum_loss2 += c_loss * U(gamma[l]) * (c_h - c_mean) * c_invvar; + if (use_mean) { + const U c_h = static_cast(k_input[l]); + sum_loss2 += c_loss * U(gamma[l]) * (c_h - c_mean) * c_invvar; + } else { + const U c_output = static_cast(k_output[l]); + sum_loss2 += c_loss * (c_output - U(beta[l])); + } } } else { int l = 4 * thrx; for (; l + 3 < n2; l += 4 * numx) { for (int k = 0; k < 4; ++k) { - const U c_h = static_cast(k_input[l + k]); const U c_loss = static_cast(k_dout[l + k]); sum_loss1 += c_loss; - sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + if (use_mean) { + const U c_h = static_cast(k_input[l + k]); + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } else { + const U c_output = static_cast(k_output[l + k]); + sum_loss2 += c_loss * c_output; + } } } for (; l < n2; ++l) { - const U c_h = static_cast(k_input[l]); const U c_loss = static_cast(k_dout[l]); sum_loss1 += c_loss; - sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + if (use_mean) { + const U c_h = static_cast(k_input[l]); + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } else { + const U c_output = static_cast(k_output[l]); + sum_loss2 += c_loss * c_output; + } } } // intra-warp reductions @@ -659,21 +705,31 @@ __global__ void cuComputeGradInput( T* k_grad_input = grad_input + i1 * n2; if (gamma != NULL) { for (int l = thrx; l < n2; l += numx) { - const U c_h = static_cast(k_input[l]); const U c_loss = static_cast(k_dout[l]); U f_grad_input = fH * c_loss * U(gamma[l]); f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + if (use_mean) { + const U c_h = static_cast(k_input[l]); + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + } else { + const U c_output = static_cast(k_output[l]); + f_grad_input -= (c_output - U(beta[l])) / U(gamma[l]) * sum_loss2; + } f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); } } else { for (int l = thrx; l < n2; l += numx) { - const U c_h = static_cast(k_input[l]); const U c_loss = static_cast(k_dout[l]); U f_grad_input = fH * c_loss; f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + if (use_mean) { + const U c_h = static_cast(k_input[l]); + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + } else { + const U c_output = static_cast(k_output[l]); + f_grad_input -= c_output * sum_loss2; + } f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); } @@ -683,20 +739,22 @@ __global__ void cuComputeGradInput( template void HostLayerNormGradient( - const cudaDeviceProp& prop, - const T* dout, - const U* mean, - const U* invvar, - const T* input, - int64_t n1, - int64_t n2, - const T* gamma, - T* grad_input, - T* grad_gamma, - T* grad_beta, - U* part_grad_gamma, - U* part_grad_beta, - const int part_size) { + const cudaDeviceProp& prop, + const T* dout, + const T* input, + const T* output, + const T* gamma, + const T* beta, + const U* mean, + const U* invvar, + int64_t n1, + int64_t n2, + T* grad_input, + T* grad_gamma, + T* grad_beta, + U* part_grad_gamma, + U* part_grad_beta, + const int part_size) { const int warp_size = prop.warpSize; ORT_ENFORCE(warp_size == GPU_WARP_SIZE); @@ -706,14 +764,31 @@ void HostLayerNormGradient( const int nshared2_b = threads2.x * threads2.y * sizeof(U); const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; - cuComputePartGradGammaBeta<<>>( + if (mean == nullptr) { + cuComputePartGradGammaBeta<<>>( dout, input, - n1, n2, + output, + gamma, + beta, mean, invvar, + n1, n2, part_grad_gamma, part_grad_beta); + } else { + cuComputePartGradGammaBeta<<>>( + dout, + input, + output, + gamma, + beta, + mean, + invvar, + n1, n2, + part_grad_gamma, + part_grad_beta); + } const dim3 threads3(warp_size, 8, 1); const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1); @@ -732,22 +807,38 @@ void HostLayerNormGradient( const dim3 threads1(warp_size, 4, 1); int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0; - cuComputeGradInput<<>>( + if (mean == nullptr) { + cuComputeGradInput<<>>( dout, input, - n1, n2, + output, + gamma, + beta, mean, invvar, - gamma, + n1, n2, grad_input); + } else { + cuComputeGradInput<<>>( + dout, + input, + output, + gamma, + beta, + mean, + invvar, + n1, n2, + grad_input); + } } -#define LAYERNORMGRAD_IMPL(T, U) \ - template void HostLayerNormGradient(const cudaDeviceProp& prop, const T* dout, const U* mean, const U* invvar, const T* input, int64_t n1, int64_t n2, const T* gamma, \ +#define LAYERNORMGRAD_IMPL(T, U) \ + template void HostLayerNormGradient(const cudaDeviceProp& prop, const T* dout, const T* input, const T* output, \ + const T* gamma, const T* beta, const U* mean, const U* invvar, int64_t n1, int64_t n2, \ T* grad_input, T* grad_gamma, T* grad_beta, U* part_grad_gamma, U* part_grad_beta, const int part_size); LAYERNORMGRAD_IMPL(float, float) -LAYERNORMGRAD_IMPL(double, float) +LAYERNORMGRAD_IMPL(double, double) LAYERNORMGRAD_IMPL(half, float) } // namespace cuda diff --git a/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.h b/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.h index 285d36e94e..ea61ff71ab 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.h +++ b/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.h @@ -45,12 +45,14 @@ template void HostLayerNormGradient( const cudaDeviceProp& prop, const T* dout, + const T* input, + const T* output, + const T* gamma, + const T* beta, const U* mean, const U* invvar, - const T* input, int64_t n1, int64_t n2, - const T* gamma, T* grad_input, T* grad_gamma, T* grad_beta,