[alt] fix unroll in successive unflatten (#137646)

We use nn_module_stack in unflatten to recognize when module calls begin and end. However the current format is not sufficient to detect module call boundaries when we have successive calls to the same module, because the successive instructions (end of one call, begin of next call) have the same nn_module_stack. This causes us to effectively "unroll" successive calls to a single call. This can cause problems when preserving module call signatures because the outputs of the successive calls might be concatenated in the single call.

Previously we introduced the concept of a "call index" to generate multiple graphs when unflattening, one per call. This PR pushes this concept into nn_module_stack itself. In particular, the keys of nn_module_stack now go from `key` to `key@call_index`. (In a previous attempt, https://github.com/pytorch/pytorch/pull/137457, instead values in nn_module_stack go from (fqn, type) to (fqn, type, call_index), which is BC-breaking.)

Note that we still do not have the ability to preserve module call signatures for multiple calls to the same module. But now instead of randomly crashing we give a proper error. OTOH when not preserving module call signatures we simply generate multiple calls, each with its own graph, possibly deduplicated, matching what we would do for non-successive calls.

Test Plan: Like D64014936

Differential Revision: D64136277

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137646
Approved by: https://github.com/angelayi
This commit is contained in:
Avik Chaudhuri 2024-10-12 15:53:52 +00:00 committed by PyTorch MergeBot
parent 561f07fae7
commit ed55d356de
8 changed files with 84 additions and 16 deletions

View file

@ -6074,6 +6074,62 @@ graph():
self.assertEqual(gm_flat_non_strict(*inp), gm_flat_strict(*inp))
def test_unflatten_no_unroll(self):
inp = (torch.ones(1),)
class N(torch.nn.Module):
def forward(self, x, b):
if b:
return x + 1
else:
return x + 2
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.n = N()
def forward(self, x):
x0 = x + 3
x1 = self.n(x0, True)
x2 = self.n(x0, False)
return x1 + x2
m = M()
eager_result = m(*inp)
if not is_retracebility_test(self._testMethodName):
ep = export(M(), inp, preserve_module_call_signature=("n",))
with self.assertRaisesRegex(
ValueError,
"Cannot unflatten multiple calls to module n while preserving its signature",
):
torch.export.unflatten(ep)
ep = export(M(), inp)
print(ep)
epm = ep.module()
ufm = torch.export.unflatten(ep)
exported_result = epm(*inp)
self.assertTrue(torch.allclose(exported_result, eager_result))
unflattened_result = ufm(*inp)
self.assertTrue(torch.allclose(unflattened_result, eager_result))
class _N(torch.nn.Module):
def forward(self, x):
return x + 1
class _N_1(torch.nn.Module):
def forward(self, x):
return x + 2
ufm.set_submodule("n", _N())
ufm.set_submodule("n@1", _N_1())
unflattened_result = ufm(*inp)
self.assertTrue(torch.allclose(unflattened_result, eager_result))
def test_preserve_module_call_signature_unflatten_specialization(self):
class N(torch.nn.Module):
def forward(self, x, b):

View file

@ -1995,7 +1995,11 @@ class SubgraphTracer(fx.Tracer):
rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
(
rv.node.name,
rv.node.meta["nn_module_stack"][target][1],
next(
ty
for k, (_, ty) in rv.node.meta["nn_module_stack"].items()
if k.split("@")[0] == target
),
)
]

View file

@ -2710,6 +2710,7 @@ class InstructionTranslatorBase(
# The first field of tuple is the fully qualified name of current module
# in original hierarchy. The second field is the type of current nn.module
self.nn_module_stack: Dict[str, Tuple[str, Type[Any]]] = {}
self.num_calls: Dict[str, int] = {}
# Flag to indicate whether tracing is used for export.
self.export = export
self.one_graph = False

View file

@ -82,8 +82,11 @@ def initialize_lazy_module(tx: "InstructionTranslator", mod, args, kwargs):
@contextmanager
def record_nn_module_stack(module_key: str, source, tx, mod: torch.nn.Module):
fully_qualified_name = source.name()
num_calls = tx.num_calls.get(fully_qualified_name, 0)
module_key = f"{module_key}@{num_calls}" if num_calls > 0 else module_key
try:
tx.nn_module_stack[module_key] = (fully_qualified_name, mod.__class__)
tx.num_calls[fully_qualified_name] = num_calls + 1
yield
finally:
del tx.nn_module_stack[module_key]

View file

@ -18,7 +18,7 @@ def _outline_submodules(orig_graph: torch.fx.Graph):
seen_nodes,
seen_modules,
None,
[""],
[("", 0)],
"",
{},
module=new_module,

View file

@ -749,7 +749,7 @@ class _ModuleFrame:
seen_nodes,
seen_modules,
parent,
module_stack: List[str],
module_stack: List[Tuple[str, int]],
module_id,
module_call_graph: Dict[str, ModuleCallSignature],
module: Optional[torch.nn.Module] = None,
@ -765,7 +765,7 @@ class _ModuleFrame:
self.module_call_graph = module_call_graph
self.verbose = False
self.fqn = self.module_stack[-1]
self.fqn, num_calls = self.module_stack[-1]
if module is not None:
self.module = module
else:
@ -779,9 +779,6 @@ class _ModuleFrame:
self.parent_call_module: Optional[torch.fx.Node] = None
if parent is not None:
num_calls = len(
[x for x in self.seen_modules[self.module_id] if x.fqn == self.fqn]
)
if self.fqn in module_call_graph and num_calls == 1:
raise ValueError(
f"Cannot unflatten multiple calls to module {self.fqn} while preserving its signature "
@ -1068,8 +1065,9 @@ class _ModuleFrame:
self.print()
self.print("STEP", node_idx, node.format_node())
self.print(self.module_stack)
depth = len(self.module_stack)
if node.op == "output":
if len(self.module_stack) == 1:
if depth == 1:
# We want the output node of the original graph to be handled
# specially by the outermost stack frame (in run_outer). So
# skip finalization here.
@ -1095,10 +1093,11 @@ class _ModuleFrame:
node_module_stack = self.module_stack
else:
node_module_stack = [
path for path, ty in node.meta["nn_module_stack"].values()
(path, int(k.split("@")[-1]) if "@" in k else 0)
for k, (path, ty) in node.meta["nn_module_stack"].items()
]
if node_module_stack[: len(self.module_stack)] != self.module_stack:
if node_module_stack[:depth] != self.module_stack:
# This means that the current module is done executing and the
# current node is the beginning of a new module.
#
@ -1114,10 +1113,11 @@ class _ModuleFrame:
if _is_prefix(self.module_stack, node_module_stack):
# This means that the current node represents the execution of a new
# module.
next_module = node_module_stack[len(self.module_stack)]
next_module = node_module_stack[depth]
self.print("Creating new stack frame for", next_module)
# Run a nested version of module outliner from the current node
# counter. Once it is complete, continue from that point.
next_module_key = list(node.meta["nn_module_stack"].keys())[depth]
node_idx = _ModuleFrame(
self.flat_graph,
self.nodes,
@ -1125,7 +1125,7 @@ class _ModuleFrame:
self.seen_modules,
self,
self.module_stack + [next_module],
list(node.meta["nn_module_stack"].keys())[len(self.module_stack)],
next_module_key.split("@")[0],
self.module_call_graph,
).run_from(node_idx)
module_idx += 1
@ -1157,7 +1157,7 @@ def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModu
seen_nodes,
seen_modules,
None,
[""],
[("", 0)],
"",
{
entry.fqn: entry.signature

View file

@ -308,6 +308,7 @@ class Tracer(TracerBase):
self.scope = Scope("", None)
# Records the module call stack
self.module_stack = collections.OrderedDict()
self.num_calls: Dict[str, int] = {}
# Mapping of node name to module scope
self.node_name_to_scope: Dict[str, Tuple[str, type]] = {}
@ -514,13 +515,16 @@ class Tracer(TracerBase):
with ScopeContextManager(self.scope, Scope(module_qualified_name, type(m))) as _scope:
# module_stack is an ordered dict so writing then deleting the
# entry is equivalent to push/pop on a list
self.module_stack[_scope.module_path] = (module_qualified_name, _scope.module_type)
num_calls = self.num_calls.get(module_qualified_name, 0)
module_key = f"{_scope.module_path}@{num_calls}" if num_calls > 0 else _scope.module_path
self.module_stack[module_key] = (module_qualified_name, _scope.module_type)
self.num_calls[module_qualified_name] = num_calls + 1
if not self.is_leaf_module(m, module_qualified_name):
ret_val = forward(*args, **kwargs)
else:
ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs)
key, _ = self.module_stack.popitem(last=True)
assert key == _scope.module_path, f" Unexpected key {key}"
assert key == module_key, f" Unexpected key {key}"
return ret_val

View file

@ -140,7 +140,7 @@ class _ModuleMeta:
) -> _ModuleMeta:
"""Create a module meta from raw meta produced by FX dynamo tracer."""
module_name, (qualified_name, module_class) = raw_meta
return _ModuleMeta(module_name, module_class, raw_meta)
return _ModuleMeta(module_name.split("@")[0], module_class, raw_meta)
@classmethod
def from_raw_meta(