[MPS] Add op_math_t (#145808)

Similar to `at::opmath_t` to be used for reduction (and int mms)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145808
Approved by: https://github.com/dcci
This commit is contained in:
Nikita Shulga 2025-01-27 17:57:27 -08:00 committed by PyTorch MergeBot
parent 5382ab57d7
commit 3fd4691908

View file

@ -51,6 +51,37 @@ struct vectypes<long> {
using type2 = short2;
};
template <typename T>
struct OpMathType {
using type = T;
};
template <>
struct OpMathType<half> {
using type = float;
};
template <>
struct OpMathType<short> {
using type = int;
};
template <>
struct OpMathType<char> {
using type = int;
};
template <>
struct OpMathType<uchar> {
using type = int;
};
#if __METAL_VERSION__ >= 310
template <>
struct OpMathType<bfloat> {
using type = float;
};
#endif
} // namespace detail
template <typename T>
@ -79,5 +110,7 @@ using vec2type_t = typename detail::vectypes<T>::type2;
template <typename T>
using vec4type_t = typename detail::vectypes<T>::type4;
template <typename T>
using opmath_t = typename detail::OpMathType<T>::type;
} // namespace metal
} // namespace c10