sync threads before calling next cub function (#3758)

Co-authored-by: RandySheriffH <rashuai@microsoft.com>
This commit is contained in:
RandySheriffH 2020-04-30 14:16:46 -07:00 committed by GitHub
parent af3988198c
commit 86eaa71ec6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -148,8 +148,17 @@ __global__ void BitonicTopK(const T* X, T* V, int64_t* I, const TArray<int64_t>
template <typename T>
__device__ __inline__ bool Equal(const T& t0, const T& t1) {
return t0 == t1;
}
__device__ __inline__ bool Equal(const float& t0, const float& t1) {
auto t2 = t0 > t1 ? t0 - t1 : t1 - t0;
return (double)t2 < 1.0e-5;
return t2 < std::numeric_limits<float>::epsilon();
}
__device__ __inline__ bool Equal(const double& t0, const double& t1) {
auto t2 = t0 > t1 ? t0 - t1 : t1 - t0;
return t2 < std::numeric_limits<double>::epsilon();
}
template<typename T>
@ -220,6 +229,7 @@ __global__ void RadixTopK(const T* X, T* V, int64_t* I, const TArray<int64_t> el
}
__syncthreads();
positive = BlockReduce(temp_storage.reduce).Sum(positive);
__syncthreads();
negative = BlockReduce(temp_storage.reduce).Sum(negative);
if (0 == tid) {
H[0] = positive;
@ -286,6 +296,7 @@ __global__ void RadixTopK(const T* X, T* V, int64_t* I, const TArray<int64_t> el
__syncthreads();
all_superior = H[0];
BlockScan(temp_storage.scan).ExclusiveSum(superior, superior);
__syncthreads();
BlockScan(temp_storage.scan).ExclusiveSum(equal, equal);
__syncthreads();
auto equal_quota = K - all_superior - equal;