diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index f3d5b7e8de..0ad7015519 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -498,9 +498,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, BatchNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, BatchNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double, BatchNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 13, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 13, double, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 13, MLFloat16, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 12, float, LRN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 12, double, LRN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 12, MLFloat16, LRN); @@ -1112,6 +1112,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, LSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, LSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, LSTM); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, BatchNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, BatchNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, BatchNormalization); #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Add); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Sub); @@ -1307,9 +1310,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1919,6 +1922,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/nn/batch_norm.cc b/onnxruntime/core/providers/cuda/nn/batch_norm.cc index 56d297429b..1a09ec8f45 100644 --- a/onnxruntime/core/providers/cuda/nn/batch_norm.cc +++ b/onnxruntime/core/providers/cuda/nn/batch_norm.cc @@ -21,10 +21,19 @@ namespace cuda { (*KernelDefBuilder::Create()) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ BatchNorm); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + BatchNormalization, \ + kOnnxDomain, \ + 9, 13, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + BatchNorm); \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ BatchNormalization, \ kOnnxDomain, \ - 9, \ + 14, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ diff --git a/onnxruntime/core/providers/cuda/nn/batch_norm.h b/onnxruntime/core/providers/cuda/nn/batch_norm.h index 62241ef52d..e8ded6a571 100644 --- a/onnxruntime/core/providers/cuda/nn/batch_norm.h +++ b/onnxruntime/core/providers/cuda/nn/batch_norm.h @@ -34,6 +34,13 @@ class BatchNorm final : public CudaKernel { if (op_kernel_info.GetAttr("momentum", &tmp_momentum).IsOK()) { momentum_ = static_cast(tmp_momentum); } + + is_training_mode_ = (op_kernel_info.GetAttrOrDefault("training_mode", 0) == 1); + const auto& node = op_kernel_info.node(); + auto opset = node.SinceVersion(); + + // batch norm opset 14 is not implemented for training mode + ORT_ENFORCE(!(is_training_mode_ && opset==14), "Training mode does not support BN opset 14 yet."); } Status ComputeInternal(OpKernelContext* context) const override; @@ -43,6 +50,7 @@ class BatchNorm final : public CudaKernel { int64_t spatial_ = 1; // default as per spec cudnnBatchNormMode_t cudnn_batch_norm_mode_; double momentum_; + bool is_training_mode_ = 0; //default as per spec }; } // namespace cuda