Revert "Reland D62220158 (#136213)"

This reverts commit 083c9149b7.

Reverted https://github.com/pytorch/pytorch/pull/136213 on behalf of https://github.com/jeanschmidt due to Seems to have introduced regressions in rocm signals ([comment](https://github.com/pytorch/pytorch/pull/136213#issuecomment-2360885064))
This commit is contained in:
PyTorch MergeBot 2024-09-19 12:44:54 +00:00
parent bce52d0b60
commit 4ea741d24f
3 changed files with 0 additions and 59 deletions

View file

@ -9,7 +9,6 @@ from torch._inductor.fx_passes.pad_mm import (
get_pad_cache,
get_padded_length,
should_pad_common,
should_pad_mm_bf16,
)
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code
@ -450,40 +449,6 @@ class PadMMTest(TestCase):
repr(get_pad_cache().get_local_cache())
)
@fresh_inductor_cache()
@inductor_config.patch(
post_grad_fusion_options={"pad_aten_mm_pass": {"k_threshold_to_pad": 8388608}}
)
def test_pad_mm_bf16(self):
m = 2
n = 13
k = 15691904
mat1 = torch.ones((m, k), device="cuda", dtype=torch.bfloat16)
mat2 = torch.ones((k, n), device="cuda", dtype=torch.bfloat16)
expected_alignment = get_alignment_size(mat1)
assert expected_alignment == 8, "Alignment for bfloat16 should be 8"
assert should_pad_common(
mat1, mat2
), "This should pass the common padding criteria"
if torch.cuda.get_device_capability() < (9, 0):
assert should_pad_mm_bf16(
mat1.dtype, m, n, k
), "This should pass the should_pad_mm_bf16 padding criteria"
@torch.compile()
def mm(mat1, mat2):
return torch.mm(mat1, mat2)
res2, (code,) = run_and_get_code(mm, mat1, mat2)
mm_expected_result = torch.mm(mat1, mat2)
# in call code, expect to see a single pad per input, and then we should see padded allocation for output
FileCheck().check("del async_compile").check_count(
".run(", 2, exactly=True
).check("empty_strided_cuda((8, 16)").run(code)
assert torch.allclose(res2, mm_expected_result), "MM results are not identical"
if __name__ == "__main__":
if HAS_CUDA:

View file

@ -364,23 +364,6 @@ def should_pad(key: str, ori_time, pad_time) -> bool:
return should_pad
def should_pad_mm_bf16(dtype, M, N, K):
# always force pad for mm with bf16 when the following are satisfied to avoid perf regression
large_k_threshold_to_pad = torch._inductor.config.post_grad_fusion_options[
"pad_aten_mm_pass"
].get("k_threshold_to_pad", 8388608)
if (
dtype is torch.bfloat16
and K > M
and K > N
and N % 2 == 1
and K >= large_k_threshold_to_pad
and torch.cuda.get_device_capability() < (9, 0)
): # doesnt repro on h100s:
return True
return False
def should_pad_bench(
match, mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None
) -> bool:
@ -427,12 +410,6 @@ def should_pad_bench(
if torch._inductor.config.force_shape_pad:
return True
if (
"pad_aten_mm_pass" in torch._inductor.config.post_grad_fusion_options
and should_pad_mm_bf16(mat1.dtype, m, n, k)
):
return True
if not has_triton():
return False

View file

@ -65,7 +65,6 @@ post_grad_pass_names = [
"decompose_mm_pass",
"unbind_stack_aten_pass",
"shape_padding_multiplier",
"pad_aten_mm_pass",
]
for pass_name in pre_grad_pass_names: