pytorch/torch/_higher_order_ops/utils.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

617 lines
23 KiB
Python
Raw Normal View History

# mypy: allow-untyped-defs
import functools
from contextlib import contextmanager, ExitStack
from dataclasses import dataclass
from typing import Any, Callable, List, Optional, Tuple, Union
import torch
import torch.fx.traceback as fx_traceback
import torch.utils._pytree as pytree
from torch._guards import detect_fake_mode
from torch._ops import OperatorBase
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing, make_fx
from torch.fx.passes.shape_prop import TensorMetadata
from torch.multiprocessing.reductions import StorageWeakRef
@dataclass
class UnsupportedAliasMutationException(RuntimeError):
reason: str
def autograd_not_implemented_inner(
operator: OperatorBase, delayed_error: bool, *args: Any, **kwargs: Any
) -> Any:
"""If autograd is enabled and any of the arguments require grad this will either
raise an error or return a DelayedError depending on the value of delayed.
Args:
operator: The Operator to call with the *args and **kwargs with
op_name: The name of the Operator
delayed_error: If True, return a DelayedError instead of raising an error
args: The flattened operands to the Operator
kwargs: The keyword arguments to the Operator
Raises:
RuntimeError: If autograd is enabled and any of the arguments to the Operator
"""
with torch._C._AutoDispatchBelowAutograd():
result = operator(*args, **kwargs)
flat_operands = pytree.arg_tree_leaves(*args)
if torch.is_grad_enabled() and any(
f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)
):
if delayed_error:
err_fn = torch._C._functions.DelayedError(
f"Autograd not implemented for {str(operator)}",
1,
)
def fake_requires_grad(tensor):
if torch.is_floating_point(tensor) or torch.is_complex(tensor):
tensor = tensor.detach()
tensor.requires_grad = True
return tensor
return pytree.tree_map_only(
torch.Tensor, lambda x: err_fn(fake_requires_grad(x)), result
)
else:
raise RuntimeError(f"Autograd not implemented for {str(operator)}")
return result
def autograd_not_implemented(op: OperatorBase, deferred_error: bool) -> Callable:
def inner(*args, **kwargs):
return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs)
return inner
def _maybe_run_with_interpreter(fn):
maybe_interpreted_fn = fn
if isinstance(fn, torch.fx.GraphModule) and fx_traceback.has_preserved_node_meta():
# Running graph with interpreter is needed for propagating the stack_trace
def graph_with_interpreter(*args):
with fx_traceback.preserve_node_meta():
return torch.fx.Interpreter(fn).run(*args)
maybe_interpreted_fn = graph_with_interpreter
return maybe_interpreted_fn
def reenter_make_fx(fn):
from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
@functools.wraps(fn)
def wrapped(*args):
assert (
_CURRENT_MAKE_FX_TRACER is not None
), "Cannot reenter make_fx when we're not under a make_fx tracing session"
return _CURRENT_MAKE_FX_TRACER.trace_subgraph(
_maybe_run_with_interpreter(fn), *args
)
return wrapped
def _maybe_reenter_make_fx(fn):
from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
if _CURRENT_MAKE_FX_TRACER is not None:
return reenter_make_fx(fn)
else:
def _maybe_make_fx_with_fake_mode(fn):
@functools.wraps(fn)
def wrapped(*args):
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(args)
if fake_mode is None:
# we creaeta a fake_mode here to make sure we could
# trace the graph with data-dependent calls e.g. .item()
return make_fx(fn, tracing_mode="fake")(*args)
# Tracing with real if all inputs have been fakfied
return make_fx(fn)(*args)
return wrapped
return _maybe_make_fx_with_fake_mode(fn)
@contextmanager
def _set_compilation_env():
_old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag
try:
# We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo
# once we are confident fx tracing works with dynamo.
torch.fx._symbolic_trace._is_fx_tracing_flag = False
yield
finally:
torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing
def _detect_input_mutation(gm: torch.fx.GraphModule) -> bool:
example_inputs = [
ph.meta.get("val", None) for ph in gm.graph.find_nodes(op="placeholder")
]
inp_mutation, _, _, _ = check_input_alias_and_mutation(gm, example_inputs)
if len(inp_mutation) > 0:
return True
for _, module in gm.named_children():
if isinstance(module, torch.fx.GraphModule):
if _detect_input_mutation(module):
return True
return False
def _detect_input_alias(gm: torch.fx.GraphModule) -> bool:
example_inputs = [
ph.meta.get("val", None) for ph in gm.graph.find_nodes(op="placeholder")
]
_, inp_inp_alias_map, inp_out_alias_map, _ = check_input_alias_and_mutation(
gm, example_inputs
)
if len(inp_out_alias_map) > 0 or len(inp_inp_alias_map) > 0:
return True
return False
# The invariant here is that we always trace the branch with fake tensor
def _maybe_fake_tracing(fn, inputs: List[Any], pre_dispatch):
fake_mode = detect_fake_mode(inputs)
tracing_mode = "real"
if fake_mode is None:
tracing_mode = "fake"
# Note: we need to turn off proxy tensor mode to avoid tracing infra
# code that happens in make_fx e.g. we now call as_strided when wrapping tensor
# as fake tensor.
with disable_proxy_modes_tracing():
return make_fx(
fn,
tracing_mode=tracing_mode,
pre_dispatch=pre_dispatch,
_error_on_data_dependent_ops=False,
)(*inputs)
def has_potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False):
try:
gm = _maybe_fake_tracing(gm, inputs, pre_dispatch)
except UnsupportedAliasMutationException:
# this can happen when nested cond_op is
# functionalized
return True
except Exception as e:
raise e
return _detect_input_mutation(gm) or _detect_input_alias(gm)
def _has_potential_branch_input_mutation(branch, inputs, pre_dispatch=False):
"""
Dispatch-trace the branch with inputs and check if
producing graph has mutable op on the input. This is
bit restrictive as the branch must be traceable.
"""
try:
gm = _maybe_fake_tracing(branch, inputs, pre_dispatch)
except UnsupportedAliasMutationException:
# this can happen when nested cond_op is
# functionalized
return True
except Exception as e:
raise e
return _detect_input_mutation(gm)
def _has_potential_branch_input_alias(branch, inputs, pre_dispatch=False):
"""
Dispatch-trace the branch with inputs and check if
producing graph has output aliasing the branch input. This is
bit restrictive as the branch must be traceable.
"""
try:
gm = _maybe_fake_tracing(branch, inputs, pre_dispatch)
except UnsupportedAliasMutationException:
# this can happen when nested cond_op is
# functionalized
return True
except Exception as e:
raise e
return _detect_input_alias(gm)
def unique_graph_id(proxy_mode, prefix):
"""Returns a unique name and id for a graph to be added to a proxy_mode tracer"""
# There are probably better ways - I know that create_arg has some self incrementing name
# magic to it, but since we explicitly have to get the name for register_module,
# I was not sure how to do that. This kinda simulates it.
next_name = None
i = 0
while not next_name:
candidate = f"{prefix}_{i}"
if hasattr(proxy_mode.tracer.root, candidate):
i += 1
else:
next_name = candidate
return i, next_name
def _from_fun(t):
from torch._functorch.aot_autograd import from_fun
from torch._subclasses.functional_tensor import FunctionalTensor
if isinstance(t, torch.Tensor):
if t.dtype != torch.bool:
return torch.empty_strided(
t.size(),
t.stride(),
dtype=t.dtype,
requires_grad=t.requires_grad,
device=t.device,
)
else:
# clone of a functional tensor produces a functional tensor
# but we want to avoid it so we clone a non-functional version
maybe_unfunc_t = t
if isinstance(t, FunctionalTensor):
torch._sync(t)
maybe_unfunc_t = from_fun(t)
elif torch._is_functional_tensor(t):
# need to handle both types of functionalization here:
# these are the tensors that came from the user,
# which could be either FunctionalTensorWrapper or FunctionalTensor
torch._sync(t)
maybe_unfunc_t = torch._from_functional_tensor(t)
return maybe_unfunc_t.clone()
return t
def clone_outputs_aliasing_inputs(args):
input_storage = {
StorageWeakRef(arg._typed_storage())
for arg in args
if isinstance(arg, torch.Tensor)
}
def maybe_clone(t):
if (
isinstance(t, torch.Tensor)
and StorageWeakRef(t._typed_storage()) in input_storage
):
return t.clone()
return t
return maybe_clone
def prepare_fw_with_masks(fn):
def fw_with_masks(*args):
fw_out = fn(*args)
return fw_out, [
True if isinstance(ret, torch.Tensor) and ret.requires_grad else False
for ret in fw_out
]
return fw_with_masks
# This function replaces None gradients with all-zero gradients.
# `None` gradients are problematic for CUDA graphs. Those gradients are
# replaced with an all-zero tensor for better optimization
def unmask_none_gradients(grads, operands):
allowed_types = (torch.Tensor, int, torch.SymInt)
assert all(
isinstance(o, allowed_types) for o in operands
), f"operands can only be of {allowed_types} but got {[type(o) for o in operands]}"
unmasked_grads = []
for g, o in zip(grads, operands):
if g is not None:
unmasked_grads.append(g)
else:
# In case the operand is an int or a torch.SymInt, return None
# This can happen for lifted_arguments. E.g., the shapes of a dynamic tensor are lifted and passed
# as additional arguments
unmasked_grads.append(
torch.zeros_like(o) if isinstance(o, torch.Tensor) else None
)
return unmasked_grads
# TODO: The parameter use_output_and_grad_bw is required because some operations
# that utilize this function, such as the while_loop, may require (grad, fwd_outputs)
def create_fw_bw_graph(fn, use_output_and_grad_bw, fw_inputs, fw_outputs):
from torch._functorch.aot_autograd import AOTConfig, create_joint
# Note:[HOP create fw_bw graph] We create "clean" environments for make_fx by suspending all dispatch keys
# between Autograd and Python key. Currently, we only suspend functionalization but more can be
# added when required. Will encounter two problems if we don't suspend functionalization:
#
# 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper,
# but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching.
# However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to
# fetch the proxy for the inputs and fail to capture any operations on them.
#
# 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further
# wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer
# only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore,
# when creating the output node, it fails to associate the wrapped tensor with its proxy.
# Instead, it will create _tensor_constant as output.
dummy_aot_config = AOTConfig(
fw_compiler=None, # type: ignore[arg-type]
bw_compiler=None, # type: ignore[arg-type]
partition_fn=None, # type: ignore[arg-type]
decompositions={},
num_params_buffers=0,
aot_id=0,
keep_inference_input_mutations=False,
)
example_grad = [_from_fun(out) for out in fw_outputs]
num_grads = len(example_grad)
fw_graph = _maybe_reenter_make_fx(fn)(*fw_inputs)
def joint_fn(*joint_operands_grads):
if use_output_and_grad_bw:
grads = joint_operands_grads[0]
inputs = joint_operands_grads[1][-1:]
else:
grads = joint_operands_grads[:num_grads]
inputs = joint_operands_grads[num_grads:]
joint = create_joint(prepare_fw_with_masks(fn), aot_config=dummy_aot_config)
_, grads = joint(
list(inputs),
[grad for grad in grads if grad is not None and grad.requires_grad],
)
# Unmask None gradients to all-zero gradients
unmasked_grads = unmask_none_gradients(grads, inputs)
# In order to keep map functional for backward graph,
# we clone outputs that are aliasing inputs
maybe_clone = clone_outputs_aliasing_inputs(joint_operands_grads)
return pytree.tree_map(maybe_clone, unmasked_grads)
if use_output_and_grad_bw:
example_xs_out = list(fw_inputs) + list(fw_outputs)
joint_graph = _maybe_reenter_make_fx(joint_fn)(
(list(example_grad), list(example_xs_out))
)
else:
example_xs_out = list(fw_inputs)
joint_graph = _maybe_reenter_make_fx(joint_fn)(
*(list(example_grad) + list(example_xs_out))
)
return fw_graph, joint_graph
def _unstack_pytree(xs):
flat_xs, inspec = pytree.tree_flatten(xs)
if not all(isinstance(xs, torch.Tensor) for xs in flat_xs):
raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")
if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
raise RuntimeError(
f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}"
)
a = zip(*flat_xs)
pytrees = [pytree.tree_unflatten(tuple, inspec) for tuple in a]
return pytrees
def _stack_pytree(pytrees):
flat_out = []
out_spec = None
for pt in pytrees:
flat_pt, out_spec = pytree.tree_flatten(pt)
flat_out.append(flat_pt)
assert out_spec is not None
b = zip(*flat_out)
stacked_out = []
for leaves in b:
if all(isinstance(leaf, torch.Tensor) for leaf in leaves):
stacked_out.append(torch.stack(leaves))
elif all(leaf is None for leaf in leaves):
# Backward graph can return None output when forward inputs doesn't require grad.
# When we eagerly execute backward graph, we need to call _stack_pytree on its output,
# therefore we need to deal with None output.
stacked_out.append(None) # type: ignore[arg-type]
else:
raise RuntimeError(f"Cannot stack {leaves}.")
return pytree.tree_unflatten(stacked_out, out_spec)
# We cannot call save_for_backward for symints. This helper function
# can be used to save symints as direct attributes of ctx in autograd.Function.
#
# For example, if args = (x, y, s0, z, s1),
# save_tensors_and_symints_for_backward will partition the args into two lists, and a bookkeeping list pos:
# partitioned_args[0] = (x, y, z)
# partitioned_args[1] = (s0, s1)
# pos = (0, 0, 1, 0, 1)
# pos list keeps track of which partition the args
# is partitioned into in order to recover it in saved_tensors_and_symints.
#
# In saved_tensors_and_symints, we can recover the original args by:
# iterating over the pos list and pop one item from the front of paritioned_args[pos[i]].
# We use t_idx and s_idx to keep track of the next index of the item we are going to pop for the two lists.
def save_tensors_and_symints_for_backward(ctx, args):
assert all(
isinstance(arg, (torch.Tensor, torch.SymInt, int, type(None))) for arg in args
), args
partitioned_args: list[Any] = [[], []]
pos = []
for arg in args:
idx = 0 if isinstance(arg, torch.Tensor) else 1
partitioned_args[idx].append(arg)
pos.append(idx)
assert not hasattr(ctx, "sym_int_args"), "ctx already has sym_int_args attribute."
assert not hasattr(ctx, "pos"), "ctx already has pos attribute."
ctx.save_for_backward(*partitioned_args[0])
ctx.sym_int_args = partitioned_args[1]
ctx.pos = pos
def saved_tensors_and_symints(ctx):
args = []
t_idx = 0
s_idx = 0
saved_tensors = ctx.saved_tensors
for p in ctx.pos:
if p == 0:
args.append(saved_tensors[t_idx])
t_idx += 1
else:
args.append(ctx.sym_int_args[s_idx])
s_idx += 1
assert t_idx + s_idx == len(ctx.pos)
return tuple(args)
def get_dummy_aot_autograd_config():
from torch._functorch.aot_autograd import AOTConfig
return AOTConfig(
fw_compiler=None, # type: ignore[arg-type]
bw_compiler=None, # type: ignore[arg-type]
partition_fn=None, # type: ignore[arg-type]
decompositions={},
num_params_buffers=0,
aot_id=0,
keep_inference_input_mutations=False,
)
# Slices off the first element of a given dimension
def first_slice_copy(t: torch.Tensor, dim: int = 0) -> torch.Tensor:
return torch.select_copy(t, dim, 0)
[hop free symbols] lift free symbols in example_value when create_graph_input (#138363) There are 4 parts (they are hard to further break into smaller ones cause they're highly coupled) in this PR: 1. **Whenever we call create_graph_input, we try to bind the symbols in the graph input.** We've enforced the invariant that all create_graph_inputs calls must provide an example value, we could intercept at the create_graph_input calls (This PR only handles free symbols in tensors). 2. **We cache the bound_symbols** to avoid lift the same symbol repeated. 3. For lifted symbols, we re-used **lifted_freevars** i.e. the mapping between symbol proxy in parent graph to the lifted phs in current subgraph, which we handle lifted tensors. In this way, all hops that supports lifted tensors should be able to handle lifted_symints automatically (at least in dynamo part). 4. For **unbacked symbols** created during tracing, we need to also bound these symbols to its proxy. This is to support the tests cases where we want to lift unbacked symbols as input. We need the proxy of the unbacked symbol in parent graph in order to properly create the args to the hop. 5. We change all the tests after free symbols are lifted in subgraphs. And also supports the lifted symbols in existing higher order ops. **The interaction of nested tracers:** The previous design for lifting tensor closures is that: suppose we're in nested tracers, whenever we see a new proxy that's not created by create tracer, we recursively look for the proxy in parent tracer until we find the tracer that creates this proxy (either a placeholder or some intermediate results). More detail is in Note [Nested SubgraphTracer and free_variable handling]. Given the above design, the plan for lifting the free symbols is: whenever we lift a free tensor to be the inputs of current subgraph, we'll look at the symbols in it and bind the symbols at the same time. For example, suppose we have the following function: ```python def f(x: [s1, s2]): def true_f(): def true_f_inner(): return x.sin() ``` what will happen in time order: 1. we create a subtracer 1 and start to speculate the outer cond's true_f 2. we create a another subtracer 2 and start to speculate the inner cond's true_f_inner. 3. dynamo realize the tensor input x by calling wrap_tensor in top-level to create graph input x (tracer 0), we bind the symbol s1, s2 after ph for x is created. So the graph now looks like: ```python def gm(s1, s2, x): ``` 4. when seeing TensorVariable.call_method of x, tracer2 wants to create a call_function(sin, proxy_of_x), but it finds that proxy_of_x is not created by current tracer. So it recursively look up its parent tracer1 and find parent tracer1 also doesn't track this proxy_of_x then it finds the root tracer0, who is the creator of it and tracks it as a ph. Then tracer 1 create_graph_input to lift the closure to its input ph1 and add (proxy_of_x: ph1) k-v in **lifted_freevars** of tracer 1. Now the graph looks like: ```python def gm(s1, s2, x): def true_gm(x): ``` 5. Since there are free symbols inside this new tensor input, tracer 1 also binds the symbols (maybe_bind_symbol), which calls create_graph_input for s1 and s2. Now the graph looks like ```python def gm(s1, s2, x): def true_gm(s1, s2, x): ``` 6. then it goes back to tracer 2, and call create_graph_input for x and get ph2, tracer 2's **lifted_freevars** records (ph1, ph2). and tracer 2 also binds the symbols in this new tensor input. Now the graph looks like: ```python def gm(s1, s2, x): def true_gm(s1, s2, x): def true_gm_inner(s1, s2, x): ``` 7. Finally the sin call_function node is created by tracer 2. **This PR also handles the following cases:** - What if we lift two tensors share the same symbol? e.g. x1 [s1, s2], x2 [s2, s3]? Each subtracer maintains bound_symbols as a cache that maps a symbol.expr to its proxy in current tracer. So when we see x1, we'll track s1 and s2 as inputs and bound s1 to ph1, s2 to ph2. So when we try to bind symbols of x2, s2 will already be tracked so no graph input is created. - what if a subgraph close over a symint? e.g. ```python def f(x): def true_f(): c = x.size(0) def true_fn_inner(): return c ``` When we speculate true_fn_inner, we find proxy_of_c is not tracked by tracer 2, so it recursively looks up its parent. At this point, x and its symbols have been lifted as input of true_f (as a result of lifting x during tracing true_f in tracer 1. Specifically the graph looks like: ```python def gm(s1, s2, x): def true_gm(s1, s2, x): def true_gm_inner(): ``` So tracer 2 is able to find that s1 have been tracked as ph in tracer 1 so it returns back to gm and call create_graph_input on s1. The graph now looks like: ```python def gm(s1, s2, x): def true_gm(s1, s2, x): def true_gm_inner(s1): return s1 ``` - What if subgraph close over an unbacked symint? e.g. ```python def f(x): def true_f(): c = x.item() def true_f_inner(): return c ``` When x.item() is called, proxy_of_c and its symnode variable is created for tracer 1, and we also call track_unbacked_symbols to record this relationship. So when tracer 2 finds proxy_of_c is not created by current tracer, it recursivelly looks up its parent tracer and finds that that expression u0 has been tracked as a result of track_unbacked_symbol in tracer 1. So it will stop the recursion and create_graph_input u0 in tracer 2. Graph looks like: ```python def f(x): def true_f(s1, s2, x): c = x.item() def true_gm_inner(u0): return u0 cond(pred, true_gm_inner, false_gm_inner, (c,)) ``` - what if subgraph close over a tensor with unbacked symint shape? ```python def f(x): def true_f(): c = x.item() r = torch.randn((c,)) def true_f_inner(): return r + 1 ``` This is the same as the case of closing over tensors with backed shapes. where we first lift r, then bind u0 in it, which recursively bind_symint of u0 in its parent and found u0 is tracked in parent tracer as a result of .item() call. Pull Request resolved: https://github.com/pytorch/pytorch/pull/138363 Approved by: https://github.com/zou3519
2024-11-06 21:33:32 +00:00
# Reports the difference between meta of two tensors in a string
def diff_tensor_meta(
meta1: TensorMetadata, meta2: TensorMetadata, check_grad=True
) -> list[str]:
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode
pair_diffs = []
for meta_name in TensorMetadata._fields:
if not check_grad and meta_name == "requires_grad":
continue
val1 = getattr(meta1, meta_name)
val2 = getattr(meta2, meta_name)
try:
if val1 != val2:
pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'")
except GuardOnDataDependentSymNode as _:
pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'")
continue
return pair_diffs
[hop free symbols] lift free symbols in example_value when create_graph_input (#138363) There are 4 parts (they are hard to further break into smaller ones cause they're highly coupled) in this PR: 1. **Whenever we call create_graph_input, we try to bind the symbols in the graph input.** We've enforced the invariant that all create_graph_inputs calls must provide an example value, we could intercept at the create_graph_input calls (This PR only handles free symbols in tensors). 2. **We cache the bound_symbols** to avoid lift the same symbol repeated. 3. For lifted symbols, we re-used **lifted_freevars** i.e. the mapping between symbol proxy in parent graph to the lifted phs in current subgraph, which we handle lifted tensors. In this way, all hops that supports lifted tensors should be able to handle lifted_symints automatically (at least in dynamo part). 4. For **unbacked symbols** created during tracing, we need to also bound these symbols to its proxy. This is to support the tests cases where we want to lift unbacked symbols as input. We need the proxy of the unbacked symbol in parent graph in order to properly create the args to the hop. 5. We change all the tests after free symbols are lifted in subgraphs. And also supports the lifted symbols in existing higher order ops. **The interaction of nested tracers:** The previous design for lifting tensor closures is that: suppose we're in nested tracers, whenever we see a new proxy that's not created by create tracer, we recursively look for the proxy in parent tracer until we find the tracer that creates this proxy (either a placeholder or some intermediate results). More detail is in Note [Nested SubgraphTracer and free_variable handling]. Given the above design, the plan for lifting the free symbols is: whenever we lift a free tensor to be the inputs of current subgraph, we'll look at the symbols in it and bind the symbols at the same time. For example, suppose we have the following function: ```python def f(x: [s1, s2]): def true_f(): def true_f_inner(): return x.sin() ``` what will happen in time order: 1. we create a subtracer 1 and start to speculate the outer cond's true_f 2. we create a another subtracer 2 and start to speculate the inner cond's true_f_inner. 3. dynamo realize the tensor input x by calling wrap_tensor in top-level to create graph input x (tracer 0), we bind the symbol s1, s2 after ph for x is created. So the graph now looks like: ```python def gm(s1, s2, x): ``` 4. when seeing TensorVariable.call_method of x, tracer2 wants to create a call_function(sin, proxy_of_x), but it finds that proxy_of_x is not created by current tracer. So it recursively look up its parent tracer1 and find parent tracer1 also doesn't track this proxy_of_x then it finds the root tracer0, who is the creator of it and tracks it as a ph. Then tracer 1 create_graph_input to lift the closure to its input ph1 and add (proxy_of_x: ph1) k-v in **lifted_freevars** of tracer 1. Now the graph looks like: ```python def gm(s1, s2, x): def true_gm(x): ``` 5. Since there are free symbols inside this new tensor input, tracer 1 also binds the symbols (maybe_bind_symbol), which calls create_graph_input for s1 and s2. Now the graph looks like ```python def gm(s1, s2, x): def true_gm(s1, s2, x): ``` 6. then it goes back to tracer 2, and call create_graph_input for x and get ph2, tracer 2's **lifted_freevars** records (ph1, ph2). and tracer 2 also binds the symbols in this new tensor input. Now the graph looks like: ```python def gm(s1, s2, x): def true_gm(s1, s2, x): def true_gm_inner(s1, s2, x): ``` 7. Finally the sin call_function node is created by tracer 2. **This PR also handles the following cases:** - What if we lift two tensors share the same symbol? e.g. x1 [s1, s2], x2 [s2, s3]? Each subtracer maintains bound_symbols as a cache that maps a symbol.expr to its proxy in current tracer. So when we see x1, we'll track s1 and s2 as inputs and bound s1 to ph1, s2 to ph2. So when we try to bind symbols of x2, s2 will already be tracked so no graph input is created. - what if a subgraph close over a symint? e.g. ```python def f(x): def true_f(): c = x.size(0) def true_fn_inner(): return c ``` When we speculate true_fn_inner, we find proxy_of_c is not tracked by tracer 2, so it recursively looks up its parent. At this point, x and its symbols have been lifted as input of true_f (as a result of lifting x during tracing true_f in tracer 1. Specifically the graph looks like: ```python def gm(s1, s2, x): def true_gm(s1, s2, x): def true_gm_inner(): ``` So tracer 2 is able to find that s1 have been tracked as ph in tracer 1 so it returns back to gm and call create_graph_input on s1. The graph now looks like: ```python def gm(s1, s2, x): def true_gm(s1, s2, x): def true_gm_inner(s1): return s1 ``` - What if subgraph close over an unbacked symint? e.g. ```python def f(x): def true_f(): c = x.item() def true_f_inner(): return c ``` When x.item() is called, proxy_of_c and its symnode variable is created for tracer 1, and we also call track_unbacked_symbols to record this relationship. So when tracer 2 finds proxy_of_c is not created by current tracer, it recursivelly looks up its parent tracer and finds that that expression u0 has been tracked as a result of track_unbacked_symbol in tracer 1. So it will stop the recursion and create_graph_input u0 in tracer 2. Graph looks like: ```python def f(x): def true_f(s1, s2, x): c = x.item() def true_gm_inner(u0): return u0 cond(pred, true_gm_inner, false_gm_inner, (c,)) ``` - what if subgraph close over a tensor with unbacked symint shape? ```python def f(x): def true_f(): c = x.item() r = torch.randn((c,)) def true_f_inner(): return r + 1 ``` This is the same as the case of closing over tensors with backed shapes. where we first lift r, then bind u0 in it, which recursively bind_symint of u0 in its parent and found u0 is tracked in parent tracer as a result of .item() call. Pull Request resolved: https://github.com/pytorch/pytorch/pull/138363 Approved by: https://github.com/zou3519
2024-11-06 21:33:32 +00:00
# Note [lifted arg types in hop]
# For dynamoed hops, we automatically lift the free symbols in tensors as arguments.
# This has implications for the types of lifted args for different dispatch keys:
# 1. functionalization, FakeTensorMode, ProxyTorchDispatchMode, Autograd need to support torch.Symint
# lifted args because it's on the path of torch.compile(dynamic=True).
# 2. functionalization, FakeTensorMode, ProxyTorchDispatchMode, Autograd, CompositeExplicitAutograd need
# to support int arguments. In the eager run case, we re-trace the subgraph in AutogradKey, so inner
# hops may receive int inputs from the shape of outer tensor inputs.
# However, CompositeExplicitAutograd won't receive SymInt inputs because it only accepts real tensor inputs.
def validate_subgraph_args_types(lifted_args: Union[tuple[Any, ...], list[Any]]):
[hop free symbols] lift free symbols in example_value when create_graph_input (#138363) There are 4 parts (they are hard to further break into smaller ones cause they're highly coupled) in this PR: 1. **Whenever we call create_graph_input, we try to bind the symbols in the graph input.** We've enforced the invariant that all create_graph_inputs calls must provide an example value, we could intercept at the create_graph_input calls (This PR only handles free symbols in tensors). 2. **We cache the bound_symbols** to avoid lift the same symbol repeated. 3. For lifted symbols, we re-used **lifted_freevars** i.e. the mapping between symbol proxy in parent graph to the lifted phs in current subgraph, which we handle lifted tensors. In this way, all hops that supports lifted tensors should be able to handle lifted_symints automatically (at least in dynamo part). 4. For **unbacked symbols** created during tracing, we need to also bound these symbols to its proxy. This is to support the tests cases where we want to lift unbacked symbols as input. We need the proxy of the unbacked symbol in parent graph in order to properly create the args to the hop. 5. We change all the tests after free symbols are lifted in subgraphs. And also supports the lifted symbols in existing higher order ops. **The interaction of nested tracers:** The previous design for lifting tensor closures is that: suppose we're in nested tracers, whenever we see a new proxy that's not created by create tracer, we recursively look for the proxy in parent tracer until we find the tracer that creates this proxy (either a placeholder or some intermediate results). More detail is in Note [Nested SubgraphTracer and free_variable handling]. Given the above design, the plan for lifting the free symbols is: whenever we lift a free tensor to be the inputs of current subgraph, we'll look at the symbols in it and bind the symbols at the same time. For example, suppose we have the following function: ```python def f(x: [s1, s2]): def true_f(): def true_f_inner(): return x.sin() ``` what will happen in time order: 1. we create a subtracer 1 and start to speculate the outer cond's true_f 2. we create a another subtracer 2 and start to speculate the inner cond's true_f_inner. 3. dynamo realize the tensor input x by calling wrap_tensor in top-level to create graph input x (tracer 0), we bind the symbol s1, s2 after ph for x is created. So the graph now looks like: ```python def gm(s1, s2, x): ``` 4. when seeing TensorVariable.call_method of x, tracer2 wants to create a call_function(sin, proxy_of_x), but it finds that proxy_of_x is not created by current tracer. So it recursively look up its parent tracer1 and find parent tracer1 also doesn't track this proxy_of_x then it finds the root tracer0, who is the creator of it and tracks it as a ph. Then tracer 1 create_graph_input to lift the closure to its input ph1 and add (proxy_of_x: ph1) k-v in **lifted_freevars** of tracer 1. Now the graph looks like: ```python def gm(s1, s2, x): def true_gm(x): ``` 5. Since there are free symbols inside this new tensor input, tracer 1 also binds the symbols (maybe_bind_symbol), which calls create_graph_input for s1 and s2. Now the graph looks like ```python def gm(s1, s2, x): def true_gm(s1, s2, x): ``` 6. then it goes back to tracer 2, and call create_graph_input for x and get ph2, tracer 2's **lifted_freevars** records (ph1, ph2). and tracer 2 also binds the symbols in this new tensor input. Now the graph looks like: ```python def gm(s1, s2, x): def true_gm(s1, s2, x): def true_gm_inner(s1, s2, x): ``` 7. Finally the sin call_function node is created by tracer 2. **This PR also handles the following cases:** - What if we lift two tensors share the same symbol? e.g. x1 [s1, s2], x2 [s2, s3]? Each subtracer maintains bound_symbols as a cache that maps a symbol.expr to its proxy in current tracer. So when we see x1, we'll track s1 and s2 as inputs and bound s1 to ph1, s2 to ph2. So when we try to bind symbols of x2, s2 will already be tracked so no graph input is created. - what if a subgraph close over a symint? e.g. ```python def f(x): def true_f(): c = x.size(0) def true_fn_inner(): return c ``` When we speculate true_fn_inner, we find proxy_of_c is not tracked by tracer 2, so it recursively looks up its parent. At this point, x and its symbols have been lifted as input of true_f (as a result of lifting x during tracing true_f in tracer 1. Specifically the graph looks like: ```python def gm(s1, s2, x): def true_gm(s1, s2, x): def true_gm_inner(): ``` So tracer 2 is able to find that s1 have been tracked as ph in tracer 1 so it returns back to gm and call create_graph_input on s1. The graph now looks like: ```python def gm(s1, s2, x): def true_gm(s1, s2, x): def true_gm_inner(s1): return s1 ``` - What if subgraph close over an unbacked symint? e.g. ```python def f(x): def true_f(): c = x.item() def true_f_inner(): return c ``` When x.item() is called, proxy_of_c and its symnode variable is created for tracer 1, and we also call track_unbacked_symbols to record this relationship. So when tracer 2 finds proxy_of_c is not created by current tracer, it recursivelly looks up its parent tracer and finds that that expression u0 has been tracked as a result of track_unbacked_symbol in tracer 1. So it will stop the recursion and create_graph_input u0 in tracer 2. Graph looks like: ```python def f(x): def true_f(s1, s2, x): c = x.item() def true_gm_inner(u0): return u0 cond(pred, true_gm_inner, false_gm_inner, (c,)) ``` - what if subgraph close over a tensor with unbacked symint shape? ```python def f(x): def true_f(): c = x.item() r = torch.randn((c,)) def true_f_inner(): return r + 1 ``` This is the same as the case of closing over tensors with backed shapes. where we first lift r, then bind u0 in it, which recursively bind_symint of u0 in its parent and found u0 is tracked in parent tracer as a result of .item() call. Pull Request resolved: https://github.com/pytorch/pytorch/pull/138363 Approved by: https://github.com/zou3519
2024-11-06 21:33:32 +00:00
allowed_types = (torch.Tensor, int, torch.SymInt)
assert all(
isinstance(arg, (torch.Tensor, int, torch.SymInt)) for arg in lifted_args
), f"{lifted_args} can only be of {allowed_types} but got {tuple(type(arg) for arg in lifted_args)}"
def check_input_alias_and_mutation(
gm: torch.fx.GraphModule,
fake_args: List[FakeTensor],
) -> Tuple[List[int], dict[int, int], dict[int, int], dict[int, int]]:
with disable_proxy_modes_tracing():
"""This function returns mutated inputs, inp-inp alias, inp-out alias, out-out alias
in the graph module gm. It checks whether input tensor versions have
changed after run gm once to detect mutation and checks tensor storage
to detect alias.
"""
from torch._prims_common import clone_preserve_strides
def _tensor_version(t) -> Optional[int]:
if isinstance(t, torch.Tensor):
assert isinstance(t, FakeTensor), "Only fake tensor is allowed"
return t._version
return None
def _tensor_storage(t) -> StorageWeakRef:
return StorageWeakRef(t._typed_storage())
# Clone the fake args to avoid mutating the original fake args
with ExitStack() as ctx_stack:
# We need to temporarily turn inference_mode off because
# under inference mode, tensor version counter is not tracked.
ctx_stack.enter_context(torch.inference_mode(False))
if (fake_mode := detect_fake_mode(fake_args)) is not None:
ctx_stack.enter_context(fake_mode)
if fake_mode.shape_env is not None:
ctx_stack.enter_context(
fake_mode.shape_env.ignore_fresh_unbacked_symbols()
)
cloned = [
clone_preserve_strides(arg) if isinstance(arg, torch.Tensor) else arg
for arg in fake_args
]
before = [_tensor_version(arg) for arg in cloned]
outputs = gm(*cloned)
outputs = [outputs] if not isinstance(outputs, (list, tuple)) else outputs
after = [_tensor_version(arg) for arg in cloned]
mutated_inputs = [
i for i, (v1, v2) in enumerate(zip(before, after)) if v1 != v2
]
# We need to analyze the original fake_args to detect
# inp-inp alias.
inp_storage_map = {
_tensor_storage(inp): i
for i, inp in enumerate(fake_args)
if isinstance(inp, torch.Tensor)
}
inp_inp_alias_map = {
i: inp_storage_map[_tensor_storage(inp)]
for i, inp in enumerate(fake_args)
if isinstance(inp, torch.Tensor)
and inp_storage_map[_tensor_storage(inp)] != i
}
out_storage_map = {
_tensor_storage(out): i
for i, out in enumerate(outputs)
if isinstance(out, torch.Tensor)
}
out_out_alias_map = {
i: out_storage_map[_tensor_storage(out)]
for i, out in enumerate(outputs)
if isinstance(out, torch.Tensor)
and out_storage_map[_tensor_storage(out)] != i
}
inp_out_alias_map = {
i: out_storage_map[_tensor_storage(inp)]
for i, inp in enumerate(cloned)
if isinstance(inp, torch.Tensor) and _tensor_storage(inp) in out_storage_map
}
return mutated_inputs, inp_inp_alias_map, inp_out_alias_map, out_out_alias_map