diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 822e0ee0373..f4b975d3055 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -510,6 +510,8 @@ class UnflattenedModule(torch.nn.Module): _reorder_submodules(self, fqn_order) self.graph.lint() + self._cached_mod = None + def _print_graph(self): for fqn, mod in self.named_modules(): print(fqn + ":") @@ -601,10 +603,16 @@ class UnflattenedModule(torch.nn.Module): return return_val[0] return return_val - if torch.compiler.is_dynamo_compiling() and not self._run_with_interpreter: - tree_out = torch.fx.GraphModule(self, self.graph)(*flat_args) + if torch.compiler.is_dynamo_compiling() or not self._run_with_interpreter: + if self._cached_mod is None: + gm = torch.fx.GraphModule(self, self.graph) + self._cached_mod = lambda: gm + tree_out = self._cached_mod()(*flat_args) else: - tree_out = torch.fx.Interpreter(self, graph=self.graph).run( + if self._cached_mod is None: + gm = torch.fx.Interpreter(self, graph=self.graph) + self._cached_mod = lambda: gm + tree_out = self._cached_mod().run( *flat_args, enable_io_processing=False ) return pytree.tree_unflatten(tree_out, signature.out_spec)