mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
Relu grad kernel (#4864)
* create branch for debug * move unit test * more changes * move relu to activations_grad* * Fix ReluGrad Domain and opset version * added unit test, CudaKernelTest.Relu_basic doesn't work yet * remove CudaKernelTest.Relu_basic * PR comment * add unit test ReluGradTest_Basic Co-authored-by: Jingyan Wang <jingywa@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net> Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
parent
dce2ce7a4f
commit
fa68bbc82e
13 changed files with 119 additions and 62 deletions
|
|
@ -7,63 +7,62 @@
|
|||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
|
||||
TEST_F(ActivationOpTest, Sigmoid) {
|
||||
TestActivationOp("Sigmoid",
|
||||
input_values,
|
||||
[](float x) {
|
||||
auto y = 1.f / (1.f + std::exp(-std::abs(x))); // safe sigmoid
|
||||
y = x > 0 ? y : 1 - y;
|
||||
return y;
|
||||
});
|
||||
input_values,
|
||||
[](float x) {
|
||||
auto y = 1.f / (1.f + std::exp(-std::abs(x))); // safe sigmoid
|
||||
y = x > 0 ? y : 1 - y;
|
||||
return y;
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(ActivationOpTest, HardSigmoid) {
|
||||
float alpha = 0.2f;
|
||||
float beta = 0.5f;
|
||||
TestActivationOp("HardSigmoid",
|
||||
input_values,
|
||||
[alpha, beta](float x) {
|
||||
return std::max(std::min((alpha * x + beta), 1.0f), 0.0f);
|
||||
},
|
||||
{{"alpha", alpha}, {"beta", beta}});
|
||||
input_values,
|
||||
[alpha, beta](float x) {
|
||||
return std::max(std::min((alpha * x + beta), 1.0f), 0.0f);
|
||||
},
|
||||
{{"alpha", alpha}, {"beta", beta}});
|
||||
}
|
||||
|
||||
TEST_F(ActivationOpTest, Tanh) {
|
||||
TestActivationOp("Tanh",
|
||||
input_values,
|
||||
[](float x) { return std::tanh(x); });
|
||||
input_values,
|
||||
[](float x) { return std::tanh(x); });
|
||||
}
|
||||
|
||||
TEST_F(ActivationOpTest, Relu) {
|
||||
TestActivationOp("Relu",
|
||||
input_values,
|
||||
[](float x) { return std::max(x, 0.0f); });
|
||||
input_values,
|
||||
[](float x) { return std::max(x, 0.0f); });
|
||||
}
|
||||
|
||||
TEST_F(ActivationOpTest, Elu) {
|
||||
float alpha = 0.1f;
|
||||
TestActivationOp("Elu",
|
||||
input_values,
|
||||
[alpha](float x) { return (x >= 0) ? x : alpha * (exp(x) - 1); },
|
||||
{{"alpha", alpha}});
|
||||
input_values,
|
||||
[alpha](float x) { return (x >= 0) ? x : alpha * (exp(x) - 1); },
|
||||
{{"alpha", alpha}});
|
||||
}
|
||||
|
||||
TEST_F(ActivationOpTest, LeakyRelu) {
|
||||
float alpha = 0.1f;
|
||||
TestActivationOp("LeakyRelu",
|
||||
input_values,
|
||||
[alpha](float x) { return (x >= 0) ? x : alpha * x; },
|
||||
{{"alpha", alpha}});
|
||||
input_values,
|
||||
[alpha](float x) { return (x >= 0) ? x : alpha * x; },
|
||||
{{"alpha", alpha}});
|
||||
}
|
||||
|
||||
TEST_F(ActivationOpTest, ThresholdedRelu) {
|
||||
float alpha = 0.1f;
|
||||
TestActivationOp(
|
||||
"ThresholdedRelu",
|
||||
input_values,
|
||||
[alpha](float x) { return (x >= alpha) ? x : 0; },
|
||||
{{"alpha", alpha}}, true, 10);
|
||||
"ThresholdedRelu",
|
||||
input_values,
|
||||
[alpha](float x) { return (x >= alpha) ? x : 0; },
|
||||
{{"alpha", alpha}}, true, 10);
|
||||
}
|
||||
|
||||
TEST_F(ActivationOpTest, Selu) {
|
||||
|
|
@ -71,9 +70,9 @@ TEST_F(ActivationOpTest, Selu) {
|
|||
static constexpr float gamma = 1.0507f;
|
||||
|
||||
TestActivationOp("Selu",
|
||||
input_values,
|
||||
[](float x) { return x <= 0 ? gamma * (alpha * exp(x) - alpha) : gamma * x; },
|
||||
{{"alpha", alpha}, {"gamma", gamma}});
|
||||
input_values,
|
||||
[](float x) { return x <= 0 ? gamma * (alpha * exp(x) - alpha) : gamma * x; },
|
||||
{{"alpha", alpha}, {"gamma", gamma}});
|
||||
}
|
||||
|
||||
TEST_F(ActivationOpTest, Selu_Attributes) {
|
||||
|
|
@ -81,9 +80,9 @@ TEST_F(ActivationOpTest, Selu_Attributes) {
|
|||
static constexpr float gamma = 0.5f;
|
||||
|
||||
TestActivationOp("Selu",
|
||||
input_values,
|
||||
[](float x) { return x <= 0 ? gamma * (alpha * exp(x) - alpha) : gamma * x; },
|
||||
{{"alpha", alpha}, {"gamma", gamma}});
|
||||
input_values,
|
||||
[](float x) { return x <= 0 ? gamma * (alpha * exp(x) - alpha) : gamma * x; },
|
||||
{{"alpha", alpha}, {"gamma", gamma}});
|
||||
}
|
||||
|
||||
TEST_F(ActivationOpTest, PRelu) {
|
||||
|
|
@ -146,20 +145,20 @@ TEST_F(ActivationOpTest, PRelu_MultiChannel) {
|
|||
|
||||
TEST_F(ActivationOpTest, Softplus) {
|
||||
TestActivationOp("Softplus",
|
||||
input_values,
|
||||
[](float x) {
|
||||
if (x > 0)
|
||||
return x + logf(expf(-x) + 1);
|
||||
else
|
||||
return logf(expf(x) + 1);
|
||||
});
|
||||
input_values,
|
||||
[](float x) {
|
||||
if (x > 0)
|
||||
return x + logf(expf(-x) + 1);
|
||||
else
|
||||
return logf(expf(x) + 1);
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(ActivationOpNoInfTest, Softsign) {
|
||||
TestActivationOp(
|
||||
"Softsign",
|
||||
input_values,
|
||||
[](float x) { return x / (1 + std::abs(x)); }, {}, false); // Disable TensorRT because result mismatches
|
||||
"Softsign",
|
||||
input_values,
|
||||
[](float x) { return x / (1 + std::abs(x)); }, {}, false); // Disable TensorRT because result mismatches
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
|
|
|
|||
|
|
@ -781,7 +781,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetGatherElementsGradient) {
|
|||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetReluGradient) {
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef("ReluGrad",
|
||||
NodeDef(OpDef{"ReluGrad", kMSDomain, 1},
|
||||
{GO(0), O(0)},
|
||||
{GI(0)})};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -322,7 +322,8 @@ OpSchema& RegisterLambOpSchema(OpSchema&& op_schema) {
|
|||
|
||||
void RegisterTrainingOpSchemas() {
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(ReluGrad)
|
||||
.SinceVersion(9)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.Input(0, "dY", "Gradient of output Y", "T")
|
||||
.Input(1, "X", "Input tensor", "T")
|
||||
.Output(0, "dX", "Gradient of input X", "T")
|
||||
|
|
|
|||
|
|
@ -523,6 +523,10 @@ TEST(GradientCheckerTest, ReduceLogSumExpGrad) {
|
|||
RunReductionTests(op_def);
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, ReluGrad) {
|
||||
UnaryOpGradientTest("Relu");
|
||||
}
|
||||
|
||||
#ifndef USE_CUDA
|
||||
TEST(GradientCheckerTest, CastGrad) {
|
||||
// A dummy test that cast float to float
|
||||
|
|
@ -540,10 +544,6 @@ TEST(GradientCheckerTest, CastGrad) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, ReluGrad) {
|
||||
UnaryOpGradientTest("Relu");
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, SplitGrad) {
|
||||
TensorShape shape({9, 5});
|
||||
float max_error;
|
||||
|
|
|
|||
|
|
@ -71,6 +71,10 @@ float GeluApproximationGrad(float dy, float x) {
|
|||
float result = dy * 0.5f * (tanh_value + (sech_sqr_value * (kAlpha * x + kBeta * x_cube)) + 1.0f);
|
||||
return result;
|
||||
}
|
||||
|
||||
float ReluGrad(float dy, float x) {
|
||||
return x > 0 ? dy : 0;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(GeluGradTest, Basic) {
|
||||
|
|
@ -139,6 +143,22 @@ TEST(BiasFastGeluGradDxTest, Basic) {
|
|||
{}, 1, kMSDomain);
|
||||
}
|
||||
|
||||
TEST(ReluGradTest, Basic) {
|
||||
const std::vector<float> x_vals = {-1.0f, 0, 1.0f, 100.0f, -100.0f, 1000.0f, -1000.0f};
|
||||
const std::vector<float> dY(7, 1.0f);
|
||||
|
||||
TestElementwiseGradientOp(
|
||||
"ReluGrad",
|
||||
{{"dY", dY}, {"X", x_vals}},
|
||||
[](const std::vector<float>& params) {
|
||||
ORT_ENFORCE(params.size() == 2);
|
||||
const auto dy = params[0], x = params[1];
|
||||
|
||||
return ReluGrad(dy, x);
|
||||
},
|
||||
{}, 1, kMSDomain);
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename TComputeGeluGradScalarFn>
|
||||
void TestBiasGeluGradBroadcastBias(const std::string& op, int opset_version, const std::string& domain,
|
||||
|
|
|
|||
|
|
@ -63,6 +63,13 @@ TEST(CudaKernelTest, FastGeluGrad_basic) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, ReluGrad_basic) {
|
||||
std::vector<std::vector<int64_t>> test_dims{{4}, {16, 2}, {8, 2, 128, 128}};
|
||||
for (const auto& test_dim : test_dims) {
|
||||
TestActivations(test_dim, "ReluGrad", true /* grad_op */);
|
||||
}
|
||||
}
|
||||
|
||||
static void TestActivationsWithBroadcastBias(
|
||||
const std::vector<int64_t>& tensor_dim,
|
||||
const std::string& operator_name,
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int64_t, SoftmaxCrossEntropyLossGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, SinGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ConvGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ReluGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ReluGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SoftmaxGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, LogSoftmaxGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, AveragePoolGrad);
|
||||
|
|
@ -126,7 +126,7 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int64_t, SoftmaxCrossEntropyLossGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, SinGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ConvGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ReluGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ReluGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SoftmaxGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, LogSoftmaxGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, AveragePoolGrad)>,
|
||||
|
|
|
|||
|
|
@ -29,9 +29,11 @@ Status SinGrad<T>::Compute(OpKernelContext* context) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
ReluGrad,
|
||||
9,
|
||||
kMSDomain,
|
||||
1,
|
||||
kCpuExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
ReluGrad<float>);
|
||||
|
||||
|
|
@ -101,12 +103,12 @@ Status SoftmaxGrad<T>::Compute(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
LogSoftmaxGrad,
|
||||
kMSDomain,
|
||||
1,
|
||||
kCpuExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
LogSoftmaxGrad<float>);
|
||||
LogSoftmaxGrad,
|
||||
kMSDomain,
|
||||
1,
|
||||
kCpuExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
LogSoftmaxGrad<float>);
|
||||
|
||||
template <typename T>
|
||||
Status LogSoftmaxGrad<T>::Compute(OpKernelContext* context) const {
|
||||
|
|
@ -133,14 +135,14 @@ Status LogSoftmaxGrad<T>::Compute(OpKernelContext* context) const {
|
|||
|
||||
std::vector<float> eY(nd);
|
||||
float* eYdata = eY.data();
|
||||
|
||||
|
||||
// dX_ai = d(log Y_ai) - [sum_j d(log Y_aj)] exp(log Y_ai)
|
||||
gsl::copy(gsl::make_span(dYdata, nd), gsl::make_span(dXdata, nd));
|
||||
math::Exp<float, CPUMathUtil>(nd, Ydata, eYdata, nullptr);
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
float sdY;
|
||||
math::Sum<float, CPUMathUtil>(d, dYdata + i*d, &sdY, nullptr, nullptr);
|
||||
math::Axpy<float, CPUMathUtil>(d, -sdY, eYdata + i*d, dXdata + i*d, nullptr);
|
||||
math::Sum<float, CPUMathUtil>(d, dYdata + i * d, &sdY, nullptr, nullptr);
|
||||
math::Axpy<float, CPUMathUtil>(d, -sdY, eYdata + i * d, dXdata + i * d, nullptr);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ namespace cuda {
|
|||
|
||||
ACTIVATION_GRAD_OP_HFD(GeluGrad, 1, kMSDomain);
|
||||
ACTIVATION_GRAD_OP_HFD(FastGeluGrad, 1, kMSDomain);
|
||||
ACTIVATION_GRAD_OP_HFD(ReluGrad, 1, kMSDomain);
|
||||
|
||||
} //namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -33,5 +33,16 @@ class FastGeluGrad final : public BinaryElementwise<ShouldNotBroadcast> {
|
|||
MAKE_FUNC_CTX_NULL()
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class ReluGrad final : public BinaryElementwise<ShouldNotBroadcast> {
|
||||
public:
|
||||
ReluGrad(const OpKernelInfo& info) : BinaryElementwise(info) {}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
MAKE_FUNC_CTX_NULL()
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -24,6 +24,13 @@ struct OP_FastGeluGrad : public CtxGeluGrad {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct OP_ReluGrad : public CtxReluGrad {
|
||||
__device__ __inline__ T operator()(const T& dy, const T& x) const {
|
||||
return x > T {0} ? dy : T {0};
|
||||
}
|
||||
};
|
||||
|
||||
#define BINARY_ELEMENTWISE_IMPL(name) \
|
||||
BINARY_ELEMENTWISE_IMPL_DECLARATION(name) { \
|
||||
BinaryElementWiseNoBroadcastImpl(lhs_data, rhs_data, \
|
||||
|
|
|
|||
|
|
@ -9,10 +9,12 @@ namespace cuda {
|
|||
|
||||
typedef onnxruntime::cuda::CtxNull CtxGeluGrad;
|
||||
typedef onnxruntime::cuda::CtxNull CtxFastGeluGrad;
|
||||
typedef onnxruntime::cuda::CtxNull CtxReluGrad;
|
||||
|
||||
#define ACTIVATION_GRAD_OPS() \
|
||||
ACTIVATION_GRAD_OP_NAME(GeluGrad) \
|
||||
ACTIVATION_GRAD_OP_NAME(FastGeluGrad)
|
||||
ACTIVATION_GRAD_OP_NAME(FastGeluGrad) \
|
||||
ACTIVATION_GRAD_OP_NAME(ReluGrad)
|
||||
|
||||
#define BINARY_ELEMENTWISE_IMPL_DECLARATION(name) \
|
||||
template <typename T> \
|
||||
|
|
|
|||
|
|
@ -85,6 +85,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasGeluGrad_dX);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasFastGeluGrad_dX);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ReluGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, ReluGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ReluGrad);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, IsFinite);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, IsFinite);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, IsFinite);
|
||||
|
|
@ -205,6 +209,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, FastGeluGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasGeluGrad_dX)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasFastGeluGrad_dX)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ReluGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, ReluGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ReluGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, IsFinite)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, IsFinite)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, IsFinite)>,
|
||||
|
|
|
|||
Loading…
Reference in a new issue