diff --git a/test/inductor/test_cutlass_backend.py b/test/inductor/test_cutlass_backend.py index 610cbb2f2e4..78f2277a337 100644 --- a/test/inductor/test_cutlass_backend.py +++ b/test/inductor/test_cutlass_backend.py @@ -92,8 +92,6 @@ class TestCutlassBackend(TestCase): if torch.version.hip: return - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False - def mm(a, b): return a @ b @@ -141,8 +139,6 @@ class TestCutlassBackend(TestCase): if torch.version.hip: return - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False - def mm(a, b): return a @ b @@ -170,7 +166,6 @@ class TestCutlassBackend(TestCase): Compile with one shape, then re-run with different input shapes """ max_autotune_gemm_backends = "CUTLASS" - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False class MyModel(torch.nn.Module): def forward(self, a, b): @@ -216,7 +211,6 @@ class TestCutlassBackend(TestCase): @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_diff_matmul_share_same_kernel(self, dynamic): max_autotune_gemm_backends = "CUTLASS" - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False class MyModel(torch.nn.Module): def __init__(self): @@ -267,8 +261,6 @@ class TestCutlassBackend(TestCase): if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip: return - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False - class MyModel(torch.nn.Module): def __init__(self): super().__init__() @@ -312,8 +304,6 @@ class TestCutlassBackend(TestCase): if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip: return - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False - def mm(a, b): return a @ b @@ -356,16 +346,11 @@ class TestCutlassBackend(TestCase): self, dynamic: bool = False, max_autotune_gemm_backends: str = "CUTLASS", - mixed_precision=False, fp16=True, expected_fuse_count=0, mm: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, batch_size: Optional[int] = None, ): - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = ( - mixed_precision - ) - # Note: The ops that are available # also depend on the alignment of the shapes # so if these shapes don't all align to at least 8 elements @@ -400,17 +385,6 @@ class TestCutlassBackend(TestCase): ), f"Expected fuse count of {expected_fuse_count} but got {actual_count}" torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2) - @unittest.skipIf(not SM90OrLater, "need sm_90") - @unittest.skipIf(torch.version.hip, "HIP not supported") - def test_max_autotune_cutlass_backend_simple_fusion_fp16(self): - def mm(a, b): - return (a @ b) * 3.0 - - # The pointwise ops seem to be pre-fused into a single Pointwise - self._test_max_autotune_cutlass_backend_epilogue_fusion( - mixed_precision=False, fp16=True, expected_fuse_count=0, mm=mm - ) - @unittest.skipIf(not SM90OrLater, "need sm_90") @unittest.skipIf(torch.version.hip, "HIP not supported") def test_max_autotune_cutlass_backend_simple_fusion_fp16_fp32acc(self): @@ -418,18 +392,7 @@ class TestCutlassBackend(TestCase): return (a @ b) * 3.0 self._test_max_autotune_cutlass_backend_epilogue_fusion( - mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm - ) - - @unittest.skipIf(not SM90OrLater, "need sm_90") - @unittest.skipIf(torch.version.hip, "HIP not supported") - def test_max_autotune_cutlass_backend_chained_fusion_fp16(self): - def mm(a, b): - return (a @ b) * 3.3 - 1.234 - - # The pointwise ops seem to be pre-fused into a single Pointwise - self._test_max_autotune_cutlass_backend_epilogue_fusion( - mixed_precision=False, fp16=True, expected_fuse_count=0, mm=mm + fp16=True, expected_fuse_count=0, mm=mm ) @unittest.skipIf(not SM90OrLater, "need sm_90") @@ -439,17 +402,7 @@ class TestCutlassBackend(TestCase): return (a @ b) * 3.3 - 1.234 self._test_max_autotune_cutlass_backend_epilogue_fusion( - mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm - ) - - @unittest.skipIf(not SM90OrLater, "need sm_90") - @unittest.skipIf(torch.version.hip, "HIP not supported") - def test_max_autotune_cutlass_backend_relu_fusion_fp16(self): - def mm(a, b): - return torch.nn.functional.relu((a @ b) * 3.3 - 1.234) - - self._test_max_autotune_cutlass_backend_epilogue_fusion( - mixed_precision=False, fp16=True, expected_fuse_count=0, mm=mm + fp16=True, expected_fuse_count=0, mm=mm ) @unittest.skipIf(not SM90OrLater, "need sm_90") @@ -460,7 +413,7 @@ class TestCutlassBackend(TestCase): # The pointwise ops seem to be pre-fused into a single Pointwise self._test_max_autotune_cutlass_backend_epilogue_fusion( - mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm + fp16=True, expected_fuse_count=0, mm=mm ) @unittest.skipIf(not SM90OrLater, "need sm_90") @@ -471,7 +424,7 @@ class TestCutlassBackend(TestCase): # The pointwise ops seem to be pre-fused into a single Pointwise self._test_max_autotune_cutlass_backend_epilogue_fusion( - mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm + fp16=True, expected_fuse_count=0, mm=mm ) @unittest.skipIf(not SM90OrLater, "need sm_90") @@ -482,7 +435,7 @@ class TestCutlassBackend(TestCase): return (a @ b).to(torch.float32) * 0.00001 self._test_max_autotune_cutlass_backend_epilogue_fusion( - mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm + fp16=True, expected_fuse_count=0, mm=mm ) def test_max_autotune_cutlass_backend_simple_bmm(self): @@ -490,7 +443,6 @@ class TestCutlassBackend(TestCase): return torch.bmm(a, b) self._test_max_autotune_cutlass_backend_epilogue_fusion( # test bmm - mixed_precision=False, fp16=True, expected_fuse_count=0, mm=bmm, @@ -504,7 +456,7 @@ class TestCutlassBackend(TestCase): return (a @ b) / b.size(1) self._test_max_autotune_cutlass_backend_epilogue_fusion( - mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm + fp16=True, expected_fuse_count=0, mm=mm ) # TODO: Enable dynamic test cases when dynamic support is added. @@ -522,8 +474,6 @@ class TestCutlassBackend(TestCase): if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip: return - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False - def mm(a, b, bias): return torch.nn.functional.linear(a, b, bias) @@ -558,8 +508,6 @@ class TestCutlassBackend(TestCase): if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip: return - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False - def addmm(x, a, b, alpha, beta): return torch.addmm(x, a, b, alpha=alpha, beta=beta) @@ -597,8 +545,6 @@ class TestCutlassBackend(TestCase): @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_addmm_with_expanded_bias(self): - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False - class MyModel(torch.nn.Module): def forward(self, x, w): bias = torch.zeros( @@ -671,8 +617,6 @@ class TestCutlassBackend(TestCase): @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) @unittest.skipIf(not SM90OrLater, "need sm_90") def test_force_cutlass_backend_aoti_dynamic(self): - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False - class MyModel(torch.nn.Module): def forward(self, x, w): return x @ w @@ -709,8 +653,6 @@ class TestCutlassBackend(TestCase): @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) @unittest.skipIf(not SM90OrLater, "need sm_90") def test_force_cutlass_backend_aoti_cexpr_codegen(self): - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False - class MyModel(torch.nn.Module): def forward(self, x, w): x0, x1 = x.shape @@ -752,8 +694,6 @@ class TestCutlassBackend(TestCase): @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) @unittest.skipIf(not SM90OrLater, "need sm_90") def test_aoti_workspace_ptr(self): - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False - class MyModel(torch.nn.Module): def forward(self, x, w): return x @ w @@ -798,8 +738,6 @@ class TestCutlassBackend(TestCase): if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip: return - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False - def mm(a, b): return torch.mm(a, b.to(torch.half)) diff --git a/torch/_inductor/codegen/cuda/cutlass_utils.py b/torch/_inductor/codegen/cuda/cutlass_utils.py index 1786921e98b..61b141e4ad3 100644 --- a/torch/_inductor/codegen/cuda/cutlass_utils.py +++ b/torch/_inductor/codegen/cuda/cutlass_utils.py @@ -279,11 +279,6 @@ def get_accumulator_dtype( ]: torch_dtype = dtype0 - if torch_dtype == torch.half: - if torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction: - return torch_dtype - else: - return torch.float if torch_dtype in (torch.float16, torch.bfloat16, torch.float): return torch.float if torch_dtype == torch.int8: