mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-28 03:20:58 +00:00
Optimize LayernormGrad (#4156)
* Draft for LayerNorm Optimization * Modify LayernormGrad kernel based on new backward graph. * keep two LayernormGrad implementations. One is implemented based on input X, mean. The other is based on output Y, scale, bias. The first one is enabled by default. The second one can be enabled by --use_invertible_layernorm_grad * expose use_invertible_layernorm_grad to frontend. * add fp16 tests. Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net> Co-authored-by: Weixing Zhang <wezhan@microsoft.com>
This commit is contained in:
parent
33e06be4ac
commit
bd11ab6816
29 changed files with 594 additions and 84 deletions
|
|
@ -61,7 +61,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float, LayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double_float, LayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double_double, LayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float, LayerNormalization);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear);
|
||||
|
|
@ -125,7 +125,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double_float, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double_double, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse)>,
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ namespace cuda {
|
|||
LayerNorm<T, U>);
|
||||
|
||||
REGISTER_KERNEL_TYPED(float, float)
|
||||
REGISTER_KERNEL_TYPED(double, float)
|
||||
REGISTER_KERNEL_TYPED(double, double)
|
||||
REGISTER_KERNEL_TYPED(MLFloat16, float)
|
||||
|
||||
template <typename T, typename U>
|
||||
|
|
|
|||
|
|
@ -376,7 +376,7 @@ void HostApplyLayerNorm(
|
|||
|
||||
LAYERNORM_LINEAR_IMPL(float, float)
|
||||
LAYERNORM_LINEAR_IMPL(half, float)
|
||||
LAYERNORM_LINEAR_IMPL(double, float)
|
||||
LAYERNORM_LINEAR_IMPL(double, double)
|
||||
//LAYERNORM_LINEAR_IMPL(half, half)
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include "core/graph/schema_registry.h"
|
||||
#include "orttraining/core/framework/gradient_graph_builder.h"
|
||||
#include "orttraining/core/graph/gradient_builder_registry.h"
|
||||
#include "orttraining/core/graph/gradient_config.h"
|
||||
#include "orttraining/core/optimizer/insert_output_rewriter.h"
|
||||
#include "core/optimizer/gelu_fusion.h"
|
||||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
|
|
@ -22,9 +23,11 @@ GradientGraphBuilder::GradientGraphBuilder(Graph* graph,
|
|||
const unordered_set<string>& y_node_arg_names,
|
||||
const unordered_set<string>& x_node_arg_names,
|
||||
string loss_node_arg_name,
|
||||
const GradientGraphConfiguration& gradient_graph_config,
|
||||
const bool set_gradient_as_graph_output)
|
||||
: graph_(graph),
|
||||
loss_node_arg_name_(loss_node_arg_name),
|
||||
gradient_graph_config_(gradient_graph_config),
|
||||
set_gradient_as_graph_output_(set_gradient_as_graph_output) {
|
||||
auto rule_based_graph_transformer =
|
||||
onnxruntime::make_unique<RuleBasedGraphTransformer>("pre_training_rule_based_graph_transformer");
|
||||
|
|
@ -187,7 +190,7 @@ Status GradientGraphBuilder::Build() {
|
|||
}
|
||||
}
|
||||
|
||||
GradientDef node_defs = GetGradientForOp(node, output_args_need_grad, input_args_need_grad);
|
||||
GradientDef node_defs = GetGradientForOp(gradient_graph_config_, node, output_args_need_grad, input_args_need_grad);
|
||||
|
||||
// updates arg name if gradient accumulation is needed
|
||||
for (auto& op_def : node_defs) {
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ class GradientGraphBuilder {
|
|||
const std::unordered_set<std::string>& y_node_arg_names,
|
||||
const std::unordered_set<std::string>& x_node_arg_names,
|
||||
std::string loss_node_arg_name,
|
||||
const GradientGraphConfiguration& gradient_graph_config,
|
||||
const bool set_gradient_as_graph_output = false);
|
||||
|
||||
Status Build();
|
||||
|
|
@ -73,6 +74,8 @@ class GradientGraphBuilder {
|
|||
|
||||
std::string loss_node_arg_name_;
|
||||
|
||||
const GradientGraphConfiguration& gradient_graph_config_;
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr_{5};
|
||||
|
||||
// key: ArgDef for the gradient after accumulation
|
||||
|
|
|
|||
|
|
@ -956,11 +956,19 @@ IMPLEMENT_GRADIENT_BUILDER(GetFastGeluGradient) {
|
|||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetLayerNormalizationGradient) {
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef(OpDef{"LayerNormalizationGrad", kMSDomain, 1},
|
||||
{GO(0), I(0), I(1), O(1), O(2)},
|
||||
{GI(0), GI(1), GI(2)},
|
||||
{SrcNodeAttributes()})};
|
||||
if (GetGradientGraphConfiguration().use_invertible_layernorm_grad) {
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef(OpDef{"InvertibleLayerNormalizationGrad", kMSDomain, 1},
|
||||
{GO(0), O(0), I(1), I(2), O(2)},
|
||||
{GI(0), GI(1), GI(2)},
|
||||
{SrcNodeAttributes()})};
|
||||
} else {
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef(OpDef{"LayerNormalizationGrad", kMSDomain, 1},
|
||||
{GO(0), I(0), I(1), O(1), O(2)},
|
||||
{GI(0), GI(1), GI(2)},
|
||||
{SrcNodeAttributes()})};
|
||||
}
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetBatchNormalizationGradient) {
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
#include <string>
|
||||
#include "core/graph/graph.h"
|
||||
#include "orttraining/core/graph/graph_augmenter.h"
|
||||
#include "orttraining/core/graph/gradient_config.h"
|
||||
#include "onnx/defs/attr_proto_util.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -27,10 +28,11 @@ typedef std::vector<NodeDef> GradientDef;
|
|||
class GradientBuilderBase {
|
||||
public:
|
||||
GradientBuilderBase(
|
||||
const GradientGraphConfiguration& gradient_graph_config,
|
||||
const Node* node,
|
||||
const std::unordered_set<std::string>& gradient_inputs,
|
||||
const std::unordered_set<std::string>& gradient_outputs)
|
||||
: node_(node), gradient_inputs_(gradient_inputs), gradient_outputs_(gradient_outputs) {
|
||||
: gradient_graph_config_(gradient_graph_config), node_(node), gradient_inputs_(gradient_inputs), gradient_outputs_(gradient_outputs) {
|
||||
unique_node_prefix_ = CreateUniqueNodePrefix();
|
||||
}
|
||||
|
||||
|
|
@ -54,6 +56,10 @@ class GradientBuilderBase {
|
|||
protected:
|
||||
virtual GradientDef GetGradientDefsImpl() const = 0;
|
||||
|
||||
const GradientGraphConfiguration& GetGradientGraphConfiguration() const {
|
||||
return gradient_graph_config_;
|
||||
}
|
||||
|
||||
// i-th input of forward op
|
||||
ArgDef I(const size_t i) const {
|
||||
ORT_ENFORCE(i < node_->InputDefs().size());
|
||||
|
|
@ -185,6 +191,7 @@ class GradientBuilderBase {
|
|||
return unique_prefix.str();
|
||||
}
|
||||
|
||||
const GradientGraphConfiguration& gradient_graph_config_;
|
||||
const Node* node_;
|
||||
std::string unique_node_prefix_;
|
||||
|
||||
|
|
|
|||
|
|
@ -3,11 +3,13 @@
|
|||
|
||||
#include "orttraining/core/graph/gradient_builder_registry.h"
|
||||
#include "orttraining/core/graph/gradient_builder.h"
|
||||
#include "orttraining/core/graph/gradient_config.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
|
||||
GradientDef GetGradientForOp(const Node* node,
|
||||
GradientDef GetGradientForOp(const GradientGraphConfiguration& gradient_graph_config,
|
||||
const Node* node,
|
||||
const std::unordered_set<std::string>& output_args_need_grad,
|
||||
const std::unordered_set<std::string>& input_args_need_grad) {
|
||||
|
||||
|
|
@ -16,6 +18,7 @@ GradientDef GetGradientForOp(const Node* node,
|
|||
// less than 9 is not supported and for Slice we have Slice-1, Slice-10 and Slice-11.
|
||||
|
||||
auto gradient_builder = GradientBuilderRegistry::GetInstance().MakeUnique(node->OpType(),
|
||||
gradient_graph_config,
|
||||
node,
|
||||
output_args_need_grad,
|
||||
input_args_need_grad);
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ namespace onnxruntime {
|
|||
namespace training {
|
||||
|
||||
typedef GenericRegistry<GradientBuilderBase,
|
||||
const GradientGraphConfiguration&,
|
||||
const Node*&, //node
|
||||
const std::unordered_set<std::string>&, // gradient_inputs
|
||||
const std::unordered_set<std::string>&> // gradient_outputs
|
||||
|
|
@ -31,7 +32,8 @@ class GradientBuilderRegistry : public GradientRegistryType {
|
|||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GradientBuilderRegistry);
|
||||
};
|
||||
|
||||
GradientDef GetGradientForOp(const Node* node,
|
||||
GradientDef GetGradientForOp(const GradientGraphConfiguration& gradient_graph_config,
|
||||
const Node* node,
|
||||
const std::unordered_set<std::string>& output_args_need_grad,
|
||||
const std::unordered_set<std::string>& input_args_need_grad);
|
||||
|
||||
|
|
|
|||
18
orttraining/orttraining/core/graph/gradient_config.h
Normal file
18
orttraining/orttraining/core/graph/gradient_config.h
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
|
||||
struct GradientGraphConfiguration {
|
||||
// Layernorm gradient can be computed based on either input or output of layernorm.
|
||||
// That is to say, either input or output needs to be stashed for layernorm gradient.
|
||||
// To save memory, ideally, only one(input vs output) should be stashed rather than both.
|
||||
// By default, the input based algorithm is used. This flag is to enable the output based algorithm.
|
||||
bool use_invertible_layernorm_grad{false};
|
||||
};
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1432,6 +1432,32 @@ Example 4:
|
|||
{"tensor(float)"},
|
||||
"Constrain mean and inv_std_var to float tensors.");
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(InvertibleLayerNormalizationGrad)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL)
|
||||
.SetDoc("LayerNormalizationGrad")
|
||||
.Attr("axis",
|
||||
"The first normalization dimension: normalization will be performed along dimensions axis : rank(inputs).",
|
||||
AttributeProto::INT, static_cast<int64_t>(-1))
|
||||
.AllowUncheckedAttributes()
|
||||
.Input(0, "Y_grad", "The gradient tensor from output.", "T")
|
||||
.Input(1, "Y", "Output data tensor from the forward path", "T")
|
||||
.Input(2, "scale", "Scale tensor.", "T")
|
||||
.Input(3, "bias", "Bias tensor.", "T")
|
||||
.Input(4, "inv_std_var", "inverse std variance of X.", "U")
|
||||
.Output(0, "X_grad", "Gradient of the input.", "T")
|
||||
.Output(1, "scale_grad", "Gradient of the scale.", "T")
|
||||
.Output(2, "bias_grad", "Gradient of the bias.", "T")
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
{"tensor(float16)", "tensor(float)", "tensor(double)"},
|
||||
"Constrain input and output types (except mean and inv_std_var) to float tensors.")
|
||||
.TypeConstraint(
|
||||
"U",
|
||||
{"tensor(float)"},
|
||||
"Constrain mean and inv_std_var to float tensors.");
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(BatchNormalizationGrad)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
|
|
|
|||
|
|
@ -219,7 +219,7 @@ Status TrainingSession::ConfigureForTraining(
|
|||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(BuildGradientGraph(
|
||||
weight_names_to_train, loss_name, config.set_gradients_as_graph_outputs));
|
||||
weight_names_to_train, loss_name, config.gradient_graph_config, config.set_gradients_as_graph_outputs));
|
||||
|
||||
// transform for mixed precision
|
||||
std::unordered_map<std::string, NodeArg*> fp32_weight_name_to_fp16_node_arg{};
|
||||
|
|
@ -425,12 +425,14 @@ static Status ConfigureLossFunctionInternal(
|
|||
static Status BuildGradientGraphInternal(Graph& graph,
|
||||
const std::string& loss_function_output_name,
|
||||
const std::unordered_set<std::string>& node_arg_names_to_train,
|
||||
const GradientGraphConfiguration& gradient_graph_config,
|
||||
const bool set_gradient_as_graph_output = false) {
|
||||
// Compute the gradient graph def.
|
||||
GradientGraphBuilder grad_graph_builder(&graph,
|
||||
{loss_function_output_name},
|
||||
node_arg_names_to_train,
|
||||
loss_function_output_name,
|
||||
gradient_graph_config,
|
||||
set_gradient_as_graph_output);
|
||||
return grad_graph_builder.Build();
|
||||
}
|
||||
|
|
@ -638,13 +640,16 @@ Status TrainingSession::EnableMixedPrecision(const std::unordered_set<std::strin
|
|||
|
||||
Status TrainingSession::BuildGradientGraph(const std::unordered_set<std::string>& weights_to_train,
|
||||
const std::string& loss_function_output_name,
|
||||
const GradientGraphConfiguration& gradient_graph_config,
|
||||
const bool set_gradient_as_graph_output) {
|
||||
// Fill weights_to_train_ according to weights_to_train
|
||||
weights_to_train_ = weights_to_train;
|
||||
gradient_graph_config_ = gradient_graph_config;
|
||||
|
||||
ORT_RETURN_IF_ERROR(BuildGradientGraphInternal(model_->MainGraph(),
|
||||
loss_function_output_name,
|
||||
weights_to_train_,
|
||||
gradient_graph_config_,
|
||||
set_gradient_as_graph_output));
|
||||
|
||||
return DoPostLoadProcessing(*model_);
|
||||
|
|
@ -762,6 +767,7 @@ Status TrainingSession::Save(const PathString& model_uri, TrainingSession::SaveO
|
|||
ORT_RETURN_IF_ERROR(BuildGradientGraphInternal(new_model->MainGraph(),
|
||||
actual_loss_name,
|
||||
weights_to_train_,
|
||||
gradient_graph_config_,
|
||||
false));
|
||||
|
||||
OptimizerOutputKeyMap<std::string> opt_graph_outputs;
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
#include "orttraining/core/graph/loss_function_registry.h"
|
||||
#include "orttraining/core/graph/optimizer_graph_output_key.h"
|
||||
#include "orttraining/core/graph/optimizer_config.h"
|
||||
#include "orttraining/core/graph/gradient_config.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
|
|
@ -42,6 +43,9 @@ class TrainingSession : public InferenceSession {
|
|||
// The immutable weights specification.
|
||||
ImmutableWeights immutable_weights;
|
||||
|
||||
// Gradient graph configuration
|
||||
GradientGraphConfiguration gradient_graph_config{};
|
||||
|
||||
// Whether to set the gradients as graph outputs.
|
||||
bool set_gradients_as_graph_outputs{false};
|
||||
|
||||
|
|
@ -409,6 +413,7 @@ class TrainingSession : public InferenceSession {
|
|||
*/
|
||||
common::Status BuildGradientGraph(const std::unordered_set<std::string>& weights_to_train,
|
||||
const std::string& loss_function_output_name,
|
||||
const GradientGraphConfiguration& gradient_graph_config,
|
||||
const bool set_gradient_as_graph_output = false);
|
||||
|
||||
common::Status BuildAccumulationNode(const std::unordered_set<std::string>& weights_to_train);
|
||||
|
|
@ -469,6 +474,8 @@ class TrainingSession : public InferenceSession {
|
|||
std::unordered_set<std::string> dropout_eval_feeds_;
|
||||
OptimizerGraphConfig opt_graph_config_;
|
||||
std::unordered_map<std::string, OptimizerNodeConfig> opt_configs_;
|
||||
|
||||
GradientGraphConfiguration gradient_graph_config_;
|
||||
};
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -165,7 +165,9 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet
|
|||
("enable_grad_norm_clip", "Specify whether to enable gradient clipping for optimizers.",
|
||||
cxxopts::value<bool>()->default_value("true"))
|
||||
("enable_gelu_approximation", "Specify whether to enable GELU approximation.",
|
||||
cxxopts::value<bool>()->default_value("true"));
|
||||
cxxopts::value<bool>()->default_value("true"))
|
||||
("use_invertible_layernorm_grad", "Specify whether to use invertible laynorm(dropping the input activation)",
|
||||
cxxopts::value<bool>()->default_value("false"));
|
||||
options
|
||||
.add_options("ORT configuration")
|
||||
("ort_log_severity", "ORT minimum logging severity (see onnxruntime::logging::Severity values)",
|
||||
|
|
@ -458,12 +460,15 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet
|
|||
"Log severity must be in the range [", static_cast<int>(logging::Severity::kVERBOSE),
|
||||
", ", static_cast<int>(logging::Severity::kFATAL), "].");
|
||||
ort_params.vlog_level = flags["ort_vlog_level"].as<int>();
|
||||
|
||||
params.use_invertible_layernorm_grad = flags["use_invertible_layernorm_grad"].as<bool>();
|
||||
} catch (const exception& e) {
|
||||
const std::string msg = "Failed to parse the command line arguments";
|
||||
cerr << msg << ": " << e.what() << "\n"
|
||||
<< options.help() << "\n";
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -91,6 +91,7 @@ Status TrainingRunner::Initialize() {
|
|||
config.weight_names_to_not_train = params_.weights_not_to_train;
|
||||
config.immutable_weights = params_.immutable_weights;
|
||||
|
||||
config.gradient_graph_config.use_invertible_layernorm_grad = params_.use_invertible_layernorm_grad;
|
||||
config.set_gradients_as_graph_outputs = false;
|
||||
|
||||
config.gradient_accumulation_steps = params_.gradient_accumulation_steps;
|
||||
|
|
|
|||
|
|
@ -169,6 +169,9 @@ class TrainingRunner {
|
|||
|
||||
// Enable GELU approximation
|
||||
bool enable_gelu_approximation = false;
|
||||
|
||||
// Use invertible layernorm grad
|
||||
bool use_invertible_layernorm_grad = false;
|
||||
};
|
||||
|
||||
TrainingRunner(Parameters params, const Environment& env);
|
||||
|
|
|
|||
|
|
@ -376,7 +376,8 @@ def create_ort_training_session_with_optimizer(model, device, training_optimizer
|
|||
deepspeed_zero_stage=0,
|
||||
enable_grad_norm_clip=True,
|
||||
frozen_weights=[], opset_version=DEFAULT_OPSET_VERSION,
|
||||
use_deterministic_compute=False):
|
||||
use_deterministic_compute=False,
|
||||
use_invertible_layernorm_grad=False):
|
||||
output_name = model.graph.output[0].name
|
||||
ort_parameters = ort.TrainingParameters()
|
||||
ort_parameters.loss_output_name = output_name
|
||||
|
|
@ -384,11 +385,11 @@ def create_ort_training_session_with_optimizer(model, device, training_optimizer
|
|||
ort_parameters.world_rank = world_rank
|
||||
ort_parameters.world_size = world_size
|
||||
ort_parameters.gradient_accumulation_steps = gradient_accumulation_steps
|
||||
ort_parameters.use_mixed_precision = use_mixed_precision
|
||||
ort_parameters.allreduce_post_accumulation = allreduce_post_accumulation
|
||||
ort_parameters.deepspeed_zero_stage = deepspeed_zero_stage
|
||||
ort_parameters.enable_grad_norm_clip = enable_grad_norm_clip
|
||||
ort_parameters.set_gradients_as_graph_outputs = False
|
||||
ort_parameters.use_invertible_layernorm_grad = use_invertible_layernorm_grad
|
||||
|
||||
output_types = {}
|
||||
for output in model.graph.output:
|
||||
|
|
@ -530,7 +531,8 @@ class ORTTrainer():
|
|||
world_rank=0, world_size=1, use_mixed_precision=False, allreduce_post_accumulation=False,
|
||||
global_step=0, get_lr_this_step=None, loss_scaler=None, deepspeed_zero_stage=0,
|
||||
enable_grad_norm_clip=True, frozen_weights=[], _opset_version=DEFAULT_OPSET_VERSION,
|
||||
_enable_internal_postprocess=True, _extra_postprocess=None, _use_deterministic_compute=False):
|
||||
_enable_internal_postprocess=True, _extra_postprocess=None, _use_deterministic_compute=False,
|
||||
use_invertible_layernorm_grad=False):
|
||||
super(ORTTrainer, self).__init__()
|
||||
"""
|
||||
Initialize ORTTrainer.
|
||||
|
|
@ -599,6 +601,8 @@ class ORTTrainer():
|
|||
Defaults to True
|
||||
_extra_postprocess: a callable to postprocess the ONNX model that is converted from PyTorch.
|
||||
Defaults to None
|
||||
use_invertible_layernorm_grad: use invertible layernorm grad
|
||||
Defaults to False
|
||||
"""
|
||||
warnings.warn('DISCLAIMER: This is an early version of an experimental training API and it is subject to change. DO NOT create production applications with it')
|
||||
self.is_train = True
|
||||
|
|
@ -651,6 +655,7 @@ class ORTTrainer():
|
|||
self.state_dict_ = None
|
||||
self._enable_internal_postprocess = _enable_internal_postprocess
|
||||
self._use_deterministic_compute = _use_deterministic_compute
|
||||
self.use_invertible_layernorm_grad = use_invertible_layernorm_grad
|
||||
|
||||
# use this special string to workaround a corner case that external loss_scale is passed into train_step as kwargs.
|
||||
# see prepare_input_and_fetches for more details.
|
||||
|
|
@ -679,7 +684,8 @@ class ORTTrainer():
|
|||
deepspeed_zero_stage=self.deepspeed_zero_stage_,
|
||||
enable_grad_norm_clip=self.enable_grad_norm_clip_,
|
||||
frozen_weights=self.frozen_weights_, opset_version=self.opset_version_,
|
||||
use_deterministic_compute=self._use_deterministic_compute)
|
||||
use_deterministic_compute=self._use_deterministic_compute,
|
||||
use_invertible_layernorm_grad=self.use_invertible_layernorm_grad)
|
||||
|
||||
self.loss_scale_input_name = self.session.loss_scale_input_name
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ struct TrainingParameters {
|
|||
int deepspeed_zero_stage = 0;
|
||||
bool enable_grad_norm_clip = true;
|
||||
bool set_gradients_as_graph_outputs = false;
|
||||
bool use_invertible_layernorm_grad = false;
|
||||
};
|
||||
|
||||
struct TrainingConfigurationResult {
|
||||
|
|
@ -144,6 +145,8 @@ TrainingConfigurationResult ConfigureSessionForTraining(
|
|||
config.optimizer_config = opt;
|
||||
}
|
||||
|
||||
config.gradient_graph_config.use_invertible_layernorm_grad = parameters.use_invertible_layernorm_grad;
|
||||
|
||||
training::TrainingSession::TrainingConfigurationResult config_result{};
|
||||
|
||||
OrtPybindThrowIfError(sess->ConfigureForTraining(config, config_result));
|
||||
|
|
@ -177,7 +180,8 @@ void addObjectMethodsForTraining(py::module& m) {
|
|||
.def_readwrite("gradient_accumulation_steps", &TrainingParameters::gradient_accumulation_steps)
|
||||
.def_readwrite("deepspeed_zero_stage", &TrainingParameters::deepspeed_zero_stage)
|
||||
.def_readwrite("enable_grad_norm_clip", &TrainingParameters::enable_grad_norm_clip)
|
||||
.def_readwrite("set_gradients_as_graph_outputs", &TrainingParameters::set_gradients_as_graph_outputs);
|
||||
.def_readwrite("set_gradients_as_graph_outputs", &TrainingParameters::set_gradients_as_graph_outputs)
|
||||
.def_readwrite("use_invertible_layernorm_grad", &TrainingParameters::use_invertible_layernorm_grad);
|
||||
|
||||
py::class_<TrainingConfigurationResult> config_result(m, "TrainingConfigurationResult", "pbdoc(Configuration result for training.)pbdoc");
|
||||
config_result.def(py::init())
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||
#include "gradient_checker.h"
|
||||
#include "gradient_op_test_utils.h"
|
||||
#include "orttraining/core/framework/gradient_graph_builder.h"
|
||||
#include "orttraining/core/graph/gradient_config.h"
|
||||
#include "test/util/include/test_random_seed.h"
|
||||
#include <random>
|
||||
|
||||
|
|
@ -305,10 +306,12 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::InitOpTesterWithGradGraph(
|
|||
}
|
||||
}
|
||||
|
||||
training::GradientGraphConfiguration gradient_graph_config;
|
||||
training::GradientGraphBuilder grad_graph_builder(&graph,
|
||||
dy_values,
|
||||
weights_to_train,
|
||||
"",
|
||||
gradient_graph_config,
|
||||
true);
|
||||
Status status = grad_graph_builder.Build();
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
#include "core/session/inference_session.h"
|
||||
#include "orttraining/core/session/training_session.h"
|
||||
#include "orttraining/core/framework/gradient_graph_builder.h"
|
||||
#include "orttraining/core/graph/gradient_config.h"
|
||||
#include "default_providers.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -69,10 +70,12 @@ void GradientOpTester::Run(
|
|||
}
|
||||
}
|
||||
|
||||
training::GradientGraphConfiguration gradient_graph_config;
|
||||
training::GradientGraphBuilder grad_graph_builder(&graph,
|
||||
dy_values,
|
||||
weights_to_train,
|
||||
"",
|
||||
gradient_graph_config,
|
||||
true);
|
||||
status = grad_graph_builder.Build();
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
|
|
|
|||
|
|
@ -97,5 +97,143 @@ TEST(CudaKernelTest, LayerNormGrad_LargeSizeTensor) {
|
|||
TestLayerNormGrad(X_dims, -1, 5e-3);
|
||||
}
|
||||
|
||||
static void TestInvertibleLayerNormGrad(
|
||||
const std::vector<int64_t>& x_dims,
|
||||
int64_t axis = -1,
|
||||
double error_tolerance = 1e-4,
|
||||
bool test_fp16=false) {
|
||||
const std::vector<int64_t>& n_x_m_dims = x_dims;
|
||||
std::vector<int64_t> n_dims, m_dims;
|
||||
ASSERT_TRUE(SplitDims(n_x_m_dims, axis, n_dims, m_dims).IsOK());
|
||||
|
||||
const auto N = std::accumulate(n_dims.begin(), n_dims.end(), static_cast<int64_t>(1), std::multiplies<>{});
|
||||
const auto M = std::accumulate(m_dims.begin(), m_dims.end(), static_cast<int64_t>(1), std::multiplies<>{});
|
||||
|
||||
CompareOpTester test{"InvertibleLayerNormalizationGrad", 1, kMSDomain};
|
||||
|
||||
test.AddAttribute("axis", axis);
|
||||
|
||||
RandomValueGenerator random{};
|
||||
const auto Y_grad_data = random.Uniform<float>(n_x_m_dims, k_random_data_min, k_random_data_max);
|
||||
const auto X_data = random.Uniform<float>(n_x_m_dims, k_random_data_min, k_random_data_max);
|
||||
const auto scale_data = random.Uniform<float>(m_dims, k_random_data_min, k_random_data_max);
|
||||
const auto bias_data = random.Uniform<float>(m_dims, k_random_data_min, k_random_data_max);
|
||||
|
||||
// these inputs are dependent on X_data
|
||||
std::vector<float> mean_data(N); // mean(X)
|
||||
std::vector<float> inv_std_var_data(N); // 1 / sqrt(mean(X^2) - mean(X)^2 + epsilon)
|
||||
std::vector<float> Y_data(N*M);
|
||||
{
|
||||
using ConstEigenArrayMap = Eigen::Map<const Eigen::ArrayXX<float>>;
|
||||
using EigenArrayMap = Eigen::Map<Eigen::ArrayXX<float>>;
|
||||
|
||||
ConstEigenArrayMap X{X_data.data(), M, N};
|
||||
|
||||
for (int i = 0; i < N; ++i) {
|
||||
mean_data[i] = X.col(i).mean();
|
||||
inv_std_var_data[i] = X.col(i).square().mean() - mean_data[i] * mean_data[i];
|
||||
}
|
||||
|
||||
// Compute Y = ((x - mean) * (inv_var) * scale + bias
|
||||
EigenArrayMap Y(Y_data.data(), M, N);
|
||||
|
||||
using EigenVectorArrayMap = Eigen::Map<Eigen::Array<float, Eigen::Dynamic, 1>>;
|
||||
using ConstEigenVectorArrayMap = Eigen::Map<const Eigen::Array<float, Eigen::Dynamic, 1>>;
|
||||
ConstEigenVectorArrayMap mean(mean_data.data(), N);
|
||||
EigenVectorArrayMap inv_std_var(inv_std_var_data.data(), N);
|
||||
inv_std_var = (inv_std_var + k_epsilon_default).sqrt().inverse();
|
||||
|
||||
Y = (X.rowwise() - mean.transpose()).rowwise() * inv_std_var.transpose();
|
||||
|
||||
ConstEigenVectorArrayMap scale(scale_data.data(), M);
|
||||
ConstEigenVectorArrayMap bias(bias_data.data(), M);
|
||||
Y = (Y.colwise() * scale).colwise() + bias;
|
||||
}
|
||||
|
||||
if (test_fp16) {
|
||||
std::vector<MLFloat16> Y_grad_data_half(Y_grad_data.size());
|
||||
std::vector<MLFloat16> Y_data_half(Y_data.size());
|
||||
std::vector<MLFloat16> scale_data_half(scale_data.size());
|
||||
std::vector<MLFloat16> bias_data_half(bias_data.size());
|
||||
ConvertFloatToMLFloat16(Y_grad_data.data(),Y_grad_data_half.data(), int(Y_grad_data.size()));
|
||||
ConvertFloatToMLFloat16(Y_data.data(),Y_data_half.data(), int(Y_data.size()));
|
||||
ConvertFloatToMLFloat16(scale_data.data(), scale_data_half.data(), int(scale_data.size()));
|
||||
ConvertFloatToMLFloat16(bias_data.data(), bias_data_half.data(), int(bias_data.size()));
|
||||
|
||||
test.AddInput<MLFloat16>("Y_grad", n_x_m_dims, Y_grad_data_half);
|
||||
test.AddInput<MLFloat16>("Y", n_x_m_dims, Y_data_half);
|
||||
test.AddInput<MLFloat16>("scale", m_dims, scale_data_half, true);
|
||||
test.AddInput<MLFloat16>("bias", m_dims, bias_data_half);
|
||||
|
||||
const auto X_grad_data = FillZeros<MLFloat16>(n_x_m_dims);
|
||||
const auto scale_grad_data = FillZeros<MLFloat16>(m_dims);
|
||||
const auto bias_grad_data = FillZeros<MLFloat16>(m_dims);
|
||||
test.AddOutput("X_grad", n_x_m_dims, X_grad_data);
|
||||
test.AddOutput("scale_grad_data", m_dims, scale_grad_data);
|
||||
test.AddOutput("bias_grad_data", m_dims, bias_grad_data);
|
||||
} else {
|
||||
test.AddInput("Y_grad", n_x_m_dims, Y_grad_data);
|
||||
test.AddInput("Y", n_x_m_dims, Y_data);
|
||||
test.AddInput("scale", m_dims, scale_data, true);
|
||||
test.AddInput("bias", m_dims, bias_data);
|
||||
|
||||
const auto X_grad_data = FillZeros<float>(n_x_m_dims);
|
||||
const auto scale_grad_data = FillZeros<float>(m_dims);
|
||||
const auto bias_grad_data = FillZeros<float>(m_dims);
|
||||
test.AddOutput("X_grad", n_x_m_dims, X_grad_data);
|
||||
test.AddOutput("scale_grad_data", m_dims, scale_grad_data);
|
||||
test.AddOutput("bias_grad_data", m_dims, bias_grad_data);
|
||||
}
|
||||
test.AddInput<float>("inv_std_var", n_dims, inv_std_var_data);
|
||||
|
||||
if (test_fp16) {
|
||||
test.CompareWithCPU(kCudaExecutionProvider, error_tolerance, error_tolerance);
|
||||
} else {
|
||||
test.CompareWithCPU(kCudaExecutionProvider, error_tolerance);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, InvertibleLayerNormGrad_SmallSizeTensor) {
|
||||
const std::vector<int64_t> X_dims{4, 20, 128};
|
||||
TestInvertibleLayerNormGrad(X_dims);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, InvertibleLayerNormGrad_SmallSizeTensor_IntermediateAxis) {
|
||||
const std::vector<int64_t> X_dims{4, 20, 16, 8};
|
||||
const int64_t axis = -2;
|
||||
TestInvertibleLayerNormGrad(X_dims, axis);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, InvertibleLayerNormGrad_MidSizeTensor) {
|
||||
const std::vector<int64_t> X_dims{8, 80, 768};
|
||||
TestInvertibleLayerNormGrad(X_dims);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, InvertibleLayerNormGrad_LargeSizeTensor) {
|
||||
const std::vector<int64_t> X_dims{16, 512, 1024};
|
||||
TestInvertibleLayerNormGrad(X_dims, -1, 5e-3);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, InvertibleLayerNormGrad_SmallSizeTensor_FP16) {
|
||||
const std::vector<int64_t> X_dims{4, 20, 128};
|
||||
TestInvertibleLayerNormGrad(X_dims, -1, 2e-3, true);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, InvertibleLayerNormGrad_SmallSizeTensor_IntermediateAxis_FP16) {
|
||||
const std::vector<int64_t> X_dims{4, 20, 16, 8};
|
||||
const int64_t axis = -2;
|
||||
TestInvertibleLayerNormGrad(X_dims, axis, 2e-3, true);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, InvertibleLayerNormGrad_MidSizeTensor_FP16) {
|
||||
const std::vector<int64_t> X_dims{8, 80, 768};
|
||||
TestInvertibleLayerNormGrad(X_dims, -1, 2e-3, true);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, InvertibleLayerNormGrad_LargeSizeTensor_FP16) {
|
||||
const std::vector<int64_t> X_dims{16, 512, 1024};
|
||||
TestInvertibleLayerNormGrad(X_dims, -1, 2e-3, true);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -77,6 +77,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GistB
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GistBinarizeDecoder);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, LayerNormalizationGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, LayerNormalizationGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, InvertibleLayerNormalizationGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, InvertibleLayerNormalizationGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SliceGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FastGeluGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BiasGeluGrad_dX);
|
||||
|
|
@ -155,6 +157,8 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SummaryText)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, LayerNormalizationGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, LayerNormalizationGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, InvertibleLayerNormalizationGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, InvertibleLayerNormalizationGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GistBinarizeEncoder)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GistBinarizeDecoder)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SliceGrad)>,
|
||||
|
|
|
|||
|
|
@ -20,7 +20,16 @@ namespace contrib {
|
|||
kCpuExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
LayerNormGrad<T>);
|
||||
LayerNormGrad<T>); \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
InvertibleLayerNormalizationGrad, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T, \
|
||||
kCpuExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
InvertibleLayerNormGrad<T>);
|
||||
|
||||
REGISTER_KERNEL_TYPED(float)
|
||||
REGISTER_KERNEL_TYPED(double)
|
||||
|
|
@ -99,5 +108,76 @@ Status LayerNormGrad<T>::Compute(OpKernelContext* op_kernel_context) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
InvertibleLayerNormGrad<T>::InvertibleLayerNormGrad(const OpKernelInfo& op_kernel_info)
|
||||
: OpKernel{op_kernel_info} {
|
||||
ORT_ENFORCE(op_kernel_info.GetAttr("axis", &axis_).IsOK());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status InvertibleLayerNormGrad<T>::Compute(OpKernelContext* op_kernel_context) const {
|
||||
const Tensor* Y_grad = op_kernel_context->Input<Tensor>(0);
|
||||
const Tensor* Y = op_kernel_context->Input<Tensor>(1);
|
||||
const Tensor* scale = op_kernel_context->Input<Tensor>(2);
|
||||
const Tensor* bias = op_kernel_context->Input<Tensor>(3);
|
||||
const Tensor* inv_std_var = op_kernel_context->Input<Tensor>(4);
|
||||
|
||||
const auto& Y_shape = Y_grad->Shape();
|
||||
const auto& X_shape = Y_shape;
|
||||
const auto axis = HandleNegativeAxis(axis_, X_shape.NumDimensions());
|
||||
const auto N = X_shape.SizeToDimension(axis);
|
||||
const auto M = X_shape.SizeFromDimension(axis);
|
||||
ORT_ENFORCE(M != 1);
|
||||
const auto& scale_shape = scale->Shape();
|
||||
|
||||
Tensor* X_grad = op_kernel_context->Output(0, X_shape);
|
||||
Tensor* scale_grad = op_kernel_context->Output(1, scale_shape);
|
||||
Tensor* bias_grad = op_kernel_context->Output(2, scale_shape);
|
||||
|
||||
// Note: Eigen has column-major storage order by default
|
||||
ConstEigenArrayMap<T> Y_grad_arr{Y_grad->Data<T>(), M, N};
|
||||
ConstEigenArrayMap<T> Y_arr{Y->Data<T>(), M, N};
|
||||
ConstEigenVectorArrayMap<T> scale_vec{scale->Data<T>(), M};
|
||||
ConstEigenVectorArrayMap<T> bias_vec{bias->Data<T>(), M};
|
||||
ConstEigenVectorArrayMap<float> inv_std_var_vec{inv_std_var->Data<float>(), N};
|
||||
|
||||
EigenArrayMap<T> X_grad_arr{X_grad->MutableData<T>(), M, N};
|
||||
EigenVectorArrayMap<T> scale_grad_vec{scale_grad->MutableData<T>(), M};
|
||||
EigenVectorArrayMap<T> bias_grad_vec{bias_grad->MutableData<T>(), M};
|
||||
|
||||
using Array = Eigen::ArrayXX<T>;
|
||||
using RowVector = Eigen::Array<T, 1, Eigen::Dynamic>;
|
||||
|
||||
// A, B, C are calculated as below:
|
||||
// A = Y_grad * (X - mean(X)) * inv_std_var
|
||||
// B = Y_grad * scale * inv_std_var
|
||||
// C = Y_grad * scale * inv_std_var * (X - mean(X)) * inv_std_var
|
||||
|
||||
// A, B, and C are M x N
|
||||
Array X_mean_difference_over_std_var = (Y_arr.colwise() - bias_vec).colwise() / scale_vec;
|
||||
Array A = Y_grad_arr * X_mean_difference_over_std_var;
|
||||
Array B = (Y_grad_arr.colwise() * scale_vec).rowwise() * inv_std_var_vec.cast<T>().transpose();
|
||||
Array C = B * X_mean_difference_over_std_var;
|
||||
|
||||
// mean_B = mean(Y_grad * scale * inv_std_var)
|
||||
RowVector mean_B = B.colwise().mean(); // 1 x N
|
||||
|
||||
// mean_C = mean(Y_grad * scale * inv_std_var * (X - mean(X)) * inv_std_var)
|
||||
RowVector mean_C = C.colwise().mean(); // 1 x N
|
||||
|
||||
// X_grad = Y_grad * scale * inv_std_var - mean_B - (X - mean(X)) * inv_std_var * mean_C
|
||||
// = B - mean_B - (X - mean(X)) * inv_std_var * mean_c
|
||||
X_grad_arr = B.rowwise() - mean_B - X_mean_difference_over_std_var.rowwise() * mean_C;
|
||||
|
||||
// bias_grad = sum(Y_grad)
|
||||
bias_grad_vec = Y_grad_arr.rowwise().sum();
|
||||
|
||||
// scale_grad = sum(Y_grad * (X - mean(X)) * inv_std_var)
|
||||
// = sum(A)
|
||||
scale_grad_vec = A.rowwise().sum();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -18,5 +18,15 @@ class LayerNormGrad final : public OpKernel {
|
|||
int64_t axis_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class InvertibleLayerNormGrad final : public OpKernel {
|
||||
public:
|
||||
InvertibleLayerNormGrad(const OpKernelInfo& op_kernel_info);
|
||||
Status Compute(OpKernelContext* op_kernel_context) const override;
|
||||
|
||||
private:
|
||||
int64_t axis_;
|
||||
};
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -89,8 +89,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_MLFloat16, ReduceAllL2);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16, ReduceAllL2);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float, LayerNormalizationGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_float, LayerNormalizationGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_double, LayerNormalizationGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_float, LayerNormalizationGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float, InvertibleLayerNormalizationGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_double, InvertibleLayerNormalizationGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_float, InvertibleLayerNormalizationGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SliceGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GatherElementsGrad);
|
||||
|
||||
|
|
@ -198,8 +201,11 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_MLFloat16, ReduceAllL2)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16, ReduceAllL2)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float, LayerNormalizationGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_float, LayerNormalizationGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_double, LayerNormalizationGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_float, LayerNormalizationGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float, InvertibleLayerNormalizationGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_double, InvertibleLayerNormalizationGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_float, InvertibleLayerNormalizationGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SliceGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GatherElementsGrad)>,
|
||||
|
||||
|
|
|
|||
|
|
@ -19,9 +19,19 @@ namespace cuda {
|
|||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
||||
.TypeConstraint("U", DataTypeImpl::GetTensorType<U>()), \
|
||||
LayerNormGrad<T, U>);
|
||||
LayerNormGrad<T, U>); \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
InvertibleLayerNormalizationGrad, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T##_##U, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
||||
.TypeConstraint("U", DataTypeImpl::GetTensorType<U>()), \
|
||||
InvertibleLayerNormGrad<T, U>);
|
||||
REGISTER_GRADIENT_KERNEL_TYPED(float, float)
|
||||
REGISTER_GRADIENT_KERNEL_TYPED(double, float)
|
||||
REGISTER_GRADIENT_KERNEL_TYPED(double, double)
|
||||
REGISTER_GRADIENT_KERNEL_TYPED(MLFloat16, float)
|
||||
|
||||
template <typename T, typename U>
|
||||
|
|
@ -65,7 +75,58 @@ Status LayerNormGrad<T, U>::ComputeInternal(OpKernelContext* p_op_kernel_context
|
|||
auto part_grad_gamma = GetScratchBuffer<CudaU>(part_size * n2);
|
||||
auto part_grad_beta = GetScratchBuffer<CudaU>(part_size * n2);
|
||||
|
||||
HostLayerNormGradient(GetDeviceProp(), Y_grad_data, mean_data, inv_std_var_data, X_data, n1, n2, scale_data, X_grad_data, scale_grad_data, bias_grad_data,
|
||||
HostLayerNormGradient(GetDeviceProp(), Y_grad_data, X_data, reinterpret_cast<const CudaT*>(NULL),
|
||||
scale_data, reinterpret_cast<const CudaT*>(NULL), mean_data, inv_std_var_data, n1, n2,
|
||||
X_grad_data, scale_grad_data, bias_grad_data,
|
||||
part_grad_gamma.get(), part_grad_beta.get(), part_size);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
InvertibleLayerNormGrad<T, U>::InvertibleLayerNormGrad(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) {
|
||||
ORT_ENFORCE(op_kernel_info.GetAttr("axis", &axis_).IsOK());
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
Status InvertibleLayerNormGrad<T, U>::ComputeInternal(OpKernelContext* p_op_kernel_context) const {
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
typedef typename ToCudaType<U>::MappedType CudaU;
|
||||
// Inputs
|
||||
const Tensor* Y_grad = p_op_kernel_context->Input<Tensor>(0);
|
||||
const Tensor* Y = p_op_kernel_context->Input<Tensor>(1);
|
||||
const Tensor* scale = p_op_kernel_context->Input<Tensor>(2);
|
||||
const Tensor* bias = p_op_kernel_context->Input<Tensor>(3);
|
||||
const Tensor* inv_std_var = p_op_kernel_context->Input<Tensor>(4);
|
||||
|
||||
auto Y_grad_data = reinterpret_cast<const CudaT*>(Y_grad->template Data<T>());
|
||||
auto Y_data = reinterpret_cast<const CudaT*>(Y->template Data<T>());
|
||||
auto scale_data = reinterpret_cast<const CudaT*>(scale->template Data<T>());
|
||||
auto bias_data = reinterpret_cast<const CudaT*>(bias->template Data<T>());
|
||||
auto inv_std_var_data = reinterpret_cast<const CudaU*>(inv_std_var->template Data<U>());
|
||||
|
||||
const TensorShape& y_shape = Y->Shape();
|
||||
const TensorShape& x_shape = y_shape;
|
||||
const int64_t axis = HandleNegativeAxis(axis_, x_shape.NumDimensions());
|
||||
auto n1 = x_shape.SizeToDimension(axis);
|
||||
auto n2 = x_shape.SizeFromDimension(axis);
|
||||
ORT_ENFORCE(n2 != 1, "n2 should not be 1");
|
||||
|
||||
// Outputs
|
||||
Tensor* X_grad = p_op_kernel_context->Output(0, x_shape);
|
||||
auto X_grad_data = reinterpret_cast<CudaT*>(X_grad->template MutableData<T>());
|
||||
|
||||
Tensor* scale_grad = p_op_kernel_context->Output(1, scale->Shape());
|
||||
Tensor* bias_grad = p_op_kernel_context->Output(2, scale->Shape());
|
||||
auto scale_grad_data = reinterpret_cast<CudaT*>(scale_grad->template MutableData<T>());
|
||||
auto bias_grad_data = reinterpret_cast<CudaT*>(bias_grad->template MutableData<T>());
|
||||
|
||||
const int part_size = 16;
|
||||
auto part_grad_gamma = GetScratchBuffer<CudaU>(part_size * n2);
|
||||
auto part_grad_beta = GetScratchBuffer<CudaU>(part_size * n2);
|
||||
|
||||
HostLayerNormGradient(GetDeviceProp(), Y_grad_data, reinterpret_cast<const CudaT*>(NULL), Y_data,
|
||||
scale_data, bias_data, reinterpret_cast<const CudaU*>(NULL), inv_std_var_data, n1, n2,
|
||||
X_grad_data, scale_grad_data, bias_grad_data,
|
||||
part_grad_gamma.get(), part_grad_beta.get(), part_size);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,5 +25,15 @@ class LayerNormGrad final : public CudaKernel {
|
|||
int64_t axis_;
|
||||
};
|
||||
|
||||
template <typename T, typename U>
|
||||
class InvertibleLayerNormGrad final : public CudaKernel {
|
||||
public:
|
||||
InvertibleLayerNormGrad(const OpKernelInfo& op_kernel_info);
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
|
||||
private:
|
||||
int64_t axis_;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -372,10 +372,10 @@ void HostApplyLayerNorm(
|
|||
|
||||
LAYERNORM_LINEAR_IMPL(float, float)
|
||||
LAYERNORM_LINEAR_IMPL(half, float)
|
||||
LAYERNORM_LINEAR_IMPL(double, float)
|
||||
LAYERNORM_LINEAR_IMPL(double, double)
|
||||
//LAYERNORM_LINEAR_IMPL(half, half)
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, bool use_mean>
|
||||
__device__ void cuLoadWriteStridedInputs(
|
||||
const int i1_block,
|
||||
const int thr_load_row_off,
|
||||
|
|
@ -385,24 +385,34 @@ __device__ void cuLoadWriteStridedInputs(
|
|||
U* warp_buf1,
|
||||
U* warp_buf2,
|
||||
const T* input,
|
||||
const T* output,
|
||||
const T* dout,
|
||||
const int i1_end,
|
||||
const int n2,
|
||||
const T* __restrict__ gamma,
|
||||
const T* __restrict__ beta,
|
||||
const U* __restrict__ mean,
|
||||
const U* __restrict__ invvar) {
|
||||
int i1 = i1_block + thr_load_row_off;
|
||||
if (i1 < i1_end) {
|
||||
U curr_mean = mean[i1];
|
||||
U curr_invvar = invvar[i1];
|
||||
U curr_mean = use_mean ? mean[i1] : U(0);
|
||||
U curr_invvar = use_mean ? invvar[i1] : U(0);
|
||||
for (int k = 0; k < blockDim.y; ++k) {
|
||||
int i2 = i2_off + k;
|
||||
int load_idx = i1 * n2 + i2;
|
||||
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
|
||||
if (i2 < n2) {
|
||||
U curr_input = static_cast<U>(input[load_idx]);
|
||||
U curr_dout = static_cast<U>(dout[load_idx]);
|
||||
warp_buf1[write_idx] = curr_dout;
|
||||
warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;
|
||||
if (use_mean) {
|
||||
U curr_input = static_cast<U>(input[load_idx]);
|
||||
warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;
|
||||
} else {
|
||||
U curr_gamma = static_cast<U>(gamma[i2]);
|
||||
U curr_beta = static_cast<U>(beta[i2]);
|
||||
U curr_output = static_cast<U>(output[load_idx]);
|
||||
warp_buf2[write_idx] = curr_dout * (curr_output - curr_beta) / curr_gamma;
|
||||
}
|
||||
} else {
|
||||
warp_buf1[write_idx] = U(0);
|
||||
warp_buf2[write_idx] = U(0);
|
||||
|
|
@ -417,7 +427,7 @@ __device__ void cuLoadWriteStridedInputs(
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, bool use_mean>
|
||||
__device__ void cuLoadAddStridedInputs(
|
||||
const int i1_block,
|
||||
const int thr_load_row_off,
|
||||
|
|
@ -427,37 +437,50 @@ __device__ void cuLoadAddStridedInputs(
|
|||
U* warp_buf1,
|
||||
U* warp_buf2,
|
||||
const T* input,
|
||||
const T* output,
|
||||
const T* dout,
|
||||
const int i1_end,
|
||||
const int n2,
|
||||
const T* __restrict__ gamma,
|
||||
const T* __restrict__ beta,
|
||||
const U* __restrict__ mean,
|
||||
const U* __restrict__ invvar) {
|
||||
int i1 = i1_block + thr_load_row_off;
|
||||
if (i1 < i1_end) {
|
||||
U curr_mean = mean[i1];
|
||||
U curr_invvar = invvar[i1];
|
||||
U curr_mean = use_mean ? mean[i1] : U(0);
|
||||
U curr_invvar = use_mean ? invvar[i1] : U(0);
|
||||
for (int k = 0; k < blockDim.y; ++k) {
|
||||
int i2 = i2_off + k;
|
||||
int load_idx = i1 * n2 + i2;
|
||||
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
|
||||
if (i2 < n2) {
|
||||
U curr_input = static_cast<U>(input[load_idx]);
|
||||
U curr_dout = static_cast<U>(dout[load_idx]);
|
||||
warp_buf1[write_idx] += curr_dout;
|
||||
warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;
|
||||
if (use_mean) {
|
||||
U curr_input = static_cast<U>(input[load_idx]);
|
||||
warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;
|
||||
} else {
|
||||
U curr_gamma = static_cast<U>(gamma[i2]);
|
||||
U curr_beta = static_cast<U>(beta[i2]);
|
||||
U curr_output = static_cast<U>(output[load_idx]);
|
||||
warp_buf2[write_idx] += curr_dout * (curr_output - curr_beta) / curr_gamma;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, bool use_mean>
|
||||
__global__ void cuComputePartGradGammaBeta(
|
||||
const T* __restrict__ dout,
|
||||
const T* __restrict__ input,
|
||||
const int n1,
|
||||
const int n2,
|
||||
const T* __restrict__ output,
|
||||
const T* __restrict__ gamma,
|
||||
const T* __restrict__ beta,
|
||||
const U* __restrict__ mean,
|
||||
const U* __restrict__ invvar,
|
||||
const int n1,
|
||||
const int n2,
|
||||
U* part_grad_gamma,
|
||||
U* part_grad_beta) {
|
||||
const int numsegs_n1 = (n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y);
|
||||
|
|
@ -475,9 +498,9 @@ __global__ void cuComputePartGradGammaBeta(
|
|||
U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
|
||||
// compute partial sums from strided inputs
|
||||
// do this to increase number of loads in flight
|
||||
cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off, row_stride, warp_buf1, warp_buf2, input, dout, i1_end, n2, mean, invvar);
|
||||
cuLoadWriteStridedInputs<T, U, use_mean>(i1_beg, thr_load_row_off, thr_load_col_off, i2_off, row_stride, warp_buf1, warp_buf2, input, output, dout, i1_end, n2, gamma, beta, mean, invvar);
|
||||
for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end; i1_block += blockDim.y * blockDim.y) {
|
||||
cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off, row_stride, warp_buf1, warp_buf2, input, dout, i1_end, n2, mean, invvar);
|
||||
cuLoadAddStridedInputs<T, U, use_mean>(i1_block, thr_load_row_off, thr_load_col_off, i2_off, row_stride, warp_buf1, warp_buf2, input, output, dout, i1_end, n2, gamma, beta, mean, invvar);
|
||||
}
|
||||
__syncthreads();
|
||||
// inter-warp reductions
|
||||
|
|
@ -566,22 +589,25 @@ __global__ void cuComputeGradGammaBeta(
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, bool use_mean>
|
||||
__global__ void cuComputeGradInput(
|
||||
const T* __restrict__ dout,
|
||||
const T* __restrict__ input,
|
||||
const int n1,
|
||||
const int n2,
|
||||
const T* __restrict__ output,
|
||||
const T* gamma,
|
||||
const T* beta,
|
||||
const U* __restrict__ mean,
|
||||
const U* __restrict__ invvar,
|
||||
const T* gamma,
|
||||
const int n1,
|
||||
const int n2,
|
||||
T* grad_input) {
|
||||
for (int i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
|
||||
U sum_loss1 = U(0);
|
||||
U sum_loss2 = U(0);
|
||||
const U c_mean = mean[i1];
|
||||
const U c_mean = use_mean ? mean[i1] : U(0);
|
||||
const U c_invvar = invvar[i1];
|
||||
const T* k_input = input + i1 * n2;
|
||||
const T* k_input = use_mean ? input + i1 * n2 : nullptr;
|
||||
const T* k_output = use_mean ? nullptr: output + i1 * n2;
|
||||
const T* k_dout = dout + i1 * n2;
|
||||
const int numx = blockDim.x * blockDim.y;
|
||||
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
|
|
@ -589,33 +615,53 @@ __global__ void cuComputeGradInput(
|
|||
int l = 4 * thrx;
|
||||
for (; l + 3 < n2; l += 4 * numx) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
const U c_h = static_cast<U>(k_input[l + k]);
|
||||
const U c_loss = static_cast<U>(k_dout[l + k]);
|
||||
sum_loss1 += c_loss * U(gamma[l + k]);
|
||||
sum_loss2 += c_loss * U(gamma[l + k]) * (c_h - c_mean) * c_invvar;
|
||||
if (use_mean) {
|
||||
const U c_h = static_cast<U>(k_input[l + k]);
|
||||
sum_loss2 += c_loss * U(gamma[l + k]) * (c_h - c_mean) * c_invvar;
|
||||
} else {
|
||||
const U c_output = static_cast<U>(k_output[l + k]);
|
||||
sum_loss2 += c_loss * (c_output - U(beta[l + k]));
|
||||
}
|
||||
}
|
||||
}
|
||||
for (; l < n2; ++l) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
const U c_loss = static_cast<U>(k_dout[l]);
|
||||
sum_loss1 += c_loss * U(gamma[l]);
|
||||
sum_loss2 += c_loss * U(gamma[l]) * (c_h - c_mean) * c_invvar;
|
||||
if (use_mean) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
sum_loss2 += c_loss * U(gamma[l]) * (c_h - c_mean) * c_invvar;
|
||||
} else {
|
||||
const U c_output = static_cast<U>(k_output[l]);
|
||||
sum_loss2 += c_loss * (c_output - U(beta[l]));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int l = 4 * thrx;
|
||||
for (; l + 3 < n2; l += 4 * numx) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
const U c_h = static_cast<U>(k_input[l + k]);
|
||||
const U c_loss = static_cast<U>(k_dout[l + k]);
|
||||
sum_loss1 += c_loss;
|
||||
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
|
||||
if (use_mean) {
|
||||
const U c_h = static_cast<U>(k_input[l + k]);
|
||||
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
|
||||
} else {
|
||||
const U c_output = static_cast<U>(k_output[l + k]);
|
||||
sum_loss2 += c_loss * c_output;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (; l < n2; ++l) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
const U c_loss = static_cast<U>(k_dout[l]);
|
||||
sum_loss1 += c_loss;
|
||||
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
|
||||
if (use_mean) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
|
||||
} else {
|
||||
const U c_output = static_cast<U>(k_output[l]);
|
||||
sum_loss2 += c_loss * c_output;
|
||||
}
|
||||
}
|
||||
}
|
||||
// intra-warp reductions
|
||||
|
|
@ -659,21 +705,31 @@ __global__ void cuComputeGradInput(
|
|||
T* k_grad_input = grad_input + i1 * n2;
|
||||
if (gamma != NULL) {
|
||||
for (int l = thrx; l < n2; l += numx) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
const U c_loss = static_cast<U>(k_dout[l]);
|
||||
U f_grad_input = fH * c_loss * U(gamma[l]);
|
||||
f_grad_input -= sum_loss1;
|
||||
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
|
||||
if (use_mean) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
|
||||
} else {
|
||||
const U c_output = static_cast<U>(k_output[l]);
|
||||
f_grad_input -= (c_output - U(beta[l])) / U(gamma[l]) * sum_loss2;
|
||||
}
|
||||
f_grad_input *= term1;
|
||||
k_grad_input[l] = static_cast<T>(f_grad_input);
|
||||
}
|
||||
} else {
|
||||
for (int l = thrx; l < n2; l += numx) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
const U c_loss = static_cast<U>(k_dout[l]);
|
||||
U f_grad_input = fH * c_loss;
|
||||
f_grad_input -= sum_loss1;
|
||||
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
|
||||
if (use_mean) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
|
||||
} else {
|
||||
const U c_output = static_cast<U>(k_output[l]);
|
||||
f_grad_input -= c_output * sum_loss2;
|
||||
}
|
||||
f_grad_input *= term1;
|
||||
k_grad_input[l] = static_cast<T>(f_grad_input);
|
||||
}
|
||||
|
|
@ -683,20 +739,22 @@ __global__ void cuComputeGradInput(
|
|||
|
||||
template <typename T, typename U>
|
||||
void HostLayerNormGradient(
|
||||
const cudaDeviceProp& prop,
|
||||
const T* dout,
|
||||
const U* mean,
|
||||
const U* invvar,
|
||||
const T* input,
|
||||
int64_t n1,
|
||||
int64_t n2,
|
||||
const T* gamma,
|
||||
T* grad_input,
|
||||
T* grad_gamma,
|
||||
T* grad_beta,
|
||||
U* part_grad_gamma,
|
||||
U* part_grad_beta,
|
||||
const int part_size) {
|
||||
const cudaDeviceProp& prop,
|
||||
const T* dout,
|
||||
const T* input,
|
||||
const T* output,
|
||||
const T* gamma,
|
||||
const T* beta,
|
||||
const U* mean,
|
||||
const U* invvar,
|
||||
int64_t n1,
|
||||
int64_t n2,
|
||||
T* grad_input,
|
||||
T* grad_gamma,
|
||||
T* grad_beta,
|
||||
U* part_grad_gamma,
|
||||
U* part_grad_beta,
|
||||
const int part_size) {
|
||||
const int warp_size = prop.warpSize;
|
||||
ORT_ENFORCE(warp_size == GPU_WARP_SIZE);
|
||||
|
||||
|
|
@ -706,14 +764,31 @@ void HostLayerNormGradient(
|
|||
const int nshared2_b = threads2.x * threads2.y * sizeof(U);
|
||||
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
|
||||
|
||||
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, 0>>>(
|
||||
if (mean == nullptr) {
|
||||
cuComputePartGradGammaBeta<T, U, false><<<blocks2, threads2, nshared2, 0>>>(
|
||||
dout,
|
||||
input,
|
||||
n1, n2,
|
||||
output,
|
||||
gamma,
|
||||
beta,
|
||||
mean,
|
||||
invvar,
|
||||
n1, n2,
|
||||
part_grad_gamma,
|
||||
part_grad_beta);
|
||||
} else {
|
||||
cuComputePartGradGammaBeta<T, U, true><<<blocks2, threads2, nshared2, 0>>>(
|
||||
dout,
|
||||
input,
|
||||
output,
|
||||
gamma,
|
||||
beta,
|
||||
mean,
|
||||
invvar,
|
||||
n1, n2,
|
||||
part_grad_gamma,
|
||||
part_grad_beta);
|
||||
}
|
||||
|
||||
const dim3 threads3(warp_size, 8, 1);
|
||||
const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1);
|
||||
|
|
@ -732,22 +807,38 @@ void HostLayerNormGradient(
|
|||
const dim3 threads1(warp_size, 4, 1);
|
||||
int nshared =
|
||||
threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0;
|
||||
cuComputeGradInput<<<blocks1, threads1, nshared, 0>>>(
|
||||
if (mean == nullptr) {
|
||||
cuComputeGradInput<T, U, false><<<blocks1, threads1, nshared, 0>>>(
|
||||
dout,
|
||||
input,
|
||||
n1, n2,
|
||||
output,
|
||||
gamma,
|
||||
beta,
|
||||
mean,
|
||||
invvar,
|
||||
gamma,
|
||||
n1, n2,
|
||||
grad_input);
|
||||
} else {
|
||||
cuComputeGradInput<T, U, true><<<blocks1, threads1, nshared, 0>>>(
|
||||
dout,
|
||||
input,
|
||||
output,
|
||||
gamma,
|
||||
beta,
|
||||
mean,
|
||||
invvar,
|
||||
n1, n2,
|
||||
grad_input);
|
||||
}
|
||||
}
|
||||
|
||||
#define LAYERNORMGRAD_IMPL(T, U) \
|
||||
template void HostLayerNormGradient(const cudaDeviceProp& prop, const T* dout, const U* mean, const U* invvar, const T* input, int64_t n1, int64_t n2, const T* gamma, \
|
||||
#define LAYERNORMGRAD_IMPL(T, U) \
|
||||
template void HostLayerNormGradient(const cudaDeviceProp& prop, const T* dout, const T* input, const T* output, \
|
||||
const T* gamma, const T* beta, const U* mean, const U* invvar, int64_t n1, int64_t n2, \
|
||||
T* grad_input, T* grad_gamma, T* grad_beta, U* part_grad_gamma, U* part_grad_beta, const int part_size);
|
||||
|
||||
LAYERNORMGRAD_IMPL(float, float)
|
||||
LAYERNORMGRAD_IMPL(double, float)
|
||||
LAYERNORMGRAD_IMPL(double, double)
|
||||
LAYERNORMGRAD_IMPL(half, float)
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -45,12 +45,14 @@ template <typename T, typename U>
|
|||
void HostLayerNormGradient(
|
||||
const cudaDeviceProp& prop,
|
||||
const T* dout,
|
||||
const T* input,
|
||||
const T* output,
|
||||
const T* gamma,
|
||||
const T* beta,
|
||||
const U* mean,
|
||||
const U* invvar,
|
||||
const T* input,
|
||||
int64_t n1,
|
||||
int64_t n2,
|
||||
const T* gamma,
|
||||
T* grad_input,
|
||||
T* grad_gamma,
|
||||
T* grad_beta,
|
||||
|
|
|
|||
Loading…
Reference in a new issue