[MPS] Implement support for zeta (both eager and inductor). (#146465)

A test was failing in inductor (`test_pointwise_zeta`) -- and I realized the operation was missing also from eager.
Implemented for both, leveraging the kernel. Happy to split in two (one PR for eager, one for inductor) if folks prefer.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146465
Approved by: https://github.com/malfet
This commit is contained in:
Davide Italiano 2025-02-05 13:55:50 +00:00 committed by PyTorch MergeBot
parent fd0cd6a08f
commit 8a2000fd42
6 changed files with 50 additions and 2 deletions

View file

@ -1,3 +1,4 @@
#include <c10/metal/special_math.h>
#include <metal_stdlib>
using namespace metal;
@ -227,3 +228,28 @@ kernel void complex_kernel(
REGISTER_COMPLEX_OUT_OP(float);
REGISTER_COMPLEX_OUT_OP(half);
template <typename T>
kernel void zeta(
constant void* input_ [[buffer(0)]],
constant void* other_ [[buffer(1)]],
device void* out_ [[buffer(2)]],
constant uint3* offsets [[buffer(3)]],
uint tid [[thread_position_in_grid]]) {
device T* out = (device T*)((device uint8_t*)out_ + offsets[tid].x);
constant T* input = (constant T*)((constant uint8_t*)input_ + offsets[tid].y);
constant T* other = (constant T*)((constant uint8_t*)other_ + offsets[tid].z);
*out = static_cast<T>(c10::metal::zeta(*input, *other));
}
#define REGISTER_ZETA_OP(DTYPE) \
template [[host_name("zeta_" #DTYPE)]] kernel void zeta<DTYPE>( \
constant void* input_ [[buffer(0)]], \
constant void* other_ [[buffer(1)]], \
device void* out_ [[buffer(2)]], \
constant uint3* offsets [[buffer(3)]], \
uint tid [[thread_position_in_grid]]);
REGISTER_ZETA_OP(float);
REGISTER_ZETA_OP(half);

View file

@ -109,10 +109,16 @@ static void nextafter_mps_kernel(TensorIteratorBase& iter) {
mps::binary_mps_impl(iter, "nextafter_kernel");
}
static void zeta_mps_kernel(TensorIteratorBase& iter) {
TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), "zeta_mps not implemented for non-floating types");
mps::binary_mps_impl(iter, "zeta");
}
REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel)
REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel)
REGISTER_DISPATCH(copysign_stub, &copysign_mps_kernel)
REGISTER_DISPATCH(nextafter_stub, &nextafter_mps_kernel)
REGISTER_DISPATCH(zeta_stub, &zeta_mps_kernel)
Tensor& polar_out_mps(const Tensor& abs, const Tensor& angle, Tensor& output) {
auto new_size = at::infer_size(abs.sizes(), angle.sizes());

View file

@ -13463,7 +13463,7 @@
python_module: special
variants: function
dispatch:
CPU, CUDA: special_zeta_out
CPU, CUDA, MPS: special_zeta_out
tags: pointwise
- func: special_zeta.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)

View file

@ -104,6 +104,13 @@ class MPSBasicTests(TestCase):
def test_pointwise_digamma(self):
self.common(torch.special.digamma, (torch.rand(128, 128),), check_lowp=False)
def test_pointwise_zeta(self):
self.common(
torch.special.zeta,
(torch.rand(128, 128), torch.rand(128, 128)),
check_lowp=False,
)
def test_broadcast(self):
self.common(torch.add, (torch.rand(32, 1024), torch.rand(1024)))

View file

@ -102,6 +102,8 @@ def mps_ops_grad_modifier(ops):
'exponential': [torch.float16, torch.float32],
# CPU errors
# derivative for zeta is not implemented
'special.zeta': None,
# derivative for aten::nextafter is not implemented on CPU
'nextafter': None,
# derivative for aten::floor_divide is not implemented on CPU
@ -331,6 +333,7 @@ def mps_ops_modifier(ops):
'select',
'sgn',
'slice',
'special.zeta',
'split',
'split_with_sizes',
'split_with_sizes_copy',
@ -791,7 +794,6 @@ def mps_ops_modifier(ops):
'special.scaled_modified_bessel_k1': None,
'special.spherical_bessel_j0': None,
'special.xlog1py': None,
'special.zeta': None,
'svd_lowrank': None,
'symeig': None,
'take': None,
@ -840,6 +842,9 @@ def mps_ops_modifier(ops):
'atan2': [torch.int64],
'angle': [torch.int64],
# zeta isn't supported for integral types
'special.zeta': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
# GEMM on MPS is not supported for integral types
'nn.functional.linear': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'addmmdecomposed': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],

View file

@ -355,6 +355,10 @@ class MetalOverrides(OpOverrides):
cast_b = f"static_cast<decltype({a}+{b})>({b})"
return f"metal::pow({cast_a}, {cast_b})"
@staticmethod
def zeta(a: CSEVariable, b: CSEVariable) -> str:
return f"c10::metal::zeta({a}, {b})"
MetalOverrides._initialize_pointwise_overrides("mps")