Miopen conv grad (#9574)

* Add source for conv_grad

* Add sources for ROCm EP.
* Transliterate sources for conv_grad for ROCm EP.

* Add conv_grad to ROCm EP

Add conv_grad to ROCm execution
provider.

* Update ROCm EP ConvGrad

Update ConvGrad for the ROCm EP to match other EP
changes and fix a build issue.
This commit is contained in:
groenenboomj 2021-10-31 13:19:46 -05:00 committed by GitHub
parent 79436a2d5b
commit 5c56fa0def
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 462 additions and 7 deletions

View file

@ -15,6 +15,8 @@ namespace onnxruntime {
namespace rocm {
#define MIOPEN_CONVOLUTION_FWD_ALGO_COUNT 6
#define MIOPEN_CONVOLUTION_BWD_FILTER_ALGO_COUNT 4
#define MIOPEN_CONVOLUTION_BWD_DATA_ALGO_COUNT 6
class MiopenTensor final {
public:

View file

@ -5,13 +5,13 @@
#include "test/providers/provider_test_utils.h"
namespace onnxruntime {
namespace cuda {
namespace contrib {
namespace test {
using namespace std;
using namespace onnxruntime::test;
#ifdef USE_CUDA
#if USE_CUDA || USE_ROCM
namespace {
struct ConvGradOpAttributes {
@ -315,8 +315,8 @@ TEST(ConvTest, Conv3D_Bias) {
TestConvGradOp(attrs, {dY, X, W}, {dY_shape, X_shape, W_shape}, {dX, dW, dB}, {dX_shape, dW_shape, dB_shape});
TestConvGradOp(attrs, {dY, X, W}, {dY_shape, X_shape, W_shape}, {dX, dW, dB}, {dX_shape, dW_shape, dB_shape}, true);
}
#endif // USE_CUDA
#endif // USE_CUDA || USE_ROCM
} // namespace test
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,384 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
//TODO Add exhaustive and default cases for algo.
#include "orttraining/training_ops/rocm/nn/conv_grad.h"
#include "core/providers/common.h"
#include "core/providers/rocm/shared_inc/fpgeneric.h"
#include "core/platform/ort_mutex.h"
namespace onnxruntime {
namespace rocm {
#define REGISTER_GRADIENT_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX(ConvGrad, kMSDomain, 1, T, kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
ConvGrad<T>);
REGISTER_GRADIENT_KERNEL_TYPED(float)
// MIOpen double support not currently implemented.
//REGISTER_GRADIENT_KERNEL_TYPED(double)
REGISTER_GRADIENT_KERNEL_TYPED(MLFloat16)
using T_BwdDataPerf = miopenConvAlgoPerf_t;
using T_BwdDataAlgo = miopenConvBwdDataAlgorithm_t;
using T_BwdFilterPerf = miopenConvAlgoPerf_t;
using T_BwdFilterAlgo = miopenConvBwdWeightsAlgorithm_t;
miopenStatus_t GetWorkspaceSize(const ConvArgs& args, T_BwdDataAlgo algo, size_t* workspace_size) {
return miopenConvolutionBackwardDataGetWorkSpaceSize(args.handle, args.y_tensor, args.x_tensor, args.conv_desc,
args.w_desc, workspace_size);
}
miopenStatus_t GetWorkspaceSize(const ConvArgs& args, T_BwdFilterAlgo algo, size_t* workspace_size) {
return miopenConvolutionBackwardWeightsGetWorkSpaceSize(args.handle, args.y_tensor, args.x_tensor, args.conv_desc,
args.w_desc, workspace_size);
}
template <typename T_Algo>
size_t GetMaxWorkspaceSize(const ConvArgs& args, const T_Algo* algo, int n_algo) {
// Calling hipMemGetInfo is not ideal, but our rocm allocator doesn't have a way to get this info.
size_t free, total;
HIP_CALL_THROW(hipMemGetInfo(&free, &total));
// Assuming 10% of fragmentation.
free = static_cast<size_t>(static_cast<double>(free) * 0.9);
size_t max_workspace_size = 0;
for (int i = 0; i < n_algo; i++) {
miopenStatus_t status;
size_t workspace_size;
status = GetWorkspaceSize(args, algo[i], &workspace_size);
if (miopenStatusSuccess != status || workspace_size == 0 || workspace_size < max_workspace_size ||
workspace_size > free)
continue;
max_workspace_size = workspace_size;
}
return max_workspace_size;
}
template <typename T_Perf>
std::vector<T_Perf> GetValidAlgorithms(const T_Perf* perf_results, int n_algo) {
std::vector<T_Perf> result;
result.reserve(n_algo);
for (int i = 0; i < n_algo; i++) {
T_Perf perf = perf_results[i];
result.emplace_back(perf);
}
ORT_ENFORCE(result.size() > 0, "No valid convolution algorithms available in MIOpen");
return result;
}
struct ConvParamsHash {
// ConvParams must be a POD because we read out its memory constant as char* when hashing.
static_assert(std::is_pod<ConvParams>::value, "ConvParams is not POD");
size_t operator()(const ConvParams& conv_params) const {
auto ptr = reinterpret_cast<const uint8_t*>(&conv_params);
uint32_t value = 0x811C9DC5;
for (int i = 0; i < static_cast<int>(sizeof(ConvParams)); ++i) {
value ^= ptr[i];
value *= 0x01000193;
}
return static_cast<size_t>(value);
}
};
struct ConvParamsEqual {
// ConvParams must be a POD because we read out its memory constant as char* when hashing.
static_assert(std::is_pod<ConvParams>::value, "ConvParams is not POD");
bool operator()(const ConvParams& a, const ConvParams& b) const {
auto ptr1 = reinterpret_cast<const uint8_t*>(&a);
auto ptr2 = reinterpret_cast<const uint8_t*>(&b);
return memcmp(ptr1, ptr2, sizeof(ConvParams)) == 0;
}
};
template <typename T_Perf>
struct AlgoPerfCache {
mutable OrtMutex mutex;
std::unordered_map<ConvParams, T_Perf, ConvParamsHash, ConvParamsEqual> map;
bool Find(const ConvParams& params, T_Perf* result) {
std::lock_guard<OrtMutex> guard(mutex);
auto it = map.find(params);
if (it == map.end()) {
return false;
}
*result = it->second;
return true;
}
void Insert(const ConvParams& params, const T_Perf& algo_perf) {
std::lock_guard<OrtMutex> guard(mutex);
map[params] = algo_perf;
}
};
// TODO: Currently we use global AlgoPerfCache for ConvGrad only. Conv's perf cache is still per node.
// Need to apply such global cache for Conv, and move some shared code from here to conv.h/cc.
AlgoPerfCache<T_BwdDataPerf> bwd_data_algos;
AlgoPerfCache<T_BwdFilterPerf> bwd_filter_algos;
template <typename T_Algo>
struct AlgoSearch {};
template <>
struct AlgoSearch<T_BwdDataAlgo> {
static constexpr auto DEFAULT_ALGO = miopenConvolutionBwdDataAlgoGEMM;
static AlgoPerfCache<T_BwdDataPerf>& Cache() { return bwd_data_algos; }
static Status FindAlgorithms(const ConvArgs& args, const ROCMExecutionProvider* provider,
std::vector<T_BwdDataPerf>& perf_results) {
static const T_BwdDataAlgo algos[] = {
miopenConvolutionBwdDataAlgoGEMM,
miopenConvolutionBwdDataAlgoDirect,
miopenConvolutionBwdDataAlgoFFT,
miopenConvolutionBwdDataAlgoWinograd,
miopenTransposeBwdDataAlgoGEMM,
miopenConvolutionBwdDataAlgoImplicitGEMM
};
static constexpr int num_algos = MIOPEN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing MIOpen convolution backward data algorithms.");
int perf_count;
std::unique_ptr<T_BwdDataPerf[]> candidates(new T_BwdDataPerf[num_algos]);
size_t max_workspace_size = provider->GetMiopenConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos)
: AlgoSearchWorkspaceSize;
// Use GetTransientScratchBuffer() so the workspace can be freed instead of cached.
// Because the benchmarking uses a huge amount of memory, e.g. a few GBs.
IAllocatorUniquePtr<void> workspace = provider->GetTransientScratchBuffer<void>(max_workspace_size);
MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionBackwardDataAlgorithm(
args.handle, args.y_tensor, args.dy_data, args.w_desc, args.w_data, args.conv_desc, args.x_tensor,
args.dx_data, 1, &perf_count, candidates.get(), workspace.get(), max_workspace_size, false));
perf_results = GetValidAlgorithms<T_BwdDataPerf>(candidates.get(), perf_count);
return Status::OK();
}
};
template <>
struct AlgoSearch<T_BwdFilterAlgo> {
static constexpr auto DEFAULT_ALGO = miopenConvolutionBwdWeightsAlgoGEMM;
static AlgoPerfCache<T_BwdFilterPerf>& Cache() { return bwd_filter_algos; }
static Status FindAlgorithms(const ConvArgs& args, const ROCMExecutionProvider* provider,
std::vector<T_BwdFilterPerf>& perf_results) {
static const T_BwdFilterAlgo algos[] = {
miopenConvolutionBwdWeightsAlgoGEMM,
miopenConvolutionBwdWeightsAlgoDirect,
miopenConvolutionBwdWeightsAlgoWinograd,
miopenConvolutionBwdWeightsAlgoImplicitGEMM
};
static constexpr int num_algos = MIOPEN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing MIOpen convolution backward filter algorithms.");
std::unique_ptr<T_BwdFilterPerf[]> candidates(new T_BwdFilterPerf[num_algos]);
int perf_count;
size_t max_workspace_size = provider->GetMiopenConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos)
: AlgoSearchWorkspaceSize;
// Use GetTransientScratchBuffer() so the workspace can be freed instead of cached.
// Because the benchmarking uses a huge amount of memory, e.g. a few GBs.
IAllocatorUniquePtr<void> workspace = provider->GetTransientScratchBuffer<void>(max_workspace_size);
MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionBackwardWeightsAlgorithm(
args.handle, args.y_tensor, args.dy_data, args.x_tensor, args.x_data, args.conv_desc, args.w_desc,
args.dw_data, 1, &perf_count, candidates.get(), workspace.get(), max_workspace_size, false));
perf_results = GetValidAlgorithms<T_BwdFilterPerf>(candidates.get(), perf_count);
return Status::OK();
}
};
template <typename T_Algo>
class AlgoIterator {
public:
AlgoIterator(const ConvArgs& args) : args_(args) {}
Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<miopenConvAlgoPerf_t>& perf_results);
Status TryAll(const ROCMExecutionProvider* provider, std::function<Status(const miopenConvAlgoPerf_t& perf)> f) {
auto& cache = AlgoSearch<T_Algo>::Cache();
miopenConvAlgoPerf_t algo_perf;
if (cache.Find(args_.params, &algo_perf) && f(algo_perf) == Status::OK()) {
return Status::OK();
}
std::vector<miopenConvAlgoPerf_t> perf_results;
ORT_RETURN_IF_ERROR(AlgoSearch<T_Algo>::FindAlgorithms(args_, provider, perf_results));
for (auto& algo_perf : perf_results) {
if (f(algo_perf) == Status::OK()) {
cache.Insert(args_.params, algo_perf);
return Status::OK();
}
}
ORT_ENFORCE(false, "Unable to find a valid MIOpen algorithm to run convolution.");
return Status::OK();
}
private:
const ConvArgs& args_;
};
template<> Status AlgoIterator<T_BwdDataAlgo>::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_BwdDataPerf>& perf_results){
perf_results.resize(1);
perf_results[0].bwd_data_algo = AlgoSearch<T_BwdDataAlgo>::DEFAULT_ALGO;
MIOPEN_RETURN_IF_ERROR(GetWorkspaceSize(args, perf_results[0].bwd_data_algo, &(perf_results[0].memory)));
return Status::OK();
}
template<> Status AlgoIterator<T_BwdFilterAlgo>::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_BwdFilterPerf>& perf_results){
perf_results.resize(1);
perf_results[0].bwd_weights_algo = AlgoSearch<T_BwdFilterAlgo>::DEFAULT_ALGO;
MIOPEN_RETURN_IF_ERROR(GetWorkspaceSize(args, perf_results[0].bwd_weights_algo, &(perf_results[0].memory)));
return Status::OK();
}
template <typename T>
Status ConvGrad<T>::PrepareArgs(const Tensor& x, const Tensor& dY, const Tensor& w, Tensor* dB, Tensor* dX,
Tensor* dW) const {
const TensorShape& x_shape = x.Shape();
std::vector<int64_t> x_dims = x_shape.GetDims();
args_.x_data = reinterpret_cast<const HipT*>(x.template Data<T>());
const TensorShape& dy_shape = dY.Shape();
std::vector<int64_t> dy_dims = dy_shape.GetDims();
args_.dy_data = reinterpret_cast<const HipT*>(dY.template Data<T>());
const TensorShape& w_shape = w.Shape();
std::vector<int64_t> w_dims = w_shape.GetDims();
args_.w_data = reinterpret_cast<const HipT*>(w.template Data<T>());
args_.db_data = dB ? reinterpret_cast<HipT*>(dB->template MutableData<T>()) : nullptr;
args_.dx_data = dX ? reinterpret_cast<HipT*>(dX->template MutableData<T>()) : nullptr;
args_.dw_data = dW ? reinterpret_cast<HipT*>(dW->template MutableData<T>()) : nullptr;
bool x_dims_changed = (args_.last_x_dims != x_dims);
bool w_dims_changed = (args_.last_w_dims != w_dims);
if (x_dims_changed || w_dims_changed) {
if (x_dims_changed) args_.last_x_dims = x_dims;
if (w_dims_changed) args_.last_w_dims = w_dims;
// Update Attributes
ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(&x, &w));
std::vector<int64_t> kernel_shape;
ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(w_shape, kernel_shape));
auto rank = kernel_shape.size();
std::vector<int64_t> pads(conv_attrs_.pads);
if (pads.empty()) {
pads.resize(rank * 2, 0);
}
std::vector<int64_t> dilations(conv_attrs_.dilations);
if (dilations.empty()) {
dilations.resize(rank, 1);
}
std::vector<int64_t> strides(conv_attrs_.strides);
if (strides.empty()) {
strides.resize(rank, 1);
}
// MIOpen only takes 4D or 5D x tensor, so pad dimensions if needed.
if (rank < 2) {
x_dims.push_back(1);
dy_dims.push_back(1);
w_dims.push_back(1);
pads.insert(pads.begin() + rank, 0);
pads.insert(pads.end(), 0);
kernel_shape.push_back(1);
strides.push_back(1);
dilations.push_back(1);
}
const ROCMExecutionProvider* rocm_ep =
static_cast<const ROCMExecutionProvider*>(this->Info().GetExecutionProvider());
memset(&args_.params, 0, sizeof(ConvParams));
args_.params.device_id = static_cast<int8_t>(rocm_ep->GetDeviceId());
args_.params.data_type = MiopenTensor::GetDataType<HipT>();
args_.params.input_dim = static_cast<uint8_t>(x_dims.size());
for (size_t i = 0; i < x_dims.size(); i++) {
args_.params.input_size[i] = static_cast<int>(x_dims[i]);
args_.params.weight_size[i] = static_cast<int>(w_dims[i]);
}
for (size_t i = 0; i < rank; i++) {
args_.params.padding[i] = static_cast<int>(pads[i]);
args_.params.padding[i + rank] = static_cast<int>(pads[i + rank]);
args_.params.stride[i] = static_cast<int>(strides[i]);
args_.params.dilation[i] = static_cast<int>(dilations[i]);
}
args_.params.groups = conv_attrs_.group;
args_.handle = MiopenHandle();
ORT_RETURN_IF_ERROR(args_.w_desc.Set(w_dims, args_.params.data_type));
ORT_RETURN_IF_ERROR(args_.x_tensor.Set(x_dims, args_.params.data_type));
ORT_RETURN_IF_ERROR(args_.y_tensor.Set(dy_dims, args_.params.data_type));
ORT_RETURN_IF_ERROR(args_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,
gsl::narrow_cast<int>(conv_attrs_.group), miopenConvolution,
args_.params.data_type));
if (dB) {
const TensorShape& db_shape = dB->Shape();
ORT_RETURN_IF_NOT(db_shape.NumDimensions() == 1, "bias should be 1D");
std::vector<int64_t> db_dims(2 + kernel_shape.size(), 1);
db_dims[1] = db_shape[0];
ORT_RETURN_IF_ERROR(args_.b_tensor.Set(db_dims, MiopenTensor::GetDataType<HipT>()));
}
}
return Status::OK();
}
template <typename T>
Status ConvGrad<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* dY = context->Input<Tensor>(0);
const Tensor* X = context->Input<Tensor>(1);
const Tensor* W = context->Input<Tensor>(2);
Tensor* dX = context->Output(0, X->Shape());
Tensor* dW = context->Output(1, W->Shape());
Tensor* dB = context->Output(2, {W->Shape()[0]});
ORT_RETURN_IF_ERROR(PrepareArgs(*X, *dY, *W, dB, dX, dW));
if (dX) ORT_RETURN_IF_ERROR(ComputeInputGradient());
if (dW) ORT_RETURN_IF_ERROR(ComputeWeightGradient());
if (dB) ORT_RETURN_IF_ERROR(ComputeBiasGradient());
return Status::OK();
}
template <typename T>
Status ConvGrad<T>::ComputeInputGradient() const {
return AlgoIterator<T_BwdDataAlgo>(args_).TryAll(
static_cast<const ROCMExecutionProvider*>(Info().GetExecutionProvider()),
[&](const T_BwdDataPerf& algo_perf) -> Status {
const auto one = Consts<HipT>::One;
const auto zero = Consts<HipT>::Zero;
IAllocatorUniquePtr<void> workspace = GetScratchBuffer<void>(algo_perf.memory);
MIOPEN_RETURN_IF_ERROR(miopenConvolutionBackwardData(
args_.handle, &one, args_.y_tensor, args_.dy_data, args_.w_desc, args_.w_data, args_.conv_desc,
algo_perf.bwd_data_algo, &zero, args_.x_tensor, args_.dx_data, workspace.get(), algo_perf.memory));
return Status::OK();
});
}
template <typename T>
Status ConvGrad<T>::ComputeWeightGradient() const {
return AlgoIterator<T_BwdFilterAlgo>(args_).TryAll(
static_cast<const ROCMExecutionProvider*>(Info().GetExecutionProvider()),
[&](const T_BwdFilterPerf& algo_perf) -> Status {
const auto one = Consts<HipT>::One;
const auto zero = Consts<HipT>::Zero;
IAllocatorUniquePtr<void> workspace = GetScratchBuffer<void>(algo_perf.memory);
MIOPEN_RETURN_IF_ERROR(miopenConvolutionBackwardWeights(
args_.handle, &one, args_.y_tensor, args_.dy_data, args_.x_tensor, args_.x_data, args_.conv_desc,
algo_perf.bwd_weights_algo, &zero, args_.w_desc, args_.dw_data, workspace.get(), algo_perf.memory));
return Status::OK();
});
}
template <typename T>
Status ConvGrad<T>::ComputeBiasGradient() const {
const auto one = Consts<HipT>::One;
const auto zero = Consts<HipT>::Zero;
MIOPEN_RETURN_IF_ERROR(miopenConvolutionBackwardBias(
args_.handle, &one, args_.y_tensor, args_.dy_data, &zero,
args_.b_tensor, args_.db_data));
return Status::OK();
}
} // namespace rocm
} // namespace onnxruntime

View file

@ -0,0 +1,69 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/miopen_common.h"
#include "core/providers/cpu/nn/conv_attributes.h"
#include "core/providers/rocm/nn/conv.h"
namespace onnxruntime {
namespace rocm {
constexpr int MAX_DIM = 3;
struct ConvParams {
int8_t device_id;
miopenDataType_t data_type;
int input_size[2 + MAX_DIM];
uint8_t input_dim;
int weight_size[2 + MAX_DIM];
int padding[MAX_DIM * 2];
int stride[MAX_DIM];
int dilation[MAX_DIM];
int64_t groups;
};
struct ConvArgs {
// Update needed if x or w's dims changed.
std::vector<int64_t> last_x_dims;
std::vector<int64_t> last_w_dims;
miopenHandle_t handle;
ConvParams params;
MiopenTensor x_tensor, y_tensor, b_tensor;
MiopenTensorDescriptor w_desc;
MiopenConvolutionDescriptor conv_desc;
const void* x_data;
const void* w_data;
const void* dy_data;
void* dx_data;
void* dw_data;
void* db_data;
};
template <typename T>
class ConvGrad final : public RocmKernel {
public:
using HipT = typename ToHipType<T>::MappedType;
ConvGrad(const OpKernelInfo& info) : RocmKernel(info), conv_attrs_(info) {
auto pads_size = conv_attrs_.pads.size();
ORT_ENFORCE(pads_size % 2 == 0);
}
Status ComputeInternal(OpKernelContext* context) const override;
protected:
Status PrepareArgs(const Tensor& x, const Tensor& dY, const Tensor& w, Tensor* dB, Tensor* dX, Tensor* dW) const;
mutable ConvArgs args_;
ConvAttributes conv_attrs_;
private:
Status ComputeWeightGradient() const;
Status ComputeInputGradient() const;
Status ComputeBiasGradient() const;
};
} // namespace rocm
} // namespace onnxruntime

View file

@ -75,7 +75,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_float, BatchNormalizationGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_float, BatchNormalizationGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ConvGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, ConvGrad);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, ConvGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ConvGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GatherGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, DropoutGrad);
@ -262,9 +262,9 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_MLFloat16, BatchNormalizationGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_float, BatchNormalizationGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_float, BatchNormalizationGrad)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ConvGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ConvGrad)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, ConvGrad)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ConvGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ConvGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GatherGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DivGrad)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, DivGrad)>,