mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[MPS] Implement support for zeta (both eager and inductor). (#146465)
A test was failing in inductor (`test_pointwise_zeta`) -- and I realized the operation was missing also from eager. Implemented for both, leveraging the kernel. Happy to split in two (one PR for eager, one for inductor) if folks prefer. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146465 Approved by: https://github.com/malfet
This commit is contained in:
parent
fd0cd6a08f
commit
8a2000fd42
6 changed files with 50 additions and 2 deletions
|
|
@ -1,3 +1,4 @@
|
|||
#include <c10/metal/special_math.h>
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
|
|
@ -227,3 +228,28 @@ kernel void complex_kernel(
|
|||
|
||||
REGISTER_COMPLEX_OUT_OP(float);
|
||||
REGISTER_COMPLEX_OUT_OP(half);
|
||||
|
||||
template <typename T>
|
||||
kernel void zeta(
|
||||
constant void* input_ [[buffer(0)]],
|
||||
constant void* other_ [[buffer(1)]],
|
||||
device void* out_ [[buffer(2)]],
|
||||
constant uint3* offsets [[buffer(3)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
device T* out = (device T*)((device uint8_t*)out_ + offsets[tid].x);
|
||||
constant T* input = (constant T*)((constant uint8_t*)input_ + offsets[tid].y);
|
||||
constant T* other = (constant T*)((constant uint8_t*)other_ + offsets[tid].z);
|
||||
|
||||
*out = static_cast<T>(c10::metal::zeta(*input, *other));
|
||||
}
|
||||
|
||||
#define REGISTER_ZETA_OP(DTYPE) \
|
||||
template [[host_name("zeta_" #DTYPE)]] kernel void zeta<DTYPE>( \
|
||||
constant void* input_ [[buffer(0)]], \
|
||||
constant void* other_ [[buffer(1)]], \
|
||||
device void* out_ [[buffer(2)]], \
|
||||
constant uint3* offsets [[buffer(3)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
REGISTER_ZETA_OP(float);
|
||||
REGISTER_ZETA_OP(half);
|
||||
|
|
|
|||
|
|
@ -109,10 +109,16 @@ static void nextafter_mps_kernel(TensorIteratorBase& iter) {
|
|||
mps::binary_mps_impl(iter, "nextafter_kernel");
|
||||
}
|
||||
|
||||
static void zeta_mps_kernel(TensorIteratorBase& iter) {
|
||||
TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), "zeta_mps not implemented for non-floating types");
|
||||
mps::binary_mps_impl(iter, "zeta");
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel)
|
||||
REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel)
|
||||
REGISTER_DISPATCH(copysign_stub, ©sign_mps_kernel)
|
||||
REGISTER_DISPATCH(nextafter_stub, &nextafter_mps_kernel)
|
||||
REGISTER_DISPATCH(zeta_stub, &zeta_mps_kernel)
|
||||
|
||||
Tensor& polar_out_mps(const Tensor& abs, const Tensor& angle, Tensor& output) {
|
||||
auto new_size = at::infer_size(abs.sizes(), angle.sizes());
|
||||
|
|
|
|||
|
|
@ -13463,7 +13463,7 @@
|
|||
python_module: special
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU, CUDA: special_zeta_out
|
||||
CPU, CUDA, MPS: special_zeta_out
|
||||
tags: pointwise
|
||||
|
||||
- func: special_zeta.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
|
|
|
|||
|
|
@ -104,6 +104,13 @@ class MPSBasicTests(TestCase):
|
|||
def test_pointwise_digamma(self):
|
||||
self.common(torch.special.digamma, (torch.rand(128, 128),), check_lowp=False)
|
||||
|
||||
def test_pointwise_zeta(self):
|
||||
self.common(
|
||||
torch.special.zeta,
|
||||
(torch.rand(128, 128), torch.rand(128, 128)),
|
||||
check_lowp=False,
|
||||
)
|
||||
|
||||
def test_broadcast(self):
|
||||
self.common(torch.add, (torch.rand(32, 1024), torch.rand(1024)))
|
||||
|
||||
|
|
|
|||
|
|
@ -102,6 +102,8 @@ def mps_ops_grad_modifier(ops):
|
|||
'exponential': [torch.float16, torch.float32],
|
||||
|
||||
# CPU errors
|
||||
# derivative for zeta is not implemented
|
||||
'special.zeta': None,
|
||||
# derivative for aten::nextafter is not implemented on CPU
|
||||
'nextafter': None,
|
||||
# derivative for aten::floor_divide is not implemented on CPU
|
||||
|
|
@ -331,6 +333,7 @@ def mps_ops_modifier(ops):
|
|||
'select',
|
||||
'sgn',
|
||||
'slice',
|
||||
'special.zeta',
|
||||
'split',
|
||||
'split_with_sizes',
|
||||
'split_with_sizes_copy',
|
||||
|
|
@ -791,7 +794,6 @@ def mps_ops_modifier(ops):
|
|||
'special.scaled_modified_bessel_k1': None,
|
||||
'special.spherical_bessel_j0': None,
|
||||
'special.xlog1py': None,
|
||||
'special.zeta': None,
|
||||
'svd_lowrank': None,
|
||||
'symeig': None,
|
||||
'take': None,
|
||||
|
|
@ -840,6 +842,9 @@ def mps_ops_modifier(ops):
|
|||
'atan2': [torch.int64],
|
||||
'angle': [torch.int64],
|
||||
|
||||
# zeta isn't supported for integral types
|
||||
'special.zeta': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
|
||||
|
||||
# GEMM on MPS is not supported for integral types
|
||||
'nn.functional.linear': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
|
||||
'addmmdecomposed': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
|
||||
|
|
|
|||
|
|
@ -355,6 +355,10 @@ class MetalOverrides(OpOverrides):
|
|||
cast_b = f"static_cast<decltype({a}+{b})>({b})"
|
||||
return f"metal::pow({cast_a}, {cast_b})"
|
||||
|
||||
@staticmethod
|
||||
def zeta(a: CSEVariable, b: CSEVariable) -> str:
|
||||
return f"c10::metal::zeta({a}, {b})"
|
||||
|
||||
|
||||
MetalOverrides._initialize_pointwise_overrides("mps")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue