From 1ea3e8633ca118b3e97d7ba9e68e8a4b5531fe28 Mon Sep 17 00:00:00 2001 From: KeDengMS Date: Tue, 28 May 2019 13:52:13 -0700 Subject: [PATCH] CUDA opset9: Update Cast/MatMul version, add Erf (#1106) * CUDA opset9: Update Cast/MatMul version, add Erf * Address CR * More fixes on node placement logic * Fix typo * Update CUDA ops Gemm and BatchNormalization to be registered in opset10 --- .../core/providers/cuda/cu_inc/common.cuh | 12 + .../providers/cuda/cuda_execution_provider.cc | 213 ++++++++++++------ .../providers/cuda/cuda_execution_provider.h | 4 +- onnxruntime/core/providers/cuda/math/gemm.cc | 9 + .../core/providers/cuda/math/matmul.cc | 11 +- .../cuda/math/unary_elementwise_ops.cc | 23 +- .../cuda/math/unary_elementwise_ops.h | 7 + .../cuda/math/unary_elementwise_ops_impl.cu | 1 + .../cuda/math/unary_elementwise_ops_impl.h | 3 +- .../core/providers/cuda/nn/batch_norm.cc | 15 +- .../core/providers/cuda/tensor/cast_op.cc | 12 +- 11 files changed, 223 insertions(+), 87 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index 750c7416d8..26354287f0 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -105,6 +105,18 @@ __device__ __inline__ double _Sqrt(double a) { return sqrt(a); } template <> __device__ __inline__ half _Sqrt(half a) { return half(sqrtf((float)a)); } +template +__device__ __inline__ T _Erf(T a); + +template <> +__device__ __inline__ float _Erf(float a) { return erff(a); } + +template <> +__device__ __inline__ double _Erf(double a) { return erf(a); } + +template <> +__device__ __inline__ half _Erf(half a) { return half(erff((float)a)); } + template __device__ __inline__ T _Exp(T a); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 65965ea92d..a430e385c9 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -249,12 +249,18 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Sq class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Identity); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, Dropout); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Gather); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 9, float, Gemm); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 9, double, Gemm); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 9, MLFloat16, Gemm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, MatMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, MatMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, MatMul); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, Gemm); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, Gemm); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Gemm); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, Gemm); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double, Gemm); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16, Gemm); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 8, float, MatMul); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 8, double, MatMul); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 8, MLFloat16, MatMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, MatMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double, MatMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, Tile); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, Tile); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, Tile); @@ -409,9 +415,15 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, Exp); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, Exp); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, Exp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 9, float, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 9, double, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 9, MLFloat16, BatchNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, Erf); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double, Erf); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16, Erf); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, BatchNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, BatchNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double, BatchNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16, BatchNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, LRN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, LRN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, LRN); @@ -472,18 +484,30 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ReduceLogSumExp); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ReduceLogSumExp); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ReduceLogSumExp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, int8_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, int16_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, int32_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, int64_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, uint8_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, uint16_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, uint32_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, uint64_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, bool, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 8, float, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 8, double, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 8, MLFloat16, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 8, int8_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 8, int16_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 8, int32_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 8, int64_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 8, uint8_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 8, uint16_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 8, uint32_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 8, uint64_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 8, bool, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int8_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int16_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int32_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int64_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, uint8_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, uint16_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, uint32_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, uint64_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, bool, Cast); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, float, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, double, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, MLFloat16, Pad); @@ -542,12 +566,18 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -705,9 +735,15 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -768,18 +804,30 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -843,9 +891,9 @@ std::shared_ptr CUDAExecutionProvider::GetKernelRegistry() const return kernel_registry; } -bool CUDAExecutionProvider::RNNNeedFallbackToCPU(const onnxruntime::Node& node, - const std::vector activations_supported, - const std::string& op_type) const { +static bool RNNNeedFallbackToCPU(const onnxruntime::Node& node, + const std::vector activations_supported, + const std::string& op_type) { auto node_attributes = node.GetAttributes(); // Check attributes for (auto& attr : node_attributes) { @@ -899,7 +947,7 @@ bool CUDAExecutionProvider::RNNNeedFallbackToCPU(const onnxruntime::Node& node, return false; } -bool CUDAExecutionProvider::ConvNeedFallbackToCPU(const onnxruntime::Node& node) const { +static bool ConvNeedFallbackToCPU(const onnxruntime::Node& node) { auto node_attributes = node.GetAttributes(); // Check attributes for (auto& attr : node_attributes) { @@ -923,6 +971,24 @@ bool CUDAExecutionProvider::ConvNeedFallbackToCPU(const onnxruntime::Node& node) return false; } +static bool CastNeedFallbackToCPU(const onnxruntime::Node& node) { + auto node_attributes = node.GetAttributes(); + // Check attributes + for (auto& attr : node_attributes) { + auto attr_name = attr.first; + auto attr_value = attr.second; + + // string is not supported + if ("to" == attr_name && ::ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INT == attr_value.type()) { + auto to_type = attr_value.i(); + if (to_type == ::ONNX_NAMESPACE::TensorProto_DataType_STRING) + return true; + } + } + + return false; +} + std::vector> CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, const std::vector&) const { @@ -945,20 +1011,51 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, bool not_supported = false; bool force_outside = false; - + bool force_inside = false; // for some compute heavy ops, we'll force it to run inside CUDA if ("LSTM" == node.OpType()) { // the supported activations covers the bidirectional mode std::vector activations_supported{"sigmoid", "tanh", "tanh", "sigmoid", "tanh", "tanh"}; not_supported = RNNNeedFallbackToCPU(node, activations_supported, node.OpType()); + force_inside = !not_supported; } else if ("RNN" == node.OpType()) { std::vector activations_supported{"tanh", "tanh"}; not_supported = RNNNeedFallbackToCPU(node, activations_supported, node.OpType()); + force_inside = !not_supported; } else if ("GRU" == node.OpType()) { std::vector activations_supported{"sigmoid", "tanh", "sigmoid", "tanh"}; not_supported = RNNNeedFallbackToCPU(node, activations_supported, node.OpType()); + force_inside = !not_supported; } else if ("Conv" == node.OpType()) { not_supported = ConvNeedFallbackToCPU(node); + force_inside = !not_supported; + } else if ("Cast" == node.OpType()) { + not_supported = CastNeedFallbackToCPU(node); + // cast is not compute heavy, and may be placed outside + } + + if (!not_supported && !force_inside) { + // Note that nodes with only inputs from initializer would not be place on CUDA + // Ideally, those nodes should be eliminated in constant folding + bool all_non_initializer_inputs_from_outside = true; + node.ForEachWithIndex( + node.InputDefs(), + [&](const NodeArg& def, size_t) { + const ONNX_NAMESPACE::TensorProto* initializer = nullptr; + if (!graph.GetInitializedTensor(def.Name(), initializer) && !defs_outside_cuda.count(&def)) + all_non_initializer_inputs_from_outside = false; + return Status::OK(); + }); + if (all_non_initializer_inputs_from_outside) { + force_outside = true; + } + } + + if (!force_inside && (not_supported || force_outside)) { + defs_outside_cuda.insert(node.OutputDefs().cbegin(), node.OutputDefs().cend()); + if (not_supported) + LOGS_DEFAULT(WARNING) << "Fallback to CPU execution provider for Op type: " << node.OpType() << " node name: " << node.Name(); } else { + // for nodes placed on CUDA, check if its output is on CPU node.ForEachWithIndex( node.OutputDefs(), [&](const NodeArg& def, size_t out_index) { @@ -966,28 +1063,6 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, defs_outside_cuda.insert(&def); return Status::OK(); }); - - // Note that nodes with only inputs from initializer would not be place on CUDA - // Ideally, those nodes should be eliminated in constant folding - bool all_non_initializer_inputs_on_cpu = true; - node.ForEachWithIndex( - node.InputDefs(), - [&](const NodeArg& def, size_t) { - const ONNX_NAMESPACE::TensorProto* initializer = nullptr; - if (!graph.GetInitializedTensor(def.Name(), initializer) && !defs_outside_cuda.count(&def)) - all_non_initializer_inputs_on_cpu = false; - return Status::OK(); - }); - if (all_non_initializer_inputs_on_cpu) { - force_outside = true; - } - } - - if (not_supported || force_outside) { - defs_outside_cuda.insert(node.OutputDefs().cbegin(), node.OutputDefs().cend()); - if (not_supported) - LOGS_DEFAULT(WARNING) << "Fallback to CPU execution provider for Op type: " << node.OpType() << " node name: " << node.Name(); - } else { std::unique_ptr sub_graph = std::make_unique(); sub_graph->nodes.push_back(node.Index()); result.push_back(std::make_unique(std::move(sub_graph))); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index e9f3848e8f..945a1d8f88 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -90,6 +90,7 @@ class CUDAExecutionProvider : public IExecutionProvider { virtual std::vector> GetCapability(const onnxruntime::GraphViewer& graph, const std::vector& kernel_registries) const override; + private: cudaStream_t streams_[kTotalCudaStreams]; int device_id_; @@ -169,9 +170,6 @@ class CUDAExecutionProvider : public IExecutionProvider { mutable OrtMutex context_pool_mutex_; void ReleasePerThreadStuffs() const; - - bool RNNNeedFallbackToCPU(const onnxruntime::Node& node, const std::vector activations_supported, const std::string& op_type) const; - bool ConvNeedFallbackToCPU(const onnxruntime::Node& node) const; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/gemm.cc b/onnxruntime/core/providers/cuda/math/gemm.cc index 0bfb1f2454..9fcf089f31 100644 --- a/onnxruntime/core/providers/cuda/math/gemm.cc +++ b/onnxruntime/core/providers/cuda/math/gemm.cc @@ -14,6 +14,15 @@ namespace cuda { Gemm, \ kOnnxDomain, \ 7, \ + 8, \ + T, \ + kCudaExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Gemm); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Gemm, \ + kOnnxDomain, \ 9, \ T, \ kCudaExecutionProvider, \ diff --git a/onnxruntime/core/providers/cuda/math/matmul.cc b/onnxruntime/core/providers/cuda/math/matmul.cc index c0c189625b..ad49d793be 100644 --- a/onnxruntime/core/providers/cuda/math/matmul.cc +++ b/onnxruntime/core/providers/cuda/math/matmul.cc @@ -10,10 +10,19 @@ namespace onnxruntime { namespace cuda { #define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + MatMul, \ + kOnnxDomain, \ + 1, 8, \ + T, \ + kCudaExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + MatMul); \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ MatMul, \ kOnnxDomain, \ - 1, \ + 9, \ T, \ kCudaExecutionProvider, \ KernelDefBuilder() \ diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index de118b9ee7..323a7ad997 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -23,17 +23,17 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ x); -#define UNARY_ELEMENTWISE_COMPUTE(x, T) \ - template <> \ - Status x::ComputeInternal(OpKernelContext* context) const { \ - UnaryElementwisePreparation p; \ - UnaryElementwise::Prepare(context, &p); \ - Impl_##x( \ - reinterpret_cast::MappedType*>(p.input_tensor->template Data()), \ - reinterpret_cast::MappedType*>(p.output_tensor->template MutableData()),\ - p.output_tensor->Shape().Size()); \ - \ - return Status::OK(); \ +#define UNARY_ELEMENTWISE_COMPUTE(x, T) \ + template <> \ + Status x::ComputeInternal(OpKernelContext* context) const { \ + UnaryElementwisePreparation p; \ + UnaryElementwise::Prepare(context, &p); \ + Impl_##x( \ + reinterpret_cast::MappedType*>(p.input_tensor->template Data()), \ + reinterpret_cast::MappedType*>(p.output_tensor->template MutableData()), \ + p.output_tensor->Shape().Size()); \ + \ + return Status::OK(); \ } #define UNARY_OP_TYPED(name, ver, T) \ @@ -81,6 +81,7 @@ UNARY_OP_HFD(Reciprocal, 6) UNARY_OP_HFD(Sqrt, 6) UNARY_OP_HFD(Log, 6) UNARY_OP_HFD(Exp, 6) +UNARY_OP_HFD(Erf, 9) } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h index 5b719ae4e8..8ae3753ba4 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h @@ -77,5 +77,12 @@ class Exp final : public UnaryElementwise { Status ComputeInternal(OpKernelContext* context) const override; }; +template +class Erf final : public UnaryElementwise { + public: + Erf(const OpKernelInfo& info) : UnaryElementwise(info) {} + Status ComputeInternal(OpKernelContext* context) const override; +}; + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index 49f05b3a75..4049cc1e90 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -76,6 +76,7 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Reciprocal) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sqrt) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Log) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Exp) +SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Erf) // When casting, half needs to be converted via float type from most other types template diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h index db990bc692..daaa3a5304 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h @@ -21,7 +21,8 @@ namespace cuda { UNARY_OP_NAME_EXPR(Reciprocal, T(1) / a) \ UNARY_OP_NAME_EXPR(Sqrt, _Sqrt(a)) \ UNARY_OP_NAME_EXPR(Exp, _Exp(a)) \ - UNARY_OP_NAME_EXPR(Log, _Log(a)) + UNARY_OP_NAME_EXPR(Log, _Log(a)) \ + UNARY_OP_NAME_EXPR(Erf, _Erf(a)) #define UNARY_ELEMENTWISE_IMPL_DECLARATION(name) \ template \ diff --git a/onnxruntime/core/providers/cuda/nn/batch_norm.cc b/onnxruntime/core/providers/cuda/nn/batch_norm.cc index 1c40d850e5..60c2055bd4 100644 --- a/onnxruntime/core/providers/cuda/nn/batch_norm.cc +++ b/onnxruntime/core/providers/cuda/nn/batch_norm.cc @@ -15,7 +15,20 @@ namespace cuda { ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ BatchNormalization, \ kOnnxDomain, \ - 7, 9, \ + 7, 8, \ + T, \ + kCudaExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("X", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("scale", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("B", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("mean", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("var", DataTypeImpl::GetTensorType()), \ + BatchNorm); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + BatchNormalization, \ + kOnnxDomain, \ + 9, \ T, \ kCudaExecutionProvider, \ KernelDefBuilder() \ diff --git a/onnxruntime/core/providers/cuda/tensor/cast_op.cc b/onnxruntime/core/providers/cuda/tensor/cast_op.cc index ef46535e13..d86a1c2c66 100644 --- a/onnxruntime/core/providers/cuda/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cuda/tensor/cast_op.cc @@ -24,10 +24,20 @@ const std::vector castOpTypeConstraints{ DataTypeImpl::GetTensorType()}; #define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Cast, \ + kOnnxDomain, \ + 6, 8, \ + T, \ + kCudaExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", castOpTypeConstraints), \ + Cast); \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ Cast, \ kOnnxDomain, \ - 6, \ + 9, \ T, \ kCudaExecutionProvider, \ KernelDefBuilder() \