[export] Deserialize subgraphs. (#103991)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103991
Approved by: https://github.com/angelayi, https://github.com/avikchaudhuri
This commit is contained in:
zhxchen17 2023-06-26 18:17:44 +00:00 committed by PyTorch MergeBot
parent dd4f4bb47d
commit 100aff9d4f
2 changed files with 106 additions and 59 deletions

View file

@ -70,7 +70,7 @@ class TestSerialize(TestCase):
serialized, _ = ExportedProgramSerializer().serialize(exported_module)
node = serialized.graph_module.graph.nodes[-7]
self.assertEqual(node.target, "torch._ops.aten.var_mean.correction")
self.assertEqual(node.target, "torch.ops.aten.var_mean.correction")
# aten::native_layer_norm returns 3 tensnors
self.assertEqual(len(node.outputs), 2)
@ -95,7 +95,7 @@ class TestSerialize(TestCase):
serialized, _ = ExportedProgramSerializer().serialize(exported_module)
node = serialized.graph_module.graph.nodes[-1]
self.assertEqual(node.target, "torch._ops.aten.split.Tensor")
self.assertEqual(node.target, "torch.ops.aten.split.Tensor")
self.assertEqual(len(node.outputs), 1)
# Input looks like:
# tensor([[0, 1],
@ -138,7 +138,7 @@ class TestSerialize(TestCase):
serialized, _ = ExportedProgramSerializer().serialize(exported_module)
node = serialized.graph_module.graph.nodes[-1]
self.assertEqual(node.target, "torch._ops.aten.var_mean.correction")
self.assertEqual(node.target, "torch.ops.aten.var_mean.correction")
self.assertEqual(len(node.outputs), 2)
# check the names are unique
@ -163,7 +163,7 @@ class TestSerialize(TestCase):
serialized, _ = ExportedProgramSerializer().serialize(exported_module)
node = serialized.graph_module.graph.nodes[-1]
self.assertEqual(node.target, "torch._ops.aten.searchsorted.Tensor")
self.assertEqual(node.target, "torch.ops.aten.searchsorted.Tensor")
self.assertEqual(len(node.inputs), 6)
self.assertEqual(node.inputs[2].arg.as_bool, False)
self.assertEqual(node.inputs[3].arg.as_bool, True)
@ -194,6 +194,7 @@ class TestDeserialize(TestCase):
else:
self.assertEqual(orig, loaded)
self.assertEqual(len(ep.graph.nodes), len(deserialized_ep.graph.nodes))
for node1, node2 in zip(ep.graph.nodes, deserialized_ep.graph.nodes):
# Check "val" metadata
val1 = node1.meta.get("val", None)
@ -240,10 +241,11 @@ class TestDeserialize(TestCase):
)
# Check "source_fn" metadata
self.assertEqual(
node1.meta.get("source_fn", None),
node2.meta.get("source_fn", None),
)
if node1.op != "get_attr":
self.assertEqual(
node1.meta.get("source_fn", None),
node2.meta.get("source_fn", None),
)
def test_multi_return(self) -> None:
"""
@ -341,6 +343,21 @@ class TestDeserialize(TestCase):
inputs = (torch.randn(3, 3),)
self.check_graph(M(), inputs)
def test_cond(self):
from functorch.experimental.control_flow import cond
inputs = torch.ones(4, 3), torch.zeros(4, 3)
class M(torch.nn.Module):
def forward(self, x, y):
def t(x, y):
return x + y
def f(x, y):
return x - y
return cond(x[0][0] > 4, t, f, [x, y])
self.check_graph(M(), inputs)
@parametrize(
"name,case",
get_filtered_export_db_tests(),

View file

@ -253,17 +253,23 @@ def deserialize_metadata(metadata) -> Dict[str, str]:
def serialize_operator(target) -> str:
if isinstance(target, str):
return target
return f"{target.__module__}.{target.__name__}"
elif target.__module__.startswith("torch._ops"):
# TODO(zhxchen17) Maybe provide a function name helper in FX.
# From torch.fx.node._get_qualified_name
module = target.__module__.replace("torch._ops", "torch.ops")
return f"{module}.{target.__name__}"
else: # TODO(zhxchen17) Don't catch all here.
return f"{target.__module__}.{target.__name__}"
def deserialize_operator(serialized_target: str):
if serialized_target.startswith("_operator"):
if serialized_target.startswith("_operator"): # TODO(zhxchen17) Follow up on this.
module = operator
serialized_target_names = serialized_target.split(".")[1:]
elif serialized_target.startswith("torch._ops"):
elif serialized_target.startswith("torch.ops"):
module = torch.ops
serialized_target_names = serialized_target.split(".")[2:]
else:
else: # TODO(zhxchen17) Don't catch all here.
return serialized_target
target = module
@ -775,6 +781,19 @@ class GraphModuleDeserializer:
self.serialized_name_to_node: Dict[str, torch.fx.Node] = {}
self.serialized_name_to_meta: Dict[str, MetaType] = {}
self.graph = torch.fx.Graph()
self.module = torch.nn.Module()
@contextmanager
def save_graph_module(self) -> None:
saved = self.graph, self.module, self.serialized_name_to_node, self.serialized_name_to_meta
self.graph = torch.fx.Graph()
self.module = torch.nn.Module()
self.serialized_name_to_node = {}
self.serialized_name_to_meta = {}
try:
yield
finally:
self.graph, self.module, self.serialized_name_to_node, self.serialized_name_to_meta = saved
def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]:
val = s.value
@ -830,19 +849,7 @@ class GraphModuleDeserializer:
),
)
def deserialize(
self,
serialized_graph_module: GraphModule,
symbol_name_to_range: Optional[Dict[str, symbolic_shapes.ValueRanges]] = None,
) -> Tuple[torch.fx.GraphModule, ep.ExportGraphSignature, ep.CallSpec, Dict[str, sympy.Symbol]]:
self.shape_env = symbolic_shapes.ShapeEnv()
self.fake_tensor_mode = FakeTensorMode(shape_env=self.shape_env)
self.symbol_name_to_symbol: Dict[str, sympy.Symbol] = {}
self.symbol_name_to_range = {} if symbol_name_to_range is None else symbol_name_to_range
graph = self.graph
serialized_graph = serialized_graph_module.graph
def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph:
# Handle the tensor metas.
for name, tensor_value in serialized_graph.tensor_values.items():
meta_val = self.deserialize_tensor_meta(tensor_value.meta, self.fake_tensor_mode)
@ -856,35 +863,30 @@ class GraphModuleDeserializer:
# Inputs: convert to placeholder nodes in FX.
for input in serialized_graph.inputs:
placeholder_node = graph.placeholder(input.as_tensor.name)
self.sync_serialized_node(input.as_tensor.name, placeholder_node)
placeholder_node = self.graph.placeholder(input.as_tensor.name)
self.sync_fx_node(input.as_tensor.name, placeholder_node)
# Nodes: convert to call_function nodes.
for serialized_node in serialized_graph.nodes:
try:
target = deserialize_operator(serialized_node.target)
if isinstance(target, str):
# Create a dummy fake op if the target does not exist
# because we cannot create a call_function node w/o a
# callable target
log.warning(f"Could not find operator {target}. Returning fake operator.") # noqa: G004
def fake_op(x):
raise NotImplementedError("Fake op is not meant to be run.")
fake_op.__name__ = target
target = fake_op
if target.__module__ == "_operator":
if target.__module__ == "_operator": # TODO(zhxchen17) Follow up on this.
name = serialized_node.outputs[0].value.as_name
args = self.deserialize_sym_op_inputs(serialized_node.inputs)
fx_node = graph.create_node("call_function", target, args, {}, name)
fx_node = self.graph.create_node("call_function", target, args, {}, name)
self.deserialize_sym_op_outputs(serialized_node, fx_node)
fx_node.meta.update(deserialize_metadata(serialized_node.metadata))
else:
target = deserialize_operator(serialized_node.target)
elif isinstance(target, torch._ops.HigherOrderOperator):
assert (
len(serialized_node.outputs) == 1
and serialized_node.outputs[0].as_tensor is not None
), "Only single tensor output is supported for higher order operators."
name = serialized_node.outputs[0].as_tensor.name
args = tuple(self.deserialize_input(input.arg) for input in serialized_node.inputs)
fx_node = self.graph.create_node("call_function", target, args, {}, name)
self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node)
elif isinstance(target, torch._ops.OpOverload):
# For convenience: if this node returns a single tensor, name the
# newly-created node after it. This ensures that these tensor values
# have names that are consistent with serialized.
@ -894,12 +896,12 @@ class GraphModuleDeserializer:
else None # FX will generate a name for us.
)
args, kwargs = self.deserialize_inputs(target, serialized_node)
fx_node = graph.create_node("call_function", target, args, kwargs, name)
fx_node = self.graph.create_node("call_function", target, args, kwargs, name)
self.deserialize_outputs(serialized_node, fx_node)
else:
raise SerializeError(f"Unsupported target type for node {serialized_node}: {target}")
fx_node.meta.update(deserialize_metadata(serialized_node.metadata))
fx_node.meta.update(deserialize_metadata(serialized_node.metadata))
except Exception as e:
raise SerializeError(f"Failed deserializing node {serialized_node}") from e
@ -915,19 +917,32 @@ class GraphModuleDeserializer:
raise SerializeError(f"Unable to deserialize output node {output}")
output_node = graph.output(tuple(outputs))
output_node = self.graph.output(tuple(outputs))
output_node.meta["val"] = tuple(
arg.meta["val"] for arg in output_node.args[0]
)
def deserialize(
self,
serialized_graph_module: GraphModule,
symbol_name_to_range: Optional[Dict[str, symbolic_shapes.ValueRanges]] = None,
) -> Tuple[torch.fx.GraphModule, ep.ExportGraphSignature, ep.CallSpec, Dict[str, sympy.Symbol]]:
self.shape_env = symbolic_shapes.ShapeEnv()
self.fake_tensor_mode = FakeTensorMode(shape_env=self.shape_env)
self.symbol_name_to_symbol: Dict[str, sympy.Symbol] = {}
self.symbol_name_to_range = {} if symbol_name_to_range is None else symbol_name_to_range
self.deserialize_graph(serialized_graph_module.graph)
sig = deserialize_signature(serialized_graph_module.signature)
call_spec = deserialize_call_spec(serialized_graph_module.call_spec)
return torch.fx.GraphModule({}, graph), sig, call_spec, self.symbol_name_to_symbol
return torch.fx.GraphModule(self.module, self.graph), sig, call_spec, self.symbol_name_to_symbol
def sync_serialized_node(self, name: str, fx_node: torch.fx.Node):
def sync_fx_node(self, name: str, fx_node: torch.fx.Node):
if name in self.serialized_name_to_node:
raise SerializeError(f"Node {name} has already been deserialized before.")
self.serialized_name_to_node[name] = fx_node
assert "val" not in fx_node.meta
fx_node.meta["val"] = self.serialized_name_to_meta[name]
def deserialize_sym_op_inputs(self, inputs):
@ -962,6 +977,17 @@ class GraphModuleDeserializer:
return _SERIALIZE_TO_TORCH_MEMORY_FORMAT[value]
elif typ_ == "as_layout":
return _SERIALIZE_TO_TORCH_LAYOUT[value]
elif typ_ == "as_graph":
assert isinstance(value, GraphArgument)
with self.save_graph_module():
self.deserialize_graph(value.graph)
submodule = torch.fx.GraphModule(self.module, self.graph)
self.module.register_module(value.name, submodule)
return self.graph.create_node(
"get_attr",
value.name,
name=value.name,
)
elif isinstance(value, Device):
return deserialize_device(value)
elif isinstance(value, TensorArgument):
@ -991,20 +1017,22 @@ class GraphModuleDeserializer:
return self.serialized_name_to_node[sym_int_arg.as_name]
def deserialize_sym_op_outputs(self, serialized_node: Node, fx_node: torch.fx.Node):
self.sync_serialized_node(serialized_node.outputs[0].value.as_name, fx_node)
self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node)
def deserialize_outputs(self, serialized_node: Node, fx_node: torch.fx.Node) -> None:
def deserialize_outputs(self, serialized_node: Node, fx_node: torch.fx.Node):
# Simple case for single tensor return.
assert isinstance(fx_node.target, torch._ops.OpOverload)
returns = fx_node.target._schema.returns
# Check single value return
if len(returns) == 0:
return None
return
if _is_single_tensor_return(fx_node.target):
return self.sync_serialized_node(serialized_node.outputs[0].as_tensor.name, fx_node)
self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node)
return
elif len(returns) == 1 and isinstance(serialized_node.outputs[0].value, (SymIntArgument, SymBoolArgument)):
return self.sync_serialized_node(serialized_node.outputs[0].value.as_name, fx_node)
self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node)
return
# Convert multiple return types to FX format.
# In FX, each node only returns one value. So in order to represent
@ -1029,7 +1057,7 @@ class GraphModuleDeserializer:
(fx_node, idx),
name=name,
)
self.sync_serialized_node(name, individual_output)
self.sync_fx_node(name, individual_output)
# The derived `getitem` nodes should have the same stacktrace as the
# original `fx_node`
individual_output.meta.update(deserialize_metadata(serialized_node.metadata))
@ -1084,7 +1112,7 @@ class ExportedProgramDeserializer:
equality_constraints = deserialize_equality_constraints(serialized_exported_program.equality_constraints)
exported_program = ep.ExportedProgram(
state_dict,
graph_module,
graph_module.graph,
sig,
call_spec,
@ -1157,6 +1185,7 @@ def serialize(
def _dict_to_dataclass(cls, data):
assert not isinstance(cls, str), f"Unresolved class type: '{cls}'."
if isinstance(cls, type) and issubclass(cls, _Union):
obj = cls(**data)
field_type = cls.__annotations__[obj.type]
@ -1164,9 +1193,10 @@ def _dict_to_dataclass(cls, data):
return obj
elif dataclasses.is_dataclass(cls):
obj = cls(**data) # type: ignore[assignment]
type_hints = typing.get_type_hints(cls)
for f in dataclasses.fields(cls):
name = f.name
new_field_obj = _dict_to_dataclass(f.type, getattr(obj, name))
new_field_obj = _dict_to_dataclass(type_hints[name], getattr(obj, name))
setattr(obj, name, new_field_obj)
return obj
elif isinstance(data, list):