diff --git a/docs/source/conf.py b/docs/source/conf.py index c4b7abb65df..73780029ec9 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1973,6 +1973,7 @@ coverage_ignore_functions = [ "mm_flop", "normalize_tuple", "register_flop_formula", + "scaled_mm_flop", "sdpa_backward_flop", "sdpa_backward_flop_count", "sdpa_flop", diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index b81fe3929eb..15ae5c891d6 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -2410,6 +2410,7 @@ "get_suffix_str", "mm_flop", "normalize_tuple", + "scaled_mm_flop", "sdpa_backward_flop", "sdpa_backward_flop_count", "sdpa_flop", diff --git a/test/test_flop_counter.py b/test/test_flop_counter.py index ed44a17e5a0..cbdba314a48 100644 --- a/test/test_flop_counter.py +++ b/test/test_flop_counter.py @@ -9,6 +9,7 @@ import torch.utils.flop_counter from torch._subclasses.fake_tensor import FakeTensorMode from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, + PLATFORM_SUPPORTS_FP8, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, PLATFORM_SUPPORTS_CUDNN_ATTENTION ) @@ -835,5 +836,23 @@ class TestFlopCounter(TestCase): ] self.assertEqual(layer1_conv_flops_standard, layer1_conv_flops_inference) + @unittest.skipIf(not HAS_CUDA, "CUDA not available") + @unittest.skipIf( + not PLATFORM_SUPPORTS_FP8, + "Does not support fp8 (pre-SM90 hardware on CUDA)", + ) + def test_scaled_mm(self): + mod = torch.nn.Linear(9, 10) + with FlopCounterMode() as mode: + torch._scaled_mm( + torch.randn((3 * 16, 5 * 16), device="cuda").to(torch.float8_e4m3fn), + torch.randn((7 * 16, 5 * 16), device="cuda").to(torch.float8_e4m3fn).t(), + scale_a=torch.ones((), device="cuda"), + scale_b=torch.ones((), device="cuda"), + out_dtype=torch.bfloat16, + ) + + self.assertExpectedInline(get_total_flops(mode), """860160""") + if __name__ == "__main__": run_tests() diff --git a/torch/utils/flop_counter.py b/torch/utils/flop_counter.py index a5ed14425a7..ea3d453dbf4 100644 --- a/torch/utils/flop_counter.py +++ b/torch/utils/flop_counter.py @@ -89,6 +89,22 @@ def baddbmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int: # Inputs contains the shapes of three tensors. return bmm_flop(a_shape, b_shape) +@register_flop_formula(aten._scaled_mm) +def scaled_mm_flop( + a_shape, + b_shape, + scale_a_shape, + scale_b_shape, + bias_shape=None, + scale_result_shape=None, + out_dtype=None, + use_fast_accum=False, + out_shape=None, + **kwargs, +) -> int: + """Count flops for _scaled_mm.""" + return mm_flop(a_shape, b_shape) + def conv_flop_count( x_shape: List[int], @@ -541,6 +557,7 @@ flop_registry = { aten.addmm: addmm_flop, aten.bmm: bmm_flop, aten.baddbmm: baddbmm_flop, + aten._scaled_mm: scaled_mm_flop, aten.convolution: conv_flop, aten._convolution: conv_flop, aten.convolution_backward: conv_backward_flop,