From 016fcca24366f745e652d5ff0df0d09f2f7244a5 Mon Sep 17 00:00:00 2001 From: Horace He Date: Sat, 13 Aug 2022 11:34:10 +0000 Subject: [PATCH] format some aotautograd-related files in functorch with black (#83240) Pull Request resolved: https://github.com/pytorch/pytorch/pull/83240 Approved by: https://github.com/ezyang --- .lintrunner.toml | 2 + functorch/functorch/_src/aot_autograd.py | 65 +++++++++++----- functorch/functorch/_src/compilers.py | 94 +++++++++++++++++------- 3 files changed, 114 insertions(+), 47 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 131898e3c6d..4c206c5fc74 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -720,6 +720,8 @@ include_patterns = [ 'torch/_subclasses/**/*.py', 'torch/_*.py', 'torchgen/**/*.py', + 'functorch/functorch/_src/aot_autograd.py', + 'functorch/functorch/_src/compilers.py', ] command = [ 'python3', diff --git a/functorch/functorch/_src/aot_autograd.py b/functorch/functorch/_src/aot_autograd.py index 45f442cb025..63d6989cf36 100644 --- a/functorch/functorch/_src/aot_autograd.py +++ b/functorch/functorch/_src/aot_autograd.py @@ -1,30 +1,34 @@ +import warnings from contextlib import contextmanager, nullcontext +from functools import wraps +from typing import Any, Callable, Dict, List, Optional, Tuple + import torch -import torch.nn as nn -from torch import Tensor -from functorch import make_fx -from torch.fx import immutable_collections, Interpreter import torch.fx.traceback as fx_traceback -from torch._subclasses import FakeTensorMode +import torch.nn as nn +import torch.nn.utils.stateless as stateless import torch.utils._pytree as pytree import torch.utils.dlpack -from torch.nn.utils import _stateless +from torch import Tensor +from torch._subclasses import FakeTensorMode +from torch.fx import immutable_collections, Interpreter + +from functorch import make_fx from functorch._C import CompileCache from functorch.experimental import functionalize from . import config from .decompositions import register_decomposition +from .named_members_polyfill import _named_buffers, _named_parameters from .partitioners import default_partition -from .named_members_polyfill import _named_parameters, _named_buffers -from typing import Callable, List, Dict, Any, Tuple, Optional -from functools import wraps -import warnings try: from torchdynamo import disable as disable_torchdynamo except ImportError: + def disable_torchdynamo(x): return x + pytree._register_pytree_node( immutable_collections.immutable_list, lambda x: (list(x), None), @@ -148,20 +152,25 @@ def track_graph_compiling(graph_name, increment_index=False): nth_graph += 1 graph_being_compiled = None + def make_boxed_func(f): def g(args): return f(*args) + g._boxed_call = True return g + def make_boxed_compiler(compiler): @wraps(compiler) def f(fx_g, inps): out_f = compiler(fx_g, inps) fx_g = make_boxed_func(out_f) return fx_g + return f + def call_func_with_args(f, args, steal_args=False): if not steal_args: args = list(args) @@ -178,6 +187,7 @@ def call_func_with_args(f, args, steal_args=False): ) return normalize_as_list(f(*args)) + def create_aot_autograd_function( flat_fn, fw_compiler, bw_compiler, partition_fn, decompositions, grad_state ): @@ -212,19 +222,30 @@ def create_aot_autograd_function( if compiled_fw is None: flat_tensor_args = pytree.tree_map( lambda x: x.detach().requires_grad_(x.requires_grad) - if isinstance(x, Tensor) else x, flat_tensor_args + if isinstance(x, Tensor) + else x, + flat_tensor_args, + ) + fake_mode = ( + FakeTensorMode.push() if config.use_fake_tensor else nullcontext() ) - fake_mode = FakeTensorMode.push() if config.use_fake_tensor else nullcontext() with preserve_rng_state(), fake_mode as mode: # Set input tensors that require grad to leaves fake_flat_tensor_args = pytree.tree_map( - lambda x: mode.from_tensor(x) if mode else x - if isinstance(x, Tensor) else x, flat_tensor_args + lambda x: mode.from_tensor(x) + if mode + else x + if isinstance(x, Tensor) + else x, + flat_tensor_args, ) with torch.set_grad_enabled(grad_state): out = flat_fn(*fake_flat_tensor_args) out = pytree.tree_map( - lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out + lambda x: x.detach().contiguous() + if isinstance(x, Tensor) + else x, + out, ) if isinstance(out, (list, tuple)): @@ -233,7 +254,10 @@ def create_aot_autograd_function( num_outs = 1 joint_inputs = (fake_flat_tensor_args, out) - aot_decompositions = {**aot_autograd_decompositions, **decompositions} + aot_decompositions = { + **aot_autograd_decompositions, + **decompositions, + } with torch.set_grad_enabled(grad_state): fx_g = make_fx(joint_forward_backward, aot_decompositions)( *joint_inputs @@ -244,6 +268,7 @@ def create_aot_autograd_function( # fake fn to make functionalize happy def fake_fn(primals, tangents): return fx_g(primals, tangents) + fx_g = make_fx(functionalize(fake_fn))(*joint_inputs) if config.debug_joint: @@ -255,7 +280,6 @@ def create_aot_autograd_function( if config.debug_graphs: print(fw_module.code, bw_module.code) - with track_graph_compiling("forward"): compiled_fw = fw_compiler(fw_module, flat_tensor_args) @@ -621,7 +645,7 @@ def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module: def functional_call(named_params, named_buffers, *args, **kwargs): params_and_buffers = {**named_params, **named_buffers} - return _stateless.functional_call(mod, params_and_buffers, args, kwargs) + return stateless.functional_call(mod, params_and_buffers, args, kwargs) compiled_f = aot_function(functional_call, *args, **kwargs) @@ -663,7 +687,7 @@ def aot_module_simplified(mod: nn.Module, *top_args, **top_kwargs) -> nn.Module: params_len = len(params_flat) def functional_call(*args, **kwargs): - with _stateless.reparametrize_module( + with stateless._reparametrize_module( mod, pytree.tree_unflatten(args[:params_len], params_spec) ): if isinstance(mod, torch.fx.GraphModule): @@ -706,13 +730,16 @@ def aot_module_simplified(mod: nn.Module, *top_args, **top_kwargs) -> nn.Module: compiled_f = aot_function_simplified(functional_call, *top_args, **top_kwargs) if top_kwargs: + def forward(*args, **kwargs): return compiled_f( *params_flat, *args, **kwargs, ) + else: + def forward(*args): return compiled_f( *params_flat, diff --git a/functorch/functorch/_src/compilers.py b/functorch/functorch/_src/compilers.py index 6333708e9e4..1c63d223f96 100644 --- a/functorch/functorch/_src/compilers.py +++ b/functorch/functorch/_src/compilers.py @@ -1,18 +1,23 @@ -import torch -import torch.fx as fx -import torch.nn as nn -from functools import partial -from typing import Callable, Optional, Tuple, Union - -from .aot_autograd import aot_function, aot_module, make_boxed_compiler -from .decompositions import get_decompositions -from .partitioners import draw_graph, min_cut_rematerialization_partition, default_partition -from .compile_utils import strip_overloads +import copy +import logging import os import pickle import random -import copy -import logging +from functools import partial +from typing import Callable, Optional, Tuple, Union + +import torch +import torch.fx as fx +import torch.nn as nn + +from .aot_autograd import aot_function, aot_module, make_boxed_compiler +from .compile_utils import strip_overloads +from .decompositions import get_decompositions +from .partitioners import ( + default_partition, + draw_graph, + min_cut_rematerialization_partition, +) # These canonicalizations are needed here (and not decompositions), as the ops @@ -43,8 +48,12 @@ def ts_compile(fx_g: fx.GraphModule, _) -> Callable: strip_overloads(fx_g) for node in fx_g.graph.nodes: - if (node.target == torch.ops.aten._to_copy and len(node.args) == 1 - and len(node.kwargs) == 1 and 'dtype' in node.kwargs): + if ( + node.target == torch.ops.aten._to_copy + and len(node.args) == 1 + and len(node.kwargs) == 1 + and "dtype" in node.kwargs + ): node.target = torch.ops.aten.to for node in fx_g.graph.nodes: @@ -55,7 +64,6 @@ def ts_compile(fx_g: fx.GraphModule, _) -> Callable: new_kwargs[k] = v node.kwargs = new_kwargs - fx_g.graph.lint() fx_g.recompile() @@ -140,7 +148,9 @@ def print_compile(fx_g, _): def memory_efficient_fusion( - fn: Union[Callable, nn.Module], static_argnums: Optional[Tuple[int]] = None, **kwargs + fn: Union[Callable, nn.Module], + static_argnums: Optional[Tuple[int]] = None, + **kwargs, ): """ Wrapper function over :func:`aot_function` and :func:`aot_module` to perform @@ -218,7 +228,7 @@ def get_inputs(input_data_path): Return a random input for the given inputs meta generated from _save_fx_default. """ inputs = [] - with (open(input_data_path, 'rb')) as f: + with (open(input_data_path, "rb")) as f: inputs_meta = pickle.load(f) inputs = [] for meta in inputs_meta: @@ -227,7 +237,16 @@ def get_inputs(input_data_path): input = type(random.rand()) else: type, shape, stride, dtype, device = meta - if dtype in {torch.int, torch.int32, torch.int64, torch.bool, torch.int, torch.uint8, int, float}: + if dtype in { + torch.int, + torch.int32, + torch.int64, + torch.bool, + torch.int, + torch.uint8, + int, + float, + }: input = torch.randint(0, 1, shape, dtype=dtype, device=device) else: input = torch.rand(shape, dtype=dtype, device=device) @@ -260,16 +279,21 @@ def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_ input_meta += get_input_meta(args[1]) return input_meta for arg in args: - if(type(arg) == int or type(arg) == float): + if type(arg) == int or type(arg) == float: input_meta.append((type(arg),)) else: - input_meta.append((type(arg), arg.shape, arg.stride(), arg.dtype, arg.device)) + input_meta.append( + (type(arg), arg.shape, arg.stride(), arg.dtype, arg.device) + ) return input_meta def graph_saver_helper(gm_to_save, args, type_name): global graph_index if len(gm_to_save.graph.nodes) == 0: - logging.log(logging.WARNING, f"No nodes in graph {current_name}_{type_name}_{graph_index}.") + logging.log( + logging.WARNING, + f"No nodes in graph {current_name}_{type_name}_{graph_index}.", + ) return gm = copy.deepcopy(gm_to_save) @@ -281,10 +305,21 @@ def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_ isExist = os.path.exists(f"{folder_name}/{current_name}") if not isExist: os.makedirs(f"{folder_name}/{current_name}") - gm.to_folder(f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}") - pickle.dump(input_meta, open(f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.input", "wb")) # noqa: E501 + gm.to_folder( + f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}" + ) + pickle.dump( + input_meta, + open( + f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.input", # noqa: B950 + "wb", + ), + ) # noqa: E501 if dump_example_input: - torch.save(args, f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.pt") # noqa: E501 + torch.save( + args, + f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.pt", # noqa: B950 + ) # noqa: E501 def graph_saver_forward(gm, fw_args): graph_saver_helper(gm, fw_args, "forward") @@ -300,10 +335,13 @@ def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_ graph_saver_helper(gm, joint_args, "joint") return default_partition(gm, joint_args) - return aot_module_simplified(gm, fw_compiler=graph_saver_forward, - bw_compiler=graph_saver_backward, - partition_fn=graph_saver_joint, - decompositions=default_decompositions) + return aot_module_simplified( + gm, + fw_compiler=graph_saver_forward, + bw_compiler=graph_saver_backward, + partition_fn=graph_saver_joint, + decompositions=default_decompositions, + ) def graph_dumper_aot(current_name, folder_name, dump_example_input=False):