From a8d549e181534e80a519cced7ee3aedebeef1b1d Mon Sep 17 00:00:00 2001 From: Jesse Benson Date: Mon, 14 Dec 2020 14:00:28 -0800 Subject: [PATCH] Minor changes to AMD element-wise kernels to converge with CUDA element-wise kernels. --- .../cuda/math/binary_elementwise_ops_impl.cu | 2 +- .../rocm/math/binary_elementwise_ops_impl.cu | 2 +- .../cuda/math/binary_elementwise_ops_impl.cu | 2 +- .../rocm/math/binary_elementwise_ops.cc | 3 +- .../rocm/math/binary_elementwise_ops_impl.cu | 2 +- .../binary_elementwise_ops_impl_functors.cuh | 40 ------------------- .../providers/rocm/rocm_execution_provider.cc | 18 ++++++--- tools/ci_build/amd_hipify.py | 1 - 8 files changed, 18 insertions(+), 52 deletions(-) delete mode 100644 onnxruntime/core/providers/rocm/math/binary_elementwise_ops_impl_functors.cuh diff --git a/onnxruntime/contrib_ops/cuda/math/binary_elementwise_ops_impl.cu b/onnxruntime/contrib_ops/cuda/math/binary_elementwise_ops_impl.cu index 85c02d73f1..c6b977ddbe 100644 --- a/onnxruntime/contrib_ops/cuda/math/binary_elementwise_ops_impl.cu +++ b/onnxruntime/contrib_ops/cuda/math/binary_elementwise_ops_impl.cu @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include -#include "binary_elementwise_ops_impl.h" +#include "contrib_ops/cuda/math/binary_elementwise_ops_impl.h" #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cu_inc/binary_elementwise_impl.cuh" diff --git a/onnxruntime/contrib_ops/rocm/math/binary_elementwise_ops_impl.cu b/onnxruntime/contrib_ops/rocm/math/binary_elementwise_ops_impl.cu index 48d2fae862..fc325ac549 100644 --- a/onnxruntime/contrib_ops/rocm/math/binary_elementwise_ops_impl.cu +++ b/onnxruntime/contrib_ops/rocm/math/binary_elementwise_ops_impl.cu @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include -#include "binary_elementwise_ops_impl.h" +#include "contrib_ops/rocm/math/binary_elementwise_ops_impl.h" #include "core/providers/rocm/cu_inc/common.cuh" #include "core/providers/rocm/cu_inc/binary_elementwise_impl.cuh" diff --git a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.cu index 60b976848e..e4cab89128 100644 --- a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.cu @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include -#include "binary_elementwise_ops_impl.h" +#include "core/providers/cuda/math/binary_elementwise_ops_impl.h" #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cu_inc/binary_elementwise_impl.cuh" #include "core/providers/cuda/math/binary_elementwise_ops_impl_functors.cuh" diff --git a/onnxruntime/core/providers/rocm/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/rocm/math/binary_elementwise_ops.cc index a276044a8f..0acd228134 100644 --- a/onnxruntime/core/providers/rocm/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/rocm/math/binary_elementwise_ops.cc @@ -280,7 +280,8 @@ BINARY_OP_REGISTER_VERSIONED_CLASS_HFD(Pow, Pow_7, 7, 11) BINARY_LOGICALOP_TYPED(And, 7, bool) BINARY_LOGICALOP_TYPED(Or, 7, bool) BINARY_LOGICALOP_TYPED(Xor, 7, bool) -BINARY_OP_HFD(PRelu, 7) +BINARY_OP_VERSIONED_HFD(PRelu, 7, 8) +BINARY_OP_HFD(PRelu, 9) // Pow since version 12 ONNX_OPERATOR_VERSIONED_KERNEL_EX( diff --git a/onnxruntime/core/providers/rocm/math/binary_elementwise_ops_impl.cu b/onnxruntime/core/providers/rocm/math/binary_elementwise_ops_impl.cu index ba157c3a02..a469398a95 100644 --- a/onnxruntime/core/providers/rocm/math/binary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/rocm/math/binary_elementwise_ops_impl.cu @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include -#include "binary_elementwise_ops_impl.h" +#include "core/providers/rocm/math/binary_elementwise_ops_impl.h" #include "core/providers/rocm/cu_inc/common.cuh" #include "core/providers/rocm/cu_inc/binary_elementwise_impl.cuh" #include "core/providers/rocm/math/binary_elementwise_ops_impl_functors.cuh" diff --git a/onnxruntime/core/providers/rocm/math/binary_elementwise_ops_impl_functors.cuh b/onnxruntime/core/providers/rocm/math/binary_elementwise_ops_impl_functors.cuh deleted file mode 100644 index b05d7c7e28..0000000000 --- a/onnxruntime/core/providers/rocm/math/binary_elementwise_ops_impl_functors.cuh +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/rocm/math/binary_elementwise_ops_impl.h" - -namespace onnxruntime { -namespace rocm { - -// define the device functors that perform the computation on scalars - -#define OP_FUNCTOR_DEFINITION(name, expr) \ - template \ - struct OP_##name { \ - __device__ __inline__ T operator()(T1 a, T2 b) const { \ - return (expr); \ - } \ - }; - -#define BINARY_OP_NAME_EXPR(name, expr) \ - OP_FUNCTOR_DEFINITION(name, expr) - -BINARY_OPS() - -OP_FUNCTOR_DEFINITION(Pow, _Pow(a, b)) - -#undef BINARY_OP_NAME_EXPR - -#define BINARY_OP_NAME_EXPR2(name, expr) \ - OP_FUNCTOR_DEFINITION(name, expr) - -BINARY_OPS2() - -#undef BINARY_OP_NAME_EXPR2 - -#undef OP_FUNCTOR_DEFINITION - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 38c1e9cabb..4732bfcf6f 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -380,9 +380,12 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, float, Pow); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, double, Pow); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, MLFloat16, Pow); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, float, PRelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, double, PRelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, MLFloat16, PRelu); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, PRelu); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, PRelu); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, PRelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, float, PRelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, double, PRelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, MLFloat16, PRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, bool, And); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, bool, Or); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, bool, Xor); @@ -1063,9 +1066,12 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index 8c916ab244..c28f6b5e84 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -89,7 +89,6 @@ core_ops_files = [ 'math/binary_elementwise_ops.h', 'math/binary_elementwise_ops_impl.cu', 'math/binary_elementwise_ops_impl.h', - 'math/binary_elementwise_ops_impl_functors.cuh', 'math/cumsum.cc', 'math/cumsum.h', 'math/cumsum_impl.cu',