From e945b5fcf649f2301e72d64fcd0df9d3c994fade Mon Sep 17 00:00:00 2001 From: Du Li Date: Thu, 10 Dec 2020 11:04:19 -0800 Subject: [PATCH] adding fp16 support for topk cuda kernel (#6082) * adding fp16 support for topk. * disable fp16 tests for cpu ep Co-authored-by: Du Li --- onnxruntime/core/providers/cuda/math/topk.cc | 2 +- .../core/providers/cuda/math/topk_impl.cu | 49 +++++++++++++------ .../test/providers/cpu/math/topk_op_test.cc | 18 +++++++ 3 files changed, 53 insertions(+), 16 deletions(-) diff --git a/onnxruntime/core/providers/cuda/math/topk.cc b/onnxruntime/core/providers/cuda/math/topk.cc index de78cd6833..0cc508bc0c 100644 --- a/onnxruntime/core/providers/cuda/math/topk.cc +++ b/onnxruntime/core/providers/cuda/math/topk.cc @@ -95,9 +95,9 @@ Status TopK::ComputeInternal(OpKernelContext* ctx) const { if (IS_PRIM_TYPE(int16_t)) return TOPKIMPL(int16_t); if (IS_PRIM_TYPE(int32_t)) return TOPKIMPL(int32_t); if (IS_PRIM_TYPE(int64_t)) return TOPKIMPL(int64_t); + if (IS_PRIM_TYPE(MLFloat16)) return TOPKIMPL(MLFloat16); if (IS_PRIM_TYPE(float)) return TOPKIMPL(float); if (IS_PRIM_TYPE(double)) return TOPKIMPL(double); - if (IS_PRIM_TYPE(uint8_t)) return TOPKIMPL(uint8_t); return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for TopK operator"); } diff --git a/onnxruntime/core/providers/cuda/math/topk_impl.cu b/onnxruntime/core/providers/cuda/math/topk_impl.cu index 851270c798..fca3a5b06e 100644 --- a/onnxruntime/core/providers/cuda/math/topk_impl.cu +++ b/onnxruntime/core/providers/cuda/math/topk_impl.cu @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "topk_impl.h" +#include "core/framework/data_types.h" #include "core/providers/cuda/cu_inc/common.cuh" #include "device_atomic_functions.h" #include "cub/cub.cuh" @@ -163,7 +164,11 @@ __device__ __inline__ bool Equal(const double& t0, const double& t1) { template __device__ bool SamePrefix(const T* t0, const T* t1, int64_t skip) { - return (((*t0)^(*t1))>>skip) == 0; + return ((*t0)^(*t1))>>skip == 0; +} + +__device__ bool SamePrefix(const half* f0, const half* f1, int64_t skip) { + return SamePrefix((const uint16_t*)f0, (const uint16_t*)f1, skip); } __device__ bool SamePrefix(const float* f0, const float* f1, int64_t skip) { @@ -179,6 +184,10 @@ __device__ int32_t Radix(const T* t, int64_t skip) { return ((*t)>>skip)&255; } +__device__ int32_t Radix(const half* f, int64_t skip) { + return Radix((const uint16_t*)f, skip); +} + __device__ int32_t Radix(const float* f, int64_t skip) { return Radix((const int32_t*)f, skip); } @@ -192,6 +201,10 @@ __device__ void SetByte(T* t, int64_t byte) { (*t) |= byte; } +__device__ void SetByte(half* f, int64_t byte) { + SetByte((uint16_t*)f, byte); +} + __device__ void SetByte(float* f, int64_t byte) { SetByte((int32_t*)f, byte); } @@ -221,9 +234,9 @@ __global__ void RadixTopK(const T* X, T* V, int64_t* I, const TArray el uint32_t positive = 0, negative = 0; for (int64_t x_i = tid; x_i < dimension; x_i += blockDim.x) { T x = X[FROM(x_i)]; - if (x > 0) { + if (x > (T)0) { ++positive; - } else if (x < 0) { + } else if (x < (T)0) { ++negative; } } @@ -261,7 +274,7 @@ __global__ void RadixTopK(const T* X, T* V, int64_t* I, const TArray el auto skip = 8 * byte, prev_skip = 8 * (byte + 1); for (int64_t x_i = tid; x_i < dimension; x_i += blockDim.x) { T x = sign*X[FROM(x_i)]; - if (x > 0 && (byte == sizeof(T) - 1 || SamePrefix(&x, &Kth, prev_skip))) { + if (x > (T)0 && (byte == sizeof(T) - 1 || SamePrefix(&x, &Kth, prev_skip))) { atomicAdd(&H[Radix(&x, skip)], 1); } } @@ -381,24 +394,28 @@ __global__ void ExcludeOutput(int64_t* output_i, int64_t K, int64_t dimension) { template Status TopKImpl(const CudaKernel* kernel, const T* input_x, T* output_v, int64_t* output_i, const TArray& elem_nums, size_t size, int32_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t N, int64_t dimension) { + typedef typename ToCudaType::MappedType CudaT; + const CudaT* input_x_ptr = reinterpret_cast(input_x); + CudaT* output_v_ptr = reinterpret_cast(output_v); + auto aligned_K = ALIGN(K); auto aligned_dimension = ALIGN(dimension); if (aligned_dimension <= GridDim::maxThreadsPerBlock) { - BitonicTopK<<)>>>(input_x, output_v, output_i, elem_nums, size, axis, K, aligned_K, largest, sorted, dimension, aligned_dimension, std::numeric_limits::lowest(), std::numeric_limits::max()); + BitonicTopK<<)>>>(input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, aligned_K, largest, sorted, dimension, aligned_dimension, std::numeric_limits::lowest(), std::numeric_limits::max()); } else if (K <= BT*16 || 0 == sorted) { auto XPT = static_cast(ceil(static_cast(dimension) / GridDim::maxThreadsPerBlock)); if (BT*2 >= K || 0 == sorted) { - RadixTopK<<>>(input_x, output_v, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits::lowest(), std::numeric_limits::max()); + RadixTopK<<>>(input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits::lowest(), std::numeric_limits::max()); } else if (BT*4>=K) { - RadixTopK<<>>(input_x, output_v, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits::lowest(), std::numeric_limits::max()); + RadixTopK<<>>(input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits::lowest(), std::numeric_limits::max()); } else if (BT*8>=K) { - RadixTopK<<>>(input_x, output_v, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits::lowest(), std::numeric_limits::max()); + RadixTopK<<>>(input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits::lowest(), std::numeric_limits::max()); } else { - RadixTopK<<>>(input_x, output_v, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits::lowest(), std::numeric_limits::max()); + RadixTopK<<>>(input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits::lowest(), std::numeric_limits::max()); } } else { - auto input_key_buffer = kernel->GetScratchBuffer(dimension); - auto output_key_buffer = kernel->GetScratchBuffer(dimension); + auto input_key_buffer = kernel->GetScratchBuffer(dimension); + auto output_key_buffer = kernel->GetScratchBuffer(dimension); auto input_value_buffer = kernel->GetScratchBuffer(dimension); auto output_value_buffer = kernel->GetScratchBuffer(dimension); auto* input_key = input_key_buffer.get(); @@ -412,14 +429,15 @@ Status TopKImpl(const CudaKernel* kernel, const T* input_x, T* output_v, int64_t auto blocks_per_grid_D = (int)(ceil(static_cast(dimension) / BT)); auto blocks_per_grid_K = (int)(ceil(static_cast(K) / BT)); for (int64_t i = 0; i < N; i++) { - FillInput<<>>(input_x, input_key, input_value, elem_nums, size, axis, K, i, dimension); - CUDA_RETURN_IF_ERROR(1 == largest ? cub::DeviceRadixSort::SortPairsDescending(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension) : cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension)); + FillInput<<>>(input_x_ptr, input_key, input_value, elem_nums, size, axis, K, i, dimension); + CUDA_RETURN_IF_ERROR(1 == largest ? cub::DeviceRadixSort::SortPairsDescending(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension) + : cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension)); if (1 == sorted) { - FillOutput<<>>(output_key, output_value, output_v, output_i, elem_nums, size, axis, K, i, dimension); + FillOutput<<>>(output_key, output_value, output_v_ptr, output_i, elem_nums, size, axis, K, i, dimension); } else { //reorder by ascending index ExcludeOutput<<>>(output_value, K, dimension); CUDA_RETURN_IF_ERROR(cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, output_value, input_value, output_key, input_key, dimension)); - FillOutput<<>>(input_key, input_value, output_v, output_i, elem_nums, size, axis, K, i, dimension); + FillOutput<<>>(input_key, input_value, output_v_ptr, output_i, elem_nums, size, axis, K, i, dimension); } } } @@ -448,6 +466,7 @@ TOPKIMPLE(int16_t); TOPKIMPLE(int32_t); TOPKIMPLE(int64_t); TOPKIMPLE(float); +TOPKIMPLE(MLFloat16); TOPKIMPLE(double); } // namespace cuda diff --git a/onnxruntime/test/providers/cpu/math/topk_op_test.cc b/onnxruntime/test/providers/cpu/math/topk_op_test.cc index eb8f775c67..68f1eac63b 100644 --- a/onnxruntime/test/providers/cpu/math/topk_op_test.cc +++ b/onnxruntime/test/providers/cpu/math/topk_op_test.cc @@ -4,6 +4,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" +#include "test/common/cuda_op_test_utils.h" namespace onnxruntime { namespace test { @@ -438,6 +439,23 @@ TEST(TopKOperator, NthElement) { RunTest(11, 4, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false); } +TEST(TopKOperator, NthElementHalf) { + if (!HasCudaEnvironment(600)) { + return; + } + + std::vector input_vals_f = {10.0f, 8.0f, 7.0f, 4.0f, 5.0f, 6.0f}; + std::vector expected_vals_f = {10.0f, 8.0f, 7.0f, 6.0f}; + std::vector input_vals(6); + std::vector expected_vals(4); + ConvertFloatToMLFloat16(input_vals_f.data(), input_vals.data(), 6); + ConvertFloatToMLFloat16(expected_vals_f.data(), expected_vals.data(), 4); + std::vector input_dimensions = {6}; + std::vector expected_indices = {0, 1, 2, 5}; + std::vector expected_dimensions = {4}; + RunTest(11, 4, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false); +} + // test dimension in range (GridDim::maxThreadsPerBlock, GridDim::maxThreadsPerBlock * 2], ie. [257, 512] TEST(TopKOperator, SmallArrayTopKSorted) { std::vector input_vals(400, 0.0f);