mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[Inductor Cutlass backend] Move tests to separate file (#121489)
Move Cutlass backend related tests to test/inductor/test_cutlass_backend.py - no changes to the tests themselves. Pull Request resolved: https://github.com/pytorch/pytorch/pull/121489 Approved by: https://github.com/jansel
This commit is contained in:
parent
844bfbbd2e
commit
a5ec45f2ec
3 changed files with 360 additions and 313 deletions
358
test/inductor/test_cutlass_backend.py
Normal file
358
test/inductor/test_cutlass_backend.py
Normal file
|
|
@ -0,0 +1,358 @@
|
|||
# Owner(s): ["module: inductor"]
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from typing import Callable, List
|
||||
|
||||
import torch
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor import config
|
||||
from torch.testing._internal.common_cuda import SM75OrLater, SM90OrLater
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
)
|
||||
|
||||
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
|
||||
|
||||
torch.set_float32_matmul_precision("high")
|
||||
if HAS_CUDA:
|
||||
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
|
||||
|
||||
_CUTLASS_DIR = os.path.join(os.path.dirname(__file__), "../../third_party/cutlass/")
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_path_without_sccache() -> str:
|
||||
"""
|
||||
Get the PATH environment variable without sccache.
|
||||
"""
|
||||
path_envs = os.environ.get("PATH", "").split(":")
|
||||
path_envs = [env for env in path_envs if "/opt/cache/bin" not in env]
|
||||
return ":".join(path_envs)
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class TestCutlassBackend(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
torch.random.manual_seed(1234)
|
||||
|
||||
@unittest.skipIf(not SM75OrLater, "need sm_75")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
def test_max_autotune_precompile(self):
|
||||
"""
|
||||
Make sure autotuning mm in sub processes work without crashes.
|
||||
"""
|
||||
|
||||
if torch.version.hip:
|
||||
return
|
||||
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
|
||||
def mm(a, b):
|
||||
return a @ b
|
||||
|
||||
a = torch.randn(100, 10).cuda().half()
|
||||
b = torch.randn(10, 100).cuda().half()
|
||||
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"autotune_in_subproc": True,
|
||||
"max_autotune_gemm_backends": "CUTLASS,Triton,ATen",
|
||||
"compile_threads": 4,
|
||||
"cuda.cutlass_dir": _CUTLASS_DIR,
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
}
|
||||
):
|
||||
Y_compiled = torch.compile(mm, dynamic=False)(a, b)
|
||||
Y = mm(a, b)
|
||||
torch.testing.assert_close(Y_compiled, Y)
|
||||
|
||||
# TODO: Enable dynamic test cases when dynamic support is added.
|
||||
@unittest.skipIf(not SM75OrLater, "need sm_75")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
@parametrize("dynamic", (False,))
|
||||
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen,Triton,CUTLASS"))
|
||||
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
def test_max_autotune_cutlass_backend_regular_mm(
|
||||
self, dynamic: bool, max_autotune_gemm_backends: str
|
||||
):
|
||||
"""
|
||||
Make sure autotuning mm in sub processes work without crashes.
|
||||
"""
|
||||
|
||||
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
|
||||
return
|
||||
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
|
||||
def mm(a, b):
|
||||
return a @ b
|
||||
|
||||
a = torch.randn(100, 10).cuda().half()
|
||||
b = torch.randn(10, 100).cuda().half()
|
||||
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"autotune_in_subproc": False,
|
||||
"max_autotune_gemm_backends": max_autotune_gemm_backends,
|
||||
"cuda.cutlass_dir": _CUTLASS_DIR,
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
}
|
||||
):
|
||||
Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
|
||||
Y = mm(a, b)
|
||||
torch.testing.assert_close(Y_compiled, Y)
|
||||
|
||||
def _test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
self,
|
||||
dynamic: bool = False,
|
||||
max_autotune_gemm_backends: str = "CUTLASS",
|
||||
mixed_precision=False,
|
||||
fp16=True,
|
||||
expected_fuse_count=1,
|
||||
mm: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
|
||||
):
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
|
||||
mixed_precision
|
||||
)
|
||||
|
||||
# Note: The ops that are available
|
||||
# also depend on the alignment of the shapes
|
||||
# so if these shapes don't all align to at least 8 elements
|
||||
# it can happen that no Cutlass 3.x op is available
|
||||
# that allows fusions
|
||||
a = torch.randn(256, 32).cuda()
|
||||
b = torch.randn(32, 256).cuda()
|
||||
if fp16:
|
||||
a = a.half()
|
||||
b = b.half()
|
||||
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"autotune_in_subproc": False,
|
||||
"max_autotune_gemm_backends": max_autotune_gemm_backends,
|
||||
"cuda.cutlass_dir": _CUTLASS_DIR,
|
||||
"cuda.cutlass_max_profiling_configs": 4,
|
||||
"cuda.cutlass_only_evt_capable_ops": True,
|
||||
"cuda.version": "12.2", # required to enable the Kernels we need
|
||||
}
|
||||
):
|
||||
counters["inductor"]["cuda_epilogue_fusion_counter"] = 0
|
||||
Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
|
||||
Y = mm(a, b)
|
||||
actual_count = counters["inductor"]["cuda_epilogue_fusion_counter"]
|
||||
assert (
|
||||
actual_count == expected_fuse_count
|
||||
), f"Expected fuse count of {expected_fuse_count} but got {actual_count}"
|
||||
torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(torch.version.hip, "HIP not supported")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
def test_max_autotune_cutlass_backend_simple_fusion_fp16(self):
|
||||
def mm(a, b):
|
||||
return (a @ b) * 3.0
|
||||
|
||||
# The pointwise ops seem to be pre-fused into a single Pointwise
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
mixed_precision=False, fp16=True, expected_fuse_count=1, mm=mm
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(torch.version.hip, "HIP not supported")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
def test_max_autotune_cutlass_backend_simple_fusion_fp16_fp32acc(self):
|
||||
def mm(a, b):
|
||||
return (a @ b) * 3.0
|
||||
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(torch.version.hip, "HIP not supported")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
def test_max_autotune_cutlass_backend_chained_fusion_fp16(self):
|
||||
def mm(a, b):
|
||||
return (a @ b) * 3.3 - 1.234
|
||||
|
||||
# The pointwise ops seem to be pre-fused into a single Pointwise
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
mixed_precision=False, fp16=True, expected_fuse_count=1, mm=mm
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(torch.version.hip, "HIP not supported")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
def test_max_autotune_cutlass_backend_chained_fusion_fp16_fp32acc(self):
|
||||
def mm(a, b):
|
||||
return (a @ b) * 3.3 - 1.234
|
||||
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(torch.version.hip, "HIP not supported")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
def test_max_autotune_cutlass_backend_relu_fusion_fp16(self):
|
||||
def mm(a, b):
|
||||
return torch.nn.functional.relu((a @ b) * 3.3 - 1.234)
|
||||
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
mixed_precision=False, fp16=True, expected_fuse_count=1, mm=mm
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(torch.version.hip, "HIP not supported")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
def test_max_autotune_cutlass_backend_relu_fusion_fp16_fp32acc(self):
|
||||
def mm(a, b):
|
||||
return torch.nn.functional.relu((a @ b) * 3.3 - 1.234)
|
||||
|
||||
# The pointwise ops seem to be pre-fused into a single Pointwise
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(torch.version.hip, "HIP not supported")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
def test_max_autotune_cutlass_backend_relu6_fusion_fp16_fp32acc(self):
|
||||
def mm(a, b):
|
||||
return torch.clamp(torch.nn.functional.relu(a @ b), max=6.0)
|
||||
|
||||
# The pointwise ops seem to be pre-fused into a single Pointwise
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(torch.version.hip, "HIP not supported")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
def test_max_autotune_cutlass_backend_no_fusion_dtype_mismatch(self):
|
||||
def mm(a, b):
|
||||
# this should not be fused, since the output dtype is different from the matmul dtype
|
||||
return (a @ b).to(torch.float32) * 0.00001
|
||||
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(torch.version.hip, "HIP not supported")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
def test_max_autotune_cutlass_backend_shape_dependent_normalization_fusion(self):
|
||||
def mm(a, b):
|
||||
return (a @ b) / b.size(1)
|
||||
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm
|
||||
)
|
||||
|
||||
# TODO: Enable dynamic test cases when dynamic support is added.
|
||||
@unittest.skipIf(not SM75OrLater, "need sm_75")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
@parametrize("dynamic", (False,))
|
||||
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen,Triton,CUTLASS"))
|
||||
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
def test_max_autotune_cutlass_backend_mm_bias(
|
||||
self, dynamic: bool = False, max_autotune_gemm_backends: str = "CUTLASS"
|
||||
):
|
||||
"""
|
||||
Make sure autotuning mm in sub processes work without crashes.
|
||||
"""
|
||||
|
||||
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
|
||||
return
|
||||
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
|
||||
def mm(a, b, bias):
|
||||
return torch.nn.functional.linear(a, b, bias)
|
||||
|
||||
a = torch.randn(2048, 4096).cuda().half()
|
||||
bias = torch.randn(2048).cuda().half()
|
||||
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"autotune_in_subproc": False,
|
||||
"max_autotune_gemm_backends": max_autotune_gemm_backends,
|
||||
"cuda.cutlass_dir": _CUTLASS_DIR,
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
}
|
||||
):
|
||||
Y = mm(a, a, bias)
|
||||
Y_compiled = torch.compile(mm, dynamic=dynamic)(a, a, bias)
|
||||
torch.testing.assert_close(Y_compiled, Y, atol=1e-1, rtol=1e-1)
|
||||
|
||||
# TODO: Enable dynamic test cases when dynamic support is added.
|
||||
@unittest.skipIf(not SM75OrLater, "need sm_75")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
@parametrize("dynamic", (False,))
|
||||
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen,Triton,CUTLASS"))
|
||||
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
def test_max_autotune_cutlass_backend_addmm(
|
||||
self, dynamic, max_autotune_gemm_backends
|
||||
):
|
||||
"""
|
||||
Make sure autotuning addmm in sub processes work without crashes.
|
||||
"""
|
||||
|
||||
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
|
||||
return
|
||||
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
|
||||
def addmm(x, a, b, alpha, beta):
|
||||
return torch.addmm(x, a, b, alpha=alpha, beta=beta)
|
||||
|
||||
def compare_results(
|
||||
m: int, k: int, n: int, alpha: float, beta: float, x_shape: List[int]
|
||||
) -> None:
|
||||
x = torch.randn(x_shape).cuda().half()
|
||||
a = torch.randn(m, k).cuda().half()
|
||||
b = torch.randn(k, n).cuda().half()
|
||||
y_expected = addmm(x, a, b, alpha, beta)
|
||||
|
||||
compiled_fn = torch.compile(addmm, dynamic=dynamic)
|
||||
y = compiled_fn(x, a, b, alpha, beta)
|
||||
torch.testing.assert_close(y, y_expected)
|
||||
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"autotune_in_subproc": False,
|
||||
"max_autotune_gemm_backends": max_autotune_gemm_backends,
|
||||
"cuda.cutlass_dir": _CUTLASS_DIR,
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
}
|
||||
):
|
||||
# No broadcast
|
||||
compare_results(4096, 25728, 2048, 2.0, 0.4, [4096, 2048])
|
||||
# Broadcast first dim.
|
||||
compare_results(4096, 25728, 2048, 2.0, 0.4, [2048])
|
||||
# Broadcast last dim.
|
||||
if not SM90OrLater and max_autotune_gemm_backends == "CUTLASS":
|
||||
with self.assertRaisesRegex(RuntimeError, "No choices to select"):
|
||||
# CUTLASS2 doesn't support Bias last-dim broadcast.
|
||||
compare_results(4096, 25728, 2048, 2.0, 0.4, [4096, 1])
|
||||
else:
|
||||
compare_results(4096, 25728, 2048, 2.0, 0.4, [4096, 1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.utils import is_big_gpu
|
||||
|
||||
# Set env to make it work in CI.
|
||||
if HAS_CUDA and HAS_CPU and is_big_gpu(0):
|
||||
run_tests()
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
# Owner(s): ["module: inductor"]
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
|
|
@ -9,7 +8,6 @@ from torch import multiprocessing as mp
|
|||
from torch._dynamo import reset
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._dynamo.testing import reset_rng_state
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor import config
|
||||
from torch._inductor.autotune_process import (
|
||||
BenchmarkRequest,
|
||||
|
|
@ -29,7 +27,6 @@ from torch._inductor.utils import run_and_get_code
|
|||
from torch._inductor.virtualized import V
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_cuda import SM75OrLater, SM90OrLater
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
|
|
@ -361,261 +358,6 @@ class TestMaxAutotune(TestCase):
|
|||
finally:
|
||||
V.set_debug_handler(old_debug_handler)
|
||||
|
||||
@unittest.skipIf(not SM75OrLater, "need sm_75")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
def test_max_autotune_precompile(self):
|
||||
"""
|
||||
Make sure autotuning mm in sub processes work without crashes.
|
||||
"""
|
||||
|
||||
if torch.version.hip:
|
||||
return
|
||||
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
|
||||
def mm(a, b):
|
||||
return a @ b
|
||||
|
||||
a = torch.randn(100, 10).cuda().half()
|
||||
b = torch.randn(10, 100).cuda().half()
|
||||
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"autotune_in_subproc": True,
|
||||
"max_autotune_gemm_backends": "CUTLASS,Triton,ATen",
|
||||
"compile_threads": 4,
|
||||
"cuda.cutlass_dir": _CUTLASS_DIR,
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
}
|
||||
):
|
||||
Y_compiled = torch.compile(mm, dynamic=False)(a, b)
|
||||
Y = mm(a, b)
|
||||
torch.testing.assert_close(Y_compiled, Y)
|
||||
|
||||
# TODO: Enable dynamic test cases when dynamic support is added.
|
||||
@unittest.skipIf(not SM75OrLater, "need sm_75")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
@parametrize("dynamic", (False,))
|
||||
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen,Triton,CUTLASS"))
|
||||
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
def test_max_autotune_cutlass_backend_regular_mm(
|
||||
self, dynamic: bool, max_autotune_gemm_backends: str
|
||||
):
|
||||
"""
|
||||
Make sure autotuning mm in sub processes work without crashes.
|
||||
"""
|
||||
|
||||
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
|
||||
return
|
||||
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
|
||||
def mm(a, b):
|
||||
return a @ b
|
||||
|
||||
a = torch.randn(100, 10).cuda().half()
|
||||
b = torch.randn(10, 100).cuda().half()
|
||||
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"autotune_in_subproc": False,
|
||||
"max_autotune_gemm_backends": max_autotune_gemm_backends,
|
||||
"cuda.cutlass_dir": _CUTLASS_DIR,
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
}
|
||||
):
|
||||
Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
|
||||
Y = mm(a, b)
|
||||
torch.testing.assert_close(Y_compiled, Y)
|
||||
|
||||
def _test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
self,
|
||||
dynamic: bool = False,
|
||||
max_autotune_gemm_backends: str = "CUTLASS",
|
||||
mixed_precision=False,
|
||||
fp16=True,
|
||||
expected_fuse_count=1,
|
||||
mm: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
|
||||
):
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
|
||||
mixed_precision
|
||||
)
|
||||
|
||||
# Note: The ops that are available
|
||||
# also depend on the alignment of the shapes
|
||||
# so if these shapes don't all align to at least 8 elements
|
||||
# it can happen that no Cutlass 3.x op is available
|
||||
# that allows fusions
|
||||
a = torch.randn(256, 32).cuda()
|
||||
b = torch.randn(32, 256).cuda()
|
||||
if fp16:
|
||||
a = a.half()
|
||||
b = b.half()
|
||||
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"autotune_in_subproc": False,
|
||||
"max_autotune_gemm_backends": max_autotune_gemm_backends,
|
||||
"cuda.cutlass_dir": _CUTLASS_DIR,
|
||||
"cuda.cutlass_max_profiling_configs": 4,
|
||||
"cuda.cutlass_only_evt_capable_ops": True,
|
||||
"cuda.version": "12.2", # required to enable the Kernels we need
|
||||
}
|
||||
):
|
||||
counters["inductor"]["cuda_epilogue_fusion_counter"] = 0
|
||||
Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
|
||||
Y = mm(a, b)
|
||||
actual_count = counters["inductor"]["cuda_epilogue_fusion_counter"]
|
||||
assert (
|
||||
actual_count == expected_fuse_count
|
||||
), f"Expected fuse count of {expected_fuse_count} but got {actual_count}"
|
||||
torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(torch.version.hip, "HIP not supported")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
def test_max_autotune_cutlass_backend_simple_fusion_fp16(self):
|
||||
def mm(a, b):
|
||||
return (a @ b) * 3.0
|
||||
|
||||
# The pointwise ops seem to be pre-fused into a single Pointwise
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
mixed_precision=False, fp16=True, expected_fuse_count=1, mm=mm
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(torch.version.hip, "HIP not supported")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
def test_max_autotune_cutlass_backend_simple_fusion_fp16_fp32acc(self):
|
||||
def mm(a, b):
|
||||
return (a @ b) * 3.0
|
||||
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(torch.version.hip, "HIP not supported")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
def test_max_autotune_cutlass_backend_chained_fusion_fp16(self):
|
||||
def mm(a, b):
|
||||
return (a @ b) * 3.3 - 1.234
|
||||
|
||||
# The pointwise ops seem to be pre-fused into a single Pointwise
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
mixed_precision=False, fp16=True, expected_fuse_count=1, mm=mm
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(torch.version.hip, "HIP not supported")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
def test_max_autotune_cutlass_backend_chained_fusion_fp16_fp32acc(self):
|
||||
def mm(a, b):
|
||||
return (a @ b) * 3.3 - 1.234
|
||||
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(torch.version.hip, "HIP not supported")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
def test_max_autotune_cutlass_backend_relu_fusion_fp16(self):
|
||||
def mm(a, b):
|
||||
return torch.nn.functional.relu((a @ b) * 3.3 - 1.234)
|
||||
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
mixed_precision=False, fp16=True, expected_fuse_count=1, mm=mm
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(torch.version.hip, "HIP not supported")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
def test_max_autotune_cutlass_backend_relu_fusion_fp16_fp32acc(self):
|
||||
def mm(a, b):
|
||||
return torch.nn.functional.relu((a @ b) * 3.3 - 1.234)
|
||||
|
||||
# The pointwise ops seem to be pre-fused into a single Pointwise
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(torch.version.hip, "HIP not supported")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
def test_max_autotune_cutlass_backend_relu6_fusion_fp16_fp32acc(self):
|
||||
def mm(a, b):
|
||||
return torch.clamp(torch.nn.functional.relu(a @ b), max=6.0)
|
||||
|
||||
# The pointwise ops seem to be pre-fused into a single Pointwise
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(torch.version.hip, "HIP not supported")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
def test_max_autotune_cutlass_backend_no_fusion_dtype_mismatch(self):
|
||||
def mm(a, b):
|
||||
# this should not be fused, since the output dtype is different from the matmul dtype
|
||||
return (a @ b).to(torch.float32) * 0.00001
|
||||
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(torch.version.hip, "HIP not supported")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
def test_max_autotune_cutlass_backend_shape_dependent_normalization_fusion(self):
|
||||
def mm(a, b):
|
||||
return (a @ b) / b.size(1)
|
||||
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm
|
||||
)
|
||||
|
||||
# TODO: Enable dynamic test cases when dynamic support is added.
|
||||
@unittest.skipIf(not SM75OrLater, "need sm_75")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
@parametrize("dynamic", (False,))
|
||||
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen,Triton,CUTLASS"))
|
||||
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
def test_max_autotune_cutlass_backend_mm_bias(
|
||||
self, dynamic: bool = False, max_autotune_gemm_backends: str = "CUTLASS"
|
||||
):
|
||||
"""
|
||||
Make sure autotuning mm in sub processes work without crashes.
|
||||
"""
|
||||
|
||||
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
|
||||
return
|
||||
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
|
||||
def mm(a, b, bias):
|
||||
return torch.nn.functional.linear(a, b, bias)
|
||||
|
||||
a = torch.randn(2048, 4096).cuda().half()
|
||||
bias = torch.randn(2048).cuda().half()
|
||||
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"autotune_in_subproc": False,
|
||||
"max_autotune_gemm_backends": max_autotune_gemm_backends,
|
||||
"cuda.cutlass_dir": _CUTLASS_DIR,
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
}
|
||||
):
|
||||
Y = mm(a, a, bias)
|
||||
Y_compiled = torch.compile(mm, dynamic=dynamic)(a, a, bias)
|
||||
torch.testing.assert_close(Y_compiled, Y, atol=1e-1, rtol=1e-1)
|
||||
|
||||
@parametrize("dynamic", (False, True))
|
||||
def test_max_autotune_addmm(self, dynamic=False):
|
||||
"""
|
||||
|
|
@ -650,60 +392,6 @@ class TestMaxAutotune(TestCase):
|
|||
with config.patch({"max_autotune": True}):
|
||||
torch.compile(addmm, dynamic=dynamic)(x, a, b)
|
||||
|
||||
# TODO: Enable dynamic test cases when dynamic support is added.
|
||||
@unittest.skipIf(not SM75OrLater, "need sm_75")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
@parametrize("dynamic", (False,))
|
||||
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen,Triton,CUTLASS"))
|
||||
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
def test_max_autotune_cutlass_backend_addmm(
|
||||
self, dynamic, max_autotune_gemm_backends
|
||||
):
|
||||
"""
|
||||
Make sure autotuning addmm in sub processes work without crashes.
|
||||
"""
|
||||
|
||||
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
|
||||
return
|
||||
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
|
||||
def addmm(x, a, b, alpha, beta):
|
||||
return torch.addmm(x, a, b, alpha=alpha, beta=beta)
|
||||
|
||||
def compare_results(
|
||||
m: int, k: int, n: int, alpha: float, beta: float, x_shape: List[int]
|
||||
) -> None:
|
||||
x = torch.randn(x_shape).cuda().half()
|
||||
a = torch.randn(m, k).cuda().half()
|
||||
b = torch.randn(k, n).cuda().half()
|
||||
y_expected = addmm(x, a, b, alpha, beta)
|
||||
|
||||
compiled_fn = torch.compile(addmm, dynamic=dynamic)
|
||||
y = compiled_fn(x, a, b, alpha, beta)
|
||||
torch.testing.assert_close(y, y_expected)
|
||||
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"autotune_in_subproc": False,
|
||||
"max_autotune_gemm_backends": max_autotune_gemm_backends,
|
||||
"cuda.cutlass_dir": _CUTLASS_DIR,
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
}
|
||||
):
|
||||
# No broadcast
|
||||
compare_results(4096, 25728, 2048, 2.0, 0.4, [4096, 2048])
|
||||
# Broadcast first dim.
|
||||
compare_results(4096, 25728, 2048, 2.0, 0.4, [2048])
|
||||
# Broadcast last dim.
|
||||
if not SM90OrLater and max_autotune_gemm_backends == "CUTLASS":
|
||||
with self.assertRaisesRegex(RuntimeError, "No choices to select"):
|
||||
# CUTLASS2 doesn't support Bias last-dim broadcast.
|
||||
compare_results(4096, 25728, 2048, 2.0, 0.4, [4096, 1])
|
||||
else:
|
||||
compare_results(4096, 25728, 2048, 2.0, 0.4, [4096, 1])
|
||||
|
||||
@skipIfRocm
|
||||
def test_autotune_conv1x1(self):
|
||||
# Assuming input has 3 channels and we want to produce 16 channels as output
|
||||
|
|
|
|||
|
|
@ -242,7 +242,8 @@ CI_SERIAL_LIST = [
|
|||
"test_autocast", # OOM
|
||||
"test_native_mha", # OOM
|
||||
"test_module_hooks", # OOM
|
||||
"inductor/test_max_autotune", # Testing, probably revert later
|
||||
"inductor/test_max_autotune",
|
||||
"inductor/test_cutlass_backend", # slow due to many nvcc compilation steps
|
||||
"inductor/test_torchinductor", # OOM on test_large_block_sizes
|
||||
"inductor/test_torchinductor_dynamic_shapes", # OOM on test_large_block_sizes
|
||||
"inductor/test_torchinductor_codegen_dynamic_shapes", # OOM on test_large_block_sizes
|
||||
|
|
|
|||
Loading…
Reference in a new issue