[mps/inductor] Add support for ceil (#144715)

inductor/test_index_dynamic_shapes passes after this change.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144715
Approved by: https://github.com/malfet
This commit is contained in:
Davide Italiano 2025-01-14 01:16:47 +00:00 committed by PyTorch MergeBot
parent 64bcf39180
commit de9d6a25d7
2 changed files with 5 additions and 0 deletions

View file

@ -49,6 +49,7 @@ class MPSBasicTests(TestCase):
test_floordiv = CommonTemplate.test_floordiv
test_fmod = CommonTemplate.test_fmod
test_fmod_zero_dim = CommonTemplate.test_fmod_zero_dim
test_index_dynamic_shapes = CommonTemplate.test_index_dynamic_shapes
test_inf = CommonTemplate.test_inf
test_isinf = CommonTemplate.test_isinf
test_isinf2 = CommonTemplate.test_isinf2

View file

@ -235,6 +235,10 @@ class MetalOverrides(OpOverrides):
float_b = f"static_cast<float>({b})" if b.dtype != torch.float else b
return f"metal::trunc({float_a}/{float_b})"
@staticmethod
def ceil(x: CSEVariable) -> str:
return f"metal::ceil({x})"
class MetalKernel(SIMDKernel):
overrides = MetalOverrides # type: ignore[assignment]