diff --git a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal index c5c39a9c99d..32c1846f32a 100644 --- a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal +++ b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal @@ -1,3 +1,4 @@ +#include #include using namespace metal; @@ -227,3 +228,28 @@ kernel void complex_kernel( REGISTER_COMPLEX_OUT_OP(float); REGISTER_COMPLEX_OUT_OP(half); + +template +kernel void zeta( + constant void* input_ [[buffer(0)]], + constant void* other_ [[buffer(1)]], + device void* out_ [[buffer(2)]], + constant uint3* offsets [[buffer(3)]], + uint tid [[thread_position_in_grid]]) { + device T* out = (device T*)((device uint8_t*)out_ + offsets[tid].x); + constant T* input = (constant T*)((constant uint8_t*)input_ + offsets[tid].y); + constant T* other = (constant T*)((constant uint8_t*)other_ + offsets[tid].z); + + *out = static_cast(c10::metal::zeta(*input, *other)); +} + +#define REGISTER_ZETA_OP(DTYPE) \ + template [[host_name("zeta_" #DTYPE)]] kernel void zeta( \ + constant void* input_ [[buffer(0)]], \ + constant void* other_ [[buffer(1)]], \ + device void* out_ [[buffer(2)]], \ + constant uint3* offsets [[buffer(3)]], \ + uint tid [[thread_position_in_grid]]); + +REGISTER_ZETA_OP(float); +REGISTER_ZETA_OP(half); diff --git a/aten/src/ATen/native/mps/operations/BinaryKernel.mm b/aten/src/ATen/native/mps/operations/BinaryKernel.mm index c5e8b2caf4a..fcbc9b00d8c 100644 --- a/aten/src/ATen/native/mps/operations/BinaryKernel.mm +++ b/aten/src/ATen/native/mps/operations/BinaryKernel.mm @@ -109,10 +109,16 @@ static void nextafter_mps_kernel(TensorIteratorBase& iter) { mps::binary_mps_impl(iter, "nextafter_kernel"); } +static void zeta_mps_kernel(TensorIteratorBase& iter) { + TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), "zeta_mps not implemented for non-floating types"); + mps::binary_mps_impl(iter, "zeta"); +} + REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel) REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel) REGISTER_DISPATCH(copysign_stub, ©sign_mps_kernel) REGISTER_DISPATCH(nextafter_stub, &nextafter_mps_kernel) +REGISTER_DISPATCH(zeta_stub, &zeta_mps_kernel) Tensor& polar_out_mps(const Tensor& abs, const Tensor& angle, Tensor& output) { auto new_size = at::infer_size(abs.sizes(), angle.sizes()); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 4d916f1a657..723fa1f9b7e 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -13463,7 +13463,7 @@ python_module: special variants: function dispatch: - CPU, CUDA: special_zeta_out + CPU, CUDA, MPS: special_zeta_out tags: pointwise - func: special_zeta.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) diff --git a/test/inductor/test_mps_basic.py b/test/inductor/test_mps_basic.py index 2bbd960877f..8e557cad1ee 100644 --- a/test/inductor/test_mps_basic.py +++ b/test/inductor/test_mps_basic.py @@ -104,6 +104,13 @@ class MPSBasicTests(TestCase): def test_pointwise_digamma(self): self.common(torch.special.digamma, (torch.rand(128, 128),), check_lowp=False) + def test_pointwise_zeta(self): + self.common( + torch.special.zeta, + (torch.rand(128, 128), torch.rand(128, 128)), + check_lowp=False, + ) + def test_broadcast(self): self.common(torch.add, (torch.rand(32, 1024), torch.rand(1024))) diff --git a/test/test_mps.py b/test/test_mps.py index a5f79e96d11..12999221b1d 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -102,6 +102,8 @@ def mps_ops_grad_modifier(ops): 'exponential': [torch.float16, torch.float32], # CPU errors + # derivative for zeta is not implemented + 'special.zeta': None, # derivative for aten::nextafter is not implemented on CPU 'nextafter': None, # derivative for aten::floor_divide is not implemented on CPU @@ -331,6 +333,7 @@ def mps_ops_modifier(ops): 'select', 'sgn', 'slice', + 'special.zeta', 'split', 'split_with_sizes', 'split_with_sizes_copy', @@ -791,7 +794,6 @@ def mps_ops_modifier(ops): 'special.scaled_modified_bessel_k1': None, 'special.spherical_bessel_j0': None, 'special.xlog1py': None, - 'special.zeta': None, 'svd_lowrank': None, 'symeig': None, 'take': None, @@ -840,6 +842,9 @@ def mps_ops_modifier(ops): 'atan2': [torch.int64], 'angle': [torch.int64], + # zeta isn't supported for integral types + 'special.zeta': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + # GEMM on MPS is not supported for integral types 'nn.functional.linear': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 'addmmdecomposed': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index 2e6c3b9cf51..7b3b9cd2920 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -355,6 +355,10 @@ class MetalOverrides(OpOverrides): cast_b = f"static_cast({b})" return f"metal::pow({cast_a}, {cast_b})" + @staticmethod + def zeta(a: CSEVariable, b: CSEVariable) -> str: + return f"c10::metal::zeta({a}, {b})" + MetalOverrides._initialize_pointwise_overrides("mps")