mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[cutlass backend] fix bug for accuminator dtype (#146356)
Will add unit tests for accuracy. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146356 Approved by: https://github.com/Chillee
This commit is contained in:
parent
13e17aa106
commit
7c8ec84dab
2 changed files with 6 additions and 73 deletions
|
|
@ -92,8 +92,6 @@ class TestCutlassBackend(TestCase):
|
|||
if torch.version.hip:
|
||||
return
|
||||
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
|
||||
def mm(a, b):
|
||||
return a @ b
|
||||
|
||||
|
|
@ -141,8 +139,6 @@ class TestCutlassBackend(TestCase):
|
|||
if torch.version.hip:
|
||||
return
|
||||
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
|
||||
def mm(a, b):
|
||||
return a @ b
|
||||
|
||||
|
|
@ -170,7 +166,6 @@ class TestCutlassBackend(TestCase):
|
|||
Compile with one shape, then re-run with different input shapes
|
||||
"""
|
||||
max_autotune_gemm_backends = "CUTLASS"
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
|
||||
class MyModel(torch.nn.Module):
|
||||
def forward(self, a, b):
|
||||
|
|
@ -216,7 +211,6 @@ class TestCutlassBackend(TestCase):
|
|||
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
def test_diff_matmul_share_same_kernel(self, dynamic):
|
||||
max_autotune_gemm_backends = "CUTLASS"
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
|
||||
class MyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
@ -267,8 +261,6 @@ class TestCutlassBackend(TestCase):
|
|||
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
|
||||
return
|
||||
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
|
||||
class MyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
@ -312,8 +304,6 @@ class TestCutlassBackend(TestCase):
|
|||
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
|
||||
return
|
||||
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
|
||||
def mm(a, b):
|
||||
return a @ b
|
||||
|
||||
|
|
@ -356,16 +346,11 @@ class TestCutlassBackend(TestCase):
|
|||
self,
|
||||
dynamic: bool = False,
|
||||
max_autotune_gemm_backends: str = "CUTLASS",
|
||||
mixed_precision=False,
|
||||
fp16=True,
|
||||
expected_fuse_count=0,
|
||||
mm: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
):
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
|
||||
mixed_precision
|
||||
)
|
||||
|
||||
# Note: The ops that are available
|
||||
# also depend on the alignment of the shapes
|
||||
# so if these shapes don't all align to at least 8 elements
|
||||
|
|
@ -400,17 +385,6 @@ class TestCutlassBackend(TestCase):
|
|||
), f"Expected fuse count of {expected_fuse_count} but got {actual_count}"
|
||||
torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(torch.version.hip, "HIP not supported")
|
||||
def test_max_autotune_cutlass_backend_simple_fusion_fp16(self):
|
||||
def mm(a, b):
|
||||
return (a @ b) * 3.0
|
||||
|
||||
# The pointwise ops seem to be pre-fused into a single Pointwise
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
mixed_precision=False, fp16=True, expected_fuse_count=0, mm=mm
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(torch.version.hip, "HIP not supported")
|
||||
def test_max_autotune_cutlass_backend_simple_fusion_fp16_fp32acc(self):
|
||||
|
|
@ -418,18 +392,7 @@ class TestCutlassBackend(TestCase):
|
|||
return (a @ b) * 3.0
|
||||
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(torch.version.hip, "HIP not supported")
|
||||
def test_max_autotune_cutlass_backend_chained_fusion_fp16(self):
|
||||
def mm(a, b):
|
||||
return (a @ b) * 3.3 - 1.234
|
||||
|
||||
# The pointwise ops seem to be pre-fused into a single Pointwise
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
mixed_precision=False, fp16=True, expected_fuse_count=0, mm=mm
|
||||
fp16=True, expected_fuse_count=0, mm=mm
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
|
|
@ -439,17 +402,7 @@ class TestCutlassBackend(TestCase):
|
|||
return (a @ b) * 3.3 - 1.234
|
||||
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(torch.version.hip, "HIP not supported")
|
||||
def test_max_autotune_cutlass_backend_relu_fusion_fp16(self):
|
||||
def mm(a, b):
|
||||
return torch.nn.functional.relu((a @ b) * 3.3 - 1.234)
|
||||
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
mixed_precision=False, fp16=True, expected_fuse_count=0, mm=mm
|
||||
fp16=True, expected_fuse_count=0, mm=mm
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
|
|
@ -460,7 +413,7 @@ class TestCutlassBackend(TestCase):
|
|||
|
||||
# The pointwise ops seem to be pre-fused into a single Pointwise
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
|
||||
fp16=True, expected_fuse_count=0, mm=mm
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
|
|
@ -471,7 +424,7 @@ class TestCutlassBackend(TestCase):
|
|||
|
||||
# The pointwise ops seem to be pre-fused into a single Pointwise
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
|
||||
fp16=True, expected_fuse_count=0, mm=mm
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
|
|
@ -482,7 +435,7 @@ class TestCutlassBackend(TestCase):
|
|||
return (a @ b).to(torch.float32) * 0.00001
|
||||
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
|
||||
fp16=True, expected_fuse_count=0, mm=mm
|
||||
)
|
||||
|
||||
def test_max_autotune_cutlass_backend_simple_bmm(self):
|
||||
|
|
@ -490,7 +443,6 @@ class TestCutlassBackend(TestCase):
|
|||
return torch.bmm(a, b)
|
||||
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion( # test bmm
|
||||
mixed_precision=False,
|
||||
fp16=True,
|
||||
expected_fuse_count=0,
|
||||
mm=bmm,
|
||||
|
|
@ -504,7 +456,7 @@ class TestCutlassBackend(TestCase):
|
|||
return (a @ b) / b.size(1)
|
||||
|
||||
self._test_max_autotune_cutlass_backend_epilogue_fusion(
|
||||
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
|
||||
fp16=True, expected_fuse_count=0, mm=mm
|
||||
)
|
||||
|
||||
# TODO: Enable dynamic test cases when dynamic support is added.
|
||||
|
|
@ -522,8 +474,6 @@ class TestCutlassBackend(TestCase):
|
|||
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
|
||||
return
|
||||
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
|
||||
def mm(a, b, bias):
|
||||
return torch.nn.functional.linear(a, b, bias)
|
||||
|
||||
|
|
@ -558,8 +508,6 @@ class TestCutlassBackend(TestCase):
|
|||
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
|
||||
return
|
||||
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
|
||||
def addmm(x, a, b, alpha, beta):
|
||||
return torch.addmm(x, a, b, alpha=alpha, beta=beta)
|
||||
|
||||
|
|
@ -597,8 +545,6 @@ class TestCutlassBackend(TestCase):
|
|||
|
||||
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
def test_addmm_with_expanded_bias(self):
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
|
||||
class MyModel(torch.nn.Module):
|
||||
def forward(self, x, w):
|
||||
bias = torch.zeros(
|
||||
|
|
@ -671,8 +617,6 @@ class TestCutlassBackend(TestCase):
|
|||
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
def test_force_cutlass_backend_aoti_dynamic(self):
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
|
||||
class MyModel(torch.nn.Module):
|
||||
def forward(self, x, w):
|
||||
return x @ w
|
||||
|
|
@ -709,8 +653,6 @@ class TestCutlassBackend(TestCase):
|
|||
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
def test_force_cutlass_backend_aoti_cexpr_codegen(self):
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
|
||||
class MyModel(torch.nn.Module):
|
||||
def forward(self, x, w):
|
||||
x0, x1 = x.shape
|
||||
|
|
@ -752,8 +694,6 @@ class TestCutlassBackend(TestCase):
|
|||
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
def test_aoti_workspace_ptr(self):
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
|
||||
class MyModel(torch.nn.Module):
|
||||
def forward(self, x, w):
|
||||
return x @ w
|
||||
|
|
@ -798,8 +738,6 @@ class TestCutlassBackend(TestCase):
|
|||
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
|
||||
return
|
||||
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
|
||||
def mm(a, b):
|
||||
return torch.mm(a, b.to(torch.half))
|
||||
|
||||
|
|
|
|||
|
|
@ -279,11 +279,6 @@ def get_accumulator_dtype(
|
|||
]:
|
||||
torch_dtype = dtype0
|
||||
|
||||
if torch_dtype == torch.half:
|
||||
if torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction:
|
||||
return torch_dtype
|
||||
else:
|
||||
return torch.float
|
||||
if torch_dtype in (torch.float16, torch.bfloat16, torch.float):
|
||||
return torch.float
|
||||
if torch_dtype == torch.int8:
|
||||
|
|
|
|||
Loading…
Reference in a new issue