mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
Fix TopK with NAN on Cuda (#10314)
* reset MIN for float/double * better logics for float/double comparision for equals
This commit is contained in:
parent
ff2057a817
commit
ce081fe655
1 changed files with 3 additions and 5 deletions
|
|
@ -197,13 +197,11 @@ __device__ __inline__ bool Equal(const T& t0, const T& t1) {
|
|||
}
|
||||
|
||||
__device__ __inline__ bool Equal(const float& t0, const float& t1) {
|
||||
auto t2 = t0 > t1 ? t0 - t1 : t1 - t0;
|
||||
return t2 < std::numeric_limits<float>::epsilon();
|
||||
return !(t0 > t1 || t1 > t0);
|
||||
}
|
||||
|
||||
__device__ __inline__ bool Equal(const double& t0, const double& t1) {
|
||||
auto t2 = t0 > t1 ? t0 - t1 : t1 - t0;
|
||||
return t2 < std::numeric_limits<double>::epsilon();
|
||||
return !(t0 > t1 || t1 > t0);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
|
|
@ -305,7 +303,7 @@ __global__ void RadixTopK(const T* X, T* V, int64_t* I, const TArray<int64_t> el
|
|||
}
|
||||
} else {
|
||||
if (KK > negative) {
|
||||
KK = dimension - KK + 1;
|
||||
KK = dimension - KK + 1;
|
||||
} else {
|
||||
sign = (T)-1;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue