mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
Fix CUDA BatchNorm bugs and add support for NHWC (#19742)
### Description - Fix incorrect running_mean / running_var in training mode due to incorrect momentum and missing input mean/var. runnig_var could be correct, but has a too high epsilon. - Fix incorrect checks when using NHWC - Pass NHWC flag to NormalizeDims to get correct new dimensions from x_shape - Register missing double operations to get parity between NHWC/NCHW
This commit is contained in:
parent
cd56ea4a74
commit
bdf678df93
5 changed files with 66 additions and 21 deletions
|
|
@ -25,6 +25,8 @@ class BatchNormHelper {
|
|||
const Tensor* var,
|
||||
bool is_spatial = true,
|
||||
bool is_nhwc = false) {
|
||||
// NHWC dependent shape: X
|
||||
// All other shapes are assumed to be in NCHW layout?
|
||||
const auto& x_dims = X->Shape().GetDims();
|
||||
|
||||
// If x_dims size < 2, num_channels defaults to 1.
|
||||
|
|
@ -48,16 +50,22 @@ class BatchNormHelper {
|
|||
// validate 'scales' shape
|
||||
const auto& scale_dims = scale->Shape().GetDims();
|
||||
if (static_cast<int>(scale_dims.size()) != kNumInputScaleDimensions) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: NumDimensions() != ", kNumInputScaleDimensions);
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Invalid input scale: NumDimensions() != ", kNumInputScaleDimensions);
|
||||
}
|
||||
if (scale_dims[0] != num_channels) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: 0th dimension != ", num_channels);
|
||||
}
|
||||
// N & C do not belong to features
|
||||
// skip the first element for NHWC and the first two elements for NCHW.
|
||||
int feature_offset = is_nhwc ? 1 : 2;
|
||||
|
||||
// in non-spatial cases - the other dims of 'scale' must be validated
|
||||
if (!is_spatial) {
|
||||
for (int feature = 0; feature < num_feature_dims; ++feature) {
|
||||
if (scale_dims[1 + feature] != x_dims[2 + feature]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: ", (1 + feature), " dimension != ", x_dims[2 + feature]);
|
||||
if (scale_dims[1 + feature] != x_dims[feature_offset + feature]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: ", (1 + feature),
|
||||
" dimension != ", x_dims[feature_offset + feature]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -65,7 +73,8 @@ class BatchNormHelper {
|
|||
// validate 'B' shape
|
||||
const auto& B_dims = B->Shape().GetDims();
|
||||
if (static_cast<int>(B_dims.size()) != kNumInputBiasDimensions) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: NumDimensions() != ", kNumInputBiasDimensions);
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Invalid input B: NumDimensions() != ", kNumInputBiasDimensions);
|
||||
}
|
||||
if (B_dims[0] != num_channels) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: 0th dimension != ", num_channels);
|
||||
|
|
@ -73,8 +82,9 @@ class BatchNormHelper {
|
|||
// in non-spatial cases - the other dims of 'B' must be validated
|
||||
if (!is_spatial) {
|
||||
for (int feature = 0; feature < num_feature_dims; ++feature) {
|
||||
if (B_dims[1 + feature] != x_dims[2 + feature]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: ", (1 + feature), " dimension != ", x_dims[2 + feature]);
|
||||
if (B_dims[1 + feature] != x_dims[feature_offset + feature]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: ", (1 + feature),
|
||||
" dimension != ", x_dims[feature_offset + feature]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -82,16 +92,19 @@ class BatchNormHelper {
|
|||
// validate 'mean' shape
|
||||
const auto& mean_dims = mean->Shape().GetDims();
|
||||
if (static_cast<int>(mean_dims.size()) != kNumInputMeanDimensions) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: NumDimensions() != ", kNumInputMeanDimensions);
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Invalid input mean: NumDimensions() != ", kNumInputMeanDimensions);
|
||||
}
|
||||
if (mean_dims[0] != num_channels) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: 0th dimension != ", num_channels);
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Invalid input mean: 0th dimension != ", num_channels);
|
||||
}
|
||||
// in non-spatial cases - the other dims of 'mean' must be validated
|
||||
if (!is_spatial) {
|
||||
for (int feature = 0; feature < num_feature_dims; ++feature) {
|
||||
if (mean_dims[1 + feature] != x_dims[2 + feature]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: ", (1 + feature), " dimension != ", x_dims[2 + feature]);
|
||||
if (mean_dims[1 + feature] != x_dims[feature_offset + feature]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: ", (1 + feature),
|
||||
" dimension != ", x_dims[feature_offset + feature]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -99,7 +112,8 @@ class BatchNormHelper {
|
|||
// validate 'var' shape
|
||||
const auto& var_dims = var->Shape().GetDims();
|
||||
if (static_cast<int>(var_dims.size()) != kNumInputVarianceDimensions) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: NumDimensions() != ", kNumInputVarianceDimensions);
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Invalid input var: NumDimensions() != ", kNumInputVarianceDimensions);
|
||||
}
|
||||
if (var_dims[0] != num_channels) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: 0th dimension != ", num_channels);
|
||||
|
|
@ -107,8 +121,9 @@ class BatchNormHelper {
|
|||
// in non-spatial cases - the other dims of 'var' must be validated
|
||||
if (!is_spatial) {
|
||||
for (int feature = 0; feature < num_feature_dims; ++feature) {
|
||||
if (var_dims[1 + feature] != x_dims[2 + feature]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: ", (1 + feature), " dimension != ", x_dims[2 + feature]);
|
||||
if (var_dims[1 + feature] != x_dims[feature_offset + feature]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: ", (1 + feature),
|
||||
" dimension != ", x_dims[feature_offset + feature]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1202,9 +1202,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, LSTM);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, LSTM);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, LSTM);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceMin);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceMin);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMin);
|
||||
|
|
@ -2107,9 +2110,12 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, LSTM)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, LSTM)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 18, Reshape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Add)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Sub)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Mul)>,
|
||||
|
|
|
|||
|
|
@ -18,10 +18,14 @@ namespace onnxruntime::cuda {
|
|||
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, float,
|
||||
BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, double,
|
||||
BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, MLFloat16,
|
||||
BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, float,
|
||||
BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, double,
|
||||
BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, MLFloat16,
|
||||
BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, float,
|
||||
|
|
@ -72,10 +76,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalN
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, MLFloat16, MaxPool);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, float,
|
||||
BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, double,
|
||||
BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, MLFloat16,
|
||||
BatchNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, float,
|
||||
BatchNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, double,
|
||||
BatchNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, MLFloat16,
|
||||
BatchNormalization);
|
||||
|
||||
|
|
@ -86,18 +94,26 @@ Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) {
|
|||
kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, MLFloat16, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, float, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, double, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, MLFloat16, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, float, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, double, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, MLFloat16, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, float, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, double, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15,
|
||||
MLFloat16, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15,
|
||||
float, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15,
|
||||
double, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, MLFloat16, Conv)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider,
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ Status BatchNorm<T, NHWC>::ComputeInternal(OpKernelContext* p_op_kernel_context)
|
|||
|
||||
CudnnTensor data_desc;
|
||||
vector<int64_t> new_dims;
|
||||
BatchNormHelper::NormalizeDims(x_shape, new_dims);
|
||||
BatchNormHelper::NormalizeDims(x_shape, new_dims, NHWC);
|
||||
ORT_RETURN_IF_ERROR(data_desc.Set(new_dims, CudnnTensor::GetDataType<CudaT>(), NHWC));
|
||||
|
||||
// For half data type, the alpha, beta, scale, B, mean, var need to be float type
|
||||
|
|
@ -137,6 +137,12 @@ Status BatchNorm<T, NHWC>::ComputeInternal(OpKernelContext* p_op_kernel_context)
|
|||
auto saved_mean_data = reinterpret_cast<CudaT*>(saved_mean->MutableData<T>());
|
||||
auto saved_inv_var_data = reinterpret_cast<CudaT*>(saved_var->MutableData<T>());
|
||||
|
||||
auto stream = static_cast<cudaStream_t>(p_op_kernel_context->GetComputeStream()->GetHandle());
|
||||
CUDA_RETURN_IF_ERROR(
|
||||
cudaMemcpyAsync(running_mean_data, mean_data, mean->SizeInBytes(), cudaMemcpyDeviceToDevice, stream));
|
||||
CUDA_RETURN_IF_ERROR(
|
||||
cudaMemcpyAsync(running_var_data, var_data, var->SizeInBytes(), cudaMemcpyDeviceToDevice, stream));
|
||||
|
||||
CUDNN_RETURN_IF_ERROR(BatchNormalizationForwardTrainingHelper(
|
||||
GetCudnnHandle(p_op_kernel_context),
|
||||
cudnn_batch_norm_mode_,
|
||||
|
|
@ -149,7 +155,7 @@ Status BatchNorm<T, NHWC>::ComputeInternal(OpKernelContext* p_op_kernel_context)
|
|||
bn_tensor_desc,
|
||||
scale_data,
|
||||
b_data,
|
||||
momentum_,
|
||||
1.0 - momentum_,
|
||||
running_mean_data,
|
||||
running_var_data,
|
||||
epsilon_,
|
||||
|
|
@ -186,6 +192,7 @@ SPECIALIZED_COMPUTE(MLFloat16, kOnnxDomain, false)
|
|||
|
||||
#ifdef ENABLE_CUDA_NHWC_OPS
|
||||
SPECIALIZED_COMPUTE(float, kMSInternalNHWCDomain, true)
|
||||
SPECIALIZED_COMPUTE(double, kMSInternalNHWCDomain, true)
|
||||
SPECIALIZED_COMPUTE(MLFloat16, kMSInternalNHWCDomain, true)
|
||||
#endif
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -916,6 +916,7 @@ TEST(BatchNormTest, ForwardTrainingTestWithSavedOutputsOpset9) {
|
|||
// exclude CUDA Execution Provider due to flakiness
|
||||
// exclude TRT and OpenVINO for same reasons as seen in TestBatchNorm()
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
// TODO(mtavenrath) flakiness of running_mean for CUDA has been fixed, the delta of running_var is still ~0.1
|
||||
{kCudaExecutionProvider, kRocmExecutionProvider,
|
||||
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue