diff --git a/onnxruntime/core/providers/rocm/cu_inc/common.cuh b/onnxruntime/core/providers/rocm/cu_inc/common.cuh index ae3781f77c..36df3a5022 100644 --- a/onnxruntime/core/providers/rocm/cu_inc/common.cuh +++ b/onnxruntime/core/providers/rocm/cu_inc/common.cuh @@ -120,10 +120,18 @@ __device__ __inline__ half2 _Tanh(half2 a) { return __float22half2_rn(tmp); } +// TODO: temporary workaround for casting half-to-double, until ROCM/hipcc adds support. +namespace { +template +__device__ __inline__ double __cast_to_double(T x) { return static_cast(x); } +template <> +__device__ __inline__ double __cast_to_double(half x) { return static_cast(static_cast(x)); } +} // namespace + // Capture permutations of int32/64/float/double template __device__ __inline__ T _Pow(T a, T1 b) { - return static_cast(pow(static_cast(a), static_cast(b))); + return static_cast(pow(__cast_to_double(a), __cast_to_double(b))); } template <> diff --git a/onnxruntime/core/providers/rocm/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/rocm/math/binary_elementwise_ops.cc index 8e7ebebfc1..f84d1e65dc 100644 --- a/onnxruntime/core/providers/rocm/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/rocm/math/binary_elementwise_ops.cc @@ -357,6 +357,19 @@ Status DispatchOnFirstArg(const BinaryElementwisePreparation& prepare) { reinterpret_cast::MappedType*>(prepare.output_tensor->template MutableData()), prepare.output_tensor->Shape().Size()); break; + case on::TensorProto_DataType_FLOAT16: + ImplT1_Pow::MappedType, typename ToHipType::MappedType>( + prepare.output_rank_or_simple_broadcast, + &prepare.lhs_padded_strides, + reinterpret_cast::MappedType*>(prepare.lhs_tensor->template Data()), + &prepare.rhs_padded_strides, + reinterpret_cast::MappedType*>(prepare.rhs_tensor->template Data()), + &prepare.fdm_output_strides, + prepare.fdm_H, + prepare.fdm_C, + reinterpret_cast::MappedType*>(prepare.output_tensor->template MutableData()), + prepare.output_tensor->Shape().Size()); + break; default: s = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported Y type: ", DataTypeImpl::ToString(prepare.rhs_tensor->DataType())); @@ -386,6 +399,9 @@ Status Pow::ComputeInternal(OpKernelContext* context) const { case on::TensorProto_DataType_DOUBLE: s = DispatchOnFirstArg(prepare); break; + case on::TensorProto_DataType_FLOAT16: + s = DispatchOnFirstArg(prepare); + break; default: s = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported X type: ", DataTypeImpl::ToString(prepare.lhs_tensor->DataType())); diff --git a/onnxruntime/core/providers/rocm/math/binary_elementwise_ops_impl.cu b/onnxruntime/core/providers/rocm/math/binary_elementwise_ops_impl.cu index f4b3f90762..ec305b04d0 100644 --- a/onnxruntime/core/providers/rocm/math/binary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/rocm/math/binary_elementwise_ops_impl.cu @@ -137,25 +137,11 @@ SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Min) // create declarations for impl for Pow BINARY_ELEMENTWISE_IMPL_T1(Pow) -SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, int32_t, int32_t) -SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, int32_t, int64_t) -SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, int32_t, float) -SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, int32_t, double) - -SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, int64_t, int32_t) -SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, int64_t, int64_t) -SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, int64_t, float) -SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, int64_t, double) - -SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, float, int32_t) -SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, float, int64_t) -SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, float, float) -SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, float, double) - -SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, double, int32_t) -SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, double, int64_t) -SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, double, float) -SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, double, double) +SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1_ILHFD(Pow, int32_t) +SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1_ILHFD(Pow, int64_t) +SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1_ILHFD(Pow, float) +SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1_ILHFD(Pow, double) +SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1_ILHFD(Pow, half) // create declarations for impl2 #define BINARY_OP_NAME_EXPR2(name, expr) \