Fix CUTLASS 2.x kernels for auto-tuning

ghstack-source-id: f4a15fb2d6
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146755
This commit is contained in:
Aleksandar Samardžić 2025-02-09 15:29:27 +01:00
parent 91c4bf39d3
commit 5eda7bbe7b
3 changed files with 12 additions and 30 deletions

View file

@ -45,7 +45,6 @@ log = logging.getLogger(__name__)
HAS_CUDA = HAS_CUDA and not torch.version.hip
SM80OrLater = SM80OrLater and not torch.version.hip
SM90OrLater = SM90OrLater and not torch.version.hip
SM80 = SM80OrLater and torch.cuda.get_device_capability() == (8, 0)
def _get_path_without_sccache() -> str:
@ -797,20 +796,15 @@ class TestCutlassBackend(TestCase):
torch.testing.assert_close(expected, actual, atol=0.01, rtol=0.01)
# TODO: Enable dynamic test cases when dynamic support is added.
@unittest.skipIf(True, "disabled due to broken on A100")
# error: TypeError: can't multiply sequence by non-int of type 'str'
@unittest.skipIf(not SM80, "need sm_80 exactly")
@unittest.skipIf(not SM80OrLater or SM90OrLater, "need sm_8x exactly")
@parametrize("dynamic", (False,))
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "CUTLASS,Triton,ATen"))
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_max_autotune_cutlass_backend_mixed_mm(
self, dynamic: bool, max_autotune_gemm_backends: str
):
def test_max_autotune_cutlass_backend_mixed_mm(self, dynamic: bool):
"""
Make sure autotuning mm in sub processes work without crashes.
"""
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
if torch.version.hip:
return
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
@ -830,7 +824,7 @@ class TestCutlassBackend(TestCase):
{
"max_autotune": True,
"autotune_in_subproc": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"max_autotune_gemm_backends": "CUTLASS",
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 2,
"use_mixed_mm": True,
@ -854,20 +848,17 @@ class TestCutlassBackend(TestCase):
assert cutlass_kernels_count > 0
# TODO: Enable dynamic test cases when dynamic support is added.
@unittest.skipIf(True, "disabled due to broken on A100")
# error: TypeError: can't multiply sequence by non-int of type 'str'
@unittest.skipIf(not SM80, "need sm_80 exactly")
@unittest.skipIf(not SM80OrLater or SM90OrLater, "need sm_8x exactly")
@parametrize("dynamic", (False,))
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "CUTLASS,Triton,ATen"))
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_max_autotune_cutlass_backend_sparse_semi_structured_mm(
self, dynamic: bool, max_autotune_gemm_backends: str
self, dynamic: bool
):
"""
Make sure autotuning mm in sub processes work without crashes.
"""
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
if torch.version.hip:
return
SparseSemiStructuredTensor._FORCE_CUTLASS = True
@ -885,7 +876,7 @@ class TestCutlassBackend(TestCase):
{
"max_autotune": True,
"autotune_in_subproc": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"max_autotune_gemm_backends": "CUTLASS",
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 2,
"autotune_local_cache": True,

View file

@ -3,7 +3,7 @@ import logging
from dataclasses import dataclass
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union
from sympy import Expr
from sympy import Expr, symbols
from torch import dtype as torch_dtype
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
@ -404,6 +404,7 @@ class CUDATemplateKernel(CUDAKernel):
if len(sizes) == 0:
return str(default_value)
sizes = [symbols(v) if isinstance(v, str) else v for v in sizes]
val = sympy_product(sizes)
return val

View file

@ -153,9 +153,6 @@ extern "C" {
PT_EXPORT {{kernel_call_signature}} {
try {
int64_t B = {{kernel.size(Y, 0, -3, default_value=1)}};
int64_t M = {{kernel.size(X, -2)}};
int64_t K = {{kernel.size(W, -2)}};
int64_t N = {{kernel.size(W, -1)}};
using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator;
using coord_t = cutlass::gemm::GemmCoord::Index;
static cutlass::KernelHardwareInfo hw_info;
@ -174,13 +171,6 @@ PT_EXPORT {{kernel_call_signature}} {
// check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers
#ifndef CUTLASS_BACKEND_DISABLE_CHECKS
{{kernel.check_not_null(X)}}
{{kernel.check_not_null(W)}}
{{kernel.check_not_null(Bias)}}
{{kernel.check_not_null(Meta)}}
{{kernel.check_not_null(Y)}}
{
auto status = gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
@ -276,7 +266,7 @@ GEMM_ARGS_SPARSE_CUTLASS_2X = r"""
{
static_cast<coord_t>({{M}}),
static_cast<coord_t>({{N}}),
static_cast<coord_t>(K),
static_cast<coord_t>(2 * K),
}, // GemmCoord problem_size
X_ref, // TensorRef<ElementA const, LayoutA> ref_A
W_ref, // TensorRef<ElementB const, LayoutB> ref_B
@ -1379,7 +1369,7 @@ class CUTLASS2xGemmTemplate(CUTLASSGemmTemplate):
A_size = [int(i) for i in A_layout.size]
B_size = [int(i) for i in B_layout.size]
K = max(A_size[1], B_size[0])
return (K == A_size[1] or K == 2 * A_size[0]) and K == B_size[0]
return (K == A_size[1] or K == 2 * A_size[1]) and K == B_size[0]
def _shape_match(
self,