mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[MPSInductor] Fix min/max for bfloat16 (#146552)
By introducing a full specialization that upcasts everything to float, as bfloat does not have a native min/max Test by runing `test_min_max_reduction` Pull Request resolved: https://github.com/pytorch/pytorch/pull/146552 Approved by: https://github.com/dcci
This commit is contained in:
parent
1f8baf09ea
commit
36c6e09528
4 changed files with 19 additions and 2 deletions
|
|
@ -106,6 +106,20 @@ template <typename T>
|
|||
return ::metal::min(a, b);
|
||||
}
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
template <>
|
||||
bfloat min(bfloat a, bfloat b) {
|
||||
return bfloat(
|
||||
::metal::isunordered(a, b) ? NAN : ::metal::min(float(a), float(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
bfloat max(bfloat a, bfloat b) {
|
||||
return bfloat(
|
||||
::metal::isunordered(a, b) ? NAN : ::metal::max(float(a), float(b)));
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
using vec2type_t = typename detail::vectypes<T>::type2;
|
||||
|
||||
|
|
|
|||
|
|
@ -131,6 +131,7 @@ class MPSBasicTests(TestCase):
|
|||
|
||||
# Copy tests
|
||||
for test_name in [
|
||||
"test_min_max_reduction",
|
||||
"test_add_const_int",
|
||||
"test_add_inplace_permuted",
|
||||
"test_addmm",
|
||||
|
|
|
|||
|
|
@ -2408,7 +2408,7 @@ class CommonTemplate:
|
|||
)
|
||||
|
||||
dtypes = [torch.float, torch.float16]
|
||||
if not (self.device == "cuda" and not SM80OrLater):
|
||||
if self.is_dtype_supported(torch.bfloat16):
|
||||
dtypes += [torch.bfloat16]
|
||||
for dtype in dtypes:
|
||||
self.common(fn, (torch.randn(8, 8).to(dtype), torch.randn(8, 8).to(dtype)))
|
||||
|
|
|
|||
|
|
@ -451,7 +451,9 @@ class MetalKernel(SIMDKernel):
|
|||
)
|
||||
if reduction_type in ["max", "min", "argmax", "argmin"]:
|
||||
acc_buf = self._new_accvar(src_dtype, reduction_dim.numel)
|
||||
self.body.splice(f"{acc_buf}[{reduction_dim.name}] = {value};")
|
||||
self.body.splice(
|
||||
f"{acc_buf}[{reduction_dim.name}] = static_cast<{DTYPE_TO_METAL[src_dtype]}>({value});"
|
||||
)
|
||||
return self.cse.generate(
|
||||
self.body,
|
||||
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {reduction_dim.numel})",
|
||||
|
|
|
|||
Loading…
Reference in a new issue