mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Fix CUTLASS 2.x kernels for auto-tuning
ghstack-source-id: f4a15fb2d6
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146755
This commit is contained in:
parent
91c4bf39d3
commit
5eda7bbe7b
3 changed files with 12 additions and 30 deletions
|
|
@ -45,7 +45,6 @@ log = logging.getLogger(__name__)
|
|||
HAS_CUDA = HAS_CUDA and not torch.version.hip
|
||||
SM80OrLater = SM80OrLater and not torch.version.hip
|
||||
SM90OrLater = SM90OrLater and not torch.version.hip
|
||||
SM80 = SM80OrLater and torch.cuda.get_device_capability() == (8, 0)
|
||||
|
||||
|
||||
def _get_path_without_sccache() -> str:
|
||||
|
|
@ -797,20 +796,15 @@ class TestCutlassBackend(TestCase):
|
|||
torch.testing.assert_close(expected, actual, atol=0.01, rtol=0.01)
|
||||
|
||||
# TODO: Enable dynamic test cases when dynamic support is added.
|
||||
@unittest.skipIf(True, "disabled due to broken on A100")
|
||||
# error: TypeError: can't multiply sequence by non-int of type 'str'
|
||||
@unittest.skipIf(not SM80, "need sm_80 exactly")
|
||||
@unittest.skipIf(not SM80OrLater or SM90OrLater, "need sm_8x exactly")
|
||||
@parametrize("dynamic", (False,))
|
||||
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "CUTLASS,Triton,ATen"))
|
||||
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
def test_max_autotune_cutlass_backend_mixed_mm(
|
||||
self, dynamic: bool, max_autotune_gemm_backends: str
|
||||
):
|
||||
def test_max_autotune_cutlass_backend_mixed_mm(self, dynamic: bool):
|
||||
"""
|
||||
Make sure autotuning mm in sub processes work without crashes.
|
||||
"""
|
||||
|
||||
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
|
||||
if torch.version.hip:
|
||||
return
|
||||
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
|
|
@ -830,7 +824,7 @@ class TestCutlassBackend(TestCase):
|
|||
{
|
||||
"max_autotune": True,
|
||||
"autotune_in_subproc": True,
|
||||
"max_autotune_gemm_backends": max_autotune_gemm_backends,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
"cuda.cutlass_dir": _CUTLASS_DIR,
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
"use_mixed_mm": True,
|
||||
|
|
@ -854,20 +848,17 @@ class TestCutlassBackend(TestCase):
|
|||
assert cutlass_kernels_count > 0
|
||||
|
||||
# TODO: Enable dynamic test cases when dynamic support is added.
|
||||
@unittest.skipIf(True, "disabled due to broken on A100")
|
||||
# error: TypeError: can't multiply sequence by non-int of type 'str'
|
||||
@unittest.skipIf(not SM80, "need sm_80 exactly")
|
||||
@unittest.skipIf(not SM80OrLater or SM90OrLater, "need sm_8x exactly")
|
||||
@parametrize("dynamic", (False,))
|
||||
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "CUTLASS,Triton,ATen"))
|
||||
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
def test_max_autotune_cutlass_backend_sparse_semi_structured_mm(
|
||||
self, dynamic: bool, max_autotune_gemm_backends: str
|
||||
self, dynamic: bool
|
||||
):
|
||||
"""
|
||||
Make sure autotuning mm in sub processes work without crashes.
|
||||
"""
|
||||
|
||||
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
|
||||
if torch.version.hip:
|
||||
return
|
||||
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = True
|
||||
|
|
@ -885,7 +876,7 @@ class TestCutlassBackend(TestCase):
|
|||
{
|
||||
"max_autotune": True,
|
||||
"autotune_in_subproc": True,
|
||||
"max_autotune_gemm_backends": max_autotune_gemm_backends,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
"cuda.cutlass_dir": _CUTLASS_DIR,
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
"autotune_local_cache": True,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import logging
|
|||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from sympy import Expr
|
||||
from sympy import Expr, symbols
|
||||
|
||||
from torch import dtype as torch_dtype
|
||||
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
|
||||
|
|
@ -404,6 +404,7 @@ class CUDATemplateKernel(CUDAKernel):
|
|||
if len(sizes) == 0:
|
||||
return str(default_value)
|
||||
|
||||
sizes = [symbols(v) if isinstance(v, str) else v for v in sizes]
|
||||
val = sympy_product(sizes)
|
||||
return val
|
||||
|
||||
|
|
|
|||
|
|
@ -153,9 +153,6 @@ extern "C" {
|
|||
PT_EXPORT {{kernel_call_signature}} {
|
||||
try {
|
||||
int64_t B = {{kernel.size(Y, 0, -3, default_value=1)}};
|
||||
int64_t M = {{kernel.size(X, -2)}};
|
||||
int64_t K = {{kernel.size(W, -2)}};
|
||||
int64_t N = {{kernel.size(W, -1)}};
|
||||
using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator;
|
||||
using coord_t = cutlass::gemm::GemmCoord::Index;
|
||||
static cutlass::KernelHardwareInfo hw_info;
|
||||
|
|
@ -174,13 +171,6 @@ PT_EXPORT {{kernel_call_signature}} {
|
|||
|
||||
// check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers
|
||||
#ifndef CUTLASS_BACKEND_DISABLE_CHECKS
|
||||
{{kernel.check_not_null(X)}}
|
||||
{{kernel.check_not_null(W)}}
|
||||
{{kernel.check_not_null(Bias)}}
|
||||
{{kernel.check_not_null(Meta)}}
|
||||
{{kernel.check_not_null(Y)}}
|
||||
|
||||
|
||||
{
|
||||
auto status = gemm_op.can_implement(arguments);
|
||||
CUTLASS_CHECK(status);
|
||||
|
|
@ -276,7 +266,7 @@ GEMM_ARGS_SPARSE_CUTLASS_2X = r"""
|
|||
{
|
||||
static_cast<coord_t>({{M}}),
|
||||
static_cast<coord_t>({{N}}),
|
||||
static_cast<coord_t>(K),
|
||||
static_cast<coord_t>(2 * K),
|
||||
}, // GemmCoord problem_size
|
||||
X_ref, // TensorRef<ElementA const, LayoutA> ref_A
|
||||
W_ref, // TensorRef<ElementB const, LayoutB> ref_B
|
||||
|
|
@ -1379,7 +1369,7 @@ class CUTLASS2xGemmTemplate(CUTLASSGemmTemplate):
|
|||
A_size = [int(i) for i in A_layout.size]
|
||||
B_size = [int(i) for i in B_layout.size]
|
||||
K = max(A_size[1], B_size[0])
|
||||
return (K == A_size[1] or K == 2 * A_size[0]) and K == B_size[0]
|
||||
return (K == A_size[1] or K == 2 * A_size[1]) and K == B_size[0]
|
||||
|
||||
def _shape_match(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Reference in a new issue