pytorch/test/inductor/test_graph_transform_observer.py
Xinya Zhang 770fcaf2ab Fix the Rank of logsumexp Tensor and mGPU support. (#137717)
The logsumexp tensor was considered for internal use only but apparently exposed to unit tests and inductors.

The stream should be selected after picking the current device. Otherwise the code is checking the default device's architecture.

Fixes #131316 #137414

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137717
Approved by: https://github.com/drisspg

Co-authored-by: Jack Taylor <108682042+jataylo@users.noreply.github.com>
2024-10-17 21:58:14 +00:00

72 lines
2.2 KiB
Python

# Owner(s): ["module: inductor"]
import glob
import math
import os
import shutil
import tempfile
import torch
import torch._dynamo
import torch._inductor.config as inductor_config
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FUSED_ATTENTION
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import HAS_CUDA
try:
import pydot # noqa: F401
HAS_PYDOT = True
except ImportError:
HAS_PYDOT = False
HAS_DOT = True if shutil.which("dot") is not None else False
class TestGraphTransformObserver(TestCase):
def test_sdpa_rewriter(self):
if not (
HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION and HAS_PYDOT and HAS_DOT
):
return
def dot_prod_attention(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
"""Input tensors assumed to have shape (batch_size, n_head, seq_len, embed_dim)"""
return (
torch.matmul(query, key.transpose(-2, -1))
.div(math.sqrt(key.shape[-1]))
.softmax(dim=-1)
.matmul(value)
)
log_url = tempfile.mkdtemp()
inductor_config.trace.log_url_for_graph_xform = log_url
inductor_config.force_disable_caches = True
compiled_fn = torch.compile(dot_prod_attention, fullgraph=True)
tensor_shape = (4, 2, 16, 32)
q = torch.randn(tensor_shape, device="cuda")
k = torch.randn(tensor_shape, device="cuda")
v = torch.randn(tensor_shape, device="cuda")
compiled_fn(q, k, v)
found_input_svg = False
found_output_svg = False
for filepath_object in glob.glob(log_url + "/*"):
if os.path.isfile(filepath_object):
if filepath_object.endswith("input_graph.dot"):
found_input_svg = True
elif filepath_object.endswith("output_graph.dot"):
found_output_svg = True
self.assertTrue(found_input_svg)
self.assertTrue(found_output_svg)
if __name__ == "__main__":
if IS_LINUX:
run_tests()