diff --git a/test/inductor/test_cutlass_backend.py b/test/inductor/test_cutlass_backend.py index d18053c527c..9e526837977 100644 --- a/test/inductor/test_cutlass_backend.py +++ b/test/inductor/test_cutlass_backend.py @@ -45,7 +45,6 @@ log = logging.getLogger(__name__) HAS_CUDA = HAS_CUDA and not torch.version.hip SM80OrLater = SM80OrLater and not torch.version.hip SM90OrLater = SM90OrLater and not torch.version.hip -SM80 = SM80OrLater and torch.cuda.get_device_capability() == (8, 0) def _get_path_without_sccache() -> str: @@ -797,20 +796,15 @@ class TestCutlassBackend(TestCase): torch.testing.assert_close(expected, actual, atol=0.01, rtol=0.01) # TODO: Enable dynamic test cases when dynamic support is added. - @unittest.skipIf(True, "disabled due to broken on A100") - # error: TypeError: can't multiply sequence by non-int of type 'str' - @unittest.skipIf(not SM80, "need sm_80 exactly") + @unittest.skipIf(not SM80OrLater or SM90OrLater, "need sm_8x exactly") @parametrize("dynamic", (False,)) - @parametrize("max_autotune_gemm_backends", ("CUTLASS", "CUTLASS,Triton,ATen")) @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) - def test_max_autotune_cutlass_backend_mixed_mm( - self, dynamic: bool, max_autotune_gemm_backends: str - ): + def test_max_autotune_cutlass_backend_mixed_mm(self, dynamic: bool): """ Make sure autotuning mm in sub processes work without crashes. """ - if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip: + if torch.version.hip: return torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False @@ -830,7 +824,7 @@ class TestCutlassBackend(TestCase): { "max_autotune": True, "autotune_in_subproc": True, - "max_autotune_gemm_backends": max_autotune_gemm_backends, + "max_autotune_gemm_backends": "CUTLASS", "cuda.cutlass_dir": _CUTLASS_DIR, "cuda.cutlass_max_profiling_configs": 2, "use_mixed_mm": True, @@ -854,20 +848,17 @@ class TestCutlassBackend(TestCase): assert cutlass_kernels_count > 0 # TODO: Enable dynamic test cases when dynamic support is added. - @unittest.skipIf(True, "disabled due to broken on A100") - # error: TypeError: can't multiply sequence by non-int of type 'str' - @unittest.skipIf(not SM80, "need sm_80 exactly") + @unittest.skipIf(not SM80OrLater or SM90OrLater, "need sm_8x exactly") @parametrize("dynamic", (False,)) - @parametrize("max_autotune_gemm_backends", ("CUTLASS", "CUTLASS,Triton,ATen")) @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_max_autotune_cutlass_backend_sparse_semi_structured_mm( - self, dynamic: bool, max_autotune_gemm_backends: str + self, dynamic: bool ): """ Make sure autotuning mm in sub processes work without crashes. """ - if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip: + if torch.version.hip: return SparseSemiStructuredTensor._FORCE_CUTLASS = True @@ -885,7 +876,7 @@ class TestCutlassBackend(TestCase): { "max_autotune": True, "autotune_in_subproc": True, - "max_autotune_gemm_backends": max_autotune_gemm_backends, + "max_autotune_gemm_backends": "CUTLASS", "cuda.cutlass_dir": _CUTLASS_DIR, "cuda.cutlass_max_profiling_configs": 2, "autotune_local_cache": True, diff --git a/torch/_inductor/codegen/cuda/cuda_kernel.py b/torch/_inductor/codegen/cuda/cuda_kernel.py index 621d9d69e37..18643b51b23 100644 --- a/torch/_inductor/codegen/cuda/cuda_kernel.py +++ b/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -3,7 +3,7 @@ import logging from dataclasses import dataclass from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union -from sympy import Expr +from sympy import Expr, symbols from torch import dtype as torch_dtype from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu @@ -404,6 +404,7 @@ class CUDATemplateKernel(CUDAKernel): if len(sizes) == 0: return str(default_value) + sizes = [symbols(v) if isinstance(v, str) else v for v in sizes] val = sympy_product(sizes) return val diff --git a/torch/_inductor/codegen/cuda/gemm_template.py b/torch/_inductor/codegen/cuda/gemm_template.py index c2e8bc96370..ff2263f0ac5 100644 --- a/torch/_inductor/codegen/cuda/gemm_template.py +++ b/torch/_inductor/codegen/cuda/gemm_template.py @@ -153,9 +153,6 @@ extern "C" { PT_EXPORT {{kernel_call_signature}} { try { int64_t B = {{kernel.size(Y, 0, -3, default_value=1)}}; - int64_t M = {{kernel.size(X, -2)}}; - int64_t K = {{kernel.size(W, -2)}}; - int64_t N = {{kernel.size(W, -1)}}; using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator; using coord_t = cutlass::gemm::GemmCoord::Index; static cutlass::KernelHardwareInfo hw_info; @@ -174,13 +171,6 @@ PT_EXPORT {{kernel_call_signature}} { // check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers #ifndef CUTLASS_BACKEND_DISABLE_CHECKS - {{kernel.check_not_null(X)}} - {{kernel.check_not_null(W)}} - {{kernel.check_not_null(Bias)}} - {{kernel.check_not_null(Meta)}} - {{kernel.check_not_null(Y)}} - - { auto status = gemm_op.can_implement(arguments); CUTLASS_CHECK(status); @@ -276,7 +266,7 @@ GEMM_ARGS_SPARSE_CUTLASS_2X = r""" { static_cast({{M}}), static_cast({{N}}), - static_cast(K), + static_cast(2 * K), }, // GemmCoord problem_size X_ref, // TensorRef ref_A W_ref, // TensorRef ref_B @@ -1379,7 +1369,7 @@ class CUTLASS2xGemmTemplate(CUTLASSGemmTemplate): A_size = [int(i) for i in A_layout.size] B_size = [int(i) for i in B_layout.size] K = max(A_size[1], B_size[0]) - return (K == A_size[1] or K == 2 * A_size[0]) and K == B_size[0] + return (K == A_size[1] or K == 2 * A_size[1]) and K == B_size[0] def _shape_match( self,