diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 79d5d84a95f..460e2076b17 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -14034,7 +14034,7 @@ CUDA: _scaled_dot_product_flash_attention_cuda NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda -- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, int philox_seed, int philox_offse, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value) +- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, int philox_seed, int philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value) variants: function dispatch: CUDA: _scaled_dot_product_flash_attention_backward_cuda diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.h b/aten/src/ATen/native/transformers/cuda/sdp_utils.h index ba8ec98d597..3c994fcf0e0 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.h +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.h @@ -39,22 +39,38 @@ struct sdp_params { bool is_causal; }; +inline bool check_requires_grad(sdp_params params, bool debug) { + const bool any_inputs_require_grad = params.query.requires_grad() || + params.key.requires_grad() || params.value.requires_grad(); + const bool gradmode_enabled = at::GradMode::is_enabled(); + if ((any_inputs_require_grad && gradmode_enabled)) { + if (debug) { + TORCH_WARN("Flash Attention does not currently support training."); + } + return false; + } + return true; +} + inline std::array priority_order(sdp_params params) { constexpr std::array default_order{ SDPBackend::flash_attention, SDPBackend::efficient_attention, SDPBackend::math}; + + constexpr std::array efficient_first{ + SDPBackend::efficient_attention, + SDPBackend::flash_attention, + SDPBackend::math}; // Logic is taken from xformers // FlashAttention parallelizes across "batch_size * num_heads" // MemEff parallelizes across "batch_size * num_heads * num_queries" and can // be more efficient. batch_size, q_len, num_heads, k = inp.query.shape + if (params.query.is_nested() || params.key.is_nested() || params.value.is_nested()) { // See check_for_nested_inputs for details - return { - SDPBackend::efficient_attention, - SDPBackend::flash_attention, - SDPBackend::math}; + return efficient_first; } if (params.query.dim() != 4) { return default_order; @@ -70,13 +86,14 @@ inline std::array priority_order(sdp_params params) { bool more_threads_cutlass = (threads_cutlass / 2) >= threads_flash; bool small_threads_flash = threads_flash < 60; bool large_head_dim = head_dim.max(params.key.sym_size(3)) == 128; - if ((small_threads_flash && more_threads_cutlass) || large_head_dim) { - return { - SDPBackend::efficient_attention, - SDPBackend::flash_attention, - SDPBackend::math}; + + // The training heuristic is taken from https://github.com/pytorch/pytorch/pull/99644 + // Revisit when updated cutlass kernel is upstreamed. + if (check_requires_grad(params, false)) { + if (6 * threads_flash > query_lengths) return efficient_first; + } else if ((small_threads_flash && more_threads_cutlass) || large_head_dim) + return efficient_first; } - } return default_order; } @@ -253,19 +270,6 @@ inline bool check_for_seq_len_1_nested_tensor(sdp_params params, bool debug) { return true; } -inline bool check_requires_grad(sdp_params params, bool debug) { - const bool any_inputs_require_grad = params.query.requires_grad() || - params.key.requires_grad() || params.value.requires_grad(); - const bool gradmode_enabled = at::GradMode::is_enabled(); - if ((any_inputs_require_grad && gradmode_enabled)) { - if (debug) { - TORCH_WARN("Flash Attention does not currently support training."); - } - return false; - } - return true; -} - inline bool check_requires_grad_and_nested(sdp_params params, bool debug) { // If we fail both checks then we return false if (check_for_nested_inputs(params) && !check_requires_grad(params, false)){ diff --git a/functorch/benchmarks/transformer_fusion_patterns/__init__.py b/functorch/benchmarks/transformer_fusion_patterns/__init__.py deleted file mode 100644 index 10a55772ab5..00000000000 --- a/functorch/benchmarks/transformer_fusion_patterns/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. diff --git a/functorch/benchmarks/transformer_fusion_patterns/benchmark.py b/functorch/benchmarks/transformer_fusion_patterns/benchmark.py deleted file mode 100644 index f7994223e12..00000000000 --- a/functorch/benchmarks/transformer_fusion_patterns/benchmark.py +++ /dev/null @@ -1,190 +0,0 @@ -import torch -from functorch.compile import memory_efficient_fusion -import benchmark_helper - - -device = "cuda" -dtype = torch.float16 - -# LightSeq pattern 1 -class DropoutResBias: - @staticmethod - def fn(input, bias, residual): - a = torch.add(input, bias) - b = torch.nn.functional.dropout(a, p=0.7, training=True) - c = b + residual - return c - - @staticmethod - def args(): - batch_size, seq_len, hidden_size = 32, 196, 1024 - input = torch.randn( - batch_size, - seq_len, - hidden_size, - requires_grad=True, - device=device, - dtype=dtype, - ) - bias = torch.randn(hidden_size, requires_grad=True, device=device, dtype=dtype) - residual = torch.randn( - batch_size, - seq_len, - hidden_size, - requires_grad=False, - device=device, - dtype=dtype, - ) - args = (input, bias, residual) - return args - - -class DropoutResBiasScalar: - @staticmethod - def fn(input, bias, residual, p: float): - a = torch.add(input, bias) - b = torch.nn.functional.dropout(a, p, training=True) - c = b + residual - return c - - @staticmethod - def args(): - batch_size, seq_len, hidden_size = 32, 196, 1024 - input = torch.randn( - batch_size, - seq_len, - hidden_size, - requires_grad=True, - device=device, - dtype=dtype, - ) - bias = torch.randn(hidden_size, requires_grad=True, device=device, dtype=dtype) - residual = torch.randn( - batch_size, - seq_len, - hidden_size, - requires_grad=False, - device=device, - dtype=dtype, - ) - args = (input, bias, residual, 0.7) - return args - - - -# LightSeq pattern 2 -class BiasReluDropout: - @staticmethod - def fn(input, bias): - a = torch.add(input, bias) - b = torch.nn.functional.relu(a) - c = torch.nn.functional.dropout(b, p=0.6, training=True) - return c - - @staticmethod - def args(): - batch_size = 32 - seq_len = 196 - intermediate_size = 4096 - input = torch.randn( - batch_size, - seq_len, - intermediate_size, - requires_grad=True, - device=device, - dtype=dtype, - ) - bias = torch.randn( - intermediate_size, requires_grad=True, device=device, dtype=dtype - ) - args = (input, bias) - return args - - -class BiasDropoutResLayerNorm: - @staticmethod - def fn(input, bias, residual): - hidden_size = 1024 - a = torch.add(input, bias) - b = torch.nn.functional.dropout(a, p=0.7, training=True) - c = b + residual - d = torch.nn.functional.layer_norm(c, normalized_shape=(hidden_size,)) - return d - - @staticmethod - def args(): - batch_size = 32 - seq_len = 196 - hidden_size = 1024 - - input = torch.randn( - batch_size, - seq_len, - hidden_size, - requires_grad=True, - device=device, - dtype=dtype, - ) - bias = torch.randn(hidden_size, requires_grad=True, device=device, dtype=dtype) - residual = torch.randn( - batch_size, - seq_len, - hidden_size, - requires_grad=False, - device=device, - dtype=dtype, - ) - args = (input, bias, residual) - return args - - -class LayerNormSigmoid: - @staticmethod - def fn(inp): - hidden_size = 512 - a = torch.nn.functional.layer_norm(inp, normalized_shape=(hidden_size,)) - b = torch.sigmoid(a) - return b - - @staticmethod - def args(): - batch_size = 8192 - hidden_size = 512 - inp = torch.randn( - batch_size, hidden_size, requires_grad=True, device=device, dtype=dtype - ) - args = (inp,) - return args - - -for cl in [DropoutResBias, BiasReluDropout, DropoutResBiasScalar, BiasDropoutResLayerNorm, LayerNormSigmoid]: - # Clear the compile cache - - # Get the function and inputs - obj = cl() - fn = obj.fn - args = obj.args() - - # Find the static args - static_argnums = [] - for idx, arg in enumerate(args): - if not isinstance(arg, torch.Tensor): - static_argnums.append(idx) - - # Get the optimized function - opt_fn = memory_efficient_fusion(fn, static_argnums) - - # Profile cuda kernels - benchmark_helper.profile_cuda_kernels(fn, args, "Eager") - with torch.jit.fuser("fuser2"): - benchmark_helper.profile_cuda_kernels(opt_fn, args, "AOTAutograd") - - # Time it with Torch Timer - benchmark_helper.time_with_torch_timer(fn, args, "Eager") - with torch.jit.fuser("fuser2"): - benchmark_helper.time_with_torch_timer(opt_fn, args, "AOTAutograd") - - # Time it with manual Timer - benchmark_helper.time_with_manual_timer(fn, args, "Eager") - with torch.jit.fuser("fuser2"): - benchmark_helper.time_with_manual_timer(opt_fn, args, "AOTAutograd") diff --git a/functorch/benchmarks/transformer_fusion_patterns/benchmark_helper.py b/functorch/benchmarks/transformer_fusion_patterns/benchmark_helper.py deleted file mode 100644 index bad27572e97..00000000000 --- a/functorch/benchmarks/transformer_fusion_patterns/benchmark_helper.py +++ /dev/null @@ -1,148 +0,0 @@ -import torch -from torch.profiler import profile, record_function, ProfilerActivity -from torch.utils.benchmark import Timer -import time - - -def profile_cuda_kernels(fn, args, string_id="Model time"): - print("################################################") - print(f"#### Profiling for {string_id} starts #########") - print("################################################") - warmup = 50 - old_args = args[:] - n_repeats = 1 - n_layers = 1 - ref = fn(*old_args) - gO = torch.rand_like(ref) - for _ in range(0, warmup // n_layers): - args = list(old_args[:]) - ref = fn(*args) - ref.backward(gO) - - torch.cuda.synchronize() - - # Forward profile - def fwd_run(): - for _ in range(0, n_repeats // n_layers): - args = list(old_args[:]) - for arg in args: - if isinstance(arg, torch.Tensor): - arg.grad = None - ref = fn(*args) - - print(f"###### Forward profile for {string_id} starts #####") - with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof: - with record_function("baseline"): - fwd_run() - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30)) - print(f"###### Forward profile for {string_id} ends #####") - - # Backward profile - def bwd_run(): - for _ in range(0, n_repeats // n_layers): - args = list(old_args[:]) - for arg in args: - if isinstance(arg, torch.Tensor): - arg.grad = None - ref = fn(*args) - - print(f"###### Backward profile for {string_id} starts #####") - torch.cuda.synchronize() - with profile( - activities=[ProfilerActivity.CUDA], record_shapes=True - ) as prof: - with record_function("baseline"): - ref.backward(gO) - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30)) - torch.cuda.synchronize() - print(f"###### Backward profile for {string_id} ends #####") - - bwd_run() - print("################################################") - print(f"#### Profiling for {string_id} ends #########") - print("################################################\n\n\n\n") - - -def time_with_torch_timer(fn, args, string_id, kwargs=None): - if kwargs is None: - kwargs = {} - print("################################################") - print(f"#### Torch Timer for {string_id} starts #########") - print("################################################") - ref = fn(*args, **kwargs) - gO = torch.rand_like(ref) - env = {"args": args, "gO": gO, "kwargs": kwargs, "fn": fn} - grad_none = {"for x in args: x.grad=None"} - fn_call = "fn(*args, **kwargs)" - # Measure end-to-end fwd time - timer = Timer(stmt=f"{fn_call}", globals=env) - fwd_latency = round(timer.timeit(1000).mean * 10 ** 6, 3) - timer_blocked = timer.blocked_autorange() - print(f"Forward = {fwd_latency}") - - # Measure end-to-end fwd bwd - timer = Timer( - stmt=f"{grad_none}; fwd = {fn_call}; fwd.backward(gO)", - globals=env, - ) - fwd_bwd_latency = round(timer.timeit(1000).mean * 10 ** 6, 3) - timer_blocked = timer.blocked_autorange() - # print(f"Forward + sum + Backward = {fwd_sum_bwd_latency}") - - bwd_latency = round(fwd_bwd_latency - fwd_latency, 3) - print(f"Backward = {bwd_latency}") - - print("################################################") - print(f"#### Torch Timer for {string_id} ends ###############") - print("################################################\n\n\n\n") - - -def time_with_manual_timer(fn, args, string_id): - print("################################################") - print(f"#### Manual Timer for {string_id} starts #########") - print("################################################") - warmup = 50 - repeats = 1000 - old_args = args[:] - ref = fn(*old_args) - gO = torch.rand_like(ref) - for _ in range(0, warmup): - args = list(old_args[:]) - - for arg in args: - if isinstance(arg, torch.Tensor): - arg.grad = None - ref = fn(*args) - ref.backward(gO) - - torch.cuda.synchronize() - - fwd_times = [] - bwd_times = [] - for _ in range(0, repeats): - args = list(old_args[:]) - for arg in args: - if isinstance(arg, torch.Tensor): - arg.grad = None - fwd_start = time.time() - ref = fn(*args) - torch.cuda.synchronize() - fwd_end = time.time() - - bwd_start = time.time() - ref.backward(gO) - torch.cuda.synchronize() - bwd_end = time.time() - - fwd_times.append(fwd_end - fwd_start) - bwd_times.append(bwd_end - bwd_start) - avg_fwd = round(sum(fwd_times) / repeats * 10 ** 6, 2) - avg_bwd = round(sum(bwd_times) / repeats * 10 ** 6, 2) - avg_total = round(avg_fwd + avg_bwd, 2) - - print(f"Forward = {avg_fwd}") - print(f"Backward = {avg_bwd}") - - print("################################################") - print(f"#### Manual Timer for {string_id} ends #########") - print("################################################\n\n\n") diff --git a/functorch/benchmarks/transformer_fusion_patterns/bias_gelu_dropout.py b/functorch/benchmarks/transformer_fusion_patterns/bias_gelu_dropout.py deleted file mode 100644 index 26c6d7c9e9f..00000000000 --- a/functorch/benchmarks/transformer_fusion_patterns/bias_gelu_dropout.py +++ /dev/null @@ -1,65 +0,0 @@ -import torch -from functorch.compile import memory_efficient_pointwise_fusion -import benchmark_helper - -# ALL comments regarding the patetrns - - -def bias_gelu_dropout(input, bias): - a = torch.add(input, bias) - b = torch.nn.functional.gelu(a) - c = torch.nn.functional.dropout(b, p=0.6, training=True) - return c - - -def aot_fn(input, bias): - a = torch.add(input, bias) - b = a * 0.5 * (1.0 + torch.tanh(0.79788456 * a * (1 + 0.044715 * a * a))) - c = torch.nn.functional.dropout(b, p=0.6, training=True) - return c - - -fn = bias_gelu_dropout - - -# Set inputs -device = "cuda" -dtype = torch.float16 -batch_size = 32 -seq_len = 196 -intermediate_size = 4096 -# batch_size = 2 -# seq_len = 4 -# intermediate_size = 3 -input = torch.randn( - batch_size, - seq_len, - intermediate_size, - requires_grad=True, - device=device, - dtype=dtype, -) -bias = torch.randn(intermediate_size, requires_grad=True, device=device, dtype=dtype) - - -# Get the optimized function -opt_fn = memory_efficient_pointwise_fusion( - aot_fn, compiler_name="torchscript_nvfuser" -) - - -# Profile cuda kernels -benchmark_helper.profile_cuda_kernels(fn, (input, bias), "Eager") -with torch.jit.fuser("fuser2"): - benchmark_helper.profile_cuda_kernels(opt_fn, (input, bias), "AOTAutograd") - - -# Time it with Torch Timer -benchmark_helper.time_with_torch_timer(fn, (input, bias), "Eager") -with torch.jit.fuser("fuser2"): - benchmark_helper.time_with_torch_timer(opt_fn, (input, bias), "AOTAutograd") - -# Time it with manual Timer -benchmark_helper.time_with_manual_timer(fn, (input, bias), "Eager") -with torch.jit.fuser("fuser2"): - benchmark_helper.time_with_manual_timer(opt_fn, (input, bias), "AOTAutograd") diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 19d5a3b750d..3e5b9086bbc 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -357,6 +357,7 @@ ALLOW_LIST = [ ("aten::_nested_view_from_buffer_copy.out", datetime.date(2023, 5, 1)), ("aten::_nested_view_from_buffer_copy", datetime.date(2023, 5, 1)), ("aten::_nested_view_from_buffer", datetime.date(2023, 5, 1)), + ("aten::_scaled_dot_product_flash_attention_backward", datetime.date(2023, 6, 1)), # These ops were moved to python under the c10d_functional namespace ("aten::wait_tensor", datetime.date(9999, 1, 30)), ("aten::reduce_scatter_tensor", datetime.date(9999, 1, 30)), diff --git a/test/test_flop_counter.py b/test/test_flop_counter.py index 4ead9f1eea9..2e34421abe9 100644 --- a/test/test_flop_counter.py +++ b/test/test_flop_counter.py @@ -148,7 +148,7 @@ class TestFlopCounter(TestCase): self.assertExpectedInline(str(layer1_conv_back_flops), """1849688064""") def test_custom(self): - mode = FlopCounterMode(custom_mapping={torch.ops.aten.add: lambda *args, out: 5}) + mode = FlopCounterMode(custom_mapping={torch.ops.aten.add: lambda *args, out_shape: 5}) with mode: a = T(4, 5) a + a diff --git a/test/test_transformers.py b/test/test_transformers.py index 2bf19cb448f..cb248e6389f 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1352,7 +1352,7 @@ class TestSDPA(NNTestCase): assert torch._fused_sdp_choice(q, k, v) == SDPBackend.MATH if PLATFORM_SUPPORTS_FUSED_SDPA: - batch_size, seq_len, num_heads, head_dim = 32, 64, 16, 64 + batch_size, seq_len, num_heads, head_dim = 2, 128, 8, 64 shape = (batch_size, seq_len, num_heads, head_dim) device = "cuda" make_tensor = partial(self.rand_tensor, device=device, dtype=torch.float16, packed=True) diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 4a898a847ad..7c0f6e7ee60 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -2932,8 +2932,6 @@ def aot_function( partition_fn: Callable = default_partition, decompositions: Optional[Dict] = None, num_params_buffers: int = 0, - hasher_type=None, # deprecated - static_argnums: Optional[Tuple[int]] = None, # deprecated keep_inference_input_mutations: bool = False, inference_compiler: Optional[Callable] = None, *, @@ -2992,10 +2990,6 @@ def aot_function( >>> x = torch.randn(4, 5, requires_grad=True) >>> aot_fn(x) """ - if static_argnums is not None: - raise RuntimeError( - "static_argnums has been deprecated - manually wrap your function or use torchdynamo." - ) if bw_compiler is None: bw_compiler = fw_compiler @@ -3127,8 +3121,6 @@ def aot_module_simplified( bw_compiler: Optional[Callable] = None, partition_fn: Callable = default_partition, decompositions: Optional[Dict] = None, - hasher_type=None, - static_argnums=None, keep_inference_input_mutations=False, inference_compiler: Optional[Callable] = None, ) -> nn.Module: @@ -3175,6 +3167,7 @@ def aot_module_simplified( with stateless._reparametrize_module( mod, pytree.tree_unflatten(args[:params_len], params_spec) ): + if isinstance(mod, torch.fx.GraphModule): with fx_traceback.preserve_node_meta(), warnings.catch_warnings(): warnings.filterwarnings( @@ -3193,7 +3186,6 @@ def aot_module_simplified( ) return out - assert static_argnums is None if bw_compiler is None: bw_compiler = fw_compiler if inference_compiler is None: diff --git a/torch/_functorch/compilers.py b/torch/_functorch/compilers.py index 735fcadb1c4..183a336f107 100644 --- a/torch/_functorch/compilers.py +++ b/torch/_functorch/compilers.py @@ -5,7 +5,7 @@ import pickle import random from contextlib import contextmanager from functools import partial -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Union import sympy import torch @@ -188,8 +188,8 @@ def simple_ts_compile(fx_g, _): return f -def nnc_jit(f, static_argnums=None): - return aot_function(f, simple_ts_compile, static_argnums=static_argnums) +def nnc_jit(f): + return aot_function(f, simple_ts_compile) aten = torch.ops.aten @@ -229,7 +229,6 @@ def print_compile(fx_g, _): def memory_efficient_fusion( fn: Union[Callable, nn.Module], - static_argnums: Optional[Tuple[int]] = None, **kwargs, ): """ @@ -245,8 +244,6 @@ def memory_efficient_fusion( Args: fn (Union[Callable, nn.Module]): A Python function or a ``nn.Module`` that takes one ore more arguments. Must return one or more Tensors. - static_argnums (Optional[Tuple[Int]]): An option tuple of ints to mark - the arguments of the function as static. **kwargs: Any other overrides you want to make to the settings Returns: @@ -261,7 +258,6 @@ def memory_efficient_fusion( "bw_compiler": ts_compile, "partition_fn": min_cut_rematerialization_partition, "decompositions": default_decompositions, - "static_argnums": static_argnums, } config.update(kwargs) if isinstance(fn, torch.nn.Module): diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index 06d9b2cb53b..f4d226dd8bc 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -94,7 +94,7 @@ def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph: func1.__name__ = name return func1 - FusionMeta = collections.namedtuple("FusionMeta", ["group", "snodes", "type"]) + FusionMeta = collections.namedtuple("FusionMeta", ["group", "snode", "type"]) func_dict = {s: get_fake_func(s) for s in ["extern", "nop", "compute", "fused"]} buf_to_fx_node = {} @@ -135,7 +135,7 @@ def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph: name = snode.get_name() fx_node.name = name - fx_node.meta["fusion_meta"] = FusionMeta(group, [snode], node_type) + fx_node.meta["fusion_meta"] = FusionMeta(group, snode, node_type) if isinstance(snode, FusedSchedulerNode): for x in snode.snodes: diff --git a/torch/distributed/_spmd/aot_function_patch.py b/torch/distributed/_spmd/aot_function_patch.py index c41f0cfa1c0..49576852888 100644 --- a/torch/distributed/_spmd/aot_function_patch.py +++ b/torch/distributed/_spmd/aot_function_patch.py @@ -1,5 +1,5 @@ from functools import wraps -from typing import Callable, Dict, Optional, Tuple +from typing import Callable, Dict, Optional import torch.utils._pytree as pytree from torch._functorch.aot_autograd import ( @@ -19,8 +19,6 @@ def patched_aot_function( partition_fn: Callable[..., object] = default_partition, decompositions: Optional[Dict[object, object]] = None, num_params_buffers: int = 0, - hasher_type: object = None, # deprecated - static_argnums: Optional[Tuple[int]] = None, # deprecated keep_inference_input_mutations: bool = False, pre_compile_fn: Optional[Callable[..., object]] = None, ) -> Callable[..., object]: @@ -98,11 +96,6 @@ def patched_aot_function( >>> x = torch.randn(4, 5, requires_grad=True) >>> aot_fn(x) """ - if static_argnums is not None: - raise RuntimeError( - "static_argnums has been deprecated - manually wrap your function or use torchdynamo." - ) - if bw_compiler is None: bw_compiler = fw_compiler diff --git a/torch/utils/flop_counter.py b/torch/utils/flop_counter.py index 72198445bf8..dc05af128ba 100644 --- a/torch/utils/flop_counter.py +++ b/torch/utils/flop_counter.py @@ -14,7 +14,7 @@ def get_shape(i): return i.shape return i -def mm_flop(a_shape, b_shape, out=None) -> int: +def mm_flop(a_shape, b_shape, *args, out_shape=None, **kwargs) -> int: """ Count flops for matmul. """ @@ -26,13 +26,13 @@ def mm_flop(a_shape, b_shape, out=None) -> int: # NB(chilli): Should be 2 * k - 1 technically for FLOPs. return m * n * 2 * k -def addmm_flop(self_shape, a_shape, b_shape, out=None, **kwargs) -> int: +def addmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int: """ Count flops for addmm """ return mm_flop(a_shape, b_shape) -def bmm_flop(a_shape, b_shape, out=None, **kwargs) -> int: +def bmm_flop(a_shape, b_shape, out_shape=None, **kwargs) -> int: """ Count flops for the bmm operation. """ @@ -46,7 +46,7 @@ def bmm_flop(a_shape, b_shape, out=None, **kwargs) -> int: flop = b * m * n * 2 * k return flop -def baddbmm_flop(self_shape, a_shape, b_shape, out=None, **kwargs) -> int: +def baddbmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int: """ Count flops for the baddbmm operation. """ @@ -83,11 +83,11 @@ def conv_flop_count( flop = batch_size * prod(conv_shape) * c_out * prod(dims) * 2 * c_in return flop -def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out=None, **kwargs) -> int: +def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> int: """ Count flops for convolution. """ - return conv_flop_count(x_shape, w_shape, out, transposed=transposed) + return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed) def transpose_shape(shape): return [shape[1], shape[0]] + list(shape[2:]) @@ -104,14 +104,14 @@ def conv_backward_flop( _output_padding, _groups, output_mask, - out) -> int: + out_shape) -> int: flop_count = 0 if output_mask[0]: - grad_input_shape = get_shape(out[0]) + grad_input_shape = get_shape(out_shape[0]) flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not transposed) if output_mask[1]: - grad_weight_shape = get_shape(out[1]) + grad_weight_shape = get_shape(out_shape[1]) flop_count += conv_flop_count(transpose_shape(x_shape), grad_out_shape, grad_weight_shape, transposed) return flop_count @@ -134,7 +134,7 @@ def sdpa_flop_count(query_shape, key_shape, value_shape): -def sdpa_flop(query_shape, key_shape, value_shape, *args, out=None, **kwargs) -> int: +def sdpa_flop(query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int: """ Count flops for self-attention. """ @@ -169,7 +169,7 @@ def sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape return total_flops -def sdpa_backward_flop(grad_out_shape, query_shape, key_shape, value_shape, *args, out=None, **kwargs) -> int: +def sdpa_backward_flop(grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int: """ Count flops for self-attention backward. """ @@ -306,6 +306,9 @@ class FlopCounterMode(TorchDispatchMode): return PopState.apply + def get_total_flops(self) -> int: + return sum(self.flop_counts['Global'].values()) + def get_flop_counts(self) -> Dict[str, Dict[Any, int]]: """Returns the flop counts as a dictionary of dictionaries. The outer dictionary is keyed by module name, and the inner dictionary is keyed by @@ -326,7 +329,7 @@ class FlopCounterMode(TorchDispatchMode): tabulate.PRESERVE_WHITESPACE = True header = ["Module", "FLOP", "% Total"] values = [] - global_flops = sum(self.flop_counts['Global'].values()) + global_flops = self.get_total_flops() global_suffix = get_suffix_str(global_flops) is_global_subsumed = False @@ -394,7 +397,7 @@ class FlopCounterMode(TorchDispatchMode): if func_packet in self.flop_mapping: flop_count_func = self.flop_mapping[func_packet] args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out)) - flop_count = flop_count_func(*args, **kwargs, out=out_shape) # type: ignore[operator] + flop_count = flop_count_func(*args, **kwargs, out_shape=out_shape) # type: ignore[operator] for par in self.parents: self.flop_counts[par][func_packet] += flop_count