diff --git a/onnxruntime/core/providers/cuda/cudnn_common.h b/onnxruntime/core/providers/cuda/cudnn_common.h index 9ea3794d5a..ac2f349958 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.h +++ b/onnxruntime/core/providers/cuda/cudnn_common.h @@ -229,5 +229,28 @@ SetLRNDescriptorHelper(cudnnLRNDescriptor_t normDesc, return cudnnSetLRNDescriptor(normDesc, lrnN, lrnAlpha, lrnBeta, lrnK); } +inline cudnnStatus_t +PoolingForwardHelper(cudnnHandle_t handle, + const cudnnPoolingDescriptor_t poolingDesc, + const void *alpha, + const cudnnTensorDescriptor_t xDesc, + const void *x, + const void *beta, + const cudnnTensorDescriptor_t yDesc, + void *y) { + return cudnnPoolingForward(handle, poolingDesc, alpha, xDesc, x, beta, yDesc, y); +} + +inline cudnnStatus_t +SetPoolingNdDescriptorHelper(cudnnPoolingDescriptor_t poolingDesc, + const cudnnPoolingMode_t mode, + const cudnnNanPropagation_t maxpoolingNanOpt, + int nbDims, + const int windowDimA[], + const int paddingA[], + const int strideA[]) { + return cudnnSetPoolingNdDescriptor(poolingDesc, mode, maxpoolingNanOpt, nbDims, windowDimA, paddingA, strideA); +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/pool.cc b/onnxruntime/core/providers/cuda/nn/pool.cc index 3a7868094a..5f7d89824d 100644 --- a/onnxruntime/core/providers/cuda/nn/pool.cc +++ b/onnxruntime/core/providers/cuda/nn/pool.cc @@ -128,7 +128,7 @@ class CudnnPoolingDescriptor final { for (int i = 0; i < rank; i++) { stride[i] = gsl::narrow_cast(strides[i]); } - CUDNN_RETURN_IF_ERROR(cudnnSetPoolingNdDescriptor( + CUDNN_RETURN_IF_ERROR(SetPoolingNdDescriptorHelper( desc_, mode, CUDNN_PROPAGATE_NAN, @@ -212,7 +212,8 @@ Status Pool::ComputeInternal(OpKernelContext* context) const { IAllocatorUniquePtr temp_X = GetScratchBuffer(input_count); auto temp_Y = GetScratchBuffer(output_count); Impl_Cast(Stream(), reinterpret_cast(x_data), temp_X.get(), input_count); - CUDNN_RETURN_IF_ERROR(cudnnPoolingForward(CudnnHandle(), pooling_desc, &alpha, x_tensor, temp_X.get(), &beta, y_tensor, temp_Y.get())); + CUDNN_RETURN_IF_ERROR(PoolingForwardHelper(CudnnHandle(), pooling_desc, &alpha, + x_tensor, temp_X.get(), &beta, y_tensor, temp_Y.get())); Impl_Cast(Stream(), temp_Y.get(), y_data, output_count); } else { const auto alpha = Consts::One; @@ -222,7 +223,8 @@ Status Pool::ComputeInternal(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType())); ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType())); - CUDNN_RETURN_IF_ERROR(cudnnPoolingForward(CudnnHandle(), pooling_desc, &alpha, x_tensor, x_data, &beta, y_tensor, y_data)); + CUDNN_RETURN_IF_ERROR(PoolingForwardHelper(CudnnHandle(), pooling_desc, &alpha, + x_tensor, x_data, &beta, y_tensor, y_data)); } return Status::OK(); diff --git a/onnxruntime/core/providers/rocm/miopen_common.h b/onnxruntime/core/providers/rocm/miopen_common.h index 6e5b9a02d9..003ebab910 100644 --- a/onnxruntime/core/providers/rocm/miopen_common.h +++ b/onnxruntime/core/providers/rocm/miopen_common.h @@ -196,5 +196,28 @@ SetLRNDescriptorHelper(miopenLRNDescriptor_t normDesc, return miopenSetLRNDescriptor(normDesc, miopenLRNCrossChannel, lrnN, lrnAlpha, lrnBeta, lrnK); } +inline miopenStatus_t +PoolingForwardHelper(miopenHandle_t handle, + const miopenPoolingDescriptor_t poolDesc, + const void* alpha, + const miopenTensorDescriptor_t xDesc, + const void* x, + const void* beta, + const miopenTensorDescriptor_t yDesc, + void* y) { + return miopenPoolingForward(handle, poolDesc, alpha, xDesc, x, beta, yDesc, y, false, nullptr, 0); +} + +inline miopenStatus_t +SetPoolingNdDescriptorHelper(miopenPoolingDescriptor_t poolDesc, + const miopenPoolingMode_t mode, + miopenNanPropagation_t /* unavailable */, + int nbDims, + int* windowDimA, + int* padA, + int* stridesA) { + return miopenSetNdPoolingDescriptor(poolDesc, mode, nbDims, windowDimA, padA, stridesA); +} + } // namespace rocm } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index f867672dfe..4cfa91cde5 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -1433,21 +1433,36 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1604,13 +1619,19 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // opset 10 - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1718,12 +1739,18 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1756,11 +1783,16 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { // OpSet 12 BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc index 7712b8d0c0..ab395057c2 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc @@ -400,7 +400,8 @@ TEST(PoolTest, MaxPool_10_DilationPadding_1d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); } TEST(PoolTest, MaxPool_10_Dilation_2d) { @@ -474,7 +475,7 @@ TEST(PoolTest, MaxPool_10_DilationPadding_2d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider}); + {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); } TEST(PoolTest, MaxPool_10_Dilation_Ceil0_2d) { @@ -579,7 +580,8 @@ TEST(PoolTest, MaxPool_10_DilationPadding_3d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); } TEST(PoolTest, GlobalMaxPool) { diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index 558a2d863a..c27a50d274 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -113,10 +113,6 @@ provider_excluded_files = [ "nn/conv.h", "nn/conv_transpose.cc", "nn/conv_transpose.h", - "nn/max_pool_with_index.cu", - "nn/max_pool_with_index.h", - "nn/pool.cc", - "nn/pool.h", "reduction/reduction_ops.cc", "rnn/cudnn_rnn_base.cc", "rnn/cudnn_rnn_base.h", @@ -312,6 +308,9 @@ def hipify(src_file_path, dst_file_path): s = s.replace("MIOPEN_BATCHNORM_SPATIAL", "miopenBNSpatial") s = s.replace("MIOPEN_BATCHNORM_PER_ACTIVATION", "miopenBNPerActivation") s = s.replace("MIOPEN_LRN_CROSS_CHANNEL", "miopenLRNCrossChannel") + s = s.replace("MIOPEN_POOLING_MAX", "miopenPoolingMax") + s = s.replace("MIOPEN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING", "miopenPoolingAverageInclusive") + s = s.replace("MIOPEN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING", "miopenPoolingAverage") # CUSPARSE -> HIPSPARSE s = s.replace("CUSPARSE", "HIPSPARSE")