From ed55d356de80f011fecd63a4aed03e4c0694acee Mon Sep 17 00:00:00 2001 From: Avik Chaudhuri Date: Sat, 12 Oct 2024 15:53:52 +0000 Subject: [PATCH] [alt] fix unroll in successive unflatten (#137646) We use nn_module_stack in unflatten to recognize when module calls begin and end. However the current format is not sufficient to detect module call boundaries when we have successive calls to the same module, because the successive instructions (end of one call, begin of next call) have the same nn_module_stack. This causes us to effectively "unroll" successive calls to a single call. This can cause problems when preserving module call signatures because the outputs of the successive calls might be concatenated in the single call. Previously we introduced the concept of a "call index" to generate multiple graphs when unflattening, one per call. This PR pushes this concept into nn_module_stack itself. In particular, the keys of nn_module_stack now go from `key` to `key@call_index`. (In a previous attempt, https://github.com/pytorch/pytorch/pull/137457, instead values in nn_module_stack go from (fqn, type) to (fqn, type, call_index), which is BC-breaking.) Note that we still do not have the ability to preserve module call signatures for multiple calls to the same module. But now instead of randomly crashing we give a proper error. OTOH when not preserving module call signatures we simply generate multiple calls, each with its own graph, possibly deduplicated, matching what we would do for non-successive calls. Test Plan: Like D64014936 Differential Revision: D64136277 Pull Request resolved: https://github.com/pytorch/pytorch/pull/137646 Approved by: https://github.com/angelayi --- test/export/test_export.py | 56 +++++++++++++++++++ torch/_dynamo/output_graph.py | 6 +- torch/_dynamo/symbolic_convert.py | 1 + torch/_dynamo/variables/nn_module.py | 3 + torch/distributed/pipelining/_unflatten.py | 2 +- torch/export/unflatten.py | 22 ++++---- torch/fx/_symbolic_trace.py | 8 ++- .../_internal/fx/passes/modularization.py | 2 +- 8 files changed, 84 insertions(+), 16 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 1359907fa90..33b6822d9fc 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -6074,6 +6074,62 @@ graph(): self.assertEqual(gm_flat_non_strict(*inp), gm_flat_strict(*inp)) + def test_unflatten_no_unroll(self): + inp = (torch.ones(1),) + + class N(torch.nn.Module): + def forward(self, x, b): + if b: + return x + 1 + else: + return x + 2 + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.n = N() + + def forward(self, x): + x0 = x + 3 + x1 = self.n(x0, True) + x2 = self.n(x0, False) + return x1 + x2 + + m = M() + eager_result = m(*inp) + + if not is_retracebility_test(self._testMethodName): + ep = export(M(), inp, preserve_module_call_signature=("n",)) + with self.assertRaisesRegex( + ValueError, + "Cannot unflatten multiple calls to module n while preserving its signature", + ): + torch.export.unflatten(ep) + + ep = export(M(), inp) + print(ep) + epm = ep.module() + ufm = torch.export.unflatten(ep) + + exported_result = epm(*inp) + self.assertTrue(torch.allclose(exported_result, eager_result)) + + unflattened_result = ufm(*inp) + self.assertTrue(torch.allclose(unflattened_result, eager_result)) + + class _N(torch.nn.Module): + def forward(self, x): + return x + 1 + + class _N_1(torch.nn.Module): + def forward(self, x): + return x + 2 + + ufm.set_submodule("n", _N()) + ufm.set_submodule("n@1", _N_1()) + unflattened_result = ufm(*inp) + self.assertTrue(torch.allclose(unflattened_result, eager_result)) + def test_preserve_module_call_signature_unflatten_specialization(self): class N(torch.nn.Module): def forward(self, x, b): diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 27c2f74c26b..090fa5f61cc 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1995,7 +1995,11 @@ class SubgraphTracer(fx.Tracer): rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ ( rv.node.name, - rv.node.meta["nn_module_stack"][target][1], + next( + ty + for k, (_, ty) in rv.node.meta["nn_module_stack"].items() + if k.split("@")[0] == target + ), ) ] diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index c6120047b23..c12e235ff3a 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -2710,6 +2710,7 @@ class InstructionTranslatorBase( # The first field of tuple is the fully qualified name of current module # in original hierarchy. The second field is the type of current nn.module self.nn_module_stack: Dict[str, Tuple[str, Type[Any]]] = {} + self.num_calls: Dict[str, int] = {} # Flag to indicate whether tracing is used for export. self.export = export self.one_graph = False diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index f3ddcd80cf5..a5839e34d49 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -82,8 +82,11 @@ def initialize_lazy_module(tx: "InstructionTranslator", mod, args, kwargs): @contextmanager def record_nn_module_stack(module_key: str, source, tx, mod: torch.nn.Module): fully_qualified_name = source.name() + num_calls = tx.num_calls.get(fully_qualified_name, 0) + module_key = f"{module_key}@{num_calls}" if num_calls > 0 else module_key try: tx.nn_module_stack[module_key] = (fully_qualified_name, mod.__class__) + tx.num_calls[fully_qualified_name] = num_calls + 1 yield finally: del tx.nn_module_stack[module_key] diff --git a/torch/distributed/pipelining/_unflatten.py b/torch/distributed/pipelining/_unflatten.py index d5aaba95ea3..7c68eecb3bb 100644 --- a/torch/distributed/pipelining/_unflatten.py +++ b/torch/distributed/pipelining/_unflatten.py @@ -18,7 +18,7 @@ def _outline_submodules(orig_graph: torch.fx.Graph): seen_nodes, seen_modules, None, - [""], + [("", 0)], "", {}, module=new_module, diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 0ca0d1b3dc6..4e7f74d8448 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -749,7 +749,7 @@ class _ModuleFrame: seen_nodes, seen_modules, parent, - module_stack: List[str], + module_stack: List[Tuple[str, int]], module_id, module_call_graph: Dict[str, ModuleCallSignature], module: Optional[torch.nn.Module] = None, @@ -765,7 +765,7 @@ class _ModuleFrame: self.module_call_graph = module_call_graph self.verbose = False - self.fqn = self.module_stack[-1] + self.fqn, num_calls = self.module_stack[-1] if module is not None: self.module = module else: @@ -779,9 +779,6 @@ class _ModuleFrame: self.parent_call_module: Optional[torch.fx.Node] = None if parent is not None: - num_calls = len( - [x for x in self.seen_modules[self.module_id] if x.fqn == self.fqn] - ) if self.fqn in module_call_graph and num_calls == 1: raise ValueError( f"Cannot unflatten multiple calls to module {self.fqn} while preserving its signature " @@ -1068,8 +1065,9 @@ class _ModuleFrame: self.print() self.print("STEP", node_idx, node.format_node()) self.print(self.module_stack) + depth = len(self.module_stack) if node.op == "output": - if len(self.module_stack) == 1: + if depth == 1: # We want the output node of the original graph to be handled # specially by the outermost stack frame (in run_outer). So # skip finalization here. @@ -1095,10 +1093,11 @@ class _ModuleFrame: node_module_stack = self.module_stack else: node_module_stack = [ - path for path, ty in node.meta["nn_module_stack"].values() + (path, int(k.split("@")[-1]) if "@" in k else 0) + for k, (path, ty) in node.meta["nn_module_stack"].items() ] - if node_module_stack[: len(self.module_stack)] != self.module_stack: + if node_module_stack[:depth] != self.module_stack: # This means that the current module is done executing and the # current node is the beginning of a new module. # @@ -1114,10 +1113,11 @@ class _ModuleFrame: if _is_prefix(self.module_stack, node_module_stack): # This means that the current node represents the execution of a new # module. - next_module = node_module_stack[len(self.module_stack)] + next_module = node_module_stack[depth] self.print("Creating new stack frame for", next_module) # Run a nested version of module outliner from the current node # counter. Once it is complete, continue from that point. + next_module_key = list(node.meta["nn_module_stack"].keys())[depth] node_idx = _ModuleFrame( self.flat_graph, self.nodes, @@ -1125,7 +1125,7 @@ class _ModuleFrame: self.seen_modules, self, self.module_stack + [next_module], - list(node.meta["nn_module_stack"].keys())[len(self.module_stack)], + next_module_key.split("@")[0], self.module_call_graph, ).run_from(node_idx) module_idx += 1 @@ -1157,7 +1157,7 @@ def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModu seen_nodes, seen_modules, None, - [""], + [("", 0)], "", { entry.fqn: entry.signature diff --git a/torch/fx/_symbolic_trace.py b/torch/fx/_symbolic_trace.py index 66938633865..63568607039 100644 --- a/torch/fx/_symbolic_trace.py +++ b/torch/fx/_symbolic_trace.py @@ -308,6 +308,7 @@ class Tracer(TracerBase): self.scope = Scope("", None) # Records the module call stack self.module_stack = collections.OrderedDict() + self.num_calls: Dict[str, int] = {} # Mapping of node name to module scope self.node_name_to_scope: Dict[str, Tuple[str, type]] = {} @@ -514,13 +515,16 @@ class Tracer(TracerBase): with ScopeContextManager(self.scope, Scope(module_qualified_name, type(m))) as _scope: # module_stack is an ordered dict so writing then deleting the # entry is equivalent to push/pop on a list - self.module_stack[_scope.module_path] = (module_qualified_name, _scope.module_type) + num_calls = self.num_calls.get(module_qualified_name, 0) + module_key = f"{_scope.module_path}@{num_calls}" if num_calls > 0 else _scope.module_path + self.module_stack[module_key] = (module_qualified_name, _scope.module_type) + self.num_calls[module_qualified_name] = num_calls + 1 if not self.is_leaf_module(m, module_qualified_name): ret_val = forward(*args, **kwargs) else: ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs) key, _ = self.module_stack.popitem(last=True) - assert key == _scope.module_path, f" Unexpected key {key}" + assert key == module_key, f" Unexpected key {key}" return ret_val diff --git a/torch/onnx/_internal/fx/passes/modularization.py b/torch/onnx/_internal/fx/passes/modularization.py index e1ec411aea1..b819bb4f663 100644 --- a/torch/onnx/_internal/fx/passes/modularization.py +++ b/torch/onnx/_internal/fx/passes/modularization.py @@ -140,7 +140,7 @@ class _ModuleMeta: ) -> _ModuleMeta: """Create a module meta from raw meta produced by FX dynamo tracer.""" module_name, (qualified_name, module_class) = raw_meta - return _ModuleMeta(module_name, module_class, raw_meta) + return _ModuleMeta(module_name.split("@")[0], module_class, raw_meta) @classmethod def from_raw_meta(