Relu grad kernel (#4864)

* create branch for debug

* move unit test

* more changes

* move relu to activations_grad*

* Fix ReluGrad Domain and opset version

* added unit test, CudaKernelTest.Relu_basic doesn't work yet

* remove CudaKernelTest.Relu_basic

* PR comment

* add unit test ReluGradTest_Basic

Co-authored-by: Jingyan Wang <jingywa@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
jingyanwangms 2020-08-22 01:03:44 -07:00 committed by GitHub
parent dce2ce7a4f
commit fa68bbc82e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 119 additions and 62 deletions

View file

@ -7,63 +7,62 @@
namespace onnxruntime {
namespace test {
TEST_F(ActivationOpTest, Sigmoid) {
TestActivationOp("Sigmoid",
input_values,
[](float x) {
auto y = 1.f / (1.f + std::exp(-std::abs(x))); // safe sigmoid
y = x > 0 ? y : 1 - y;
return y;
});
input_values,
[](float x) {
auto y = 1.f / (1.f + std::exp(-std::abs(x))); // safe sigmoid
y = x > 0 ? y : 1 - y;
return y;
});
}
TEST_F(ActivationOpTest, HardSigmoid) {
float alpha = 0.2f;
float beta = 0.5f;
TestActivationOp("HardSigmoid",
input_values,
[alpha, beta](float x) {
return std::max(std::min((alpha * x + beta), 1.0f), 0.0f);
},
{{"alpha", alpha}, {"beta", beta}});
input_values,
[alpha, beta](float x) {
return std::max(std::min((alpha * x + beta), 1.0f), 0.0f);
},
{{"alpha", alpha}, {"beta", beta}});
}
TEST_F(ActivationOpTest, Tanh) {
TestActivationOp("Tanh",
input_values,
[](float x) { return std::tanh(x); });
input_values,
[](float x) { return std::tanh(x); });
}
TEST_F(ActivationOpTest, Relu) {
TestActivationOp("Relu",
input_values,
[](float x) { return std::max(x, 0.0f); });
input_values,
[](float x) { return std::max(x, 0.0f); });
}
TEST_F(ActivationOpTest, Elu) {
float alpha = 0.1f;
TestActivationOp("Elu",
input_values,
[alpha](float x) { return (x >= 0) ? x : alpha * (exp(x) - 1); },
{{"alpha", alpha}});
input_values,
[alpha](float x) { return (x >= 0) ? x : alpha * (exp(x) - 1); },
{{"alpha", alpha}});
}
TEST_F(ActivationOpTest, LeakyRelu) {
float alpha = 0.1f;
TestActivationOp("LeakyRelu",
input_values,
[alpha](float x) { return (x >= 0) ? x : alpha * x; },
{{"alpha", alpha}});
input_values,
[alpha](float x) { return (x >= 0) ? x : alpha * x; },
{{"alpha", alpha}});
}
TEST_F(ActivationOpTest, ThresholdedRelu) {
float alpha = 0.1f;
TestActivationOp(
"ThresholdedRelu",
input_values,
[alpha](float x) { return (x >= alpha) ? x : 0; },
{{"alpha", alpha}}, true, 10);
"ThresholdedRelu",
input_values,
[alpha](float x) { return (x >= alpha) ? x : 0; },
{{"alpha", alpha}}, true, 10);
}
TEST_F(ActivationOpTest, Selu) {
@ -71,9 +70,9 @@ TEST_F(ActivationOpTest, Selu) {
static constexpr float gamma = 1.0507f;
TestActivationOp("Selu",
input_values,
[](float x) { return x <= 0 ? gamma * (alpha * exp(x) - alpha) : gamma * x; },
{{"alpha", alpha}, {"gamma", gamma}});
input_values,
[](float x) { return x <= 0 ? gamma * (alpha * exp(x) - alpha) : gamma * x; },
{{"alpha", alpha}, {"gamma", gamma}});
}
TEST_F(ActivationOpTest, Selu_Attributes) {
@ -81,9 +80,9 @@ TEST_F(ActivationOpTest, Selu_Attributes) {
static constexpr float gamma = 0.5f;
TestActivationOp("Selu",
input_values,
[](float x) { return x <= 0 ? gamma * (alpha * exp(x) - alpha) : gamma * x; },
{{"alpha", alpha}, {"gamma", gamma}});
input_values,
[](float x) { return x <= 0 ? gamma * (alpha * exp(x) - alpha) : gamma * x; },
{{"alpha", alpha}, {"gamma", gamma}});
}
TEST_F(ActivationOpTest, PRelu) {
@ -146,20 +145,20 @@ TEST_F(ActivationOpTest, PRelu_MultiChannel) {
TEST_F(ActivationOpTest, Softplus) {
TestActivationOp("Softplus",
input_values,
[](float x) {
if (x > 0)
return x + logf(expf(-x) + 1);
else
return logf(expf(x) + 1);
});
input_values,
[](float x) {
if (x > 0)
return x + logf(expf(-x) + 1);
else
return logf(expf(x) + 1);
});
}
TEST_F(ActivationOpNoInfTest, Softsign) {
TestActivationOp(
"Softsign",
input_values,
[](float x) { return x / (1 + std::abs(x)); }, {}, false); // Disable TensorRT because result mismatches
"Softsign",
input_values,
[](float x) { return x / (1 + std::abs(x)); }, {}, false); // Disable TensorRT because result mismatches
}
} // namespace test

