LeakyRelu Gradient (#17039)

This commit is contained in:
Baiju Meswani 2023-08-10 20:45:34 -07:00 committed by GitHub
parent 0180c0429f
commit 3e7f70bf88
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 183 additions and 2 deletions

View file

@ -64,6 +64,10 @@ constexpr float SigmoidGrad(float dy, float y) {
constexpr float TanhGrad(float dy, float y) {
return dy * (1 - y * y);
}
constexpr float LeakyReluGrad(float dy, float y, float alpha) {
return dy * (y > 0.0f ? 1.0f : alpha);
}
} // namespace
#endif
@ -669,6 +673,23 @@ TEST(TanhGradInferenceTest, Basic) {
},
{}, 1, kMSDomain);
}
TEST(LeakyReluGradInferenceTest, Basic) {
const std::vector<float> y_vals = {-1.0f, 0, 1.0f, 100.0f, -100.0f, 1000.0f, -1000.0f};
const std::vector<float> dY(7, 1.0f);
float alpha = 0.5f;
TestElementwiseGradientOp(
"LeakyReluGrad",
{{"dY", dY}, {"Y", y_vals}},
[alpha](const std::vector<float>& params) {
ORT_ENFORCE(params.size() == 2);
const auto dy = params[0], y = params[1];
return LeakyReluGrad(dy, y, alpha);
},
{{"alpha", alpha}}, 1, kMSDomain);
}
#endif
} // namespace test

View file

@ -2065,5 +2065,10 @@ IMPLEMENT_GRADIENT_BUILDER(GetReciprocalGradient) {
NodeDef("Mul", {GO(0), IA("Neg_Square_O0")}, {GI(0)})};
}
IMPLEMENT_GRADIENT_BUILDER(GetLeakyReluGradient) {
return {NodeDef(OpDef{"LeakyReluGrad", kMSDomain, 1},
{GO(0), O(0)}, {GI(0)}, SrcNodeAttributes())};
}
} // namespace training
} // namespace onnxruntime

View file

@ -87,6 +87,7 @@ DECLARE_GRADIENT_BUILDER(GetFakeQuantGradient)
DECLARE_GRADIENT_BUILDER(GetLSTMGradient)
DECLARE_GRADIENT_BUILDER(GetGRUGradient)
DECLARE_GRADIENT_BUILDER(GetReciprocalGradient)
DECLARE_GRADIENT_BUILDER(GetLeakyReluGradient)
DECLARE_GRADIENT_BUILDER(GetExternalGradient)

View file

@ -119,6 +119,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("LSTMTraining", GetLSTMGradient);
REGISTER_GRADIENT_BUILDER("GRUTraining", GetGRUGradient);
REGISTER_GRADIENT_BUILDER("Reciprocal", GetReciprocalGradient);
REGISTER_GRADIENT_BUILDER("LeakyRelu", GetLeakyReluGradient);
REGISTER_GRADIENT_BUILDER("ExternalGradient", GetExternalGradient);
};

View file

@ -2900,6 +2900,19 @@ Example 4:
return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {});
});
ONNX_CONTRIB_OPERATOR_SCHEMA(LeakyReluGrad)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetDoc("Gradient operator for LeakyRelu.")
.Attr("alpha", "Alpha (negative slope) value.", AttributeProto::FLOAT, 0.01f)
.AllowUncheckedAttributes()
.Input(0, "dY", "The gradient tensor from output.", "T")
.Input(1, "Y", "The output tensor. ", "T")
.Output(0, "dX", "Gradient of the input.", "T")
.TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
ONNX_CONTRIB_OPERATOR_SCHEMA(LayerNormalizationGrad)
.SetDomain(kMSDomain)
.SinceVersion(1)

View file

@ -3033,6 +3033,12 @@ TEST(GradientCheckerTest, ReciprocalGrad) {
UnaryOpGradientTest("Reciprocal", kOnnxDomain, 12, nullptr, &transformer);
}
TEST(GradientCheckerTest, LeakyReluGrad) {
// Gradient is non continuous at 0, so we need to avoid it.
std::function<float(float)> transformer = [](float x) { return x > 0 ? x + 0.2f : x - 0.2f; };
UnaryOpGradientTest("LeakyRelu", kOnnxDomain, 16, nullptr, &transformer);
}
} // namespace test
} // namespace onnxruntime

View file

