diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index ff71993c01..19e0c5e148 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -486,9 +486,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Sh class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, Tile); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, Tile); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, Tile); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Transpose); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Transpose); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Transpose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Transpose); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, InstanceNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, InstanceNormalization); @@ -863,9 +861,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/tensor/transpose.cc b/onnxruntime/core/providers/cuda/tensor/transpose.cc index 065d41e2b4..f6a97a7fd7 100644 --- a/onnxruntime/core/providers/cuda/tensor/transpose.cc +++ b/onnxruntime/core/providers/cuda/tensor/transpose.cc @@ -9,19 +9,16 @@ namespace onnxruntime { namespace cuda { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - Transpose, \ - kOnnxDomain, \ - 1, \ - T, \ - kCudaExecutionProvider, \ - KernelDefBuilder() \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Transpose); +ONNX_OPERATOR_KERNEL_EX(Transpose, + kOnnxDomain, + 1, + kCudaExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), + Transpose); // special case acceleration using cublas matrix transpose -std::tuple TryTransposeWithCublas(const std::vector& perm, const TensorShape& input_shape) { +static std::tuple TryTransposeWithCublas(const std::vector& perm, const TensorShape& input_shape) { int M = 0; int N = 0; @@ -47,7 +44,72 @@ std::tuple TryTransposeWithCublas(const std::vector& perm, con } template -Status Transpose::ComputeInternal(OpKernelContext* ctx) const { +Status TransposeWithCublas(cublasHandle_t cublas_handle, const Tensor& input, Tensor& output, int M, int N) { + typedef typename ToCudaType::MappedType CudaT; + CudaT one = ToCudaType::FromFloat(1.0f); + CudaT zero = ToCudaType::FromFloat(0.0f); + const CudaT* input_data = reinterpret_cast(input.Data()); + CudaT* output_data = reinterpret_cast(output.MutableData()); + CUBLAS_RETURN_IF_ERROR( + cublasTransposeHelper(cublas_handle, + CUBLAS_OP_T, CUBLAS_OP_T, M, N, + &one, + input_data, + N, + &zero, + input_data, + N, + output_data, + M)); + return Status::OK(); +} + +Status Transpose::DoTranspose(const Transpose& kernel, + const std::vector& permutations, const Tensor& input, Tensor& output) { + // special case when there is a dim value of 0 in the shape. + if (output.Shape().Size() == 0) + return Status::OK(); + + auto element_type = input.GetElementType(); + if (element_type == utils::GetONNXTensorElementDataType() || + element_type == utils::GetONNXTensorElementDataType() || + element_type == utils::GetONNXTensorElementDataType()) { + auto mn = TryTransposeWithCublas(permutations, input.Shape()); + int M = std::get<0>(mn); + int N = std::get<1>(mn); + if (M != 0 && N != 0) { + if (element_type == utils::GetONNXTensorElementDataType()) { + return TransposeWithCublas(kernel.CublasHandle(), input, output, M, N); + } else if (element_type == utils::GetONNXTensorElementDataType()) { + return TransposeWithCublas(kernel.CublasHandle(), input, output, M, N); + } else { + return TransposeWithCublas(kernel.CublasHandle(), input, output, M, N); + } + } + } + + const std::vector& input_dims = input.Shape().GetDims(); + const std::vector& output_dims = output.Shape().GetDims(); + + auto rank = input_dims.size(); + CudaAsyncBuffer input_strides(&kernel, rank); + CudaAsyncBuffer perm(&kernel, permutations); + CudaAsyncBuffer fdm_output_strides(&kernel, rank); + ORT_ENFORCE(TensorPitches::Calculate(input_strides.CpuSpan(), input_dims)); + ORT_ENFORCE(CalculateFdmStrides(fdm_output_strides.CpuSpan(), output_dims)); + + ORT_RETURN_IF_ERROR(input_strides.CopyToGpu()); + ORT_RETURN_IF_ERROR(perm.CopyToGpu()); + ORT_RETURN_IF_ERROR(fdm_output_strides.CopyToGpu()); + + size_t element_size = input.DataType()->Size(); + auto status = TransposeImpl(element_size, rank, input_strides.GpuPtr(), perm.GpuPtr(), input.DataRaw(), + fdm_output_strides.GpuPtr(), output.MutableDataRaw(), output.Shape().Size()); + + return status; +} + +Status Transpose::ComputeInternal(OpKernelContext* ctx) const { const Tensor* X_ptr = ctx->Input(0); if (X_ptr == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); const Tensor& X = *X_ptr; @@ -65,66 +127,8 @@ Status Transpose::ComputeInternal(OpKernelContext* ctx) const { TensorShape output_shape{output_dims}; Tensor* Y = ctx->Output(0, output_shape); - // special case when there is a dim value of 0 in the shape. - if (output_shape.Size() == 0) - return Status::OK(); - - auto mn = TryTransposeWithCublas(*p_perm, input_shape); - int M = std::get<0>(mn); - int N = std::get<1>(mn); - if (M != 0 && N != 0) { - typedef typename ToCudaType::MappedType CudaT; - CudaT one = ToCudaType::FromFloat(1.0f); - CudaT zero = ToCudaType::FromFloat(0.0f); - const CudaT* input_data = reinterpret_cast(X.template Data()); - CudaT* output_data = reinterpret_cast(Y->template MutableData()); - CUBLAS_RETURN_IF_ERROR( - cublasTransposeHelper( - CublasHandle(), - CUBLAS_OP_T, - CUBLAS_OP_T, - M, - N, - &one, - input_data, - N, - &zero, - input_data, - N, - output_data, - M)); - return Status::OK(); - } - - CudaAsyncBuffer input_strides(this, rank); - CudaAsyncBuffer perm(this, *p_perm); - CudaAsyncBuffer fdm_output_strides(this, rank); - ORT_ENFORCE(TensorPitches::Calculate(input_strides.CpuSpan(), input_dims)); - ORT_ENFORCE(CalculateFdmStrides(fdm_output_strides.CpuSpan(), output_dims)); - - ORT_RETURN_IF_ERROR(input_strides.CopyToGpu()); - ORT_RETURN_IF_ERROR(perm.CopyToGpu()); - ORT_RETURN_IF_ERROR(fdm_output_strides.CopyToGpu()); - - TransposeImpl( - rank, - input_strides.GpuPtr(), - perm.GpuPtr(), - reinterpret_cast::MappedType*>(X.template Data()), - fdm_output_strides.GpuPtr(), - reinterpret_cast::MappedType*>(Y->template MutableData()), - output_shape.Size()); - - return Status::OK(); + return DoTranspose(*this, *p_perm, X, *Y); } -#define SPECIALIZED_COMPUTE(T) \ - REGISTER_KERNEL_TYPED(T) \ - template Status Transpose::ComputeInternal(OpKernelContext* ctx) const; - -SPECIALIZED_COMPUTE(float) -SPECIALIZED_COMPUTE(double) -SPECIALIZED_COMPUTE(MLFloat16) - } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/transpose.h b/onnxruntime/core/providers/cuda/tensor/transpose.h index 4ca504f14b..d091584b42 100644 --- a/onnxruntime/core/providers/cuda/tensor/transpose.h +++ b/onnxruntime/core/providers/cuda/tensor/transpose.h @@ -12,12 +12,14 @@ namespace onnxruntime { namespace cuda { -template class Transpose final : public CudaKernel, public TransposeBase { public: Transpose(const OpKernelInfo& info) : CudaKernel(info), TransposeBase(info) {} Status ComputeInternal(OpKernelContext* context) const override; + + static Status DoTranspose(const Transpose& transpose_kernel, + const std::vector& permutations, const Tensor& input, Tensor& output); }; } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu b/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu index 482aa916d4..0bbdd33209 100644 --- a/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu @@ -23,23 +23,49 @@ __global__ void _TransposeKernel(size_t shape_rank, const int64_t* input_strides output_data[id] = input_data[input_index]; } -template -void TransposeImpl(size_t shape_rank, const int64_t* input_strides, const size_t* perm, const T* input_data, - const fast_divmod* fdm_output_strides, T* output_data, size_t N) { +Status TransposeImpl(size_t element_size, size_t shape_rank, const int64_t* input_strides, const size_t* perm, + const void* input_data, const fast_divmod* fdm_output_strides, void* output_data, size_t N) { int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); - _TransposeKernel<<>>( - shape_rank, input_strides, perm, input_data, - fdm_output_strides, output_data, N); + switch (element_size) { + case sizeof(int8_t): + _TransposeKernel<<>>( + shape_rank, input_strides, perm, + reinterpret_cast::MappedType*>(input_data), + fdm_output_strides, + reinterpret_cast::MappedType*>(output_data), + N); + break; + case sizeof(int16_t): + _TransposeKernel<<>>( + shape_rank, input_strides, perm, + reinterpret_cast::MappedType*>(input_data), + fdm_output_strides, + reinterpret_cast::MappedType*>(output_data), + N); + break; + case sizeof(int32_t): + _TransposeKernel<<>>( + shape_rank, input_strides, perm, + reinterpret_cast::MappedType*>(input_data), + fdm_output_strides, + reinterpret_cast::MappedType*>(output_data), + N); + break; + case sizeof(int64_t): + _TransposeKernel<<>>( + shape_rank, input_strides, perm, + reinterpret_cast::MappedType*>(input_data), + fdm_output_strides, + reinterpret_cast::MappedType*>(output_data), + N); + break; + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for transpose on CUDA. Element size was ", + element_size); + } + + return Status::OK(); } -#define SPECIALIZED_IMPL(T) \ - template void TransposeImpl(size_t shape_rank, const int64_t* input_strides, const size_t* perm, \ - const T* input_data, const fast_divmod* fdm_output_strides, T* output_data, \ - size_t N); - -SPECIALIZED_IMPL(float) -SPECIALIZED_IMPL(double) -SPECIALIZED_IMPL(half) - } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/transpose_impl.h b/onnxruntime/core/providers/cuda/tensor/transpose_impl.h index 0d53abcf49..023944b73b 100644 --- a/onnxruntime/core/providers/cuda/tensor/transpose_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/transpose_impl.h @@ -8,9 +8,8 @@ namespace onnxruntime { namespace cuda { -template -void TransposeImpl(size_t shape_rank, const int64_t* input_strides, const size_t* perm, const T* input_data, - const fast_divmod* fdm_output_strides, T* output_data, size_t N); +Status TransposeImpl(size_t element_size, size_t shape_rank, const int64_t* input_strides, const size_t* perm, + const void* input_data, const fast_divmod* fdm_output_strides, void* output_data, size_t N); } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc index d09d711105..20370aa65d 100644 --- a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc @@ -64,20 +64,91 @@ TEST(TransposeOpTest, TwoDimNoAttrStr) { // Test 2 dimensional transpose, with permutation attribute specified TEST(TransposeOpTest, TwoDim) { std::vector input_shape({2, 3}); - std::vector input_vals = { - 1.0f, 2.0f, 3.0f, - 4.0f, 5.0f, 6.0f}; + std::vector input_vals = {1.0f, 2.0f, 3.0f, + 4.0f, 5.0f, 6.0f}; std::vector perm = {1, 0}; std::vector expected_shape({3, 2}); - auto expected_vals = { - 1.0f, 4.0f, - 2.0f, 5.0f, - 3.0f, 6.0f}; + auto expected_vals = {1.0f, 4.0f, + 2.0f, 5.0f, + 3.0f, 6.0f}; TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals); } +TEST(TransposeOpTest, TwoDim_double) { + std::vector input_shape({2, 3}); + std::vector input_vals = {1.0, 2.0, 3.0, + 4.0, 5.0, 6.0}; + + std::vector perm = {1, 0}; + std::vector expected_shape({3, 2}); + std::initializer_list expected_vals = {1.0, 4.0, + 2.0, 5.0, + 3.0, 6.0}; + + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals); +} + +TEST(TransposeOpTest, TwoDim_int32) { + std::vector input_shape({2, 3}); + std::vector input_vals = {1, 2, 3, + 4, 5, 6}; + + std::vector perm = {1, 0}; + std::vector expected_shape({3, 2}); + std::initializer_list expected_vals = {1, 4, + 2, 5, + 3, 6}; + + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals); +} + +TEST(TransposeOpTest, TwoDim_int16) { + std::vector input_shape({2, 3}); + std::vector input_vals = { + 1, 2, 3, + 4, 5, 6}; + + std::vector perm = {1, 0}; + std::vector expected_shape({3, 2}); + std::initializer_list expected_vals = { + 1, 4, + 2, 5, + 3, 6}; + + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals); +} + +TEST(TransposeOpTest, TwoDim_mlfloat16) { + std::vector input_shape({2, 3}); + std::vector input_vals; + for (uint16_t i = 0; i < 6; ++i) + input_vals.push_back(MLFloat16(i)); + + std::vector perm = {1, 0}; + std::vector expected_shape({3, 2}); + std::initializer_list expected_vals = {MLFloat16(1), MLFloat16(4), + MLFloat16(2), MLFloat16(5), + MLFloat16(3), MLFloat16(6)}; + + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, false); +} + +TEST(TransposeOpTest, TwoDim_int8) { + std::vector input_shape({2, 3}); + std::vector input_vals = {1, 2, 3, + 4, 5, 6}; + + std::vector perm = {1, 0}; + std::vector expected_shape({3, 2}); + std::initializer_list expected_vals = {1, 4, + 2, 5, + 3, 6}; + + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, false); +} + TEST(TransposeOpTest, TwoDimStr) { std::vector input_shape({2, 3}); std::vector input_vals = {