[cutlass backend] Set no fallback to aten, disabled a few broken tests, default to test on H100 (#146554)

This PR does a few things:
* set fall back to aten to False for most tests. Without this, a lot of tests would fail silently since they just use aten
* Disable two subprocess related broken tests. They would crash in subprocess. More investigation needed.
* remove/disable the tests on A100. Let me elaborate a bit more.

There are two types of A100 tests.
* normal tests that also test A100. e.g., mm, addmm, bmm. However, since the shift to cutlass 3x, they don't work anymore. GenerateSM80 would generate ops that use cutlass 2x, but they get filtered out since they are of GemmKind.Universal but only GemmKind.Universal3x are supported in the 3x template.
* tests for A100 only. The mixed mm and sparse semi structure tests are failing due to "TypeError: can't multiply sequence by non-int of type 'str'" for a while. Disabled them for now. Do let us know if you are about them @alexsamardzic

Differential Revision: D69209929

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146554
Approved by: https://github.com/chenyang78
This commit is contained in:
Henry Tsang 2025-02-07 19:59:28 +00:00 committed by PyTorch MergeBot
parent f17109bd96
commit 206ad9f4ad

View file

@ -26,7 +26,7 @@ from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import fresh_inductor_cache
from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import SM75OrLater, SM80OrLater, SM90OrLater
from torch.testing._internal.common_cuda import SM80OrLater, SM90OrLater
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -43,7 +43,6 @@ _CUTLASS_DIR = os.path.join(os.path.dirname(__file__), "../../third_party/cutlas
log = logging.getLogger(__name__)
HAS_CUDA = HAS_CUDA and not torch.version.hip
SM75OrLater = SM75OrLater and not torch.version.hip
SM80OrLater = SM80OrLater and not torch.version.hip
SM90OrLater = SM90OrLater and not torch.version.hip
SM80 = SM80OrLater and torch.cuda.get_device_capability() == (8, 0)
@ -82,7 +81,7 @@ class TestCutlassBackend(TestCase):
super().tearDown()
clear_inductor_caches()
@unittest.skipIf(not SM75OrLater, "need sm_75")
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_max_autotune_cutlass_threshold(self):
"""
@ -104,11 +103,13 @@ class TestCutlassBackend(TestCase):
{
"max_autotune": True,
"autotune_in_subproc": True,
"max_autotune_gemm_backends": "CUTLASS,ATen",
"max_autotune_gemm_backends": "CUTLASS",
"compile_threads": 4,
"cuda.cutlass_backend_min_gemm_size": 100000,
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 2,
# allow fallback to aten as intended
"autotune_fallback_to_aten": True,
}
):
from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller
@ -131,7 +132,7 @@ class TestCutlassBackend(TestCase):
), "Cutlass Kernels should have been filtered, GEMM size is too small"
torch.testing.assert_close(Y_compiled, Y)
@unittest.skipIf(not SM75OrLater, "need sm_75")
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_max_autotune_precompile(self):
"""
@ -189,6 +190,7 @@ class TestCutlassBackend(TestCase):
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 3,
"autotune_fallback_to_aten": False,
}
):
from torch.export import Dim
@ -239,6 +241,7 @@ class TestCutlassBackend(TestCase):
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 1,
"autotune_fallback_to_aten": False,
}
):
from torch._inductor.utils import run_and_get_code
@ -252,8 +255,8 @@ class TestCutlassBackend(TestCase):
2,
).run(codes[0])
# TODO: Enable dynamic test cases when dynamic support is added.
@unittest.skipIf(not SM75OrLater, "need sm_75")
# NOTE: right now tuned_mm doesn't support cutlass 2x, which is used by A100
@unittest.skipIf(not SM90OrLater, "need sm_90")
@parametrize("dynamic", (False, True))
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen,Triton,CUTLASS"))
@parametrize("use_aoti", (False, True))
@ -287,6 +290,7 @@ class TestCutlassBackend(TestCase):
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 2,
"autotune_fallback_to_aten": False,
}
):
Y = model(a, b)
@ -328,6 +332,7 @@ class TestCutlassBackend(TestCase):
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 2,
"cuda.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels
"autotune_fallback_to_aten": False,
}
):
for M, K, N in (
@ -389,6 +394,7 @@ class TestCutlassBackend(TestCase):
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 4,
"cuda.version": "12.2", # required to enable the Kernels we need
"autotune_fallback_to_aten": False,
}
):
counters["inductor"]["cuda_epilogue_fusion_counter"] = 0
@ -485,6 +491,7 @@ class TestCutlassBackend(TestCase):
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
)
@unittest.skipIf(not SM90OrLater, "need sm_90")
def test_max_autotune_cutlass_backend_simple_bmm(self):
def bmm(a, b):
return torch.bmm(a, b)
@ -508,7 +515,8 @@ class TestCutlassBackend(TestCase):
)
# TODO: Enable dynamic test cases when dynamic support is added.
@unittest.skipIf(not SM75OrLater, "need sm_75")
@unittest.skipIf(True, "FIXME: Disabled temporarily since crashing in subprocess")
@unittest.skipIf(not SM90OrLater, "need sm_90")
@parametrize("dynamic", (False,))
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen,Triton,CUTLASS"))
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
@ -537,14 +545,15 @@ class TestCutlassBackend(TestCase):
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 2,
"autotune_fallback_to_aten": False,
}
):
Y = mm(a, a, bias)
Y_compiled = torch.compile(mm, dynamic=dynamic)(a, a, bias)
torch.testing.assert_close(Y_compiled, Y, atol=1e-1, rtol=1e-1)
@unittest.skipIf(not SM75OrLater, "need sm_75")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
@unittest.skipIf(True, "FIXME: Disabled temporarily since crashing in subprocess")
@unittest.skipIf(not SM90OrLater, "need sm_90")
@parametrize("dynamic", (False,))
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen,Triton,CUTLASS"))
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
@ -586,6 +595,7 @@ class TestCutlassBackend(TestCase):
"cuda.cutlass_max_profiling_configs": 4,
"cuda.cutlass_op_allowlist_regex": "",
"cuda.cutlass_op_denylist_regex": "pingpong", # Pingpong Kernels can lead to numerical issues
"autotune_fallback_to_aten": False,
}
):
# No broadcast
@ -595,6 +605,7 @@ class TestCutlassBackend(TestCase):
# Broadcast last dim.
compare_results(4096, 25728, 2048, 2.0, 0.4, [4096, 1])
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_addmm_with_expanded_bias(self):
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
@ -610,9 +621,10 @@ class TestCutlassBackend(TestCase):
{
"max_autotune": True,
"autotune_in_subproc": False,
"max_autotune_gemm_backends": "ATEN,CUTLASS",
"max_autotune_gemm_backends": "CUTLASS",
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 1,
"autotune_fallback_to_aten": False,
}
):
model = MyModel()
@ -629,7 +641,7 @@ class TestCutlassBackend(TestCase):
torch.testing.assert_close(expected, actual)
# TODO: Enable dynamic test cases when dynamic support is added.
@unittest.skipIf(not SM80OrLater, "need sm_80")
@unittest.skipIf(not SM90OrLater, "need sm_90")
@parametrize("dynamic", (False,))
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "CUTLASS,ATen"))
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
@ -662,6 +674,7 @@ class TestCutlassBackend(TestCase):
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 2,
"autotune_fallback_to_aten": False,
}
):
Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
@ -784,6 +797,8 @@ class TestCutlassBackend(TestCase):
torch.testing.assert_close(expected, actual, atol=0.01, rtol=0.01)
# TODO: Enable dynamic test cases when dynamic support is added.
@unittest.skipIf(True, "disabled due to broken on A100")
# error: TypeError: can't multiply sequence by non-int of type 'str'
@unittest.skipIf(not SM80, "need sm_80 exactly")
@parametrize("dynamic", (False,))
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "CUTLASS,Triton,ATen"))
@ -820,6 +835,7 @@ class TestCutlassBackend(TestCase):
"cuda.cutlass_max_profiling_configs": 2,
"use_mixed_mm": True,
"autotune_local_cache": True,
"autotune_fallback_to_aten": False,
}
):
Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
@ -838,6 +854,8 @@ class TestCutlassBackend(TestCase):
assert cutlass_kernels_count > 0
# TODO: Enable dynamic test cases when dynamic support is added.
@unittest.skipIf(True, "disabled due to broken on A100")
# error: TypeError: can't multiply sequence by non-int of type 'str'
@unittest.skipIf(not SM80, "need sm_80 exactly")
@parametrize("dynamic", (False,))
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "CUTLASS,Triton,ATen"))
@ -871,6 +889,7 @@ class TestCutlassBackend(TestCase):
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 2,
"autotune_local_cache": True,
"autotune_fallback_to_aten": False,
}
):
Y_compiled = torch.compile(mm, dynamic=dynamic)(a_sparse, b)