From fa68bbc82e030c8f4cacba8a222bbedbc2107042 Mon Sep 17 00:00:00 2001 From: jingyanwangms <47403504+jingyanwangms@users.noreply.github.com> Date: Sat, 22 Aug 2020 01:03:44 -0700 Subject: [PATCH] Relu grad kernel (#4864) * create branch for debug * move unit test * more changes * move relu to activations_grad* * Fix ReluGrad Domain and opset version * added unit test, CudaKernelTest.Relu_basic doesn't work yet * remove CudaKernelTest.Relu_basic * PR comment * add unit test ReluGradTest_Basic Co-authored-by: Jingyan Wang Co-authored-by: Sherlock Huang --- .../cpu/activation/activation_op_test.cc | 83 +++++++++---------- .../core/graph/gradient_builder.cc | 2 +- .../core/graph/training_op_defs.cc | 3 +- .../test/gradient/gradient_ops_test.cc | 8 +- .../cpu/activation/activation_op_test.cc | 20 +++++ .../training_ops/cuda/activations_test.cc | 7 ++ .../training_ops/cpu/cpu_training_kernels.cc | 4 +- .../training_ops/cpu/op_gradients.cc | 24 +++--- .../cuda/activation/activations_grad.cc | 1 + .../cuda/activation/activations_grad.h | 11 +++ .../cuda/activation/activations_grad_impl.cu | 7 ++ .../cuda/activation/activations_grad_impl.h | 4 +- .../cuda/cuda_training_kernels.cc | 7 ++ 13 files changed, 119 insertions(+), 62 deletions(-) diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc index b58779df7f..1921e4335c 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc @@ -7,63 +7,62 @@ namespace onnxruntime { namespace test { - TEST_F(ActivationOpTest, Sigmoid) { TestActivationOp("Sigmoid", - input_values, - [](float x) { - auto y = 1.f / (1.f + std::exp(-std::abs(x))); // safe sigmoid - y = x > 0 ? y : 1 - y; - return y; - }); + input_values, + [](float x) { + auto y = 1.f / (1.f + std::exp(-std::abs(x))); // safe sigmoid + y = x > 0 ? y : 1 - y; + return y; + }); } TEST_F(ActivationOpTest, HardSigmoid) { float alpha = 0.2f; float beta = 0.5f; TestActivationOp("HardSigmoid", - input_values, - [alpha, beta](float x) { - return std::max(std::min((alpha * x + beta), 1.0f), 0.0f); - }, - {{"alpha", alpha}, {"beta", beta}}); + input_values, + [alpha, beta](float x) { + return std::max(std::min((alpha * x + beta), 1.0f), 0.0f); + }, + {{"alpha", alpha}, {"beta", beta}}); } TEST_F(ActivationOpTest, Tanh) { TestActivationOp("Tanh", - input_values, - [](float x) { return std::tanh(x); }); + input_values, + [](float x) { return std::tanh(x); }); } TEST_F(ActivationOpTest, Relu) { TestActivationOp("Relu", - input_values, - [](float x) { return std::max(x, 0.0f); }); + input_values, + [](float x) { return std::max(x, 0.0f); }); } TEST_F(ActivationOpTest, Elu) { float alpha = 0.1f; TestActivationOp("Elu", - input_values, - [alpha](float x) { return (x >= 0) ? x : alpha * (exp(x) - 1); }, - {{"alpha", alpha}}); + input_values, + [alpha](float x) { return (x >= 0) ? x : alpha * (exp(x) - 1); }, + {{"alpha", alpha}}); } TEST_F(ActivationOpTest, LeakyRelu) { float alpha = 0.1f; TestActivationOp("LeakyRelu", - input_values, - [alpha](float x) { return (x >= 0) ? x : alpha * x; }, - {{"alpha", alpha}}); + input_values, + [alpha](float x) { return (x >= 0) ? x : alpha * x; }, + {{"alpha", alpha}}); } TEST_F(ActivationOpTest, ThresholdedRelu) { float alpha = 0.1f; TestActivationOp( - "ThresholdedRelu", - input_values, - [alpha](float x) { return (x >= alpha) ? x : 0; }, - {{"alpha", alpha}}, true, 10); + "ThresholdedRelu", + input_values, + [alpha](float x) { return (x >= alpha) ? x : 0; }, + {{"alpha", alpha}}, true, 10); } TEST_F(ActivationOpTest, Selu) { @@ -71,9 +70,9 @@ TEST_F(ActivationOpTest, Selu) { static constexpr float gamma = 1.0507f; TestActivationOp("Selu", - input_values, - [](float x) { return x <= 0 ? gamma * (alpha * exp(x) - alpha) : gamma * x; }, - {{"alpha", alpha}, {"gamma", gamma}}); + input_values, + [](float x) { return x <= 0 ? gamma * (alpha * exp(x) - alpha) : gamma * x; }, + {{"alpha", alpha}, {"gamma", gamma}}); } TEST_F(ActivationOpTest, Selu_Attributes) { @@ -81,9 +80,9 @@ TEST_F(ActivationOpTest, Selu_Attributes) { static constexpr float gamma = 0.5f; TestActivationOp("Selu", - input_values, - [](float x) { return x <= 0 ? gamma * (alpha * exp(x) - alpha) : gamma * x; }, - {{"alpha", alpha}, {"gamma", gamma}}); + input_values, + [](float x) { return x <= 0 ? gamma * (alpha * exp(x) - alpha) : gamma * x; }, + {{"alpha", alpha}, {"gamma", gamma}}); } TEST_F(ActivationOpTest, PRelu) { @@ -146,20 +145,20 @@ TEST_F(ActivationOpTest, PRelu_MultiChannel) { TEST_F(ActivationOpTest, Softplus) { TestActivationOp("Softplus", - input_values, - [](float x) { - if (x > 0) - return x + logf(expf(-x) + 1); - else - return logf(expf(x) + 1); - }); + input_values, + [](float x) { + if (x > 0) + return x + logf(expf(-x) + 1); + else + return logf(expf(x) + 1); + }); } TEST_F(ActivationOpNoInfTest, Softsign) { TestActivationOp( - "Softsign", - input_values, - [](float x) { return x / (1 + std::abs(x)); }, {}, false); // Disable TensorRT because result mismatches + "Softsign", + input_values, + [](float x) { return x / (1 + std::abs(x)); }, {}, false); // Disable TensorRT because result mismatches } } // namespace test diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 71fbb7be84..73a1096da4 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -781,7 +781,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetGatherElementsGradient) { IMPLEMENT_GRADIENT_BUILDER(GetReluGradient) { return std::vector{ - NodeDef("ReluGrad", + NodeDef(OpDef{"ReluGrad", kMSDomain, 1}, {GO(0), O(0)}, {GI(0)})}; } diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index ef1283862f..8116cb7ae9 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -322,7 +322,8 @@ OpSchema& RegisterLambOpSchema(OpSchema&& op_schema) { void RegisterTrainingOpSchemas() { ONNX_CONTRIB_OPERATOR_SCHEMA(ReluGrad) - .SinceVersion(9) + .SetDomain(kMSDomain) + .SinceVersion(1) .Input(0, "dY", "Gradient of output Y", "T") .Input(1, "X", "Input tensor", "T") .Output(0, "dX", "Gradient of input X", "T") diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index be7ef952b3..c1c84ae41f 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -523,6 +523,10 @@ TEST(GradientCheckerTest, ReduceLogSumExpGrad) { RunReductionTests(op_def); } +TEST(GradientCheckerTest, ReluGrad) { + UnaryOpGradientTest("Relu"); +} + #ifndef USE_CUDA TEST(GradientCheckerTest, CastGrad) { // A dummy test that cast float to float @@ -540,10 +544,6 @@ TEST(GradientCheckerTest, CastGrad) { } } -TEST(GradientCheckerTest, ReluGrad) { - UnaryOpGradientTest("Relu"); -} - TEST(GradientCheckerTest, SplitGrad) { TensorShape shape({9, 5}); float max_error; diff --git a/orttraining/orttraining/test/training_ops/cpu/activation/activation_op_test.cc b/orttraining/orttraining/test/training_ops/cpu/activation/activation_op_test.cc index 6ec19d8dc0..3508fa5899 100644 --- a/orttraining/orttraining/test/training_ops/cpu/activation/activation_op_test.cc +++ b/orttraining/orttraining/test/training_ops/cpu/activation/activation_op_test.cc @@ -71,6 +71,10 @@ float GeluApproximationGrad(float dy, float x) { float result = dy * 0.5f * (tanh_value + (sech_sqr_value * (kAlpha * x + kBeta * x_cube)) + 1.0f); return result; } + +float ReluGrad(float dy, float x) { + return x > 0 ? dy : 0; +} } // namespace TEST(GeluGradTest, Basic) { @@ -139,6 +143,22 @@ TEST(BiasFastGeluGradDxTest, Basic) { {}, 1, kMSDomain); } +TEST(ReluGradTest, Basic) { + const std::vector x_vals = {-1.0f, 0, 1.0f, 100.0f, -100.0f, 1000.0f, -1000.0f}; + const std::vector dY(7, 1.0f); + + TestElementwiseGradientOp( + "ReluGrad", + {{"dY", dY}, {"X", x_vals}}, + [](const std::vector& params) { + ORT_ENFORCE(params.size() == 2); + const auto dy = params[0], x = params[1]; + + return ReluGrad(dy, x); + }, + {}, 1, kMSDomain); +} + namespace { template void TestBiasGeluGradBroadcastBias(const std::string& op, int opset_version, const std::string& domain, diff --git a/orttraining/orttraining/test/training_ops/cuda/activations_test.cc b/orttraining/orttraining/test/training_ops/cuda/activations_test.cc index 43429513f1..ea02e651a9 100644 --- a/orttraining/orttraining/test/training_ops/cuda/activations_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/activations_test.cc @@ -63,6 +63,13 @@ TEST(CudaKernelTest, FastGeluGrad_basic) { } } +TEST(CudaKernelTest, ReluGrad_basic) { + std::vector> test_dims{{4}, {16, 2}, {8, 2, 128, 128}}; + for (const auto& test_dim : test_dims) { + TestActivations(test_dim, "ReluGrad", true /* grad_op */); + } +} + static void TestActivationsWithBroadcastBias( const std::vector& tensor_dim, const std::string& operator_name, diff --git a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc index 67d753aae6..67b758c2f0 100644 --- a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc @@ -34,7 +34,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int64_t, SoftmaxCrossEntropyLossGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, SinGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ConvGrad); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ReluGrad); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ReluGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SoftmaxGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, LogSoftmaxGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, AveragePoolGrad); @@ -126,7 +126,7 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/cpu/op_gradients.cc b/orttraining/orttraining/training_ops/cpu/op_gradients.cc index 4ae95b7a31..c82eec8644 100644 --- a/orttraining/orttraining/training_ops/cpu/op_gradients.cc +++ b/orttraining/orttraining/training_ops/cpu/op_gradients.cc @@ -29,9 +29,11 @@ Status SinGrad::Compute(OpKernelContext* context) const { return Status::OK(); } -ONNX_CPU_OPERATOR_KERNEL( +ONNX_OPERATOR_KERNEL_EX( ReluGrad, - 9, + kMSDomain, + 1, + kCpuExecutionProvider, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), ReluGrad); @@ -101,12 +103,12 @@ Status SoftmaxGrad::Compute(OpKernelContext* context) const { } ONNX_OPERATOR_KERNEL_EX( - LogSoftmaxGrad, - kMSDomain, - 1, - kCpuExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - LogSoftmaxGrad); + LogSoftmaxGrad, + kMSDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + LogSoftmaxGrad); template Status LogSoftmaxGrad::Compute(OpKernelContext* context) const { @@ -133,14 +135,14 @@ Status LogSoftmaxGrad::Compute(OpKernelContext* context) const { std::vector eY(nd); float* eYdata = eY.data(); - + // dX_ai = d(log Y_ai) - [sum_j d(log Y_aj)] exp(log Y_ai) gsl::copy(gsl::make_span(dYdata, nd), gsl::make_span(dXdata, nd)); math::Exp(nd, Ydata, eYdata, nullptr); for (size_t i = 0; i < N; ++i) { float sdY; - math::Sum(d, dYdata + i*d, &sdY, nullptr, nullptr); - math::Axpy(d, -sdY, eYdata + i*d, dXdata + i*d, nullptr); + math::Sum(d, dYdata + i * d, &sdY, nullptr, nullptr); + math::Axpy(d, -sdY, eYdata + i * d, dXdata + i * d, nullptr); } return Status::OK(); diff --git a/orttraining/orttraining/training_ops/cuda/activation/activations_grad.cc b/orttraining/orttraining/training_ops/cuda/activation/activations_grad.cc index e45b300f68..7e4f344f44 100644 --- a/orttraining/orttraining/training_ops/cuda/activation/activations_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/activation/activations_grad.cc @@ -45,6 +45,7 @@ namespace cuda { ACTIVATION_GRAD_OP_HFD(GeluGrad, 1, kMSDomain); ACTIVATION_GRAD_OP_HFD(FastGeluGrad, 1, kMSDomain); +ACTIVATION_GRAD_OP_HFD(ReluGrad, 1, kMSDomain); } //namespace cuda } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/activation/activations_grad.h b/orttraining/orttraining/training_ops/cuda/activation/activations_grad.h index fe4ad62fd6..a31be31ea7 100644 --- a/orttraining/orttraining/training_ops/cuda/activation/activations_grad.h +++ b/orttraining/orttraining/training_ops/cuda/activation/activations_grad.h @@ -33,5 +33,16 @@ class FastGeluGrad final : public BinaryElementwise { MAKE_FUNC_CTX_NULL() }; +template +class ReluGrad final : public BinaryElementwise { + public: + ReluGrad(const OpKernelInfo& info) : BinaryElementwise(info) {} + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + MAKE_FUNC_CTX_NULL() +}; + } // namespace cuda } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu b/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu index ff9772c0cc..caa38cac0d 100644 --- a/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu @@ -24,6 +24,13 @@ struct OP_FastGeluGrad : public CtxGeluGrad { } }; +template +struct OP_ReluGrad : public CtxReluGrad { + __device__ __inline__ T operator()(const T& dy, const T& x) const { + return x > T {0} ? dy : T {0}; + } +}; + #define BINARY_ELEMENTWISE_IMPL(name) \ BINARY_ELEMENTWISE_IMPL_DECLARATION(name) { \ BinaryElementWiseNoBroadcastImpl(lhs_data, rhs_data, \ diff --git a/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.h b/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.h index 8525a10327..bc5e292652 100644 --- a/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.h +++ b/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.h @@ -9,10 +9,12 @@ namespace cuda { typedef onnxruntime::cuda::CtxNull CtxGeluGrad; typedef onnxruntime::cuda::CtxNull CtxFastGeluGrad; +typedef onnxruntime::cuda::CtxNull CtxReluGrad; #define ACTIVATION_GRAD_OPS() \ ACTIVATION_GRAD_OP_NAME(GeluGrad) \ - ACTIVATION_GRAD_OP_NAME(FastGeluGrad) + ACTIVATION_GRAD_OP_NAME(FastGeluGrad) \ + ACTIVATION_GRAD_OP_NAME(ReluGrad) #define BINARY_ELEMENTWISE_IMPL_DECLARATION(name) \ template \ diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index 7934c9d991..466bf67b41 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -85,6 +85,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasGeluGrad_dX); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasFastGeluGrad_dX); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ReluGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, ReluGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ReluGrad); + class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, IsFinite); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, IsFinite); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, IsFinite); @@ -205,6 +209,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo,