From 86eaa71ec69fe8961dd58a364052d61831dae5ed Mon Sep 17 00:00:00 2001 From: RandySheriffH <48490400+RandySheriffH@users.noreply.github.com> Date: Thu, 30 Apr 2020 14:16:46 -0700 Subject: [PATCH] sync threads before calling next cub function (#3758) Co-authored-by: RandySheriffH --- onnxruntime/core/providers/cuda/math/topk_impl.cu | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cuda/math/topk_impl.cu b/onnxruntime/core/providers/cuda/math/topk_impl.cu index 8f447644ba..851270c798 100644 --- a/onnxruntime/core/providers/cuda/math/topk_impl.cu +++ b/onnxruntime/core/providers/cuda/math/topk_impl.cu @@ -148,8 +148,17 @@ __global__ void BitonicTopK(const T* X, T* V, int64_t* I, const TArray template __device__ __inline__ bool Equal(const T& t0, const T& t1) { + return t0 == t1; +} + +__device__ __inline__ bool Equal(const float& t0, const float& t1) { auto t2 = t0 > t1 ? t0 - t1 : t1 - t0; - return (double)t2 < 1.0e-5; + return t2 < std::numeric_limits::epsilon(); +} + +__device__ __inline__ bool Equal(const double& t0, const double& t1) { + auto t2 = t0 > t1 ? t0 - t1 : t1 - t0; + return t2 < std::numeric_limits::epsilon(); } template @@ -220,6 +229,7 @@ __global__ void RadixTopK(const T* X, T* V, int64_t* I, const TArray el } __syncthreads(); positive = BlockReduce(temp_storage.reduce).Sum(positive); + __syncthreads(); negative = BlockReduce(temp_storage.reduce).Sum(negative); if (0 == tid) { H[0] = positive; @@ -286,6 +296,7 @@ __global__ void RadixTopK(const T* X, T* V, int64_t* I, const TArray el __syncthreads(); all_superior = H[0]; BlockScan(temp_storage.scan).ExclusiveSum(superior, superior); + __syncthreads(); BlockScan(temp_storage.scan).ExclusiveSum(equal, equal); __syncthreads(); auto equal_quota = K - all_superior - equal;