mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
LeakyRelu Gradient (#17039)
This commit is contained in:
parent
0180c0429f
commit
3e7f70bf88
18 changed files with 183 additions and 2 deletions
|
|
@ -64,6 +64,10 @@ constexpr float SigmoidGrad(float dy, float y) {
|
|||
constexpr float TanhGrad(float dy, float y) {
|
||||
return dy * (1 - y * y);
|
||||
}
|
||||
|
||||
constexpr float LeakyReluGrad(float dy, float y, float alpha) {
|
||||
return dy * (y > 0.0f ? 1.0f : alpha);
|
||||
}
|
||||
} // namespace
|
||||
#endif
|
||||
|
||||
|
|
@ -669,6 +673,23 @@ TEST(TanhGradInferenceTest, Basic) {
|
|||
},
|
||||
{}, 1, kMSDomain);
|
||||
}
|
||||
|
||||
TEST(LeakyReluGradInferenceTest, Basic) {
|
||||
const std::vector<float> y_vals = {-1.0f, 0, 1.0f, 100.0f, -100.0f, 1000.0f, -1000.0f};
|
||||
const std::vector<float> dY(7, 1.0f);
|
||||
float alpha = 0.5f;
|
||||
|
||||
TestElementwiseGradientOp(
|
||||
"LeakyReluGrad",
|
||||
{{"dY", dY}, {"Y", y_vals}},
|
||||
[alpha](const std::vector<float>& params) {
|
||||
ORT_ENFORCE(params.size() == 2);
|
||||
const auto dy = params[0], y = params[1];
|
||||
|
||||
return LeakyReluGrad(dy, y, alpha);
|
||||
},
|
||||
{{"alpha", alpha}}, 1, kMSDomain);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace test
|
||||
|
|
|
|||
|
|
@ -2065,5 +2065,10 @@ IMPLEMENT_GRADIENT_BUILDER(GetReciprocalGradient) {
|
|||
NodeDef("Mul", {GO(0), IA("Neg_Square_O0")}, {GI(0)})};
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetLeakyReluGradient) {
|
||||
return {NodeDef(OpDef{"LeakyReluGrad", kMSDomain, 1},
|
||||
{GO(0), O(0)}, {GI(0)}, SrcNodeAttributes())};
|
||||
}
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -87,6 +87,7 @@ DECLARE_GRADIENT_BUILDER(GetFakeQuantGradient)
|
|||
DECLARE_GRADIENT_BUILDER(GetLSTMGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetGRUGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetReciprocalGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetLeakyReluGradient)
|
||||
|
||||
DECLARE_GRADIENT_BUILDER(GetExternalGradient)
|
||||
|
||||
|
|
|
|||
|
|
@ -119,6 +119,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
|
|||
REGISTER_GRADIENT_BUILDER("LSTMTraining", GetLSTMGradient);
|
||||
REGISTER_GRADIENT_BUILDER("GRUTraining", GetGRUGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Reciprocal", GetReciprocalGradient);
|
||||
REGISTER_GRADIENT_BUILDER("LeakyRelu", GetLeakyReluGradient);
|
||||
|
||||
REGISTER_GRADIENT_BUILDER("ExternalGradient", GetExternalGradient);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -2900,6 +2900,19 @@ Example 4:
|
|||
return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {});
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(LeakyReluGrad)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.SetDoc("Gradient operator for LeakyRelu.")
|
||||
.Attr("alpha", "Alpha (negative slope) value.", AttributeProto::FLOAT, 0.01f)
|
||||
.AllowUncheckedAttributes()
|
||||
.Input(0, "dY", "The gradient tensor from output.", "T")
|
||||
.Input(1, "Y", "The output tensor. ", "T")
|
||||
.Output(0, "dX", "Gradient of the input.", "T")
|
||||
.TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
|
||||
"Constrain input and output types to float tensors.")
|
||||
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(LayerNormalizationGrad)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
|
|
|
|||
|
|
@ -3033,6 +3033,12 @@ TEST(GradientCheckerTest, ReciprocalGrad) {
|
|||
UnaryOpGradientTest("Reciprocal", kOnnxDomain, 12, nullptr, &transformer);
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, LeakyReluGrad) {
|
||||
// Gradient is non continuous at 0, so we need to avoid it.
|
||||
std::function<float(float)> transformer = [](float x) { return x > 0 ? x + 0.2f : x - 0.2f; };
|
||||
UnaryOpGradientTest("LeakyRelu", kOnnxDomain, 16, nullptr, &transformer);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
||||
|
|
|
|||
|
|
@ -6162,7 +6162,39 @@ def test_reciprocal_gradient():
|
|||
pt_model = ReciprocalModel().to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
pt_x = torch.zeros(3, 224, 224, requires_grad=True, device=device)
|
||||
pt_x = torch.randn(3, 224, 224, requires_grad=True, device=device)
|
||||
with torch.no_grad():
|
||||
pt_x[pt_x <= 0] -= 0.2
|
||||
pt_x[pt_x > 0] += 0.2
|
||||
ort_x = copy.deepcopy(pt_x)
|
||||
|
||||
pt_prediction, pt_loss = run_step(pt_model, pt_x)
|
||||
ort_prediction, ort_loss = run_step(ort_model, ort_x)
|
||||
_test_helpers.assert_values_are_close(pt_prediction, ort_prediction)
|
||||
_test_helpers.assert_values_are_close(pt_loss, ort_loss)
|
||||
_test_helpers.assert_values_are_close(pt_x.grad, ort_x.grad)
|
||||
|
||||
|
||||
def test_leakyrelu_gradient():
|
||||
class LeakyReluModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.leakyrelu = nn.LeakyReLU(0.5)
|
||||
|
||||
def forward(self, x):
|
||||
return self.leakyrelu(x)
|
||||
|
||||
def run_step(model, x):
|
||||
prediction = model(x)
|
||||
loss = prediction.sum()
|
||||
loss.backward()
|
||||
return prediction, loss
|
||||
|
||||
device = "cuda"
|
||||
pt_model = LeakyReluModel().to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
pt_x = torch.randn(3, 224, 224, requires_grad=True, device=device)
|
||||
with torch.no_grad():
|
||||
pt_x[pt_x <= 0] -= 0.2
|
||||
pt_x[pt_x > 0] += 0.2
|
||||
|
|
|
|||
|
|
@ -89,6 +89,10 @@ float QuickGeluGrad(float dy, float x, float alpha) {
|
|||
float sigmoid = v >= 0 ? 1.f / (1.f + std::exp(-v)) : 1.f - 1.f / (1 + std::exp(v));
|
||||
return dy * sigmoid * (1 + v * (1 - sigmoid));
|
||||
}
|
||||
|
||||
constexpr float LeakyReluGrad(float dy, float y, float alpha) {
|
||||
return dy * (y > 0.0f ? 1.0f : alpha);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(GeluGradTest, Basic) {
|
||||
|
|
@ -263,6 +267,23 @@ TEST(QuickGeluGradTest, Basic) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(LeakyReluGradTest, Basic) {
|
||||
const std::vector<float> y_vals = {-1.0f, 0, 1.0f, 100.0f, -100.0f, 1000.0f, -1000.0f};
|
||||
const std::vector<float> dY(7, 1.0f);
|
||||
float alpha = 0.5f;
|
||||
|
||||
TestElementwiseGradientOp(
|
||||
"LeakyReluGrad",
|
||||
{{"dY", dY}, {"Y", y_vals}},
|
||||
[alpha](const std::vector<float>& params) {
|
||||
ORT_ENFORCE(params.size() == 2);
|
||||
const auto dy = params[0], y = params[1];
|
||||
|
||||
return LeakyReluGrad(dy, y, alpha);
|
||||
},
|
||||
{{"alpha", alpha}}, 1, kMSDomain);
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename TComputeGeluGradScalarFn>
|
||||
void TestBiasGeluGradBroadcastBias(const std::string& op, int opset_version, const std::string& domain,
|
||||
|
|
|
|||
|
|
@ -90,6 +90,13 @@ TEST(CudaKernelTest, TanhGrad_basic) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, LeakyReluGrad_basic) {
|
||||
std::vector<std::vector<int64_t>> test_dims{{4}, {16, 2}, {8, 2, 128, 128}};
|
||||
for (const auto& test_dim : test_dims) {
|
||||
TestActivations(test_dim, "LeakyReluGrad", true /* grad_op */);
|
||||
}
|
||||
}
|
||||
|
||||
static void TestActivationsWithBroadcastBias(
|
||||
const std::vector<int64_t>& tensor_dim,
|
||||
const std::string& operator_name,
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GeluG
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SigmoidGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TanhGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuickGeluGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, LeakyReluGrad);
|
||||
|
||||
// REVIEW(mzs): ConstEigenVectorArrayMap.cast<MLFLoat16) does not seem to be supported.
|
||||
// However these types work on GPU implementation.
|
||||
|
|
@ -182,6 +183,7 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SigmoidGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TanhGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuickGeluGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, LeakyReluGrad)>,
|
||||
// REVIEW(mzs): ConstEigenVectorArrayMap.cast<MLFLoat16) does not seem to be supported.
|
||||
// However these types work on GPU implementation.
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16, DropoutGrad)>,
|
||||
|
|
|
|||
|
|
@ -260,5 +260,24 @@ Status QuickGeluGrad<T>::Compute(OpKernelContext* context) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(LeakyReluGrad, kMSDomain, 1, kCpuExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
LeakyReluGrad<float>);
|
||||
|
||||
template <typename T>
|
||||
Status LeakyReluGrad<T>::Compute(OpKernelContext* context) const {
|
||||
auto& dY = *context->Input<Tensor>(0);
|
||||
auto& Y = *context->Input<Tensor>(1);
|
||||
auto& dX = *context->Output(0, dY.Shape());
|
||||
EigenVectorArrayMap<float> dx = EigenVectorArrayMap<float>(dX.template MutableData<T>(),
|
||||
narrow<Eigen::Index>(dX.Shape().Size()));
|
||||
ConstEigenVectorArrayMap<float> y = ConstEigenVectorArrayMap<float>(Y.template Data<T>(),
|
||||
narrow<Eigen::Index>(Y.Shape().Size()));
|
||||
ConstEigenVectorArrayMap<float> dy = ConstEigenVectorArrayMap<float>(dY.template Data<T>(),
|
||||
narrow<Eigen::Index>(dY.Shape().Size()));
|
||||
dx = (y > 0.0f).select(dy, alpha_ * dy);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -78,5 +78,19 @@ class SoftmaxGrad final : public OpKernel {
|
|||
bool is_logsoftmaxgrad_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class LeakyReluGrad final : public OpKernel {
|
||||
public:
|
||||
explicit LeakyReluGrad(const OpKernelInfo& info) : OpKernel(info) {
|
||||
alpha_ = info.GetAttrOrDefault<float>("alpha", 0.01f);
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(LeakyReluGrad);
|
||||
float alpha_;
|
||||
};
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ ACTIVATION_GRAD_OP_HFD(ReluGrad, 1, kMSDomain);
|
|||
ACTIVATION_GRAD_OP_HFD(SigmoidGrad, 1, kMSDomain);
|
||||
ACTIVATION_GRAD_OP_HFD(QuickGeluGrad, 1, kMSDomain);
|
||||
ACTIVATION_GRAD_OP_HFD(TanhGrad, 1, kMSDomain);
|
||||
ACTIVATION_GRAD_OP_HFD(LeakyReluGrad, 1, kMSDomain);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -79,5 +79,20 @@ class TanhGrad final : public BinaryElementwise<ShouldNotBroadcast> {
|
|||
private:
|
||||
MAKE_FUNC_CTX_NULL()
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class LeakyReluGrad final : public BinaryElementwise<ShouldNotBroadcast> {
|
||||
public:
|
||||
LeakyReluGrad(const OpKernelInfo& info) : BinaryElementwise(info) {
|
||||
alpha_ = info.GetAttrOrDefault<float>("alpha", 0.01f);
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
MAKE_FUNC_CTX_ALPHA()
|
||||
float alpha_;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -64,6 +64,13 @@ struct OP_TanhGrad : public CtxTanhGrad {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct OP_LeakyReluGrad : public CtxLeakyReluGrad {
|
||||
__device__ __inline__ T operator()(const T& dy, const T& y) const {
|
||||
return dy * (y > T{0} ? T{1} : static_cast<T>(alpha));
|
||||
}
|
||||
};
|
||||
|
||||
#define BINARY_ELEMENTWISE_IMPL(name) \
|
||||
BINARY_ELEMENTWISE_IMPL_DECLARATION(name) { \
|
||||
BinaryElementWiseNoBroadcastImpl(stream, \
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ typedef onnxruntime::cuda::CtxNull CtxReluGrad;
|
|||
typedef onnxruntime::cuda::CtxNull CtxSigmoidGrad;
|
||||
typedef onnxruntime::cuda::CtxAlpha CtxQuickGeluGrad;
|
||||
typedef onnxruntime::cuda::CtxNull CtxTanhGrad;
|
||||
typedef onnxruntime::cuda::CtxAlpha CtxLeakyReluGrad;
|
||||
|
||||
#define ACTIVATION_GRAD_OPS() \
|
||||
ACTIVATION_GRAD_OP_NAME(GeluGrad) \
|
||||
|
|
@ -20,7 +21,8 @@ typedef onnxruntime::cuda::CtxNull CtxTanhGrad;
|
|||
ACTIVATION_GRAD_OP_NAME(ReluGrad) \
|
||||
ACTIVATION_GRAD_OP_NAME(SigmoidGrad) \
|
||||
ACTIVATION_GRAD_OP_NAME(QuickGeluGrad) \
|
||||
ACTIVATION_GRAD_OP_NAME(TanhGrad)
|
||||
ACTIVATION_GRAD_OP_NAME(TanhGrad) \
|
||||
ACTIVATION_GRAD_OP_NAME(LeakyReluGrad)
|
||||
|
||||
#define BINARY_ELEMENTWISE_IMPL_DECLARATION(name) \
|
||||
template <typename T> \
|
||||
|
|
|
|||
|
|
@ -123,6 +123,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TanhGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TanhGrad);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, LeakyReluGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, LeakyReluGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, LeakyReluGrad);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, IsFinite);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, IsFinite);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, IsFinite);
|
||||
|
|
@ -366,6 +370,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TanhGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TanhGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TanhGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, LeakyReluGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, LeakyReluGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, LeakyReluGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, IsFinite)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, IsFinite)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, IsFinite)>,
|
||||
|
|
|
|||
|
|
@ -119,6 +119,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, TanhGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, TanhGrad);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, LeakyReluGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, LeakyReluGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, LeakyReluGrad);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, IsFinite);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, IsFinite);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, IsFinite);
|
||||
|
|
@ -316,6 +320,9 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, TanhGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, TanhGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, TanhGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, LeakyReluGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, LeakyReluGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, LeakyReluGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, IsFinite)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, IsFinite)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, IsFinite)>,
|
||||
|
|
|
|||
Loading…
Reference in a new issue