mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
adding fp16 support for topk cuda kernel (#6082)
* adding fp16 support for topk. * disable fp16 tests for cpu ep Co-authored-by: Du Li <duli@OrtTrainingDev0.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
parent
7ddeafdfcc
commit
e945b5fcf6
3 changed files with 53 additions and 16 deletions
|
|
@ -95,9 +95,9 @@ Status TopK<inputk>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
if (IS_PRIM_TYPE(int16_t)) return TOPKIMPL(int16_t);
|
||||
if (IS_PRIM_TYPE(int32_t)) return TOPKIMPL(int32_t);
|
||||
if (IS_PRIM_TYPE(int64_t)) return TOPKIMPL(int64_t);
|
||||
if (IS_PRIM_TYPE(MLFloat16)) return TOPKIMPL(MLFloat16);
|
||||
if (IS_PRIM_TYPE(float)) return TOPKIMPL(float);
|
||||
if (IS_PRIM_TYPE(double)) return TOPKIMPL(double);
|
||||
if (IS_PRIM_TYPE(uint8_t)) return TOPKIMPL(uint8_t);
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for TopK operator");
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "topk_impl.h"
|
||||
#include "core/framework/data_types.h"
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "device_atomic_functions.h"
|
||||
#include "cub/cub.cuh"
|
||||
|
|
@ -163,7 +164,11 @@ __device__ __inline__ bool Equal(const double& t0, const double& t1) {
|
|||
|
||||
template<typename T>
|
||||
__device__ bool SamePrefix(const T* t0, const T* t1, int64_t skip) {
|
||||
return (((*t0)^(*t1))>>skip) == 0;
|
||||
return ((*t0)^(*t1))>>skip == 0;
|
||||
}
|
||||
|
||||
__device__ bool SamePrefix(const half* f0, const half* f1, int64_t skip) {
|
||||
return SamePrefix((const uint16_t*)f0, (const uint16_t*)f1, skip);
|
||||
}
|
||||
|
||||
__device__ bool SamePrefix(const float* f0, const float* f1, int64_t skip) {
|
||||
|
|
@ -179,6 +184,10 @@ __device__ int32_t Radix(const T* t, int64_t skip) {
|
|||
return ((*t)>>skip)&255;
|
||||
}
|
||||
|
||||
__device__ int32_t Radix(const half* f, int64_t skip) {
|
||||
return Radix((const uint16_t*)f, skip);
|
||||
}
|
||||
|
||||
__device__ int32_t Radix(const float* f, int64_t skip) {
|
||||
return Radix((const int32_t*)f, skip);
|
||||
}
|
||||
|
|
@ -192,6 +201,10 @@ __device__ void SetByte(T* t, int64_t byte) {
|
|||
(*t) |= byte;
|
||||
}
|
||||
|
||||
__device__ void SetByte(half* f, int64_t byte) {
|
||||
SetByte((uint16_t*)f, byte);
|
||||
}
|
||||
|
||||
__device__ void SetByte(float* f, int64_t byte) {
|
||||
SetByte((int32_t*)f, byte);
|
||||
}
|
||||
|
|
@ -221,9 +234,9 @@ __global__ void RadixTopK(const T* X, T* V, int64_t* I, const TArray<int64_t> el
|
|||
uint32_t positive = 0, negative = 0;
|
||||
for (int64_t x_i = tid; x_i < dimension; x_i += blockDim.x) {
|
||||
T x = X[FROM(x_i)];
|
||||
if (x > 0) {
|
||||
if (x > (T)0) {
|
||||
++positive;
|
||||
} else if (x < 0) {
|
||||
} else if (x < (T)0) {
|
||||
++negative;
|
||||
}
|
||||
}
|
||||
|
|
@ -261,7 +274,7 @@ __global__ void RadixTopK(const T* X, T* V, int64_t* I, const TArray<int64_t> el
|
|||
auto skip = 8 * byte, prev_skip = 8 * (byte + 1);
|
||||
for (int64_t x_i = tid; x_i < dimension; x_i += blockDim.x) {
|
||||
T x = sign*X[FROM(x_i)];
|
||||
if (x > 0 && (byte == sizeof(T) - 1 || SamePrefix(&x, &Kth, prev_skip))) {
|
||||
if (x > (T)0 && (byte == sizeof(T) - 1 || SamePrefix(&x, &Kth, prev_skip))) {
|
||||
atomicAdd(&H[Radix(&x, skip)], 1);
|
||||
}
|
||||
}
|
||||
|
|
@ -381,24 +394,28 @@ __global__ void ExcludeOutput(int64_t* output_i, int64_t K, int64_t dimension) {
|
|||
|
||||
template <typename T>
|
||||
Status TopKImpl(const CudaKernel* kernel, const T* input_x, T* output_v, int64_t* output_i, const TArray<int64_t>& elem_nums, size_t size, int32_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t N, int64_t dimension) {
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
const CudaT* input_x_ptr = reinterpret_cast<const CudaT*>(input_x);
|
||||
CudaT* output_v_ptr = reinterpret_cast<CudaT*>(output_v);
|
||||
|
||||
auto aligned_K = ALIGN(K);
|
||||
auto aligned_dimension = ALIGN(dimension);
|
||||
if (aligned_dimension <= GridDim::maxThreadsPerBlock) {
|
||||
BitonicTopK<T><<<N, GridDim::maxThreadsPerBlock, aligned_dimension * sizeof(KV<T>)>>>(input_x, output_v, output_i, elem_nums, size, axis, K, aligned_K, largest, sorted, dimension, aligned_dimension, std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max());
|
||||
BitonicTopK<CudaT><<<N, GridDim::maxThreadsPerBlock, aligned_dimension * sizeof(KV<CudaT>)>>>(input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, aligned_K, largest, sorted, dimension, aligned_dimension, std::numeric_limits<CudaT>::lowest(), std::numeric_limits<CudaT>::max());
|
||||
} else if (K <= BT*16 || 0 == sorted) {
|
||||
auto XPT = static_cast<int64_t>(ceil(static_cast<double>(dimension) / GridDim::maxThreadsPerBlock));
|
||||
if (BT*2 >= K || 0 == sorted) {
|
||||
RadixTopK<T,BT,2><<<N,BT,256*sizeof(uint32_t)>>>(input_x, output_v, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max());
|
||||
RadixTopK<CudaT,BT,2><<<N,BT,256*sizeof(uint32_t)>>>(input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits<CudaT>::lowest(), std::numeric_limits<CudaT>::max());
|
||||
} else if (BT*4>=K) {
|
||||
RadixTopK<T,BT,4><<<N,BT,256*sizeof(uint32_t)>>>(input_x, output_v, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max());
|
||||
RadixTopK<CudaT,BT,4><<<N,BT,256*sizeof(uint32_t)>>>(input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits<CudaT>::lowest(), std::numeric_limits<CudaT>::max());
|
||||
} else if (BT*8>=K) {
|
||||
RadixTopK<T,BT,8><<<N,BT,256*sizeof(uint32_t)>>>(input_x, output_v, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max());
|
||||
RadixTopK<CudaT,BT,8><<<N,BT,256*sizeof(uint32_t)>>>(input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits<CudaT>::lowest(), std::numeric_limits<CudaT>::max());
|
||||
} else {
|
||||
RadixTopK<T,BT,16><<<N,BT,256*sizeof(uint32_t)>>>(input_x, output_v, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max());
|
||||
RadixTopK<CudaT,BT,16><<<N,BT,256*sizeof(uint32_t)>>>(input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits<CudaT>::lowest(), std::numeric_limits<CudaT>::max());
|
||||
}
|
||||
} else {
|
||||
auto input_key_buffer = kernel->GetScratchBuffer<T>(dimension);
|
||||
auto output_key_buffer = kernel->GetScratchBuffer<T>(dimension);
|
||||
auto input_key_buffer = kernel->GetScratchBuffer<CudaT>(dimension);
|
||||
auto output_key_buffer = kernel->GetScratchBuffer<CudaT>(dimension);
|
||||
auto input_value_buffer = kernel->GetScratchBuffer<int64_t>(dimension);
|
||||
auto output_value_buffer = kernel->GetScratchBuffer<int64_t>(dimension);
|
||||
auto* input_key = input_key_buffer.get();
|
||||
|
|
@ -412,14 +429,15 @@ Status TopKImpl(const CudaKernel* kernel, const T* input_x, T* output_v, int64_t
|
|||
auto blocks_per_grid_D = (int)(ceil(static_cast<float>(dimension) / BT));
|
||||
auto blocks_per_grid_K = (int)(ceil(static_cast<float>(K) / BT));
|
||||
for (int64_t i = 0; i < N; i++) {
|
||||
FillInput<T><<<blocks_per_grid_D, BT, 0>>>(input_x, input_key, input_value, elem_nums, size, axis, K, i, dimension);
|
||||
CUDA_RETURN_IF_ERROR(1 == largest ? cub::DeviceRadixSort::SortPairsDescending(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension) : cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension));
|
||||
FillInput<CudaT><<<blocks_per_grid_D, BT, 0>>>(input_x_ptr, input_key, input_value, elem_nums, size, axis, K, i, dimension);
|
||||
CUDA_RETURN_IF_ERROR(1 == largest ? cub::DeviceRadixSort::SortPairsDescending(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension)
|
||||
: cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension));
|
||||
if (1 == sorted) {
|
||||
FillOutput<T><<<blocks_per_grid_K, BT, 0>>>(output_key, output_value, output_v, output_i, elem_nums, size, axis, K, i, dimension);
|
||||
FillOutput<CudaT><<<blocks_per_grid_K, BT, 0>>>(output_key, output_value, output_v_ptr, output_i, elem_nums, size, axis, K, i, dimension);
|
||||
} else { //reorder by ascending index
|
||||
ExcludeOutput<<<blocks_per_grid_D, BT, 0>>>(output_value, K, dimension);
|
||||
CUDA_RETURN_IF_ERROR(cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, output_value, input_value, output_key, input_key, dimension));
|
||||
FillOutput<T><<<blocks_per_grid_K, BT, 0>>>(input_key, input_value, output_v, output_i, elem_nums, size, axis, K, i, dimension);
|
||||
FillOutput<CudaT><<<blocks_per_grid_K, BT, 0>>>(input_key, input_value, output_v_ptr, output_i, elem_nums, size, axis, K, i, dimension);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -448,6 +466,7 @@ TOPKIMPLE(int16_t);
|
|||
TOPKIMPLE(int32_t);
|
||||
TOPKIMPLE(int64_t);
|
||||
TOPKIMPLE(float);
|
||||
TOPKIMPLE(MLFloat16);
|
||||
TOPKIMPLE(double);
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include "gmock/gmock.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "test/common/cuda_op_test_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
|
@ -438,6 +439,23 @@ TEST(TopKOperator, NthElement) {
|
|||
RunTest(11, 4, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false);
|
||||
}
|
||||
|
||||
TEST(TopKOperator, NthElementHalf) {
|
||||
if (!HasCudaEnvironment(600)) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<float> input_vals_f = {10.0f, 8.0f, 7.0f, 4.0f, 5.0f, 6.0f};
|
||||
std::vector<float> expected_vals_f = {10.0f, 8.0f, 7.0f, 6.0f};
|
||||
std::vector<MLFloat16> input_vals(6);
|
||||
std::vector<MLFloat16> expected_vals(4);
|
||||
ConvertFloatToMLFloat16(input_vals_f.data(), input_vals.data(), 6);
|
||||
ConvertFloatToMLFloat16(expected_vals_f.data(), expected_vals.data(), 4);
|
||||
std::vector<int64_t> input_dimensions = {6};
|
||||
std::vector<int64_t> expected_indices = {0, 1, 2, 5};
|
||||
std::vector<int64_t> expected_dimensions = {4};
|
||||
RunTest(11, 4, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false);
|
||||
}
|
||||
|
||||
// test dimension in range (GridDim::maxThreadsPerBlock, GridDim::maxThreadsPerBlock * 2], ie. [257, 512]
|
||||
TEST(TopKOperator, SmallArrayTopKSorted) {
|
||||
std::vector<float> input_vals(400, 0.0f);
|
||||
|
|
|
|||
Loading…
Reference in a new issue