fix sfdp patern 13 accuracy issue (#110001)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110001
Approved by: https://github.com/eellison
This commit is contained in:
leslie-fang-intel 2023-09-25 17:56:02 -07:00 committed by PyTorch MergeBot
parent 2393864070
commit 0dcea70bfd
2 changed files with 18 additions and 8 deletions

View file

@ -1,4 +1,5 @@
# Owner(s): ["module: inductor"]
import functools
import itertools
import math
@ -543,7 +544,7 @@ class TestSDPAPatternRewriterTemplate(TestCase):
self._check_common(dot_prod_attention, check_train=False)
@skipIfRocm
def _test_sdpa_rewriter_13(self):
def _test_sdpa_rewriter_13(self, dtype):
def dot_prod_attention(
query: torch.Tensor,
key: torch.Tensor,
@ -559,13 +560,19 @@ class TestSDPAPatternRewriterTemplate(TestCase):
tensor_shape = (4, 8, 16)
args = [
torch.randn(tensor_shape, device=self.device, dtype=torch.half),
torch.randn(tensor_shape, device=self.device, dtype=torch.half),
torch.randn(tensor_shape, device=self.device, dtype=torch.half),
torch.randn(tensor_shape, device=self.device, dtype=dtype),
torch.randn(tensor_shape, device=self.device, dtype=dtype),
torch.randn(tensor_shape, device=self.device, dtype=dtype),
]
self._check_common(
dot_prod_attention, check_train=False, args1=args, has_dropout=True
dot_prod_attention,
check_train=False,
args1=args,
has_dropout=True,
override_check_equal=True,
atol=1e-2,
rtol=1e-2,
)
@ -621,8 +628,8 @@ if HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION:
test_sdpa_prev_13_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_13
test_sdpa_prev_14_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_14
test_sdpa_prev_15_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_15
test_sdpa_rewriter_13_cuda = (
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13
test_sdpa_rewriter_13_cuda = functools.partialmethod(
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13, dtype=torch.half
)
@ -651,6 +658,9 @@ if HAS_CPU:
test_sdpa_prev_13_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_13
test_sdpa_prev_14_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_14
test_sdpa_prev_15_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_15
test_sdpa_rewriter_13_cpu = functools.partialmethod(
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13, dtype=torch.float32
)
if __name__ == "__main__":

View file

@ -314,7 +314,7 @@ def _sfdp_replacement_13(query, key, value, dropout_p):
counters["inductor"]["fuse_attention"] += 1
return aten.scaled_dot_product_attention(
query.unsqueeze(0),
value.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
dropout_p=dropout_p,
scale=1.0,