[MPSInductor] Add min/max to MetalExprPrinter (#144798)

After that `GPUTests::test_avg_pool2d8_mps` and `GPUTests::test_avg_pool2d5_mps` passes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144798
Approved by: https://github.com/dcci
ghstack dependencies: #144509
This commit is contained in:
Nikita Shulga 2025-01-14 16:53:23 -08:00 committed by PyTorch MergeBot
parent 9199c79a9c
commit e2251fffbb
2 changed files with 12 additions and 0 deletions

View file

@ -42,6 +42,8 @@ class MPSBasicTests(TestCase):
test_add_inplace_permuted_mps = CommonTemplate.test_add_inplace_permuted
test_addmm = CommonTemplate.test_addmm
test_argmax_min_int32 = CommonTemplate.test_argmax_min_int32
test_avg_pool2d5 = CommonTemplate.test_avg_pool2d5
test_avg_pool2d8 = CommonTemplate.test_avg_pool2d8
test_div1 = CommonTemplate.test_div1
test_div3 = CommonTemplate.test_div3
test_cat_empty = CommonTemplate.test_cat_empty

View file

@ -63,6 +63,16 @@ class MetalExprPrinter(ExprPrinter_):
mod = self.doprint(mod)
return f"({x}) % ({mod})"
def _print_Min(self, expr: sympy.Expr) -> str:
if len(expr.args) != 2:
raise RuntimeError("metal::min only supported for 2 args")
return f"metal::min({', '.join(map(self._print, expr.args))})"
def _print_Max(self, expr: sympy.Expr) -> str:
if len(expr.args) != 2:
raise RuntimeError("metal::max only supported for 2 args")
return f"metal::max({', '.join(map(self._print, expr.args))})"
class MetalOverrides(OpOverrides):
@staticmethod