From 20912ba582e65a6e6c0ab77352e84dd73021faa2 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Tue, 3 Dec 2024 10:49:40 -0800 Subject: [PATCH] fix incorrect c10::SymFloat::sqrt (#141728) Fixes the silent correctness for SDPA in https://github.com/pytorch/pytorch/issues/141710 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141728 Approved by: https://github.com/Skylion007, https://github.com/ezyang, https://github.com/drisspg ghstack dependencies: #141725 --- c10/core/SymFloat.cpp | 2 +- test/dynamo/test_repros.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/c10/core/SymFloat.cpp b/c10/core/SymFloat.cpp index 267f894c23a..61e91ec76e7 100644 --- a/c10/core/SymFloat.cpp +++ b/c10/core/SymFloat.cpp @@ -146,7 +146,7 @@ SymFloat SymFloat::sqrt() const { if (!is_symbolic()) { return SymFloat(std::sqrt(data_)); } - auto other = SymFloat(-0.5); + auto other = SymFloat(0.5); auto res = normalize_symfloats(*this, other); return SymFloat(res[0]->pow(res[1])); } diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 8b9697f0e05..6ba48b97d1b 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -6421,7 +6421,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor): with torch._dynamo.config.patch(assume_static_by_default=False): out_ref = f(x_ref, s0, s1, s2) out = f_compiled(x, s0, s1, s2) - self.assertFalse(torch.any(torch.isnan(out))) + self.assertEqual(out_ref, out) def test_bitwise_op_guard(self): # attempt evaluating a guard with BitwiseFn_bitwise_[and/or]