diff --git a/test/export/test_export.py b/test/export/test_export.py index 1359907fa90..33b6822d9fc 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -6074,6 +6074,62 @@ graph(): self.assertEqual(gm_flat_non_strict(*inp), gm_flat_strict(*inp)) + def test_unflatten_no_unroll(self): + inp = (torch.ones(1),) + + class N(torch.nn.Module): + def forward(self, x, b): + if b: + return x + 1 + else: + return x + 2 + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.n = N() + + def forward(self, x): + x0 = x + 3 + x1 = self.n(x0, True) + x2 = self.n(x0, False) + return x1 + x2 + + m = M() + eager_result = m(*inp) + + if not is_retracebility_test(self._testMethodName): + ep = export(M(), inp, preserve_module_call_signature=("n",)) + with self.assertRaisesRegex( + ValueError, + "Cannot unflatten multiple calls to module n while preserving its signature", + ): + torch.export.unflatten(ep) + + ep = export(M(), inp) + print(ep) + epm = ep.module() + ufm = torch.export.unflatten(ep) + + exported_result = epm(*inp) + self.assertTrue(torch.allclose(exported_result, eager_result)) + + unflattened_result = ufm(*inp) + self.assertTrue(torch.allclose(unflattened_result, eager_result)) + + class _N(torch.nn.Module): + def forward(self, x): + return x + 1 + + class _N_1(torch.nn.Module): + def forward(self, x): + return x + 2 + + ufm.set_submodule("n", _N()) + ufm.set_submodule("n@1", _N_1()) + unflattened_result = ufm(*inp) + self.assertTrue(torch.allclose(unflattened_result, eager_result)) + def test_preserve_module_call_signature_unflatten_specialization(self): class N(torch.nn.Module): def forward(self, x, b): diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 27c2f74c26b..090fa5f61cc 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1995,7 +1995,11 @@ class SubgraphTracer(fx.Tracer): rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ ( rv.node.name, - rv.node.meta["nn_module_stack"][target][1], + next( + ty + for k, (_, ty) in rv.node.meta["nn_module_stack"].items() + if k.split("@")[0] == target + ), ) ] diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index c6120047b23..c12e235ff3a 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -2710,6 +2710,7 @@ class InstructionTranslatorBase( # The first field of tuple is the fully qualified name of current module # in original hierarchy. The second field is the type of current nn.module self.nn_module_stack: Dict[str, Tuple[str, Type[Any]]] = {} + self.num_calls: Dict[str, int] = {} # Flag to indicate whether tracing is used for export. self.export = export self.one_graph = False diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index f3ddcd80cf5..a5839e34d49 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -82,8 +82,11 @@ def initialize_lazy_module(tx: "InstructionTranslator", mod, args, kwargs): @contextmanager def record_nn_module_stack(module_key: str, source, tx, mod: torch.nn.Module): fully_qualified_name = source.name() + num_calls = tx.num_calls.get(fully_qualified_name, 0) + module_key = f"{module_key}@{num_calls}" if num_calls > 0 else module_key try: tx.nn_module_stack[module_key] = (fully_qualified_name, mod.__class__) + tx.num_calls[fully_qualified_name] = num_calls + 1 yield finally: del tx.nn_module_stack[module_key] diff --git a/torch/distributed/pipelining/_unflatten.py b/torch/distributed/pipelining/_unflatten.py index d5aaba95ea3..7c68eecb3bb 100644 --- a/torch/distributed/pipelining/_unflatten.py +++ b/torch/distributed/pipelining/_unflatten.py @@ -18,7 +18,7 @@ def _outline_submodules(orig_graph: torch.fx.Graph): seen_nodes, seen_modules, None, - [""], + [("", 0)], "", {}, module=new_module, diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 0ca0d1b3dc6..4e7f74d8448 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -749,7 +749,7 @@ class _ModuleFrame: seen_nodes, seen_modules, parent, - module_stack: List[str], + module_stack: List[Tuple[str, int]], module_id, module_call_graph: Dict[str, ModuleCallSignature], module: Optional[torch.nn.Module] = None, @@ -765,7 +765,7 @@ class _ModuleFrame: self.module_call_graph = module_call_graph self.verbose = False - self.fqn = self.module_stack[-1] + self.fqn, num_calls = self.module_stack[-1] if module is not None: self.module = module else: @@ -779,9 +779,6 @@ class _ModuleFrame: self.parent_call_module: Optional[torch.fx.Node] = None if parent is not None: - num_calls = len( - [x for x in self.seen_modules[self.module_id] if x.fqn == self.fqn] - ) if self.fqn in module_call_graph and num_calls == 1: raise ValueError( f"Cannot unflatten multiple calls to module {self.fqn} while preserving its signature " @@ -1068,8 +1065,9 @@ class _ModuleFrame: self.print() self.print("STEP", node_idx, node.format_node()) self.print(self.module_stack) + depth = len(self.module_stack) if node.op == "output": - if len(self.module_stack) == 1: + if depth == 1: # We want the output node of the original graph to be handled # specially by the outermost stack frame (in run_outer). So # skip finalization here. @@ -1095,10 +1093,11 @@ class _ModuleFrame: node_module_stack = self.module_stack else: node_module_stack = [ - path for path, ty in node.meta["nn_module_stack"].values() + (path, int(k.split("@")[-1]) if "@" in k else 0) + for k, (path, ty) in node.meta["nn_module_stack"].items() ] - if node_module_stack[: len(self.module_stack)] != self.module_stack: + if node_module_stack[:depth] != self.module_stack: # This means that the current module is done executing and the # current node is the beginning of a new module. # @@ -1114,10 +1113,11 @@ class _ModuleFrame: if _is_prefix(self.module_stack, node_module_stack): # This means that the current node represents the execution of a new # module. - next_module = node_module_stack[len(self.module_stack)] + next_module = node_module_stack[depth] self.print("Creating new stack frame for", next_module) # Run a nested version of module outliner from the current node # counter. Once it is complete, continue from that point. + next_module_key = list(node.meta["nn_module_stack"].keys())[depth] node_idx = _ModuleFrame( self.flat_graph, self.nodes, @@ -1125,7 +1125,7 @@ class _ModuleFrame: self.seen_modules, self, self.module_stack + [next_module], - list(node.meta["nn_module_stack"].keys())[len(self.module_stack)], + next_module_key.split("@")[0], self.module_call_graph, ).run_from(node_idx) module_idx += 1 @@ -1157,7 +1157,7 @@ def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModu seen_nodes, seen_modules, None, - [""], + [("", 0)], "", { entry.fqn: entry.signature diff --git a/torch/fx/_symbolic_trace.py b/torch/fx/_symbolic_trace.py index 66938633865..63568607039 100644 --- a/torch/fx/_symbolic_trace.py +++ b/torch/fx/_symbolic_trace.py @@ -308,6 +308,7 @@ class Tracer(TracerBase): self.scope = Scope("", None) # Records the module call stack self.module_stack = collections.OrderedDict() + self.num_calls: Dict[str, int] = {} # Mapping of node name to module scope self.node_name_to_scope: Dict[str, Tuple[str, type]] = {} @@ -514,13 +515,16 @@ class Tracer(TracerBase): with ScopeContextManager(self.scope, Scope(module_qualified_name, type(m))) as _scope: # module_stack is an ordered dict so writing then deleting the # entry is equivalent to push/pop on a list - self.module_stack[_scope.module_path] = (module_qualified_name, _scope.module_type) + num_calls = self.num_calls.get(module_qualified_name, 0) + module_key = f"{_scope.module_path}@{num_calls}" if num_calls > 0 else _scope.module_path + self.module_stack[module_key] = (module_qualified_name, _scope.module_type) + self.num_calls[module_qualified_name] = num_calls + 1 if not self.is_leaf_module(m, module_qualified_name): ret_val = forward(*args, **kwargs) else: ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs) key, _ = self.module_stack.popitem(last=True) - assert key == _scope.module_path, f" Unexpected key {key}" + assert key == module_key, f" Unexpected key {key}" return ret_val diff --git a/torch/onnx/_internal/fx/passes/modularization.py b/torch/onnx/_internal/fx/passes/modularization.py index e1ec411aea1..b819bb4f663 100644 --- a/torch/onnx/_internal/fx/passes/modularization.py +++ b/torch/onnx/_internal/fx/passes/modularization.py @@ -140,7 +140,7 @@ class _ModuleMeta: ) -> _ModuleMeta: """Create a module meta from raw meta produced by FX dynamo tracer.""" module_name, (qualified_name, module_class) = raw_meta - return _ModuleMeta(module_name, module_class, raw_meta) + return _ModuleMeta(module_name.split("@")[0], module_class, raw_meta) @classmethod def from_raw_meta(