mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[mps/inductor] Introduce a metal approx for erf() and use it. (#145161)
Probably we can do better, but this is a start. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145161 Approved by: https://github.com/malfet
This commit is contained in:
parent
893ca1dfe1
commit
8cc415774f
3 changed files with 33 additions and 0 deletions
|
|
@ -4,6 +4,32 @@
|
|||
namespace c10 {
|
||||
namespace metal {
|
||||
|
||||
// Translated to metal from https://www.johndcook.com/cpp_erf.html
|
||||
|
||||
template <typename T>
|
||||
T erf(T x) {
|
||||
T a1 = 0.254829592;
|
||||
T a2 = -0.284496736;
|
||||
T a3 = 1.421413741;
|
||||
T a4 = -1.453152027;
|
||||
T a5 = 1.061405429;
|
||||
T p = 0.3275911;
|
||||
|
||||
// Save the sign of x
|
||||
int sign = 1;
|
||||
if (x < 0)
|
||||
sign = -1;
|
||||
x = ::metal::fabs(x);
|
||||
|
||||
// A&S formula 7.1.26
|
||||
T t = 1.0 / (1.0 + p * x);
|
||||
T y = 1.0 -
|
||||
(((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t *
|
||||
::metal::exp(-x * x);
|
||||
|
||||
return sign * y;
|
||||
}
|
||||
|
||||
/*
|
||||
* For licensing information and documentation, please refer to the cpu
|
||||
* implementation located in "ATen/native/Math.h".
|
||||
|
|
|
|||
|
|
@ -131,6 +131,9 @@ class MPSBasicTests(TestCase):
|
|||
def test_pointwise_i1(self):
|
||||
self.common(torch.special.i1, (torch.rand(128, 128),), check_lowp=False)
|
||||
|
||||
def test_pointwise_erf(self):
|
||||
self.common(torch.special.erf, (torch.rand(128, 128),), check_lowp=False)
|
||||
|
||||
def test_broadcast(self):
|
||||
self.common(torch.add, (torch.rand(32, 1024), torch.rand(1024)))
|
||||
|
||||
|
|
|
|||
|
|
@ -190,6 +190,10 @@ class MetalOverrides(OpOverrides):
|
|||
def i1(x: CSEVariable) -> str:
|
||||
return f"c10::metal::i1({x})"
|
||||
|
||||
@staticmethod
|
||||
def erf(x: CSEVariable) -> str:
|
||||
return f"c10::metal::erf({x})"
|
||||
|
||||
@staticmethod
|
||||
def tan(x: CSEVariable) -> str:
|
||||
return f"metal::tan({x})"
|
||||
|
|
|
|||
Loading…
Reference in a new issue