mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[export] Deserialize subgraphs. (#103991)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/103991 Approved by: https://github.com/angelayi, https://github.com/avikchaudhuri
This commit is contained in:
parent
dd4f4bb47d
commit
100aff9d4f
2 changed files with 106 additions and 59 deletions
|
|
@ -70,7 +70,7 @@ class TestSerialize(TestCase):
|
|||
|
||||
serialized, _ = ExportedProgramSerializer().serialize(exported_module)
|
||||
node = serialized.graph_module.graph.nodes[-7]
|
||||
self.assertEqual(node.target, "torch._ops.aten.var_mean.correction")
|
||||
self.assertEqual(node.target, "torch.ops.aten.var_mean.correction")
|
||||
# aten::native_layer_norm returns 3 tensnors
|
||||
self.assertEqual(len(node.outputs), 2)
|
||||
|
||||
|
|
@ -95,7 +95,7 @@ class TestSerialize(TestCase):
|
|||
|
||||
serialized, _ = ExportedProgramSerializer().serialize(exported_module)
|
||||
node = serialized.graph_module.graph.nodes[-1]
|
||||
self.assertEqual(node.target, "torch._ops.aten.split.Tensor")
|
||||
self.assertEqual(node.target, "torch.ops.aten.split.Tensor")
|
||||
self.assertEqual(len(node.outputs), 1)
|
||||
# Input looks like:
|
||||
# tensor([[0, 1],
|
||||
|
|
@ -138,7 +138,7 @@ class TestSerialize(TestCase):
|
|||
|
||||
serialized, _ = ExportedProgramSerializer().serialize(exported_module)
|
||||
node = serialized.graph_module.graph.nodes[-1]
|
||||
self.assertEqual(node.target, "torch._ops.aten.var_mean.correction")
|
||||
self.assertEqual(node.target, "torch.ops.aten.var_mean.correction")
|
||||
self.assertEqual(len(node.outputs), 2)
|
||||
|
||||
# check the names are unique
|
||||
|
|
@ -163,7 +163,7 @@ class TestSerialize(TestCase):
|
|||
serialized, _ = ExportedProgramSerializer().serialize(exported_module)
|
||||
|
||||
node = serialized.graph_module.graph.nodes[-1]
|
||||
self.assertEqual(node.target, "torch._ops.aten.searchsorted.Tensor")
|
||||
self.assertEqual(node.target, "torch.ops.aten.searchsorted.Tensor")
|
||||
self.assertEqual(len(node.inputs), 6)
|
||||
self.assertEqual(node.inputs[2].arg.as_bool, False)
|
||||
self.assertEqual(node.inputs[3].arg.as_bool, True)
|
||||
|
|
@ -194,6 +194,7 @@ class TestDeserialize(TestCase):
|
|||
else:
|
||||
self.assertEqual(orig, loaded)
|
||||
|
||||
self.assertEqual(len(ep.graph.nodes), len(deserialized_ep.graph.nodes))
|
||||
for node1, node2 in zip(ep.graph.nodes, deserialized_ep.graph.nodes):
|
||||
# Check "val" metadata
|
||||
val1 = node1.meta.get("val", None)
|
||||
|
|
@ -240,10 +241,11 @@ class TestDeserialize(TestCase):
|
|||
)
|
||||
|
||||
# Check "source_fn" metadata
|
||||
self.assertEqual(
|
||||
node1.meta.get("source_fn", None),
|
||||
node2.meta.get("source_fn", None),
|
||||
)
|
||||
if node1.op != "get_attr":
|
||||
self.assertEqual(
|
||||
node1.meta.get("source_fn", None),
|
||||
node2.meta.get("source_fn", None),
|
||||
)
|
||||
|
||||
def test_multi_return(self) -> None:
|
||||
"""
|
||||
|
|
@ -341,6 +343,21 @@ class TestDeserialize(TestCase):
|
|||
inputs = (torch.randn(3, 3),)
|
||||
self.check_graph(M(), inputs)
|
||||
|
||||
def test_cond(self):
|
||||
from functorch.experimental.control_flow import cond
|
||||
inputs = torch.ones(4, 3), torch.zeros(4, 3)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
def t(x, y):
|
||||
return x + y
|
||||
|
||||
def f(x, y):
|
||||
return x - y
|
||||
return cond(x[0][0] > 4, t, f, [x, y])
|
||||
|
||||
self.check_graph(M(), inputs)
|
||||
|
||||
@parametrize(
|
||||
"name,case",
|
||||
get_filtered_export_db_tests(),
|
||||
|
|
|
|||
|
|
@ -253,17 +253,23 @@ def deserialize_metadata(metadata) -> Dict[str, str]:
|
|||
def serialize_operator(target) -> str:
|
||||
if isinstance(target, str):
|
||||
return target
|
||||
return f"{target.__module__}.{target.__name__}"
|
||||
elif target.__module__.startswith("torch._ops"):
|
||||
# TODO(zhxchen17) Maybe provide a function name helper in FX.
|
||||
# From torch.fx.node._get_qualified_name
|
||||
module = target.__module__.replace("torch._ops", "torch.ops")
|
||||
return f"{module}.{target.__name__}"
|
||||
else: # TODO(zhxchen17) Don't catch all here.
|
||||
return f"{target.__module__}.{target.__name__}"
|
||||
|
||||
|
||||
def deserialize_operator(serialized_target: str):
|
||||
if serialized_target.startswith("_operator"):
|
||||
if serialized_target.startswith("_operator"): # TODO(zhxchen17) Follow up on this.
|
||||
module = operator
|
||||
serialized_target_names = serialized_target.split(".")[1:]
|
||||
elif serialized_target.startswith("torch._ops"):
|
||||
elif serialized_target.startswith("torch.ops"):
|
||||
module = torch.ops
|
||||
serialized_target_names = serialized_target.split(".")[2:]
|
||||
else:
|
||||
else: # TODO(zhxchen17) Don't catch all here.
|
||||
return serialized_target
|
||||
|
||||
target = module
|
||||
|
|
@ -775,6 +781,19 @@ class GraphModuleDeserializer:
|
|||
self.serialized_name_to_node: Dict[str, torch.fx.Node] = {}
|
||||
self.serialized_name_to_meta: Dict[str, MetaType] = {}
|
||||
self.graph = torch.fx.Graph()
|
||||
self.module = torch.nn.Module()
|
||||
|
||||
@contextmanager
|
||||
def save_graph_module(self) -> None:
|
||||
saved = self.graph, self.module, self.serialized_name_to_node, self.serialized_name_to_meta
|
||||
self.graph = torch.fx.Graph()
|
||||
self.module = torch.nn.Module()
|
||||
self.serialized_name_to_node = {}
|
||||
self.serialized_name_to_meta = {}
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.graph, self.module, self.serialized_name_to_node, self.serialized_name_to_meta = saved
|
||||
|
||||
def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]:
|
||||
val = s.value
|
||||
|
|
@ -830,19 +849,7 @@ class GraphModuleDeserializer:
|
|||
),
|
||||
)
|
||||
|
||||
def deserialize(
|
||||
self,
|
||||
serialized_graph_module: GraphModule,
|
||||
symbol_name_to_range: Optional[Dict[str, symbolic_shapes.ValueRanges]] = None,
|
||||
) -> Tuple[torch.fx.GraphModule, ep.ExportGraphSignature, ep.CallSpec, Dict[str, sympy.Symbol]]:
|
||||
self.shape_env = symbolic_shapes.ShapeEnv()
|
||||
self.fake_tensor_mode = FakeTensorMode(shape_env=self.shape_env)
|
||||
self.symbol_name_to_symbol: Dict[str, sympy.Symbol] = {}
|
||||
self.symbol_name_to_range = {} if symbol_name_to_range is None else symbol_name_to_range
|
||||
|
||||
graph = self.graph
|
||||
serialized_graph = serialized_graph_module.graph
|
||||
|
||||
def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph:
|
||||
# Handle the tensor metas.
|
||||
for name, tensor_value in serialized_graph.tensor_values.items():
|
||||
meta_val = self.deserialize_tensor_meta(tensor_value.meta, self.fake_tensor_mode)
|
||||
|
|
@ -856,35 +863,30 @@ class GraphModuleDeserializer:
|
|||
|
||||
# Inputs: convert to placeholder nodes in FX.
|
||||
for input in serialized_graph.inputs:
|
||||
placeholder_node = graph.placeholder(input.as_tensor.name)
|
||||
self.sync_serialized_node(input.as_tensor.name, placeholder_node)
|
||||
placeholder_node = self.graph.placeholder(input.as_tensor.name)
|
||||
self.sync_fx_node(input.as_tensor.name, placeholder_node)
|
||||
|
||||
# Nodes: convert to call_function nodes.
|
||||
for serialized_node in serialized_graph.nodes:
|
||||
try:
|
||||
target = deserialize_operator(serialized_node.target)
|
||||
if isinstance(target, str):
|
||||
# Create a dummy fake op if the target does not exist
|
||||
# because we cannot create a call_function node w/o a
|
||||
# callable target
|
||||
log.warning(f"Could not find operator {target}. Returning fake operator.") # noqa: G004
|
||||
|
||||
def fake_op(x):
|
||||
raise NotImplementedError("Fake op is not meant to be run.")
|
||||
fake_op.__name__ = target
|
||||
target = fake_op
|
||||
|
||||
if target.__module__ == "_operator":
|
||||
if target.__module__ == "_operator": # TODO(zhxchen17) Follow up on this.
|
||||
name = serialized_node.outputs[0].value.as_name
|
||||
args = self.deserialize_sym_op_inputs(serialized_node.inputs)
|
||||
|
||||
fx_node = graph.create_node("call_function", target, args, {}, name)
|
||||
fx_node = self.graph.create_node("call_function", target, args, {}, name)
|
||||
self.deserialize_sym_op_outputs(serialized_node, fx_node)
|
||||
fx_node.meta.update(deserialize_metadata(serialized_node.metadata))
|
||||
|
||||
else:
|
||||
target = deserialize_operator(serialized_node.target)
|
||||
|
||||
elif isinstance(target, torch._ops.HigherOrderOperator):
|
||||
assert (
|
||||
len(serialized_node.outputs) == 1
|
||||
and serialized_node.outputs[0].as_tensor is not None
|
||||
), "Only single tensor output is supported for higher order operators."
|
||||
name = serialized_node.outputs[0].as_tensor.name
|
||||
args = tuple(self.deserialize_input(input.arg) for input in serialized_node.inputs)
|
||||
fx_node = self.graph.create_node("call_function", target, args, {}, name)
|
||||
self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node)
|
||||
elif isinstance(target, torch._ops.OpOverload):
|
||||
# For convenience: if this node returns a single tensor, name the
|
||||
# newly-created node after it. This ensures that these tensor values
|
||||
# have names that are consistent with serialized.
|
||||
|
|
@ -894,12 +896,12 @@ class GraphModuleDeserializer:
|
|||
else None # FX will generate a name for us.
|
||||
)
|
||||
args, kwargs = self.deserialize_inputs(target, serialized_node)
|
||||
|
||||
fx_node = graph.create_node("call_function", target, args, kwargs, name)
|
||||
|
||||
fx_node = self.graph.create_node("call_function", target, args, kwargs, name)
|
||||
self.deserialize_outputs(serialized_node, fx_node)
|
||||
else:
|
||||
raise SerializeError(f"Unsupported target type for node {serialized_node}: {target}")
|
||||
|
||||
fx_node.meta.update(deserialize_metadata(serialized_node.metadata))
|
||||
fx_node.meta.update(deserialize_metadata(serialized_node.metadata))
|
||||
|
||||
except Exception as e:
|
||||
raise SerializeError(f"Failed deserializing node {serialized_node}") from e
|
||||
|
|
@ -915,19 +917,32 @@ class GraphModuleDeserializer:
|
|||
raise SerializeError(f"Unable to deserialize output node {output}")
|
||||
|
||||
|
||||
output_node = graph.output(tuple(outputs))
|
||||
output_node = self.graph.output(tuple(outputs))
|
||||
output_node.meta["val"] = tuple(
|
||||
arg.meta["val"] for arg in output_node.args[0]
|
||||
)
|
||||
|
||||
def deserialize(
|
||||
self,
|
||||
serialized_graph_module: GraphModule,
|
||||
symbol_name_to_range: Optional[Dict[str, symbolic_shapes.ValueRanges]] = None,
|
||||
) -> Tuple[torch.fx.GraphModule, ep.ExportGraphSignature, ep.CallSpec, Dict[str, sympy.Symbol]]:
|
||||
self.shape_env = symbolic_shapes.ShapeEnv()
|
||||
self.fake_tensor_mode = FakeTensorMode(shape_env=self.shape_env)
|
||||
self.symbol_name_to_symbol: Dict[str, sympy.Symbol] = {}
|
||||
self.symbol_name_to_range = {} if symbol_name_to_range is None else symbol_name_to_range
|
||||
|
||||
self.deserialize_graph(serialized_graph_module.graph)
|
||||
|
||||
sig = deserialize_signature(serialized_graph_module.signature)
|
||||
call_spec = deserialize_call_spec(serialized_graph_module.call_spec)
|
||||
return torch.fx.GraphModule({}, graph), sig, call_spec, self.symbol_name_to_symbol
|
||||
return torch.fx.GraphModule(self.module, self.graph), sig, call_spec, self.symbol_name_to_symbol
|
||||
|
||||
def sync_serialized_node(self, name: str, fx_node: torch.fx.Node):
|
||||
def sync_fx_node(self, name: str, fx_node: torch.fx.Node):
|
||||
if name in self.serialized_name_to_node:
|
||||
raise SerializeError(f"Node {name} has already been deserialized before.")
|
||||
self.serialized_name_to_node[name] = fx_node
|
||||
assert "val" not in fx_node.meta
|
||||
fx_node.meta["val"] = self.serialized_name_to_meta[name]
|
||||
|
||||
def deserialize_sym_op_inputs(self, inputs):
|
||||
|
|
@ -962,6 +977,17 @@ class GraphModuleDeserializer:
|
|||
return _SERIALIZE_TO_TORCH_MEMORY_FORMAT[value]
|
||||
elif typ_ == "as_layout":
|
||||
return _SERIALIZE_TO_TORCH_LAYOUT[value]
|
||||
elif typ_ == "as_graph":
|
||||
assert isinstance(value, GraphArgument)
|
||||
with self.save_graph_module():
|
||||
self.deserialize_graph(value.graph)
|
||||
submodule = torch.fx.GraphModule(self.module, self.graph)
|
||||
self.module.register_module(value.name, submodule)
|
||||
return self.graph.create_node(
|
||||
"get_attr",
|
||||
value.name,
|
||||
name=value.name,
|
||||
)
|
||||
elif isinstance(value, Device):
|
||||
return deserialize_device(value)
|
||||
elif isinstance(value, TensorArgument):
|
||||
|
|
@ -991,20 +1017,22 @@ class GraphModuleDeserializer:
|
|||
return self.serialized_name_to_node[sym_int_arg.as_name]
|
||||
|
||||
def deserialize_sym_op_outputs(self, serialized_node: Node, fx_node: torch.fx.Node):
|
||||
self.sync_serialized_node(serialized_node.outputs[0].value.as_name, fx_node)
|
||||
self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node)
|
||||
|
||||
def deserialize_outputs(self, serialized_node: Node, fx_node: torch.fx.Node) -> None:
|
||||
def deserialize_outputs(self, serialized_node: Node, fx_node: torch.fx.Node):
|
||||
# Simple case for single tensor return.
|
||||
assert isinstance(fx_node.target, torch._ops.OpOverload)
|
||||
returns = fx_node.target._schema.returns
|
||||
|
||||
# Check single value return
|
||||
if len(returns) == 0:
|
||||
return None
|
||||
return
|
||||
if _is_single_tensor_return(fx_node.target):
|
||||
return self.sync_serialized_node(serialized_node.outputs[0].as_tensor.name, fx_node)
|
||||
self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node)
|
||||
return
|
||||
elif len(returns) == 1 and isinstance(serialized_node.outputs[0].value, (SymIntArgument, SymBoolArgument)):
|
||||
return self.sync_serialized_node(serialized_node.outputs[0].value.as_name, fx_node)
|
||||
self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node)
|
||||
return
|
||||
|
||||
# Convert multiple return types to FX format.
|
||||
# In FX, each node only returns one value. So in order to represent
|
||||
|
|
@ -1029,7 +1057,7 @@ class GraphModuleDeserializer:
|
|||
(fx_node, idx),
|
||||
name=name,
|
||||
)
|
||||
self.sync_serialized_node(name, individual_output)
|
||||
self.sync_fx_node(name, individual_output)
|
||||
# The derived `getitem` nodes should have the same stacktrace as the
|
||||
# original `fx_node`
|
||||
individual_output.meta.update(deserialize_metadata(serialized_node.metadata))
|
||||
|
|
@ -1084,7 +1112,7 @@ class ExportedProgramDeserializer:
|
|||
equality_constraints = deserialize_equality_constraints(serialized_exported_program.equality_constraints)
|
||||
|
||||
exported_program = ep.ExportedProgram(
|
||||
state_dict,
|
||||
graph_module,
|
||||
graph_module.graph,
|
||||
sig,
|
||||
call_spec,
|
||||
|
|
@ -1157,6 +1185,7 @@ def serialize(
|
|||
|
||||
|
||||
def _dict_to_dataclass(cls, data):
|
||||
assert not isinstance(cls, str), f"Unresolved class type: '{cls}'."
|
||||
if isinstance(cls, type) and issubclass(cls, _Union):
|
||||
obj = cls(**data)
|
||||
field_type = cls.__annotations__[obj.type]
|
||||
|
|
@ -1164,9 +1193,10 @@ def _dict_to_dataclass(cls, data):
|
|||
return obj
|
||||
elif dataclasses.is_dataclass(cls):
|
||||
obj = cls(**data) # type: ignore[assignment]
|
||||
type_hints = typing.get_type_hints(cls)
|
||||
for f in dataclasses.fields(cls):
|
||||
name = f.name
|
||||
new_field_obj = _dict_to_dataclass(f.type, getattr(obj, name))
|
||||
new_field_obj = _dict_to_dataclass(type_hints[name], getattr(obj, name))
|
||||
setattr(obj, name, new_field_obj)
|
||||
return obj
|
||||
elif isinstance(data, list):
|
||||
|
|
|
|||
Loading…
Reference in a new issue