mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Reland of https://github.com/pytorch/pytorch/pull/114787 Pull Request resolved: https://github.com/pytorch/pytorch/pull/115558 Approved by: https://github.com/zhxchen17, https://github.com/atalman ghstack dependencies: #115556, #115557
60 lines
1.6 KiB
Python
60 lines
1.6 KiB
Python
import dataclasses
|
|
from typing import Optional
|
|
import warnings
|
|
|
|
|
|
import torch
|
|
import torch.fx
|
|
import torch.utils._pytree as pytree
|
|
|
|
|
|
# TODO(ycao): This is added to avoid breaking existing code temporarily.
|
|
# Remove when migration is done.
|
|
from torch.export.graph_signature import (
|
|
ExportBackwardSignature,
|
|
ExportGraphSignature,
|
|
)
|
|
|
|
from torch.export.exported_program import (
|
|
ExportedProgram,
|
|
ModuleCallEntry,
|
|
ModuleCallSignature,
|
|
)
|
|
|
|
|
|
|
|
__all__ = [
|
|
"ExportBackwardSignature",
|
|
"ExportGraphSignature",
|
|
"ExportedProgram",
|
|
"ModuleCallEntry",
|
|
"ModuleCallSignature",
|
|
]
|
|
|
|
|
|
# Information to maintain user calling/returning specs
|
|
@dataclasses.dataclass
|
|
class CallSpec:
|
|
in_spec: Optional[pytree.TreeSpec]
|
|
out_spec: Optional[pytree.TreeSpec]
|
|
|
|
|
|
def _create_graph_module_for_export(root, graph):
|
|
try:
|
|
gm = torch.fx.GraphModule(root, graph)
|
|
except SyntaxError:
|
|
# If custom objects stored in memory are being used in the graph,
|
|
# the generated python code will result in a syntax error on the custom
|
|
# object, since it is unable to parse the in-memory object. However
|
|
# we can still run the graph eagerly through torch.fx.Interpreter,
|
|
# so we will bypass this error.
|
|
warnings.warn(
|
|
"Unable to execute the generated python source code from "
|
|
"the graph. The graph module will no longer be directly callable, "
|
|
"but you can still run the ExportedProgram, and if needed, you can "
|
|
"run the graph module eagerly using torch.fx.Interpreter."
|
|
)
|
|
gm = torch.fx.GraphModule(root, torch.fx.Graph())
|
|
gm._graph = graph
|
|
|
|
return gm
|