diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 0f552dbfed..70b99b9a98 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -576,6 +576,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 5, Reshape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 4, Reshape_1); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Shape); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Size); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 12, Tile); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Tile); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Transpose); @@ -1065,6 +1066,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/tensor/size.cc b/onnxruntime/core/providers/cuda/tensor/size.cc new file mode 100644 index 0000000000..39d30242d3 --- /dev/null +++ b/onnxruntime/core/providers/cuda/tensor/size.cc @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cpu/tensor/size.h" +#include "core/providers/cuda/cuda_fwd.h" + +namespace onnxruntime { +namespace cuda { + +ONNX_OPERATOR_KERNEL_EX( + Size, + kOnnxDomain, + 1, + kCudaExecutionProvider, + KernelDefBuilder() + // properly force CPU/GPU synch inside the kernel + .OutputMemoryType(0) + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Size); + +} // namespace cuda +} // namespace onnxruntime