Minor changes to AMD element-wise kernels to converge with CUDA element-wise kernels.

This commit is contained in:
Jesse Benson 2020-12-14 14:00:28 -08:00 committed by Jesse Benson
parent a9548283d0
commit a8d549e181
8 changed files with 18 additions and 52 deletions

View file

@ -2,7 +2,7 @@
// Licensed under the MIT License.
#include <cuda_runtime.h>
#include "binary_elementwise_ops_impl.h"
#include "contrib_ops/cuda/math/binary_elementwise_ops_impl.h"
#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/cu_inc/binary_elementwise_impl.cuh"

View file

@ -2,7 +2,7 @@
// Licensed under the MIT License.
#include <hip/hip_runtime.h>
#include "binary_elementwise_ops_impl.h"
#include "contrib_ops/rocm/math/binary_elementwise_ops_impl.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/cu_inc/binary_elementwise_impl.cuh"

View file

@ -2,7 +2,7 @@
// Licensed under the MIT License.
#include <cuda_runtime.h>
#include "binary_elementwise_ops_impl.h"
#include "core/providers/cuda/math/binary_elementwise_ops_impl.h"
#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/cu_inc/binary_elementwise_impl.cuh"
#include "core/providers/cuda/math/binary_elementwise_ops_impl_functors.cuh"

View file

@ -280,7 +280,8 @@ BINARY_OP_REGISTER_VERSIONED_CLASS_HFD(Pow, Pow_7, 7, 11)
BINARY_LOGICALOP_TYPED(And, 7, bool)
BINARY_LOGICALOP_TYPED(Or, 7, bool)
BINARY_LOGICALOP_TYPED(Xor, 7, bool)
BINARY_OP_HFD(PRelu, 7)
BINARY_OP_VERSIONED_HFD(PRelu, 7, 8)
BINARY_OP_HFD(PRelu, 9)
// Pow since version 12
ONNX_OPERATOR_VERSIONED_KERNEL_EX(

View file

@ -2,7 +2,7 @@
// Licensed under the MIT License.
#include <hip/hip_runtime.h>
#include "binary_elementwise_ops_impl.h"
#include "core/providers/rocm/math/binary_elementwise_ops_impl.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/cu_inc/binary_elementwise_impl.cuh"
#include "core/providers/rocm/math/binary_elementwise_ops_impl_functors.cuh"

View file

@ -1,40 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/math/binary_elementwise_ops_impl.h"
namespace onnxruntime {
namespace rocm {
// define the device functors that perform the computation on scalars
#define OP_FUNCTOR_DEFINITION(name, expr) \
template <class T, class T1, class T2> \
struct OP_##name { \
__device__ __inline__ T operator()(T1 a, T2 b) const { \
return (expr); \
} \
};
#define BINARY_OP_NAME_EXPR(name, expr) \
OP_FUNCTOR_DEFINITION(name, expr)
BINARY_OPS()
OP_FUNCTOR_DEFINITION(Pow, _Pow(a, b))
#undef BINARY_OP_NAME_EXPR
#define BINARY_OP_NAME_EXPR2(name, expr) \
OP_FUNCTOR_DEFINITION(name, expr)
BINARY_OPS2()
#undef BINARY_OP_NAME_EXPR2
#undef OP_FUNCTOR_DEFINITION
} // namespace rocm
} // namespace onnxruntime

View file

@ -380,9 +380,12 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, float, Pow);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, double, Pow);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, MLFloat16, Pow);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, float, PRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, double, PRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, MLFloat16, PRelu);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, PRelu);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, PRelu);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, PRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, float, PRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, double, PRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, MLFloat16, PRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, bool, And);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, bool, Or);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, bool, Xor);
@ -1063,9 +1066,12 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, float, Pow)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, double, Pow)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, MLFloat16, Pow)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, float, PRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, double, PRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, MLFloat16, PRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, PRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, PRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, PRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, float, PRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, double, PRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, MLFloat16, PRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, bool, And)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, bool, Or)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, bool, Xor)>,

View file

@ -89,7 +89,6 @@ core_ops_files = [
'math/binary_elementwise_ops.h',
'math/binary_elementwise_ops_impl.cu',
'math/binary_elementwise_ops_impl.h',
'math/binary_elementwise_ops_impl_functors.cuh',
'math/cumsum.cc',
'math/cumsum.h',
'math/cumsum_impl.cu',