From 3e7f70bf8858c9dffde45be4f4670e9ffd71a7eb Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Thu, 10 Aug 2023 20:45:34 -0700 Subject: [PATCH] LeakyRelu Gradient (#17039) --- .../cpu/activation/activation_op_test.cc | 21 ++++++++++++ .../core/graph/gradient_builder.cc | 5 +++ .../orttraining/core/graph/gradient_builder.h | 1 + .../core/graph/gradient_builder_registry.cc | 1 + .../core/graph/training_op_defs.cc | 13 +++++++ .../test/gradient/gradient_ops_test.cc | 6 ++++ .../python/orttraining_test_ortmodule_api.py | 34 ++++++++++++++++++- .../cpu/activation/activation_op_test.cc | 21 ++++++++++++ .../training_ops/cuda/activations_test.cc | 7 ++++ .../training_ops/cpu/cpu_training_kernels.cc | 2 ++ .../training_ops/cpu/op_gradients.cc | 19 +++++++++++ .../training_ops/cpu/op_gradients.h | 14 ++++++++ .../cuda/activation/activations_grad.cc | 1 + .../cuda/activation/activations_grad.h | 15 ++++++++ .../cuda/activation/activations_grad_impl.cu | 7 ++++ .../cuda/activation/activations_grad_impl.h | 4 ++- .../cuda/cuda_training_kernels.cc | 7 ++++ .../rocm/rocm_training_kernels.cc | 7 ++++ 18 files changed, 183 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc index b7383ca690..7ec9e0f345 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc @@ -64,6 +64,10 @@ constexpr float SigmoidGrad(float dy, float y) { constexpr float TanhGrad(float dy, float y) { return dy * (1 - y * y); } + +constexpr float LeakyReluGrad(float dy, float y, float alpha) { + return dy * (y > 0.0f ? 1.0f : alpha); +} } // namespace #endif @@ -669,6 +673,23 @@ TEST(TanhGradInferenceTest, Basic) { }, {}, 1, kMSDomain); } + +TEST(LeakyReluGradInferenceTest, Basic) { + const std::vector y_vals = {-1.0f, 0, 1.0f, 100.0f, -100.0f, 1000.0f, -1000.0f}; + const std::vector dY(7, 1.0f); + float alpha = 0.5f; + + TestElementwiseGradientOp( + "LeakyReluGrad", + {{"dY", dY}, {"Y", y_vals}}, + [alpha](const std::vector& params) { + ORT_ENFORCE(params.size() == 2); + const auto dy = params[0], y = params[1]; + + return LeakyReluGrad(dy, y, alpha); + }, + {{"alpha", alpha}}, 1, kMSDomain); +} #endif } // namespace test diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index dfb8f6d3b1..f8e0545574 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -2065,5 +2065,10 @@ IMPLEMENT_GRADIENT_BUILDER(GetReciprocalGradient) { NodeDef("Mul", {GO(0), IA("Neg_Square_O0")}, {GI(0)})}; } +IMPLEMENT_GRADIENT_BUILDER(GetLeakyReluGradient) { + return {NodeDef(OpDef{"LeakyReluGrad", kMSDomain, 1}, + {GO(0), O(0)}, {GI(0)}, SrcNodeAttributes())}; +} + } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index 8d7b005e6f..ca86777d36 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -87,6 +87,7 @@ DECLARE_GRADIENT_BUILDER(GetFakeQuantGradient) DECLARE_GRADIENT_BUILDER(GetLSTMGradient) DECLARE_GRADIENT_BUILDER(GetGRUGradient) DECLARE_GRADIENT_BUILDER(GetReciprocalGradient) +DECLARE_GRADIENT_BUILDER(GetLeakyReluGradient) DECLARE_GRADIENT_BUILDER(GetExternalGradient) diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index 1d3e7ede77..cc9a762ff8 100755 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -119,6 +119,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("LSTMTraining", GetLSTMGradient); REGISTER_GRADIENT_BUILDER("GRUTraining", GetGRUGradient); REGISTER_GRADIENT_BUILDER("Reciprocal", GetReciprocalGradient); + REGISTER_GRADIENT_BUILDER("LeakyRelu", GetLeakyReluGradient); REGISTER_GRADIENT_BUILDER("ExternalGradient", GetExternalGradient); }; diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index a0a48503ec..8ae914b666 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -2900,6 +2900,19 @@ Example 4: return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {}); }); + ONNX_CONTRIB_OPERATOR_SCHEMA(LeakyReluGrad) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc("Gradient operator for LeakyRelu.") + .Attr("alpha", "Alpha (negative slope) value.", AttributeProto::FLOAT, 0.01f) + .AllowUncheckedAttributes() + .Input(0, "dY", "The gradient tensor from output.", "T") + .Input(1, "Y", "The output tensor. ", "T") + .Output(0, "dX", "Gradient of the input.", "T") + .TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}, + "Constrain input and output types to float tensors.") + .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput); + ONNX_CONTRIB_OPERATOR_SCHEMA(LayerNormalizationGrad) .SetDomain(kMSDomain) .SinceVersion(1) diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 4834451616..39cc6bdd11 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -3033,6 +3033,12 @@ TEST(GradientCheckerTest, ReciprocalGrad) { UnaryOpGradientTest("Reciprocal", kOnnxDomain, 12, nullptr, &transformer); } +TEST(GradientCheckerTest, LeakyReluGrad) { + // Gradient is non continuous at 0, so we need to avoid it. + std::function transformer = [](float x) { return x > 0 ? x + 0.2f : x - 0.2f; }; + UnaryOpGradientTest("LeakyRelu", kOnnxDomain, 16, nullptr, &transformer); +} + } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 8d2bd19bff..b62e959556 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6162,7 +6162,39 @@ def test_reciprocal_gradient(): pt_model = ReciprocalModel().to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) - pt_x = torch.zeros(3, 224, 224, requires_grad=True, device=device) + pt_x = torch.randn(3, 224, 224, requires_grad=True, device=device) + with torch.no_grad(): + pt_x[pt_x <= 0] -= 0.2 + pt_x[pt_x > 0] += 0.2 + ort_x = copy.deepcopy(pt_x) + + pt_prediction, pt_loss = run_step(pt_model, pt_x) + ort_prediction, ort_loss = run_step(ort_model, ort_x) + _test_helpers.assert_values_are_close(pt_prediction, ort_prediction) + _test_helpers.assert_values_are_close(pt_loss, ort_loss) + _test_helpers.assert_values_are_close(pt_x.grad, ort_x.grad) + + +def test_leakyrelu_gradient(): + class LeakyReluModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.leakyrelu = nn.LeakyReLU(0.5) + + def forward(self, x): + return self.leakyrelu(x) + + def run_step(model, x): + prediction = model(x) + loss = prediction.sum() + loss.backward() + return prediction, loss + + device = "cuda" + pt_model = LeakyReluModel().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + pt_x = torch.randn(3, 224, 224, requires_grad=True, device=device) with torch.no_grad(): pt_x[pt_x <= 0] -= 0.2 pt_x[pt_x > 0] += 0.2 diff --git a/orttraining/orttraining/test/training_ops/cpu/activation/activation_op_test.cc b/orttraining/orttraining/test/training_ops/cpu/activation/activation_op_test.cc index 3f0d78c7ee..c700c73086 100644 --- a/orttraining/orttraining/test/training_ops/cpu/activation/activation_op_test.cc +++ b/orttraining/orttraining/test/training_ops/cpu/activation/activation_op_test.cc @@ -89,6 +89,10 @@ float QuickGeluGrad(float dy, float x, float alpha) { float sigmoid = v >= 0 ? 1.f / (1.f + std::exp(-v)) : 1.f - 1.f / (1 + std::exp(v)); return dy * sigmoid * (1 + v * (1 - sigmoid)); } + +constexpr float LeakyReluGrad(float dy, float y, float alpha) { + return dy * (y > 0.0f ? 1.0f : alpha); +} } // namespace TEST(GeluGradTest, Basic) { @@ -263,6 +267,23 @@ TEST(QuickGeluGradTest, Basic) { } } +TEST(LeakyReluGradTest, Basic) { + const std::vector y_vals = {-1.0f, 0, 1.0f, 100.0f, -100.0f, 1000.0f, -1000.0f}; + const std::vector dY(7, 1.0f); + float alpha = 0.5f; + + TestElementwiseGradientOp( + "LeakyReluGrad", + {{"dY", dY}, {"Y", y_vals}}, + [alpha](const std::vector& params) { + ORT_ENFORCE(params.size() == 2); + const auto dy = params[0], y = params[1]; + + return LeakyReluGrad(dy, y, alpha); + }, + {{"alpha", alpha}}, 1, kMSDomain); +} + namespace { template void TestBiasGeluGradBroadcastBias(const std::string& op, int opset_version, const std::string& domain, diff --git a/orttraining/orttraining/test/training_ops/cuda/activations_test.cc b/orttraining/orttraining/test/training_ops/cuda/activations_test.cc index 656db10e90..3173610597 100644 --- a/orttraining/orttraining/test/training_ops/cuda/activations_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/activations_test.cc @@ -90,6 +90,13 @@ TEST(CudaKernelTest, TanhGrad_basic) { } } +TEST(CudaKernelTest, LeakyReluGrad_basic) { + std::vector> test_dims{{4}, {16, 2}, {8, 2, 128, 128}}; + for (const auto& test_dim : test_dims) { + TestActivations(test_dim, "LeakyReluGrad", true /* grad_op */); + } +} + static void TestActivationsWithBroadcastBias( const std::vector& tensor_dim, const std::string& operator_name, diff --git a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc index cd2d504f41..e1e0e520bb 100644 --- a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc @@ -53,6 +53,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GeluG class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SigmoidGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TanhGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuickGeluGrad); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, LeakyReluGrad); // REVIEW(mzs): ConstEigenVectorArrayMap.cast, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // REVIEW(mzs): ConstEigenVectorArrayMap.cast, diff --git a/orttraining/orttraining/training_ops/cpu/op_gradients.cc b/orttraining/orttraining/training_ops/cpu/op_gradients.cc index 46c5646b48..c3476161c1 100644 --- a/orttraining/orttraining/training_ops/cpu/op_gradients.cc +++ b/orttraining/orttraining/training_ops/cpu/op_gradients.cc @@ -260,5 +260,24 @@ Status QuickGeluGrad::Compute(OpKernelContext* context) const { return Status::OK(); } +ONNX_OPERATOR_KERNEL_EX(LeakyReluGrad, kMSDomain, 1, kCpuExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + LeakyReluGrad); + +template +Status LeakyReluGrad::Compute(OpKernelContext* context) const { + auto& dY = *context->Input(0); + auto& Y = *context->Input(1); + auto& dX = *context->Output(0, dY.Shape()); + EigenVectorArrayMap dx = EigenVectorArrayMap(dX.template MutableData(), + narrow(dX.Shape().Size())); + ConstEigenVectorArrayMap y = ConstEigenVectorArrayMap(Y.template Data(), + narrow(Y.Shape().Size())); + ConstEigenVectorArrayMap dy = ConstEigenVectorArrayMap(dY.template Data(), + narrow(dY.Shape().Size())); + dx = (y > 0.0f).select(dy, alpha_ * dy); + return Status::OK(); +} + } // namespace contrib } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/op_gradients.h b/orttraining/orttraining/training_ops/cpu/op_gradients.h index be1268dac5..64dad382a8 100644 --- a/orttraining/orttraining/training_ops/cpu/op_gradients.h +++ b/orttraining/orttraining/training_ops/cpu/op_gradients.h @@ -78,5 +78,19 @@ class SoftmaxGrad final : public OpKernel { bool is_logsoftmaxgrad_; }; +template +class LeakyReluGrad final : public OpKernel { + public: + explicit LeakyReluGrad(const OpKernelInfo& info) : OpKernel(info) { + alpha_ = info.GetAttrOrDefault("alpha", 0.01f); + } + + Status Compute(OpKernelContext* context) const override; + + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(LeakyReluGrad); + float alpha_; +}; + } // namespace contrib } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/activation/activations_grad.cc b/orttraining/orttraining/training_ops/cuda/activation/activations_grad.cc index 8dcfba33a9..7fde69d758 100644 --- a/orttraining/orttraining/training_ops/cuda/activation/activations_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/activation/activations_grad.cc @@ -49,6 +49,7 @@ ACTIVATION_GRAD_OP_HFD(ReluGrad, 1, kMSDomain); ACTIVATION_GRAD_OP_HFD(SigmoidGrad, 1, kMSDomain); ACTIVATION_GRAD_OP_HFD(QuickGeluGrad, 1, kMSDomain); ACTIVATION_GRAD_OP_HFD(TanhGrad, 1, kMSDomain); +ACTIVATION_GRAD_OP_HFD(LeakyReluGrad, 1, kMSDomain); } // namespace cuda } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/activation/activations_grad.h b/orttraining/orttraining/training_ops/cuda/activation/activations_grad.h index 34de4ef8bb..2f60bc2cf2 100644 --- a/orttraining/orttraining/training_ops/cuda/activation/activations_grad.h +++ b/orttraining/orttraining/training_ops/cuda/activation/activations_grad.h @@ -79,5 +79,20 @@ class TanhGrad final : public BinaryElementwise { private: MAKE_FUNC_CTX_NULL() }; + +template +class LeakyReluGrad final : public BinaryElementwise { + public: + LeakyReluGrad(const OpKernelInfo& info) : BinaryElementwise(info) { + alpha_ = info.GetAttrOrDefault("alpha", 0.01f); + } + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + MAKE_FUNC_CTX_ALPHA() + float alpha_; +}; + } // namespace cuda } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu b/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu index 2c23a3ed87..164aba8667 100644 --- a/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu @@ -64,6 +64,13 @@ struct OP_TanhGrad : public CtxTanhGrad { } }; +template +struct OP_LeakyReluGrad : public CtxLeakyReluGrad { + __device__ __inline__ T operator()(const T& dy, const T& y) const { + return dy * (y > T{0} ? T{1} : static_cast(alpha)); + } +}; + #define BINARY_ELEMENTWISE_IMPL(name) \ BINARY_ELEMENTWISE_IMPL_DECLARATION(name) { \ BinaryElementWiseNoBroadcastImpl(stream, \ diff --git a/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.h b/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.h index 8e925f0484..0686dc4129 100644 --- a/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.h +++ b/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.h @@ -13,6 +13,7 @@ typedef onnxruntime::cuda::CtxNull CtxReluGrad; typedef onnxruntime::cuda::CtxNull CtxSigmoidGrad; typedef onnxruntime::cuda::CtxAlpha CtxQuickGeluGrad; typedef onnxruntime::cuda::CtxNull CtxTanhGrad; +typedef onnxruntime::cuda::CtxAlpha CtxLeakyReluGrad; #define ACTIVATION_GRAD_OPS() \ ACTIVATION_GRAD_OP_NAME(GeluGrad) \ @@ -20,7 +21,8 @@ typedef onnxruntime::cuda::CtxNull CtxTanhGrad; ACTIVATION_GRAD_OP_NAME(ReluGrad) \ ACTIVATION_GRAD_OP_NAME(SigmoidGrad) \ ACTIVATION_GRAD_OP_NAME(QuickGeluGrad) \ - ACTIVATION_GRAD_OP_NAME(TanhGrad) + ACTIVATION_GRAD_OP_NAME(TanhGrad) \ + ACTIVATION_GRAD_OP_NAME(LeakyReluGrad) #define BINARY_ELEMENTWISE_IMPL_DECLARATION(name) \ template \ diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index c80859c31d..6aac9ad7ec 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -123,6 +123,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TanhGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TanhGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, LeakyReluGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, LeakyReluGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, LeakyReluGrad); + class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, IsFinite); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, IsFinite); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, IsFinite); @@ -366,6 +370,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc index 82631fc04f..2321aa23dd 100644 --- a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc +++ b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc @@ -119,6 +119,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, TanhGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, TanhGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, LeakyReluGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, LeakyReluGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, LeakyReluGrad); + class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, IsFinite); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, IsFinite); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, IsFinite); @@ -316,6 +320,9 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo,