mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
[ROCm] Add AveragePool, GlobalAveragePool, MaxPool, GlobalMaxPool Ops (#11968)
* [ROCm] disable expected failure tests PoolTest.MaxPool_10_DilationPadding_?d * [ROCm] Add AveragePool, GlobalAveragePool, MaxPool, GlobalMaxPool Ops * (To squash after review) Replace rocm/nn/pool.cc with amd_hipify.py changes * [ROCM] Replace miCompat with Helper functions * (to squash) fix the compiling error of SetPoolingNdDescriptorHelper
This commit is contained in:
parent
d1497bdf62
commit
77cab7a3a5
6 changed files with 123 additions and 42 deletions
|
|
@ -229,5 +229,28 @@ SetLRNDescriptorHelper(cudnnLRNDescriptor_t normDesc,
|
|||
return cudnnSetLRNDescriptor(normDesc, lrnN, lrnAlpha, lrnBeta, lrnK);
|
||||
}
|
||||
|
||||
inline cudnnStatus_t
|
||||
PoolingForwardHelper(cudnnHandle_t handle,
|
||||
const cudnnPoolingDescriptor_t poolingDesc,
|
||||
const void *alpha,
|
||||
const cudnnTensorDescriptor_t xDesc,
|
||||
const void *x,
|
||||
const void *beta,
|
||||
const cudnnTensorDescriptor_t yDesc,
|
||||
void *y) {
|
||||
return cudnnPoolingForward(handle, poolingDesc, alpha, xDesc, x, beta, yDesc, y);
|
||||
}
|
||||
|
||||
inline cudnnStatus_t
|
||||
SetPoolingNdDescriptorHelper(cudnnPoolingDescriptor_t poolingDesc,
|
||||
const cudnnPoolingMode_t mode,
|
||||
const cudnnNanPropagation_t maxpoolingNanOpt,
|
||||
int nbDims,
|
||||
const int windowDimA[],
|
||||
const int paddingA[],
|
||||
const int strideA[]) {
|
||||
return cudnnSetPoolingNdDescriptor(poolingDesc, mode, maxpoolingNanOpt, nbDims, windowDimA, paddingA, strideA);
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ class CudnnPoolingDescriptor final {
|
|||
for (int i = 0; i < rank; i++) {
|
||||
stride[i] = gsl::narrow_cast<int>(strides[i]);
|
||||
}
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetPoolingNdDescriptor(
|
||||
CUDNN_RETURN_IF_ERROR(SetPoolingNdDescriptorHelper(
|
||||
desc_,
|
||||
mode,
|
||||
CUDNN_PROPAGATE_NAN,
|
||||
|
|
@ -212,7 +212,8 @@ Status Pool<T, PoolType>::ComputeInternal(OpKernelContext* context) const {
|
|||
IAllocatorUniquePtr<float> temp_X = GetScratchBuffer<float>(input_count);
|
||||
auto temp_Y = GetScratchBuffer<float>(output_count);
|
||||
Impl_Cast<CudaT, float>(Stream(), reinterpret_cast<const CudaT*>(x_data), temp_X.get(), input_count);
|
||||
CUDNN_RETURN_IF_ERROR(cudnnPoolingForward(CudnnHandle(), pooling_desc, &alpha, x_tensor, temp_X.get(), &beta, y_tensor, temp_Y.get()));
|
||||
CUDNN_RETURN_IF_ERROR(PoolingForwardHelper(CudnnHandle(), pooling_desc, &alpha,
|
||||
x_tensor, temp_X.get(), &beta, y_tensor, temp_Y.get()));
|
||||
Impl_Cast<float, CudaT>(Stream(), temp_Y.get(), y_data, output_count);
|
||||
} else {
|
||||
const auto alpha = Consts<CudaT>::One;
|
||||
|
|
@ -222,7 +223,8 @@ Status Pool<T, PoolType>::ComputeInternal(OpKernelContext* context) const {
|
|||
ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType<CudaT>()));
|
||||
ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType<CudaT>()));
|
||||
|
||||
CUDNN_RETURN_IF_ERROR(cudnnPoolingForward(CudnnHandle(), pooling_desc, &alpha, x_tensor, x_data, &beta, y_tensor, y_data));
|
||||
CUDNN_RETURN_IF_ERROR(PoolingForwardHelper(CudnnHandle(), pooling_desc, &alpha,
|
||||
x_tensor, x_data, &beta, y_tensor, y_data));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -196,5 +196,28 @@ SetLRNDescriptorHelper(miopenLRNDescriptor_t normDesc,
|
|||
return miopenSetLRNDescriptor(normDesc, miopenLRNCrossChannel, lrnN, lrnAlpha, lrnBeta, lrnK);
|
||||
}
|
||||
|
||||
inline miopenStatus_t
|
||||
PoolingForwardHelper(miopenHandle_t handle,
|
||||
const miopenPoolingDescriptor_t poolDesc,
|
||||
const void* alpha,
|
||||
const miopenTensorDescriptor_t xDesc,
|
||||
const void* x,
|
||||
const void* beta,
|
||||
const miopenTensorDescriptor_t yDesc,
|
||||
void* y) {
|
||||
return miopenPoolingForward(handle, poolDesc, alpha, xDesc, x, beta, yDesc, y, false, nullptr, 0);
|
||||
}
|
||||
|
||||
inline miopenStatus_t
|
||||
SetPoolingNdDescriptorHelper(miopenPoolingDescriptor_t poolDesc,
|
||||
const miopenPoolingMode_t mode,
|
||||
miopenNanPropagation_t /* unavailable */,
|
||||
int nbDims,
|
||||
int* windowDimA,
|
||||
int* padA,
|
||||
int* stridesA) {
|
||||
return miopenSetNdPoolingDescriptor(poolDesc, mode, nbDims, windowDimA, padA, stridesA);
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1433,21 +1433,36 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ConvTranspose)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ConvTranspose)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ConvTranspose)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, float, AveragePool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, double, AveragePool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, MLFloat16, AveragePool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, GlobalAveragePool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, GlobalAveragePool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, GlobalAveragePool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 7, float, MaxPool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 7, double, MaxPool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 7, MLFloat16, MaxPool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 8, 9, float, MaxPool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 8, 9, double, MaxPool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 8, 9, MLFloat16, MaxPool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, GlobalMaxPool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, GlobalMaxPool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, GlobalMaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
7, 9, float, AveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
7, 9, double, AveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
7, 9, MLFloat16, AveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
1, float, GlobalAveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
1, double, GlobalAveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
1, MLFloat16, GlobalAveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
1, 7, float, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
1, 7, double, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
1, 7, MLFloat16, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
8, 9, float, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
8, 9, double, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
8, 9, MLFloat16, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
1, float, GlobalMaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
1, double, GlobalMaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
1, MLFloat16, GlobalMaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ArgMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ArgMax)>,
|
||||
|
|
@ -1604,13 +1619,19 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, RandomUniformLike)>,
|
||||
|
||||
// opset 10
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, float, AveragePool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, double, AveragePool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, MLFloat16, AveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
10, 10, float, AveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
10, 10, double, AveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
10, 10, MLFloat16, AveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 11, Dropout)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, float, MaxPool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, double, MaxPool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, MLFloat16, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
10, 10, float, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
10, 10, double, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
10, 10, MLFloat16, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, NonMaxSuppression)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, float, Resize)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, double, Resize)>,
|
||||
|
|
@ -1718,12 +1739,18 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, float, ConvTranspose)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, double, ConvTranspose)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, ConvTranspose)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, float, AveragePool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, double, AveragePool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, AveragePool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, float, MaxPool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, double, MaxPool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
11, float, AveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
11, double, AveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
11, MLFloat16, AveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
11, 11, float, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
11, 11, double, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
11, 11, MLFloat16, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, Resize)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, Resize)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, Resize)>,
|
||||
|
|
@ -1756,11 +1783,16 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
// OpSet 12
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, Clip)>,
|
||||
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, float, MaxPool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, double, MaxPool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, MLFloat16, MaxPool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, int8_t, MaxPool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, uint8_t, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
12, float, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
12, double, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
12, MLFloat16, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
12, int8_t, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
||||
12, uint8_t, MaxPool)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, Pow)>,
|
||||
|
||||
|
|
|
|||
|
|
@ -400,7 +400,8 @@ TEST(PoolTest, MaxPool_10_DilationPadding_1d) {
|
|||
|
||||
test.AddInput<float>("X", x_dims, x_vals);
|
||||
test.AddOutput<float>("Y", expected_dims, expected_vals);
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(PoolTest, MaxPool_10_Dilation_2d) {
|
||||
|
|
@ -474,7 +475,7 @@ TEST(PoolTest, MaxPool_10_DilationPadding_2d) {
|
|||
test.AddInput<float>("X", x_dims, x_vals);
|
||||
test.AddOutput<float>("Y", expected_dims, expected_vals);
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kCudaExecutionProvider, kTensorrtExecutionProvider});
|
||||
{kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(PoolTest, MaxPool_10_Dilation_Ceil0_2d) {
|
||||
|
|
@ -579,7 +580,8 @@ TEST(PoolTest, MaxPool_10_DilationPadding_3d) {
|
|||
|
||||
test.AddInput<float>("X", x_dims, x_vals);
|
||||
test.AddOutput<float>("Y", expected_dims, expected_vals);
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(PoolTest, GlobalMaxPool) {
|
||||
|
|
|
|||
|
|
@ -113,10 +113,6 @@ provider_excluded_files = [
|
|||
"nn/conv.h",
|
||||
"nn/conv_transpose.cc",
|
||||
"nn/conv_transpose.h",
|
||||
"nn/max_pool_with_index.cu",
|
||||
"nn/max_pool_with_index.h",
|
||||
"nn/pool.cc",
|
||||
"nn/pool.h",
|
||||
"reduction/reduction_ops.cc",
|
||||
"rnn/cudnn_rnn_base.cc",
|
||||
"rnn/cudnn_rnn_base.h",
|
||||
|
|
@ -312,6 +308,9 @@ def hipify(src_file_path, dst_file_path):
|
|||
s = s.replace("MIOPEN_BATCHNORM_SPATIAL", "miopenBNSpatial")
|
||||
s = s.replace("MIOPEN_BATCHNORM_PER_ACTIVATION", "miopenBNPerActivation")
|
||||
s = s.replace("MIOPEN_LRN_CROSS_CHANNEL", "miopenLRNCrossChannel")
|
||||
s = s.replace("MIOPEN_POOLING_MAX", "miopenPoolingMax")
|
||||
s = s.replace("MIOPEN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING", "miopenPoolingAverageInclusive")
|
||||
s = s.replace("MIOPEN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING", "miopenPoolingAverage")
|
||||
|
||||
# CUSPARSE -> HIPSPARSE
|
||||
s = s.replace("CUSPARSE", "HIPSPARSE")
|
||||
|
|
|
|||
Loading…
Reference in a new issue