[cutlass backend] fix bug for accuminator dtype (#146356)

Will add unit tests for accuracy.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146356
Approved by: https://github.com/Chillee
This commit is contained in:
Henry Tsang 2025-02-03 14:02:56 -08:00 committed by PyTorch MergeBot
parent 13e17aa106
commit 7c8ec84dab
2 changed files with 6 additions and 73 deletions

View file

@ -92,8 +92,6 @@ class TestCutlassBackend(TestCase):
if torch.version.hip:
return
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
def mm(a, b):
return a @ b
@ -141,8 +139,6 @@ class TestCutlassBackend(TestCase):
if torch.version.hip:
return
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
def mm(a, b):
return a @ b
@ -170,7 +166,6 @@ class TestCutlassBackend(TestCase):
Compile with one shape, then re-run with different input shapes
"""
max_autotune_gemm_backends = "CUTLASS"
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
class MyModel(torch.nn.Module):
def forward(self, a, b):
@ -216,7 +211,6 @@ class TestCutlassBackend(TestCase):
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_diff_matmul_share_same_kernel(self, dynamic):
max_autotune_gemm_backends = "CUTLASS"
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
class MyModel(torch.nn.Module):
def __init__(self):
@ -267,8 +261,6 @@ class TestCutlassBackend(TestCase):
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
return
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
@ -312,8 +304,6 @@ class TestCutlassBackend(TestCase):
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
return
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
def mm(a, b):
return a @ b
@ -356,16 +346,11 @@ class TestCutlassBackend(TestCase):
self,
dynamic: bool = False,
max_autotune_gemm_backends: str = "CUTLASS",
mixed_precision=False,
fp16=True,
expected_fuse_count=0,
mm: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
batch_size: Optional[int] = None,
):
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
mixed_precision
)
# Note: The ops that are available
# also depend on the alignment of the shapes
# so if these shapes don't all align to at least 8 elements
@ -400,17 +385,6 @@ class TestCutlassBackend(TestCase):
), f"Expected fuse count of {expected_fuse_count} but got {actual_count}"
torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
def test_max_autotune_cutlass_backend_simple_fusion_fp16(self):
def mm(a, b):
return (a @ b) * 3.0
# The pointwise ops seem to be pre-fused into a single Pointwise
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=False, fp16=True, expected_fuse_count=0, mm=mm
)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
def test_max_autotune_cutlass_backend_simple_fusion_fp16_fp32acc(self):
@ -418,18 +392,7 @@ class TestCutlassBackend(TestCase):
return (a @ b) * 3.0
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
def test_max_autotune_cutlass_backend_chained_fusion_fp16(self):
def mm(a, b):
return (a @ b) * 3.3 - 1.234
# The pointwise ops seem to be pre-fused into a single Pointwise
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=False, fp16=True, expected_fuse_count=0, mm=mm
fp16=True, expected_fuse_count=0, mm=mm
)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@ -439,17 +402,7 @@ class TestCutlassBackend(TestCase):
return (a @ b) * 3.3 - 1.234
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
def test_max_autotune_cutlass_backend_relu_fusion_fp16(self):
def mm(a, b):
return torch.nn.functional.relu((a @ b) * 3.3 - 1.234)
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=False, fp16=True, expected_fuse_count=0, mm=mm
fp16=True, expected_fuse_count=0, mm=mm
)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@ -460,7 +413,7 @@ class TestCutlassBackend(TestCase):
# The pointwise ops seem to be pre-fused into a single Pointwise
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
fp16=True, expected_fuse_count=0, mm=mm
)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@ -471,7 +424,7 @@ class TestCutlassBackend(TestCase):
# The pointwise ops seem to be pre-fused into a single Pointwise
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
fp16=True, expected_fuse_count=0, mm=mm
)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@ -482,7 +435,7 @@ class TestCutlassBackend(TestCase):
return (a @ b).to(torch.float32) * 0.00001
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
fp16=True, expected_fuse_count=0, mm=mm
)
def test_max_autotune_cutlass_backend_simple_bmm(self):
@ -490,7 +443,6 @@ class TestCutlassBackend(TestCase):
return torch.bmm(a, b)
self._test_max_autotune_cutlass_backend_epilogue_fusion( # test bmm
mixed_precision=False,
fp16=True,
expected_fuse_count=0,
mm=bmm,
@ -504,7 +456,7 @@ class TestCutlassBackend(TestCase):
return (a @ b) / b.size(1)
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
fp16=True, expected_fuse_count=0, mm=mm
)
# TODO: Enable dynamic test cases when dynamic support is added.
@ -522,8 +474,6 @@ class TestCutlassBackend(TestCase):
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
return
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
def mm(a, b, bias):
return torch.nn.functional.linear(a, b, bias)
@ -558,8 +508,6 @@ class TestCutlassBackend(TestCase):
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
return
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
def addmm(x, a, b, alpha, beta):
return torch.addmm(x, a, b, alpha=alpha, beta=beta)
@ -597,8 +545,6 @@ class TestCutlassBackend(TestCase):
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_addmm_with_expanded_bias(self):
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
class MyModel(torch.nn.Module):
def forward(self, x, w):
bias = torch.zeros(
@ -671,8 +617,6 @@ class TestCutlassBackend(TestCase):
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
@unittest.skipIf(not SM90OrLater, "need sm_90")
def test_force_cutlass_backend_aoti_dynamic(self):
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
class MyModel(torch.nn.Module):
def forward(self, x, w):
return x @ w
@ -709,8 +653,6 @@ class TestCutlassBackend(TestCase):
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
@unittest.skipIf(not SM90OrLater, "need sm_90")
def test_force_cutlass_backend_aoti_cexpr_codegen(self):
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
class MyModel(torch.nn.Module):
def forward(self, x, w):
x0, x1 = x.shape
@ -752,8 +694,6 @@ class TestCutlassBackend(TestCase):
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
@unittest.skipIf(not SM90OrLater, "need sm_90")
def test_aoti_workspace_ptr(self):
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
class MyModel(torch.nn.Module):
def forward(self, x, w):
return x @ w
@ -798,8 +738,6 @@ class TestCutlassBackend(TestCase):
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
return
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
def mm(a, b):
return torch.mm(a, b.to(torch.half))

View file

@ -279,11 +279,6 @@ def get_accumulator_dtype(
]:
torch_dtype = dtype0
if torch_dtype == torch.half:
if torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction:
return torch_dtype
else:
return torch.float
if torch_dtype in (torch.float16, torch.bfloat16, torch.float):
return torch.float
if torch_dtype == torch.int8: