mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
This makes the next PR in the stack cleaner: having the top level entry point to aot autograd perform the functionalization analysis pass once, and plumb the metadata everywhere else that we need it. I put it in a separate PR because I recently learned that this function is used in fbcode, so I'll need to fix up internals when I land this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/95991 Approved by: https://github.com/ezyang
205 lines
6.8 KiB
Python
205 lines
6.8 KiB
Python
import contextlib
|
|
import copy
|
|
from typing import Callable, Tuple, Generator, Dict
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
import torch._dynamo as torchdynamo
|
|
from torch._decomp import core_aten_decompositions
|
|
from torch._dispatch.python import enable_python_dispatcher
|
|
from torch.nn.utils import stateless
|
|
from torch.utils import _pytree as pytree
|
|
|
|
from torch._functorch.aot_autograd import (
|
|
AOTConfig,
|
|
create_aot_dispatcher_function,
|
|
default_partition,
|
|
run_functionalized_fw_and_collect_metadata,
|
|
)
|
|
|
|
from torch.fx.experimental.proxy_tensor import (
|
|
get_proxy_slot,
|
|
get_torch_dispatch_modes,
|
|
has_proxy_slot,
|
|
make_fx,
|
|
ProxyTorchDispatchMode,
|
|
set_proxy_slot,
|
|
)
|
|
|
|
from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional
|
|
|
|
from .workflow import ExportedProgram
|
|
|
|
CORE_ATEN_DECOMPOSITIONS_TABLE = core_aten_decompositions()
|
|
|
|
__all__ = ["experimental_export"]
|
|
|
|
|
|
def _aot_capture(mod, flat_args):
|
|
"""
|
|
A wrapper around aot_autograd() to mix AOT Autograd + torch.export.
|
|
Some assumptions were made about the AOT Autograd internal:
|
|
1. The functionalization metadata format.
|
|
2. Calling convention of returned forward graph.
|
|
3. make_fx() internal proxy storage.
|
|
|
|
In the current context we're just experimenting the idea so it's possible things
|
|
could break. For the next step we should find a way to upstream something reasonable.
|
|
"""
|
|
param_list = [
|
|
*mod.named_parameters(remove_duplicate=False),
|
|
*mod.named_buffers(remove_duplicate=False),
|
|
]
|
|
params = dict(param_list)
|
|
params_flat, params_spec = pytree.tree_flatten(params)
|
|
params_len = len(params_flat)
|
|
|
|
full_args = []
|
|
full_args.extend(params_flat)
|
|
full_args.extend(flat_args)
|
|
|
|
def functional_call(*args):
|
|
|
|
with stateless._reparametrize_module(
|
|
mod,
|
|
pytree.tree_unflatten(args[:params_len], params_spec), # type: ignore[arg-type]
|
|
):
|
|
return torch.fx.Interpreter(mod).run(*args[params_len:])
|
|
|
|
out_spec = None
|
|
|
|
with enable_python_dispatcher():
|
|
fw_metadata = run_functionalized_fw_and_collect_metadata(
|
|
lambda *args: pytree.tree_flatten(functional_call(*args))[0],
|
|
keep_input_mutations=False,
|
|
)(*copy.deepcopy(full_args)) # type: ignore[operator]
|
|
|
|
assert len(fw_metadata.input_info) == len(full_args)
|
|
mutated_input_indices = [
|
|
i
|
|
for i, input_info in enumerate(fw_metadata.input_info)
|
|
if input_info.mutates_data or input_info.mutates_metadata
|
|
]
|
|
|
|
graph_module = None
|
|
|
|
def fw_compiler(gm, inputs):
|
|
nonlocal graph_module
|
|
graph_module = gm
|
|
|
|
num_fwd_returns = None
|
|
|
|
def partition_fn(joint_module, joint_inputs, *, num_fwd_outputs, **kwargs):
|
|
nonlocal num_fwd_returns
|
|
num_fwd_returns = num_fwd_outputs
|
|
return default_partition(
|
|
joint_module, joint_inputs, num_fwd_outputs=num_fwd_outputs, **kwargs
|
|
)
|
|
|
|
def set_state_proxies(state_args):
|
|
modes = get_torch_dispatch_modes()
|
|
proxy_tensor_modes = [m for m in modes if isinstance(m, ProxyTorchDispatchMode)]
|
|
if len(proxy_tensor_modes) == 0:
|
|
return
|
|
assert len(state_args) == len(params_flat)
|
|
for i, arg in enumerate(state_args):
|
|
tracer = next(
|
|
m.tracer for m in proxy_tensor_modes if has_proxy_slot(arg, m.tracer)
|
|
)
|
|
set_proxy_slot(arg, tracer, params_flat[i])
|
|
|
|
aot_config = AOTConfig(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=lambda gm, inputs: None,
|
|
partition_fn=partition_fn,
|
|
decompositions=CORE_ATEN_DECOMPOSITIONS_TABLE, # type: ignore[arg-type]
|
|
num_params_buffers=params_len,
|
|
aot_id=-1,
|
|
keep_inference_input_mutations=False,
|
|
dynamic_shapes=True,
|
|
)
|
|
|
|
def exported_call(*args):
|
|
state_args = args[:params_len]
|
|
unwrapped_state_args = _unwrap_all_tensors_from_functional(
|
|
state_args, reapply_views=False
|
|
)
|
|
set_state_proxies(unwrapped_state_args)
|
|
with torch.fx.traceback.preserve_node_meta():
|
|
outputs = functional_call(*args)
|
|
nonlocal out_spec
|
|
outputs, out_spec = pytree.tree_flatten(outputs)
|
|
return outputs
|
|
|
|
with torch.enable_grad():
|
|
create_aot_dispatcher_function(
|
|
exported_call,
|
|
full_args,
|
|
aot_config,
|
|
)
|
|
|
|
assert graph_module is not None
|
|
|
|
for i, node in enumerate(graph_module.graph.nodes):
|
|
if i == len(params_flat):
|
|
break
|
|
assert node.op == "placeholder" and len(node.users) == 0
|
|
graph_module.graph.erase_node(node)
|
|
|
|
output_node = next(iter(reversed(graph_module.graph.nodes)))
|
|
assert output_node.op == "output" and len(output_node.args) == 1
|
|
assert num_fwd_returns is not None
|
|
# Turncate the output so we only output what we need.
|
|
output_node.args = (
|
|
output_node.args[0][
|
|
: len(mutated_input_indices) + len(fw_metadata.output_info)
|
|
],
|
|
)
|
|
|
|
graph_module.graph.eliminate_dead_code()
|
|
graph_module.recompile()
|
|
|
|
def find_mutation_destinations(gm, w):
|
|
assert isinstance(w, torch.Tensor)
|
|
ret = [
|
|
name for name, x in [*gm.named_parameters(), *gm.named_buffers()] if x is w
|
|
]
|
|
assert len(ret) != 0, "Cannot find mutation destination."
|
|
return ret
|
|
|
|
mutation = [
|
|
(
|
|
"copy_",
|
|
output_node.args[0][k].name,
|
|
find_mutation_destinations(graph_module, param_list[i][1]),
|
|
)
|
|
for k, i in enumerate(mutated_input_indices)
|
|
]
|
|
assert out_spec is not None
|
|
return graph_module, mutation, out_spec
|
|
|
|
|
|
@patch.object(torchdynamo.config, "dynamic_shapes", True)
|
|
@patch.object(torchdynamo.config, "capture_scalar_outputs", True)
|
|
@patch.object(torchdynamo.config, "guard_nn_modules", True)
|
|
@patch.object(torchdynamo.config, "specialize_int", True)
|
|
@patch.object(torchdynamo.config, "allow_rnn", True)
|
|
@patch.object(torchdynamo.config, "verbose", True)
|
|
def do_not_use_experimental_export(f: Callable, args: Tuple, training=False):
|
|
"""
|
|
This prototype is under heavy development. Pls don't use it if you are
|
|
not part of PyTorch 2.0 Export team.
|
|
"""
|
|
if training:
|
|
NotImplementedError("training mode is not supported yet")
|
|
|
|
flattened_args, in_spec = pytree.tree_flatten(args)
|
|
# Doing it twice so that if graph_module accidentally modifies the input
|
|
# we still get the same original input.
|
|
original_flat_args = tuple(flattened_args)
|
|
flat_args = tuple(flattened_args)
|
|
|
|
graph_module, guards = torchdynamo.export(f, *args, aten_graph=False)
|
|
# TODO (tmanlaibaatar) do sth with guards?
|
|
graph_module, _, out_spec = _aot_capture(graph_module, flat_args)
|
|
return ExportedProgram(fw_module=graph_module, example_inputs=original_flat_args, in_spec=in_spec, out_spec=out_spec)
|