diff --git a/test/inductor/test_cutlass_backend.py b/test/inductor/test_cutlass_backend.py new file mode 100644 index 00000000000..61abe782301 --- /dev/null +++ b/test/inductor/test_cutlass_backend.py @@ -0,0 +1,358 @@ +# Owner(s): ["module: inductor"] +import logging +import os +import unittest +from typing import Callable, List + +import torch +from torch._dynamo.test_case import run_tests, TestCase +from torch._dynamo.utils import counters +from torch._inductor import config +from torch.testing._internal.common_cuda import SM75OrLater, SM90OrLater +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, +) + +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA + +torch.set_float32_matmul_precision("high") +if HAS_CUDA: + torch.cuda.memory._set_allocator_settings("expandable_segments:False") + +_CUTLASS_DIR = os.path.join(os.path.dirname(__file__), "../../third_party/cutlass/") + +log = logging.getLogger(__name__) + + +def _get_path_without_sccache() -> str: + """ + Get the PATH environment variable without sccache. + """ + path_envs = os.environ.get("PATH", "").split(":") + path_envs = [env for env in path_envs if "/opt/cache/bin" not in env] + return ":".join(path_envs) + + +@instantiate_parametrized_tests +class TestCutlassBackend(TestCase): + def setUp(self): + super().setUp() + torch.random.manual_seed(1234) + + @unittest.skipIf(not SM75OrLater, "need sm_75") + @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") + @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) + def test_max_autotune_precompile(self): + """ + Make sure autotuning mm in sub processes work without crashes. + """ + + if torch.version.hip: + return + + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + + def mm(a, b): + return a @ b + + a = torch.randn(100, 10).cuda().half() + b = torch.randn(10, 100).cuda().half() + + with config.patch( + { + "max_autotune": True, + "autotune_in_subproc": True, + "max_autotune_gemm_backends": "CUTLASS,Triton,ATen", + "compile_threads": 4, + "cuda.cutlass_dir": _CUTLASS_DIR, + "cuda.cutlass_max_profiling_configs": 2, + } + ): + Y_compiled = torch.compile(mm, dynamic=False)(a, b) + Y = mm(a, b) + torch.testing.assert_close(Y_compiled, Y) + + # TODO: Enable dynamic test cases when dynamic support is added. + @unittest.skipIf(not SM75OrLater, "need sm_75") + @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") + @parametrize("dynamic", (False,)) + @parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen,Triton,CUTLASS")) + @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) + def test_max_autotune_cutlass_backend_regular_mm( + self, dynamic: bool, max_autotune_gemm_backends: str + ): + """ + Make sure autotuning mm in sub processes work without crashes. + """ + + if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip: + return + + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + + def mm(a, b): + return a @ b + + a = torch.randn(100, 10).cuda().half() + b = torch.randn(10, 100).cuda().half() + + with config.patch( + { + "max_autotune": True, + "autotune_in_subproc": False, + "max_autotune_gemm_backends": max_autotune_gemm_backends, + "cuda.cutlass_dir": _CUTLASS_DIR, + "cuda.cutlass_max_profiling_configs": 2, + } + ): + Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b) + Y = mm(a, b) + torch.testing.assert_close(Y_compiled, Y) + + def _test_max_autotune_cutlass_backend_epilogue_fusion( + self, + dynamic: bool = False, + max_autotune_gemm_backends: str = "CUTLASS", + mixed_precision=False, + fp16=True, + expected_fuse_count=1, + mm: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None, + ): + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = ( + mixed_precision + ) + + # Note: The ops that are available + # also depend on the alignment of the shapes + # so if these shapes don't all align to at least 8 elements + # it can happen that no Cutlass 3.x op is available + # that allows fusions + a = torch.randn(256, 32).cuda() + b = torch.randn(32, 256).cuda() + if fp16: + a = a.half() + b = b.half() + + with config.patch( + { + "max_autotune": True, + "autotune_in_subproc": False, + "max_autotune_gemm_backends": max_autotune_gemm_backends, + "cuda.cutlass_dir": _CUTLASS_DIR, + "cuda.cutlass_max_profiling_configs": 4, + "cuda.cutlass_only_evt_capable_ops": True, + "cuda.version": "12.2", # required to enable the Kernels we need + } + ): + counters["inductor"]["cuda_epilogue_fusion_counter"] = 0 + Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b) + Y = mm(a, b) + actual_count = counters["inductor"]["cuda_epilogue_fusion_counter"] + assert ( + actual_count == expected_fuse_count + ), f"Expected fuse count of {expected_fuse_count} but got {actual_count}" + torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2) + + @unittest.skipIf(not SM90OrLater, "need sm_90") + @unittest.skipIf(torch.version.hip, "HIP not supported") + @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") + def test_max_autotune_cutlass_backend_simple_fusion_fp16(self): + def mm(a, b): + return (a @ b) * 3.0 + + # The pointwise ops seem to be pre-fused into a single Pointwise + self._test_max_autotune_cutlass_backend_epilogue_fusion( + mixed_precision=False, fp16=True, expected_fuse_count=1, mm=mm + ) + + @unittest.skipIf(not SM90OrLater, "need sm_90") + @unittest.skipIf(torch.version.hip, "HIP not supported") + @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") + def test_max_autotune_cutlass_backend_simple_fusion_fp16_fp32acc(self): + def mm(a, b): + return (a @ b) * 3.0 + + self._test_max_autotune_cutlass_backend_epilogue_fusion( + mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm + ) + + @unittest.skipIf(not SM90OrLater, "need sm_90") + @unittest.skipIf(torch.version.hip, "HIP not supported") + @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") + def test_max_autotune_cutlass_backend_chained_fusion_fp16(self): + def mm(a, b): + return (a @ b) * 3.3 - 1.234 + + # The pointwise ops seem to be pre-fused into a single Pointwise + self._test_max_autotune_cutlass_backend_epilogue_fusion( + mixed_precision=False, fp16=True, expected_fuse_count=1, mm=mm + ) + + @unittest.skipIf(not SM90OrLater, "need sm_90") + @unittest.skipIf(torch.version.hip, "HIP not supported") + @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") + def test_max_autotune_cutlass_backend_chained_fusion_fp16_fp32acc(self): + def mm(a, b): + return (a @ b) * 3.3 - 1.234 + + self._test_max_autotune_cutlass_backend_epilogue_fusion( + mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm + ) + + @unittest.skipIf(not SM90OrLater, "need sm_90") + @unittest.skipIf(torch.version.hip, "HIP not supported") + @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") + def test_max_autotune_cutlass_backend_relu_fusion_fp16(self): + def mm(a, b): + return torch.nn.functional.relu((a @ b) * 3.3 - 1.234) + + self._test_max_autotune_cutlass_backend_epilogue_fusion( + mixed_precision=False, fp16=True, expected_fuse_count=1, mm=mm + ) + + @unittest.skipIf(not SM90OrLater, "need sm_90") + @unittest.skipIf(torch.version.hip, "HIP not supported") + @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") + def test_max_autotune_cutlass_backend_relu_fusion_fp16_fp32acc(self): + def mm(a, b): + return torch.nn.functional.relu((a @ b) * 3.3 - 1.234) + + # The pointwise ops seem to be pre-fused into a single Pointwise + self._test_max_autotune_cutlass_backend_epilogue_fusion( + mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm + ) + + @unittest.skipIf(not SM90OrLater, "need sm_90") + @unittest.skipIf(torch.version.hip, "HIP not supported") + @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") + def test_max_autotune_cutlass_backend_relu6_fusion_fp16_fp32acc(self): + def mm(a, b): + return torch.clamp(torch.nn.functional.relu(a @ b), max=6.0) + + # The pointwise ops seem to be pre-fused into a single Pointwise + self._test_max_autotune_cutlass_backend_epilogue_fusion( + mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm + ) + + @unittest.skipIf(not SM90OrLater, "need sm_90") + @unittest.skipIf(torch.version.hip, "HIP not supported") + @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") + def test_max_autotune_cutlass_backend_no_fusion_dtype_mismatch(self): + def mm(a, b): + # this should not be fused, since the output dtype is different from the matmul dtype + return (a @ b).to(torch.float32) * 0.00001 + + self._test_max_autotune_cutlass_backend_epilogue_fusion( + mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm + ) + + @unittest.skipIf(not SM90OrLater, "need sm_90") + @unittest.skipIf(torch.version.hip, "HIP not supported") + @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") + def test_max_autotune_cutlass_backend_shape_dependent_normalization_fusion(self): + def mm(a, b): + return (a @ b) / b.size(1) + + self._test_max_autotune_cutlass_backend_epilogue_fusion( + mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm + ) + + # TODO: Enable dynamic test cases when dynamic support is added. + @unittest.skipIf(not SM75OrLater, "need sm_75") + @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") + @parametrize("dynamic", (False,)) + @parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen,Triton,CUTLASS")) + @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) + def test_max_autotune_cutlass_backend_mm_bias( + self, dynamic: bool = False, max_autotune_gemm_backends: str = "CUTLASS" + ): + """ + Make sure autotuning mm in sub processes work without crashes. + """ + + if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip: + return + + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + + def mm(a, b, bias): + return torch.nn.functional.linear(a, b, bias) + + a = torch.randn(2048, 4096).cuda().half() + bias = torch.randn(2048).cuda().half() + + with config.patch( + { + "max_autotune": True, + "autotune_in_subproc": False, + "max_autotune_gemm_backends": max_autotune_gemm_backends, + "cuda.cutlass_dir": _CUTLASS_DIR, + "cuda.cutlass_max_profiling_configs": 2, + } + ): + Y = mm(a, a, bias) + Y_compiled = torch.compile(mm, dynamic=dynamic)(a, a, bias) + torch.testing.assert_close(Y_compiled, Y, atol=1e-1, rtol=1e-1) + + # TODO: Enable dynamic test cases when dynamic support is added. + @unittest.skipIf(not SM75OrLater, "need sm_75") + @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") + @parametrize("dynamic", (False,)) + @parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen,Triton,CUTLASS")) + @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) + def test_max_autotune_cutlass_backend_addmm( + self, dynamic, max_autotune_gemm_backends + ): + """ + Make sure autotuning addmm in sub processes work without crashes. + """ + + if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip: + return + + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + + def addmm(x, a, b, alpha, beta): + return torch.addmm(x, a, b, alpha=alpha, beta=beta) + + def compare_results( + m: int, k: int, n: int, alpha: float, beta: float, x_shape: List[int] + ) -> None: + x = torch.randn(x_shape).cuda().half() + a = torch.randn(m, k).cuda().half() + b = torch.randn(k, n).cuda().half() + y_expected = addmm(x, a, b, alpha, beta) + + compiled_fn = torch.compile(addmm, dynamic=dynamic) + y = compiled_fn(x, a, b, alpha, beta) + torch.testing.assert_close(y, y_expected) + + with config.patch( + { + "max_autotune": True, + "autotune_in_subproc": False, + "max_autotune_gemm_backends": max_autotune_gemm_backends, + "cuda.cutlass_dir": _CUTLASS_DIR, + "cuda.cutlass_max_profiling_configs": 2, + } + ): + # No broadcast + compare_results(4096, 25728, 2048, 2.0, 0.4, [4096, 2048]) + # Broadcast first dim. + compare_results(4096, 25728, 2048, 2.0, 0.4, [2048]) + # Broadcast last dim. + if not SM90OrLater and max_autotune_gemm_backends == "CUTLASS": + with self.assertRaisesRegex(RuntimeError, "No choices to select"): + # CUTLASS2 doesn't support Bias last-dim broadcast. + compare_results(4096, 25728, 2048, 2.0, 0.4, [4096, 1]) + else: + compare_results(4096, 25728, 2048, 2.0, 0.4, [4096, 1]) + + +if __name__ == "__main__": + from torch._inductor.utils import is_big_gpu + + # Set env to make it work in CI. + if HAS_CUDA and HAS_CPU and is_big_gpu(0): + run_tests() diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 8589fb8101f..b8a222bec59 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -1,6 +1,5 @@ # Owner(s): ["module: inductor"] import os -import unittest from typing import Callable, List, Optional @@ -9,7 +8,6 @@ from torch import multiprocessing as mp from torch._dynamo import reset from torch._dynamo.test_case import run_tests, TestCase from torch._dynamo.testing import reset_rng_state -from torch._dynamo.utils import counters from torch._inductor import config from torch._inductor.autotune_process import ( BenchmarkRequest, @@ -29,7 +27,6 @@ from torch._inductor.utils import run_and_get_code from torch._inductor.virtualized import V from torch.fx.experimental.proxy_tensor import make_fx from torch.testing import FileCheck -from torch.testing._internal.common_cuda import SM75OrLater, SM90OrLater from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -361,261 +358,6 @@ class TestMaxAutotune(TestCase): finally: V.set_debug_handler(old_debug_handler) - @unittest.skipIf(not SM75OrLater, "need sm_75") - @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") - @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) - def test_max_autotune_precompile(self): - """ - Make sure autotuning mm in sub processes work without crashes. - """ - - if torch.version.hip: - return - - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False - - def mm(a, b): - return a @ b - - a = torch.randn(100, 10).cuda().half() - b = torch.randn(10, 100).cuda().half() - - with config.patch( - { - "max_autotune": True, - "autotune_in_subproc": True, - "max_autotune_gemm_backends": "CUTLASS,Triton,ATen", - "compile_threads": 4, - "cuda.cutlass_dir": _CUTLASS_DIR, - "cuda.cutlass_max_profiling_configs": 2, - } - ): - Y_compiled = torch.compile(mm, dynamic=False)(a, b) - Y = mm(a, b) - torch.testing.assert_close(Y_compiled, Y) - - # TODO: Enable dynamic test cases when dynamic support is added. - @unittest.skipIf(not SM75OrLater, "need sm_75") - @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") - @parametrize("dynamic", (False,)) - @parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen,Triton,CUTLASS")) - @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) - def test_max_autotune_cutlass_backend_regular_mm( - self, dynamic: bool, max_autotune_gemm_backends: str - ): - """ - Make sure autotuning mm in sub processes work without crashes. - """ - - if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip: - return - - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False - - def mm(a, b): - return a @ b - - a = torch.randn(100, 10).cuda().half() - b = torch.randn(10, 100).cuda().half() - - with config.patch( - { - "max_autotune": True, - "autotune_in_subproc": False, - "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_dir": _CUTLASS_DIR, - "cuda.cutlass_max_profiling_configs": 2, - } - ): - Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b) - Y = mm(a, b) - torch.testing.assert_close(Y_compiled, Y) - - def _test_max_autotune_cutlass_backend_epilogue_fusion( - self, - dynamic: bool = False, - max_autotune_gemm_backends: str = "CUTLASS", - mixed_precision=False, - fp16=True, - expected_fuse_count=1, - mm: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None, - ): - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = ( - mixed_precision - ) - - # Note: The ops that are available - # also depend on the alignment of the shapes - # so if these shapes don't all align to at least 8 elements - # it can happen that no Cutlass 3.x op is available - # that allows fusions - a = torch.randn(256, 32).cuda() - b = torch.randn(32, 256).cuda() - if fp16: - a = a.half() - b = b.half() - - with config.patch( - { - "max_autotune": True, - "autotune_in_subproc": False, - "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_dir": _CUTLASS_DIR, - "cuda.cutlass_max_profiling_configs": 4, - "cuda.cutlass_only_evt_capable_ops": True, - "cuda.version": "12.2", # required to enable the Kernels we need - } - ): - counters["inductor"]["cuda_epilogue_fusion_counter"] = 0 - Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b) - Y = mm(a, b) - actual_count = counters["inductor"]["cuda_epilogue_fusion_counter"] - assert ( - actual_count == expected_fuse_count - ), f"Expected fuse count of {expected_fuse_count} but got {actual_count}" - torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2) - - @unittest.skipIf(not SM90OrLater, "need sm_90") - @unittest.skipIf(torch.version.hip, "HIP not supported") - @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") - def test_max_autotune_cutlass_backend_simple_fusion_fp16(self): - def mm(a, b): - return (a @ b) * 3.0 - - # The pointwise ops seem to be pre-fused into a single Pointwise - self._test_max_autotune_cutlass_backend_epilogue_fusion( - mixed_precision=False, fp16=True, expected_fuse_count=1, mm=mm - ) - - @unittest.skipIf(not SM90OrLater, "need sm_90") - @unittest.skipIf(torch.version.hip, "HIP not supported") - @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") - def test_max_autotune_cutlass_backend_simple_fusion_fp16_fp32acc(self): - def mm(a, b): - return (a @ b) * 3.0 - - self._test_max_autotune_cutlass_backend_epilogue_fusion( - mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm - ) - - @unittest.skipIf(not SM90OrLater, "need sm_90") - @unittest.skipIf(torch.version.hip, "HIP not supported") - @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") - def test_max_autotune_cutlass_backend_chained_fusion_fp16(self): - def mm(a, b): - return (a @ b) * 3.3 - 1.234 - - # The pointwise ops seem to be pre-fused into a single Pointwise - self._test_max_autotune_cutlass_backend_epilogue_fusion( - mixed_precision=False, fp16=True, expected_fuse_count=1, mm=mm - ) - - @unittest.skipIf(not SM90OrLater, "need sm_90") - @unittest.skipIf(torch.version.hip, "HIP not supported") - @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") - def test_max_autotune_cutlass_backend_chained_fusion_fp16_fp32acc(self): - def mm(a, b): - return (a @ b) * 3.3 - 1.234 - - self._test_max_autotune_cutlass_backend_epilogue_fusion( - mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm - ) - - @unittest.skipIf(not SM90OrLater, "need sm_90") - @unittest.skipIf(torch.version.hip, "HIP not supported") - @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") - def test_max_autotune_cutlass_backend_relu_fusion_fp16(self): - def mm(a, b): - return torch.nn.functional.relu((a @ b) * 3.3 - 1.234) - - self._test_max_autotune_cutlass_backend_epilogue_fusion( - mixed_precision=False, fp16=True, expected_fuse_count=1, mm=mm - ) - - @unittest.skipIf(not SM90OrLater, "need sm_90") - @unittest.skipIf(torch.version.hip, "HIP not supported") - @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") - def test_max_autotune_cutlass_backend_relu_fusion_fp16_fp32acc(self): - def mm(a, b): - return torch.nn.functional.relu((a @ b) * 3.3 - 1.234) - - # The pointwise ops seem to be pre-fused into a single Pointwise - self._test_max_autotune_cutlass_backend_epilogue_fusion( - mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm - ) - - @unittest.skipIf(not SM90OrLater, "need sm_90") - @unittest.skipIf(torch.version.hip, "HIP not supported") - @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") - def test_max_autotune_cutlass_backend_relu6_fusion_fp16_fp32acc(self): - def mm(a, b): - return torch.clamp(torch.nn.functional.relu(a @ b), max=6.0) - - # The pointwise ops seem to be pre-fused into a single Pointwise - self._test_max_autotune_cutlass_backend_epilogue_fusion( - mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm - ) - - @unittest.skipIf(not SM90OrLater, "need sm_90") - @unittest.skipIf(torch.version.hip, "HIP not supported") - @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") - def test_max_autotune_cutlass_backend_no_fusion_dtype_mismatch(self): - def mm(a, b): - # this should not be fused, since the output dtype is different from the matmul dtype - return (a @ b).to(torch.float32) * 0.00001 - - self._test_max_autotune_cutlass_backend_epilogue_fusion( - mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm - ) - - @unittest.skipIf(not SM90OrLater, "need sm_90") - @unittest.skipIf(torch.version.hip, "HIP not supported") - @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") - def test_max_autotune_cutlass_backend_shape_dependent_normalization_fusion(self): - def mm(a, b): - return (a @ b) / b.size(1) - - self._test_max_autotune_cutlass_backend_epilogue_fusion( - mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm - ) - - # TODO: Enable dynamic test cases when dynamic support is added. - @unittest.skipIf(not SM75OrLater, "need sm_75") - @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") - @parametrize("dynamic", (False,)) - @parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen,Triton,CUTLASS")) - @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) - def test_max_autotune_cutlass_backend_mm_bias( - self, dynamic: bool = False, max_autotune_gemm_backends: str = "CUTLASS" - ): - """ - Make sure autotuning mm in sub processes work without crashes. - """ - - if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip: - return - - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False - - def mm(a, b, bias): - return torch.nn.functional.linear(a, b, bias) - - a = torch.randn(2048, 4096).cuda().half() - bias = torch.randn(2048).cuda().half() - - with config.patch( - { - "max_autotune": True, - "autotune_in_subproc": False, - "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_dir": _CUTLASS_DIR, - "cuda.cutlass_max_profiling_configs": 2, - } - ): - Y = mm(a, a, bias) - Y_compiled = torch.compile(mm, dynamic=dynamic)(a, a, bias) - torch.testing.assert_close(Y_compiled, Y, atol=1e-1, rtol=1e-1) - @parametrize("dynamic", (False, True)) def test_max_autotune_addmm(self, dynamic=False): """ @@ -650,60 +392,6 @@ class TestMaxAutotune(TestCase): with config.patch({"max_autotune": True}): torch.compile(addmm, dynamic=dynamic)(x, a, b) - # TODO: Enable dynamic test cases when dynamic support is added. - @unittest.skipIf(not SM75OrLater, "need sm_75") - @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") - @parametrize("dynamic", (False,)) - @parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen,Triton,CUTLASS")) - @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) - def test_max_autotune_cutlass_backend_addmm( - self, dynamic, max_autotune_gemm_backends - ): - """ - Make sure autotuning addmm in sub processes work without crashes. - """ - - if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip: - return - - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False - - def addmm(x, a, b, alpha, beta): - return torch.addmm(x, a, b, alpha=alpha, beta=beta) - - def compare_results( - m: int, k: int, n: int, alpha: float, beta: float, x_shape: List[int] - ) -> None: - x = torch.randn(x_shape).cuda().half() - a = torch.randn(m, k).cuda().half() - b = torch.randn(k, n).cuda().half() - y_expected = addmm(x, a, b, alpha, beta) - - compiled_fn = torch.compile(addmm, dynamic=dynamic) - y = compiled_fn(x, a, b, alpha, beta) - torch.testing.assert_close(y, y_expected) - - with config.patch( - { - "max_autotune": True, - "autotune_in_subproc": False, - "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_dir": _CUTLASS_DIR, - "cuda.cutlass_max_profiling_configs": 2, - } - ): - # No broadcast - compare_results(4096, 25728, 2048, 2.0, 0.4, [4096, 2048]) - # Broadcast first dim. - compare_results(4096, 25728, 2048, 2.0, 0.4, [2048]) - # Broadcast last dim. - if not SM90OrLater and max_autotune_gemm_backends == "CUTLASS": - with self.assertRaisesRegex(RuntimeError, "No choices to select"): - # CUTLASS2 doesn't support Bias last-dim broadcast. - compare_results(4096, 25728, 2048, 2.0, 0.4, [4096, 1]) - else: - compare_results(4096, 25728, 2048, 2.0, 0.4, [4096, 1]) - @skipIfRocm def test_autotune_conv1x1(self): # Assuming input has 3 channels and we want to produce 16 channels as output diff --git a/test/run_test.py b/test/run_test.py index e86af962304..49498ffa106 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -242,7 +242,8 @@ CI_SERIAL_LIST = [ "test_autocast", # OOM "test_native_mha", # OOM "test_module_hooks", # OOM - "inductor/test_max_autotune", # Testing, probably revert later + "inductor/test_max_autotune", + "inductor/test_cutlass_backend", # slow due to many nvcc compilation steps "inductor/test_torchinductor", # OOM on test_large_block_sizes "inductor/test_torchinductor_dynamic_shapes", # OOM on test_large_block_sizes "inductor/test_torchinductor_codegen_dynamic_shapes", # OOM on test_large_block_sizes