mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add flop formula for _scaled_mm (#144872)
This will make it work correctly with the partitioner's AutoAC Pull Request resolved: https://github.com/pytorch/pytorch/pull/144872 Approved by: https://github.com/vkuzo
This commit is contained in:
parent
1c290912e4
commit
f31452268b
4 changed files with 38 additions and 0 deletions
|
|
@ -1973,6 +1973,7 @@ coverage_ignore_functions = [
|
|||
"mm_flop",
|
||||
"normalize_tuple",
|
||||
"register_flop_formula",
|
||||
"scaled_mm_flop",
|
||||
"sdpa_backward_flop",
|
||||
"sdpa_backward_flop_count",
|
||||
"sdpa_flop",
|
||||
|
|
|
|||
|
|
@ -2410,6 +2410,7 @@
|
|||
"get_suffix_str",
|
||||
"mm_flop",
|
||||
"normalize_tuple",
|
||||
"scaled_mm_flop",
|
||||
"sdpa_backward_flop",
|
||||
"sdpa_backward_flop_count",
|
||||
"sdpa_flop",
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import torch.utils.flop_counter
|
|||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||
PLATFORM_SUPPORTS_FP8,
|
||||
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
|
||||
PLATFORM_SUPPORTS_CUDNN_ATTENTION
|
||||
)
|
||||
|
|
@ -835,5 +836,23 @@ class TestFlopCounter(TestCase):
|
|||
]
|
||||
self.assertEqual(layer1_conv_flops_standard, layer1_conv_flops_inference)
|
||||
|
||||
@unittest.skipIf(not HAS_CUDA, "CUDA not available")
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FP8,
|
||||
"Does not support fp8 (pre-SM90 hardware on CUDA)",
|
||||
)
|
||||
def test_scaled_mm(self):
|
||||
mod = torch.nn.Linear(9, 10)
|
||||
with FlopCounterMode() as mode:
|
||||
torch._scaled_mm(
|
||||
torch.randn((3 * 16, 5 * 16), device="cuda").to(torch.float8_e4m3fn),
|
||||
torch.randn((7 * 16, 5 * 16), device="cuda").to(torch.float8_e4m3fn).t(),
|
||||
scale_a=torch.ones((), device="cuda"),
|
||||
scale_b=torch.ones((), device="cuda"),
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
self.assertExpectedInline(get_total_flops(mode), """860160""")
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -89,6 +89,22 @@ def baddbmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
|
|||
# Inputs contains the shapes of three tensors.
|
||||
return bmm_flop(a_shape, b_shape)
|
||||
|
||||
@register_flop_formula(aten._scaled_mm)
|
||||
def scaled_mm_flop(
|
||||
a_shape,
|
||||
b_shape,
|
||||
scale_a_shape,
|
||||
scale_b_shape,
|
||||
bias_shape=None,
|
||||
scale_result_shape=None,
|
||||
out_dtype=None,
|
||||
use_fast_accum=False,
|
||||
out_shape=None,
|
||||
**kwargs,
|
||||
) -> int:
|
||||
"""Count flops for _scaled_mm."""
|
||||
return mm_flop(a_shape, b_shape)
|
||||
|
||||
|
||||
def conv_flop_count(
|
||||
x_shape: List[int],
|
||||
|
|
@ -541,6 +557,7 @@ flop_registry = {
|
|||
aten.addmm: addmm_flop,
|
||||
aten.bmm: bmm_flop,
|
||||
aten.baddbmm: baddbmm_flop,
|
||||
aten._scaled_mm: scaled_mm_flop,
|
||||
aten.convolution: conv_flop,
|
||||
aten._convolution: conv_flop,
|
||||
aten.convolution_backward: conv_backward_flop,
|
||||
|
|
|
|||
Loading…
Reference in a new issue