mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
sync threads before calling next cub function (#3758)
Co-authored-by: RandySheriffH <rashuai@microsoft.com>
This commit is contained in:
parent
af3988198c
commit
86eaa71ec6
1 changed files with 12 additions and 1 deletions
|
|
@ -148,8 +148,17 @@ __global__ void BitonicTopK(const T* X, T* V, int64_t* I, const TArray<int64_t>
|
|||
|
||||
template <typename T>
|
||||
__device__ __inline__ bool Equal(const T& t0, const T& t1) {
|
||||
return t0 == t1;
|
||||
}
|
||||
|
||||
__device__ __inline__ bool Equal(const float& t0, const float& t1) {
|
||||
auto t2 = t0 > t1 ? t0 - t1 : t1 - t0;
|
||||
return (double)t2 < 1.0e-5;
|
||||
return t2 < std::numeric_limits<float>::epsilon();
|
||||
}
|
||||
|
||||
__device__ __inline__ bool Equal(const double& t0, const double& t1) {
|
||||
auto t2 = t0 > t1 ? t0 - t1 : t1 - t0;
|
||||
return t2 < std::numeric_limits<double>::epsilon();
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
|
|
@ -220,6 +229,7 @@ __global__ void RadixTopK(const T* X, T* V, int64_t* I, const TArray<int64_t> el
|
|||
}
|
||||
__syncthreads();
|
||||
positive = BlockReduce(temp_storage.reduce).Sum(positive);
|
||||
__syncthreads();
|
||||
negative = BlockReduce(temp_storage.reduce).Sum(negative);
|
||||
if (0 == tid) {
|
||||
H[0] = positive;
|
||||
|
|
@ -286,6 +296,7 @@ __global__ void RadixTopK(const T* X, T* V, int64_t* I, const TArray<int64_t> el
|
|||
__syncthreads();
|
||||
all_superior = H[0];
|
||||
BlockScan(temp_storage.scan).ExclusiveSum(superior, superior);
|
||||
__syncthreads();
|
||||
BlockScan(temp_storage.scan).ExclusiveSum(equal, equal);
|
||||
__syncthreads();
|
||||
auto equal_quota = K - all_superior - equal;
|
||||
|
|
|
|||
Loading…
Reference in a new issue