diff --git a/test/inductor/test_mps_basic.py b/test/inductor/test_mps_basic.py index 13ec1869c27..3600a3bf1d1 100644 --- a/test/inductor/test_mps_basic.py +++ b/test/inductor/test_mps_basic.py @@ -49,6 +49,7 @@ class MPSBasicTests(TestCase): test_floordiv = CommonTemplate.test_floordiv test_fmod = CommonTemplate.test_fmod test_fmod_zero_dim = CommonTemplate.test_fmod_zero_dim + test_index_dynamic_shapes = CommonTemplate.test_index_dynamic_shapes test_inf = CommonTemplate.test_inf test_isinf = CommonTemplate.test_isinf test_isinf2 = CommonTemplate.test_isinf2 diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index 87d916eddd3..0b49cd47405 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -235,6 +235,10 @@ class MetalOverrides(OpOverrides): float_b = f"static_cast({b})" if b.dtype != torch.float else b return f"metal::trunc({float_a}/{float_b})" + @staticmethod + def ceil(x: CSEVariable) -> str: + return f"metal::ceil({x})" + class MetalKernel(SIMDKernel): overrides = MetalOverrides # type: ignore[assignment]