diff --git a/test/inductor/test_mps_basic.py b/test/inductor/test_mps_basic.py index d0b2251edb9..559db76b0b3 100644 --- a/test/inductor/test_mps_basic.py +++ b/test/inductor/test_mps_basic.py @@ -42,6 +42,8 @@ class MPSBasicTests(TestCase): test_add_inplace_permuted_mps = CommonTemplate.test_add_inplace_permuted test_addmm = CommonTemplate.test_addmm test_argmax_min_int32 = CommonTemplate.test_argmax_min_int32 + test_avg_pool2d5 = CommonTemplate.test_avg_pool2d5 + test_avg_pool2d8 = CommonTemplate.test_avg_pool2d8 test_div1 = CommonTemplate.test_div1 test_div3 = CommonTemplate.test_div3 test_cat_empty = CommonTemplate.test_cat_empty diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index 7ea6dec89a4..2a07c6fc524 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -63,6 +63,16 @@ class MetalExprPrinter(ExprPrinter_): mod = self.doprint(mod) return f"({x}) % ({mod})" + def _print_Min(self, expr: sympy.Expr) -> str: + if len(expr.args) != 2: + raise RuntimeError("metal::min only supported for 2 args") + return f"metal::min({', '.join(map(self._print, expr.args))})" + + def _print_Max(self, expr: sympy.Expr) -> str: + if len(expr.args) != 2: + raise RuntimeError("metal::max only supported for 2 args") + return f"metal::max({', '.join(map(self._print, expr.args))})" + class MetalOverrides(OpOverrides): @staticmethod