mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert "Reland D62220158 (#136213)"
This reverts commit 083c9149b7.
Reverted https://github.com/pytorch/pytorch/pull/136213 on behalf of https://github.com/jeanschmidt due to Seems to have introduced regressions in rocm signals ([comment](https://github.com/pytorch/pytorch/pull/136213#issuecomment-2360885064))
This commit is contained in:
parent
bce52d0b60
commit
4ea741d24f
3 changed files with 0 additions and 59 deletions
|
|
@ -9,7 +9,6 @@ from torch._inductor.fx_passes.pad_mm import (
|
|||
get_pad_cache,
|
||||
get_padded_length,
|
||||
should_pad_common,
|
||||
should_pad_mm_bf16,
|
||||
)
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code
|
||||
|
|
@ -450,40 +449,6 @@ class PadMMTest(TestCase):
|
|||
repr(get_pad_cache().get_local_cache())
|
||||
)
|
||||
|
||||
@fresh_inductor_cache()
|
||||
@inductor_config.patch(
|
||||
post_grad_fusion_options={"pad_aten_mm_pass": {"k_threshold_to_pad": 8388608}}
|
||||
)
|
||||
def test_pad_mm_bf16(self):
|
||||
m = 2
|
||||
n = 13
|
||||
k = 15691904
|
||||
mat1 = torch.ones((m, k), device="cuda", dtype=torch.bfloat16)
|
||||
mat2 = torch.ones((k, n), device="cuda", dtype=torch.bfloat16)
|
||||
expected_alignment = get_alignment_size(mat1)
|
||||
|
||||
assert expected_alignment == 8, "Alignment for bfloat16 should be 8"
|
||||
assert should_pad_common(
|
||||
mat1, mat2
|
||||
), "This should pass the common padding criteria"
|
||||
if torch.cuda.get_device_capability() < (9, 0):
|
||||
assert should_pad_mm_bf16(
|
||||
mat1.dtype, m, n, k
|
||||
), "This should pass the should_pad_mm_bf16 padding criteria"
|
||||
|
||||
@torch.compile()
|
||||
def mm(mat1, mat2):
|
||||
return torch.mm(mat1, mat2)
|
||||
|
||||
res2, (code,) = run_and_get_code(mm, mat1, mat2)
|
||||
mm_expected_result = torch.mm(mat1, mat2)
|
||||
# in call code, expect to see a single pad per input, and then we should see padded allocation for output
|
||||
FileCheck().check("del async_compile").check_count(
|
||||
".run(", 2, exactly=True
|
||||
).check("empty_strided_cuda((8, 16)").run(code)
|
||||
|
||||
assert torch.allclose(res2, mm_expected_result), "MM results are not identical"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if HAS_CUDA:
|
||||
|
|
|
|||
|
|
@ -364,23 +364,6 @@ def should_pad(key: str, ori_time, pad_time) -> bool:
|
|||
return should_pad
|
||||
|
||||
|
||||
def should_pad_mm_bf16(dtype, M, N, K):
|
||||
# always force pad for mm with bf16 when the following are satisfied to avoid perf regression
|
||||
large_k_threshold_to_pad = torch._inductor.config.post_grad_fusion_options[
|
||||
"pad_aten_mm_pass"
|
||||
].get("k_threshold_to_pad", 8388608)
|
||||
if (
|
||||
dtype is torch.bfloat16
|
||||
and K > M
|
||||
and K > N
|
||||
and N % 2 == 1
|
||||
and K >= large_k_threshold_to_pad
|
||||
and torch.cuda.get_device_capability() < (9, 0)
|
||||
): # doesnt repro on h100s:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def should_pad_bench(
|
||||
match, mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None
|
||||
) -> bool:
|
||||
|
|
@ -427,12 +410,6 @@ def should_pad_bench(
|
|||
if torch._inductor.config.force_shape_pad:
|
||||
return True
|
||||
|
||||
if (
|
||||
"pad_aten_mm_pass" in torch._inductor.config.post_grad_fusion_options
|
||||
and should_pad_mm_bf16(mat1.dtype, m, n, k)
|
||||
):
|
||||
return True
|
||||
|
||||
if not has_triton():
|
||||
return False
|
||||
|
||||
|
|
|
|||
|
|
@ -65,7 +65,6 @@ post_grad_pass_names = [
|
|||
"decompose_mm_pass",
|
||||
"unbind_stack_aten_pass",
|
||||
"shape_padding_multiplier",
|
||||
"pad_aten_mm_pass",
|
||||
]
|
||||
|
||||
for pass_name in pre_grad_pass_names:
|
||||
|
|
|
|||
Loading…
Reference in a new issue