From b133907d0ab2fe5250681ca3407b8c16fb74fdd5 Mon Sep 17 00:00:00 2001 From: drisspg Date: Fri, 7 Feb 2025 20:48:51 -0800 Subject: [PATCH 1/2] Update strided test to float32 (#146748) Fixes #146377 Pull Request resolved: https://github.com/pytorch/pytorch/pull/146748 Approved by: https://github.com/BoyuanFeng, https://github.com/leijurv --- test/inductor/test_flex_attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 8b4382061b0..99440593c2b 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -2510,9 +2510,9 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): @supported_platform def test_strided_backwards(self): shape = (1, 2, 4096, 64) - Q = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16) - K = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16) - V = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16) + Q = torch.randn(shape, requires_grad=True, device="cuda") + K = torch.randn(shape, requires_grad=True, device="cuda") + V = torch.randn(shape, requires_grad=True, device="cuda") func = torch.compile(flex_attention, dynamic=True, fullgraph=True) K_sliced = K[:, :, :-128] From 2a55311773bfb9e569aa672ac3322d21abc1af32 Mon Sep 17 00:00:00 2001 From: Davide Italiano Date: Sun, 9 Feb 2025 20:09:34 +0000 Subject: [PATCH 2/2] [cuda] Simplify the sinc function a bit. (#146774) `else` after `return` can be removed & the indentation can be reduced, for readability. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146774 Approved by: https://github.com/malfet --- aten/src/ATen/native/cuda/Math.cuh | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/cuda/Math.cuh b/aten/src/ATen/native/cuda/Math.cuh index b99e9d0c94d..2fe8f5dd2e3 100644 --- a/aten/src/ATen/native/cuda/Math.cuh +++ b/aten/src/ATen/native/cuda/Math.cuh @@ -758,11 +758,10 @@ const auto sinc_string = jiterator_stringify( T sinc(T a) { if (a == T(0)) { return T(1); - } else { - constexpr T pi = T(3.14159265358979323846L); - T product = pi * a; - return std::sin(product) / product; } + constexpr T pi = T(3.14159265358979323846L); + T product = pi * a; + return std::sin(product) / product; } ); // sinc_string