@ -6162,7 +6162,39 @@ def test_reciprocal_gradient():
pt_model = ReciprocalModel().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
pt_x = torch.zeros(3, 224, 224, requires_grad=True, device=device)
pt_x = torch.randn(3, 224, 224, requires_grad=True, device=device)
with torch.no_grad():
pt_x[pt_x <= 0] -= 0.2
pt_x[pt_x > 0] += 0.2
ort_x = copy.deepcopy(pt_x)
pt_prediction, pt_loss = run_step(pt_model, pt_x)
ort_prediction, ort_loss = run_step(ort_model, ort_x)
_test_helpers.assert_values_are_close(pt_prediction, ort_prediction)
_test_helpers.assert_values_are_close(pt_loss, ort_loss)
_test_helpers.assert_values_are_close(pt_x.grad, ort_x.grad)
def test_leakyrelu_gradient():
class LeakyReluModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.leakyrelu = nn.LeakyReLU(0.5)
def forward(self, x):
return self.leakyrelu(x)
def run_step(model, x):
prediction = model(x)
loss = prediction.sum()
loss.backward()
return prediction, loss
device = "cuda"
pt_model = LeakyReluModel().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
pt_x = torch.randn(3, 224, 224, requires_grad=True, device=device)
with torch.no_grad():
pt_x[pt_x <= 0] -= 0.2
pt_x[pt_x > 0] += 0.2

View file

@ -89,6 +89,10 @@ float QuickGeluGrad(float dy, float x, float alpha) {
float sigmoid = v >= 0 ? 1.f / (1.f + std::exp(-v)) : 1.f - 1.f / (1 + std::exp(v));
return dy * sigmoid * (1 + v * (1 - sigmoid));
}
constexpr float LeakyReluGrad(float dy, float y, float alpha) {
return dy * (y > 0.0f ? 1.0f : alpha);
}
} // namespace
TEST(GeluGradTest, Basic) {
@ -263,6 +267,23 @@ TEST(QuickGeluGradTest, Basic) {
}
}
TEST(LeakyReluGradTest, Basic) {
const std::vector<float> y_vals = {-1.0f, 0, 1.0f, 100.0f, -100.0f, 1000.0f, -1000.0f};
const std::vector<float> dY(7, 1.0f);
float alpha = 0.5f;
TestElementwiseGradientOp(
"LeakyReluGrad",
{{"dY", dY}, {"Y", y_vals}},
[alpha](const std::vector<float>& params) {
ORT_ENFORCE(params.size() == 2);
const auto dy = params[0], y = params[1];
return LeakyReluGrad(dy, y, alpha);
},
{{"alpha", alpha}}, 1, kMSDomain);
}
namespace {
template <typename TComputeGeluGradScalarFn>
void TestBiasGeluGradBroadcastBias(const std::string& op, int opset_version, const std::string& domain,

View file

@ -90,6 +90,13 @@ TEST(CudaKernelTest, TanhGrad_basic) {
}
}
TEST(CudaKernelTest, LeakyReluGrad_basic) {
std::vector<std::vector<int64_t>> test_dims{{4}, {16, 2}, {8, 2, 128, 128}};
for (const auto& test_dim : test_dims) {
TestActivations(test_dim, "LeakyReluGrad", true /* grad_op */);
}
}
static void TestActivationsWithBroadcastBias(
const std::vector<int64_t>& tensor_dim,
const std::string& operator_name,

View file

@ -53,6 +53,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GeluG
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SigmoidGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TanhGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuickGeluGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, LeakyReluGrad);
// REVIEW(mzs): ConstEigenVectorArrayMap.cast<MLFLoat16) does not seem to be supported.
// However these types work on GPU implementation.
@ -182,6 +183,7 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SigmoidGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TanhGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuickGeluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, LeakyReluGrad)>,
// REVIEW(mzs): ConstEigenVectorArrayMap.cast<MLFLoat16) does not seem to be supported.
// However these types work on GPU implementation.
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16, DropoutGrad)>,

View file

@ -260,5 +260,24 @@ Status QuickGeluGrad<T>::Compute(OpKernelContext* context) const {
return Status::OK();
}
ONNX_OPERATOR_KERNEL_EX(LeakyReluGrad, kMSDomain, 1, kCpuExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
LeakyReluGrad<float>);
template <typename T>
Status LeakyReluGrad<T>::Compute(OpKernelContext* context) const {
auto& dY = *context->Input<Tensor>(0);
auto& Y = *context->Input<Tensor>(1);
auto& dX = *context->Output(0, dY.Shape());
EigenVectorArrayMap<float> dx = EigenVectorArrayMap<float>(dX.template MutableData<T>(),
narrow<Eigen::Index>(dX.Shape().Size()));
ConstEigenVectorArrayMap<float> y = ConstEigenVectorArrayMap<float>(Y.template Data<T>(),
narrow<Eigen::Index>(Y.Shape().Size()));
ConstEigenVectorArrayMap<float> dy = ConstEigenVectorArrayMap<float>(dY.template Data<T>(),
narrow<Eigen::Index>(dY.Shape().Size()));
dx = (y > 0.0f).select(dy, alpha_ * dy);
return Status::OK();
}
} // namespace contrib
} // namespace onnxruntime

