Fix CUDA BatchNorm bugs and add support for NHWC (#19742)

### Description
- Fix incorrect running_mean / running_var in training mode due to
incorrect momentum and missing input mean/var. runnig_var could be
correct, but has a too high epsilon.
- Fix incorrect checks when using NHWC
- Pass NHWC flag to NormalizeDims to get correct new dimensions from
x_shape
- Register missing double operations to get parity between NHWC/NCHW
This commit is contained in:
Markus Tavenrath 2024-03-05 17:09:42 +01:00 committed by GitHub
parent cd56ea4a74
commit bdf678df93
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 66 additions and 21 deletions

View file

@ -25,6 +25,8 @@ class BatchNormHelper {
const Tensor* var,
bool is_spatial = true,
bool is_nhwc = false) {
// NHWC dependent shape: X
// All other shapes are assumed to be in NCHW layout?
const auto& x_dims = X->Shape().GetDims();
// If x_dims size < 2, num_channels defaults to 1.
@ -48,16 +50,22 @@ class BatchNormHelper {
// validate 'scales' shape
const auto& scale_dims = scale->Shape().GetDims();
if (static_cast<int>(scale_dims.size()) != kNumInputScaleDimensions) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: NumDimensions() != ", kNumInputScaleDimensions);
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Invalid input scale: NumDimensions() != ", kNumInputScaleDimensions);
}
if (scale_dims[0] != num_channels) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: 0th dimension != ", num_channels);
}
// N & C do not belong to features
// skip the first element for NHWC and the first two elements for NCHW.
int feature_offset = is_nhwc ? 1 : 2;
// in non-spatial cases - the other dims of 'scale' must be validated
if (!is_spatial) {
for (int feature = 0; feature < num_feature_dims; ++feature) {
if (scale_dims[1 + feature] != x_dims[2 + feature]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: ", (1 + feature), " dimension != ", x_dims[2 + feature]);
if (scale_dims[1 + feature] != x_dims[feature_offset + feature]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: ", (1 + feature),
" dimension != ", x_dims[feature_offset + feature]);
}
}
}
@ -65,7 +73,8 @@ class BatchNormHelper {
// validate 'B' shape
const auto& B_dims = B->Shape().GetDims();
if (static_cast<int>(B_dims.size()) != kNumInputBiasDimensions) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: NumDimensions() != ", kNumInputBiasDimensions);
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Invalid input B: NumDimensions() != ", kNumInputBiasDimensions);
}
if (B_dims[0] != num_channels) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: 0th dimension != ", num_channels);
@ -73,8 +82,9 @@ class BatchNormHelper {
// in non-spatial cases - the other dims of 'B' must be validated
if (!is_spatial) {
for (int feature = 0; feature < num_feature_dims; ++feature) {
if (B_dims[1 + feature] != x_dims[2 + feature]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: ", (1 + feature), " dimension != ", x_dims[2 + feature]);
if (B_dims[1 + feature] != x_dims[feature_offset + feature]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: ", (1 + feature),
" dimension != ", x_dims[feature_offset + feature]);
}
}
}
@ -82,16 +92,19 @@ class BatchNormHelper {
// validate 'mean' shape
const auto& mean_dims = mean->Shape().GetDims();
if (static_cast<int>(mean_dims.size()) != kNumInputMeanDimensions) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: NumDimensions() != ", kNumInputMeanDimensions);
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Invalid input mean: NumDimensions() != ", kNumInputMeanDimensions);
}
if (mean_dims[0] != num_channels) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: 0th dimension != ", num_channels);
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Invalid input mean: 0th dimension != ", num_channels);
}
// in non-spatial cases - the other dims of 'mean' must be validated
if (!is_spatial) {
for (int feature = 0; feature < num_feature_dims; ++feature) {
if (mean_dims[1 + feature] != x_dims[2 + feature]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: ", (1 + feature), " dimension != ", x_dims[2 + feature]);
if (mean_dims[1 + feature] != x_dims[feature_offset + feature]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: ", (1 + feature),
" dimension != ", x_dims[feature_offset + feature]);
}
}
}
@ -99,7 +112,8 @@ class BatchNormHelper {
// validate 'var' shape
const auto& var_dims = var->Shape().GetDims();
if (static_cast<int>(var_dims.size()) != kNumInputVarianceDimensions) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: NumDimensions() != ", kNumInputVarianceDimensions);
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Invalid input var: NumDimensions() != ", kNumInputVarianceDimensions);
}
if (var_dims[0] != num_channels) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: 0th dimension != ", num_channels);
@ -107,8 +121,9 @@ class BatchNormHelper {
// in non-spatial cases - the other dims of 'var' must be validated
if (!is_spatial) {
for (int feature = 0; feature < num_feature_dims; ++feature) {
if (var_dims[1 + feature] != x_dims[2 + feature]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: ", (1 + feature), " dimension != ", x_dims[2 + feature]);
if (var_dims[1 + feature] != x_dims[feature_offset + feature]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: ", (1 + feature),
" dimension != ", x_dims[feature_offset + feature]);
}
}
}

View file

@ -1202,9 +1202,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, LSTM);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, LSTM);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, LSTM);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMin);
@ -2107,9 +2110,12 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, LSTM)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, LSTM)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 18, Reshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Mul)>,

