mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Implemented basic version of AOTDispatcher that only chooses between autograd or no autograd (#83248)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83248 Approved by: https://github.com/zou3519, https://github.com/ezyang
This commit is contained in:
parent
86de9e7291
commit
fbe8c77427
3 changed files with 201 additions and 146 deletions
|
|
@ -1,3 +1,4 @@
|
|||
import dataclasses
|
||||
import warnings
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from functools import wraps
|
||||
|
|
@ -6,12 +7,12 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
|||
import torch
|
||||
import torch.fx.traceback as fx_traceback
|
||||
import torch.nn as nn
|
||||
import torch.nn.utils.stateless as stateless
|
||||
import torch.utils._pytree as pytree
|
||||
import torch.utils.dlpack
|
||||
from torch import Tensor
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch.fx import immutable_collections, Interpreter
|
||||
from torch.nn.utils import stateless
|
||||
|
||||
from functorch import make_fx
|
||||
from functorch._C import CompileCache
|
||||
|
|
@ -114,16 +115,6 @@ def _reshape_alias(x, shape, strides):
|
|||
return aten.view(x, shape)
|
||||
|
||||
|
||||
@register_decomposition(aten.new_zeros, aot_autograd_decompositions)
|
||||
def new_zeros(inp, size, dtype=None, layout=None, device=None, pin_memory=None):
|
||||
return torch.zeros(size, dtype=inp.dtype, device=inp.device)
|
||||
|
||||
|
||||
@register_decomposition(aten.new_full, aot_autograd_decompositions)
|
||||
def new_full(inp, size, value, dtype=None, layout=None, device=None, pin_memory=None):
|
||||
return torch.full(size, value, dtype=inp.dtype, device=inp.device)
|
||||
|
||||
|
||||
graph_being_compiled: str = None
|
||||
nth_graph: int = 0
|
||||
model_name: str = "model"
|
||||
|
|
@ -175,21 +166,134 @@ def call_func_with_args(f, args, steal_args=False):
|
|||
if not steal_args:
|
||||
args = list(args)
|
||||
assert isinstance(args, list)
|
||||
# TODO: Please remove soon
|
||||
# https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670
|
||||
|
||||
if hasattr(f, "_boxed_call"):
|
||||
return normalize_as_list(f(args))
|
||||
out = normalize_as_list(f(args))
|
||||
else:
|
||||
# TODO: Please remove soon
|
||||
# https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670
|
||||
warnings.warn(
|
||||
"Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. "
|
||||
"Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. "
|
||||
"See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale."
|
||||
)
|
||||
return normalize_as_list(f(*args))
|
||||
out = normalize_as_list(f(*args))
|
||||
return out
|
||||
|
||||
|
||||
def create_aot_autograd_function(
|
||||
flat_fn, fw_compiler, bw_compiler, partition_fn, decompositions, grad_state
|
||||
@dataclasses.dataclass
|
||||
class AOTConfig:
|
||||
"""
|
||||
Configuration for AOTDispatcher
|
||||
"""
|
||||
|
||||
fw_compiler: Callable
|
||||
bw_compiler: Callable
|
||||
partition_fn: Callable
|
||||
decompositions: Dict[Callable, Callable]
|
||||
|
||||
|
||||
def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig):
|
||||
fw_module = make_fx(flat_fn, aot_config.decompositions)(*flat_args)
|
||||
with track_graph_compiling("forward"):
|
||||
compiled_fw = aot_config.fw_compiler(fw_module, flat_args)
|
||||
|
||||
@wraps(compiled_fw)
|
||||
def new_fn(args):
|
||||
fw_outs = call_func_with_args(compiled_fw, args)
|
||||
return fw_outs
|
||||
|
||||
return new_fn
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _disable_jit_autocast():
|
||||
old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch._C._jit_set_autocast_mode(old_jit_autocast_flag)
|
||||
|
||||
|
||||
def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig):
|
||||
with _disable_jit_autocast():
|
||||
joint_forward_backward = create_joint_forward_backward(flat_fn)
|
||||
# Set input tensors that require grad to leaves
|
||||
with torch.set_grad_enabled(True):
|
||||
out = flat_fn(*flat_args)
|
||||
out = pytree.tree_map(
|
||||
lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x,
|
||||
out,
|
||||
)
|
||||
|
||||
if isinstance(out, (list, tuple)):
|
||||
num_outs = len(out)
|
||||
else:
|
||||
num_outs = 1
|
||||
|
||||
joint_inputs = (flat_args, out)
|
||||
with torch.set_grad_enabled(True):
|
||||
fx_g = make_fx(joint_forward_backward, aot_config.decompositions)(
|
||||
*joint_inputs
|
||||
)
|
||||
|
||||
if config.use_functionalize:
|
||||
# Functionalize the foward backward graph. First create a
|
||||
# fake fn to make functionalize happy
|
||||
def fake_fn(primals, tangents):
|
||||
return fx_g(primals, tangents)
|
||||
|
||||
fx_g = make_fx(functionalize(fake_fn))(*joint_inputs)
|
||||
|
||||
if config.debug_joint:
|
||||
print(fx_g.code)
|
||||
|
||||
with track_graph_compiling("joint"):
|
||||
fw_module, bw_module = aot_config.partition_fn(fx_g, joint_inputs)
|
||||
|
||||
if config.debug_graphs:
|
||||
print(fw_module.code, bw_module.code)
|
||||
|
||||
with track_graph_compiling("forward"):
|
||||
compiled_fw = aot_config.fw_compiler(fw_module, flat_args)
|
||||
|
||||
# TODO: Delay this backwards compilation until the backwards pass
|
||||
with torch.no_grad():
|
||||
fw_outs = call_func_with_args(compiled_fw, flat_args)
|
||||
|
||||
if config.debug_partitioner:
|
||||
activation_sizes = 0
|
||||
for out in fw_outs[num_outs:]:
|
||||
if isinstance(out, torch.Tensor):
|
||||
activation_sizes += out.storage().nbytes()
|
||||
print(f"Real Activations Stored(GB): {activation_sizes/1e9}")
|
||||
|
||||
bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
|
||||
with track_graph_compiling("backward", True):
|
||||
compiled_bw = aot_config.bw_compiler(bw_module, bw_args)
|
||||
|
||||
class CompiledFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@disable_torchdynamo
|
||||
def forward(ctx, *flat_tensor_args):
|
||||
fw_outs = call_func_with_args(compiled_fw, flat_tensor_args)
|
||||
ctx.save_for_backward(*fw_outs[num_outs:])
|
||||
return tuple(fw_outs[0:num_outs])
|
||||
|
||||
@staticmethod
|
||||
@disable_torchdynamo
|
||||
def backward(ctx, *flat_args):
|
||||
contiguous_args = [t.contiguous() for t in flat_args]
|
||||
all_args = list(ctx.saved_tensors) + list(contiguous_args)
|
||||
ctx.maybe_clear_saved_tensors()
|
||||
out = call_func_with_args(compiled_bw, all_args, steal_args=True)
|
||||
return tuple(out)
|
||||
|
||||
return CompiledFunction.apply
|
||||
|
||||
|
||||
def create_aot_dispatcher_function(
|
||||
flat_fn, flat_args: List[Tensor], aot_config: AOTConfig
|
||||
):
|
||||
"""
|
||||
Traces the forward and backward graphs of the attr:`flat_fn` to generate a
|
||||
|
|
@ -203,118 +307,54 @@ def create_aot_autograd_function(
|
|||
The resulting compiled forward and backward graphs are then wrapped up in a
|
||||
``torch.autograd.Function`` object.
|
||||
"""
|
||||
if decompositions is None:
|
||||
decompositions = {}
|
||||
joint_forward_backward = create_joint_forward_backward(flat_fn)
|
||||
if aot_config.decompositions is None:
|
||||
aot_config.decompositions = {}
|
||||
|
||||
compiled_fw = None
|
||||
compiled_bw = None
|
||||
num_outs = None
|
||||
aot_config.decompositions = {
|
||||
**aot_autograd_decompositions,
|
||||
**aot_config.decompositions,
|
||||
}
|
||||
fake_mode = FakeTensorMode.push() if config.use_fake_tensor else nullcontext()
|
||||
|
||||
class CompiledFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@disable_torchdynamo
|
||||
def forward(ctx, *flat_tensor_args):
|
||||
nonlocal compiled_fw, compiled_bw, num_outs
|
||||
# Disable the JIT Autocast flag to prevent re-autocasting of jitted graph.
|
||||
# TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed.
|
||||
old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False)
|
||||
if compiled_fw is None:
|
||||
flat_tensor_args = pytree.tree_map(
|
||||
lambda x: x.detach().requires_grad_(x.requires_grad)
|
||||
with preserve_rng_state(), fake_mode as mode:
|
||||
|
||||
def process_inputs(flat_args):
|
||||
flat_args = pytree.tree_map(
|
||||
lambda x: x.detach().requires_grad_(x.requires_grad)
|
||||
if isinstance(x, Tensor)
|
||||
else x,
|
||||
flat_args,
|
||||
)
|
||||
fake_flat_tensor_args = pytree.tree_map(
|
||||
lambda x: mode.from_tensor(x)
|
||||
if mode
|
||||
else x
|
||||
if isinstance(x, Tensor)
|
||||
else x,
|
||||
flat_args,
|
||||
)
|
||||
return fake_flat_tensor_args
|
||||
|
||||
fake_flat_tensor_args = process_inputs(flat_args)
|
||||
|
||||
needs_autograd = (
|
||||
any(
|
||||
[
|
||||
x.requires_grad
|
||||
for x in fake_flat_tensor_args
|
||||
if isinstance(x, Tensor)
|
||||
else x,
|
||||
flat_tensor_args,
|
||||
)
|
||||
fake_mode = (
|
||||
FakeTensorMode.push() if config.use_fake_tensor else nullcontext()
|
||||
)
|
||||
with preserve_rng_state(), fake_mode as mode:
|
||||
# Set input tensors that require grad to leaves
|
||||
fake_flat_tensor_args = pytree.tree_map(
|
||||
lambda x: mode.from_tensor(x)
|
||||
if mode
|
||||
else x
|
||||
if isinstance(x, Tensor)
|
||||
else x,
|
||||
flat_tensor_args,
|
||||
)
|
||||
with torch.set_grad_enabled(grad_state):
|
||||
out = flat_fn(*fake_flat_tensor_args)
|
||||
out = pytree.tree_map(
|
||||
lambda x: x.detach().contiguous()
|
||||
if isinstance(x, Tensor)
|
||||
else x,
|
||||
out,
|
||||
)
|
||||
|
||||
if isinstance(out, (list, tuple)):
|
||||
num_outs = len(out)
|
||||
else:
|
||||
num_outs = 1
|
||||
|
||||
joint_inputs = (fake_flat_tensor_args, out)
|
||||
aot_decompositions = {
|
||||
**aot_autograd_decompositions,
|
||||
**decompositions,
|
||||
}
|
||||
with torch.set_grad_enabled(grad_state):
|
||||
fx_g = make_fx(joint_forward_backward, aot_decompositions)(
|
||||
*joint_inputs
|
||||
)
|
||||
|
||||
if config.use_functionalize:
|
||||
# Functionalize the foward backward graph. First create a
|
||||
# fake fn to make functionalize happy
|
||||
def fake_fn(primals, tangents):
|
||||
return fx_g(primals, tangents)
|
||||
|
||||
fx_g = make_fx(functionalize(fake_fn))(*joint_inputs)
|
||||
|
||||
if config.debug_joint:
|
||||
print(fx_g.code)
|
||||
|
||||
with track_graph_compiling("joint"):
|
||||
fw_module, bw_module = partition_fn(fx_g, joint_inputs)
|
||||
|
||||
if config.debug_graphs:
|
||||
print(fw_module.code, bw_module.code)
|
||||
|
||||
with track_graph_compiling("forward"):
|
||||
compiled_fw = fw_compiler(fw_module, flat_tensor_args)
|
||||
|
||||
fw_outs = call_func_with_args(compiled_fw, flat_tensor_args)
|
||||
if config.debug_partitioner:
|
||||
activation_sizes = 0
|
||||
for out in fw_outs[num_outs:]:
|
||||
if isinstance(out, torch.Tensor):
|
||||
activation_sizes += out.storage().nbytes()
|
||||
print(f"Real Activations Stored(GB): {activation_sizes/1e9}")
|
||||
|
||||
bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
|
||||
with track_graph_compiling("backward", True):
|
||||
compiled_bw = bw_compiler(bw_module, bw_args)
|
||||
else:
|
||||
fw_outs = call_func_with_args(compiled_fw, flat_tensor_args)
|
||||
torch._C._jit_set_autocast_mode(old_jit_autocast_flag)
|
||||
ctx.save_for_backward(*fw_outs[num_outs:])
|
||||
return tuple(fw_outs[0:num_outs])
|
||||
|
||||
@staticmethod
|
||||
@disable_torchdynamo
|
||||
def backward(ctx, *flat_args):
|
||||
# Disable the JIT Autocast flag to prevent re-autocasting of jitted graph.
|
||||
# TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed.
|
||||
old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False)
|
||||
contiguous_args = [t.contiguous() for t in flat_args]
|
||||
all_args = list(ctx.saved_tensors) + list(contiguous_args)
|
||||
ctx.maybe_clear_saved_tensors()
|
||||
out = call_func_with_args(compiled_bw, all_args, steal_args=True)
|
||||
|
||||
torch._C._jit_set_autocast_mode(old_jit_autocast_flag)
|
||||
return tuple(out)
|
||||
|
||||
return CompiledFunction
|
||||
]
|
||||
)
|
||||
and torch.is_grad_enabled()
|
||||
)
|
||||
# crappy version of dispatcher
|
||||
# TODO: Do this properly
|
||||
if needs_autograd:
|
||||
return make_boxed_func(
|
||||
aot_dispatch_autograd(flat_fn, fake_flat_tensor_args, aot_config)
|
||||
)
|
||||
else:
|
||||
return aot_dispatch_base(flat_fn, fake_flat_tensor_args, aot_config)
|
||||
|
||||
|
||||
class _CompileCache(CompileCache):
|
||||
|
|
@ -486,6 +526,12 @@ def aot_function(
|
|||
compile_cache = CompileCache()
|
||||
if bw_compiler is None:
|
||||
bw_compiler = fw_compiler
|
||||
aot_config = AOTConfig(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
decompositions=decompositions,
|
||||
)
|
||||
cached_res = None
|
||||
|
||||
fn_id = id(fn)
|
||||
|
|
@ -571,14 +617,11 @@ def aot_function(
|
|||
out_spec.set(spec)
|
||||
return flat_out
|
||||
|
||||
compiled_fn = create_aot_autograd_function(
|
||||
compiled_fn = create_aot_dispatcher_function(
|
||||
flat_fn,
|
||||
fw_compiler,
|
||||
bw_compiler,
|
||||
partition_fn,
|
||||
decompositions,
|
||||
grad_state=torch.is_grad_enabled(),
|
||||
).apply
|
||||
flat_tensor_args,
|
||||
aot_config,
|
||||
)
|
||||
cached_res = (compiled_fn, out_spec)
|
||||
|
||||
# Save the compiled_fn in the cache
|
||||
|
|
@ -593,7 +636,7 @@ def aot_function(
|
|||
)
|
||||
|
||||
cached_fn, out_spec = cached_res
|
||||
out = cached_fn(*flat_tensor_args)
|
||||
out = cached_fn(flat_tensor_args)
|
||||
return out_spec.unflatten(out)
|
||||
|
||||
return returned_function
|
||||
|
|
@ -716,16 +759,27 @@ def aot_module_simplified(mod: nn.Module, *top_args, **top_kwargs) -> nn.Module:
|
|||
assert static_argnums is None
|
||||
if bw_compiler is None:
|
||||
bw_compiler = fw_compiler
|
||||
compiled_fn = create_aot_autograd_function(
|
||||
fn,
|
||||
fw_compiler,
|
||||
bw_compiler,
|
||||
partition_fn,
|
||||
decompositions,
|
||||
grad_state=torch.is_grad_enabled(),
|
||||
).apply
|
||||
aot_config = AOTConfig(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
decompositions=decompositions,
|
||||
)
|
||||
|
||||
return compiled_fn
|
||||
compiled_fn = None
|
||||
|
||||
@wraps(fn)
|
||||
def new_func(*args):
|
||||
nonlocal compiled_fn
|
||||
if compiled_fn is None:
|
||||
compiled_fn = create_aot_dispatcher_function(
|
||||
fn,
|
||||
args,
|
||||
aot_config,
|
||||
)
|
||||
return compiled_fn(args)
|
||||
|
||||
return new_func
|
||||
|
||||
compiled_f = aot_function_simplified(functional_call, *top_args, **top_kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -142,6 +142,7 @@ default_decompositions = {
|
|||
default_decompositions = get_decompositions(default_decompositions)
|
||||
|
||||
|
||||
@make_boxed_compiler
|
||||
def print_compile(fx_g, _):
|
||||
print(fx_g.code)
|
||||
return fx_g
|
||||
|
|
|
|||
|
|
@ -257,16 +257,16 @@ class TestAOTAutograd(TestCase):
|
|||
inps = [torch.randn((), requires_grad=True)]
|
||||
graph_size = None
|
||||
|
||||
def assert_graph_empty(fx_g, _):
|
||||
def get_graph_size(fx_g, _):
|
||||
nonlocal graph_size
|
||||
graph_size = len(fx_g.graph.nodes)
|
||||
return fx_g
|
||||
|
||||
start_recompilations = num_of_recompilations()
|
||||
f = aot_function(foo, nop, assert_graph_empty)
|
||||
f = aot_function(foo, nop, get_graph_size)
|
||||
with torch.set_grad_enabled(False):
|
||||
f(*inps)
|
||||
self.assertEqual(graph_size, 2)
|
||||
self.assertIsNone(graph_size)
|
||||
with torch.set_grad_enabled(True):
|
||||
f(*inps)
|
||||
self.assertTrue(graph_size > 2)
|
||||
|
|
|
|||
Loading…
Reference in a new issue