From 0dcea70bfd4ba2432db2c069f7c81dd6df9e474d Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Mon, 25 Sep 2023 17:56:02 -0700 Subject: [PATCH] fix sfdp patern 13 accuracy issue (#110001) Pull Request resolved: https://github.com/pytorch/pytorch/pull/110001 Approved by: https://github.com/eellison --- test/inductor/test_fused_attention.py | 24 +++++++++++++++------ torch/_inductor/fx_passes/fuse_attention.py | 2 +- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/test/inductor/test_fused_attention.py b/test/inductor/test_fused_attention.py index 41a01f9d039..e8878c3886b 100644 --- a/test/inductor/test_fused_attention.py +++ b/test/inductor/test_fused_attention.py @@ -1,4 +1,5 @@ # Owner(s): ["module: inductor"] +import functools import itertools import math @@ -543,7 +544,7 @@ class TestSDPAPatternRewriterTemplate(TestCase): self._check_common(dot_prod_attention, check_train=False) @skipIfRocm - def _test_sdpa_rewriter_13(self): + def _test_sdpa_rewriter_13(self, dtype): def dot_prod_attention( query: torch.Tensor, key: torch.Tensor, @@ -559,13 +560,19 @@ class TestSDPAPatternRewriterTemplate(TestCase): tensor_shape = (4, 8, 16) args = [ - torch.randn(tensor_shape, device=self.device, dtype=torch.half), - torch.randn(tensor_shape, device=self.device, dtype=torch.half), - torch.randn(tensor_shape, device=self.device, dtype=torch.half), + torch.randn(tensor_shape, device=self.device, dtype=dtype), + torch.randn(tensor_shape, device=self.device, dtype=dtype), + torch.randn(tensor_shape, device=self.device, dtype=dtype), ] self._check_common( - dot_prod_attention, check_train=False, args1=args, has_dropout=True + dot_prod_attention, + check_train=False, + args1=args, + has_dropout=True, + override_check_equal=True, + atol=1e-2, + rtol=1e-2, ) @@ -621,8 +628,8 @@ if HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION: test_sdpa_prev_13_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_13 test_sdpa_prev_14_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_14 test_sdpa_prev_15_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_15 - test_sdpa_rewriter_13_cuda = ( - TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13 + test_sdpa_rewriter_13_cuda = functools.partialmethod( + TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13, dtype=torch.half ) @@ -651,6 +658,9 @@ if HAS_CPU: test_sdpa_prev_13_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_13 test_sdpa_prev_14_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_14 test_sdpa_prev_15_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_15 + test_sdpa_rewriter_13_cpu = functools.partialmethod( + TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13, dtype=torch.float32 + ) if __name__ == "__main__": diff --git a/torch/_inductor/fx_passes/fuse_attention.py b/torch/_inductor/fx_passes/fuse_attention.py index 5813b8c1930..8001d589852 100644 --- a/torch/_inductor/fx_passes/fuse_attention.py +++ b/torch/_inductor/fx_passes/fuse_attention.py @@ -314,7 +314,7 @@ def _sfdp_replacement_13(query, key, value, dropout_p): counters["inductor"]["fuse_attention"] += 1 return aten.scaled_dot_product_attention( query.unsqueeze(0), - value.unsqueeze(0), + key.unsqueeze(0), value.unsqueeze(0), dropout_p=dropout_p, scale=1.0,