From 36c6e09528a7e071edecde083254da70cba26c95 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Wed, 5 Feb 2025 17:21:00 -0800 Subject: [PATCH] [MPSInductor] Fix min/max for bfloat16 (#146552) By introducing a full specialization that upcasts everything to float, as bfloat does not have a native min/max Test by runing `test_min_max_reduction` Pull Request resolved: https://github.com/pytorch/pytorch/pull/146552 Approved by: https://github.com/dcci --- c10/metal/utils.h | 14 ++++++++++++++ test/inductor/test_mps_basic.py | 1 + test/inductor/test_torchinductor.py | 2 +- torch/_inductor/codegen/mps.py | 4 +++- 4 files changed, 19 insertions(+), 2 deletions(-) diff --git a/c10/metal/utils.h b/c10/metal/utils.h index b73898173de..0429787951c 100644 --- a/c10/metal/utils.h +++ b/c10/metal/utils.h @@ -106,6 +106,20 @@ template return ::metal::min(a, b); } +#if __METAL_VERSION__ >= 310 +template <> +bfloat min(bfloat a, bfloat b) { + return bfloat( + ::metal::isunordered(a, b) ? NAN : ::metal::min(float(a), float(b))); +} + +template <> +bfloat max(bfloat a, bfloat b) { + return bfloat( + ::metal::isunordered(a, b) ? NAN : ::metal::max(float(a), float(b))); +} +#endif + template using vec2type_t = typename detail::vectypes::type2; diff --git a/test/inductor/test_mps_basic.py b/test/inductor/test_mps_basic.py index 8e557cad1ee..6295e85c4e9 100644 --- a/test/inductor/test_mps_basic.py +++ b/test/inductor/test_mps_basic.py @@ -131,6 +131,7 @@ class MPSBasicTests(TestCase): # Copy tests for test_name in [ + "test_min_max_reduction", "test_add_const_int", "test_add_inplace_permuted", "test_addmm", diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 8cb033c9b72..5ef7a6ec5b6 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -2408,7 +2408,7 @@ class CommonTemplate: ) dtypes = [torch.float, torch.float16] - if not (self.device == "cuda" and not SM80OrLater): + if self.is_dtype_supported(torch.bfloat16): dtypes += [torch.bfloat16] for dtype in dtypes: self.common(fn, (torch.randn(8, 8).to(dtype), torch.randn(8, 8).to(dtype))) diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index 8bc4781fd38..38922c5859c 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -451,7 +451,9 @@ class MetalKernel(SIMDKernel): ) if reduction_type in ["max", "min", "argmax", "argmin"]: acc_buf = self._new_accvar(src_dtype, reduction_dim.numel) - self.body.splice(f"{acc_buf}[{reduction_dim.name}] = {value};") + self.body.splice( + f"{acc_buf}[{reduction_dim.name}] = static_cast<{DTYPE_TO_METAL[src_dtype]}>({value});" + ) return self.cse.generate( self.body, f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {reduction_dim.numel})",