[mps/inductor] Introduce a metal approx for erf() and use it. (#145161)

Probably we can do better, but this is a start.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145161
Approved by: https://github.com/malfet
This commit is contained in:
Davide Italiano 2025-01-19 02:29:03 +00:00 committed by PyTorch MergeBot
parent 893ca1dfe1
commit 8cc415774f
3 changed files with 33 additions and 0 deletions

View file

@ -4,6 +4,32 @@
namespace c10 {
namespace metal {
// Translated to metal from https://www.johndcook.com/cpp_erf.html
template <typename T>
T erf(T x) {
T a1 = 0.254829592;
T a2 = -0.284496736;
T a3 = 1.421413741;
T a4 = -1.453152027;
T a5 = 1.061405429;
T p = 0.3275911;
// Save the sign of x
int sign = 1;
if (x < 0)
sign = -1;
x = ::metal::fabs(x);
// A&S formula 7.1.26
T t = 1.0 / (1.0 + p * x);
T y = 1.0 -
(((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t *
::metal::exp(-x * x);
return sign * y;
}
/*
* For licensing information and documentation, please refer to the cpu
* implementation located in "ATen/native/Math.h".

View file

@ -131,6 +131,9 @@ class MPSBasicTests(TestCase):
def test_pointwise_i1(self):
self.common(torch.special.i1, (torch.rand(128, 128),), check_lowp=False)
def test_pointwise_erf(self):
self.common(torch.special.erf, (torch.rand(128, 128),), check_lowp=False)
def test_broadcast(self):
self.common(torch.add, (torch.rand(32, 1024), torch.rand(1024)))

View file

@ -190,6 +190,10 @@ class MetalOverrides(OpOverrides):
def i1(x: CSEVariable) -> str:
return f"c10::metal::i1({x})"
@staticmethod
def erf(x: CSEVariable) -> str:
return f"c10::metal::erf({x})"
@staticmethod
def tan(x: CSEVariable) -> str:
return f"metal::tan({x})"