mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[alt] fix unroll in successive unflatten (#137646)
We use nn_module_stack in unflatten to recognize when module calls begin and end. However the current format is not sufficient to detect module call boundaries when we have successive calls to the same module, because the successive instructions (end of one call, begin of next call) have the same nn_module_stack. This causes us to effectively "unroll" successive calls to a single call. This can cause problems when preserving module call signatures because the outputs of the successive calls might be concatenated in the single call. Previously we introduced the concept of a "call index" to generate multiple graphs when unflattening, one per call. This PR pushes this concept into nn_module_stack itself. In particular, the keys of nn_module_stack now go from `key` to `key@call_index`. (In a previous attempt, https://github.com/pytorch/pytorch/pull/137457, instead values in nn_module_stack go from (fqn, type) to (fqn, type, call_index), which is BC-breaking.) Note that we still do not have the ability to preserve module call signatures for multiple calls to the same module. But now instead of randomly crashing we give a proper error. OTOH when not preserving module call signatures we simply generate multiple calls, each with its own graph, possibly deduplicated, matching what we would do for non-successive calls. Test Plan: Like D64014936 Differential Revision: D64136277 Pull Request resolved: https://github.com/pytorch/pytorch/pull/137646 Approved by: https://github.com/angelayi
This commit is contained in:
parent
561f07fae7
commit
ed55d356de
8 changed files with 84 additions and 16 deletions
|
|
@ -6074,6 +6074,62 @@ graph():
|
|||
|
||||
self.assertEqual(gm_flat_non_strict(*inp), gm_flat_strict(*inp))
|
||||
|
||||
def test_unflatten_no_unroll(self):
|
||||
inp = (torch.ones(1),)
|
||||
|
||||
class N(torch.nn.Module):
|
||||
def forward(self, x, b):
|
||||
if b:
|
||||
return x + 1
|
||||
else:
|
||||
return x + 2
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.n = N()
|
||||
|
||||
def forward(self, x):
|
||||
x0 = x + 3
|
||||
x1 = self.n(x0, True)
|
||||
x2 = self.n(x0, False)
|
||||
return x1 + x2
|
||||
|
||||
m = M()
|
||||
eager_result = m(*inp)
|
||||
|
||||
if not is_retracebility_test(self._testMethodName):
|
||||
ep = export(M(), inp, preserve_module_call_signature=("n",))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Cannot unflatten multiple calls to module n while preserving its signature",
|
||||
):
|
||||
torch.export.unflatten(ep)
|
||||
|
||||
ep = export(M(), inp)
|
||||
print(ep)
|
||||
epm = ep.module()
|
||||
ufm = torch.export.unflatten(ep)
|
||||
|
||||
exported_result = epm(*inp)
|
||||
self.assertTrue(torch.allclose(exported_result, eager_result))
|
||||
|
||||
unflattened_result = ufm(*inp)
|
||||
self.assertTrue(torch.allclose(unflattened_result, eager_result))
|
||||
|
||||
class _N(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + 1
|
||||
|
||||
class _N_1(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + 2
|
||||
|
||||
ufm.set_submodule("n", _N())
|
||||
ufm.set_submodule("n@1", _N_1())
|
||||
unflattened_result = ufm(*inp)
|
||||
self.assertTrue(torch.allclose(unflattened_result, eager_result))
|
||||
|
||||
def test_preserve_module_call_signature_unflatten_specialization(self):
|
||||
class N(torch.nn.Module):
|
||||
def forward(self, x, b):
|
||||
|
|
|
|||
|
|
@ -1995,7 +1995,11 @@ class SubgraphTracer(fx.Tracer):
|
|||
rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
|
||||
(
|
||||
rv.node.name,
|
||||
rv.node.meta["nn_module_stack"][target][1],
|
||||
next(
|
||||
ty
|
||||
for k, (_, ty) in rv.node.meta["nn_module_stack"].items()
|
||||
if k.split("@")[0] == target
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -2710,6 +2710,7 @@ class InstructionTranslatorBase(
|
|||
# The first field of tuple is the fully qualified name of current module
|
||||
# in original hierarchy. The second field is the type of current nn.module
|
||||
self.nn_module_stack: Dict[str, Tuple[str, Type[Any]]] = {}
|
||||
self.num_calls: Dict[str, int] = {}
|
||||
# Flag to indicate whether tracing is used for export.
|
||||
self.export = export
|
||||
self.one_graph = False
|
||||
|
|
|
|||
|
|
@ -82,8 +82,11 @@ def initialize_lazy_module(tx: "InstructionTranslator", mod, args, kwargs):
|
|||
@contextmanager
|
||||
def record_nn_module_stack(module_key: str, source, tx, mod: torch.nn.Module):
|
||||
fully_qualified_name = source.name()
|
||||
num_calls = tx.num_calls.get(fully_qualified_name, 0)
|
||||
module_key = f"{module_key}@{num_calls}" if num_calls > 0 else module_key
|
||||
try:
|
||||
tx.nn_module_stack[module_key] = (fully_qualified_name, mod.__class__)
|
||||
tx.num_calls[fully_qualified_name] = num_calls + 1
|
||||
yield
|
||||
finally:
|
||||
del tx.nn_module_stack[module_key]
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ def _outline_submodules(orig_graph: torch.fx.Graph):
|
|||
seen_nodes,
|
||||
seen_modules,
|
||||
None,
|
||||
[""],
|
||||
[("", 0)],
|
||||
"",
|
||||
{},
|
||||
module=new_module,
|
||||
|
|
|
|||
|
|
@ -749,7 +749,7 @@ class _ModuleFrame:
|
|||
seen_nodes,
|
||||
seen_modules,
|
||||
parent,
|
||||
module_stack: List[str],
|
||||
module_stack: List[Tuple[str, int]],
|
||||
module_id,
|
||||
module_call_graph: Dict[str, ModuleCallSignature],
|
||||
module: Optional[torch.nn.Module] = None,
|
||||
|
|
@ -765,7 +765,7 @@ class _ModuleFrame:
|
|||
self.module_call_graph = module_call_graph
|
||||
self.verbose = False
|
||||
|
||||
self.fqn = self.module_stack[-1]
|
||||
self.fqn, num_calls = self.module_stack[-1]
|
||||
if module is not None:
|
||||
self.module = module
|
||||
else:
|
||||
|
|
@ -779,9 +779,6 @@ class _ModuleFrame:
|
|||
|
||||
self.parent_call_module: Optional[torch.fx.Node] = None
|
||||
if parent is not None:
|
||||
num_calls = len(
|
||||
[x for x in self.seen_modules[self.module_id] if x.fqn == self.fqn]
|
||||
)
|
||||
if self.fqn in module_call_graph and num_calls == 1:
|
||||
raise ValueError(
|
||||
f"Cannot unflatten multiple calls to module {self.fqn} while preserving its signature "
|
||||
|
|
@ -1068,8 +1065,9 @@ class _ModuleFrame:
|
|||
self.print()
|
||||
self.print("STEP", node_idx, node.format_node())
|
||||
self.print(self.module_stack)
|
||||
depth = len(self.module_stack)
|
||||
if node.op == "output":
|
||||
if len(self.module_stack) == 1:
|
||||
if depth == 1:
|
||||
# We want the output node of the original graph to be handled
|
||||
# specially by the outermost stack frame (in run_outer). So
|
||||
# skip finalization here.
|
||||
|
|
@ -1095,10 +1093,11 @@ class _ModuleFrame:
|
|||
node_module_stack = self.module_stack
|
||||
else:
|
||||
node_module_stack = [
|
||||
path for path, ty in node.meta["nn_module_stack"].values()
|
||||
(path, int(k.split("@")[-1]) if "@" in k else 0)
|
||||
for k, (path, ty) in node.meta["nn_module_stack"].items()
|
||||
]
|
||||
|
||||
if node_module_stack[: len(self.module_stack)] != self.module_stack:
|
||||
if node_module_stack[:depth] != self.module_stack:
|
||||
# This means that the current module is done executing and the
|
||||
# current node is the beginning of a new module.
|
||||
#
|
||||
|
|
@ -1114,10 +1113,11 @@ class _ModuleFrame:
|
|||
if _is_prefix(self.module_stack, node_module_stack):
|
||||
# This means that the current node represents the execution of a new
|
||||
# module.
|
||||
next_module = node_module_stack[len(self.module_stack)]
|
||||
next_module = node_module_stack[depth]
|
||||
self.print("Creating new stack frame for", next_module)
|
||||
# Run a nested version of module outliner from the current node
|
||||
# counter. Once it is complete, continue from that point.
|
||||
next_module_key = list(node.meta["nn_module_stack"].keys())[depth]
|
||||
node_idx = _ModuleFrame(
|
||||
self.flat_graph,
|
||||
self.nodes,
|
||||
|
|
@ -1125,7 +1125,7 @@ class _ModuleFrame:
|
|||
self.seen_modules,
|
||||
self,
|
||||
self.module_stack + [next_module],
|
||||
list(node.meta["nn_module_stack"].keys())[len(self.module_stack)],
|
||||
next_module_key.split("@")[0],
|
||||
self.module_call_graph,
|
||||
).run_from(node_idx)
|
||||
module_idx += 1
|
||||
|
|
@ -1157,7 +1157,7 @@ def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModu
|
|||
seen_nodes,
|
||||
seen_modules,
|
||||
None,
|
||||
[""],
|
||||
[("", 0)],
|
||||
"",
|
||||
{
|
||||
entry.fqn: entry.signature
|
||||
|
|
|
|||
|
|
@ -308,6 +308,7 @@ class Tracer(TracerBase):
|
|||
self.scope = Scope("", None)
|
||||
# Records the module call stack
|
||||
self.module_stack = collections.OrderedDict()
|
||||
self.num_calls: Dict[str, int] = {}
|
||||
# Mapping of node name to module scope
|
||||
self.node_name_to_scope: Dict[str, Tuple[str, type]] = {}
|
||||
|
||||
|
|
@ -514,13 +515,16 @@ class Tracer(TracerBase):
|
|||
with ScopeContextManager(self.scope, Scope(module_qualified_name, type(m))) as _scope:
|
||||
# module_stack is an ordered dict so writing then deleting the
|
||||
# entry is equivalent to push/pop on a list
|
||||
self.module_stack[_scope.module_path] = (module_qualified_name, _scope.module_type)
|
||||
num_calls = self.num_calls.get(module_qualified_name, 0)
|
||||
module_key = f"{_scope.module_path}@{num_calls}" if num_calls > 0 else _scope.module_path
|
||||
self.module_stack[module_key] = (module_qualified_name, _scope.module_type)
|
||||
self.num_calls[module_qualified_name] = num_calls + 1
|
||||
if not self.is_leaf_module(m, module_qualified_name):
|
||||
ret_val = forward(*args, **kwargs)
|
||||
else:
|
||||
ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs)
|
||||
key, _ = self.module_stack.popitem(last=True)
|
||||
assert key == _scope.module_path, f" Unexpected key {key}"
|
||||
assert key == module_key, f" Unexpected key {key}"
|
||||
|
||||
return ret_val
|
||||
|
||||
|
|
|
|||
|
|
@ -140,7 +140,7 @@ class _ModuleMeta:
|
|||
) -> _ModuleMeta:
|
||||
"""Create a module meta from raw meta produced by FX dynamo tracer."""
|
||||
module_name, (qualified_name, module_class) = raw_meta
|
||||
return _ModuleMeta(module_name, module_class, raw_meta)
|
||||
return _ModuleMeta(module_name.split("@")[0], module_class, raw_meta)
|
||||
|
||||
@classmethod
|
||||
def from_raw_meta(
|
||||
|
|
|
|||
Loading…
Reference in a new issue