From ce081fe6558d97b27d62ce8d3d69a42833e9f8df Mon Sep 17 00:00:00 2001 From: RandySheriffH <48490400+RandySheriffH@users.noreply.github.com> Date: Thu, 27 Jan 2022 16:19:55 -0800 Subject: [PATCH] Fix TopK with NAN on Cuda (#10314) * reset MIN for float/double * better logics for float/double comparision for equals --- onnxruntime/core/providers/cuda/math/topk_impl.cu | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/cuda/math/topk_impl.cu b/onnxruntime/core/providers/cuda/math/topk_impl.cu index 0254813042..90f8e8ec7b 100644 --- a/onnxruntime/core/providers/cuda/math/topk_impl.cu +++ b/onnxruntime/core/providers/cuda/math/topk_impl.cu @@ -197,13 +197,11 @@ __device__ __inline__ bool Equal(const T& t0, const T& t1) { } __device__ __inline__ bool Equal(const float& t0, const float& t1) { - auto t2 = t0 > t1 ? t0 - t1 : t1 - t0; - return t2 < std::numeric_limits::epsilon(); + return !(t0 > t1 || t1 > t0); } __device__ __inline__ bool Equal(const double& t0, const double& t1) { - auto t2 = t0 > t1 ? t0 - t1 : t1 - t0; - return t2 < std::numeric_limits::epsilon(); + return !(t0 > t1 || t1 > t0); } template @@ -305,7 +303,7 @@ __global__ void RadixTopK(const T* X, T* V, int64_t* I, const TArray el } } else { if (KK > negative) { - KK = dimension - KK + 1; + KK = dimension - KK + 1; } else { sign = (T)-1; }