mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
Add BatchNorm kernel for ROCm (#9014)
* Add BatchNorm kernel for ROCm, update BN test * correct epsilon_ setting; limit min epsilon
This commit is contained in:
parent
e83cc534d4
commit
a1021a1cf4
8 changed files with 433 additions and 8 deletions
|
|
@ -9,6 +9,8 @@
|
|||
#include "core/framework/tensor.h"
|
||||
#include <cfloat>
|
||||
|
||||
const double MIOPEN_BN_MIN_EPSILON = 1e-5;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace rocm {
|
||||
|
||||
|
|
@ -56,5 +58,14 @@ struct ReduceConsts {
|
|||
static const ElemType One;
|
||||
};
|
||||
|
||||
inline double ClampMiopenBatchNormEpsilon(double epsilon) {
|
||||
if (epsilon < MIOPEN_BN_MIN_EPSILON) {
|
||||
if (MIOPEN_BN_MIN_EPSILON - epsilon > FLT_EPSILON)
|
||||
LOGS_DEFAULT(WARNING) << "Provided epsilon is smaller than CUDNN_BN_MIN_EPSILON. Setting it to CUDNN_BN_MIN_EPSILON";
|
||||
return MIOPEN_BN_MIN_EPSILON;
|
||||
}
|
||||
return epsilon;
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/shared_library/provider_api.h"
|
||||
#include "core/common/status.h"
|
||||
#include "core/providers/rocm/rocm_pch.h"
|
||||
#include "core/providers/rocm/shared_inc/rocm_call.h"
|
||||
|
|
|
|||
|
|
@ -8,15 +8,13 @@
|
|||
#include "gtest/gtest.h"
|
||||
#include "gmock/gmock.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace test {
|
||||
|
||||
using namespace onnxruntime::test;
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#if USE_CUDA || USE_ROCM
|
||||
static void TestBatchNormInternal(bool test_double = false, bool T_is_half = false,
|
||||
bool T1_is_half = false, bool T2_is_half = false,
|
||||
const std::vector<int64_t>& input_output_dims = {2, 2, 2, 2}) {
|
||||
|
|
@ -138,9 +136,11 @@ TEST(CudaKernelTest, BNInternalBasic) { // float case
|
|||
TestBatchNormInternal();
|
||||
}
|
||||
|
||||
#ifndef USE_ROCM // MIOpen does not support double type
|
||||
TEST(CudaKernelTest, BNInternalDouble) { // double case
|
||||
TestBatchNormInternal(true);
|
||||
}
|
||||
#endif // ndef USE_ROCM
|
||||
|
||||
TEST(CudaKernelTest, BNInternalHalf) { // half case
|
||||
TestBatchNormInternal(false, true, true, true);
|
||||
|
|
@ -195,7 +195,7 @@ TEST(CudaKernelTest, BNInternal1DInput) { // float case, 1d input
|
|||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kCpuExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
#endif // USE_CUDA
|
||||
#endif // USE_CUDA || USE_ROCM
|
||||
|
||||
} // namespace test
|
||||
} // namespace contrib
|
||||
|
|
|
|||
137
orttraining/orttraining/training_ops/rocm/nn/batch_norm_grad.cc
Normal file
137
orttraining/orttraining/training_ops/rocm/nn/batch_norm_grad.cc
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "batch_norm_grad.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/rocm/miopen_common.h"
|
||||
#include "core/providers/cpu/nn/batch_norm_helper.h"
|
||||
#include "core/providers/rocm/math/unary_elementwise_ops_impl.h"
|
||||
|
||||
using namespace std;
|
||||
namespace onnxruntime {
|
||||
namespace rocm {
|
||||
|
||||
#define REGISTER_GRADIENT_KERNEL_TYPED(T, T1, T2) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
BatchNormalizationGrad, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T##_##T1##_##T2, \
|
||||
kRocmExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T1>()) \
|
||||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<T2>()), \
|
||||
BatchNormalizationGrad<T, T1, T2>);
|
||||
|
||||
template <typename T, typename T1, typename T2>
|
||||
Status BatchNormalizationGrad<T, T1, T2>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
typedef typename ToHipType<T>::MappedType HipT;
|
||||
typedef typename ToHipType<T1>::MappedType HipT1;
|
||||
typedef typename ToHipType<T2>::MappedType HipT2;
|
||||
|
||||
const Tensor* dY = ctx->Input<Tensor>(0);
|
||||
const Tensor* X = ctx->Input<Tensor>(1);
|
||||
const Tensor* Scale = ctx->Input<Tensor>(2);
|
||||
const Tensor* saved_mean = ctx->Input<Tensor>(3);
|
||||
// miopenBatchNormalizationBackward() claims to use `savedInvVariance`, but the value
|
||||
// is actually equal to the batch inv_std, so we use name `saved_inv_std` here.
|
||||
const Tensor* saved_inv_std = ctx->Input<Tensor>(4);
|
||||
const TensorShape input_shape = X->Shape();
|
||||
const TensorShape channel_shape = saved_mean->Shape();
|
||||
|
||||
// no B here, but B has same size as Scale, so can validate inputs for gradient with this substitute
|
||||
ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(X, Scale, Scale, saved_mean, saved_inv_std));
|
||||
|
||||
auto dY_data = reinterpret_cast<const HipT*>(dY->template Data<T>());
|
||||
auto X_data = reinterpret_cast<const HipT*>(X->template Data<T>());
|
||||
auto Scale_data = reinterpret_cast<const HipT1*>(Scale->template Data<T1>());
|
||||
auto saved_mean_data = reinterpret_cast<const HipT2*>(saved_mean->template Data<T2>());
|
||||
auto saved_inv_std_data = reinterpret_cast<const HipT2*>(saved_inv_std->template Data<T2>());
|
||||
|
||||
auto dX_data = reinterpret_cast<HipT*>(ctx->Output(0, input_shape)->template MutableData<T>());
|
||||
auto dScale_data = reinterpret_cast<HipT1*>(ctx->Output(1, channel_shape)->template MutableData<T1>());
|
||||
auto dBias_data = reinterpret_cast<HipT1*>(ctx->Output(2, channel_shape)->template MutableData<T1>());
|
||||
|
||||
const auto alpha = Consts<HipT>::One;
|
||||
const auto beta = Consts<HipT>::Zero;
|
||||
|
||||
MiopenTensor input_tensor, scale_bias_tensor;
|
||||
vector<int64_t> new_dims;
|
||||
BatchNormHelper::NormalizeDims(input_shape, new_dims);
|
||||
ORT_RETURN_IF_ERROR(input_tensor.Set(new_dims, MiopenTensor::GetDataType<HipT>()));
|
||||
ORT_RETURN_IF_ERROR(scale_bias_tensor.Set(input_tensor, miopen_batch_norm_mode_));
|
||||
|
||||
const int64_t C = new_dims[1];
|
||||
auto p_scale = reinterpret_cast<const void*>(Scale_data);
|
||||
auto p_saved_mean = reinterpret_cast<const void*>(saved_mean_data);
|
||||
auto p_saved_inv_std = reinterpret_cast<const void*>(saved_inv_std_data);
|
||||
auto p_dScale = reinterpret_cast<void*>(dScale_data);
|
||||
auto p_dBias = reinterpret_cast<void*>(dBias_data);
|
||||
|
||||
IAllocatorUniquePtr<float> p_f_scale, p_f_dScale, p_f_dBias, p_f_saved_mean, p_f_saved_inv_std;
|
||||
|
||||
if (std::is_same<T1, MLFloat16>::value) {
|
||||
p_f_scale = GetScratchBuffer<float>(C);
|
||||
p_f_dScale = GetScratchBuffer<float>(C);
|
||||
p_f_dBias = GetScratchBuffer<float>(C);
|
||||
|
||||
Impl_Cast<HipT1, float>(Stream(), Scale_data, p_f_scale.get(), C);
|
||||
|
||||
p_scale = p_f_scale.get();
|
||||
p_dScale = p_f_dScale.get();
|
||||
p_dBias = p_f_dBias.get();
|
||||
}
|
||||
|
||||
if (std::is_same<T2, MLFloat16>::value) {
|
||||
p_f_saved_mean = GetScratchBuffer<float>(C);
|
||||
p_f_saved_inv_std = GetScratchBuffer<float>(C);
|
||||
|
||||
Impl_Cast<HipT2, float>(Stream(), saved_mean_data, p_f_saved_mean.get(), C);
|
||||
Impl_Cast<HipT2, float>(Stream(), saved_inv_std_data, p_f_saved_inv_std.get(), C);
|
||||
|
||||
p_saved_mean = p_f_saved_mean.get();
|
||||
p_saved_inv_std = p_f_saved_inv_std.get();
|
||||
}
|
||||
|
||||
MIOPEN_RETURN_IF_ERROR(miopenBatchNormalizationBackward(
|
||||
MiopenHandle(),
|
||||
miopen_batch_norm_mode_,
|
||||
&alpha,
|
||||
&beta,
|
||||
&alpha,
|
||||
&beta,
|
||||
input_tensor,
|
||||
X_data,
|
||||
input_tensor,
|
||||
dY_data,
|
||||
input_tensor,
|
||||
dX_data,
|
||||
scale_bias_tensor,
|
||||
p_scale,
|
||||
p_dScale,
|
||||
p_dBias,
|
||||
epsilon_,
|
||||
p_saved_mean,
|
||||
p_saved_inv_std));
|
||||
|
||||
if (std::is_same<T1, MLFloat16>::value) {
|
||||
Impl_Cast<float, HipT1>(Stream(), reinterpret_cast<float*>(p_dScale), dScale_data, C);
|
||||
Impl_Cast<float, HipT1>(Stream(), reinterpret_cast<float*>(p_dBias), dBias_data, C);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define SPECIALIZED_GRADIENT(T, T1, T2) \
|
||||
REGISTER_GRADIENT_KERNEL_TYPED(T, T1, T2) \
|
||||
template Status BatchNormalizationGrad<T, T1, T2>::ComputeInternal(OpKernelContext* ctx) const;
|
||||
|
||||
SPECIALIZED_GRADIENT(float, float, float)
|
||||
// MIOpen kernel does not support double, disable for now.
|
||||
// SPECIALIZED_GRADIENT(double, double, double)
|
||||
SPECIALIZED_GRADIENT(MLFloat16, MLFloat16, MLFloat16)
|
||||
SPECIALIZED_GRADIENT(MLFloat16, MLFloat16, float)
|
||||
SPECIALIZED_GRADIENT(MLFloat16, float, float)
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "gsl/gsl"
|
||||
#include "core/providers/rocm/rocm_kernel.h"
|
||||
#include "core/providers/rocm/miopen_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace rocm {
|
||||
|
||||
template <typename T, typename T1, typename T2>
|
||||
class BatchNormalizationGrad final : public RocmKernel {
|
||||
public:
|
||||
BatchNormalizationGrad(const OpKernelInfo& info)
|
||||
: RocmKernel{info},
|
||||
miopen_batch_norm_mode_(miopenBNSpatial) {
|
||||
float tmp_epsilon;
|
||||
ORT_ENFORCE(info.GetAttr<float>("epsilon", &tmp_epsilon).IsOK());
|
||||
epsilon_ = ClampMiopenBatchNormEpsilon(static_cast<double>(tmp_epsilon));
|
||||
|
||||
// spatial or not
|
||||
int64_t tmp_spatial;
|
||||
if (info.GetAttr<int64_t>("spatial", &tmp_spatial).IsOK()) {
|
||||
spatial_ = tmp_spatial;
|
||||
}
|
||||
|
||||
if (spatial_ == 0) {
|
||||
miopen_batch_norm_mode_ = miopenBNPerActivation;
|
||||
}
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
double epsilon_;
|
||||
int64_t spatial_ = 1; // default as per spec
|
||||
miopenBatchNormMode_t miopen_batch_norm_mode_;
|
||||
};
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,167 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "batch_norm_internal.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/rocm/miopen_common.h"
|
||||
#include "core/providers/cpu/nn/batch_norm_helper.h"
|
||||
#include "core/providers/rocm/math/unary_elementwise_ops_impl.h"
|
||||
|
||||
using namespace std;
|
||||
namespace onnxruntime {
|
||||
namespace rocm {
|
||||
|
||||
#define REGISTER_KERNEL_TYPED(T, T1, T2) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
BatchNormInternal, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T##_##T1##_##T2, \
|
||||
kRocmExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.Alias(3, 1) \
|
||||
.Alias(4, 2) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T1>()) \
|
||||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<T2>()), \
|
||||
BatchNormInternal<T, T1, T2>);
|
||||
|
||||
template <typename T, typename T1, typename T2>
|
||||
Status BatchNormInternal<T, T1, T2>::ComputeInternal(OpKernelContext* p_op_kernel_context) const {
|
||||
typedef typename ToHipType<T>::MappedType HipT;
|
||||
typedef typename ToHipType<T1>::MappedType HipT1;
|
||||
typedef typename ToHipType<T2>::MappedType HipT2;
|
||||
|
||||
const Tensor* X = p_op_kernel_context->Input<Tensor>(0);
|
||||
const Tensor* scale = p_op_kernel_context->Input<Tensor>(1);
|
||||
const Tensor* B = p_op_kernel_context->Input<Tensor>(2);
|
||||
const Tensor* mean = p_op_kernel_context->Input<Tensor>(3);
|
||||
const Tensor* var = p_op_kernel_context->Input<Tensor>(4);
|
||||
|
||||
ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(X, scale, B, mean, var, spatial_ == 1));
|
||||
|
||||
const TensorShape& x_shape = X->Shape();
|
||||
const TensorShape& channel_shape = mean->Shape();
|
||||
|
||||
Tensor* Y = p_op_kernel_context->Output(0, x_shape);
|
||||
Tensor* running_mean = p_op_kernel_context->Output(1, channel_shape);
|
||||
Tensor* running_var = p_op_kernel_context->Output(2, channel_shape);
|
||||
Tensor* saved_mean = p_op_kernel_context->Output(3, channel_shape);
|
||||
// miopenBatchNormalizationForwardTraining() claims to output `resultSaveInvVariance`, but the value
|
||||
// is actually equal to the batch inv_std, so we use name `saved_inv_std` here.
|
||||
Tensor* saved_inv_std = p_op_kernel_context->Output(4, channel_shape);
|
||||
|
||||
auto x_data = reinterpret_cast<const HipT*>(X->template Data<T>());
|
||||
auto scale_data = reinterpret_cast<const HipT1*>(scale->template Data<T1>());
|
||||
auto b_data = reinterpret_cast<const HipT1*>(B->template Data<T1>());
|
||||
auto mean_data = reinterpret_cast<const HipT2*>(mean->template Data<T2>());
|
||||
auto var_data = reinterpret_cast<const HipT2*>(var->template Data<T2>());
|
||||
|
||||
auto y_data = reinterpret_cast<HipT*>(Y->template MutableData<T>());
|
||||
|
||||
// In MIOpenBatchNormForward, alpha and beta are not const.
|
||||
float alpha = 1.0;
|
||||
float beta = 0.0;
|
||||
|
||||
MiopenTensor data_desc, bn_tensor_desc;
|
||||
vector<int64_t> new_dims;
|
||||
BatchNormHelper::NormalizeDims(x_shape, new_dims);
|
||||
ORT_RETURN_IF_ERROR(data_desc.Set(new_dims, MiopenTensor::GetDataType<HipT>()));
|
||||
ORT_RETURN_IF_ERROR(bn_tensor_desc.Set(data_desc, miopen_batch_norm_mode_));
|
||||
|
||||
auto running_mean_data = reinterpret_cast<HipT2*>(running_mean->template MutableData<T2>());
|
||||
auto running_var_data = reinterpret_cast<HipT2*>(running_var->template MutableData<T2>());
|
||||
auto saved_mean_data = reinterpret_cast<HipT2*>(saved_mean->template MutableData<T2>());
|
||||
auto saved_inv_std_data = reinterpret_cast<HipT2*>(saved_inv_std->template MutableData<T2>());
|
||||
|
||||
auto p_scale = reinterpret_cast<const void*>(scale_data);
|
||||
auto p_B = reinterpret_cast<const void*>(b_data);
|
||||
auto p_running_mean = reinterpret_cast<void*>(running_mean_data);
|
||||
auto p_running_var = reinterpret_cast<void*>(running_var_data);
|
||||
auto p_saved_mean = reinterpret_cast<void*>(saved_mean_data);
|
||||
auto p_saved_inv_std = reinterpret_cast<void*>(saved_inv_std_data);
|
||||
|
||||
|
||||
const int64_t C = new_dims[1];
|
||||
IAllocatorUniquePtr<float> p_f_scale, p_f_B, p_f_running_mean, p_f_running_var, p_f_saved_mean, p_f_saved_inv_std;
|
||||
|
||||
if (std::is_same<T1, MLFloat16>::value) {
|
||||
// Convert scale/B to float
|
||||
p_f_scale = GetScratchBuffer<float>(C);
|
||||
p_f_B = GetScratchBuffer<float>(C);
|
||||
|
||||
Impl_Cast<HipT1, float>(Stream(), scale_data, p_f_scale.get(), C);
|
||||
Impl_Cast<HipT1, float>(Stream(), b_data, p_f_B.get(), C);
|
||||
|
||||
p_scale = p_f_scale.get();
|
||||
p_B = p_f_B.get();
|
||||
}
|
||||
|
||||
if (std::is_same<T2, MLFloat16>::value) {
|
||||
// Convert mean/var to float
|
||||
p_f_running_mean = GetScratchBuffer<float>(C);
|
||||
p_f_running_var = GetScratchBuffer<float>(C);
|
||||
p_f_saved_mean = GetScratchBuffer<float>(C);
|
||||
p_f_saved_inv_std = GetScratchBuffer<float>(C);
|
||||
|
||||
Impl_Cast<HipT2, float>(Stream(), mean_data, p_f_running_mean.get(), C);
|
||||
Impl_Cast<HipT2, float>(Stream(), var_data, p_f_running_var.get(), C);
|
||||
|
||||
p_running_mean = p_f_running_mean.get();
|
||||
p_running_var = p_f_running_var.get();
|
||||
p_saved_mean = p_f_saved_mean.get();
|
||||
p_saved_inv_std = p_f_saved_inv_std.get();
|
||||
} else if (mean_data != running_mean_data) {
|
||||
HIP_RETURN_IF_ERROR(
|
||||
hipMemcpyAsync(running_mean_data, mean_data, C * sizeof(T2), hipMemcpyDeviceToDevice, Stream()));
|
||||
HIP_RETURN_IF_ERROR(
|
||||
hipMemcpyAsync(running_var_data, var_data, C * sizeof(T2), hipMemcpyDeviceToDevice, Stream()));
|
||||
}
|
||||
|
||||
// NOTE: in miopenBatchNorm, biased std/var is used when calculating `save_inv_std` and `y`, while
|
||||
// `running_var` is updated using unbiased `batch_var`:
|
||||
// running_var = (1 - momentum_) * unbiased_batch_var + momentum_ * running_var
|
||||
// This is inconsistent with BatchNormalization Onnx spec, which uses population variance (biased).
|
||||
MIOPEN_RETURN_IF_ERROR(miopenBatchNormalizationForwardTraining(
|
||||
MiopenHandle(),
|
||||
miopen_batch_norm_mode_,
|
||||
&alpha,
|
||||
&beta,
|
||||
data_desc,
|
||||
x_data,
|
||||
data_desc,
|
||||
y_data,
|
||||
bn_tensor_desc,
|
||||
const_cast<void*>(p_scale),
|
||||
const_cast<void*>(p_B),
|
||||
1.0 - momentum_,
|
||||
p_running_mean,
|
||||
p_running_var,
|
||||
epsilon_,
|
||||
p_saved_mean,
|
||||
p_saved_inv_std));
|
||||
|
||||
if (std::is_same<T2, MLFloat16>::value) {
|
||||
Impl_Cast<float, HipT2>(Stream(), reinterpret_cast<float*>(p_running_mean), running_mean_data, C);
|
||||
Impl_Cast<float, HipT2>(Stream(), reinterpret_cast<float*>(p_running_var), running_var_data, C);
|
||||
Impl_Cast<float, HipT2>(Stream(), reinterpret_cast<float*>(p_saved_mean), saved_mean_data, C);
|
||||
Impl_Cast<float, HipT2>(Stream(), reinterpret_cast<float*>(p_saved_inv_std), saved_inv_std_data, C);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define SPECIALIZED_COMPUTE(T, T1, T2) \
|
||||
REGISTER_KERNEL_TYPED(T, T1, T2) \
|
||||
template Status BatchNormInternal<T, T1, T2>::ComputeInternal(OpKernelContext* ctx) const;
|
||||
|
||||
|
||||
SPECIALIZED_COMPUTE(float, float, float)
|
||||
// MIOpen kernel does not support double, disable for now.
|
||||
// SPECIALIZED_COMPUTE(double, double, double)
|
||||
SPECIALIZED_COMPUTE(MLFloat16, MLFloat16, MLFloat16)
|
||||
SPECIALIZED_COMPUTE(MLFloat16, MLFloat16, float)
|
||||
SPECIALIZED_COMPUTE(MLFloat16, float, float)
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "gsl/gsl"
|
||||
#include "core/providers/rocm/rocm_kernel.h"
|
||||
#include "core/providers/rocm/miopen_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace rocm {
|
||||
|
||||
template <typename T, typename T1, typename T2>
|
||||
class BatchNormInternal final : public RocmKernel {
|
||||
public:
|
||||
BatchNormInternal(const OpKernelInfo& op_kernel_info)
|
||||
: RocmKernel{op_kernel_info},
|
||||
miopen_batch_norm_mode_(miopenBNSpatial),
|
||||
momentum_(0.9) {
|
||||
float tmp_epsilon;
|
||||
ORT_ENFORCE(op_kernel_info.GetAttr<float>("epsilon", &tmp_epsilon).IsOK());
|
||||
epsilon_ = ClampMiopenBatchNormEpsilon(static_cast<double>(tmp_epsilon));
|
||||
|
||||
// spatial or not
|
||||
int64_t tmp_spatial;
|
||||
if (op_kernel_info.GetAttr<int64_t>("spatial", &tmp_spatial).IsOK()) {
|
||||
spatial_ = tmp_spatial;
|
||||
}
|
||||
|
||||
if (spatial_ == 0) {
|
||||
miopen_batch_norm_mode_ = miopenBNPerActivation;
|
||||
}
|
||||
|
||||
float tmp_momentum;
|
||||
if (op_kernel_info.GetAttr<float>("momentum", &tmp_momentum).IsOK()) {
|
||||
momentum_ = static_cast<double>(tmp_momentum);
|
||||
}
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
double epsilon_;
|
||||
int64_t spatial_ = 1; // default as per spec
|
||||
miopenBatchNormMode_t miopen_batch_norm_mode_;
|
||||
double momentum_;
|
||||
};
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -67,8 +67,16 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, LogSoftmaxGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, LogSoftmaxGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, LogSoftmaxGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BatchNormalizationGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, BatchNormalizationGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_float_float, BatchNormalizationGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double_double_double, BatchNormalizationGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_MLFloat16, BatchNormalizationGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_float, BatchNormalizationGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_float, BatchNormalizationGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_float_float, BatchNormInternal);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double_double_double, BatchNormInternal);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_MLFloat16, BatchNormInternal);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_float, BatchNormInternal);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_float, BatchNormInternal);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GatherGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, DropoutGrad);
|
||||
|
||||
|
|
@ -208,8 +216,16 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, int64_t, SoftmaxCrossEntropyLossInternal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, int64_t, SoftmaxCrossEntropyLossInternalGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, int64_t, SoftmaxCrossEntropyLossInternalGrad)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BatchNormalizationGrad)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, BatchNormalizationGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_float_float, BatchNormalizationGrad)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double_double_double, BatchNormalizationGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_MLFloat16, BatchNormalizationGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_float, BatchNormalizationGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_float, BatchNormalizationGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_float_float, BatchNormInternal)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double_double_double, BatchNormInternal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_MLFloat16, BatchNormInternal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_float, BatchNormInternal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_float, BatchNormInternal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GatherGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DivGrad)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, DivGrad)>,
|
||||
|
|
|
|||
Loading…
Reference in a new issue