mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[MPSInductor] Add min/max to MetalExprPrinter (#144798)
After that `GPUTests::test_avg_pool2d8_mps` and `GPUTests::test_avg_pool2d5_mps` passes Pull Request resolved: https://github.com/pytorch/pytorch/pull/144798 Approved by: https://github.com/dcci ghstack dependencies: #144509
This commit is contained in:
parent
9199c79a9c
commit
e2251fffbb
2 changed files with 12 additions and 0 deletions
|
|
@ -42,6 +42,8 @@ class MPSBasicTests(TestCase):
|
|||
test_add_inplace_permuted_mps = CommonTemplate.test_add_inplace_permuted
|
||||
test_addmm = CommonTemplate.test_addmm
|
||||
test_argmax_min_int32 = CommonTemplate.test_argmax_min_int32
|
||||
test_avg_pool2d5 = CommonTemplate.test_avg_pool2d5
|
||||
test_avg_pool2d8 = CommonTemplate.test_avg_pool2d8
|
||||
test_div1 = CommonTemplate.test_div1
|
||||
test_div3 = CommonTemplate.test_div3
|
||||
test_cat_empty = CommonTemplate.test_cat_empty
|
||||
|
|
|
|||
|
|
@ -63,6 +63,16 @@ class MetalExprPrinter(ExprPrinter_):
|
|||
mod = self.doprint(mod)
|
||||
return f"({x}) % ({mod})"
|
||||
|
||||
def _print_Min(self, expr: sympy.Expr) -> str:
|
||||
if len(expr.args) != 2:
|
||||
raise RuntimeError("metal::min only supported for 2 args")
|
||||
return f"metal::min({', '.join(map(self._print, expr.args))})"
|
||||
|
||||
def _print_Max(self, expr: sympy.Expr) -> str:
|
||||
if len(expr.args) != 2:
|
||||
raise RuntimeError("metal::max only supported for 2 args")
|
||||
return f"metal::max({', '.join(map(self._print, expr.args))})"
|
||||
|
||||
|
||||
class MetalOverrides(OpOverrides):
|
||||
@staticmethod
|
||||
|
|
|
|||
Loading…
Reference in a new issue