mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
The logsumexp tensor was considered for internal use only but apparently exposed to unit tests and inductors. The stream should be selected after picking the current device. Otherwise the code is checking the default device's architecture. Fixes #131316 #137414 Pull Request resolved: https://github.com/pytorch/pytorch/pull/137717 Approved by: https://github.com/drisspg Co-authored-by: Jack Taylor <108682042+jataylo@users.noreply.github.com>
72 lines
2.2 KiB
Python
72 lines
2.2 KiB
Python
# Owner(s): ["module: inductor"]
|
|
import glob
|
|
import math
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
|
|
import torch
|
|
import torch._dynamo
|
|
import torch._inductor.config as inductor_config
|
|
from torch._inductor.test_case import run_tests, TestCase
|
|
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FUSED_ATTENTION
|
|
from torch.testing._internal.common_utils import IS_LINUX
|
|
from torch.testing._internal.inductor_utils import HAS_CUDA
|
|
|
|
|
|
try:
|
|
import pydot # noqa: F401
|
|
|
|
HAS_PYDOT = True
|
|
except ImportError:
|
|
HAS_PYDOT = False
|
|
|
|
|
|
HAS_DOT = True if shutil.which("dot") is not None else False
|
|
|
|
|
|
class TestGraphTransformObserver(TestCase):
|
|
def test_sdpa_rewriter(self):
|
|
if not (
|
|
HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION and HAS_PYDOT and HAS_DOT
|
|
):
|
|
return
|
|
|
|
def dot_prod_attention(
|
|
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
|
) -> torch.Tensor:
|
|
"""Input tensors assumed to have shape (batch_size, n_head, seq_len, embed_dim)"""
|
|
return (
|
|
torch.matmul(query, key.transpose(-2, -1))
|
|
.div(math.sqrt(key.shape[-1]))
|
|
.softmax(dim=-1)
|
|
.matmul(value)
|
|
)
|
|
|
|
log_url = tempfile.mkdtemp()
|
|
inductor_config.trace.log_url_for_graph_xform = log_url
|
|
inductor_config.force_disable_caches = True
|
|
compiled_fn = torch.compile(dot_prod_attention, fullgraph=True)
|
|
|
|
tensor_shape = (4, 2, 16, 32)
|
|
q = torch.randn(tensor_shape, device="cuda")
|
|
k = torch.randn(tensor_shape, device="cuda")
|
|
v = torch.randn(tensor_shape, device="cuda")
|
|
compiled_fn(q, k, v)
|
|
|
|
found_input_svg = False
|
|
found_output_svg = False
|
|
for filepath_object in glob.glob(log_url + "/*"):
|
|
if os.path.isfile(filepath_object):
|
|
if filepath_object.endswith("input_graph.dot"):
|
|
found_input_svg = True
|
|
elif filepath_object.endswith("output_graph.dot"):
|
|
found_output_svg = True
|
|
|
|
self.assertTrue(found_input_svg)
|
|
self.assertTrue(found_output_svg)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if IS_LINUX:
|
|
run_tests()
|