mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Workaround for static_cast<double>(half)
This commit is contained in:
parent
da952a9a20
commit
1059bfaf75
3 changed files with 30 additions and 20 deletions
|
|
@ -120,10 +120,18 @@ __device__ __inline__ half2 _Tanh(half2 a) {
|
|||
return __float22half2_rn(tmp);
|
||||
}
|
||||
|
||||
// TODO: temporary workaround for casting half-to-double, until ROCM/hipcc adds support.
|
||||
namespace {
|
||||
template <typename T>
|
||||
__device__ __inline__ double __cast_to_double(T x) { return static_cast<double>(x); }
|
||||
template <>
|
||||
__device__ __inline__ double __cast_to_double(half x) { return static_cast<double>(static_cast<float>(x)); }
|
||||
} // namespace
|
||||
|
||||
// Capture permutations of int32/64/float/double
|
||||
template <typename T, typename T1>
|
||||
__device__ __inline__ T _Pow(T a, T1 b) {
|
||||
return static_cast<T>(pow(static_cast<double>(a), static_cast<double>(b)));
|
||||
return static_cast<T>(pow(__cast_to_double(a), __cast_to_double(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
|
|
|
|||
|
|
@ -357,6 +357,19 @@ Status DispatchOnFirstArg(const BinaryElementwisePreparation& prepare) {
|
|||
reinterpret_cast<typename ToHipType<T>::MappedType*>(prepare.output_tensor->template MutableData<T>()),
|
||||
prepare.output_tensor->Shape().Size());
|
||||
break;
|
||||
case on::TensorProto_DataType_FLOAT16:
|
||||
ImplT1_Pow<typename ToHipType<T>::MappedType, typename ToHipType<MLFloat16>::MappedType>(
|
||||
prepare.output_rank_or_simple_broadcast,
|
||||
&prepare.lhs_padded_strides,
|
||||
reinterpret_cast<const typename ToHipType<T>::MappedType*>(prepare.lhs_tensor->template Data<T>()),
|
||||
&prepare.rhs_padded_strides,
|
||||
reinterpret_cast<const typename ToHipType<MLFloat16>::MappedType*>(prepare.rhs_tensor->template Data<MLFloat16>()),
|
||||
&prepare.fdm_output_strides,
|
||||
prepare.fdm_H,
|
||||
prepare.fdm_C,
|
||||
reinterpret_cast<typename ToHipType<T>::MappedType*>(prepare.output_tensor->template MutableData<T>()),
|
||||
prepare.output_tensor->Shape().Size());
|
||||
break;
|
||||
default:
|
||||
s = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported Y type: ",
|
||||
DataTypeImpl::ToString(prepare.rhs_tensor->DataType()));
|
||||
|
|
@ -386,6 +399,9 @@ Status Pow::ComputeInternal(OpKernelContext* context) const {
|
|||
case on::TensorProto_DataType_DOUBLE:
|
||||
s = DispatchOnFirstArg<double>(prepare);
|
||||
break;
|
||||
case on::TensorProto_DataType_FLOAT16:
|
||||
s = DispatchOnFirstArg<MLFloat16>(prepare);
|
||||
break;
|
||||
default:
|
||||
s = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported X type: ",
|
||||
DataTypeImpl::ToString(prepare.lhs_tensor->DataType()));
|
||||
|
|
|
|||
|
|
@ -137,25 +137,11 @@ SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Min)
|
|||
// create declarations for impl for Pow
|
||||
BINARY_ELEMENTWISE_IMPL_T1(Pow)
|
||||
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, int32_t, int32_t)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, int32_t, int64_t)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, int32_t, float)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, int32_t, double)
|
||||
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, int64_t, int32_t)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, int64_t, int64_t)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, int64_t, float)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, int64_t, double)
|
||||
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, float, int32_t)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, float, int64_t)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, float, float)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, float, double)
|
||||
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, double, int32_t)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, double, int64_t)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, double, float)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(Pow, double, double)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1_ILHFD(Pow, int32_t)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1_ILHFD(Pow, int64_t)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1_ILHFD(Pow, float)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1_ILHFD(Pow, double)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1_ILHFD(Pow, half)
|
||||
|
||||
// create declarations for impl2
|
||||
#define BINARY_OP_NAME_EXPR2(name, expr) \
|
||||
|
|
|
|||
Loading…
Reference in a new issue