View file

@ -78,5 +78,19 @@ class SoftmaxGrad final : public OpKernel {
bool is_logsoftmaxgrad_;
};
template <typename T>
class LeakyReluGrad final : public OpKernel {
public:
explicit LeakyReluGrad(const OpKernelInfo& info) : OpKernel(info) {
alpha_ = info.GetAttrOrDefault<float>("alpha", 0.01f);
}
Status Compute(OpKernelContext* context) const override;
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(LeakyReluGrad);
float alpha_;
};
} // namespace contrib
} // namespace onnxruntime

View file

@ -49,6 +49,7 @@ ACTIVATION_GRAD_OP_HFD(ReluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(SigmoidGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(QuickGeluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(TanhGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(LeakyReluGrad, 1, kMSDomain);
} // namespace cuda
} // namespace onnxruntime

View file

@ -79,5 +79,20 @@ class TanhGrad final : public BinaryElementwise<ShouldNotBroadcast> {
private:
MAKE_FUNC_CTX_NULL()
};
template <typename T>
class LeakyReluGrad final : public BinaryElementwise<ShouldNotBroadcast> {
public:
LeakyReluGrad(const OpKernelInfo& info) : BinaryElementwise(info) {
alpha_ = info.GetAttrOrDefault<float>("alpha", 0.01f);
}
Status ComputeInternal(OpKernelContext* context) const override;
private:
MAKE_FUNC_CTX_ALPHA()
float alpha_;
};
} // namespace cuda
} // namespace onnxruntime

View file

@ -64,6 +64,13 @@ struct OP_TanhGrad : public CtxTanhGrad {
}
};
template <typename T>
struct OP_LeakyReluGrad : public CtxLeakyReluGrad {
__device__ __inline__ T operator()(const T& dy, const T& y) const {
return dy * (y > T{0} ? T{1} : static_cast<T>(alpha));
}
};
#define BINARY_ELEMENTWISE_IMPL(name) \
BINARY_ELEMENTWISE_IMPL_DECLARATION(name) { \
BinaryElementWiseNoBroadcastImpl(stream, \

View file

@ -13,6 +13,7 @@ typedef onnxruntime::cuda::CtxNull CtxReluGrad;
typedef onnxruntime::cuda::CtxNull CtxSigmoidGrad;
typedef onnxruntime::cuda::CtxAlpha CtxQuickGeluGrad;
typedef onnxruntime::cuda::CtxNull CtxTanhGrad;
typedef onnxruntime::cuda::CtxAlpha CtxLeakyReluGrad;
#define ACTIVATION_GRAD_OPS() \
ACTIVATION_GRAD_OP_NAME(GeluGrad) \
@ -20,7 +21,8 @@ typedef onnxruntime::cuda::CtxNull CtxTanhGrad;
ACTIVATION_GRAD_OP_NAME(ReluGrad) \
ACTIVATION_GRAD_OP_NAME(SigmoidGrad) \
ACTIVATION_GRAD_OP_NAME(QuickGeluGrad) \
ACTIVATION_GRAD_OP_NAME(TanhGrad)
ACTIVATION_GRAD_OP_NAME(TanhGrad) \
ACTIVATION_GRAD_OP_NAME(LeakyReluGrad)
#define BINARY_ELEMENTWISE_IMPL_DECLARATION(name) \
template <typename T> \

View file

@ -123,6 +123,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TanhGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TanhGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, LeakyReluGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, LeakyReluGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, LeakyReluGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, IsFinite);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, IsFinite);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, IsFinite);
@ -366,6 +370,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TanhGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TanhGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TanhGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, LeakyReluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, LeakyReluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, LeakyReluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, IsFinite)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, IsFinite)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, IsFinite)>,

View file

@ -119,6 +119,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, TanhGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, TanhGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, LeakyReluGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, LeakyReluGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, LeakyReluGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, IsFinite);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, IsFinite);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, IsFinite);
@ -316,6 +320,9 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, TanhGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, TanhGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, TanhGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, LeakyReluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, LeakyReluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, LeakyReluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, IsFinite)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, IsFinite)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, IsFinite)>,