diff --git a/test/inductor/test_mps_basic.py b/test/inductor/test_mps_basic.py index 728b92457c0..d4d53b30dc6 100644 --- a/test/inductor/test_mps_basic.py +++ b/test/inductor/test_mps_basic.py @@ -66,6 +66,7 @@ class MPSBasicTests(TestCase): test_remove_no_ops = CommonTemplate.test_remove_no_ops test_reflection_pad2d = CommonTemplate.test_reflection_pad2d test_rsqrt = CommonTemplate.test_rsqrt + test_scalar_cpu_tensor_arg = CommonTemplate.test_scalar_cpu_tensor_arg test_scalar_output = CommonTemplate.test_scalar_output test_setitem_with_int_parameter = CommonTemplate.test_setitem_with_int_parameter test_signbit = CommonTemplate.test_signbit diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 83efb231d6d..84f1e46aad5 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -11778,6 +11778,8 @@ class CommonTemplate: torch.bfloat16, ] for cpu_dtype in test_dtypes: + if not self.is_dtype_supported(cpu_dtype): + continue x = torch.rand([20], device=GPU_TYPE) y = torch.rand([4], device="cpu", dtype=cpu_dtype) self.common(