[BE] [mps] Refactor UnaryConstants to be its own kernel. (#145230)

In preparation for using this file for inductor (for erfinv).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145230
Approved by: https://github.com/malfet
This commit is contained in:
Davide Italiano 2025-01-23 20:58:43 +00:00 committed by PyTorch MergeBot
parent 881eb86692
commit e924ddbef1
3 changed files with 144 additions and 85 deletions

View file

@ -1,80 +0,0 @@
#pragma once
const char* UNARY_KERNEL_TEMPLATE = R"METAL(
#include <metal_stdlib>
using namespace metal;
constant float a[4] = {{0.886226899, -1.645349621, 0.914624893, -0.140543331}};
constant float b[4] = {{-2.118377725, 1.442710462, -0.329097515, 0.012229801}};
constant float c[4] = {{-1.970840454, -1.624906493, 3.429567803, 1.641345311}};
constant float d[2] = {{3.543889200, 1.637067800}};
kernel void erfinv_kernel( device {0} *output [[buffer(0)]],
device {1} *input [[buffer(1)]],
uint index [[thread_position_in_grid]]) {{
float y = input[index];
float x, z, num, dem; /*working variables */
/* coefficients in rational expansion */
float y_abs = abs(y);
if (y_abs >= 1.0f) {{
output[index] = {0}( y_abs > 1.0f ? NAN : copysign(INFINITY, y));
return;
}}
if (y_abs <= 0.7f) {{
z = y * y;
num = ((a[3] * z + a[2]) * z + a[1])*z + a[0];
dem = (((b[3] * z + b[2]) * z + b[1]) * z +b[0]) * z + 1.0f;
x = y * num / dem;
}} else {{
z = sqrt(-1.0f*log((1.0-y_abs)/2.0));
num = ((c[3] * z + c[2]) * z + c[1]) * z + c[0];
dem = (d[1] * z + d[0]) * z + 1.0f;
x = copysign(num, y) / dem;
}}
output[index] = {0}(x);
}}
kernel void exp_kernel( device {0} *output [[buffer(0)]],
device {1} *input [[ buffer(1)]],
uint index [[thread_position_in_grid]]) {{
output[index] = {0}(precise::exp(input[index]));
}}
kernel void exp_complex_kernel( device {0}2 *output [[buffer(0)]],
device {0}2 *input [[ buffer(1)]],
uint index [[thread_position_in_grid]]) {{
output[index].x = {0}(precise::exp(input[index].x)*precise::cos(input[index].y));
output[index].y = {0}(precise::exp(input[index].x)*precise::sin(input[index].y));
}}
kernel void tanh_kernel( device {0} *output [[buffer(0)]],
device {1} *input [[ buffer(1)]],
uint index [[thread_position_in_grid]]) {{
output[index] = {0}(precise::tanh(input[index]));
}}
#if __METAL_VERSION__ >= 310
bfloat dot(bfloat2 a, bfloat2 b) {{
return a.x * b.x + a.y * b.y;
}}
#endif
template<typename T>
T complex_div(T a, T b) {{
auto denom = dot(b, b);
return T(dot(a, b), a.y * b.x - a.x * b.y)/denom;
}}
kernel void tanh_complex_kernel( device {0}2 *output [[buffer(0)]],
device {0}2 *input [[ buffer(1)]],
uint index [[thread_position_in_grid]]) {{
//tanh(x+iy)=(tanh(x)+itan(y))/(1+itahnh(x)*tan(y));
auto tanh_x = {0}(precise::tanh(input[index].x));
auto tan_y = {0}(precise::tan(input[index].y));
output[index] = complex_div({0}2(tanh_x, tan_y), {0}2({0}(1), tanh_x * tan_y));
}}
)METAL";

View file

