Size Op - CUDA kernel support (#4868)

* cuda kernel support

* on comments

* test UT

* test UT

* revert settings

* attempt to fix broken UT

* corrected UT fix

Co-authored-by: Ethan Tao <ettao@microsoft.com>
This commit is contained in:
ytaous 2020-08-25 14:26:41 -07:00 committed by GitHub
parent 294eaca9ef
commit cb2dfee31c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 0 deletions

View file

@ -576,6 +576,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 5, Reshape);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 4, Reshape_1);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Shape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Size);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 12, Tile);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Tile);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Transpose);
@ -1065,6 +1066,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 5, Reshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 4, Reshape_1)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Shape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Size)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 12, Tile)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Tile)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Transpose)>,

View file

@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/cpu/tensor/size.h"
#include "core/providers/cuda/cuda_fwd.h"
namespace onnxruntime {
namespace cuda {
ONNX_OPERATOR_KERNEL_EX(
Size,
kOnnxDomain,
1,
kCudaExecutionProvider,
KernelDefBuilder()
// properly force CPU/GPU synch inside the kernel
.OutputMemoryType<OrtMemTypeCPUInput>(0)
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
Size);
} // namespace cuda
} // namespace onnxruntime