From 6cc57721f4cf6e10cb7f829ca66cf7aaa1a70ef4 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Fri, 15 Nov 2019 10:36:28 +1000 Subject: [PATCH] Change CUDA implementation of Transpose to support all fixed size tensor types (#2387) * Change CUDA implementation of Transpose to not use a typed kernel so we can support more types with minimum binary size. Add support for 8, 16, 32 and 64 bit types. Add unit tests. Add method so the implementation can be called directly (will be used by CUDA Scan very soon). * Disable TensorRT for MLFloat16 and int8 unit tests. * Address PR comment and add support for calling cublas implementation if type is mlfloat16. --- .../providers/cuda/cuda_execution_provider.cc | 8 +- .../core/providers/cuda/tensor/transpose.cc | 146 +++++++++--------- .../core/providers/cuda/tensor/transpose.h | 4 +- .../providers/cuda/tensor/transpose_impl.cu | 56 +++++-- .../providers/cuda/tensor/transpose_impl.h | 5 +- .../providers/cpu/tensor/transpose_test.cc | 85 +++++++++- 6 files changed, 201 insertions(+), 103 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index ff71993c01..19e0c5e148 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -486,9 +486,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Sh class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, Tile); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, Tile); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, Tile); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Transpose); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Transpose); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Transpose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Transpose); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, InstanceNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, InstanceNormalization); @@ -863,9 +861,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/tensor/transpose.cc b/onnxruntime/core/providers/cuda/tensor/transpose.cc index 065d41e2b4..f6a97a7fd7 100644 --- a/onnxruntime/core/providers/cuda/tensor/transpose.cc +++ b/onnxruntime/core/providers/cuda/tensor/transpose.cc @@ -9,19 +9,16 @@ namespace onnxruntime { namespace cuda { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - Transpose, \ - kOnnxDomain, \ - 1, \ - T, \ - kCudaExecutionProvider, \ - KernelDefBuilder() \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Transpose); +ONNX_OPERATOR_KERNEL_EX(Transpose, + kOnnxDomain, + 1, + kCudaExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), + Transpose); // special case acceleration using cublas matrix transpose -std::tuple TryTransposeWithCublas(const std::vector& perm, const TensorShape& input_shape) { +static std::tuple TryTransposeWithCublas(const std::vector& perm, const TensorShape& input_shape) { int M = 0; int N = 0; @@ -47,7 +44,72 @@ std::tuple TryTransposeWithCublas(const std::vector& perm, con } template -Status Transpose::ComputeInternal(OpKernelContext* ctx) const { +Status TransposeWithCublas(cublasHandle_t cublas_handle, const Tensor& input, Tensor& output, int M, int N) { + typedef typename ToCudaType::MappedType CudaT; + CudaT one = ToCudaType::FromFloat(1.0f); + CudaT zero = ToCudaType::FromFloat(0.0f); + const CudaT* input_data = reinterpret_cast(input.Data()); + CudaT* output_data = reinterpret_cast(output.MutableData()); + CUBLAS_RETURN_IF_ERROR( + cublasTransposeHelper(cublas_handle, + CUBLAS_OP_T, CUBLAS_OP_T, M, N, + &one, + input_data, + N, + &zero, + input_data, + N, + output_data, + M)); + return Status::OK(); +} + +Status Transpose::DoTranspose(const Transpose& kernel, + const std::vector& permutations, const Tensor& input, Tensor& output) { + // special case when there is a dim value of 0 in the shape. + if (output.Shape().Size() == 0) + return Status::OK(); + + auto element_type = input.GetElementType(); + if (element_type == utils::GetONNXTensorElementDataType() || + element_type == utils::GetONNXTensorElementDataType() || + element_type == utils::GetONNXTensorElementDataType()) { + auto mn = TryTransposeWithCublas(permutations, input.Shape()); + int M = std::get<0>(mn); + int N = std::get<1>(mn); + if (M != 0 && N != 0) { + if (element_type == utils::GetONNXTensorElementDataType()) { + return TransposeWithCublas(kernel.CublasHandle(), input, output, M, N); + } else if (element_type == utils::GetONNXTensorElementDataType()) { + return TransposeWithCublas(kernel.CublasHandle(), input, output, M, N); + } else { + return TransposeWithCublas(kernel.CublasHandle(), input, output, M, N); + } + } + } + + const std::vector& input_dims = input.Shape().GetDims(); + const std::vector& output_dims = output.Shape().GetDims(); + + auto rank = input_dims.size(); + CudaAsyncBuffer input_strides(&kernel, rank); + CudaAsyncBuffer perm(&kernel, permutations); + CudaAsyncBuffer fdm_output_strides(&kernel, rank); + ORT_ENFORCE(TensorPitches::Calculate(input_strides.CpuSpan(), input_dims)); + ORT_ENFORCE(CalculateFdmStrides(fdm_output_strides.CpuSpan(), output_dims)); + + ORT_RETURN_IF_ERROR(input_strides.CopyToGpu()); + ORT_RETURN_IF_ERROR(perm.CopyToGpu()); + ORT_RETURN_IF_ERROR(fdm_output_strides.CopyToGpu()); + + size_t element_size = input.DataType()->Size(); + auto status = TransposeImpl(element_size, rank, input_strides.GpuPtr(), perm.GpuPtr(), input.DataRaw(), + fdm_output_strides.GpuPtr(), output.MutableDataRaw(), output.Shape().Size()); + + return status; +} + +Status Transpose::ComputeInternal(OpKernelContext* ctx) const { const Tensor* X_ptr = ctx->Input(0); if (X_ptr == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); const Tensor& X = *X_ptr; @@ -65,66 +127,8 @@ Status Transpose::ComputeInternal(OpKernelContext* ctx) const { TensorShape output_shape{output_dims}; Tensor* Y = ctx->Output(0, output_shape); - // special case when there is a dim value of 0 in the shape. - if (output_shape.Size() == 0) - return Status::OK(); - - auto mn = TryTransposeWithCublas(*p_perm, input_shape); - int M = std::get<0>(mn); - int N = std::get<1>(mn); - if (M != 0 && N != 0) { - typedef typename ToCudaType::MappedType CudaT; - CudaT one = ToCudaType::FromFloat(1.0f); - CudaT zero = ToCudaType::FromFloat(0.0f); - const CudaT* input_data = reinterpret_cast(X.template Data()); - CudaT* output_data = reinterpret_cast(Y->template MutableData()); - CUBLAS_RETURN_IF_ERROR( - cublasTransposeHelper( - CublasHandle(), - CUBLAS_OP_T, - CUBLAS_OP_T, - M, - N, - &one, - input_data, - N, - &zero, - input_data, - N, - output_data, - M)); - return Status::OK(); - } - - CudaAsyncBuffer input_strides(this, rank); - CudaAsyncBuffer perm(this, *p_perm); - CudaAsyncBuffer fdm_output_strides(this, rank); - ORT_ENFORCE(TensorPitches::Calculate(input_strides.CpuSpan(), input_dims)); - ORT_ENFORCE(CalculateFdmStrides(fdm_output_strides.CpuSpan(), output_dims)); - - ORT_RETURN_IF_ERROR(input_strides.CopyToGpu()); - ORT_RETURN_IF_ERROR(perm.CopyToGpu()); - ORT_RETURN_IF_ERROR(fdm_output_strides.CopyToGpu()); - - TransposeImpl( - rank, - input_strides.GpuPtr(), - perm.GpuPtr(), - reinterpret_cast::MappedType*>(X.template Data()), - fdm_output_strides.GpuPtr(), - reinterpret_cast::MappedType*>(Y->template MutableData()), - output_shape.Size()); - - return Status::OK(); + return DoTranspose(*this, *p_perm, X, *Y); } -#define SPECIALIZED_COMPUTE(T) \ - REGISTER_KERNEL_TYPED(T) \ - template Status Transpose::ComputeInternal(OpKernelContext* ctx) const; - -SPECIALIZED_COMPUTE(float) -SPECIALIZED_COMPUTE(double) -SPECIALIZED_COMPUTE(MLFloat16) - } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/transpose.h b/onnxruntime/core/providers/cuda/tensor/transpose.h index 4ca504f14b..d091584b42 100644 --- a/onnxruntime/core/providers/cuda/tensor/transpose.h +++ b/onnxruntime/core/providers/cuda/tensor/transpose.h @@ -12,12 +12,14 @@ namespace onnxruntime { namespace cuda { -template class Transpose final : public CudaKernel, public TransposeBase { public: Transpose(const OpKernelInfo& info) : CudaKernel(info), TransposeBase(info) {} Status ComputeInternal(OpKernelContext* context) const override; + + static Status DoTranspose(const Transpose& transpose_kernel, + const std::vector& permutations, const Tensor& input, Tensor& output); }; } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu b/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu index 482aa916d4..0bbdd33209 100644 --- a/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu @@ -23,23 +23,49 @@ __global__ void _TransposeKernel(size_t shape_rank, const int64_t* input_strides output_data[id] = input_data[input_index]; } -template -void TransposeImpl(size_t shape_rank, const int64_t* input_strides, const size_t* perm, const T* input_data, - const fast_divmod* fdm_output_strides, T* output_data, size_t N) { +Status TransposeImpl(size_t element_size, size_t shape_rank, const int64_t* input_strides, const size_t* perm, + const void* input_data, const fast_divmod* fdm_output_strides, void* output_data, size_t N) { int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); - _TransposeKernel<<>>( - shape_rank, input_strides, perm, input_data, - fdm_output_strides, output_data, N); + switch (element_size) { + case sizeof(int8_t): + _TransposeKernel<<>>( + shape_rank, input_strides, perm, + reinterpret_cast::MappedType*>(input_data), + fdm_output_strides, + reinterpret_cast::MappedType*>(output_data), + N); + break; + case sizeof(int16_t): + _TransposeKernel<<>>( + shape_rank, input_strides, perm, + reinterpret_cast::MappedType*>(input_data), + fdm_output_strides, + reinterpret_cast::MappedType*>(output_data), + N); + break; + case sizeof(int32_t): + _TransposeKernel<<>>( + shape_rank, input_strides, perm, + reinterpret_cast::MappedType*>(input_data), + fdm_output_strides, + reinterpret_cast::MappedType*>(output_data), + N); + break; + case sizeof(int64_t): + _TransposeKernel<<>>( + shape_rank, input_strides, perm, + reinterpret_cast::MappedType*>(input_data), + fdm_output_strides, + reinterpret_cast::MappedType*>(output_data), + N); + break; + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for transpose on CUDA. Element size was ", + element_size); + } + + return Status::OK(); } -#define SPECIALIZED_IMPL(T) \ - template void TransposeImpl(size_t shape_rank, const int64_t* input_strides, const size_t* perm, \ - const T* input_data, const fast_divmod* fdm_output_strides, T* output_data, \ - size_t N); - -SPECIALIZED_IMPL(float) -SPECIALIZED_IMPL(double) -SPECIALIZED_IMPL(half) - } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/transpose_impl.h b/onnxruntime/core/providers/cuda/tensor/transpose_impl.h index 0d53abcf49..023944b73b 100644 --- a/onnxruntime/core/providers/cuda/tensor/transpose_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/transpose_impl.h @@ -8,9 +8,8 @@ namespace onnxruntime { namespace cuda { -template -void TransposeImpl(size_t shape_rank, const int64_t* input_strides, const size_t* perm, const T* input_data, - const fast_divmod* fdm_output_strides, T* output_data, size_t N); +Status TransposeImpl(size_t element_size, size_t shape_rank, const int64_t* input_strides, const size_t* perm, + const void* input_data, const fast_divmod* fdm_output_strides, void* output_data, size_t N); } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc index d09d711105..20370aa65d 100644 --- a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc @@ -64,20 +64,91 @@ TEST(TransposeOpTest, TwoDimNoAttrStr) { // Test 2 dimensional transpose, with permutation attribute specified TEST(TransposeOpTest, TwoDim) { std::vector input_shape({2, 3}); - std::vector input_vals = { - 1.0f, 2.0f, 3.0f, - 4.0f, 5.0f, 6.0f}; + std::vector input_vals = {1.0f, 2.0f, 3.0f, + 4.0f, 5.0f, 6.0f}; std::vector perm = {1, 0}; std::vector expected_shape({3, 2}); - auto expected_vals = { - 1.0f, 4.0f, - 2.0f, 5.0f, - 3.0f, 6.0f}; + auto expected_vals = {1.0f, 4.0f, + 2.0f, 5.0f, + 3.0f, 6.0f}; TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals); } +TEST(TransposeOpTest, TwoDim_double) { + std::vector input_shape({2, 3}); + std::vector input_vals = {1.0, 2.0, 3.0, + 4.0, 5.0, 6.0}; + + std::vector perm = {1, 0}; + std::vector expected_shape({3, 2}); + std::initializer_list expected_vals = {1.0, 4.0, + 2.0, 5.0, + 3.0, 6.0}; + + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals); +} + +TEST(TransposeOpTest, TwoDim_int32) { + std::vector input_shape({2, 3}); + std::vector input_vals = {1, 2, 3, + 4, 5, 6}; + + std::vector perm = {1, 0}; + std::vector expected_shape({3, 2}); + std::initializer_list expected_vals = {1, 4, + 2, 5, + 3, 6}; + + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals); +} + +TEST(TransposeOpTest, TwoDim_int16) { + std::vector input_shape({2, 3}); + std::vector input_vals = { + 1, 2, 3, + 4, 5, 6}; + + std::vector perm = {1, 0}; + std::vector expected_shape({3, 2}); + std::initializer_list expected_vals = { + 1, 4, + 2, 5, + 3, 6}; + + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals); +} + +TEST(TransposeOpTest, TwoDim_mlfloat16) { + std::vector input_shape({2, 3}); + std::vector input_vals; + for (uint16_t i = 0; i < 6; ++i) + input_vals.push_back(MLFloat16(i)); + + std::vector perm = {1, 0}; + std::vector expected_shape({3, 2}); + std::initializer_list expected_vals = {MLFloat16(1), MLFloat16(4), + MLFloat16(2), MLFloat16(5), + MLFloat16(3), MLFloat16(6)}; + + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, false); +} + +TEST(TransposeOpTest, TwoDim_int8) { + std::vector input_shape({2, 3}); + std::vector input_vals = {1, 2, 3, + 4, 5, 6}; + + std::vector perm = {1, 0}; + std::vector expected_shape({3, 2}); + std::initializer_list expected_vals = {1, 4, + 2, 5, + 3, 6}; + + TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, false); +} + TEST(TransposeOpTest, TwoDimStr) { std::vector input_shape({2, 3}); std::vector input_vals = {