mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[MPS] Add op_math_t (#145808)
Similar to `at::opmath_t` to be used for reduction (and int mms) Pull Request resolved: https://github.com/pytorch/pytorch/pull/145808 Approved by: https://github.com/dcci
This commit is contained in:
parent
5382ab57d7
commit
3fd4691908
1 changed files with 33 additions and 0 deletions
|
|
@ -51,6 +51,37 @@ struct vectypes<long> {
|
|||
using type2 = short2;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct OpMathType {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OpMathType<half> {
|
||||
using type = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OpMathType<short> {
|
||||
using type = int;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OpMathType<char> {
|
||||
using type = int;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OpMathType<uchar> {
|
||||
using type = int;
|
||||
};
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
template <>
|
||||
struct OpMathType<bfloat> {
|
||||
using type = float;
|
||||
};
|
||||
#endif
|
||||
} // namespace detail
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -79,5 +110,7 @@ using vec2type_t = typename detail::vectypes<T>::type2;
|
|||
template <typename T>
|
||||
using vec4type_t = typename detail::vectypes<T>::type4;
|
||||
|
||||
template <typename T>
|
||||
using opmath_t = typename detail::OpMathType<T>::type;
|
||||
} // namespace metal
|
||||
} // namespace c10
|
||||
|
|
|
|||
Loading…
Reference in a new issue