Workaround for static_cast<double>(half)

This commit is contained in:
Jesse Benson 2021-01-08 11:55:07 -08:00 committed by Jesse Benson
parent da952a9a20
commit 1059bfaf75
3 changed files with 30 additions and 20 deletions

View file

@ -120,10 +120,18 @@ __device__ __inline__ half2 _Tanh(half2 a) {
return __float22half2_rn(tmp);
}
// TODO: temporary workaround for casting half-to-double, until ROCM/hipcc adds support.
namespace {
template <typename T>
__device__ __inline__ double __cast_to_double(T x) { return static_cast<double>(x); }
template <>
__device__ __inline__ double __cast_to_double(half x) { return static_cast<double>(static_cast<float>(x)); }
} // namespace
// Capture permutations of int32/64/float/double
template <typename T, typename T1>
__device__ __inline__ T _Pow(T a, T1 b) {
return static_cast<T>(pow(static_cast<double>(a), static_cast<double>(b)));
return static_cast<T>(pow(__cast_to_double(a), __cast_to_double(b)));
}
template <>

View file

@ -357,6 +357,19 @@ Status DispatchOnFirstArg(const BinaryElementwisePreparation& prepare) {
reinterpret_cast<typename ToHipType<T>::MappedType*>(prepare.output_tensor->template MutableData<T>()),
prepare.output_tensor->Shape().Size());
break;
case on::TensorProto_DataType_FLOAT16:
ImplT1_Pow<typename ToHipType<T>::MappedType, typename ToHipType<MLFloat16>::MappedType>(
prepare.output_rank_or_simple_broadcast,
&prepare.lhs_padded_strides,
reinterpret_cast<const typename ToHipType<T>::MappedType*>(prepare.lhs_tensor->template Data<T>()),
&prepare.rhs_padded_strides,
reinterpret_cast<const typename ToHipType<MLFloat16>::MappedType*>(prepare.rhs_tensor->template Data<MLFloat16>()),
&prepare.fdm_output_strides,
prepare.fdm_H,
prepare.fdm_C,
reinterpret_cast<typename ToHipType<T>::MappedType*>(prepare.output_tensor->template MutableData<T>()),
prepare.output_tensor->Shape().Size());
break;
default:
s = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported Y type: ",
DataTypeImpl::ToString(prepare.rhs_tensor->DataType()));
@ -386,6 +399,9 @@ Status Pow::ComputeInternal(OpKernelContext* context) const {
case on::TensorProto_DataType_DOUBLE:
s = DispatchOnFirstArg<double>(prepare);
break;
case on::TensorProto_DataType_FLOAT16:
s = DispatchOnFirstArg<MLFloat16>(prepare);
break;
default:
s = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported X type: ",
DataTypeImpl::ToString(prepare.lhs_tensor->DataType()));

View file

@ -137,25 +137,11 @@ SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Min)
// create declarations for impl for Pow
BINARY_ELEMENTWISE_IMPL_T1(Pow)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, int32_t, int32_t)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, int32_t, int64_t)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, int32_t, float)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, int32_t, double)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, int64_t, int32_t)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, int64_t, int64_t)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, int64_t, float)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, int64_t, double)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, float, int32_t)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, float, int64_t)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, float, float)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, float, double)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, double, int32_t)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, double, int64_t)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, double, float)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, double, double)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1_ILHFD(Pow, int32_t)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1_ILHFD(Pow, int64_t)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1_ILHFD(Pow, float)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1_ILHFD(Pow, double)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1_ILHFD(Pow, half)
// create declarations for impl2
#define BINARY_OP_NAME_EXPR2(name, expr) \