Remove ROCM workaround for half-to-double cast.

This commit is contained in:
Jesse Benson 2021-04-06 14:17:21 -07:00 committed by Jesse Benson
parent 25e261f196
commit 2ec452cdad

View file

@ -144,18 +144,10 @@ __device__ __inline__ half2 _Tanh(half2 a) {
return __float22half2_rn(tmp);
}
// TODO: temporary workaround for casting half-to-double, until ROCM/hipcc adds support.
namespace {
template <typename T>
__device__ __inline__ double __cast_to_double(T x) { return static_cast<double>(x); }
template <>
__device__ __inline__ double __cast_to_double(half x) { return static_cast<double>(static_cast<float>(x)); }
} // namespace
// Capture permutations of int32/64/float/double
template <typename T, typename T1>
__device__ __inline__ T _Pow(T a, T1 b) {
return static_cast<T>(pow(__cast_to_double(a), __cast_to_double(b)));
return static_cast<T>(pow(static_cast<double>(a), static_cast<double>(b)));
}
template <>