View file

@ -18,10 +18,14 @@ namespace onnxruntime::cuda {
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, float,
BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, double,
BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, MLFloat16,
BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, float,
BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, double,
BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, MLFloat16,
BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, float,
@ -72,10 +76,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalN
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, MLFloat16, MaxPool);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, float,
BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, double,
BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, MLFloat16,
BatchNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, float,
BatchNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, double,
BatchNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, MLFloat16,
BatchNormalization);
@ -86,18 +94,26 @@ Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) {
kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, MLFloat16, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, float, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, double, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, MLFloat16, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, float, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, double, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, MLFloat16, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, float, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, double, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15,
MLFloat16, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15,
float, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15,
double, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, MLFloat16, Conv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider,

View file

@ -87,7 +87,7 @@ Status BatchNorm<T, NHWC>::ComputeInternal(OpKernelContext* p_op_kernel_context)
CudnnTensor data_desc;
vector<int64_t> new_dims;
BatchNormHelper::NormalizeDims(x_shape, new_dims);
BatchNormHelper::NormalizeDims(x_shape, new_dims, NHWC);
ORT_RETURN_IF_ERROR(data_desc.Set(new_dims, CudnnTensor::GetDataType<CudaT>(), NHWC));
// For half data type, the alpha, beta, scale, B, mean, var need to be float type
@ -137,6 +137,12 @@ Status BatchNorm<T, NHWC>::ComputeInternal(OpKernelContext* p_op_kernel_context)
auto saved_mean_data = reinterpret_cast<CudaT*>(saved_mean->MutableData<T>());
auto saved_inv_var_data = reinterpret_cast<CudaT*>(saved_var->MutableData<T>());
auto stream = static_cast<cudaStream_t>(p_op_kernel_context->GetComputeStream()->GetHandle());
CUDA_RETURN_IF_ERROR(
cudaMemcpyAsync(running_mean_data, mean_data, mean->SizeInBytes(), cudaMemcpyDeviceToDevice, stream));
CUDA_RETURN_IF_ERROR(
cudaMemcpyAsync(running_var_data, var_data, var->SizeInBytes(), cudaMemcpyDeviceToDevice, stream));
CUDNN_RETURN_IF_ERROR(BatchNormalizationForwardTrainingHelper(
GetCudnnHandle(p_op_kernel_context),
cudnn_batch_norm_mode_,
@ -149,7 +155,7 @@ Status BatchNorm<T, NHWC>::ComputeInternal(OpKernelContext* p_op_kernel_context)
bn_tensor_desc,
scale_data,
b_data,
momentum_,
1.0 - momentum_,
running_mean_data,
running_var_data,
epsilon_,
@ -186,6 +192,7 @@ SPECIALIZED_COMPUTE(MLFloat16, kOnnxDomain, false)
#ifdef ENABLE_CUDA_NHWC_OPS
SPECIALIZED_COMPUTE(float, kMSInternalNHWCDomain, true)
SPECIALIZED_COMPUTE(double, kMSInternalNHWCDomain, true)
SPECIALIZED_COMPUTE(MLFloat16, kMSInternalNHWCDomain, true)
#endif
} // namespace cuda

View file

@ -916,6 +916,7 @@ TEST(BatchNormTest, ForwardTrainingTestWithSavedOutputsOpset9) {
// exclude CUDA Execution Provider due to flakiness
// exclude TRT and OpenVINO for same reasons as seen in TestBatchNorm()
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
// TODO(mtavenrath) flakiness of running_mean for CUDA has been fixed, the delta of running_var is still ~0.1
{kCudaExecutionProvider, kRocmExecutionProvider,
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
}