mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
[CUDA] Fix NumericLimits (#22738)
### Description * Fix `NumericLimits<float>` that used infinity as max, which is not consistent with `std::numeric_limits<float>::max()` In Windows, (float)(1e+300) is used for INFINITY, which causes compiler error in Visual Studio 2022 v17.12 Preview 5. * Rename `NumericLimits<T>::Min` to Lowest to be consistent with std::numeric_limits * Fix topk implementation: use `NumericLimits<CudaT>` instead of `NumericLimits<T>` in kernel. That could avoid defining a confusing defintion of `NumericLimits<MLFloat16>` that returns half instead of MLFloat16. * Use CUDART_MAX_NORMAL_FP16 if possible. It sets bits value directly, which is faster than converting float to half. Note that NumericLimits does not support __nv_bfloat16 and _nv_fp8_e4m3 and __nv_fp8_e5m2 right now. ### Motivation and Context https://github.com/microsoft/onnxruntime/issues/22728
This commit is contained in:
parent
1cb5ceedf3
commit
d993ec313f
4 changed files with 21 additions and 41 deletions
|
|
@ -60,7 +60,7 @@ struct TopK {
|
|||
__device__ __forceinline__ void Init() {
|
||||
for (int i = 0; i < max_k; i++) {
|
||||
key[i] = -1;
|
||||
value[i] = NumericLimits<T>::Min();
|
||||
value[i] = NumericLimits<T>::Lowest();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
#include <cub/cub.cuh>
|
||||
|
||||
|
||||
#include "core/providers/cuda/shared_inc/cuda_utils.h"
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
|
||||
|
|
@ -19,7 +20,10 @@ struct TopOne {
|
|||
int32_t key;
|
||||
T value;
|
||||
|
||||
__device__ __host__ __forceinline__ TopOne(int32_t key = -1, T value = NumericLimits<T>::Min()) : key(key), value(value) {
|
||||
__device__ __host__ __forceinline__ TopOne() : key(-1), value(NumericLimits<T>::Lowest()) {
|
||||
}
|
||||
|
||||
__device__ __host__ __forceinline__ TopOne(int32_t key, T value) : key(key), value(value) {
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void Reduce(int32_t k, T v) {
|
||||
|
|
|
|||
|
|
@ -412,7 +412,7 @@ Status TopKImpl(const CudaKernel* kernel, bool use_deterministic_compute,
|
|||
if (aligned_dimension <= GridDim::maxThreadsPerBlock) {
|
||||
BitonicTopK<CudaT><<<N, GridDim::maxThreadsPerBlock, aligned_dimension * sizeof(KV<CudaT>), stream>>>(
|
||||
input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, aligned_K, largest, sorted, dimension,
|
||||
aligned_dimension, NumericLimits<T>::Min(), NumericLimits<T>::Max());
|
||||
aligned_dimension, NumericLimits<CudaT>::Lowest(), NumericLimits<CudaT>::Max());
|
||||
} else if (K <= BT * 16 || 0 == sorted) {
|
||||
if (use_deterministic_compute) {
|
||||
static std::once_flag log_warning;
|
||||
|
|
@ -425,19 +425,19 @@ Status TopKImpl(const CudaKernel* kernel, bool use_deterministic_compute,
|
|||
if (BT * 2 >= K || 0 == sorted) {
|
||||
RadixTopK<CudaT, BT, 2><<<N, BT, 256 * sizeof(uint32_t), stream>>>(
|
||||
input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT,
|
||||
NumericLimits<T>::Min(), NumericLimits<T>::Max());
|
||||
NumericLimits<CudaT>::Lowest(), NumericLimits<CudaT>::Max());
|
||||
} else if (BT * 4 >= K) {
|
||||
RadixTopK<CudaT, BT, 4><<<N, BT, 256 * sizeof(uint32_t), stream>>>(
|
||||
input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT,
|
||||
NumericLimits<T>::Min(), NumericLimits<T>::Max());
|
||||
NumericLimits<CudaT>::Lowest(), NumericLimits<CudaT>::Max());
|
||||
} else if (BT * 8 >= K) {
|
||||
RadixTopK<CudaT, BT, 8><<<N, BT, 256 * sizeof(uint32_t), stream>>>(
|
||||
input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT,
|
||||
NumericLimits<T>::Min(), NumericLimits<T>::Max());
|
||||
NumericLimits<CudaT>::Lowest(), NumericLimits<CudaT>::Max());
|
||||
} else {
|
||||
RadixTopK<CudaT, BT, 16><<<N, BT, 256 * sizeof(uint32_t), stream>>>(
|
||||
input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT,
|
||||
NumericLimits<T>::Min(), NumericLimits<T>::Max());
|
||||
NumericLimits<CudaT>::Lowest(), NumericLimits<CudaT>::Max());
|
||||
}
|
||||
} else {
|
||||
auto input_key_buffer = kernel->GetScratchBuffer<CudaT>(dimension, ort_stream);
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
|
||||
#include <gsl/gsl>
|
||||
#include "core/framework/float16.h"
|
||||
|
|
@ -120,7 +121,7 @@ constexpr int kNumBitsPerBitmaskElement = std::numeric_limits<BitmaskElementType
|
|||
|
||||
template <typename T>
|
||||
struct NumericLimits {
|
||||
__inline__ __host__ __device__ static T Min() {
|
||||
__inline__ __host__ __device__ static T Lowest() {
|
||||
return std::numeric_limits<T>::lowest();
|
||||
}
|
||||
__inline__ __host__ __device__ static T Max() {
|
||||
|
|
@ -128,43 +129,18 @@ struct NumericLimits {
|
|||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<MLFloat16> {
|
||||
__inline__ __host__ __device__ static half Min() {
|
||||
return -65504.0;
|
||||
}
|
||||
__inline__ __host__ __device__ static half Max() {
|
||||
return 65504.0;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<half> {
|
||||
__inline__ __host__ __device__ static half Min() {
|
||||
return -65504.0;
|
||||
__inline__ __host__ __device__ static half Lowest() {
|
||||
return -65504.0f;
|
||||
}
|
||||
|
||||
__inline__ __host__ __device__ static half Max() {
|
||||
return 65504.0;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<float> {
|
||||
__inline__ __host__ __device__ static float Min() {
|
||||
return -INFINITY;
|
||||
}
|
||||
__inline__ __host__ __device__ static float Max() {
|
||||
return INFINITY;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<double> {
|
||||
__inline__ __host__ __device__ static double Min() {
|
||||
return -HUGE_VAL;
|
||||
}
|
||||
__inline__ __host__ __device__ static double Max() {
|
||||
return HUGE_VAL;
|
||||
#ifdef CUDART_MAX_NORMAL_FP16 // defined in cuda 12.3 or later
|
||||
return CUDART_MAX_NORMAL_FP16;
|
||||
#else
|
||||
return 65504.0f;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue