mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
fix sfdp patern 13 accuracy issue (#110001)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110001 Approved by: https://github.com/eellison
This commit is contained in:
parent
2393864070
commit
0dcea70bfd
2 changed files with 18 additions and 8 deletions
|
|
@ -1,4 +1,5 @@
|
|||
# Owner(s): ["module: inductor"]
|
||||
import functools
|
||||
import itertools
|
||||
import math
|
||||
|
||||
|
|
@ -543,7 +544,7 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
|||
self._check_common(dot_prod_attention, check_train=False)
|
||||
|
||||
@skipIfRocm
|
||||
def _test_sdpa_rewriter_13(self):
|
||||
def _test_sdpa_rewriter_13(self, dtype):
|
||||
def dot_prod_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
|
|
@ -559,13 +560,19 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
|||
|
||||
tensor_shape = (4, 8, 16)
|
||||
args = [
|
||||
torch.randn(tensor_shape, device=self.device, dtype=torch.half),
|
||||
torch.randn(tensor_shape, device=self.device, dtype=torch.half),
|
||||
torch.randn(tensor_shape, device=self.device, dtype=torch.half),
|
||||
torch.randn(tensor_shape, device=self.device, dtype=dtype),
|
||||
torch.randn(tensor_shape, device=self.device, dtype=dtype),
|
||||
torch.randn(tensor_shape, device=self.device, dtype=dtype),
|
||||
]
|
||||
|
||||
self._check_common(
|
||||
dot_prod_attention, check_train=False, args1=args, has_dropout=True
|
||||
dot_prod_attention,
|
||||
check_train=False,
|
||||
args1=args,
|
||||
has_dropout=True,
|
||||
override_check_equal=True,
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -621,8 +628,8 @@ if HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION:
|
|||
test_sdpa_prev_13_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_13
|
||||
test_sdpa_prev_14_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_14
|
||||
test_sdpa_prev_15_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_15
|
||||
test_sdpa_rewriter_13_cuda = (
|
||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13
|
||||
test_sdpa_rewriter_13_cuda = functools.partialmethod(
|
||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13, dtype=torch.half
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -651,6 +658,9 @@ if HAS_CPU:
|
|||
test_sdpa_prev_13_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_13
|
||||
test_sdpa_prev_14_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_14
|
||||
test_sdpa_prev_15_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_15
|
||||
test_sdpa_rewriter_13_cpu = functools.partialmethod(
|
||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13, dtype=torch.float32
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -314,7 +314,7 @@ def _sfdp_replacement_13(query, key, value, dropout_p):
|
|||
counters["inductor"]["fuse_attention"] += 1
|
||||
return aten.scaled_dot_product_attention(
|
||||
query.unsqueeze(0),
|
||||
value.unsqueeze(0),
|
||||
key.unsqueeze(0),
|
||||
value.unsqueeze(0),
|
||||
dropout_p=dropout_p,
|
||||
scale=1.0,
|
||||
|
|
|
|||
Loading…
Reference in a new issue