diff --git a/onnxruntime/core/providers/rocm/cu_inc/common.cuh b/onnxruntime/core/providers/rocm/cu_inc/common.cuh index 0d5ffee4ef..95224b6e3b 100644 --- a/onnxruntime/core/providers/rocm/cu_inc/common.cuh +++ b/onnxruntime/core/providers/rocm/cu_inc/common.cuh @@ -144,18 +144,10 @@ __device__ __inline__ half2 _Tanh(half2 a) { return __float22half2_rn(tmp); } -// TODO: temporary workaround for casting half-to-double, until ROCM/hipcc adds support. -namespace { -template -__device__ __inline__ double __cast_to_double(T x) { return static_cast(x); } -template <> -__device__ __inline__ double __cast_to_double(half x) { return static_cast(static_cast(x)); } -} // namespace - // Capture permutations of int32/64/float/double template __device__ __inline__ T _Pow(T a, T1 b) { - return static_cast(pow(__cast_to_double(a), __cast_to_double(b))); + return static_cast(pow(static_cast(a), static_cast(b))); } template <>