From bd11ab68161bcb1cd3af7881fc1875eb3f9929c8 Mon Sep 17 00:00:00 2001 From: Weixing Zhang Date: Thu, 2 Jul 2020 22:09:30 -0700 Subject: [PATCH] Optimize LayernormGrad (#4156) * Draft for LayerNorm Optimization * Modify LayernormGrad kernel based on new backward graph. * keep two LayernormGrad implementations. One is implemented based on input X, mean. The other is based on output Y, scale, bias. The first one is enabled by default. The second one can be enabled by --use_invertible_layernorm_grad * expose use_invertible_layernorm_grad to frontend. * add fp16 tests. Co-authored-by: Sherlock Huang Co-authored-by: Weixing Zhang --- .../contrib_ops/cuda/cuda_contrib_kernels.cc | 4 +- onnxruntime/contrib_ops/cuda/layer_norm.cc | 2 +- .../contrib_ops/cuda/layer_norm_impl.cu | 2 +- .../core/framework/gradient_graph_builder.cc | 5 +- .../core/framework/gradient_graph_builder.h | 3 + .../core/graph/gradient_builder.cc | 18 +- .../core/graph/gradient_builder_base.h | 9 +- .../core/graph/gradient_builder_registry.cc | 5 +- .../core/graph/gradient_builder_registry.h | 4 +- .../orttraining/core/graph/gradient_config.h | 18 ++ .../core/graph/training_op_defs.cc | 26 +++ .../core/session/training_session.cc | 8 +- .../core/session/training_session.h | 7 + orttraining/orttraining/models/bert/main.cc | 7 +- .../models/runner/training_runner.cc | 1 + .../models/runner/training_runner.h | 3 + orttraining/orttraining/python/ort_trainer.py | 14 +- .../python/orttraining_pybind_state.cc | 6 +- .../test/gradient/gradient_checker.cc | 3 + .../test/gradient/gradient_op_test_utils.cc | 3 + .../test/training_ops/cuda/layer_norm_test.cc | 138 ++++++++++++ .../training_ops/cpu/cpu_training_kernels.cc | 4 + .../training_ops/cpu/nn/layer_norm.cc | 82 ++++++- .../training_ops/cpu/nn/layer_norm.h | 10 + .../cuda/cuda_training_kernels.cc | 10 +- .../training_ops/cuda/nn/layer_norm.cc | 67 +++++- .../training_ops/cuda/nn/layer_norm.h | 10 + .../training_ops/cuda/nn/layer_norm_impl.cu | 203 +++++++++++++----- .../training_ops/cuda/nn/layer_norm_impl.h | 6 +- 29 files changed, 594 insertions(+), 84 deletions(-) create mode 100644 orttraining/orttraining/core/graph/gradient_config.h diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 685ae59fc6..1005423035 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -61,7 +61,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float, LayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double_float, LayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double_double, LayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float, LayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear); @@ -125,7 +125,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/layer_norm.cc b/onnxruntime/contrib_ops/cuda/layer_norm.cc index bd6d14eef0..de89bd424d 100644 --- a/onnxruntime/contrib_ops/cuda/layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/layer_norm.cc @@ -24,7 +24,7 @@ namespace cuda { LayerNorm); REGISTER_KERNEL_TYPED(float, float) -REGISTER_KERNEL_TYPED(double, float) +REGISTER_KERNEL_TYPED(double, double) REGISTER_KERNEL_TYPED(MLFloat16, float) template diff --git a/onnxruntime/contrib_ops/cuda/layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/layer_norm_impl.cu index 747a2ff70e..4251c21033 100644 --- a/onnxruntime/contrib_ops/cuda/layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/layer_norm_impl.cu @@ -376,7 +376,7 @@ void HostApplyLayerNorm( LAYERNORM_LINEAR_IMPL(float, float) LAYERNORM_LINEAR_IMPL(half, float) -LAYERNORM_LINEAR_IMPL(double, float) +LAYERNORM_LINEAR_IMPL(double, double) //LAYERNORM_LINEAR_IMPL(half, half) } // namespace cuda diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.cc b/orttraining/orttraining/core/framework/gradient_graph_builder.cc index 0b32ec0857..fe44ae78b6 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.cc +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.cc @@ -6,6 +6,7 @@ #include "core/graph/schema_registry.h" #include "orttraining/core/framework/gradient_graph_builder.h" #include "orttraining/core/graph/gradient_builder_registry.h" +#include "orttraining/core/graph/gradient_config.h" #include "orttraining/core/optimizer/insert_output_rewriter.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/rule_based_graph_transformer.h" @@ -22,9 +23,11 @@ GradientGraphBuilder::GradientGraphBuilder(Graph* graph, const unordered_set& y_node_arg_names, const unordered_set& x_node_arg_names, string loss_node_arg_name, + const GradientGraphConfiguration& gradient_graph_config, const bool set_gradient_as_graph_output) : graph_(graph), loss_node_arg_name_(loss_node_arg_name), + gradient_graph_config_(gradient_graph_config), set_gradient_as_graph_output_(set_gradient_as_graph_output) { auto rule_based_graph_transformer = onnxruntime::make_unique("pre_training_rule_based_graph_transformer"); @@ -187,7 +190,7 @@ Status GradientGraphBuilder::Build() { } } - GradientDef node_defs = GetGradientForOp(node, output_args_need_grad, input_args_need_grad); + GradientDef node_defs = GetGradientForOp(gradient_graph_config_, node, output_args_need_grad, input_args_need_grad); // updates arg name if gradient accumulation is needed for (auto& op_def : node_defs) { diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.h b/orttraining/orttraining/core/framework/gradient_graph_builder.h index 2c40ab86e8..403d543613 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.h +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.h @@ -58,6 +58,7 @@ class GradientGraphBuilder { const std::unordered_set& y_node_arg_names, const std::unordered_set& x_node_arg_names, std::string loss_node_arg_name, + const GradientGraphConfiguration& gradient_graph_config, const bool set_gradient_as_graph_output = false); Status Build(); @@ -73,6 +74,8 @@ class GradientGraphBuilder { std::string loss_node_arg_name_; + const GradientGraphConfiguration& gradient_graph_config_; + onnxruntime::GraphTransformerManager graph_transformation_mgr_{5}; // key: ArgDef for the gradient after accumulation diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 9c6cbf72fc..a823d46ba9 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -956,11 +956,19 @@ IMPLEMENT_GRADIENT_BUILDER(GetFastGeluGradient) { } IMPLEMENT_GRADIENT_BUILDER(GetLayerNormalizationGradient) { - return std::vector{ - NodeDef(OpDef{"LayerNormalizationGrad", kMSDomain, 1}, - {GO(0), I(0), I(1), O(1), O(2)}, - {GI(0), GI(1), GI(2)}, - {SrcNodeAttributes()})}; + if (GetGradientGraphConfiguration().use_invertible_layernorm_grad) { + return std::vector{ + NodeDef(OpDef{"InvertibleLayerNormalizationGrad", kMSDomain, 1}, + {GO(0), O(0), I(1), I(2), O(2)}, + {GI(0), GI(1), GI(2)}, + {SrcNodeAttributes()})}; + } else { + return std::vector{ + NodeDef(OpDef{"LayerNormalizationGrad", kMSDomain, 1}, + {GO(0), I(0), I(1), O(1), O(2)}, + {GI(0), GI(1), GI(2)}, + {SrcNodeAttributes()})}; + } } IMPLEMENT_GRADIENT_BUILDER(GetBatchNormalizationGradient) { diff --git a/orttraining/orttraining/core/graph/gradient_builder_base.h b/orttraining/orttraining/core/graph/gradient_builder_base.h index 0f5c79eb8e..90f5ed5124 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_base.h +++ b/orttraining/orttraining/core/graph/gradient_builder_base.h @@ -7,6 +7,7 @@ #include #include "core/graph/graph.h" #include "orttraining/core/graph/graph_augmenter.h" +#include "orttraining/core/graph/gradient_config.h" #include "onnx/defs/attr_proto_util.h" namespace onnxruntime { @@ -27,10 +28,11 @@ typedef std::vector GradientDef; class GradientBuilderBase { public: GradientBuilderBase( + const GradientGraphConfiguration& gradient_graph_config, const Node* node, const std::unordered_set& gradient_inputs, const std::unordered_set& gradient_outputs) - : node_(node), gradient_inputs_(gradient_inputs), gradient_outputs_(gradient_outputs) { + : gradient_graph_config_(gradient_graph_config), node_(node), gradient_inputs_(gradient_inputs), gradient_outputs_(gradient_outputs) { unique_node_prefix_ = CreateUniqueNodePrefix(); } @@ -54,6 +56,10 @@ class GradientBuilderBase { protected: virtual GradientDef GetGradientDefsImpl() const = 0; + const GradientGraphConfiguration& GetGradientGraphConfiguration() const { + return gradient_graph_config_; + } + // i-th input of forward op ArgDef I(const size_t i) const { ORT_ENFORCE(i < node_->InputDefs().size()); @@ -185,6 +191,7 @@ class GradientBuilderBase { return unique_prefix.str(); } + const GradientGraphConfiguration& gradient_graph_config_; const Node* node_; std::string unique_node_prefix_; diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index 54149f7bf3..94b4e1e096 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -3,11 +3,13 @@ #include "orttraining/core/graph/gradient_builder_registry.h" #include "orttraining/core/graph/gradient_builder.h" +#include "orttraining/core/graph/gradient_config.h" namespace onnxruntime { namespace training { -GradientDef GetGradientForOp(const Node* node, +GradientDef GetGradientForOp(const GradientGraphConfiguration& gradient_graph_config, + const Node* node, const std::unordered_set& output_args_need_grad, const std::unordered_set& input_args_need_grad) { @@ -16,6 +18,7 @@ GradientDef GetGradientForOp(const Node* node, // less than 9 is not supported and for Slice we have Slice-1, Slice-10 and Slice-11. auto gradient_builder = GradientBuilderRegistry::GetInstance().MakeUnique(node->OpType(), + gradient_graph_config, node, output_args_need_grad, input_args_need_grad); diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.h b/orttraining/orttraining/core/graph/gradient_builder_registry.h index 9568aff6bd..7acec7e7f4 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.h +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.h @@ -12,6 +12,7 @@ namespace onnxruntime { namespace training { typedef GenericRegistry&, // gradient_inputs const std::unordered_set&> // gradient_outputs @@ -31,7 +32,8 @@ class GradientBuilderRegistry : public GradientRegistryType { ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GradientBuilderRegistry); }; -GradientDef GetGradientForOp(const Node* node, +GradientDef GetGradientForOp(const GradientGraphConfiguration& gradient_graph_config, + const Node* node, const std::unordered_set& output_args_need_grad, const std::unordered_set& input_args_need_grad); diff --git a/orttraining/orttraining/core/graph/gradient_config.h b/orttraining/orttraining/core/graph/gradient_config.h new file mode 100644 index 0000000000..bee896965e --- /dev/null +++ b/orttraining/orttraining/core/graph/gradient_config.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +namespace training { + +struct GradientGraphConfiguration { + // Layernorm gradient can be computed based on either input or output of layernorm. + // That is to say, either input or output needs to be stashed for layernorm gradient. + // To save memory, ideally, only one(input vs output) should be stashed rather than both. + // By default, the input based algorithm is used. This flag is to enable the output based algorithm. + bool use_invertible_layernorm_grad{false}; +}; + +} // namespace training +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 17c816268c..23d99cf071 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -1432,6 +1432,32 @@ Example 4: {"tensor(float)"}, "Constrain mean and inv_std_var to float tensors."); + ONNX_CONTRIB_OPERATOR_SCHEMA(InvertibleLayerNormalizationGrad) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL) + .SetDoc("LayerNormalizationGrad") + .Attr("axis", + "The first normalization dimension: normalization will be performed along dimensions axis : rank(inputs).", + AttributeProto::INT, static_cast(-1)) + .AllowUncheckedAttributes() + .Input(0, "Y_grad", "The gradient tensor from output.", "T") + .Input(1, "Y", "Output data tensor from the forward path", "T") + .Input(2, "scale", "Scale tensor.", "T") + .Input(3, "bias", "Bias tensor.", "T") + .Input(4, "inv_std_var", "inverse std variance of X.", "U") + .Output(0, "X_grad", "Gradient of the input.", "T") + .Output(1, "scale_grad", "Gradient of the scale.", "T") + .Output(2, "bias_grad", "Gradient of the bias.", "T") + .TypeConstraint( + "T", + {"tensor(float16)", "tensor(float)", "tensor(double)"}, + "Constrain input and output types (except mean and inv_std_var) to float tensors.") + .TypeConstraint( + "U", + {"tensor(float)"}, + "Constrain mean and inv_std_var to float tensors."); + ONNX_CONTRIB_OPERATOR_SCHEMA(BatchNormalizationGrad) .SetDomain(kMSDomain) .SinceVersion(1) diff --git a/orttraining/orttraining/core/session/training_session.cc b/orttraining/orttraining/core/session/training_session.cc index 6b650eed47..1898639f12 100644 --- a/orttraining/orttraining/core/session/training_session.cc +++ b/orttraining/orttraining/core/session/training_session.cc @@ -219,7 +219,7 @@ Status TrainingSession::ConfigureForTraining( } ORT_RETURN_IF_ERROR(BuildGradientGraph( - weight_names_to_train, loss_name, config.set_gradients_as_graph_outputs)); + weight_names_to_train, loss_name, config.gradient_graph_config, config.set_gradients_as_graph_outputs)); // transform for mixed precision std::unordered_map fp32_weight_name_to_fp16_node_arg{}; @@ -425,12 +425,14 @@ static Status ConfigureLossFunctionInternal( static Status BuildGradientGraphInternal(Graph& graph, const std::string& loss_function_output_name, const std::unordered_set& node_arg_names_to_train, + const GradientGraphConfiguration& gradient_graph_config, const bool set_gradient_as_graph_output = false) { // Compute the gradient graph def. GradientGraphBuilder grad_graph_builder(&graph, {loss_function_output_name}, node_arg_names_to_train, loss_function_output_name, + gradient_graph_config, set_gradient_as_graph_output); return grad_graph_builder.Build(); } @@ -638,13 +640,16 @@ Status TrainingSession::EnableMixedPrecision(const std::unordered_set& weights_to_train, const std::string& loss_function_output_name, + const GradientGraphConfiguration& gradient_graph_config, const bool set_gradient_as_graph_output) { // Fill weights_to_train_ according to weights_to_train weights_to_train_ = weights_to_train; + gradient_graph_config_ = gradient_graph_config; ORT_RETURN_IF_ERROR(BuildGradientGraphInternal(model_->MainGraph(), loss_function_output_name, weights_to_train_, + gradient_graph_config_, set_gradient_as_graph_output)); return DoPostLoadProcessing(*model_); @@ -762,6 +767,7 @@ Status TrainingSession::Save(const PathString& model_uri, TrainingSession::SaveO ORT_RETURN_IF_ERROR(BuildGradientGraphInternal(new_model->MainGraph(), actual_loss_name, weights_to_train_, + gradient_graph_config_, false)); OptimizerOutputKeyMap opt_graph_outputs; diff --git a/orttraining/orttraining/core/session/training_session.h b/orttraining/orttraining/core/session/training_session.h index 9ede468f50..5453b54273 100644 --- a/orttraining/orttraining/core/session/training_session.h +++ b/orttraining/orttraining/core/session/training_session.h @@ -10,6 +10,7 @@ #include "orttraining/core/graph/loss_function_registry.h" #include "orttraining/core/graph/optimizer_graph_output_key.h" #include "orttraining/core/graph/optimizer_config.h" +#include "orttraining/core/graph/gradient_config.h" namespace onnxruntime { namespace training { @@ -42,6 +43,9 @@ class TrainingSession : public InferenceSession { // The immutable weights specification. ImmutableWeights immutable_weights; + // Gradient graph configuration + GradientGraphConfiguration gradient_graph_config{}; + // Whether to set the gradients as graph outputs. bool set_gradients_as_graph_outputs{false}; @@ -409,6 +413,7 @@ class TrainingSession : public InferenceSession { */ common::Status BuildGradientGraph(const std::unordered_set& weights_to_train, const std::string& loss_function_output_name, + const GradientGraphConfiguration& gradient_graph_config, const bool set_gradient_as_graph_output = false); common::Status BuildAccumulationNode(const std::unordered_set& weights_to_train); @@ -469,6 +474,8 @@ class TrainingSession : public InferenceSession { std::unordered_set dropout_eval_feeds_; OptimizerGraphConfig opt_graph_config_; std::unordered_map opt_configs_; + + GradientGraphConfiguration gradient_graph_config_; }; } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/models/bert/main.cc b/orttraining/orttraining/models/bert/main.cc index effa43b7ab..89067c9269 100644 --- a/orttraining/orttraining/models/bert/main.cc +++ b/orttraining/orttraining/models/bert/main.cc @@ -165,7 +165,9 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet ("enable_grad_norm_clip", "Specify whether to enable gradient clipping for optimizers.", cxxopts::value()->default_value("true")) ("enable_gelu_approximation", "Specify whether to enable GELU approximation.", - cxxopts::value()->default_value("true")); + cxxopts::value()->default_value("true")) + ("use_invertible_layernorm_grad", "Specify whether to use invertible laynorm(dropping the input activation)", + cxxopts::value()->default_value("false")); options .add_options("ORT configuration") ("ort_log_severity", "ORT minimum logging severity (see onnxruntime::logging::Severity values)", @@ -458,12 +460,15 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet "Log severity must be in the range [", static_cast(logging::Severity::kVERBOSE), ", ", static_cast(logging::Severity::kFATAL), "]."); ort_params.vlog_level = flags["ort_vlog_level"].as(); + + params.use_invertible_layernorm_grad = flags["use_invertible_layernorm_grad"].as(); } catch (const exception& e) { const std::string msg = "Failed to parse the command line arguments"; cerr << msg << ": " << e.what() << "\n" << options.help() << "\n"; return Status(ONNXRUNTIME, INVALID_ARGUMENT, msg); } + return Status::OK(); } diff --git a/orttraining/orttraining/models/runner/training_runner.cc b/orttraining/orttraining/models/runner/training_runner.cc index 226dbe2d27..e94dc86977 100644 --- a/orttraining/orttraining/models/runner/training_runner.cc +++ b/orttraining/orttraining/models/runner/training_runner.cc @@ -91,6 +91,7 @@ Status TrainingRunner::Initialize() { config.weight_names_to_not_train = params_.weights_not_to_train; config.immutable_weights = params_.immutable_weights; + config.gradient_graph_config.use_invertible_layernorm_grad = params_.use_invertible_layernorm_grad; config.set_gradients_as_graph_outputs = false; config.gradient_accumulation_steps = params_.gradient_accumulation_steps; diff --git a/orttraining/orttraining/models/runner/training_runner.h b/orttraining/orttraining/models/runner/training_runner.h index 015d21b470..4b42a56168 100644 --- a/orttraining/orttraining/models/runner/training_runner.h +++ b/orttraining/orttraining/models/runner/training_runner.h @@ -169,6 +169,9 @@ class TrainingRunner { // Enable GELU approximation bool enable_gelu_approximation = false; + + // Use invertible layernorm grad + bool use_invertible_layernorm_grad = false; }; TrainingRunner(Parameters params, const Environment& env); diff --git a/orttraining/orttraining/python/ort_trainer.py b/orttraining/orttraining/python/ort_trainer.py index c58580902b..33e00252e9 100644 --- a/orttraining/orttraining/python/ort_trainer.py +++ b/orttraining/orttraining/python/ort_trainer.py @@ -376,7 +376,8 @@ def create_ort_training_session_with_optimizer(model, device, training_optimizer deepspeed_zero_stage=0, enable_grad_norm_clip=True, frozen_weights=[], opset_version=DEFAULT_OPSET_VERSION, - use_deterministic_compute=False): + use_deterministic_compute=False, + use_invertible_layernorm_grad=False): output_name = model.graph.output[0].name ort_parameters = ort.TrainingParameters() ort_parameters.loss_output_name = output_name @@ -384,11 +385,11 @@ def create_ort_training_session_with_optimizer(model, device, training_optimizer ort_parameters.world_rank = world_rank ort_parameters.world_size = world_size ort_parameters.gradient_accumulation_steps = gradient_accumulation_steps - ort_parameters.use_mixed_precision = use_mixed_precision ort_parameters.allreduce_post_accumulation = allreduce_post_accumulation ort_parameters.deepspeed_zero_stage = deepspeed_zero_stage ort_parameters.enable_grad_norm_clip = enable_grad_norm_clip ort_parameters.set_gradients_as_graph_outputs = False + ort_parameters.use_invertible_layernorm_grad = use_invertible_layernorm_grad output_types = {} for output in model.graph.output: @@ -530,7 +531,8 @@ class ORTTrainer(): world_rank=0, world_size=1, use_mixed_precision=False, allreduce_post_accumulation=False, global_step=0, get_lr_this_step=None, loss_scaler=None, deepspeed_zero_stage=0, enable_grad_norm_clip=True, frozen_weights=[], _opset_version=DEFAULT_OPSET_VERSION, - _enable_internal_postprocess=True, _extra_postprocess=None, _use_deterministic_compute=False): + _enable_internal_postprocess=True, _extra_postprocess=None, _use_deterministic_compute=False, + use_invertible_layernorm_grad=False): super(ORTTrainer, self).__init__() """ Initialize ORTTrainer. @@ -599,6 +601,8 @@ class ORTTrainer(): Defaults to True _extra_postprocess: a callable to postprocess the ONNX model that is converted from PyTorch. Defaults to None + use_invertible_layernorm_grad: use invertible layernorm grad + Defaults to False """ warnings.warn('DISCLAIMER: This is an early version of an experimental training API and it is subject to change. DO NOT create production applications with it') self.is_train = True @@ -651,6 +655,7 @@ class ORTTrainer(): self.state_dict_ = None self._enable_internal_postprocess = _enable_internal_postprocess self._use_deterministic_compute = _use_deterministic_compute + self.use_invertible_layernorm_grad = use_invertible_layernorm_grad # use this special string to workaround a corner case that external loss_scale is passed into train_step as kwargs. # see prepare_input_and_fetches for more details. @@ -679,7 +684,8 @@ class ORTTrainer(): deepspeed_zero_stage=self.deepspeed_zero_stage_, enable_grad_norm_clip=self.enable_grad_norm_clip_, frozen_weights=self.frozen_weights_, opset_version=self.opset_version_, - use_deterministic_compute=self._use_deterministic_compute) + use_deterministic_compute=self._use_deterministic_compute, + use_invertible_layernorm_grad=self.use_invertible_layernorm_grad) self.loss_scale_input_name = self.session.loss_scale_input_name diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index a9d52c7a80..415e1c2c0e 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -47,6 +47,7 @@ struct TrainingParameters { int deepspeed_zero_stage = 0; bool enable_grad_norm_clip = true; bool set_gradients_as_graph_outputs = false; + bool use_invertible_layernorm_grad = false; }; struct TrainingConfigurationResult { @@ -144,6 +145,8 @@ TrainingConfigurationResult ConfigureSessionForTraining( config.optimizer_config = opt; } + config.gradient_graph_config.use_invertible_layernorm_grad = parameters.use_invertible_layernorm_grad; + training::TrainingSession::TrainingConfigurationResult config_result{}; OrtPybindThrowIfError(sess->ConfigureForTraining(config, config_result)); @@ -177,7 +180,8 @@ void addObjectMethodsForTraining(py::module& m) { .def_readwrite("gradient_accumulation_steps", &TrainingParameters::gradient_accumulation_steps) .def_readwrite("deepspeed_zero_stage", &TrainingParameters::deepspeed_zero_stage) .def_readwrite("enable_grad_norm_clip", &TrainingParameters::enable_grad_norm_clip) - .def_readwrite("set_gradients_as_graph_outputs", &TrainingParameters::set_gradients_as_graph_outputs); + .def_readwrite("set_gradients_as_graph_outputs", &TrainingParameters::set_gradients_as_graph_outputs) + .def_readwrite("use_invertible_layernorm_grad", &TrainingParameters::use_invertible_layernorm_grad); py::class_ config_result(m, "TrainingConfigurationResult", "pbdoc(Configuration result for training.)pbdoc"); config_result.def(py::init()) diff --git a/orttraining/orttraining/test/gradient/gradient_checker.cc b/orttraining/orttraining/test/gradient/gradient_checker.cc index e923b81926..92892bff21 100644 --- a/orttraining/orttraining/test/gradient/gradient_checker.cc +++ b/orttraining/orttraining/test/gradient/gradient_checker.cc @@ -18,6 +18,7 @@ limitations under the License. #include "gradient_checker.h" #include "gradient_op_test_utils.h" #include "orttraining/core/framework/gradient_graph_builder.h" +#include "orttraining/core/graph/gradient_config.h" #include "test/util/include/test_random_seed.h" #include @@ -305,10 +306,12 @@ inline Status GradientChecker::InitOpTesterWithGradGraph( } } + training::GradientGraphConfiguration gradient_graph_config; training::GradientGraphBuilder grad_graph_builder(&graph, dy_values, weights_to_train, "", + gradient_graph_config, true); Status status = grad_graph_builder.Build(); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); diff --git a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc index c8d546dba3..048efc6078 100644 --- a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc +++ b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc @@ -5,6 +5,7 @@ #include "core/session/inference_session.h" #include "orttraining/core/session/training_session.h" #include "orttraining/core/framework/gradient_graph_builder.h" +#include "orttraining/core/graph/gradient_config.h" #include "default_providers.h" namespace onnxruntime { @@ -69,10 +70,12 @@ void GradientOpTester::Run( } } + training::GradientGraphConfiguration gradient_graph_config; training::GradientGraphBuilder grad_graph_builder(&graph, dy_values, weights_to_train, "", + gradient_graph_config, true); status = grad_graph_builder.Build(); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); diff --git a/orttraining/orttraining/test/training_ops/cuda/layer_norm_test.cc b/orttraining/orttraining/test/training_ops/cuda/layer_norm_test.cc index 8fb8f543e7..cbd23c43f6 100644 --- a/orttraining/orttraining/test/training_ops/cuda/layer_norm_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/layer_norm_test.cc @@ -97,5 +97,143 @@ TEST(CudaKernelTest, LayerNormGrad_LargeSizeTensor) { TestLayerNormGrad(X_dims, -1, 5e-3); } +static void TestInvertibleLayerNormGrad( + const std::vector& x_dims, + int64_t axis = -1, + double error_tolerance = 1e-4, + bool test_fp16=false) { + const std::vector& n_x_m_dims = x_dims; + std::vector n_dims, m_dims; + ASSERT_TRUE(SplitDims(n_x_m_dims, axis, n_dims, m_dims).IsOK()); + + const auto N = std::accumulate(n_dims.begin(), n_dims.end(), static_cast(1), std::multiplies<>{}); + const auto M = std::accumulate(m_dims.begin(), m_dims.end(), static_cast(1), std::multiplies<>{}); + + CompareOpTester test{"InvertibleLayerNormalizationGrad", 1, kMSDomain}; + + test.AddAttribute("axis", axis); + + RandomValueGenerator random{}; + const auto Y_grad_data = random.Uniform(n_x_m_dims, k_random_data_min, k_random_data_max); + const auto X_data = random.Uniform(n_x_m_dims, k_random_data_min, k_random_data_max); + const auto scale_data = random.Uniform(m_dims, k_random_data_min, k_random_data_max); + const auto bias_data = random.Uniform(m_dims, k_random_data_min, k_random_data_max); + + // these inputs are dependent on X_data + std::vector mean_data(N); // mean(X) + std::vector inv_std_var_data(N); // 1 / sqrt(mean(X^2) - mean(X)^2 + epsilon) + std::vector Y_data(N*M); + { + using ConstEigenArrayMap = Eigen::Map>; + using EigenArrayMap = Eigen::Map>; + + ConstEigenArrayMap X{X_data.data(), M, N}; + + for (int i = 0; i < N; ++i) { + mean_data[i] = X.col(i).mean(); + inv_std_var_data[i] = X.col(i).square().mean() - mean_data[i] * mean_data[i]; + } + + // Compute Y = ((x - mean) * (inv_var) * scale + bias + EigenArrayMap Y(Y_data.data(), M, N); + + using EigenVectorArrayMap = Eigen::Map>; + using ConstEigenVectorArrayMap = Eigen::Map>; + ConstEigenVectorArrayMap mean(mean_data.data(), N); + EigenVectorArrayMap inv_std_var(inv_std_var_data.data(), N); + inv_std_var = (inv_std_var + k_epsilon_default).sqrt().inverse(); + + Y = (X.rowwise() - mean.transpose()).rowwise() * inv_std_var.transpose(); + + ConstEigenVectorArrayMap scale(scale_data.data(), M); + ConstEigenVectorArrayMap bias(bias_data.data(), M); + Y = (Y.colwise() * scale).colwise() + bias; + } + + if (test_fp16) { + std::vector Y_grad_data_half(Y_grad_data.size()); + std::vector Y_data_half(Y_data.size()); + std::vector scale_data_half(scale_data.size()); + std::vector bias_data_half(bias_data.size()); + ConvertFloatToMLFloat16(Y_grad_data.data(),Y_grad_data_half.data(), int(Y_grad_data.size())); + ConvertFloatToMLFloat16(Y_data.data(),Y_data_half.data(), int(Y_data.size())); + ConvertFloatToMLFloat16(scale_data.data(), scale_data_half.data(), int(scale_data.size())); + ConvertFloatToMLFloat16(bias_data.data(), bias_data_half.data(), int(bias_data.size())); + + test.AddInput("Y_grad", n_x_m_dims, Y_grad_data_half); + test.AddInput("Y", n_x_m_dims, Y_data_half); + test.AddInput("scale", m_dims, scale_data_half, true); + test.AddInput("bias", m_dims, bias_data_half); + + const auto X_grad_data = FillZeros(n_x_m_dims); + const auto scale_grad_data = FillZeros(m_dims); + const auto bias_grad_data = FillZeros(m_dims); + test.AddOutput("X_grad", n_x_m_dims, X_grad_data); + test.AddOutput("scale_grad_data", m_dims, scale_grad_data); + test.AddOutput("bias_grad_data", m_dims, bias_grad_data); + } else { + test.AddInput("Y_grad", n_x_m_dims, Y_grad_data); + test.AddInput("Y", n_x_m_dims, Y_data); + test.AddInput("scale", m_dims, scale_data, true); + test.AddInput("bias", m_dims, bias_data); + + const auto X_grad_data = FillZeros(n_x_m_dims); + const auto scale_grad_data = FillZeros(m_dims); + const auto bias_grad_data = FillZeros(m_dims); + test.AddOutput("X_grad", n_x_m_dims, X_grad_data); + test.AddOutput("scale_grad_data", m_dims, scale_grad_data); + test.AddOutput("bias_grad_data", m_dims, bias_grad_data); + } + test.AddInput("inv_std_var", n_dims, inv_std_var_data); + + if (test_fp16) { + test.CompareWithCPU(kCudaExecutionProvider, error_tolerance, error_tolerance); + } else { + test.CompareWithCPU(kCudaExecutionProvider, error_tolerance); + } +} + +TEST(CudaKernelTest, InvertibleLayerNormGrad_SmallSizeTensor) { + const std::vector X_dims{4, 20, 128}; + TestInvertibleLayerNormGrad(X_dims); +} + +TEST(CudaKernelTest, InvertibleLayerNormGrad_SmallSizeTensor_IntermediateAxis) { + const std::vector X_dims{4, 20, 16, 8}; + const int64_t axis = -2; + TestInvertibleLayerNormGrad(X_dims, axis); +} + +TEST(CudaKernelTest, InvertibleLayerNormGrad_MidSizeTensor) { + const std::vector X_dims{8, 80, 768}; + TestInvertibleLayerNormGrad(X_dims); +} + +TEST(CudaKernelTest, InvertibleLayerNormGrad_LargeSizeTensor) { + const std::vector X_dims{16, 512, 1024}; + TestInvertibleLayerNormGrad(X_dims, -1, 5e-3); +} + +TEST(CudaKernelTest, InvertibleLayerNormGrad_SmallSizeTensor_FP16) { + const std::vector X_dims{4, 20, 128}; + TestInvertibleLayerNormGrad(X_dims, -1, 2e-3, true); +} + +TEST(CudaKernelTest, InvertibleLayerNormGrad_SmallSizeTensor_IntermediateAxis_FP16) { + const std::vector X_dims{4, 20, 16, 8}; + const int64_t axis = -2; + TestInvertibleLayerNormGrad(X_dims, axis, 2e-3, true); +} + +TEST(CudaKernelTest, InvertibleLayerNormGrad_MidSizeTensor_FP16) { + const std::vector X_dims{8, 80, 768}; + TestInvertibleLayerNormGrad(X_dims, -1, 2e-3, true); +} + +TEST(CudaKernelTest, InvertibleLayerNormGrad_LargeSizeTensor_FP16) { + const std::vector X_dims{16, 512, 1024}; + TestInvertibleLayerNormGrad(X_dims, -1, 2e-3, true); +} + } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc index 1ae886545c..a302ad0f52 100644 --- a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc @@ -77,6 +77,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GistB class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GistBinarizeDecoder); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, LayerNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, LayerNormalizationGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, InvertibleLayerNormalizationGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, InvertibleLayerNormalizationGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SliceGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FastGeluGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BiasGeluGrad_dX); @@ -155,6 +157,8 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/cpu/nn/layer_norm.cc b/orttraining/orttraining/training_ops/cpu/nn/layer_norm.cc index fc2128b05b..e55fe5e411 100644 --- a/orttraining/orttraining/training_ops/cpu/nn/layer_norm.cc +++ b/orttraining/orttraining/training_ops/cpu/nn/layer_norm.cc @@ -20,7 +20,16 @@ namespace contrib { kCpuExecutionProvider, \ KernelDefBuilder() \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - LayerNormGrad); + LayerNormGrad); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + InvertibleLayerNormalizationGrad, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + InvertibleLayerNormGrad); REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(double) @@ -99,5 +108,76 @@ Status LayerNormGrad::Compute(OpKernelContext* op_kernel_context) const { return Status::OK(); } +template +InvertibleLayerNormGrad::InvertibleLayerNormGrad(const OpKernelInfo& op_kernel_info) + : OpKernel{op_kernel_info} { + ORT_ENFORCE(op_kernel_info.GetAttr("axis", &axis_).IsOK()); +} + +template +Status InvertibleLayerNormGrad::Compute(OpKernelContext* op_kernel_context) const { + const Tensor* Y_grad = op_kernel_context->Input(0); + const Tensor* Y = op_kernel_context->Input(1); + const Tensor* scale = op_kernel_context->Input(2); + const Tensor* bias = op_kernel_context->Input(3); + const Tensor* inv_std_var = op_kernel_context->Input(4); + + const auto& Y_shape = Y_grad->Shape(); + const auto& X_shape = Y_shape; + const auto axis = HandleNegativeAxis(axis_, X_shape.NumDimensions()); + const auto N = X_shape.SizeToDimension(axis); + const auto M = X_shape.SizeFromDimension(axis); + ORT_ENFORCE(M != 1); + const auto& scale_shape = scale->Shape(); + + Tensor* X_grad = op_kernel_context->Output(0, X_shape); + Tensor* scale_grad = op_kernel_context->Output(1, scale_shape); + Tensor* bias_grad = op_kernel_context->Output(2, scale_shape); + + // Note: Eigen has column-major storage order by default + ConstEigenArrayMap Y_grad_arr{Y_grad->Data(), M, N}; + ConstEigenArrayMap Y_arr{Y->Data(), M, N}; + ConstEigenVectorArrayMap scale_vec{scale->Data(), M}; + ConstEigenVectorArrayMap bias_vec{bias->Data(), M}; + ConstEigenVectorArrayMap inv_std_var_vec{inv_std_var->Data(), N}; + + EigenArrayMap X_grad_arr{X_grad->MutableData(), M, N}; + EigenVectorArrayMap scale_grad_vec{scale_grad->MutableData(), M}; + EigenVectorArrayMap bias_grad_vec{bias_grad->MutableData(), M}; + + using Array = Eigen::ArrayXX; + using RowVector = Eigen::Array; + + // A, B, C are calculated as below: + // A = Y_grad * (X - mean(X)) * inv_std_var + // B = Y_grad * scale * inv_std_var + // C = Y_grad * scale * inv_std_var * (X - mean(X)) * inv_std_var + + // A, B, and C are M x N + Array X_mean_difference_over_std_var = (Y_arr.colwise() - bias_vec).colwise() / scale_vec; + Array A = Y_grad_arr * X_mean_difference_over_std_var; + Array B = (Y_grad_arr.colwise() * scale_vec).rowwise() * inv_std_var_vec.cast().transpose(); + Array C = B * X_mean_difference_over_std_var; + + // mean_B = mean(Y_grad * scale * inv_std_var) + RowVector mean_B = B.colwise().mean(); // 1 x N + + // mean_C = mean(Y_grad * scale * inv_std_var * (X - mean(X)) * inv_std_var) + RowVector mean_C = C.colwise().mean(); // 1 x N + + // X_grad = Y_grad * scale * inv_std_var - mean_B - (X - mean(X)) * inv_std_var * mean_C + // = B - mean_B - (X - mean(X)) * inv_std_var * mean_c + X_grad_arr = B.rowwise() - mean_B - X_mean_difference_over_std_var.rowwise() * mean_C; + + // bias_grad = sum(Y_grad) + bias_grad_vec = Y_grad_arr.rowwise().sum(); + + // scale_grad = sum(Y_grad * (X - mean(X)) * inv_std_var) + // = sum(A) + scale_grad_vec = A.rowwise().sum(); + + return Status::OK(); +} + } // namespace contrib } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/nn/layer_norm.h b/orttraining/orttraining/training_ops/cpu/nn/layer_norm.h index 14024c6068..4cd6207c26 100644 --- a/orttraining/orttraining/training_ops/cpu/nn/layer_norm.h +++ b/orttraining/orttraining/training_ops/cpu/nn/layer_norm.h @@ -18,5 +18,15 @@ class LayerNormGrad final : public OpKernel { int64_t axis_; }; +template +class InvertibleLayerNormGrad final : public OpKernel { + public: + InvertibleLayerNormGrad(const OpKernelInfo& op_kernel_info); + Status Compute(OpKernelContext* op_kernel_context) const override; + + private: + int64_t axis_; +}; + } // namespace contrib } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index e648e6fb14..e3a3e87662 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -89,8 +89,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_MLFloat16, ReduceAllL2); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16, ReduceAllL2); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float, LayerNormalizationGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_float, LayerNormalizationGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_double, LayerNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_float, LayerNormalizationGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float, InvertibleLayerNormalizationGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_double, InvertibleLayerNormalizationGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_float, InvertibleLayerNormalizationGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SliceGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GatherElementsGrad); @@ -198,8 +201,11 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/cuda/nn/layer_norm.cc b/orttraining/orttraining/training_ops/cuda/nn/layer_norm.cc index f2e3b0e026..d7f8acfc8c 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/layer_norm.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/layer_norm.cc @@ -19,9 +19,19 @@ namespace cuda { KernelDefBuilder() \ .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ .TypeConstraint("U", DataTypeImpl::GetTensorType()), \ - LayerNormGrad); + LayerNormGrad); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + InvertibleLayerNormalizationGrad, \ + kMSDomain, \ + 1, \ + T##_##U, \ + kCudaExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("U", DataTypeImpl::GetTensorType()), \ + InvertibleLayerNormGrad); REGISTER_GRADIENT_KERNEL_TYPED(float, float) -REGISTER_GRADIENT_KERNEL_TYPED(double, float) +REGISTER_GRADIENT_KERNEL_TYPED(double, double) REGISTER_GRADIENT_KERNEL_TYPED(MLFloat16, float) template @@ -65,7 +75,58 @@ Status LayerNormGrad::ComputeInternal(OpKernelContext* p_op_kernel_context auto part_grad_gamma = GetScratchBuffer(part_size * n2); auto part_grad_beta = GetScratchBuffer(part_size * n2); - HostLayerNormGradient(GetDeviceProp(), Y_grad_data, mean_data, inv_std_var_data, X_data, n1, n2, scale_data, X_grad_data, scale_grad_data, bias_grad_data, + HostLayerNormGradient(GetDeviceProp(), Y_grad_data, X_data, reinterpret_cast(NULL), + scale_data, reinterpret_cast(NULL), mean_data, inv_std_var_data, n1, n2, + X_grad_data, scale_grad_data, bias_grad_data, + part_grad_gamma.get(), part_grad_beta.get(), part_size); + return Status::OK(); +} + +template +InvertibleLayerNormGrad::InvertibleLayerNormGrad(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) { + ORT_ENFORCE(op_kernel_info.GetAttr("axis", &axis_).IsOK()); +} + +template +Status InvertibleLayerNormGrad::ComputeInternal(OpKernelContext* p_op_kernel_context) const { + typedef typename ToCudaType::MappedType CudaT; + typedef typename ToCudaType::MappedType CudaU; + // Inputs + const Tensor* Y_grad = p_op_kernel_context->Input(0); + const Tensor* Y = p_op_kernel_context->Input(1); + const Tensor* scale = p_op_kernel_context->Input(2); + const Tensor* bias = p_op_kernel_context->Input(3); + const Tensor* inv_std_var = p_op_kernel_context->Input(4); + + auto Y_grad_data = reinterpret_cast(Y_grad->template Data()); + auto Y_data = reinterpret_cast(Y->template Data()); + auto scale_data = reinterpret_cast(scale->template Data()); + auto bias_data = reinterpret_cast(bias->template Data()); + auto inv_std_var_data = reinterpret_cast(inv_std_var->template Data()); + + const TensorShape& y_shape = Y->Shape(); + const TensorShape& x_shape = y_shape; + const int64_t axis = HandleNegativeAxis(axis_, x_shape.NumDimensions()); + auto n1 = x_shape.SizeToDimension(axis); + auto n2 = x_shape.SizeFromDimension(axis); + ORT_ENFORCE(n2 != 1, "n2 should not be 1"); + + // Outputs + Tensor* X_grad = p_op_kernel_context->Output(0, x_shape); + auto X_grad_data = reinterpret_cast(X_grad->template MutableData()); + + Tensor* scale_grad = p_op_kernel_context->Output(1, scale->Shape()); + Tensor* bias_grad = p_op_kernel_context->Output(2, scale->Shape()); + auto scale_grad_data = reinterpret_cast(scale_grad->template MutableData()); + auto bias_grad_data = reinterpret_cast(bias_grad->template MutableData()); + + const int part_size = 16; + auto part_grad_gamma = GetScratchBuffer(part_size * n2); + auto part_grad_beta = GetScratchBuffer(part_size * n2); + + HostLayerNormGradient(GetDeviceProp(), Y_grad_data, reinterpret_cast(NULL), Y_data, + scale_data, bias_data, reinterpret_cast(NULL), inv_std_var_data, n1, n2, + X_grad_data, scale_grad_data, bias_grad_data, part_grad_gamma.get(), part_grad_beta.get(), part_size); return Status::OK(); } diff --git a/orttraining/orttraining/training_ops/cuda/nn/layer_norm.h b/orttraining/orttraining/training_ops/cuda/nn/layer_norm.h index fd21b09ba4..ab092ed12a 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/layer_norm.h +++ b/orttraining/orttraining/training_ops/cuda/nn/layer_norm.h @@ -25,5 +25,15 @@ class LayerNormGrad final : public CudaKernel { int64_t axis_; }; +template +class InvertibleLayerNormGrad final : public CudaKernel { + public: + InvertibleLayerNormGrad(const OpKernelInfo& op_kernel_info); + Status ComputeInternal(OpKernelContext* ctx) const override; + + private: + int64_t axis_; +}; + } // namespace cuda } // namespace onnxruntime \ No newline at end of file diff --git a/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu b/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu index 7f85c5676f..11753e080b 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu @@ -372,10 +372,10 @@ void HostApplyLayerNorm( LAYERNORM_LINEAR_IMPL(float, float) LAYERNORM_LINEAR_IMPL(half, float) -LAYERNORM_LINEAR_IMPL(double, float) +LAYERNORM_LINEAR_IMPL(double, double) //LAYERNORM_LINEAR_IMPL(half, half) -template +template __device__ void cuLoadWriteStridedInputs( const int i1_block, const int thr_load_row_off, @@ -385,24 +385,34 @@ __device__ void cuLoadWriteStridedInputs( U* warp_buf1, U* warp_buf2, const T* input, + const T* output, const T* dout, const int i1_end, const int n2, + const T* __restrict__ gamma, + const T* __restrict__ beta, const U* __restrict__ mean, const U* __restrict__ invvar) { int i1 = i1_block + thr_load_row_off; if (i1 < i1_end) { - U curr_mean = mean[i1]; - U curr_invvar = invvar[i1]; + U curr_mean = use_mean ? mean[i1] : U(0); + U curr_invvar = use_mean ? invvar[i1] : U(0); for (int k = 0; k < blockDim.y; ++k) { int i2 = i2_off + k; int load_idx = i1 * n2 + i2; int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; if (i2 < n2) { - U curr_input = static_cast(input[load_idx]); U curr_dout = static_cast(dout[load_idx]); warp_buf1[write_idx] = curr_dout; - warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar; + if (use_mean) { + U curr_input = static_cast(input[load_idx]); + warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar; + } else { + U curr_gamma = static_cast(gamma[i2]); + U curr_beta = static_cast(beta[i2]); + U curr_output = static_cast(output[load_idx]); + warp_buf2[write_idx] = curr_dout * (curr_output - curr_beta) / curr_gamma; + } } else { warp_buf1[write_idx] = U(0); warp_buf2[write_idx] = U(0); @@ -417,7 +427,7 @@ __device__ void cuLoadWriteStridedInputs( } } -template +template __device__ void cuLoadAddStridedInputs( const int i1_block, const int thr_load_row_off, @@ -427,37 +437,50 @@ __device__ void cuLoadAddStridedInputs( U* warp_buf1, U* warp_buf2, const T* input, + const T* output, const T* dout, const int i1_end, const int n2, + const T* __restrict__ gamma, + const T* __restrict__ beta, const U* __restrict__ mean, const U* __restrict__ invvar) { int i1 = i1_block + thr_load_row_off; if (i1 < i1_end) { - U curr_mean = mean[i1]; - U curr_invvar = invvar[i1]; + U curr_mean = use_mean ? mean[i1] : U(0); + U curr_invvar = use_mean ? invvar[i1] : U(0); for (int k = 0; k < blockDim.y; ++k) { int i2 = i2_off + k; int load_idx = i1 * n2 + i2; int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; if (i2 < n2) { - U curr_input = static_cast(input[load_idx]); U curr_dout = static_cast(dout[load_idx]); warp_buf1[write_idx] += curr_dout; - warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; + if (use_mean) { + U curr_input = static_cast(input[load_idx]); + warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; + } else { + U curr_gamma = static_cast(gamma[i2]); + U curr_beta = static_cast(beta[i2]); + U curr_output = static_cast(output[load_idx]); + warp_buf2[write_idx] += curr_dout * (curr_output - curr_beta) / curr_gamma; + } } } } } -template +template __global__ void cuComputePartGradGammaBeta( const T* __restrict__ dout, const T* __restrict__ input, - const int n1, - const int n2, + const T* __restrict__ output, + const T* __restrict__ gamma, + const T* __restrict__ beta, const U* __restrict__ mean, const U* __restrict__ invvar, + const int n1, + const int n2, U* part_grad_gamma, U* part_grad_beta) { const int numsegs_n1 = (n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y); @@ -475,9 +498,9 @@ __global__ void cuComputePartGradGammaBeta( U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; // compute partial sums from strided inputs // do this to increase number of loads in flight - cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off, row_stride, warp_buf1, warp_buf2, input, dout, i1_end, n2, mean, invvar); + cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off, row_stride, warp_buf1, warp_buf2, input, output, dout, i1_end, n2, gamma, beta, mean, invvar); for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end; i1_block += blockDim.y * blockDim.y) { - cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off, row_stride, warp_buf1, warp_buf2, input, dout, i1_end, n2, mean, invvar); + cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off, row_stride, warp_buf1, warp_buf2, input, output, dout, i1_end, n2, gamma, beta, mean, invvar); } __syncthreads(); // inter-warp reductions @@ -566,22 +589,25 @@ __global__ void cuComputeGradGammaBeta( } } -template +template __global__ void cuComputeGradInput( const T* __restrict__ dout, const T* __restrict__ input, - const int n1, - const int n2, + const T* __restrict__ output, + const T* gamma, + const T* beta, const U* __restrict__ mean, const U* __restrict__ invvar, - const T* gamma, + const int n1, + const int n2, T* grad_input) { for (int i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { U sum_loss1 = U(0); U sum_loss2 = U(0); - const U c_mean = mean[i1]; + const U c_mean = use_mean ? mean[i1] : U(0); const U c_invvar = invvar[i1]; - const T* k_input = input + i1 * n2; + const T* k_input = use_mean ? input + i1 * n2 : nullptr; + const T* k_output = use_mean ? nullptr: output + i1 * n2; const T* k_dout = dout + i1 * n2; const int numx = blockDim.x * blockDim.y; const int thrx = threadIdx.x + threadIdx.y * blockDim.x; @@ -589,33 +615,53 @@ __global__ void cuComputeGradInput( int l = 4 * thrx; for (; l + 3 < n2; l += 4 * numx) { for (int k = 0; k < 4; ++k) { - const U c_h = static_cast(k_input[l + k]); const U c_loss = static_cast(k_dout[l + k]); sum_loss1 += c_loss * U(gamma[l + k]); - sum_loss2 += c_loss * U(gamma[l + k]) * (c_h - c_mean) * c_invvar; + if (use_mean) { + const U c_h = static_cast(k_input[l + k]); + sum_loss2 += c_loss * U(gamma[l + k]) * (c_h - c_mean) * c_invvar; + } else { + const U c_output = static_cast(k_output[l + k]); + sum_loss2 += c_loss * (c_output - U(beta[l + k])); + } } } for (; l < n2; ++l) { - const U c_h = static_cast(k_input[l]); const U c_loss = static_cast(k_dout[l]); sum_loss1 += c_loss * U(gamma[l]); - sum_loss2 += c_loss * U(gamma[l]) * (c_h - c_mean) * c_invvar; + if (use_mean) { + const U c_h = static_cast(k_input[l]); + sum_loss2 += c_loss * U(gamma[l]) * (c_h - c_mean) * c_invvar; + } else { + const U c_output = static_cast(k_output[l]); + sum_loss2 += c_loss * (c_output - U(beta[l])); + } } } else { int l = 4 * thrx; for (; l + 3 < n2; l += 4 * numx) { for (int k = 0; k < 4; ++k) { - const U c_h = static_cast(k_input[l + k]); const U c_loss = static_cast(k_dout[l + k]); sum_loss1 += c_loss; - sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + if (use_mean) { + const U c_h = static_cast(k_input[l + k]); + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } else { + const U c_output = static_cast(k_output[l + k]); + sum_loss2 += c_loss * c_output; + } } } for (; l < n2; ++l) { - const U c_h = static_cast(k_input[l]); const U c_loss = static_cast(k_dout[l]); sum_loss1 += c_loss; - sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + if (use_mean) { + const U c_h = static_cast(k_input[l]); + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } else { + const U c_output = static_cast(k_output[l]); + sum_loss2 += c_loss * c_output; + } } } // intra-warp reductions @@ -659,21 +705,31 @@ __global__ void cuComputeGradInput( T* k_grad_input = grad_input + i1 * n2; if (gamma != NULL) { for (int l = thrx; l < n2; l += numx) { - const U c_h = static_cast(k_input[l]); const U c_loss = static_cast(k_dout[l]); U f_grad_input = fH * c_loss * U(gamma[l]); f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + if (use_mean) { + const U c_h = static_cast(k_input[l]); + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + } else { + const U c_output = static_cast(k_output[l]); + f_grad_input -= (c_output - U(beta[l])) / U(gamma[l]) * sum_loss2; + } f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); } } else { for (int l = thrx; l < n2; l += numx) { - const U c_h = static_cast(k_input[l]); const U c_loss = static_cast(k_dout[l]); U f_grad_input = fH * c_loss; f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + if (use_mean) { + const U c_h = static_cast(k_input[l]); + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + } else { + const U c_output = static_cast(k_output[l]); + f_grad_input -= c_output * sum_loss2; + } f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); } @@ -683,20 +739,22 @@ __global__ void cuComputeGradInput( template void HostLayerNormGradient( - const cudaDeviceProp& prop, - const T* dout, - const U* mean, - const U* invvar, - const T* input, - int64_t n1, - int64_t n2, - const T* gamma, - T* grad_input, - T* grad_gamma, - T* grad_beta, - U* part_grad_gamma, - U* part_grad_beta, - const int part_size) { + const cudaDeviceProp& prop, + const T* dout, + const T* input, + const T* output, + const T* gamma, + const T* beta, + const U* mean, + const U* invvar, + int64_t n1, + int64_t n2, + T* grad_input, + T* grad_gamma, + T* grad_beta, + U* part_grad_gamma, + U* part_grad_beta, + const int part_size) { const int warp_size = prop.warpSize; ORT_ENFORCE(warp_size == GPU_WARP_SIZE); @@ -706,14 +764,31 @@ void HostLayerNormGradient( const int nshared2_b = threads2.x * threads2.y * sizeof(U); const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; - cuComputePartGradGammaBeta<<>>( + if (mean == nullptr) { + cuComputePartGradGammaBeta<<>>( dout, input, - n1, n2, + output, + gamma, + beta, mean, invvar, + n1, n2, part_grad_gamma, part_grad_beta); + } else { + cuComputePartGradGammaBeta<<>>( + dout, + input, + output, + gamma, + beta, + mean, + invvar, + n1, n2, + part_grad_gamma, + part_grad_beta); + } const dim3 threads3(warp_size, 8, 1); const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1); @@ -732,22 +807,38 @@ void HostLayerNormGradient( const dim3 threads1(warp_size, 4, 1); int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0; - cuComputeGradInput<<>>( + if (mean == nullptr) { + cuComputeGradInput<<>>( dout, input, - n1, n2, + output, + gamma, + beta, mean, invvar, - gamma, + n1, n2, grad_input); + } else { + cuComputeGradInput<<>>( + dout, + input, + output, + gamma, + beta, + mean, + invvar, + n1, n2, + grad_input); + } } -#define LAYERNORMGRAD_IMPL(T, U) \ - template void HostLayerNormGradient(const cudaDeviceProp& prop, const T* dout, const U* mean, const U* invvar, const T* input, int64_t n1, int64_t n2, const T* gamma, \ +#define LAYERNORMGRAD_IMPL(T, U) \ + template void HostLayerNormGradient(const cudaDeviceProp& prop, const T* dout, const T* input, const T* output, \ + const T* gamma, const T* beta, const U* mean, const U* invvar, int64_t n1, int64_t n2, \ T* grad_input, T* grad_gamma, T* grad_beta, U* part_grad_gamma, U* part_grad_beta, const int part_size); LAYERNORMGRAD_IMPL(float, float) -LAYERNORMGRAD_IMPL(double, float) +LAYERNORMGRAD_IMPL(double, double) LAYERNORMGRAD_IMPL(half, float) } // namespace cuda diff --git a/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.h b/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.h index 285d36e94e..ea61ff71ab 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.h +++ b/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.h @@ -45,12 +45,14 @@ template void HostLayerNormGradient( const cudaDeviceProp& prop, const T* dout, + const T* input, + const T* output, + const T* gamma, + const T* beta, const U* mean, const U* invvar, - const T* input, int64_t n1, int64_t n2, - const T* gamma, T* grad_input, T* grad_gamma, T* grad_beta,