Add flop formula for _scaled_mm (#144872)

This will make it work correctly with the partitioner's AutoAC
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144872
Approved by: https://github.com/vkuzo
This commit is contained in:
Luca Wehrstedt 2025-01-16 09:47:28 +00:00 committed by PyTorch MergeBot
parent 1c290912e4
commit f31452268b
4 changed files with 38 additions and 0 deletions

View file

@ -1973,6 +1973,7 @@ coverage_ignore_functions = [
"mm_flop",
"normalize_tuple",
"register_flop_formula",
"scaled_mm_flop",
"sdpa_backward_flop",
"sdpa_backward_flop_count",
"sdpa_flop",

View file

@ -2410,6 +2410,7 @@
"get_suffix_str",
"mm_flop",
"normalize_tuple",
"scaled_mm_flop",
"sdpa_backward_flop",
"sdpa_backward_flop_count",
"sdpa_flop",

View file

@ -9,6 +9,7 @@ import torch.utils.flop_counter
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FLASH_ATTENTION,
PLATFORM_SUPPORTS_FP8,
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
PLATFORM_SUPPORTS_CUDNN_ATTENTION
)
@ -835,5 +836,23 @@ class TestFlopCounter(TestCase):
]
self.assertEqual(layer1_conv_flops_standard, layer1_conv_flops_inference)
@unittest.skipIf(not HAS_CUDA, "CUDA not available")
@unittest.skipIf(
not PLATFORM_SUPPORTS_FP8,
"Does not support fp8 (pre-SM90 hardware on CUDA)",
)
def test_scaled_mm(self):
mod = torch.nn.Linear(9, 10)
with FlopCounterMode() as mode:
torch._scaled_mm(
torch.randn((3 * 16, 5 * 16), device="cuda").to(torch.float8_e4m3fn),
torch.randn((7 * 16, 5 * 16), device="cuda").to(torch.float8_e4m3fn).t(),
scale_a=torch.ones((), device="cuda"),
scale_b=torch.ones((), device="cuda"),
out_dtype=torch.bfloat16,
)
self.assertExpectedInline(get_total_flops(mode), """860160""")
if __name__ == "__main__":
run_tests()

View file

@ -89,6 +89,22 @@ def baddbmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
# Inputs contains the shapes of three tensors.
return bmm_flop(a_shape, b_shape)
@register_flop_formula(aten._scaled_mm)
def scaled_mm_flop(
a_shape,
b_shape,
scale_a_shape,
scale_b_shape,
bias_shape=None,
scale_result_shape=None,
out_dtype=None,
use_fast_accum=False,
out_shape=None,
**kwargs,
) -> int:
"""Count flops for _scaled_mm."""
return mm_flop(a_shape, b_shape)
def conv_flop_count(
x_shape: List[int],
@ -541,6 +557,7 @@ flop_registry = {
aten.addmm: addmm_flop,
aten.bmm: bmm_flop,
aten.baddbmm: baddbmm_flop,
aten._scaled_mm: scaled_mm_flop,
aten.convolution: conv_flop,
aten._convolution: conv_flop,
aten.convolution_backward: conv_backward_flop,