mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
unflatten with specialized graphs per submodule call (#137013)
Previously we were making a fairly restrictive assumption when unflattening an exported program: for any submodule, we would assert that the graph of every call to that submodule must be the same. This assertion is load-bearing, i.e., if we simply remove the assertion then we can get incorrect results, as shown by the following example.
```
class N(torch.nn.Module):
def forward(self, x, b):
if b:
return x + 1
else:
return x + 2
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.n = N()
def forward(self, x):
x0 = x + 3
x1 = self.n(x0, True)
x2 = x1 + 4
x3 = self.n(x2, False)
return x3 + 5
m = M()
inp = (torch.ones(1),)
print(m(*inp)) # tensor([16.])
ep = torch.export.export(m, inp)
print(ep.module()(*inp)) # tensor([16.])
unflattened = torch.export.unflatten(ep)
print(unflattened(*inp)) # tensor([15.])
```
However, this goes against the spirit of specializing graphs when exporting: we should *expect* that for every call to a submodule we *might* generate a different graph. The goal of this PR is to fix unflattening to handle multiple specialized graphs corresponding to multiple calls to the same submodule.
The idea is simple: for every call to a child module `foo`, we will create potentially different child modules `foo`, `foo@1`, `foo@2`, etc. and use those names as targets in `callmodule` instructions in the parent graph. An immediate consequence of this is that the list of fqns in an unflattened module may not be the same as an exported module. Note that all these variants share the same parameters / buffers, so that multiple calls to the same submodule can share state as expected.
However, as described so far this scheme may end up with needlessly too many submodules. Thus, between calls to the same submodule, if graphs are equal then we optimize away the extra submodules and reuse call names as much as possible. Moreover, when submodules are shared across fqns, we also try to de-duplicate graphs corresponding to their calls as much as possible. Note that no matter what, information about which submodule was called is still preserved, so that if a submodule has to be swapped with another, one can still find all calls to the former submodule and replace them with calls to the latter.
A note on the choice of naming scheme for call names: instead of generating "sibling" modules `foo@1`, `foo@2`, etc. for `foo`, we had considered generating "children" modules `foo._1`, `foo._2`, etc. of `foo`. However this can cause spurious cycles when de-duplicating graphs. E.g., suppose that `foo` is an alias for `bar._1` and `foo._1` is an alias for `bar`, then we must either introduce a cycle or drop the opportunity to optimize. Another idea would be to make `foo` a dummy module that contains `foo._0` corresponding to the first call, but this necessitates too many changes to existing tests and hurts the common case.
Differential Revision: D63642479
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137013
Approved by: https://github.com/pianpwk
This commit is contained in:
parent
6241006c28
commit
cd5d1fe015
3 changed files with 379 additions and 54 deletions
|
|
@ -5974,6 +5974,250 @@ graph():
|
|||
|
||||
self.assertEqual(gm_flat_non_strict(*inp), gm_flat_strict(*inp))
|
||||
|
||||
def test_unflatten_multiple_graphs_preserve_signature_error(self):
|
||||
class N(torch.nn.Module):
|
||||
def forward(self, x, b):
|
||||
if b:
|
||||
return x + 1
|
||||
else:
|
||||
return x + 2
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.n = N()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.n(x, True)
|
||||
x = x + 1
|
||||
x = self.n(x, False)
|
||||
x = x + 1
|
||||
return x
|
||||
|
||||
inp = (torch.ones(1),)
|
||||
m = M()
|
||||
eager_result = m(*inp)
|
||||
|
||||
if not is_retracebility_test(self._testMethodName):
|
||||
ep = export(M(), inp, preserve_module_call_signature=("n",))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Cannot unflatten multiple calls to module n while preserving its signature",
|
||||
):
|
||||
torch.export.unflatten(ep)
|
||||
|
||||
ep = export(M(), inp)
|
||||
epm = ep.module()
|
||||
ufm = torch.export.unflatten(ep)
|
||||
|
||||
exported_result = epm(*inp)
|
||||
self.assertTrue(torch.allclose(exported_result, eager_result))
|
||||
|
||||
unflattened_result = ufm(*inp)
|
||||
self.assertTrue(torch.allclose(unflattened_result, eager_result))
|
||||
|
||||
def test_unflatten_multiple_graphs_state(self):
|
||||
class N(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("buf", torch.ones(1), persistent=False)
|
||||
|
||||
def forward(self, x, b):
|
||||
if b:
|
||||
self.buf.add_(1)
|
||||
else:
|
||||
self.buf.add_(2)
|
||||
return x + self.buf
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.n = N()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.n(x, True)
|
||||
x = x + 1
|
||||
x = self.n(x, False)
|
||||
x = x + 1
|
||||
x = self.n(x, True)
|
||||
x = x + 1
|
||||
x = self.n(x, False)
|
||||
return x
|
||||
|
||||
inp = (torch.ones(1),)
|
||||
m = M()
|
||||
eager_result = m(*inp)
|
||||
|
||||
ep = export(M(), inp)
|
||||
epm = ep.module()
|
||||
ufm = torch.export.unflatten(ep)
|
||||
|
||||
exported_result = epm(*inp)
|
||||
self.assertTrue(torch.allclose(exported_result, eager_result))
|
||||
|
||||
unflattened_result = ufm(*inp)
|
||||
self.assertTrue(torch.allclose(unflattened_result, eager_result))
|
||||
|
||||
def test_unflatten_multiple_graphs_shared_submodule(self):
|
||||
class N(torch.nn.Module):
|
||||
def forward(self, x, b):
|
||||
if b:
|
||||
return x + 1
|
||||
else:
|
||||
return x + 2
|
||||
|
||||
def gen_m(n, n_1, p, p_1):
|
||||
# Create a module instance where self.n and self.p
|
||||
# share the same submodule instance.
|
||||
# The booleans n, n_1 and p, p_1 are passed to two calls each
|
||||
# to self.n and self.p, and they determine which path through
|
||||
# the shared submodule instance is taken during export.
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.n = N()
|
||||
self.p = self.n
|
||||
|
||||
def forward(self, x):
|
||||
x = x + 3
|
||||
x = self.n(x, n)
|
||||
x = x + 4
|
||||
x = self.n(x, n_1)
|
||||
x = x + 5
|
||||
x = self.p(x, p)
|
||||
x = x + 6
|
||||
x = self.p(x, p_1)
|
||||
return x + 7
|
||||
|
||||
return M()
|
||||
|
||||
inp = (torch.ones(1),)
|
||||
|
||||
def test(m, expected_graph, expected_fqns, expected_duplicates):
|
||||
eager_result = m(*inp)
|
||||
|
||||
ep = export(m, inp)
|
||||
exported_result = ep.module()(*inp)
|
||||
# exported and eager results should match (baseline)
|
||||
self.assertTrue(torch.allclose(exported_result, eager_result))
|
||||
|
||||
unflattened = torch.export.unflatten(ep)
|
||||
unflattened_result = unflattened(*inp)
|
||||
# unflattened and eager results should match
|
||||
# (needs multiple specialized graphs for shared submodule instance)
|
||||
self.assertTrue(torch.allclose(unflattened_result, eager_result))
|
||||
|
||||
# expected graph should call minimal number of specialized submodules
|
||||
self.assertExpectedInline(
|
||||
str(unflattened.graph).strip(),
|
||||
expected_graph,
|
||||
)
|
||||
|
||||
# expected graph should contain minimal number of specialized submodule fqns
|
||||
self.assertEqual(
|
||||
sorted(
|
||||
[
|
||||
fqn
|
||||
for fqn, _ in unflattened.named_modules(remove_duplicate=False)
|
||||
]
|
||||
),
|
||||
expected_fqns,
|
||||
)
|
||||
# expected graph should contain minimal number of specialized submodule instances
|
||||
for a, b in expected_duplicates:
|
||||
if is_non_strict_test(self._testMethodName):
|
||||
# NOTE: non-strict does not de-duplicate shared submodules through different fqns.
|
||||
# In particular, we use different module ids for self.n and self.p calls in non-strict,
|
||||
# but in strict we use the same module id, which enables additional reuse.
|
||||
# This is pre-existing behavior that might need to be fixed orthogonally.
|
||||
self.assertNotEqual(
|
||||
id(getattr(unflattened, a)), id(getattr(unflattened, b))
|
||||
)
|
||||
else:
|
||||
self.assertEqual(
|
||||
id(getattr(unflattened, a)), id(getattr(unflattened, b))
|
||||
)
|
||||
|
||||
test(
|
||||
gen_m(n=True, n_1=False, p=False, p_1=False),
|
||||
# p should share n_1 graph, p_1 should be optimized away
|
||||
"""\
|
||||
graph():
|
||||
%x : [num_users=1] = placeholder[target=x]
|
||||
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 3), kwargs = {})
|
||||
%n : [num_users=1] = call_module[target=n](args = (%add,), kwargs = {})
|
||||
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%n, 4), kwargs = {})
|
||||
%n_1 : [num_users=1] = call_module[target=n@1](args = (%add_2,), kwargs = {})
|
||||
%add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%n_1, 5), kwargs = {})
|
||||
%p : [num_users=1] = call_module[target=p](args = (%add_4,), kwargs = {})
|
||||
%add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%p, 6), kwargs = {})
|
||||
%p_1 : [num_users=1] = call_module[target=p](args = (%add_6,), kwargs = {})
|
||||
%add_8 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%p_1, 7), kwargs = {})
|
||||
return (add_8,)""",
|
||||
["", "n", "n@1", "p"],
|
||||
[("n@1", "p")],
|
||||
)
|
||||
|
||||
test(
|
||||
gen_m(n=True, n_1=False, p=True, p_1=False),
|
||||
# p should reuse n graph, p_1 should reuse n_1 graph
|
||||
"""\
|
||||
graph():
|
||||
%x : [num_users=1] = placeholder[target=x]
|
||||
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 3), kwargs = {})
|
||||
%n : [num_users=1] = call_module[target=n](args = (%add,), kwargs = {})
|
||||
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%n, 4), kwargs = {})
|
||||
%n_1 : [num_users=1] = call_module[target=n@1](args = (%add_2,), kwargs = {})
|
||||
%add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%n_1, 5), kwargs = {})
|
||||
%p : [num_users=1] = call_module[target=p](args = (%add_4,), kwargs = {})
|
||||
%add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%p, 6), kwargs = {})
|
||||
%p_1 : [num_users=1] = call_module[target=p@1](args = (%add_6,), kwargs = {})
|
||||
%add_8 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%p_1, 7), kwargs = {})
|
||||
return (add_8,)""",
|
||||
["", "n", "n@1", "p", "p@1"],
|
||||
[("n", "p"), ("n@1", "p@1")],
|
||||
)
|
||||
|
||||
test(
|
||||
gen_m(n=True, n_1=True, p=True, p_1=False),
|
||||
# n_1 should be optimized away, p should reuse n graph
|
||||
"""\
|
||||
graph():
|
||||
%x : [num_users=1] = placeholder[target=x]
|
||||
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 3), kwargs = {})
|
||||
%n : [num_users=1] = call_module[target=n](args = (%add,), kwargs = {})
|
||||
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%n, 4), kwargs = {})
|
||||
%n_1 : [num_users=1] = call_module[target=n](args = (%add_2,), kwargs = {})
|
||||
%add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%n_1, 5), kwargs = {})
|
||||
%p : [num_users=1] = call_module[target=p](args = (%add_4,), kwargs = {})
|
||||
%add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%p, 6), kwargs = {})
|
||||
%p_1 : [num_users=1] = call_module[target=p@1](args = (%add_6,), kwargs = {})
|
||||
%add_8 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%p_1, 7), kwargs = {})
|
||||
return (add_8,)""",
|
||||
["", "n", "p", "p@1"],
|
||||
[("n", "p")],
|
||||
)
|
||||
|
||||
test(
|
||||
gen_m(n=True, n_1=False, p=False, p_1=True),
|
||||
# p should reuse n_1 graph, p_1 should reuse n graph
|
||||
"""\
|
||||
graph():
|
||||
%x : [num_users=1] = placeholder[target=x]
|
||||
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 3), kwargs = {})
|
||||
%n : [num_users=1] = call_module[target=n](args = (%add,), kwargs = {})
|
||||
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%n, 4), kwargs = {})
|
||||
%n_1 : [num_users=1] = call_module[target=n@1](args = (%add_2,), kwargs = {})
|
||||
%add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%n_1, 5), kwargs = {})
|
||||
%p : [num_users=1] = call_module[target=p](args = (%add_4,), kwargs = {})
|
||||
%add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%p, 6), kwargs = {})
|
||||
%p_1 : [num_users=1] = call_module[target=p@1](args = (%add_6,), kwargs = {})
|
||||
%add_8 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%p_1, 7), kwargs = {})
|
||||
return (add_8,)""",
|
||||
["", "n", "n@1", "p", "p@1"],
|
||||
[("n", "p@1"), ("p", "n@1")],
|
||||
)
|
||||
|
||||
def test_stack_trace(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
|
|
|||
|
|
@ -1,16 +1,17 @@
|
|||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from typing import Dict
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch.export.unflatten import _ModuleFrame
|
||||
from torch.export.unflatten import _ModuleFrame, _SubmoduleEntry
|
||||
|
||||
|
||||
def _outline_submodules(orig_graph: torch.fx.Graph):
|
||||
# Create an empty GraphModule to hold the outlined modules
|
||||
new_module = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
|
||||
seen_nodes: Dict[str, torch.fx.Node] = {}
|
||||
seen_modules: Dict[int, torch.nn.Module] = {}
|
||||
seen_modules: Dict[int, List[_SubmoduleEntry]] = defaultdict(list)
|
||||
_ModuleFrame(
|
||||
orig_graph,
|
||||
tuple(orig_graph.nodes),
|
||||
|
|
|
|||
|
|
@ -3,9 +3,11 @@ import abc
|
|||
import copy
|
||||
import logging
|
||||
import operator
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
|
|
@ -68,35 +70,47 @@ def _assign_attr(
|
|||
persistent: bool = True,
|
||||
):
|
||||
*prefix, field = target.split(".")
|
||||
# We need to generate all submodules of `to_module` that are at `prefix` and
|
||||
# variants of `prefix` that differ only by call name. All of these submodules
|
||||
# will then be assigned `from_obj` at `field` so that they can share this attribute.
|
||||
# For example, if target is foo.bar.f, foo has another call name foo@1,
|
||||
# and bar has other call names bar@1, bar@2, then we will assign f to
|
||||
# foo.bar, foo.bar@1, foo.bar@2, foo@1.bar, foo@1.bar@1, foo@1.bar@2.
|
||||
to_modules = [to_module]
|
||||
for item in prefix:
|
||||
t = getattr(to_module, item, None)
|
||||
ts: List[torch.nn.Module] = []
|
||||
for to_module in to_modules:
|
||||
if not hasattr(to_module, item):
|
||||
setattr(to_module, item, torch.nn.Module())
|
||||
ts.extend(
|
||||
t_call # type: ignore[misc]
|
||||
for k, t_call in to_module._modules.items()
|
||||
if _is_call_name(k, item)
|
||||
)
|
||||
to_modules = ts
|
||||
|
||||
if t is None:
|
||||
t = torch.nn.Module()
|
||||
setattr(to_module, item, t)
|
||||
to_module = t
|
||||
|
||||
if attr_kind == _AttrKind.PARAMETER:
|
||||
assert isinstance(from_obj, torch.nn.Parameter)
|
||||
to_module.register_parameter(field, from_obj)
|
||||
elif attr_kind == _AttrKind.BUFFER:
|
||||
assert isinstance(from_obj, torch.Tensor)
|
||||
to_module.register_buffer(field, from_obj, persistent=persistent)
|
||||
elif attr_kind == _AttrKind.CONSTANT:
|
||||
assert not isinstance(
|
||||
from_obj, FakeScriptObject
|
||||
), "FakeScriptObject should only exist during tracing."
|
||||
assert isinstance(
|
||||
from_obj,
|
||||
(
|
||||
torch.Tensor,
|
||||
torch.ScriptObject,
|
||||
),
|
||||
)
|
||||
setattr(to_module, field, from_obj)
|
||||
elif attr_kind == _AttrKind.MODULE:
|
||||
assert isinstance(from_obj, torch.nn.Module)
|
||||
setattr(to_module, field, from_obj)
|
||||
for to_module in to_modules:
|
||||
if attr_kind == _AttrKind.PARAMETER:
|
||||
assert isinstance(from_obj, torch.nn.Parameter)
|
||||
to_module.register_parameter(field, from_obj)
|
||||
elif attr_kind == _AttrKind.BUFFER:
|
||||
assert isinstance(from_obj, torch.Tensor)
|
||||
to_module.register_buffer(field, from_obj, persistent=persistent)
|
||||
elif attr_kind == _AttrKind.CONSTANT:
|
||||
assert not isinstance(
|
||||
from_obj, FakeScriptObject
|
||||
), "FakeScriptObject should only exist during tracing."
|
||||
assert isinstance(
|
||||
from_obj,
|
||||
(
|
||||
torch.Tensor,
|
||||
torch.ScriptObject,
|
||||
),
|
||||
)
|
||||
setattr(to_module, field, from_obj)
|
||||
elif attr_kind == _AttrKind.MODULE:
|
||||
assert isinstance(from_obj, torch.nn.Module)
|
||||
setattr(to_module, field, from_obj)
|
||||
|
||||
|
||||
class InterpreterModule(torch.nn.Module):
|
||||
|
|
@ -220,7 +234,7 @@ class UnflattenedModule(torch.nn.Module):
|
|||
self._run_with_interpeter = RUN_WITH_INTERPRETER
|
||||
|
||||
_inplace_buffer_mutations(export_graph, self.graph_signature)
|
||||
_outline_submodules(export_graph, self)
|
||||
seen_modules = _outline_submodules(export_graph, self)
|
||||
|
||||
self.range_constraints = export_module.range_constraints
|
||||
self.equality_constraints: List = []
|
||||
|
|
@ -392,6 +406,7 @@ class UnflattenedModule(torch.nn.Module):
|
|||
inputs_to_state[n] = targets
|
||||
|
||||
_sink_params(self, inputs_to_state, [])
|
||||
_deduplicate_modules(seen_modules.values())
|
||||
|
||||
# Helper function to check input nodes of `module` has been processed.
|
||||
def check_module_inputs(module, scope):
|
||||
|
|
@ -436,9 +451,6 @@ class UnflattenedModule(torch.nn.Module):
|
|||
if name not in fqn_order:
|
||||
fqn_order[name] = len(fqn_order)
|
||||
_reorder_submodules(self, fqn_order)
|
||||
assert [fqn for fqn, _ in self.named_modules(remove_duplicate=False)] == list(
|
||||
fqn_order.keys()
|
||||
)
|
||||
self.graph.lint()
|
||||
|
||||
def _print_graph(self):
|
||||
|
|
@ -644,7 +656,7 @@ def _compute_accessor(parent_fqn: str, child_fqn: str) -> str:
|
|||
return ".".join(child_split[len(parent_split) :])
|
||||
|
||||
|
||||
def _verify_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module):
|
||||
def _check_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module):
|
||||
def graph_dump(graph: torch.fx.Graph) -> str:
|
||||
ret = []
|
||||
nodes_idx: Dict[int, int] = {}
|
||||
|
|
@ -665,7 +677,7 @@ def _verify_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module):
|
|||
nodes_idx[id(node)] = i
|
||||
return "\n".join(ret)
|
||||
|
||||
assert graph_dump(x.graph) == graph_dump(y.graph)
|
||||
return graph_dump(x.graph) == graph_dump(y.graph)
|
||||
|
||||
|
||||
def _add_spec(gm: torch.nn.Module, spec) -> str:
|
||||
|
|
@ -724,6 +736,17 @@ def _add_submodule(mod: torch.nn.Module, target: str, module_to_add: torch.nn.Mo
|
|||
mod.add_module(field, module_to_add)
|
||||
|
||||
|
||||
def _call_name(base: str, n: int) -> str:
|
||||
# Given n >= 0, generate call names to a submodule `base` of the form
|
||||
# `base`, `base@1`, `base@2`, etc.
|
||||
return base if n == 1 else f"{base}@{n-1}"
|
||||
|
||||
|
||||
def _is_call_name(call_name: str, base: str) -> bool:
|
||||
# Recognize when call_name = _call_name(base, n) for some n >= 0.
|
||||
return re.match(re.escape(base) + r"(@\d+)?$", call_name) is not None
|
||||
|
||||
|
||||
class _ModuleFrame:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -753,11 +776,6 @@ class _ModuleFrame:
|
|||
self.module = module
|
||||
else:
|
||||
self.module = InterpreterModule(torch.fx.Graph())
|
||||
if self.module_id in self.seen_modules:
|
||||
self.cached_graph_module = self.seen_modules[self.module_id]
|
||||
else:
|
||||
self.cached_graph_module = None
|
||||
self.seen_modules[self.module_id] = self.module
|
||||
|
||||
self.graph = self.module.graph
|
||||
|
||||
|
|
@ -767,17 +785,31 @@ class _ModuleFrame:
|
|||
|
||||
self.parent_call_module: Optional[torch.fx.Node] = None
|
||||
if parent is not None:
|
||||
accessor = _compute_accessor(parent.fqn, self.fqn)
|
||||
_add_submodule(
|
||||
parent.module,
|
||||
accessor,
|
||||
(
|
||||
self.module
|
||||
if self.cached_graph_module is None
|
||||
else self.cached_graph_module
|
||||
),
|
||||
num_calls = len(
|
||||
[x for x in self.seen_modules[self.module_id] if x.fqn == self.fqn]
|
||||
)
|
||||
if self.fqn in module_call_graph and num_calls == 1:
|
||||
raise ValueError(
|
||||
f"Cannot unflatten multiple calls to module {self.fqn} while preserving its signature "
|
||||
"because each of these calls might have generated a different specialized graph. "
|
||||
f"If the reason you want to preserve the signature is to swap {self.fqn} with another module, "
|
||||
"consider using _swap_modules() directly on the exported program instead of unflattening it."
|
||||
)
|
||||
# generate call name for self.fqn
|
||||
child_fqn = _call_name(self.fqn, num_calls + 1)
|
||||
accessor = _compute_accessor(parent.fqn, child_fqn)
|
||||
_add_submodule(parent.module, accessor, self.module)
|
||||
self.parent_call_module = parent.graph.call_module(accessor)
|
||||
self.seen_modules[self.module_id].append(
|
||||
_SubmoduleEntry(
|
||||
parent_fqn=self.parent.fqn,
|
||||
parent_module=self.parent.module,
|
||||
parent_call_module=self.parent_call_module,
|
||||
fqn=self.fqn,
|
||||
call_idx=num_calls + 1,
|
||||
module=self.module,
|
||||
)
|
||||
)
|
||||
|
||||
signature = module_call_graph.get(self.fqn)
|
||||
if signature is not None and self.parent is not None:
|
||||
|
|
@ -1002,9 +1034,6 @@ class _ModuleFrame:
|
|||
proxy_out.meta["val"] = orig_output.meta.get("val")
|
||||
self.parent.node_map[orig_output] = proxy_out
|
||||
|
||||
if self.cached_graph_module is not None:
|
||||
_verify_graph_equivalence(self.cached_graph_module, self.module)
|
||||
|
||||
def copy_node(self, node):
|
||||
self.print("copying", node.format_node())
|
||||
self.node_map[node] = self.graph.node_copy(node, self.remap_input)
|
||||
|
|
@ -1115,9 +1144,19 @@ class _ModuleFrame:
|
|||
node_idx += 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SubmoduleEntry:
|
||||
parent_fqn: str
|
||||
parent_module: torch.nn.Module
|
||||
parent_call_module: torch.fx.Node
|
||||
fqn: str
|
||||
call_idx: int
|
||||
module: torch.nn.Module
|
||||
|
||||
|
||||
def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModule):
|
||||
seen_nodes: Dict[str, torch.fx.Node] = {}
|
||||
seen_modules: Dict[int, torch.nn.Module] = {}
|
||||
seen_modules: Dict[int, List[_SubmoduleEntry]] = defaultdict(list)
|
||||
_ModuleFrame(
|
||||
orig_graph,
|
||||
tuple(orig_graph.nodes),
|
||||
|
|
@ -1133,6 +1172,7 @@ def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModu
|
|||
},
|
||||
module=root_module,
|
||||
).run_outer()
|
||||
return seen_modules
|
||||
|
||||
|
||||
def _reorder_submodules(
|
||||
|
|
@ -1157,6 +1197,46 @@ def _reorder_submodules(
|
|||
parent.register_module(name, child)
|
||||
|
||||
|
||||
def _deduplicate_modules(partitions):
|
||||
for shared_submodules in partitions:
|
||||
for i, entry in enumerate(shared_submodules):
|
||||
child_fqn = _call_name(entry.fqn, entry.call_idx)
|
||||
target = _compute_accessor(entry.parent_fqn, child_fqn)
|
||||
deduplicated = False
|
||||
# Iterate over all previously seen modules, and deduplicate if possible
|
||||
for seen in shared_submodules[:i]:
|
||||
if _check_graph_equivalence(seen.module, entry.module):
|
||||
# Since graphs are equivalent, we can deduplicate.
|
||||
# There are two cases.
|
||||
if seen.fqn == entry.fqn:
|
||||
# Case 1: The current module has the same fqn as the seen module.
|
||||
# In this case we have generated a call name that can be optimized away.
|
||||
# So we remove the current module from the hierarchy and replace
|
||||
# the current call name with the seen call name in the parent graph.
|
||||
*prefix, name = target.split(".")
|
||||
_recursive_getattr(entry.parent_module, prefix)._modules.pop(
|
||||
name
|
||||
)
|
||||
seen_child_fqn = _call_name(seen.fqn, seen.call_idx)
|
||||
seen_target = _compute_accessor(
|
||||
entry.parent_fqn, seen_child_fqn
|
||||
)
|
||||
entry.parent_call_module.target = seen_target # type: ignore[union-attr]
|
||||
break
|
||||
elif not deduplicated:
|
||||
# Case 2: The current module has a different fqn than the seen module.
|
||||
# In this case we replace the current module with the seen module.
|
||||
# There should be nothing pointing to the current module any more,
|
||||
# so it can be garbage collected.
|
||||
# NOTE: We *do not* replace the current call name with the seen call name
|
||||
# in the parent graph, because this will lose information on which fqn
|
||||
# was actually called. However, it is possible that the current call name
|
||||
# will be optimized away when we find another seen module with the same fqn,
|
||||
# so we do not break out of the loop yet.
|
||||
entry.parent_module.set_submodule(target, seen.module)
|
||||
deduplicated = True
|
||||
|
||||
|
||||
def _sink_params(
|
||||
module: torch.nn.Module,
|
||||
inputs_to_state: Dict[str, List[str]],
|
||||
|
|
|
|||
Loading…
Reference in a new issue