View file

@ -781,7 +781,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetGatherElementsGradient) {
IMPLEMENT_GRADIENT_BUILDER(GetReluGradient) {
return std::vector<NodeDef>{
NodeDef("ReluGrad",
NodeDef(OpDef{"ReluGrad", kMSDomain, 1},
{GO(0), O(0)},
{GI(0)})};
}

View file

@ -322,7 +322,8 @@ OpSchema& RegisterLambOpSchema(OpSchema&& op_schema) {
void RegisterTrainingOpSchemas() {
ONNX_CONTRIB_OPERATOR_SCHEMA(ReluGrad)
.SinceVersion(9)
.SetDomain(kMSDomain)
.SinceVersion(1)
.Input(0, "dY", "Gradient of output Y", "T")
.Input(1, "X", "Input tensor", "T")
.Output(0, "dX", "Gradient of input X", "T")

View file

@ -523,6 +523,10 @@ TEST(GradientCheckerTest, ReduceLogSumExpGrad) {
RunReductionTests(op_def);
}
TEST(GradientCheckerTest, ReluGrad) {
UnaryOpGradientTest("Relu");
}
#ifndef USE_CUDA
TEST(GradientCheckerTest, CastGrad) {
// A dummy test that cast float to float
@ -540,10 +544,6 @@ TEST(GradientCheckerTest, CastGrad) {
}
}
TEST(GradientCheckerTest, ReluGrad) {
UnaryOpGradientTest("Relu");
}
TEST(GradientCheckerTest, SplitGrad) {
TensorShape shape({9, 5});
float max_error;

View file

@ -71,6 +71,10 @@ float GeluApproximationGrad(float dy, float x) {
float result = dy * 0.5f * (tanh_value + (sech_sqr_value * (kAlpha * x + kBeta * x_cube)) + 1.0f);
return result;
}
float ReluGrad(float dy, float x) {
return x > 0 ? dy : 0;
}
} // namespace
TEST(GeluGradTest, Basic) {
@ -139,6 +143,22 @@ TEST(BiasFastGeluGradDxTest, Basic) {
{}, 1, kMSDomain);
}
TEST(ReluGradTest, Basic) {
const std::vector<float> x_vals = {-1.0f, 0, 1.0f, 100.0f, -100.0f, 1000.0f, -1000.0f};
const std::vector<float> dY(7, 1.0f);
TestElementwiseGradientOp(
"ReluGrad",
{{"dY", dY}, {"X", x_vals}},
[](const std::vector<float>& params) {
ORT_ENFORCE(params.size() == 2);
const auto dy = params[0], x = params[1];
return ReluGrad(dy, x);
},
{}, 1, kMSDomain);
}
namespace {
template <typename TComputeGeluGradScalarFn>
void TestBiasGeluGradBroadcastBias(const std::string& op, int opset_version, const std::string& domain,

View file

@ -63,6 +63,13 @@ TEST(CudaKernelTest, FastGeluGrad_basic) {
}
}
TEST(CudaKernelTest, ReluGrad_basic) {
std::vector<std::vector<int64_t>> test_dims{{4}, {16, 2}, {8, 2, 128, 128}};
for (const auto& test_dim : test_dims) {
TestActivations(test_dim, "ReluGrad", true /* grad_op */);
}
}
static void TestActivationsWithBroadcastBias(
const std::vector<int64_t>& tensor_dim,
const std::string& operator_name,

View file

@ -34,7 +34,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int64_t, SoftmaxCrossEntropyLossGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, SinGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ConvGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ReluGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ReluGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SoftmaxGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, LogSoftmaxGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, AveragePoolGrad);
@ -126,7 +126,7 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int64_t, SoftmaxCrossEntropyLossGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, SinGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ConvGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ReluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ReluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SoftmaxGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, LogSoftmaxGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, AveragePoolGrad)>,

View file

@ -29,9 +29,11 @@ Status SinGrad<T>::Compute(OpKernelContext* context) const {
return Status::OK();
}
ONNX_CPU_OPERATOR_KERNEL(
ONNX_OPERATOR_KERNEL_EX(
ReluGrad,
9,
kMSDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
ReluGrad<float>);
@ -101,12 +103,12 @@ Status SoftmaxGrad<T>::Compute(OpKernelContext* context) const {
}
ONNX_OPERATOR_KERNEL_EX(
LogSoftmaxGrad,
kMSDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
LogSoftmaxGrad<float>);
LogSoftmaxGrad,
kMSDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
LogSoftmaxGrad<float>);
template <typename T>
Status LogSoftmaxGrad<T>::Compute(OpKernelContext* context) const {
@ -133,14 +135,14 @@ Status LogSoftmaxGrad<T>::Compute(OpKernelContext* context) const {
std::vector<float> eY(nd);
float* eYdata = eY.data();
// dX_ai = d(log Y_ai) - [sum_j d(log Y_aj)] exp(log Y_ai)
gsl::copy(gsl::make_span(dYdata, nd), gsl::make_span(dXdata, nd));
math::Exp<float, CPUMathUtil>(nd, Ydata, eYdata, nullptr);
for (size_t i = 0; i < N; ++i) {
float sdY;
math::Sum<float, CPUMathUtil>(d, dYdata + i*d, &sdY, nullptr, nullptr);
math::Axpy<float, CPUMathUtil>(d, -sdY, eYdata + i*d, dXdata + i*d, nullptr);
math::Sum<float, CPUMathUtil>(d, dYdata + i * d, &sdY, nullptr, nullptr);
math::Axpy<float, CPUMathUtil>(d, -sdY, eYdata + i * d, dXdata + i * d, nullptr);
}
return Status::OK();

View file

@ -45,6 +45,7 @@ namespace cuda {
ACTIVATION_GRAD_OP_HFD(GeluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(FastGeluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(ReluGrad, 1, kMSDomain);
} //namespace cuda
} // namespace onnxruntime

View file

@ -33,5 +33,16 @@ class FastGeluGrad final : public BinaryElementwise<ShouldNotBroadcast> {
MAKE_FUNC_CTX_NULL()
};
template <typename T>
class ReluGrad final : public BinaryElementwise<ShouldNotBroadcast> {
public:
ReluGrad(const OpKernelInfo& info) : BinaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
private:
MAKE_FUNC_CTX_NULL()
};
} // namespace cuda
} // namespace onnxruntime

View file

@ -24,6 +24,13 @@ struct OP_FastGeluGrad : public CtxGeluGrad {
}
};
template <typename T>
struct OP_ReluGrad : public CtxReluGrad {
__device__ __inline__ T operator()(const T& dy, const T& x) const {
return x > T {0} ? dy : T {0};
}
};
#define BINARY_ELEMENTWISE_IMPL(name) \
BINARY_ELEMENTWISE_IMPL_DECLARATION(name) { \
BinaryElementWiseNoBroadcastImpl(lhs_data, rhs_data, \

View file

@ -9,10 +9,12 @@ namespace cuda {
typedef onnxruntime::cuda::CtxNull CtxGeluGrad;
typedef onnxruntime::cuda::CtxNull CtxFastGeluGrad;
typedef onnxruntime::cuda::CtxNull CtxReluGrad;
#define ACTIVATION_GRAD_OPS() \
ACTIVATION_GRAD_OP_NAME(GeluGrad) \
ACTIVATION_GRAD_OP_NAME(FastGeluGrad)
ACTIVATION_GRAD_OP_NAME(FastGeluGrad) \
ACTIVATION_GRAD_OP_NAME(ReluGrad)
#define BINARY_ELEMENTWISE_IMPL_DECLARATION(name) \
template <typename T> \

View file

@ -85,6 +85,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasGeluGrad_dX);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasFastGeluGrad_dX);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ReluGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, ReluGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ReluGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, IsFinite);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, IsFinite);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, IsFinite);
@ -205,6 +209,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, FastGeluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasGeluGrad_dX)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasFastGeluGrad_dX)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ReluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, ReluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ReluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, IsFinite)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, IsFinite)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, IsFinite)>,