[export] cache unflatten forward module

Differential Revision: D69361235
This commit is contained in:
Pian Pawakapan 2025-02-08 23:35:43 -08:00 committed by Facebook GitHub Bot
parent 91c4bf39d3
commit 276098a5e2

View file

@ -510,6 +510,8 @@ class UnflattenedModule(torch.nn.Module):
_reorder_submodules(self, fqn_order)
self.graph.lint()
self._cached_mod = None
def _print_graph(self):
for fqn, mod in self.named_modules():
print(fqn + ":")
@ -601,10 +603,16 @@ class UnflattenedModule(torch.nn.Module):
return return_val[0]
return return_val
if torch.compiler.is_dynamo_compiling() and not self._run_with_interpreter:
tree_out = torch.fx.GraphModule(self, self.graph)(*flat_args)
if torch.compiler.is_dynamo_compiling() or not self._run_with_interpreter:
if self._cached_mod is None:
gm = torch.fx.GraphModule(self, self.graph)
self._cached_mod = lambda: gm
tree_out = self._cached_mod()(*flat_args)
else:
tree_out = torch.fx.Interpreter(self, graph=self.graph).run(
if self._cached_mod is None:
gm = torch.fx.Interpreter(self, graph=self.graph)
self._cached_mod = lambda: gm
tree_out = self._cached_mod().run(
*flat_args, enable_io_processing=False
)
return pytree.tree_unflatten(tree_out, signature.out_spec)