mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[BE] [mps] Refactor UnaryConstants to be its own kernel. (#145230)
In preparation for using this file for inductor (for erfinv). Pull Request resolved: https://github.com/pytorch/pytorch/pull/145230 Approved by: https://github.com/malfet
This commit is contained in:
parent
881eb86692
commit
e924ddbef1
3 changed files with 144 additions and 85 deletions
|
|
@ -1,80 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
const char* UNARY_KERNEL_TEMPLATE = R"METAL(
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
constant float a[4] = {{0.886226899, -1.645349621, 0.914624893, -0.140543331}};
|
||||
constant float b[4] = {{-2.118377725, 1.442710462, -0.329097515, 0.012229801}};
|
||||
constant float c[4] = {{-1.970840454, -1.624906493, 3.429567803, 1.641345311}};
|
||||
constant float d[2] = {{3.543889200, 1.637067800}};
|
||||
|
||||
kernel void erfinv_kernel( device {0} *output [[buffer(0)]],
|
||||
device {1} *input [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {{
|
||||
|
||||
float y = input[index];
|
||||
float x, z, num, dem; /*working variables */
|
||||
/* coefficients in rational expansion */
|
||||
|
||||
float y_abs = abs(y);
|
||||
if (y_abs >= 1.0f) {{
|
||||
output[index] = {0}( y_abs > 1.0f ? NAN : copysign(INFINITY, y));
|
||||
return;
|
||||
}}
|
||||
if (y_abs <= 0.7f) {{
|
||||
z = y * y;
|
||||
num = ((a[3] * z + a[2]) * z + a[1])*z + a[0];
|
||||
dem = (((b[3] * z + b[2]) * z + b[1]) * z +b[0]) * z + 1.0f;
|
||||
x = y * num / dem;
|
||||
}} else {{
|
||||
z = sqrt(-1.0f*log((1.0-y_abs)/2.0));
|
||||
num = ((c[3] * z + c[2]) * z + c[1]) * z + c[0];
|
||||
dem = (d[1] * z + d[0]) * z + 1.0f;
|
||||
x = copysign(num, y) / dem;
|
||||
}}
|
||||
|
||||
output[index] = {0}(x);
|
||||
}}
|
||||
|
||||
kernel void exp_kernel( device {0} *output [[buffer(0)]],
|
||||
device {1} *input [[ buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {{
|
||||
output[index] = {0}(precise::exp(input[index]));
|
||||
}}
|
||||
|
||||
kernel void exp_complex_kernel( device {0}2 *output [[buffer(0)]],
|
||||
device {0}2 *input [[ buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {{
|
||||
output[index].x = {0}(precise::exp(input[index].x)*precise::cos(input[index].y));
|
||||
output[index].y = {0}(precise::exp(input[index].x)*precise::sin(input[index].y));
|
||||
}}
|
||||
|
||||
kernel void tanh_kernel( device {0} *output [[buffer(0)]],
|
||||
device {1} *input [[ buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {{
|
||||
output[index] = {0}(precise::tanh(input[index]));
|
||||
}}
|
||||
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
bfloat dot(bfloat2 a, bfloat2 b) {{
|
||||
return a.x * b.x + a.y * b.y;
|
||||
}}
|
||||
#endif
|
||||
|
||||
template<typename T>
|
||||
T complex_div(T a, T b) {{
|
||||
auto denom = dot(b, b);
|
||||
return T(dot(a, b), a.y * b.x - a.x * b.y)/denom;
|
||||
}}
|
||||
|
||||
kernel void tanh_complex_kernel( device {0}2 *output [[buffer(0)]],
|
||||
device {0}2 *input [[ buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {{
|
||||
//tanh(x+iy)=(tanh(x)+itan(y))/(1+itahnh(x)*tan(y));
|
||||
auto tanh_x = {0}(precise::tanh(input[index].x));
|
||||
auto tan_y = {0}(precise::tan(input[index].y));
|
||||
output[index] = complex_div({0}2(tanh_x, tan_y), {0}2({0}(1), tanh_x * tan_y));
|
||||
}}
|
||||
)METAL";
|
||||
135
aten/src/ATen/native/mps/kernels/UnaryKernel.metal
Normal file
135
aten/src/ATen/native/mps/kernels/UnaryKernel.metal
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
#include <c10/metal/utils.h>
|
||||
#include <metal_stdlib>
|
||||
using namespace c10::metal;
|
||||
using namespace metal;
|
||||
|
||||
constant float a[4] = {0.886226899, -1.645349621, 0.914624893, -0.140543331};
|
||||
constant float b[4] = {-2.118377725, 1.442710462, -0.329097515, 0.012229801};
|
||||
constant float c[4] = {-1.970840454, -1.624906493, 3.429567803, 1.641345311};
|
||||
constant float d[2] = {3.543889200, 1.637067800};
|
||||
|
||||
template <typename T0, typename T1>
|
||||
kernel void erfinv_kernel(
|
||||
device T0* output [[buffer(0)]],
|
||||
device T1* input [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
float y = input[index];
|
||||
float x, z, num, dem; /*working variables */
|
||||
/* coefficients in rational expansion */
|
||||
|
||||
float y_abs = abs(y);
|
||||
if (y_abs >= 1.0f) {
|
||||
output[index] = T0(y_abs > 1.0f ? NAN : copysign(INFINITY, y));
|
||||
return;
|
||||
}
|
||||
if (y_abs <= 0.7f) {
|
||||
z = y * y;
|
||||
num = ((a[3] * z + a[2]) * z + a[1]) * z + a[0];
|
||||
dem = (((b[3] * z + b[2]) * z + b[1]) * z + b[0]) * z + 1.0f;
|
||||
x = y * num / dem;
|
||||
} else {
|
||||
z = sqrt(-1.0f * log((1.0 - y_abs) / 2.0));
|
||||
num = ((c[3] * z + c[2]) * z + c[1]) * z + c[0];
|
||||
dem = (d[1] * z + d[0]) * z + 1.0f;
|
||||
x = copysign(num, y) / dem;
|
||||
}
|
||||
|
||||
output[index] = T0(x);
|
||||
}
|
||||
|
||||
template <typename T0, typename T1>
|
||||
kernel void exp_kernel(
|
||||
device T0* output [[buffer(0)]],
|
||||
device T1* input [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
output[index] = T0(precise::exp(input[index]));
|
||||
}
|
||||
|
||||
template <typename T0>
|
||||
kernel void exp_complex_kernel(
|
||||
device vec2type_t<T0>* output [[buffer(0)]],
|
||||
device vec2type_t<T0>* input [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
output[index].x =
|
||||
T0(precise::exp(input[index].x) * precise::cos(input[index].y));
|
||||
output[index].y =
|
||||
T0(precise::exp(input[index].x) * precise::sin(input[index].y));
|
||||
}
|
||||
|
||||
template <typename T0, typename T1>
|
||||
kernel void tanh_kernel(
|
||||
device T0* output [[buffer(0)]],
|
||||
device T1* input [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
output[index] = T0(precise::tanh(input[index]));
|
||||
}
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
bfloat dot(bfloat2 a, bfloat2 b) {
|
||||
return a.x * b.x + a.y * b.y;
|
||||
}
|
||||
#endif
|
||||
|
||||
short dot(short2 a, short2 b) {
|
||||
return a.x * b.x + a.y * b.y;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T complex_div(T a, T b) {
|
||||
auto denom = dot(b, b);
|
||||
return T(dot(a, b), a.y * b.x - a.x * b.y) / denom;
|
||||
}
|
||||
|
||||
template <typename T0>
|
||||
kernel void tanh_complex_kernel(
|
||||
device vec2type_t<T0>* output [[buffer(0)]],
|
||||
device vec2type_t<T0>* input [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
// tanh(x+iy)=(tanh(x)+itan(y))/(1+itahnh(x)*tan(y));
|
||||
auto tanh_x = T0(precise::tanh(input[index].x));
|
||||
auto tan_y = T0(precise::tan(input[index].y));
|
||||
output[index] = complex_div(
|
||||
vec2type_t<T0>(tanh_x, tan_y), vec2type_t<T0>(T0(1), tanh_x * tan_y));
|
||||
}
|
||||
|
||||
#define INSTANTIATE_UNARY_KERNELS2(DTYPE0, DTYPE1) \
|
||||
template [[host_name("erfinv_" #DTYPE0 "_" #DTYPE1)]] kernel void \
|
||||
erfinv_kernel( \
|
||||
device DTYPE0* output [[buffer(0)]], \
|
||||
device DTYPE1* input [[buffer(1)]], \
|
||||
uint id [[thread_position_in_grid]]); \
|
||||
template [[host_name("exp_" #DTYPE0 "_" #DTYPE1)]] kernel void exp_kernel( \
|
||||
device DTYPE0* output [[buffer(0)]], \
|
||||
device DTYPE1* input [[buffer(1)]], \
|
||||
uint id [[thread_position_in_grid]]); \
|
||||
template [[host_name("tanh_" #DTYPE0 "_" #DTYPE1)]] kernel void tanh_kernel( \
|
||||
device DTYPE0* output [[buffer(0)]], \
|
||||
device DTYPE1* input [[buffer(1)]], \
|
||||
uint id [[thread_position_in_grid]]);
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_UNARY_KERNELS2(bfloat, bfloat);
|
||||
#endif
|
||||
INSTANTIATE_UNARY_KERNELS2(half, half);
|
||||
INSTANTIATE_UNARY_KERNELS2(float, float);
|
||||
INSTANTIATE_UNARY_KERNELS2(float, bool);
|
||||
INSTANTIATE_UNARY_KERNELS2(float, uchar);
|
||||
INSTANTIATE_UNARY_KERNELS2(float, char);
|
||||
INSTANTIATE_UNARY_KERNELS2(float, short);
|
||||
INSTANTIATE_UNARY_KERNELS2(float, int);
|
||||
INSTANTIATE_UNARY_KERNELS2(float, long);
|
||||
|
||||
#define INSTANTIATE_UNARY_KERNELS_VEC2(DTYPE0, DTYPE1) \
|
||||
template [[host_name("exp_complex_" #DTYPE0 "_" #DTYPE1)]] kernel void \
|
||||
exp_complex_kernel<DTYPE0>( \
|
||||
device vec2type_t<DTYPE0> * output [[buffer(0)]], \
|
||||
device vec2type_t<DTYPE0> * input [[buffer(1)]], \
|
||||
uint did [[thread_position_in_grid]]); \
|
||||
template [[host_name("tanh_complex_" #DTYPE0 "_" #DTYPE1)]] kernel void \
|
||||
tanh_complex_kernel<DTYPE0>( \
|
||||
device vec2type_t<DTYPE0> * output [[buffer(0)]], \
|
||||
device vec2type_t<DTYPE0> * input [[buffer(1)]], \
|
||||
uint did [[thread_position_in_grid]]);
|
||||
|
||||
INSTANTIATE_UNARY_KERNELS_VEC2(short, short);
|
||||
INSTANTIATE_UNARY_KERNELS_VEC2(float, float);
|
||||
|
|
@ -2,7 +2,6 @@
|
|||
#include <ATen/mps/MPSProfiler.h>
|
||||
#include <ATen/native/UnaryOps.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/mps/UnaryConstants.h>
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
|
|
@ -15,7 +14,12 @@
|
|||
#include <fmt/format.h>
|
||||
|
||||
namespace at::native {
|
||||
static mps::MetalShaderLibrary lib(UNARY_KERNEL_TEMPLATE, 2);
|
||||
|
||||
#ifndef PYTORCH_JIT_COMPILE_SHADERS
|
||||
static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
|
||||
#else
|
||||
#include <ATen/native/mps/UnaryKernel_metallib.h>
|
||||
#endif
|
||||
|
||||
static void exec_unary_kernel(const Tensor& self, const Tensor& output_, const std::string& name) {
|
||||
Tensor inputTensor = self.contiguous();
|
||||
|
|
@ -30,10 +34,10 @@ static void exec_unary_kernel(const Tensor& self, const Tensor& output_, const s
|
|||
id<MTLComputePipelineState> cplState = nil;
|
||||
if (c10::isComplexType(self.scalar_type())) {
|
||||
auto scalarStr = self.scalar_type() == kComplexFloat ? "float" : "half";
|
||||
cplState = lib.getPipelineStateForFunc(name + "_complex_kernel", {scalarStr, scalarStr});
|
||||
cplState = lib.getPipelineStateForFunc(fmt::format("{}_complex_{}_{}", name, scalarStr, scalarStr));
|
||||
} else {
|
||||
cplState = lib.getPipelineStateForFunc(name + "_kernel",
|
||||
{scalarToMetalTypeString(outputTensor), scalarToMetalTypeString(self)});
|
||||
cplState = lib.getPipelineStateForFunc(
|
||||
fmt::format("{}_{}_{}", name, scalarToMetalTypeString(outputTensor), scalarToMetalTypeString(self)));
|
||||
}
|
||||
|
||||
if (!outputTensor.is_contiguous()) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue