adding fp16 support for topk cuda kernel (#6082)

* adding fp16 support for topk.

* disable fp16 tests for cpu ep

Co-authored-by: Du Li <duli@OrtTrainingDev0.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
Du Li 2020-12-10 11:04:19 -08:00 committed by GitHub
parent 7ddeafdfcc
commit e945b5fcf6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 53 additions and 16 deletions

View file

@ -95,9 +95,9 @@ Status TopK<inputk>::ComputeInternal(OpKernelContext* ctx) const {
if (IS_PRIM_TYPE(int16_t)) return TOPKIMPL(int16_t);
if (IS_PRIM_TYPE(int32_t)) return TOPKIMPL(int32_t);
if (IS_PRIM_TYPE(int64_t)) return TOPKIMPL(int64_t);
if (IS_PRIM_TYPE(MLFloat16)) return TOPKIMPL(MLFloat16);
if (IS_PRIM_TYPE(float)) return TOPKIMPL(float);
if (IS_PRIM_TYPE(double)) return TOPKIMPL(double);
if (IS_PRIM_TYPE(uint8_t)) return TOPKIMPL(uint8_t);
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for TopK operator");
}

View file

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "topk_impl.h"
#include "core/framework/data_types.h"
#include "core/providers/cuda/cu_inc/common.cuh"
#include "device_atomic_functions.h"
#include "cub/cub.cuh"
@ -163,7 +164,11 @@ __device__ __inline__ bool Equal(const double& t0, const double& t1) {
template<typename T>
__device__ bool SamePrefix(const T* t0, const T* t1, int64_t skip) {
return (((*t0)^(*t1))>>skip) == 0;
return ((*t0)^(*t1))>>skip == 0;
}
__device__ bool SamePrefix(const half* f0, const half* f1, int64_t skip) {
return SamePrefix((const uint16_t*)f0, (const uint16_t*)f1, skip);
}
__device__ bool SamePrefix(const float* f0, const float* f1, int64_t skip) {
@ -179,6 +184,10 @@ __device__ int32_t Radix(const T* t, int64_t skip) {
return ((*t)>>skip)&255;
}
__device__ int32_t Radix(const half* f, int64_t skip) {
return Radix((const uint16_t*)f, skip);
}
__device__ int32_t Radix(const float* f, int64_t skip) {
return Radix((const int32_t*)f, skip);
}
@ -192,6 +201,10 @@ __device__ void SetByte(T* t, int64_t byte) {
(*t) |= byte;
}
__device__ void SetByte(half* f, int64_t byte) {
SetByte((uint16_t*)f, byte);
}
__device__ void SetByte(float* f, int64_t byte) {
SetByte((int32_t*)f, byte);
}
@ -221,9 +234,9 @@ __global__ void RadixTopK(const T* X, T* V, int64_t* I, const TArray<int64_t> el
uint32_t positive = 0, negative = 0;
for (int64_t x_i = tid; x_i < dimension; x_i += blockDim.x) {
T x = X[FROM(x_i)];
if (x > 0) {
if (x > (T)0) {
++positive;
} else if (x < 0) {
} else if (x < (T)0) {
++negative;
}
}
@ -261,7 +274,7 @@ __global__ void RadixTopK(const T* X, T* V, int64_t* I, const TArray<int64_t> el
auto skip = 8 * byte, prev_skip = 8 * (byte + 1);
for (int64_t x_i = tid; x_i < dimension; x_i += blockDim.x) {
T x = sign*X[FROM(x_i)];
if (x > 0 && (byte == sizeof(T) - 1 || SamePrefix(&x, &Kth, prev_skip))) {
if (x > (T)0 && (byte == sizeof(T) - 1 || SamePrefix(&x, &Kth, prev_skip))) {
atomicAdd(&H[Radix(&x, skip)], 1);
}
}
@ -381,24 +394,28 @@ __global__ void ExcludeOutput(int64_t* output_i, int64_t K, int64_t dimension) {
template <typename T>
Status TopKImpl(const CudaKernel* kernel, const T* input_x, T* output_v, int64_t* output_i, const TArray<int64_t>& elem_nums, size_t size, int32_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t N, int64_t dimension) {
typedef typename ToCudaType<T>::MappedType CudaT;
const CudaT* input_x_ptr = reinterpret_cast<const CudaT*>(input_x);
CudaT* output_v_ptr = reinterpret_cast<CudaT*>(output_v);
auto aligned_K = ALIGN(K);
auto aligned_dimension = ALIGN(dimension);
if (aligned_dimension <= GridDim::maxThreadsPerBlock) {
BitonicTopK<T><<<N, GridDim::maxThreadsPerBlock, aligned_dimension * sizeof(KV<T>)>>>(input_x, output_v, output_i, elem_nums, size, axis, K, aligned_K, largest, sorted, dimension, aligned_dimension, std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max());
BitonicTopK<CudaT><<<N, GridDim::maxThreadsPerBlock, aligned_dimension * sizeof(KV<CudaT>)>>>(input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, aligned_K, largest, sorted, dimension, aligned_dimension, std::numeric_limits<CudaT>::lowest(), std::numeric_limits<CudaT>::max());
} else if (K <= BT*16 || 0 == sorted) {
auto XPT = static_cast<int64_t>(ceil(static_cast<double>(dimension) / GridDim::maxThreadsPerBlock));
if (BT*2 >= K || 0 == sorted) {
RadixTopK<T,BT,2><<<N,BT,256*sizeof(uint32_t)>>>(input_x, output_v, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max());
RadixTopK<CudaT,BT,2><<<N,BT,256*sizeof(uint32_t)>>>(input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits<CudaT>::lowest(), std::numeric_limits<CudaT>::max());
} else if (BT*4>=K) {
RadixTopK<T,BT,4><<<N,BT,256*sizeof(uint32_t)>>>(input_x, output_v, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max());
RadixTopK<CudaT,BT,4><<<N,BT,256*sizeof(uint32_t)>>>(input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits<CudaT>::lowest(), std::numeric_limits<CudaT>::max());
} else if (BT*8>=K) {
RadixTopK<T,BT,8><<<N,BT,256*sizeof(uint32_t)>>>(input_x, output_v, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max());
RadixTopK<CudaT,BT,8><<<N,BT,256*sizeof(uint32_t)>>>(input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits<CudaT>::lowest(), std::numeric_limits<CudaT>::max());
} else {
RadixTopK<T,BT,16><<<N,BT,256*sizeof(uint32_t)>>>(input_x, output_v, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max());
RadixTopK<CudaT,BT,16><<<N,BT,256*sizeof(uint32_t)>>>(input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits<CudaT>::lowest(), std::numeric_limits<CudaT>::max());
}
} else {
auto input_key_buffer = kernel->GetScratchBuffer<T>(dimension);
auto output_key_buffer = kernel->GetScratchBuffer<T>(dimension);
auto input_key_buffer = kernel->GetScratchBuffer<CudaT>(dimension);
auto output_key_buffer = kernel->GetScratchBuffer<CudaT>(dimension);
auto input_value_buffer = kernel->GetScratchBuffer<int64_t>(dimension);
auto output_value_buffer = kernel->GetScratchBuffer<int64_t>(dimension);
auto* input_key = input_key_buffer.get();
@ -412,14 +429,15 @@ Status TopKImpl(const CudaKernel* kernel, const T* input_x, T* output_v, int64_t
auto blocks_per_grid_D = (int)(ceil(static_cast<float>(dimension) / BT));
auto blocks_per_grid_K = (int)(ceil(static_cast<float>(K) / BT));
for (int64_t i = 0; i < N; i++) {
FillInput<T><<<blocks_per_grid_D, BT, 0>>>(input_x, input_key, input_value, elem_nums, size, axis, K, i, dimension);
CUDA_RETURN_IF_ERROR(1 == largest ? cub::DeviceRadixSort::SortPairsDescending(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension) : cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension));
FillInput<CudaT><<<blocks_per_grid_D, BT, 0>>>(input_x_ptr, input_key, input_value, elem_nums, size, axis, K, i, dimension);
CUDA_RETURN_IF_ERROR(1 == largest ? cub::DeviceRadixSort::SortPairsDescending(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension)
: cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension));
if (1 == sorted) {
FillOutput<T><<<blocks_per_grid_K, BT, 0>>>(output_key, output_value, output_v, output_i, elem_nums, size, axis, K, i, dimension);
FillOutput<CudaT><<<blocks_per_grid_K, BT, 0>>>(output_key, output_value, output_v_ptr, output_i, elem_nums, size, axis, K, i, dimension);
} else { //reorder by ascending index
ExcludeOutput<<<blocks_per_grid_D, BT, 0>>>(output_value, K, dimension);
CUDA_RETURN_IF_ERROR(cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, output_value, input_value, output_key, input_key, dimension));
FillOutput<T><<<blocks_per_grid_K, BT, 0>>>(input_key, input_value, output_v, output_i, elem_nums, size, axis, K, i, dimension);
FillOutput<CudaT><<<blocks_per_grid_K, BT, 0>>>(input_key, input_value, output_v_ptr, output_i, elem_nums, size, axis, K, i, dimension);
}
}
}
@ -448,6 +466,7 @@ TOPKIMPLE(int16_t);
TOPKIMPLE(int32_t);
TOPKIMPLE(int64_t);
TOPKIMPLE(float);
TOPKIMPLE(MLFloat16);
TOPKIMPLE(double);
} // namespace cuda

View file

@ -4,6 +4,7 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
#include "test/common/cuda_op_test_utils.h"
namespace onnxruntime {
namespace test {
@ -438,6 +439,23 @@ TEST(TopKOperator, NthElement) {
RunTest(11, 4, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false);
}
TEST(TopKOperator, NthElementHalf) {
if (!HasCudaEnvironment(600)) {
return;
}
std::vector<float> input_vals_f = {10.0f, 8.0f, 7.0f, 4.0f, 5.0f, 6.0f};
std::vector<float> expected_vals_f = {10.0f, 8.0f, 7.0f, 6.0f};
std::vector<MLFloat16> input_vals(6);
std::vector<MLFloat16> expected_vals(4);
ConvertFloatToMLFloat16(input_vals_f.data(), input_vals.data(), 6);
ConvertFloatToMLFloat16(expected_vals_f.data(), expected_vals.data(), 4);
std::vector<int64_t> input_dimensions = {6};
std::vector<int64_t> expected_indices = {0, 1, 2, 5};
std::vector<int64_t> expected_dimensions = {4};
RunTest(11, 4, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false);
}
// test dimension in range (GridDim::maxThreadsPerBlock, GridDim::maxThreadsPerBlock * 2], ie. [257, 512]
TEST(TopKOperator, SmallArrayTopKSorted) {
std::vector<float> input_vals(400, 0.0f);