mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Minor changes to AMD element-wise kernels to converge with CUDA element-wise kernels.
This commit is contained in:
parent
a9548283d0
commit
a8d549e181
8 changed files with 18 additions and 52 deletions
|
|
@ -2,7 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include "binary_elementwise_ops_impl.h"
|
||||
#include "contrib_ops/cuda/math/binary_elementwise_ops_impl.h"
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "core/providers/cuda/cu_inc/binary_elementwise_impl.cuh"
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include "binary_elementwise_ops_impl.h"
|
||||
#include "contrib_ops/rocm/math/binary_elementwise_ops_impl.h"
|
||||
#include "core/providers/rocm/cu_inc/common.cuh"
|
||||
#include "core/providers/rocm/cu_inc/binary_elementwise_impl.cuh"
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include "binary_elementwise_ops_impl.h"
|
||||
#include "core/providers/cuda/math/binary_elementwise_ops_impl.h"
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "core/providers/cuda/cu_inc/binary_elementwise_impl.cuh"
|
||||
#include "core/providers/cuda/math/binary_elementwise_ops_impl_functors.cuh"
|
||||
|
|
|
|||
|
|
@ -280,7 +280,8 @@ BINARY_OP_REGISTER_VERSIONED_CLASS_HFD(Pow, Pow_7, 7, 11)
|
|||
BINARY_LOGICALOP_TYPED(And, 7, bool)
|
||||
BINARY_LOGICALOP_TYPED(Or, 7, bool)
|
||||
BINARY_LOGICALOP_TYPED(Xor, 7, bool)
|
||||
BINARY_OP_HFD(PRelu, 7)
|
||||
BINARY_OP_VERSIONED_HFD(PRelu, 7, 8)
|
||||
BINARY_OP_HFD(PRelu, 9)
|
||||
|
||||
// Pow since version 12
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include "binary_elementwise_ops_impl.h"
|
||||
#include "core/providers/rocm/math/binary_elementwise_ops_impl.h"
|
||||
#include "core/providers/rocm/cu_inc/common.cuh"
|
||||
#include "core/providers/rocm/cu_inc/binary_elementwise_impl.cuh"
|
||||
#include "core/providers/rocm/math/binary_elementwise_ops_impl_functors.cuh"
|
||||
|
|
|
|||
|
|
@ -1,40 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/rocm/math/binary_elementwise_ops_impl.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace rocm {
|
||||
|
||||
// define the device functors that perform the computation on scalars
|
||||
|
||||
#define OP_FUNCTOR_DEFINITION(name, expr) \
|
||||
template <class T, class T1, class T2> \
|
||||
struct OP_##name { \
|
||||
__device__ __inline__ T operator()(T1 a, T2 b) const { \
|
||||
return (expr); \
|
||||
} \
|
||||
};
|
||||
|
||||
#define BINARY_OP_NAME_EXPR(name, expr) \
|
||||
OP_FUNCTOR_DEFINITION(name, expr)
|
||||
|
||||
BINARY_OPS()
|
||||
|
||||
OP_FUNCTOR_DEFINITION(Pow, _Pow(a, b))
|
||||
|
||||
#undef BINARY_OP_NAME_EXPR
|
||||
|
||||
#define BINARY_OP_NAME_EXPR2(name, expr) \
|
||||
OP_FUNCTOR_DEFINITION(name, expr)
|
||||
|
||||
BINARY_OPS2()
|
||||
|
||||
#undef BINARY_OP_NAME_EXPR2
|
||||
|
||||
#undef OP_FUNCTOR_DEFINITION
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -380,9 +380,12 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO
|
|||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, float, Pow);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, double, Pow);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, MLFloat16, Pow);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, float, PRelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, double, PRelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, MLFloat16, PRelu);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, PRelu);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, PRelu);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, PRelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, float, PRelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, double, PRelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, MLFloat16, PRelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, bool, And);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, bool, Or);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, bool, Xor);
|
||||
|
|
@ -1063,9 +1066,12 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, float, Pow)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, double, Pow)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, MLFloat16, Pow)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, float, PRelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, double, PRelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, MLFloat16, PRelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, PRelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, PRelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, PRelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, float, PRelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, double, PRelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, MLFloat16, PRelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, bool, And)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, bool, Or)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, bool, Xor)>,
|
||||
|
|
|
|||
|
|
@ -89,7 +89,6 @@ core_ops_files = [
|
|||
'math/binary_elementwise_ops.h',
|
||||
'math/binary_elementwise_ops_impl.cu',
|
||||
'math/binary_elementwise_ops_impl.h',
|
||||
'math/binary_elementwise_ops_impl_functors.cuh',
|
||||
'math/cumsum.cc',
|
||||
'math/cumsum.h',
|
||||
'math/cumsum_impl.cu',
|
||||
|
|
|
|||
Loading…
Reference in a new issue