diff --git a/aten/src/ATen/native/cuda/Math.cuh b/aten/src/ATen/native/cuda/Math.cuh index b99e9d0c94d..2fe8f5dd2e3 100644 --- a/aten/src/ATen/native/cuda/Math.cuh +++ b/aten/src/ATen/native/cuda/Math.cuh @@ -758,11 +758,10 @@ const auto sinc_string = jiterator_stringify( T sinc(T a) { if (a == T(0)) { return T(1); - } else { - constexpr T pi = T(3.14159265358979323846L); - T product = pi * a; - return std::sin(product) / product; } + constexpr T pi = T(3.14159265358979323846L); + T product = pi * a; + return std::sin(product) / product; } ); // sinc_string diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 8b4382061b0..99440593c2b 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -2510,9 +2510,9 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): @supported_platform def test_strided_backwards(self): shape = (1, 2, 4096, 64) - Q = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16) - K = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16) - V = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16) + Q = torch.randn(shape, requires_grad=True, device="cuda") + K = torch.randn(shape, requires_grad=True, device="cuda") + V = torch.randn(shape, requires_grad=True, device="cuda") func = torch.compile(flex_attention, dynamic=True, fullgraph=True) K_sliced = K[:, :, :-128]