[MPSInductor] Fix min/max for bfloat16 (#146552)

By introducing a full specialization that upcasts everything to float, as bfloat does not have a native min/max

Test by runing `test_min_max_reduction`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146552
Approved by: https://github.com/dcci
This commit is contained in:
Nikita Shulga 2025-02-05 17:21:00 -08:00 committed by PyTorch MergeBot
parent 1f8baf09ea
commit 36c6e09528
4 changed files with 19 additions and 2 deletions

View file

@ -106,6 +106,20 @@ template <typename T>
return ::metal::min(a, b);
}
#if __METAL_VERSION__ >= 310
template <>
bfloat min(bfloat a, bfloat b) {
return bfloat(
::metal::isunordered(a, b) ? NAN : ::metal::min(float(a), float(b)));
}
template <>
bfloat max(bfloat a, bfloat b) {
return bfloat(
::metal::isunordered(a, b) ? NAN : ::metal::max(float(a), float(b)));
}
#endif
template <typename T>
using vec2type_t = typename detail::vectypes<T>::type2;

View file

@ -131,6 +131,7 @@ class MPSBasicTests(TestCase):
# Copy tests
for test_name in [
"test_min_max_reduction",
"test_add_const_int",
"test_add_inplace_permuted",
"test_addmm",

View file

@ -2408,7 +2408,7 @@ class CommonTemplate:
)
dtypes = [torch.float, torch.float16]
if not (self.device == "cuda" and not SM80OrLater):
if self.is_dtype_supported(torch.bfloat16):
dtypes += [torch.bfloat16]
for dtype in dtypes:
self.common(fn, (torch.randn(8, 8).to(dtype), torch.randn(8, 8).to(dtype)))

View file

@ -451,7 +451,9 @@ class MetalKernel(SIMDKernel):
)
if reduction_type in ["max", "min", "argmax", "argmin"]:
acc_buf = self._new_accvar(src_dtype, reduction_dim.numel)
self.body.splice(f"{acc_buf}[{reduction_dim.name}] = {value};")
self.body.splice(
f"{acc_buf}[{reduction_dim.name}] = static_cast<{DTYPE_TO_METAL[src_dtype]}>({value});"
)
return self.cse.generate(
self.body,
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {reduction_dim.numel})",