[ROCm] Add AveragePool, GlobalAveragePool, MaxPool, GlobalMaxPool Ops (#11968)

* [ROCm] disable expected failure tests PoolTest.MaxPool_10_DilationPadding_?d

* [ROCm] Add AveragePool, GlobalAveragePool, MaxPool, GlobalMaxPool Ops

* (To squash after review) Replace rocm/nn/pool.cc with amd_hipify.py changes

* [ROCM] Replace miCompat with Helper functions

* (to squash) fix the compiling error of SetPoolingNdDescriptorHelper
This commit is contained in:
Xinya Zhang 2022-08-03 16:36:36 -05:00 committed by GitHub
parent d1497bdf62
commit 77cab7a3a5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 123 additions and 42 deletions

View file

@ -229,5 +229,28 @@ SetLRNDescriptorHelper(cudnnLRNDescriptor_t normDesc,
return cudnnSetLRNDescriptor(normDesc, lrnN, lrnAlpha, lrnBeta, lrnK);
}
inline cudnnStatus_t
PoolingForwardHelper(cudnnHandle_t handle,
const cudnnPoolingDescriptor_t poolingDesc,
const void *alpha,
const cudnnTensorDescriptor_t xDesc,
const void *x,
const void *beta,
const cudnnTensorDescriptor_t yDesc,
void *y) {
return cudnnPoolingForward(handle, poolingDesc, alpha, xDesc, x, beta, yDesc, y);
}
inline cudnnStatus_t
SetPoolingNdDescriptorHelper(cudnnPoolingDescriptor_t poolingDesc,
const cudnnPoolingMode_t mode,
const cudnnNanPropagation_t maxpoolingNanOpt,
int nbDims,
const int windowDimA[],
const int paddingA[],
const int strideA[]) {
return cudnnSetPoolingNdDescriptor(poolingDesc, mode, maxpoolingNanOpt, nbDims, windowDimA, paddingA, strideA);
}
} // namespace cuda
} // namespace onnxruntime

View file

@ -128,7 +128,7 @@ class CudnnPoolingDescriptor final {
for (int i = 0; i < rank; i++) {
stride[i] = gsl::narrow_cast<int>(strides[i]);
}
CUDNN_RETURN_IF_ERROR(cudnnSetPoolingNdDescriptor(
CUDNN_RETURN_IF_ERROR(SetPoolingNdDescriptorHelper(
desc_,
mode,
CUDNN_PROPAGATE_NAN,
@ -212,7 +212,8 @@ Status Pool<T, PoolType>::ComputeInternal(OpKernelContext* context) const {
IAllocatorUniquePtr<float> temp_X = GetScratchBuffer<float>(input_count);
auto temp_Y = GetScratchBuffer<float>(output_count);
Impl_Cast<CudaT, float>(Stream(), reinterpret_cast<const CudaT*>(x_data), temp_X.get(), input_count);
CUDNN_RETURN_IF_ERROR(cudnnPoolingForward(CudnnHandle(), pooling_desc, &alpha, x_tensor, temp_X.get(), &beta, y_tensor, temp_Y.get()));
CUDNN_RETURN_IF_ERROR(PoolingForwardHelper(CudnnHandle(), pooling_desc, &alpha,
x_tensor, temp_X.get(), &beta, y_tensor, temp_Y.get()));
Impl_Cast<float, CudaT>(Stream(), temp_Y.get(), y_data, output_count);
} else {
const auto alpha = Consts<CudaT>::One;
@ -222,7 +223,8 @@ Status Pool<T, PoolType>::ComputeInternal(OpKernelContext* context) const {
ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType<CudaT>()));
ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType<CudaT>()));
CUDNN_RETURN_IF_ERROR(cudnnPoolingForward(CudnnHandle(), pooling_desc, &alpha, x_tensor, x_data, &beta, y_tensor, y_data));
CUDNN_RETURN_IF_ERROR(PoolingForwardHelper(CudnnHandle(), pooling_desc, &alpha,
x_tensor, x_data, &beta, y_tensor, y_data));
}
return Status::OK();

View file

@ -196,5 +196,28 @@ SetLRNDescriptorHelper(miopenLRNDescriptor_t normDesc,
return miopenSetLRNDescriptor(normDesc, miopenLRNCrossChannel, lrnN, lrnAlpha, lrnBeta, lrnK);
}
inline miopenStatus_t
PoolingForwardHelper(miopenHandle_t handle,
const miopenPoolingDescriptor_t poolDesc,
const void* alpha,
const miopenTensorDescriptor_t xDesc,
const void* x,
const void* beta,
const miopenTensorDescriptor_t yDesc,
void* y) {
return miopenPoolingForward(handle, poolDesc, alpha, xDesc, x, beta, yDesc, y, false, nullptr, 0);
}
inline miopenStatus_t
SetPoolingNdDescriptorHelper(miopenPoolingDescriptor_t poolDesc,
const miopenPoolingMode_t mode,
miopenNanPropagation_t /* unavailable */,
int nbDims,
int* windowDimA,
int* padA,
int* stridesA) {
return miopenSetNdPoolingDescriptor(poolDesc, mode, nbDims, windowDimA, padA, stridesA);
}
} // namespace rocm
} // namespace onnxruntime

View file

@ -1433,21 +1433,36 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ConvTranspose)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ConvTranspose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ConvTranspose)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, float, AveragePool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, double, AveragePool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, MLFloat16, AveragePool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, GlobalAveragePool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, GlobalAveragePool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, GlobalAveragePool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 7, float, MaxPool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 7, double, MaxPool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 7, MLFloat16, MaxPool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 8, 9, float, MaxPool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 8, 9, double, MaxPool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 8, 9, MLFloat16, MaxPool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, GlobalMaxPool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, GlobalMaxPool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, GlobalMaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
7, 9, float, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
7, 9, double, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
7, 9, MLFloat16, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
1, float, GlobalAveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
1, double, GlobalAveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
1, MLFloat16, GlobalAveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
1, 7, float, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
1, 7, double, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
1, 7, MLFloat16, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
8, 9, float, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
8, 9, double, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
8, 9, MLFloat16, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
1, float, GlobalMaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
1, double, GlobalMaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
1, MLFloat16, GlobalMaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ArgMax)>,
@ -1604,13 +1619,19 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, RandomUniformLike)>,
// opset 10
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, float, AveragePool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, double, AveragePool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, MLFloat16, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
10, 10, float, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
10, 10, double, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
10, 10, MLFloat16, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 11, Dropout)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, float, MaxPool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, double, MaxPool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, MLFloat16, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
10, 10, float, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
10, 10, double, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
10, 10, MLFloat16, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, NonMaxSuppression)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, float, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, double, Resize)>,
@ -1718,12 +1739,18 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, float, ConvTranspose)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, double, ConvTranspose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, ConvTranspose)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, float, AveragePool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, double, AveragePool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, AveragePool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, float, MaxPool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, double, MaxPool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
11, float, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
11, double, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
11, MLFloat16, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
11, 11, float, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
11, 11, double, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
11, 11, MLFloat16, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, Resize)>,
@ -1756,11 +1783,16 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
// OpSet 12
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, Clip)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, float, MaxPool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, double, MaxPool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, MLFloat16, MaxPool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, int8_t, MaxPool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, uint8_t, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
12, float, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
12, double, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
12, MLFloat16, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
12, int8_t, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
12, uint8_t, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, Pow)>,

View file

@ -400,7 +400,8 @@ TEST(PoolTest, MaxPool_10_DilationPadding_1d) {
test.AddInput<float>("X", x_dims, x_vals);
test.AddOutput<float>("Y", expected_dims, expected_vals);
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider});
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider});
}
TEST(PoolTest, MaxPool_10_Dilation_2d) {
@ -474,7 +475,7 @@ TEST(PoolTest, MaxPool_10_DilationPadding_2d) {
test.AddInput<float>("X", x_dims, x_vals);
test.AddOutput<float>("Y", expected_dims, expected_vals);
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kCudaExecutionProvider, kTensorrtExecutionProvider});
{kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider});
}
TEST(PoolTest, MaxPool_10_Dilation_Ceil0_2d) {
@ -579,7 +580,8 @@ TEST(PoolTest, MaxPool_10_DilationPadding_3d) {
test.AddInput<float>("X", x_dims, x_vals);
test.AddOutput<float>("Y", expected_dims, expected_vals);
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider});
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider});
}
TEST(PoolTest, GlobalMaxPool) {

View file

@ -113,10 +113,6 @@ provider_excluded_files = [
"nn/conv.h",
"nn/conv_transpose.cc",
"nn/conv_transpose.h",
"nn/max_pool_with_index.cu",
"nn/max_pool_with_index.h",
"nn/pool.cc",
"nn/pool.h",
"reduction/reduction_ops.cc",
"rnn/cudnn_rnn_base.cc",
"rnn/cudnn_rnn_base.h",
@ -312,6 +308,9 @@ def hipify(src_file_path, dst_file_path):
s = s.replace("MIOPEN_BATCHNORM_SPATIAL", "miopenBNSpatial")
s = s.replace("MIOPEN_BATCHNORM_PER_ACTIVATION", "miopenBNPerActivation")
s = s.replace("MIOPEN_LRN_CROSS_CHANNEL", "miopenLRNCrossChannel")
s = s.replace("MIOPEN_POOLING_MAX", "miopenPoolingMax")
s = s.replace("MIOPEN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING", "miopenPoolingAverageInclusive")
s = s.replace("MIOPEN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING", "miopenPoolingAverage")
# CUSPARSE -> HIPSPARSE
s = s.replace("CUSPARSE", "HIPSPARSE")