From abd759d50d87fa56078de95a444fe658de2983f9 Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Tue, 23 Jan 2024 22:28:40 +0000 Subject: [PATCH] [fx] Add hooks to intercept node replacements. (#117825) Summary: Adding an experimental API to FX graph module to place "hooks" every time when we are changing or replacing nodes in a graph, so that we can properly update the new name in graph signature and potentially other places. Test Plan: buck test mode/opt -c fbcode.enable_gpu_sections=true caffe2/test/distributed/_tensor/experimental:tp_transform buck test mode/opt caffe2/test:test_export -- -r test_replace_hook Differential Revision: D52896531 Pull Request resolved: https://github.com/pytorch/pytorch/pull/117825 Approved by: https://github.com/avikchaudhuri --- docs/source/export.rst | 1 + test/export/test_pass_infra.py | 43 ++++++++++++++++++- .../_tensor/experimental/tp_transform.py | 19 +++++--- torch/export/exported_program.py | 16 +++++++ torch/export/graph_signature.py | 14 +++++- torch/fx/graph_module.py | 18 ++++++++ torch/fx/node.py | 17 ++++++++ 7 files changed, 120 insertions(+), 8 deletions(-) diff --git a/docs/source/export.rst b/docs/source/export.rst index ccad3e8c161..d7db226d695 100644 --- a/docs/source/export.rst +++ b/docs/source/export.rst @@ -570,6 +570,7 @@ API Reference .. autoclass:: ExportGraphSignature .. automethod:: replace_all_uses + .. automethod:: get_replace_hook .. autoclass:: torch.export.graph_signature.CustomObjArgument diff --git a/test/export/test_pass_infra.py b/test/export/test_pass_infra.py index 4a1cd999e9d..40de84bedb3 100644 --- a/test/export/test_pass_infra.py +++ b/test/export/test_pass_infra.py @@ -1,11 +1,13 @@ # Owner(s): ["oncall: export"] +import copy import unittest import torch from functorch.experimental import control_flow from torch._dynamo.eval_frame import is_dynamo_supported -from torch.export import export from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse +from torch.export import export +from torch.fx.passes.infra.pass_base import PassResult from torch.testing._internal.common_utils import run_tests, TestCase @@ -141,6 +143,45 @@ class TestPassInfra(TestCase): old_signature = ep_before.graph_signature self.assertNotEqual(new_signature.user_outputs, old_signature.user_outputs) + def test_replace_hook_basic(self) -> None: + class CustomModule(torch.nn.Module): + def __init__(self): + super().__init__() + + self.my_parameter = torch.nn.Parameter(torch.tensor(2.0)) + + self.register_buffer("my_buffer1", torch.tensor(3.0)) + self.register_buffer("my_buffer2", torch.tensor(4.0)) + + def forward(self, x1, x2): + # Use the parameter, buffers, and both inputs in the forward method + output = ( + x1 + self.my_parameter + ) * self.my_buffer1 + x2 * self.my_buffer2 + return output + + my_module = CustomModule() + inputs = (torch.tensor(6.0), torch.tensor(7.0)) + ep_before = export(my_module, inputs) + + def replace_pass(gm): + for node in gm.graph.nodes: + if node.op == "call_function": + node.name = node.name + "_modified" + gm.recompile() + return PassResult(gm, True) + + gm = copy.deepcopy(ep_before.graph_module) + sig = copy.deepcopy(ep_before.graph_signature) + + with gm._set_replace_hook(sig.get_replace_hook()): + replace_pass(gm) + + for node_name in sig.user_outputs: + self.assertTrue("_modified" in node_name) + + old_signature = ep_before.graph_signature + self.assertNotEqual(sig.user_outputs, old_signature.user_outputs) if __name__ == '__main__': run_tests() diff --git a/torch/distributed/_tensor/experimental/tp_transform.py b/torch/distributed/_tensor/experimental/tp_transform.py index 832c167dcc3..2fb8b1b0093 100644 --- a/torch/distributed/_tensor/experimental/tp_transform.py +++ b/torch/distributed/_tensor/experimental/tp_transform.py @@ -47,17 +47,24 @@ def tensor_parallel_transformation( .. warning:: This API is experimental and subject to change. """ - # TODO Migrate this to plain function call. - return exported_program._transform_do_not_use( - TensorParallelTransformPass( + + gm = exported_program.graph_module + sig = copy.deepcopy(exported_program.graph_signature) + state_dict = copy.copy(exported_program.state_dict) + + with gm._set_replace_hook(sig.get_replace_hook()): + res = TensorParallelTransformPass( rank, world_size, device_type, - exported_program.state_dict, + state_dict, exported_program.graph_signature, parallel_strategies, - ) - ) + )(gm) + assert res is not None + gm = res.graph_module + + return exported_program._update(gm, sig, state_dict) class TensorParallelTransformPass(PassBase): diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index fe94e695b9c..7d490f6c2a5 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -579,6 +579,22 @@ class ExportedProgram: def _validate(self): self.verifier().check(self) + # TODO(zhxchen17) Formalize this. + def _update( + self, graph_module, graph_signature, state_dict=None + ) -> "ExportedProgram": + return ExportedProgram( + root=graph_module, + graph=graph_module.graph, + graph_signature=graph_signature, + state_dict=state_dict or self.state_dict, + range_constraints=copy.deepcopy(self.range_constraints), + module_call_graph=copy.deepcopy(self._module_call_graph), + example_inputs=self.example_inputs, + verifier=self.verifier, + tensor_constants=self.tensor_constants, + ) + def _get_updated_range_constraints( gm: torch.fx.GraphModule, diff --git a/torch/export/graph_signature.py b/torch/export/graph_signature.py index 0030f63d191..cf781849ad9 100644 --- a/torch/export/graph_signature.py +++ b/torch/export/graph_signature.py @@ -424,7 +424,19 @@ class ExportGraphSignature: """ assert isinstance(old, str) assert isinstance(new, str) + arg_types = (TensorArgument, SymIntArgument, CustomObjArgument) for o in self.output_specs: - if isinstance(o.arg, TensorArgument): + if isinstance(o.arg, arg_types): if o.arg.name == old: o.arg.name = new + for i in self.input_specs: + if isinstance(i.arg, arg_types): + if i.arg.name == old: + i.arg.name = new + + def get_replace_hook(self): + def _(old, new, user): + if user.op in ("output", "input"): + self.replace_all_uses(old.name, new) + + return _ diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 8ee898ef20a..277674b229d 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -1,3 +1,4 @@ +import contextlib import copy import itertools import linecache @@ -445,6 +446,7 @@ class GraphModule(torch.nn.Module): # Dictionary to store metadata self.meta: Dict[str, Any] = {} + self._replace_hook = None # TorchScript breaks trying to compile the graph setter because of the # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842 @@ -799,6 +801,7 @@ class {module_name}(torch.nn.Module): "_state_dict_hooks", "_load_state_dict_pre_hooks", "_load_state_dict_post_hooks", + "_replace_hook", ] for attr in extra_preserved_attrs: if attr in self.__dict__: @@ -849,6 +852,21 @@ class {module_name}(torch.nn.Module): new_gm._is_replica = True return new_gm + @contextlib.contextmanager + def _set_replace_hook(self, f): + """ + Takes a callable which will be called everytime when we replace a node + to a new node, or change the node's name. Callable takes three arguments: + the old node we're changing, and NAME of the new node, followed by the + user node which consumes the old node to be replaced. + """ + assert callable(f), "Replace hook must be a callable." + prev, self._replace_hook = self._replace_hook, f + try: + yield + finally: + self._replace_hook = prev + # workarounds for issues in __torch_function__ diff --git a/torch/fx/node.py b/torch/fx/node.py index 616b3888f89..9a020f46566 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -562,6 +562,7 @@ class Node: replace_with.meta[k] = v to_process = list(self.users) skipped = [] + m = self.graph.owning_module for use_node in to_process: if not delete_user_cb(use_node): skipped.append(use_node) @@ -573,6 +574,9 @@ class Node: else: return n + if getattr(m, "_replace_hook", None): + m._replace_hook(old=self, new=replace_with.name, user=use_node) + new_args = map_arg(use_node.args, maybe_replace_node) new_kwargs = map_arg(use_node.kwargs, maybe_replace_node) assert isinstance(new_args, tuple) @@ -662,6 +666,10 @@ class Node: def maybe_replace_node(n : Node) -> Node: return new_input if n == old_input else n + m = self.graph.owning_module + if getattr(m, "_replace_hook", None): + m._replace_hook(old=old_input, new=new_input.name, user=self) + new_args = map_arg(self.args, maybe_replace_node) new_kwargs = map_arg(self.kwargs, maybe_replace_node) assert isinstance(new_args, tuple) @@ -675,6 +683,15 @@ class Node: self.name = name self.graph._graph_namespace._rename_object(self, name) + def __setattr__(self, name: str, value: Any) -> None: + if name == 'name' and hasattr(self, "name"): + m = self.graph.owning_module + if getattr(m, "_replace_hook", None): + assert isinstance(value, str) + for user in self.users: + m._replace_hook(old=self, new=value, user=user) + object.__setattr__(self, name, value) + @compatibility(is_backward_compatible=True) def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument: