mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[export] cache unflatten forward module
Differential Revision: D69361235
This commit is contained in:
parent
91c4bf39d3
commit
276098a5e2
1 changed files with 11 additions and 3 deletions
|
|
@ -510,6 +510,8 @@ class UnflattenedModule(torch.nn.Module):
|
|||
_reorder_submodules(self, fqn_order)
|
||||
self.graph.lint()
|
||||
|
||||
self._cached_mod = None
|
||||
|
||||
def _print_graph(self):
|
||||
for fqn, mod in self.named_modules():
|
||||
print(fqn + ":")
|
||||
|
|
@ -601,10 +603,16 @@ class UnflattenedModule(torch.nn.Module):
|
|||
return return_val[0]
|
||||
return return_val
|
||||
|
||||
if torch.compiler.is_dynamo_compiling() and not self._run_with_interpreter:
|
||||
tree_out = torch.fx.GraphModule(self, self.graph)(*flat_args)
|
||||
if torch.compiler.is_dynamo_compiling() or not self._run_with_interpreter:
|
||||
if self._cached_mod is None:
|
||||
gm = torch.fx.GraphModule(self, self.graph)
|
||||
self._cached_mod = lambda: gm
|
||||
tree_out = self._cached_mod()(*flat_args)
|
||||
else:
|
||||
tree_out = torch.fx.Interpreter(self, graph=self.graph).run(
|
||||
if self._cached_mod is None:
|
||||
gm = torch.fx.Interpreter(self, graph=self.graph)
|
||||
self._cached_mod = lambda: gm
|
||||
tree_out = self._cached_mod().run(
|
||||
*flat_args, enable_io_processing=False
|
||||
)
|
||||
return pytree.tree_unflatten(tree_out, signature.out_spec)
|
||||
|
|
|
|||
Loading…
Reference in a new issue