diff --git a/onnxruntime/core/providers/rocm/miopen_common.h b/onnxruntime/core/providers/rocm/miopen_common.h index 32491b0cb8..3bb8c2d8f0 100644 --- a/onnxruntime/core/providers/rocm/miopen_common.h +++ b/onnxruntime/core/providers/rocm/miopen_common.h @@ -9,6 +9,8 @@ #include "core/framework/tensor.h" #include +const double MIOPEN_BN_MIN_EPSILON = 1e-5; + namespace onnxruntime { namespace rocm { @@ -56,5 +58,14 @@ struct ReduceConsts { static const ElemType One; }; +inline double ClampMiopenBatchNormEpsilon(double epsilon) { + if (epsilon < MIOPEN_BN_MIN_EPSILON) { + if (MIOPEN_BN_MIN_EPSILON - epsilon > FLT_EPSILON) + LOGS_DEFAULT(WARNING) << "Provided epsilon is smaller than CUDNN_BN_MIN_EPSILON. Setting it to CUDNN_BN_MIN_EPSILON"; + return MIOPEN_BN_MIN_EPSILON; + } + return epsilon; +} + } // namespace rocm } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_common.h b/onnxruntime/core/providers/rocm/rocm_common.h index eca6c4cb1a..ddc99bf87c 100644 --- a/onnxruntime/core/providers/rocm/rocm_common.h +++ b/onnxruntime/core/providers/rocm/rocm_common.h @@ -3,6 +3,7 @@ #pragma once +#include "core/providers/shared_library/provider_api.h" #include "core/common/status.h" #include "core/providers/rocm/rocm_pch.h" #include "core/providers/rocm/shared_inc/rocm_call.h" diff --git a/orttraining/orttraining/test/training_ops/cuda/batch_norm_internal_test.cc b/orttraining/orttraining/test/training_ops/cuda/batch_norm_internal_test.cc index 0e329537d7..d00f8e6613 100644 --- a/orttraining/orttraining/test/training_ops/cuda/batch_norm_internal_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/batch_norm_internal_test.cc @@ -8,15 +8,13 @@ #include "gtest/gtest.h" #include "gmock/gmock.h" -using namespace std; - namespace onnxruntime { namespace contrib { namespace test { using namespace onnxruntime::test; -#ifdef USE_CUDA +#if USE_CUDA || USE_ROCM static void TestBatchNormInternal(bool test_double = false, bool T_is_half = false, bool T1_is_half = false, bool T2_is_half = false, const std::vector& input_output_dims = {2, 2, 2, 2}) { @@ -138,9 +136,11 @@ TEST(CudaKernelTest, BNInternalBasic) { // float case TestBatchNormInternal(); } +#ifndef USE_ROCM // MIOpen does not support double type TEST(CudaKernelTest, BNInternalDouble) { // double case TestBatchNormInternal(true); } +#endif // ndef USE_ROCM TEST(CudaKernelTest, BNInternalHalf) { // half case TestBatchNormInternal(false, true, true, true); @@ -195,7 +195,7 @@ TEST(CudaKernelTest, BNInternal1DInput) { // float case, 1d input test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCpuExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } -#endif // USE_CUDA +#endif // USE_CUDA || USE_ROCM } // namespace test } // namespace contrib diff --git a/orttraining/orttraining/training_ops/rocm/nn/batch_norm_grad.cc b/orttraining/orttraining/training_ops/rocm/nn/batch_norm_grad.cc new file mode 100644 index 0000000000..0a65545e3a --- /dev/null +++ b/orttraining/orttraining/training_ops/rocm/nn/batch_norm_grad.cc @@ -0,0 +1,137 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "batch_norm_grad.h" +#include "core/providers/common.h" +#include "core/providers/rocm/miopen_common.h" +#include "core/providers/cpu/nn/batch_norm_helper.h" +#include "core/providers/rocm/math/unary_elementwise_ops_impl.h" + +using namespace std; +namespace onnxruntime { +namespace rocm { + +#define REGISTER_GRADIENT_KERNEL_TYPED(T, T1, T2) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + BatchNormalizationGrad, \ + kMSDomain, \ + 1, \ + T##_##T1##_##T2, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + BatchNormalizationGrad); + +template +Status BatchNormalizationGrad::ComputeInternal(OpKernelContext* ctx) const { + typedef typename ToHipType::MappedType HipT; + typedef typename ToHipType::MappedType HipT1; + typedef typename ToHipType::MappedType HipT2; + + const Tensor* dY = ctx->Input(0); + const Tensor* X = ctx->Input(1); + const Tensor* Scale = ctx->Input(2); + const Tensor* saved_mean = ctx->Input(3); + // miopenBatchNormalizationBackward() claims to use `savedInvVariance`, but the value + // is actually equal to the batch inv_std, so we use name `saved_inv_std` here. + const Tensor* saved_inv_std = ctx->Input(4); + const TensorShape input_shape = X->Shape(); + const TensorShape channel_shape = saved_mean->Shape(); + + // no B here, but B has same size as Scale, so can validate inputs for gradient with this substitute + ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(X, Scale, Scale, saved_mean, saved_inv_std)); + + auto dY_data = reinterpret_cast(dY->template Data()); + auto X_data = reinterpret_cast(X->template Data()); + auto Scale_data = reinterpret_cast(Scale->template Data()); + auto saved_mean_data = reinterpret_cast(saved_mean->template Data()); + auto saved_inv_std_data = reinterpret_cast(saved_inv_std->template Data()); + + auto dX_data = reinterpret_cast(ctx->Output(0, input_shape)->template MutableData()); + auto dScale_data = reinterpret_cast(ctx->Output(1, channel_shape)->template MutableData()); + auto dBias_data = reinterpret_cast(ctx->Output(2, channel_shape)->template MutableData()); + + const auto alpha = Consts::One; + const auto beta = Consts::Zero; + + MiopenTensor input_tensor, scale_bias_tensor; + vector new_dims; + BatchNormHelper::NormalizeDims(input_shape, new_dims); + ORT_RETURN_IF_ERROR(input_tensor.Set(new_dims, MiopenTensor::GetDataType())); + ORT_RETURN_IF_ERROR(scale_bias_tensor.Set(input_tensor, miopen_batch_norm_mode_)); + + const int64_t C = new_dims[1]; + auto p_scale = reinterpret_cast(Scale_data); + auto p_saved_mean = reinterpret_cast(saved_mean_data); + auto p_saved_inv_std = reinterpret_cast(saved_inv_std_data); + auto p_dScale = reinterpret_cast(dScale_data); + auto p_dBias = reinterpret_cast(dBias_data); + + IAllocatorUniquePtr p_f_scale, p_f_dScale, p_f_dBias, p_f_saved_mean, p_f_saved_inv_std; + + if (std::is_same::value) { + p_f_scale = GetScratchBuffer(C); + p_f_dScale = GetScratchBuffer(C); + p_f_dBias = GetScratchBuffer(C); + + Impl_Cast(Stream(), Scale_data, p_f_scale.get(), C); + + p_scale = p_f_scale.get(); + p_dScale = p_f_dScale.get(); + p_dBias = p_f_dBias.get(); + } + + if (std::is_same::value) { + p_f_saved_mean = GetScratchBuffer(C); + p_f_saved_inv_std = GetScratchBuffer(C); + + Impl_Cast(Stream(), saved_mean_data, p_f_saved_mean.get(), C); + Impl_Cast(Stream(), saved_inv_std_data, p_f_saved_inv_std.get(), C); + + p_saved_mean = p_f_saved_mean.get(); + p_saved_inv_std = p_f_saved_inv_std.get(); + } + + MIOPEN_RETURN_IF_ERROR(miopenBatchNormalizationBackward( + MiopenHandle(), + miopen_batch_norm_mode_, + &alpha, + &beta, + &alpha, + &beta, + input_tensor, + X_data, + input_tensor, + dY_data, + input_tensor, + dX_data, + scale_bias_tensor, + p_scale, + p_dScale, + p_dBias, + epsilon_, + p_saved_mean, + p_saved_inv_std)); + + if (std::is_same::value) { + Impl_Cast(Stream(), reinterpret_cast(p_dScale), dScale_data, C); + Impl_Cast(Stream(), reinterpret_cast(p_dBias), dBias_data, C); + } + + return Status::OK(); +} + +#define SPECIALIZED_GRADIENT(T, T1, T2) \ + REGISTER_GRADIENT_KERNEL_TYPED(T, T1, T2) \ + template Status BatchNormalizationGrad::ComputeInternal(OpKernelContext* ctx) const; + +SPECIALIZED_GRADIENT(float, float, float) +// MIOpen kernel does not support double, disable for now. +// SPECIALIZED_GRADIENT(double, double, double) +SPECIALIZED_GRADIENT(MLFloat16, MLFloat16, MLFloat16) +SPECIALIZED_GRADIENT(MLFloat16, MLFloat16, float) +SPECIALIZED_GRADIENT(MLFloat16, float, float) + +} // namespace rocm +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/rocm/nn/batch_norm_grad.h b/orttraining/orttraining/training_ops/rocm/nn/batch_norm_grad.h new file mode 100644 index 0000000000..4995692b1d --- /dev/null +++ b/orttraining/orttraining/training_ops/rocm/nn/batch_norm_grad.h @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "gsl/gsl" +#include "core/providers/rocm/rocm_kernel.h" +#include "core/providers/rocm/miopen_common.h" + +namespace onnxruntime { +namespace rocm { + +template +class BatchNormalizationGrad final : public RocmKernel { + public: + BatchNormalizationGrad(const OpKernelInfo& info) + : RocmKernel{info}, + miopen_batch_norm_mode_(miopenBNSpatial) { + float tmp_epsilon; + ORT_ENFORCE(info.GetAttr("epsilon", &tmp_epsilon).IsOK()); + epsilon_ = ClampMiopenBatchNormEpsilon(static_cast(tmp_epsilon)); + + // spatial or not + int64_t tmp_spatial; + if (info.GetAttr("spatial", &tmp_spatial).IsOK()) { + spatial_ = tmp_spatial; + } + + if (spatial_ == 0) { + miopen_batch_norm_mode_ = miopenBNPerActivation; + } + } + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + double epsilon_; + int64_t spatial_ = 1; // default as per spec + miopenBatchNormMode_t miopen_batch_norm_mode_; +}; + +} // namespace rocm +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/rocm/nn/batch_norm_internal.cc b/orttraining/orttraining/training_ops/rocm/nn/batch_norm_internal.cc new file mode 100644 index 0000000000..3c5c8f35d9 --- /dev/null +++ b/orttraining/orttraining/training_ops/rocm/nn/batch_norm_internal.cc @@ -0,0 +1,167 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "batch_norm_internal.h" +#include "core/providers/common.h" +#include "core/providers/rocm/miopen_common.h" +#include "core/providers/cpu/nn/batch_norm_helper.h" +#include "core/providers/rocm/math/unary_elementwise_ops_impl.h" + +using namespace std; +namespace onnxruntime { +namespace rocm { + +#define REGISTER_KERNEL_TYPED(T, T1, T2) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + BatchNormInternal, \ + kMSDomain, \ + 1, \ + T##_##T1##_##T2, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .Alias(3, 1) \ + .Alias(4, 2) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + BatchNormInternal); + +template +Status BatchNormInternal::ComputeInternal(OpKernelContext* p_op_kernel_context) const { + typedef typename ToHipType::MappedType HipT; + typedef typename ToHipType::MappedType HipT1; + typedef typename ToHipType::MappedType HipT2; + + const Tensor* X = p_op_kernel_context->Input(0); + const Tensor* scale = p_op_kernel_context->Input(1); + const Tensor* B = p_op_kernel_context->Input(2); + const Tensor* mean = p_op_kernel_context->Input(3); + const Tensor* var = p_op_kernel_context->Input(4); + + ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(X, scale, B, mean, var, spatial_ == 1)); + + const TensorShape& x_shape = X->Shape(); + const TensorShape& channel_shape = mean->Shape(); + + Tensor* Y = p_op_kernel_context->Output(0, x_shape); + Tensor* running_mean = p_op_kernel_context->Output(1, channel_shape); + Tensor* running_var = p_op_kernel_context->Output(2, channel_shape); + Tensor* saved_mean = p_op_kernel_context->Output(3, channel_shape); + // miopenBatchNormalizationForwardTraining() claims to output `resultSaveInvVariance`, but the value + // is actually equal to the batch inv_std, so we use name `saved_inv_std` here. + Tensor* saved_inv_std = p_op_kernel_context->Output(4, channel_shape); + + auto x_data = reinterpret_cast(X->template Data()); + auto scale_data = reinterpret_cast(scale->template Data()); + auto b_data = reinterpret_cast(B->template Data()); + auto mean_data = reinterpret_cast(mean->template Data()); + auto var_data = reinterpret_cast(var->template Data()); + + auto y_data = reinterpret_cast(Y->template MutableData()); + + // In MIOpenBatchNormForward, alpha and beta are not const. + float alpha = 1.0; + float beta = 0.0; + + MiopenTensor data_desc, bn_tensor_desc; + vector new_dims; + BatchNormHelper::NormalizeDims(x_shape, new_dims); + ORT_RETURN_IF_ERROR(data_desc.Set(new_dims, MiopenTensor::GetDataType())); + ORT_RETURN_IF_ERROR(bn_tensor_desc.Set(data_desc, miopen_batch_norm_mode_)); + + auto running_mean_data = reinterpret_cast(running_mean->template MutableData()); + auto running_var_data = reinterpret_cast(running_var->template MutableData()); + auto saved_mean_data = reinterpret_cast(saved_mean->template MutableData()); + auto saved_inv_std_data = reinterpret_cast(saved_inv_std->template MutableData()); + + auto p_scale = reinterpret_cast(scale_data); + auto p_B = reinterpret_cast(b_data); + auto p_running_mean = reinterpret_cast(running_mean_data); + auto p_running_var = reinterpret_cast(running_var_data); + auto p_saved_mean = reinterpret_cast(saved_mean_data); + auto p_saved_inv_std = reinterpret_cast(saved_inv_std_data); + + + const int64_t C = new_dims[1]; + IAllocatorUniquePtr p_f_scale, p_f_B, p_f_running_mean, p_f_running_var, p_f_saved_mean, p_f_saved_inv_std; + + if (std::is_same::value) { + // Convert scale/B to float + p_f_scale = GetScratchBuffer(C); + p_f_B = GetScratchBuffer(C); + + Impl_Cast(Stream(), scale_data, p_f_scale.get(), C); + Impl_Cast(Stream(), b_data, p_f_B.get(), C); + + p_scale = p_f_scale.get(); + p_B = p_f_B.get(); + } + + if (std::is_same::value) { + // Convert mean/var to float + p_f_running_mean = GetScratchBuffer(C); + p_f_running_var = GetScratchBuffer(C); + p_f_saved_mean = GetScratchBuffer(C); + p_f_saved_inv_std = GetScratchBuffer(C); + + Impl_Cast(Stream(), mean_data, p_f_running_mean.get(), C); + Impl_Cast(Stream(), var_data, p_f_running_var.get(), C); + + p_running_mean = p_f_running_mean.get(); + p_running_var = p_f_running_var.get(); + p_saved_mean = p_f_saved_mean.get(); + p_saved_inv_std = p_f_saved_inv_std.get(); + } else if (mean_data != running_mean_data) { + HIP_RETURN_IF_ERROR( + hipMemcpyAsync(running_mean_data, mean_data, C * sizeof(T2), hipMemcpyDeviceToDevice, Stream())); + HIP_RETURN_IF_ERROR( + hipMemcpyAsync(running_var_data, var_data, C * sizeof(T2), hipMemcpyDeviceToDevice, Stream())); + } + + // NOTE: in miopenBatchNorm, biased std/var is used when calculating `save_inv_std` and `y`, while + // `running_var` is updated using unbiased `batch_var`: + // running_var = (1 - momentum_) * unbiased_batch_var + momentum_ * running_var + // This is inconsistent with BatchNormalization Onnx spec, which uses population variance (biased). + MIOPEN_RETURN_IF_ERROR(miopenBatchNormalizationForwardTraining( + MiopenHandle(), + miopen_batch_norm_mode_, + &alpha, + &beta, + data_desc, + x_data, + data_desc, + y_data, + bn_tensor_desc, + const_cast(p_scale), + const_cast(p_B), + 1.0 - momentum_, + p_running_mean, + p_running_var, + epsilon_, + p_saved_mean, + p_saved_inv_std)); + + if (std::is_same::value) { + Impl_Cast(Stream(), reinterpret_cast(p_running_mean), running_mean_data, C); + Impl_Cast(Stream(), reinterpret_cast(p_running_var), running_var_data, C); + Impl_Cast(Stream(), reinterpret_cast(p_saved_mean), saved_mean_data, C); + Impl_Cast(Stream(), reinterpret_cast(p_saved_inv_std), saved_inv_std_data, C); + } + + return Status::OK(); +} + +#define SPECIALIZED_COMPUTE(T, T1, T2) \ + REGISTER_KERNEL_TYPED(T, T1, T2) \ + template Status BatchNormInternal::ComputeInternal(OpKernelContext* ctx) const; + + +SPECIALIZED_COMPUTE(float, float, float) +// MIOpen kernel does not support double, disable for now. +// SPECIALIZED_COMPUTE(double, double, double) +SPECIALIZED_COMPUTE(MLFloat16, MLFloat16, MLFloat16) +SPECIALIZED_COMPUTE(MLFloat16, MLFloat16, float) +SPECIALIZED_COMPUTE(MLFloat16, float, float) + +} // namespace rocm +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/rocm/nn/batch_norm_internal.h b/orttraining/orttraining/training_ops/rocm/nn/batch_norm_internal.h new file mode 100644 index 0000000000..cb2817951e --- /dev/null +++ b/orttraining/orttraining/training_ops/rocm/nn/batch_norm_internal.h @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "gsl/gsl" +#include "core/providers/rocm/rocm_kernel.h" +#include "core/providers/rocm/miopen_common.h" + +namespace onnxruntime { +namespace rocm { + +template +class BatchNormInternal final : public RocmKernel { + public: + BatchNormInternal(const OpKernelInfo& op_kernel_info) + : RocmKernel{op_kernel_info}, + miopen_batch_norm_mode_(miopenBNSpatial), + momentum_(0.9) { + float tmp_epsilon; + ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &tmp_epsilon).IsOK()); + epsilon_ = ClampMiopenBatchNormEpsilon(static_cast(tmp_epsilon)); + + // spatial or not + int64_t tmp_spatial; + if (op_kernel_info.GetAttr("spatial", &tmp_spatial).IsOK()) { + spatial_ = tmp_spatial; + } + + if (spatial_ == 0) { + miopen_batch_norm_mode_ = miopenBNPerActivation; + } + + float tmp_momentum; + if (op_kernel_info.GetAttr("momentum", &tmp_momentum).IsOK()) { + momentum_ = static_cast(tmp_momentum); + } + } + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + double epsilon_; + int64_t spatial_ = 1; // default as per spec + miopenBatchNormMode_t miopen_batch_norm_mode_; + double momentum_; +}; + +} // namespace rocm +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc index 2fa9c1d133..0d71c697e7 100644 --- a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc +++ b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc @@ -67,8 +67,16 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, LogSoftmaxGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, LogSoftmaxGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, LogSoftmaxGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BatchNormalizationGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, BatchNormalizationGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_float_float, BatchNormalizationGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double_double_double, BatchNormalizationGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_MLFloat16, BatchNormalizationGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_float, BatchNormalizationGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_float, BatchNormalizationGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_float_float, BatchNormInternal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double_double_double, BatchNormInternal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_MLFloat16, BatchNormInternal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_float, BatchNormInternal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_float, BatchNormInternal); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GatherGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, DropoutGrad); @@ -208,8 +216,16 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo,