From d993ec313f73c5ee6a655d13690e0770a0d6cd6e Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 6 Nov 2024 09:53:49 -0800 Subject: [PATCH] [CUDA] Fix NumericLimits (#22738) ### Description * Fix `NumericLimits` that used infinity as max, which is not consistent with `std::numeric_limits::max()` In Windows, (float)(1e+300) is used for INFINITY, which causes compiler error in Visual Studio 2022 v17.12 Preview 5. * Rename `NumericLimits::Min` to Lowest to be consistent with std::numeric_limits * Fix topk implementation: use `NumericLimits` instead of `NumericLimits` in kernel. That could avoid defining a confusing defintion of `NumericLimits` that returns half instead of MLFloat16. * Use CUDART_MAX_NORMAL_FP16 if possible. It sets bits value directly, which is faster than converting float to half. Note that NumericLimits does not support __nv_bfloat16 and _nv_fp8_e4m3 and __nv_fp8_e5m2 right now. ### Motivation and Context https://github.com/microsoft/onnxruntime/issues/22728 --- .../cuda/transformers/beam_search_topk.cu | 2 +- .../transformers/greedy_search_top_one.cu | 6 ++- .../core/providers/cuda/math/topk_impl.cuh | 10 ++--- .../providers/cuda/shared_inc/cuda_utils.h | 44 +++++-------------- 4 files changed, 21 insertions(+), 41 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu b/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu index 5ac10f6321..44be2ef237 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu @@ -60,7 +60,7 @@ struct TopK { __device__ __forceinline__ void Init() { for (int i = 0; i < max_k; i++) { key[i] = -1; - value[i] = NumericLimits::Min(); + value[i] = NumericLimits::Lowest(); } } }; diff --git a/onnxruntime/contrib_ops/cuda/transformers/greedy_search_top_one.cu b/onnxruntime/contrib_ops/cuda/transformers/greedy_search_top_one.cu index 68a2e16482..b2969194ff 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/greedy_search_top_one.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/greedy_search_top_one.cu @@ -5,6 +5,7 @@ #include + #include "core/providers/cuda/shared_inc/cuda_utils.h" #include "core/providers/cuda/cu_inc/common.cuh" @@ -19,7 +20,10 @@ struct TopOne { int32_t key; T value; - __device__ __host__ __forceinline__ TopOne(int32_t key = -1, T value = NumericLimits::Min()) : key(key), value(value) { + __device__ __host__ __forceinline__ TopOne() : key(-1), value(NumericLimits::Lowest()) { + } + + __device__ __host__ __forceinline__ TopOne(int32_t key, T value) : key(key), value(value) { } __device__ __forceinline__ void Reduce(int32_t k, T v) { diff --git a/onnxruntime/core/providers/cuda/math/topk_impl.cuh b/onnxruntime/core/providers/cuda/math/topk_impl.cuh index cbde6da457..112566e54b 100644 --- a/onnxruntime/core/providers/cuda/math/topk_impl.cuh +++ b/onnxruntime/core/providers/cuda/math/topk_impl.cuh @@ -412,7 +412,7 @@ Status TopKImpl(const CudaKernel* kernel, bool use_deterministic_compute, if (aligned_dimension <= GridDim::maxThreadsPerBlock) { BitonicTopK<<), stream>>>( input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, aligned_K, largest, sorted, dimension, - aligned_dimension, NumericLimits::Min(), NumericLimits::Max()); + aligned_dimension, NumericLimits::Lowest(), NumericLimits::Max()); } else if (K <= BT * 16 || 0 == sorted) { if (use_deterministic_compute) { static std::once_flag log_warning; @@ -425,19 +425,19 @@ Status TopKImpl(const CudaKernel* kernel, bool use_deterministic_compute, if (BT * 2 >= K || 0 == sorted) { RadixTopK<<>>( input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, - NumericLimits::Min(), NumericLimits::Max()); + NumericLimits::Lowest(), NumericLimits::Max()); } else if (BT * 4 >= K) { RadixTopK<<>>( input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, - NumericLimits::Min(), NumericLimits::Max()); + NumericLimits::Lowest(), NumericLimits::Max()); } else if (BT * 8 >= K) { RadixTopK<<>>( input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, - NumericLimits::Min(), NumericLimits::Max()); + NumericLimits::Lowest(), NumericLimits::Max()); } else { RadixTopK<<>>( input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, - NumericLimits::Min(), NumericLimits::Max()); + NumericLimits::Lowest(), NumericLimits::Max()); } } else { auto input_key_buffer = kernel->GetScratchBuffer(dimension, ort_stream); diff --git a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h index ed642754af..f9433642f0 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h +++ b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include "core/framework/float16.h" @@ -120,7 +121,7 @@ constexpr int kNumBitsPerBitmaskElement = std::numeric_limits struct NumericLimits { - __inline__ __host__ __device__ static T Min() { + __inline__ __host__ __device__ static T Lowest() { return std::numeric_limits::lowest(); } __inline__ __host__ __device__ static T Max() { @@ -128,43 +129,18 @@ struct NumericLimits { } }; -template <> -struct NumericLimits { - __inline__ __host__ __device__ static half Min() { - return -65504.0; - } - __inline__ __host__ __device__ static half Max() { - return 65504.0; - } -}; - template <> struct NumericLimits { - __inline__ __host__ __device__ static half Min() { - return -65504.0; + __inline__ __host__ __device__ static half Lowest() { + return -65504.0f; } + __inline__ __host__ __device__ static half Max() { - return 65504.0; - } -}; - -template <> -struct NumericLimits { - __inline__ __host__ __device__ static float Min() { - return -INFINITY; - } - __inline__ __host__ __device__ static float Max() { - return INFINITY; - } -}; - -template <> -struct NumericLimits { - __inline__ __host__ __device__ static double Min() { - return -HUGE_VAL; - } - __inline__ __host__ __device__ static double Max() { - return HUGE_VAL; +#ifdef CUDART_MAX_NORMAL_FP16 // defined in cuda 12.3 or later + return CUDART_MAX_NORMAL_FP16; +#else + return 65504.0f; +#endif } };