@ -0,0 +1,135 @@
#include <c10/metal/utils.h>
#include <metal_stdlib>
using namespace c10::metal;
using namespace metal;
constant float a[4] = {0.886226899, -1.645349621, 0.914624893, -0.140543331};
constant float b[4] = {-2.118377725, 1.442710462, -0.329097515, 0.012229801};
constant float c[4] = {-1.970840454, -1.624906493, 3.429567803, 1.641345311};
constant float d[2] = {3.543889200, 1.637067800};
template <typename T0, typename T1>
kernel void erfinv_kernel(
device T0* output [[buffer(0)]],
device T1* input [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
float y = input[index];
float x, z, num, dem; /*working variables */
/* coefficients in rational expansion */
float y_abs = abs(y);
if (y_abs >= 1.0f) {
output[index] = T0(y_abs > 1.0f ? NAN : copysign(INFINITY, y));
return;
}
if (y_abs <= 0.7f) {
z = y * y;
num = ((a[3] * z + a[2]) * z + a[1]) * z + a[0];
dem = (((b[3] * z + b[2]) * z + b[1]) * z + b[0]) * z + 1.0f;
x = y * num / dem;
} else {
z = sqrt(-1.0f * log((1.0 - y_abs) / 2.0));
num = ((c[3] * z + c[2]) * z + c[1]) * z + c[0];
dem = (d[1] * z + d[0]) * z + 1.0f;
x = copysign(num, y) / dem;
}
output[index] = T0(x);
}
template <typename T0, typename T1>
kernel void exp_kernel(
device T0* output [[buffer(0)]],
device T1* input [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
output[index] = T0(precise::exp(input[index]));
}
template <typename T0>
kernel void exp_complex_kernel(
device vec2type_t<T0>* output [[buffer(0)]],
device vec2type_t<T0>* input [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
output[index].x =
T0(precise::exp(input[index].x) * precise::cos(input[index].y));
output[index].y =
T0(precise::exp(input[index].x) * precise::sin(input[index].y));
}
template <typename T0, typename T1>
kernel void tanh_kernel(
device T0* output [[buffer(0)]],
device T1* input [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
output[index] = T0(precise::tanh(input[index]));
}
#if __METAL_VERSION__ >= 310
bfloat dot(bfloat2 a, bfloat2 b) {
return a.x * b.x + a.y * b.y;
}
#endif
short dot(short2 a, short2 b) {
return a.x * b.x + a.y * b.y;
}
template <typename T>
T complex_div(T a, T b) {
auto denom = dot(b, b);
return T(dot(a, b), a.y * b.x - a.x * b.y) / denom;
}
template <typename T0>
kernel void tanh_complex_kernel(
device vec2type_t<T0>* output [[buffer(0)]],
device vec2type_t<T0>* input [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
// tanh(x+iy)=(tanh(x)+itan(y))/(1+itahnh(x)*tan(y));
auto tanh_x = T0(precise::tanh(input[index].x));
auto tan_y = T0(precise::tan(input[index].y));
output[index] = complex_div(
vec2type_t<T0>(tanh_x, tan_y), vec2type_t<T0>(T0(1), tanh_x * tan_y));
}
#define INSTANTIATE_UNARY_KERNELS2(DTYPE0, DTYPE1) \
template [[host_name("erfinv_" #DTYPE0 "_" #DTYPE1)]] kernel void \
erfinv_kernel( \
device DTYPE0* output [[buffer(0)]], \
device DTYPE1* input [[buffer(1)]], \
uint id [[thread_position_in_grid]]); \
template [[host_name("exp_" #DTYPE0 "_" #DTYPE1)]] kernel void exp_kernel( \
device DTYPE0* output [[buffer(0)]], \
device DTYPE1* input [[buffer(1)]], \
uint id [[thread_position_in_grid]]); \
template [[host_name("tanh_" #DTYPE0 "_" #DTYPE1)]] kernel void tanh_kernel( \
device DTYPE0* output [[buffer(0)]], \
device DTYPE1* input [[buffer(1)]], \
uint id [[thread_position_in_grid]]);
#if __METAL_VERSION__ >= 310
INSTANTIATE_UNARY_KERNELS2(bfloat, bfloat);
#endif
INSTANTIATE_UNARY_KERNELS2(half, half);
INSTANTIATE_UNARY_KERNELS2(float, float);
INSTANTIATE_UNARY_KERNELS2(float, bool);
INSTANTIATE_UNARY_KERNELS2(float, uchar);
INSTANTIATE_UNARY_KERNELS2(float, char);
INSTANTIATE_UNARY_KERNELS2(float, short);
INSTANTIATE_UNARY_KERNELS2(float, int);
INSTANTIATE_UNARY_KERNELS2(float, long);
#define INSTANTIATE_UNARY_KERNELS_VEC2(DTYPE0, DTYPE1) \
template [[host_name("exp_complex_" #DTYPE0 "_" #DTYPE1)]] kernel void \
exp_complex_kernel<DTYPE0>( \
device vec2type_t<DTYPE0> * output [[buffer(0)]], \
device vec2type_t<DTYPE0> * input [[buffer(1)]], \
uint did [[thread_position_in_grid]]); \
template [[host_name("tanh_complex_" #DTYPE0 "_" #DTYPE1)]] kernel void \
tanh_complex_kernel<DTYPE0>( \
device vec2type_t<DTYPE0> * output [[buffer(0)]], \
device vec2type_t<DTYPE0> * input [[buffer(1)]], \
uint did [[thread_position_in_grid]]);
INSTANTIATE_UNARY_KERNELS_VEC2(short, short);
INSTANTIATE_UNARY_KERNELS_VEC2(float, float);

View file

@ -2,7 +2,6 @@
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/mps/UnaryConstants.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
@ -15,7 +14,12 @@
#include <fmt/format.h>
namespace at::native {
static mps::MetalShaderLibrary lib(UNARY_KERNEL_TEMPLATE, 2);
#ifndef PYTORCH_JIT_COMPILE_SHADERS
static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
#else
#include <ATen/native/mps/UnaryKernel_metallib.h>
#endif
static void exec_unary_kernel(const Tensor& self, const Tensor& output_, const std::string& name) {
Tensor inputTensor = self.contiguous();
@ -30,10 +34,10 @@ static void exec_unary_kernel(const Tensor& self, const Tensor& output_, const s
id<MTLComputePipelineState> cplState = nil;
if (c10::isComplexType(self.scalar_type())) {
auto scalarStr = self.scalar_type() == kComplexFloat ? "float" : "half";
cplState = lib.getPipelineStateForFunc(name + "_complex_kernel", {scalarStr, scalarStr});
cplState = lib.getPipelineStateForFunc(fmt::format("{}_complex_{}_{}", name, scalarStr, scalarStr));
} else {
cplState = lib.getPipelineStateForFunc(name + "_kernel",
{scalarToMetalTypeString(outputTensor), scalarToMetalTypeString(self)});
cplState = lib.getPipelineStateForFunc(
fmt::format("{}_{}_{}", name, scalarToMetalTypeString(outputTensor), scalarToMetalTypeString(self)));
}
if (!outputTensor.is_contiguous()) {