diff --git a/test/export/test_export.py b/test/export/test_export.py index 8bc63faa0ca..0701d7dc75b 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -5974,6 +5974,250 @@ graph(): self.assertEqual(gm_flat_non_strict(*inp), gm_flat_strict(*inp)) + def test_unflatten_multiple_graphs_preserve_signature_error(self): + class N(torch.nn.Module): + def forward(self, x, b): + if b: + return x + 1 + else: + return x + 2 + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.n = N() + + def forward(self, x): + x = self.n(x, True) + x = x + 1 + x = self.n(x, False) + x = x + 1 + return x + + inp = (torch.ones(1),) + m = M() + eager_result = m(*inp) + + if not is_retracebility_test(self._testMethodName): + ep = export(M(), inp, preserve_module_call_signature=("n",)) + with self.assertRaisesRegex( + ValueError, + "Cannot unflatten multiple calls to module n while preserving its signature", + ): + torch.export.unflatten(ep) + + ep = export(M(), inp) + epm = ep.module() + ufm = torch.export.unflatten(ep) + + exported_result = epm(*inp) + self.assertTrue(torch.allclose(exported_result, eager_result)) + + unflattened_result = ufm(*inp) + self.assertTrue(torch.allclose(unflattened_result, eager_result)) + + def test_unflatten_multiple_graphs_state(self): + class N(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buf", torch.ones(1), persistent=False) + + def forward(self, x, b): + if b: + self.buf.add_(1) + else: + self.buf.add_(2) + return x + self.buf + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.n = N() + + def forward(self, x): + x = self.n(x, True) + x = x + 1 + x = self.n(x, False) + x = x + 1 + x = self.n(x, True) + x = x + 1 + x = self.n(x, False) + return x + + inp = (torch.ones(1),) + m = M() + eager_result = m(*inp) + + ep = export(M(), inp) + epm = ep.module() + ufm = torch.export.unflatten(ep) + + exported_result = epm(*inp) + self.assertTrue(torch.allclose(exported_result, eager_result)) + + unflattened_result = ufm(*inp) + self.assertTrue(torch.allclose(unflattened_result, eager_result)) + + def test_unflatten_multiple_graphs_shared_submodule(self): + class N(torch.nn.Module): + def forward(self, x, b): + if b: + return x + 1 + else: + return x + 2 + + def gen_m(n, n_1, p, p_1): + # Create a module instance where self.n and self.p + # share the same submodule instance. + # The booleans n, n_1 and p, p_1 are passed to two calls each + # to self.n and self.p, and they determine which path through + # the shared submodule instance is taken during export. + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.n = N() + self.p = self.n + + def forward(self, x): + x = x + 3 + x = self.n(x, n) + x = x + 4 + x = self.n(x, n_1) + x = x + 5 + x = self.p(x, p) + x = x + 6 + x = self.p(x, p_1) + return x + 7 + + return M() + + inp = (torch.ones(1),) + + def test(m, expected_graph, expected_fqns, expected_duplicates): + eager_result = m(*inp) + + ep = export(m, inp) + exported_result = ep.module()(*inp) + # exported and eager results should match (baseline) + self.assertTrue(torch.allclose(exported_result, eager_result)) + + unflattened = torch.export.unflatten(ep) + unflattened_result = unflattened(*inp) + # unflattened and eager results should match + # (needs multiple specialized graphs for shared submodule instance) + self.assertTrue(torch.allclose(unflattened_result, eager_result)) + + # expected graph should call minimal number of specialized submodules + self.assertExpectedInline( + str(unflattened.graph).strip(), + expected_graph, + ) + + # expected graph should contain minimal number of specialized submodule fqns + self.assertEqual( + sorted( + [ + fqn + for fqn, _ in unflattened.named_modules(remove_duplicate=False) + ] + ), + expected_fqns, + ) + # expected graph should contain minimal number of specialized submodule instances + for a, b in expected_duplicates: + if is_non_strict_test(self._testMethodName): + # NOTE: non-strict does not de-duplicate shared submodules through different fqns. + # In particular, we use different module ids for self.n and self.p calls in non-strict, + # but in strict we use the same module id, which enables additional reuse. + # This is pre-existing behavior that might need to be fixed orthogonally. + self.assertNotEqual( + id(getattr(unflattened, a)), id(getattr(unflattened, b)) + ) + else: + self.assertEqual( + id(getattr(unflattened, a)), id(getattr(unflattened, b)) + ) + + test( + gen_m(n=True, n_1=False, p=False, p_1=False), + # p should share n_1 graph, p_1 should be optimized away + """\ +graph(): + %x : [num_users=1] = placeholder[target=x] + %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 3), kwargs = {}) + %n : [num_users=1] = call_module[target=n](args = (%add,), kwargs = {}) + %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%n, 4), kwargs = {}) + %n_1 : [num_users=1] = call_module[target=n@1](args = (%add_2,), kwargs = {}) + %add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%n_1, 5), kwargs = {}) + %p : [num_users=1] = call_module[target=p](args = (%add_4,), kwargs = {}) + %add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%p, 6), kwargs = {}) + %p_1 : [num_users=1] = call_module[target=p](args = (%add_6,), kwargs = {}) + %add_8 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%p_1, 7), kwargs = {}) + return (add_8,)""", + ["", "n", "n@1", "p"], + [("n@1", "p")], + ) + + test( + gen_m(n=True, n_1=False, p=True, p_1=False), + # p should reuse n graph, p_1 should reuse n_1 graph + """\ +graph(): + %x : [num_users=1] = placeholder[target=x] + %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 3), kwargs = {}) + %n : [num_users=1] = call_module[target=n](args = (%add,), kwargs = {}) + %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%n, 4), kwargs = {}) + %n_1 : [num_users=1] = call_module[target=n@1](args = (%add_2,), kwargs = {}) + %add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%n_1, 5), kwargs = {}) + %p : [num_users=1] = call_module[target=p](args = (%add_4,), kwargs = {}) + %add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%p, 6), kwargs = {}) + %p_1 : [num_users=1] = call_module[target=p@1](args = (%add_6,), kwargs = {}) + %add_8 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%p_1, 7), kwargs = {}) + return (add_8,)""", + ["", "n", "n@1", "p", "p@1"], + [("n", "p"), ("n@1", "p@1")], + ) + + test( + gen_m(n=True, n_1=True, p=True, p_1=False), + # n_1 should be optimized away, p should reuse n graph + """\ +graph(): + %x : [num_users=1] = placeholder[target=x] + %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 3), kwargs = {}) + %n : [num_users=1] = call_module[target=n](args = (%add,), kwargs = {}) + %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%n, 4), kwargs = {}) + %n_1 : [num_users=1] = call_module[target=n](args = (%add_2,), kwargs = {}) + %add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%n_1, 5), kwargs = {}) + %p : [num_users=1] = call_module[target=p](args = (%add_4,), kwargs = {}) + %add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%p, 6), kwargs = {}) + %p_1 : [num_users=1] = call_module[target=p@1](args = (%add_6,), kwargs = {}) + %add_8 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%p_1, 7), kwargs = {}) + return (add_8,)""", + ["", "n", "p", "p@1"], + [("n", "p")], + ) + + test( + gen_m(n=True, n_1=False, p=False, p_1=True), + # p should reuse n_1 graph, p_1 should reuse n graph + """\ +graph(): + %x : [num_users=1] = placeholder[target=x] + %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 3), kwargs = {}) + %n : [num_users=1] = call_module[target=n](args = (%add,), kwargs = {}) + %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%n, 4), kwargs = {}) + %n_1 : [num_users=1] = call_module[target=n@1](args = (%add_2,), kwargs = {}) + %add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%n_1, 5), kwargs = {}) + %p : [num_users=1] = call_module[target=p](args = (%add_4,), kwargs = {}) + %add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%p, 6), kwargs = {}) + %p_1 : [num_users=1] = call_module[target=p@1](args = (%add_6,), kwargs = {}) + %add_8 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%p_1, 7), kwargs = {}) + return (add_8,)""", + ["", "n", "n@1", "p", "p@1"], + [("n", "p@1"), ("p", "n@1")], + ) + def test_stack_trace(self): class Foo(torch.nn.Module): def __init__(self) -> None: diff --git a/torch/distributed/pipelining/_unflatten.py b/torch/distributed/pipelining/_unflatten.py index 659c9804a96..d5aaba95ea3 100644 --- a/torch/distributed/pipelining/_unflatten.py +++ b/torch/distributed/pipelining/_unflatten.py @@ -1,16 +1,17 @@ # mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates -from typing import Dict +from collections import defaultdict +from typing import Dict, List import torch -from torch.export.unflatten import _ModuleFrame +from torch.export.unflatten import _ModuleFrame, _SubmoduleEntry def _outline_submodules(orig_graph: torch.fx.Graph): # Create an empty GraphModule to hold the outlined modules new_module = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) seen_nodes: Dict[str, torch.fx.Node] = {} - seen_modules: Dict[int, torch.nn.Module] = {} + seen_modules: Dict[int, List[_SubmoduleEntry]] = defaultdict(list) _ModuleFrame( orig_graph, tuple(orig_graph.nodes), diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 55fe1a29440..dc250ca11c5 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -3,9 +3,11 @@ import abc import copy import logging import operator +import re from collections import defaultdict from contextlib import contextmanager from copy import deepcopy +from dataclasses import dataclass from enum import Enum from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union @@ -68,35 +70,47 @@ def _assign_attr( persistent: bool = True, ): *prefix, field = target.split(".") + # We need to generate all submodules of `to_module` that are at `prefix` and + # variants of `prefix` that differ only by call name. All of these submodules + # will then be assigned `from_obj` at `field` so that they can share this attribute. + # For example, if target is foo.bar.f, foo has another call name foo@1, + # and bar has other call names bar@1, bar@2, then we will assign f to + # foo.bar, foo.bar@1, foo.bar@2, foo@1.bar, foo@1.bar@1, foo@1.bar@2. + to_modules = [to_module] for item in prefix: - t = getattr(to_module, item, None) + ts: List[torch.nn.Module] = [] + for to_module in to_modules: + if not hasattr(to_module, item): + setattr(to_module, item, torch.nn.Module()) + ts.extend( + t_call # type: ignore[misc] + for k, t_call in to_module._modules.items() + if _is_call_name(k, item) + ) + to_modules = ts - if t is None: - t = torch.nn.Module() - setattr(to_module, item, t) - to_module = t - - if attr_kind == _AttrKind.PARAMETER: - assert isinstance(from_obj, torch.nn.Parameter) - to_module.register_parameter(field, from_obj) - elif attr_kind == _AttrKind.BUFFER: - assert isinstance(from_obj, torch.Tensor) - to_module.register_buffer(field, from_obj, persistent=persistent) - elif attr_kind == _AttrKind.CONSTANT: - assert not isinstance( - from_obj, FakeScriptObject - ), "FakeScriptObject should only exist during tracing." - assert isinstance( - from_obj, - ( - torch.Tensor, - torch.ScriptObject, - ), - ) - setattr(to_module, field, from_obj) - elif attr_kind == _AttrKind.MODULE: - assert isinstance(from_obj, torch.nn.Module) - setattr(to_module, field, from_obj) + for to_module in to_modules: + if attr_kind == _AttrKind.PARAMETER: + assert isinstance(from_obj, torch.nn.Parameter) + to_module.register_parameter(field, from_obj) + elif attr_kind == _AttrKind.BUFFER: + assert isinstance(from_obj, torch.Tensor) + to_module.register_buffer(field, from_obj, persistent=persistent) + elif attr_kind == _AttrKind.CONSTANT: + assert not isinstance( + from_obj, FakeScriptObject + ), "FakeScriptObject should only exist during tracing." + assert isinstance( + from_obj, + ( + torch.Tensor, + torch.ScriptObject, + ), + ) + setattr(to_module, field, from_obj) + elif attr_kind == _AttrKind.MODULE: + assert isinstance(from_obj, torch.nn.Module) + setattr(to_module, field, from_obj) class InterpreterModule(torch.nn.Module): @@ -220,7 +234,7 @@ class UnflattenedModule(torch.nn.Module): self._run_with_interpeter = RUN_WITH_INTERPRETER _inplace_buffer_mutations(export_graph, self.graph_signature) - _outline_submodules(export_graph, self) + seen_modules = _outline_submodules(export_graph, self) self.range_constraints = export_module.range_constraints self.equality_constraints: List = [] @@ -392,6 +406,7 @@ class UnflattenedModule(torch.nn.Module): inputs_to_state[n] = targets _sink_params(self, inputs_to_state, []) + _deduplicate_modules(seen_modules.values()) # Helper function to check input nodes of `module` has been processed. def check_module_inputs(module, scope): @@ -436,9 +451,6 @@ class UnflattenedModule(torch.nn.Module): if name not in fqn_order: fqn_order[name] = len(fqn_order) _reorder_submodules(self, fqn_order) - assert [fqn for fqn, _ in self.named_modules(remove_duplicate=False)] == list( - fqn_order.keys() - ) self.graph.lint() def _print_graph(self): @@ -644,7 +656,7 @@ def _compute_accessor(parent_fqn: str, child_fqn: str) -> str: return ".".join(child_split[len(parent_split) :]) -def _verify_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module): +def _check_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module): def graph_dump(graph: torch.fx.Graph) -> str: ret = [] nodes_idx: Dict[int, int] = {} @@ -665,7 +677,7 @@ def _verify_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module): nodes_idx[id(node)] = i return "\n".join(ret) - assert graph_dump(x.graph) == graph_dump(y.graph) + return graph_dump(x.graph) == graph_dump(y.graph) def _add_spec(gm: torch.nn.Module, spec) -> str: @@ -724,6 +736,17 @@ def _add_submodule(mod: torch.nn.Module, target: str, module_to_add: torch.nn.Mo mod.add_module(field, module_to_add) +def _call_name(base: str, n: int) -> str: + # Given n >= 0, generate call names to a submodule `base` of the form + # `base`, `base@1`, `base@2`, etc. + return base if n == 1 else f"{base}@{n-1}" + + +def _is_call_name(call_name: str, base: str) -> bool: + # Recognize when call_name = _call_name(base, n) for some n >= 0. + return re.match(re.escape(base) + r"(@\d+)?$", call_name) is not None + + class _ModuleFrame: def __init__( self, @@ -753,11 +776,6 @@ class _ModuleFrame: self.module = module else: self.module = InterpreterModule(torch.fx.Graph()) - if self.module_id in self.seen_modules: - self.cached_graph_module = self.seen_modules[self.module_id] - else: - self.cached_graph_module = None - self.seen_modules[self.module_id] = self.module self.graph = self.module.graph @@ -767,17 +785,31 @@ class _ModuleFrame: self.parent_call_module: Optional[torch.fx.Node] = None if parent is not None: - accessor = _compute_accessor(parent.fqn, self.fqn) - _add_submodule( - parent.module, - accessor, - ( - self.module - if self.cached_graph_module is None - else self.cached_graph_module - ), + num_calls = len( + [x for x in self.seen_modules[self.module_id] if x.fqn == self.fqn] ) + if self.fqn in module_call_graph and num_calls == 1: + raise ValueError( + f"Cannot unflatten multiple calls to module {self.fqn} while preserving its signature " + "because each of these calls might have generated a different specialized graph. " + f"If the reason you want to preserve the signature is to swap {self.fqn} with another module, " + "consider using _swap_modules() directly on the exported program instead of unflattening it." + ) + # generate call name for self.fqn + child_fqn = _call_name(self.fqn, num_calls + 1) + accessor = _compute_accessor(parent.fqn, child_fqn) + _add_submodule(parent.module, accessor, self.module) self.parent_call_module = parent.graph.call_module(accessor) + self.seen_modules[self.module_id].append( + _SubmoduleEntry( + parent_fqn=self.parent.fqn, + parent_module=self.parent.module, + parent_call_module=self.parent_call_module, + fqn=self.fqn, + call_idx=num_calls + 1, + module=self.module, + ) + ) signature = module_call_graph.get(self.fqn) if signature is not None and self.parent is not None: @@ -1002,9 +1034,6 @@ class _ModuleFrame: proxy_out.meta["val"] = orig_output.meta.get("val") self.parent.node_map[orig_output] = proxy_out - if self.cached_graph_module is not None: - _verify_graph_equivalence(self.cached_graph_module, self.module) - def copy_node(self, node): self.print("copying", node.format_node()) self.node_map[node] = self.graph.node_copy(node, self.remap_input) @@ -1115,9 +1144,19 @@ class _ModuleFrame: node_idx += 1 +@dataclass +class _SubmoduleEntry: + parent_fqn: str + parent_module: torch.nn.Module + parent_call_module: torch.fx.Node + fqn: str + call_idx: int + module: torch.nn.Module + + def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModule): seen_nodes: Dict[str, torch.fx.Node] = {} - seen_modules: Dict[int, torch.nn.Module] = {} + seen_modules: Dict[int, List[_SubmoduleEntry]] = defaultdict(list) _ModuleFrame( orig_graph, tuple(orig_graph.nodes), @@ -1133,6 +1172,7 @@ def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModu }, module=root_module, ).run_outer() + return seen_modules def _reorder_submodules( @@ -1157,6 +1197,46 @@ def _reorder_submodules( parent.register_module(name, child) +def _deduplicate_modules(partitions): + for shared_submodules in partitions: + for i, entry in enumerate(shared_submodules): + child_fqn = _call_name(entry.fqn, entry.call_idx) + target = _compute_accessor(entry.parent_fqn, child_fqn) + deduplicated = False + # Iterate over all previously seen modules, and deduplicate if possible + for seen in shared_submodules[:i]: + if _check_graph_equivalence(seen.module, entry.module): + # Since graphs are equivalent, we can deduplicate. + # There are two cases. + if seen.fqn == entry.fqn: + # Case 1: The current module has the same fqn as the seen module. + # In this case we have generated a call name that can be optimized away. + # So we remove the current module from the hierarchy and replace + # the current call name with the seen call name in the parent graph. + *prefix, name = target.split(".") + _recursive_getattr(entry.parent_module, prefix)._modules.pop( + name + ) + seen_child_fqn = _call_name(seen.fqn, seen.call_idx) + seen_target = _compute_accessor( + entry.parent_fqn, seen_child_fqn + ) + entry.parent_call_module.target = seen_target # type: ignore[union-attr] + break + elif not deduplicated: + # Case 2: The current module has a different fqn than the seen module. + # In this case we replace the current module with the seen module. + # There should be nothing pointing to the current module any more, + # so it can be garbage collected. + # NOTE: We *do not* replace the current call name with the seen call name + # in the parent graph, because this will lose information on which fqn + # was actually called. However, it is possible that the current call name + # will be optimized away when we find another seen module with the same fqn, + # so we do not break out of the loop yet. + entry.parent_module.set_submodule(target, seen.module) + deduplicated = True + + def _sink_params( module: torch.nn.Module, inputs_to_state: Dict[str, List[str]],