diff --git a/functorch/functorch/_src/aot_autograd.py b/functorch/functorch/_src/aot_autograd.py index 63d6989cf36..03889cfb6a2 100644 --- a/functorch/functorch/_src/aot_autograd.py +++ b/functorch/functorch/_src/aot_autograd.py @@ -1,3 +1,4 @@ +import dataclasses import warnings from contextlib import contextmanager, nullcontext from functools import wraps @@ -6,12 +7,12 @@ from typing import Any, Callable, Dict, List, Optional, Tuple import torch import torch.fx.traceback as fx_traceback import torch.nn as nn -import torch.nn.utils.stateless as stateless import torch.utils._pytree as pytree import torch.utils.dlpack from torch import Tensor from torch._subclasses import FakeTensorMode from torch.fx import immutable_collections, Interpreter +from torch.nn.utils import stateless from functorch import make_fx from functorch._C import CompileCache @@ -114,16 +115,6 @@ def _reshape_alias(x, shape, strides): return aten.view(x, shape) -@register_decomposition(aten.new_zeros, aot_autograd_decompositions) -def new_zeros(inp, size, dtype=None, layout=None, device=None, pin_memory=None): - return torch.zeros(size, dtype=inp.dtype, device=inp.device) - - -@register_decomposition(aten.new_full, aot_autograd_decompositions) -def new_full(inp, size, value, dtype=None, layout=None, device=None, pin_memory=None): - return torch.full(size, value, dtype=inp.dtype, device=inp.device) - - graph_being_compiled: str = None nth_graph: int = 0 model_name: str = "model" @@ -175,21 +166,134 @@ def call_func_with_args(f, args, steal_args=False): if not steal_args: args = list(args) assert isinstance(args, list) - # TODO: Please remove soon - # https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 + if hasattr(f, "_boxed_call"): - return normalize_as_list(f(args)) + out = normalize_as_list(f(args)) else: + # TODO: Please remove soon + # https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 warnings.warn( "Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. " "Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. " "See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale." ) - return normalize_as_list(f(*args)) + out = normalize_as_list(f(*args)) + return out -def create_aot_autograd_function( - flat_fn, fw_compiler, bw_compiler, partition_fn, decompositions, grad_state +@dataclasses.dataclass +class AOTConfig: + """ + Configuration for AOTDispatcher + """ + + fw_compiler: Callable + bw_compiler: Callable + partition_fn: Callable + decompositions: Dict[Callable, Callable] + + +def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig): + fw_module = make_fx(flat_fn, aot_config.decompositions)(*flat_args) + with track_graph_compiling("forward"): + compiled_fw = aot_config.fw_compiler(fw_module, flat_args) + + @wraps(compiled_fw) + def new_fn(args): + fw_outs = call_func_with_args(compiled_fw, args) + return fw_outs + + return new_fn + + +@contextmanager +def _disable_jit_autocast(): + old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False) + try: + yield + finally: + torch._C._jit_set_autocast_mode(old_jit_autocast_flag) + + +def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig): + with _disable_jit_autocast(): + joint_forward_backward = create_joint_forward_backward(flat_fn) + # Set input tensors that require grad to leaves + with torch.set_grad_enabled(True): + out = flat_fn(*flat_args) + out = pytree.tree_map( + lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, + out, + ) + + if isinstance(out, (list, tuple)): + num_outs = len(out) + else: + num_outs = 1 + + joint_inputs = (flat_args, out) + with torch.set_grad_enabled(True): + fx_g = make_fx(joint_forward_backward, aot_config.decompositions)( + *joint_inputs + ) + + if config.use_functionalize: + # Functionalize the foward backward graph. First create a + # fake fn to make functionalize happy + def fake_fn(primals, tangents): + return fx_g(primals, tangents) + + fx_g = make_fx(functionalize(fake_fn))(*joint_inputs) + + if config.debug_joint: + print(fx_g.code) + + with track_graph_compiling("joint"): + fw_module, bw_module = aot_config.partition_fn(fx_g, joint_inputs) + + if config.debug_graphs: + print(fw_module.code, bw_module.code) + + with track_graph_compiling("forward"): + compiled_fw = aot_config.fw_compiler(fw_module, flat_args) + + # TODO: Delay this backwards compilation until the backwards pass + with torch.no_grad(): + fw_outs = call_func_with_args(compiled_fw, flat_args) + + if config.debug_partitioner: + activation_sizes = 0 + for out in fw_outs[num_outs:]: + if isinstance(out, torch.Tensor): + activation_sizes += out.storage().nbytes() + print(f"Real Activations Stored(GB): {activation_sizes/1e9}") + + bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs] + with track_graph_compiling("backward", True): + compiled_bw = aot_config.bw_compiler(bw_module, bw_args) + + class CompiledFunction(torch.autograd.Function): + @staticmethod + @disable_torchdynamo + def forward(ctx, *flat_tensor_args): + fw_outs = call_func_with_args(compiled_fw, flat_tensor_args) + ctx.save_for_backward(*fw_outs[num_outs:]) + return tuple(fw_outs[0:num_outs]) + + @staticmethod + @disable_torchdynamo + def backward(ctx, *flat_args): + contiguous_args = [t.contiguous() for t in flat_args] + all_args = list(ctx.saved_tensors) + list(contiguous_args) + ctx.maybe_clear_saved_tensors() + out = call_func_with_args(compiled_bw, all_args, steal_args=True) + return tuple(out) + + return CompiledFunction.apply + + +def create_aot_dispatcher_function( + flat_fn, flat_args: List[Tensor], aot_config: AOTConfig ): """ Traces the forward and backward graphs of the attr:`flat_fn` to generate a @@ -203,118 +307,54 @@ def create_aot_autograd_function( The resulting compiled forward and backward graphs are then wrapped up in a ``torch.autograd.Function`` object. """ - if decompositions is None: - decompositions = {} - joint_forward_backward = create_joint_forward_backward(flat_fn) + if aot_config.decompositions is None: + aot_config.decompositions = {} - compiled_fw = None - compiled_bw = None - num_outs = None + aot_config.decompositions = { + **aot_autograd_decompositions, + **aot_config.decompositions, + } + fake_mode = FakeTensorMode.push() if config.use_fake_tensor else nullcontext() - class CompiledFunction(torch.autograd.Function): - @staticmethod - @disable_torchdynamo - def forward(ctx, *flat_tensor_args): - nonlocal compiled_fw, compiled_bw, num_outs - # Disable the JIT Autocast flag to prevent re-autocasting of jitted graph. - # TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed. - old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False) - if compiled_fw is None: - flat_tensor_args = pytree.tree_map( - lambda x: x.detach().requires_grad_(x.requires_grad) + with preserve_rng_state(), fake_mode as mode: + + def process_inputs(flat_args): + flat_args = pytree.tree_map( + lambda x: x.detach().requires_grad_(x.requires_grad) + if isinstance(x, Tensor) + else x, + flat_args, + ) + fake_flat_tensor_args = pytree.tree_map( + lambda x: mode.from_tensor(x) + if mode + else x + if isinstance(x, Tensor) + else x, + flat_args, + ) + return fake_flat_tensor_args + + fake_flat_tensor_args = process_inputs(flat_args) + + needs_autograd = ( + any( + [ + x.requires_grad + for x in fake_flat_tensor_args if isinstance(x, Tensor) - else x, - flat_tensor_args, - ) - fake_mode = ( - FakeTensorMode.push() if config.use_fake_tensor else nullcontext() - ) - with preserve_rng_state(), fake_mode as mode: - # Set input tensors that require grad to leaves - fake_flat_tensor_args = pytree.tree_map( - lambda x: mode.from_tensor(x) - if mode - else x - if isinstance(x, Tensor) - else x, - flat_tensor_args, - ) - with torch.set_grad_enabled(grad_state): - out = flat_fn(*fake_flat_tensor_args) - out = pytree.tree_map( - lambda x: x.detach().contiguous() - if isinstance(x, Tensor) - else x, - out, - ) - - if isinstance(out, (list, tuple)): - num_outs = len(out) - else: - num_outs = 1 - - joint_inputs = (fake_flat_tensor_args, out) - aot_decompositions = { - **aot_autograd_decompositions, - **decompositions, - } - with torch.set_grad_enabled(grad_state): - fx_g = make_fx(joint_forward_backward, aot_decompositions)( - *joint_inputs - ) - - if config.use_functionalize: - # Functionalize the foward backward graph. First create a - # fake fn to make functionalize happy - def fake_fn(primals, tangents): - return fx_g(primals, tangents) - - fx_g = make_fx(functionalize(fake_fn))(*joint_inputs) - - if config.debug_joint: - print(fx_g.code) - - with track_graph_compiling("joint"): - fw_module, bw_module = partition_fn(fx_g, joint_inputs) - - if config.debug_graphs: - print(fw_module.code, bw_module.code) - - with track_graph_compiling("forward"): - compiled_fw = fw_compiler(fw_module, flat_tensor_args) - - fw_outs = call_func_with_args(compiled_fw, flat_tensor_args) - if config.debug_partitioner: - activation_sizes = 0 - for out in fw_outs[num_outs:]: - if isinstance(out, torch.Tensor): - activation_sizes += out.storage().nbytes() - print(f"Real Activations Stored(GB): {activation_sizes/1e9}") - - bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs] - with track_graph_compiling("backward", True): - compiled_bw = bw_compiler(bw_module, bw_args) - else: - fw_outs = call_func_with_args(compiled_fw, flat_tensor_args) - torch._C._jit_set_autocast_mode(old_jit_autocast_flag) - ctx.save_for_backward(*fw_outs[num_outs:]) - return tuple(fw_outs[0:num_outs]) - - @staticmethod - @disable_torchdynamo - def backward(ctx, *flat_args): - # Disable the JIT Autocast flag to prevent re-autocasting of jitted graph. - # TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed. - old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False) - contiguous_args = [t.contiguous() for t in flat_args] - all_args = list(ctx.saved_tensors) + list(contiguous_args) - ctx.maybe_clear_saved_tensors() - out = call_func_with_args(compiled_bw, all_args, steal_args=True) - - torch._C._jit_set_autocast_mode(old_jit_autocast_flag) - return tuple(out) - - return CompiledFunction + ] + ) + and torch.is_grad_enabled() + ) + # crappy version of dispatcher + # TODO: Do this properly + if needs_autograd: + return make_boxed_func( + aot_dispatch_autograd(flat_fn, fake_flat_tensor_args, aot_config) + ) + else: + return aot_dispatch_base(flat_fn, fake_flat_tensor_args, aot_config) class _CompileCache(CompileCache): @@ -486,6 +526,12 @@ def aot_function( compile_cache = CompileCache() if bw_compiler is None: bw_compiler = fw_compiler + aot_config = AOTConfig( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + decompositions=decompositions, + ) cached_res = None fn_id = id(fn) @@ -571,14 +617,11 @@ def aot_function( out_spec.set(spec) return flat_out - compiled_fn = create_aot_autograd_function( + compiled_fn = create_aot_dispatcher_function( flat_fn, - fw_compiler, - bw_compiler, - partition_fn, - decompositions, - grad_state=torch.is_grad_enabled(), - ).apply + flat_tensor_args, + aot_config, + ) cached_res = (compiled_fn, out_spec) # Save the compiled_fn in the cache @@ -593,7 +636,7 @@ def aot_function( ) cached_fn, out_spec = cached_res - out = cached_fn(*flat_tensor_args) + out = cached_fn(flat_tensor_args) return out_spec.unflatten(out) return returned_function @@ -716,16 +759,27 @@ def aot_module_simplified(mod: nn.Module, *top_args, **top_kwargs) -> nn.Module: assert static_argnums is None if bw_compiler is None: bw_compiler = fw_compiler - compiled_fn = create_aot_autograd_function( - fn, - fw_compiler, - bw_compiler, - partition_fn, - decompositions, - grad_state=torch.is_grad_enabled(), - ).apply + aot_config = AOTConfig( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + decompositions=decompositions, + ) - return compiled_fn + compiled_fn = None + + @wraps(fn) + def new_func(*args): + nonlocal compiled_fn + if compiled_fn is None: + compiled_fn = create_aot_dispatcher_function( + fn, + args, + aot_config, + ) + return compiled_fn(args) + + return new_func compiled_f = aot_function_simplified(functional_call, *top_args, **top_kwargs) diff --git a/functorch/functorch/_src/compilers.py b/functorch/functorch/_src/compilers.py index 1c63d223f96..44daf68a575 100644 --- a/functorch/functorch/_src/compilers.py +++ b/functorch/functorch/_src/compilers.py @@ -142,6 +142,7 @@ default_decompositions = { default_decompositions = get_decompositions(default_decompositions) +@make_boxed_compiler def print_compile(fx_g, _): print(fx_g.code) return fx_g diff --git a/functorch/test/test_pythonkey.py b/functorch/test/test_pythonkey.py index 72479769a10..9823dc512ec 100644 --- a/functorch/test/test_pythonkey.py +++ b/functorch/test/test_pythonkey.py @@ -257,16 +257,16 @@ class TestAOTAutograd(TestCase): inps = [torch.randn((), requires_grad=True)] graph_size = None - def assert_graph_empty(fx_g, _): + def get_graph_size(fx_g, _): nonlocal graph_size graph_size = len(fx_g.graph.nodes) return fx_g start_recompilations = num_of_recompilations() - f = aot_function(foo, nop, assert_graph_empty) + f = aot_function(foo, nop, get_graph_size) with torch.set_grad_enabled(False): f(*inps) - self.assertEqual(graph_size, 2) + self.assertIsNone(graph_size) with torch.set_grad_enabled(True): f(*inps) self.assertTrue(graph_size > 2)