diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index e10b2105742..67495f5a522 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -70,7 +70,7 @@ class TestSerialize(TestCase): serialized, _ = ExportedProgramSerializer().serialize(exported_module) node = serialized.graph_module.graph.nodes[-7] - self.assertEqual(node.target, "torch._ops.aten.var_mean.correction") + self.assertEqual(node.target, "torch.ops.aten.var_mean.correction") # aten::native_layer_norm returns 3 tensnors self.assertEqual(len(node.outputs), 2) @@ -95,7 +95,7 @@ class TestSerialize(TestCase): serialized, _ = ExportedProgramSerializer().serialize(exported_module) node = serialized.graph_module.graph.nodes[-1] - self.assertEqual(node.target, "torch._ops.aten.split.Tensor") + self.assertEqual(node.target, "torch.ops.aten.split.Tensor") self.assertEqual(len(node.outputs), 1) # Input looks like: # tensor([[0, 1], @@ -138,7 +138,7 @@ class TestSerialize(TestCase): serialized, _ = ExportedProgramSerializer().serialize(exported_module) node = serialized.graph_module.graph.nodes[-1] - self.assertEqual(node.target, "torch._ops.aten.var_mean.correction") + self.assertEqual(node.target, "torch.ops.aten.var_mean.correction") self.assertEqual(len(node.outputs), 2) # check the names are unique @@ -163,7 +163,7 @@ class TestSerialize(TestCase): serialized, _ = ExportedProgramSerializer().serialize(exported_module) node = serialized.graph_module.graph.nodes[-1] - self.assertEqual(node.target, "torch._ops.aten.searchsorted.Tensor") + self.assertEqual(node.target, "torch.ops.aten.searchsorted.Tensor") self.assertEqual(len(node.inputs), 6) self.assertEqual(node.inputs[2].arg.as_bool, False) self.assertEqual(node.inputs[3].arg.as_bool, True) @@ -194,6 +194,7 @@ class TestDeserialize(TestCase): else: self.assertEqual(orig, loaded) + self.assertEqual(len(ep.graph.nodes), len(deserialized_ep.graph.nodes)) for node1, node2 in zip(ep.graph.nodes, deserialized_ep.graph.nodes): # Check "val" metadata val1 = node1.meta.get("val", None) @@ -240,10 +241,11 @@ class TestDeserialize(TestCase): ) # Check "source_fn" metadata - self.assertEqual( - node1.meta.get("source_fn", None), - node2.meta.get("source_fn", None), - ) + if node1.op != "get_attr": + self.assertEqual( + node1.meta.get("source_fn", None), + node2.meta.get("source_fn", None), + ) def test_multi_return(self) -> None: """ @@ -341,6 +343,21 @@ class TestDeserialize(TestCase): inputs = (torch.randn(3, 3),) self.check_graph(M(), inputs) + def test_cond(self): + from functorch.experimental.control_flow import cond + inputs = torch.ones(4, 3), torch.zeros(4, 3) + + class M(torch.nn.Module): + def forward(self, x, y): + def t(x, y): + return x + y + + def f(x, y): + return x - y + return cond(x[0][0] > 4, t, f, [x, y]) + + self.check_graph(M(), inputs) + @parametrize( "name,case", get_filtered_export_db_tests(), diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 8cb264a31ed..31fe6600ac4 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -253,17 +253,23 @@ def deserialize_metadata(metadata) -> Dict[str, str]: def serialize_operator(target) -> str: if isinstance(target, str): return target - return f"{target.__module__}.{target.__name__}" + elif target.__module__.startswith("torch._ops"): + # TODO(zhxchen17) Maybe provide a function name helper in FX. + # From torch.fx.node._get_qualified_name + module = target.__module__.replace("torch._ops", "torch.ops") + return f"{module}.{target.__name__}" + else: # TODO(zhxchen17) Don't catch all here. + return f"{target.__module__}.{target.__name__}" def deserialize_operator(serialized_target: str): - if serialized_target.startswith("_operator"): + if serialized_target.startswith("_operator"): # TODO(zhxchen17) Follow up on this. module = operator serialized_target_names = serialized_target.split(".")[1:] - elif serialized_target.startswith("torch._ops"): + elif serialized_target.startswith("torch.ops"): module = torch.ops serialized_target_names = serialized_target.split(".")[2:] - else: + else: # TODO(zhxchen17) Don't catch all here. return serialized_target target = module @@ -775,6 +781,19 @@ class GraphModuleDeserializer: self.serialized_name_to_node: Dict[str, torch.fx.Node] = {} self.serialized_name_to_meta: Dict[str, MetaType] = {} self.graph = torch.fx.Graph() + self.module = torch.nn.Module() + + @contextmanager + def save_graph_module(self) -> None: + saved = self.graph, self.module, self.serialized_name_to_node, self.serialized_name_to_meta + self.graph = torch.fx.Graph() + self.module = torch.nn.Module() + self.serialized_name_to_node = {} + self.serialized_name_to_meta = {} + try: + yield + finally: + self.graph, self.module, self.serialized_name_to_node, self.serialized_name_to_meta = saved def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: val = s.value @@ -830,19 +849,7 @@ class GraphModuleDeserializer: ), ) - def deserialize( - self, - serialized_graph_module: GraphModule, - symbol_name_to_range: Optional[Dict[str, symbolic_shapes.ValueRanges]] = None, - ) -> Tuple[torch.fx.GraphModule, ep.ExportGraphSignature, ep.CallSpec, Dict[str, sympy.Symbol]]: - self.shape_env = symbolic_shapes.ShapeEnv() - self.fake_tensor_mode = FakeTensorMode(shape_env=self.shape_env) - self.symbol_name_to_symbol: Dict[str, sympy.Symbol] = {} - self.symbol_name_to_range = {} if symbol_name_to_range is None else symbol_name_to_range - - graph = self.graph - serialized_graph = serialized_graph_module.graph - + def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph: # Handle the tensor metas. for name, tensor_value in serialized_graph.tensor_values.items(): meta_val = self.deserialize_tensor_meta(tensor_value.meta, self.fake_tensor_mode) @@ -856,35 +863,30 @@ class GraphModuleDeserializer: # Inputs: convert to placeholder nodes in FX. for input in serialized_graph.inputs: - placeholder_node = graph.placeholder(input.as_tensor.name) - self.sync_serialized_node(input.as_tensor.name, placeholder_node) + placeholder_node = self.graph.placeholder(input.as_tensor.name) + self.sync_fx_node(input.as_tensor.name, placeholder_node) # Nodes: convert to call_function nodes. for serialized_node in serialized_graph.nodes: try: target = deserialize_operator(serialized_node.target) - if isinstance(target, str): - # Create a dummy fake op if the target does not exist - # because we cannot create a call_function node w/o a - # callable target - log.warning(f"Could not find operator {target}. Returning fake operator.") # noqa: G004 - def fake_op(x): - raise NotImplementedError("Fake op is not meant to be run.") - fake_op.__name__ = target - target = fake_op - - if target.__module__ == "_operator": + if target.__module__ == "_operator": # TODO(zhxchen17) Follow up on this. name = serialized_node.outputs[0].value.as_name args = self.deserialize_sym_op_inputs(serialized_node.inputs) - fx_node = graph.create_node("call_function", target, args, {}, name) + fx_node = self.graph.create_node("call_function", target, args, {}, name) self.deserialize_sym_op_outputs(serialized_node, fx_node) - fx_node.meta.update(deserialize_metadata(serialized_node.metadata)) - - else: - target = deserialize_operator(serialized_node.target) - + elif isinstance(target, torch._ops.HigherOrderOperator): + assert ( + len(serialized_node.outputs) == 1 + and serialized_node.outputs[0].as_tensor is not None + ), "Only single tensor output is supported for higher order operators." + name = serialized_node.outputs[0].as_tensor.name + args = tuple(self.deserialize_input(input.arg) for input in serialized_node.inputs) + fx_node = self.graph.create_node("call_function", target, args, {}, name) + self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node) + elif isinstance(target, torch._ops.OpOverload): # For convenience: if this node returns a single tensor, name the # newly-created node after it. This ensures that these tensor values # have names that are consistent with serialized. @@ -894,12 +896,12 @@ class GraphModuleDeserializer: else None # FX will generate a name for us. ) args, kwargs = self.deserialize_inputs(target, serialized_node) - - fx_node = graph.create_node("call_function", target, args, kwargs, name) - + fx_node = self.graph.create_node("call_function", target, args, kwargs, name) self.deserialize_outputs(serialized_node, fx_node) + else: + raise SerializeError(f"Unsupported target type for node {serialized_node}: {target}") - fx_node.meta.update(deserialize_metadata(serialized_node.metadata)) + fx_node.meta.update(deserialize_metadata(serialized_node.metadata)) except Exception as e: raise SerializeError(f"Failed deserializing node {serialized_node}") from e @@ -915,19 +917,32 @@ class GraphModuleDeserializer: raise SerializeError(f"Unable to deserialize output node {output}") - output_node = graph.output(tuple(outputs)) + output_node = self.graph.output(tuple(outputs)) output_node.meta["val"] = tuple( arg.meta["val"] for arg in output_node.args[0] ) + def deserialize( + self, + serialized_graph_module: GraphModule, + symbol_name_to_range: Optional[Dict[str, symbolic_shapes.ValueRanges]] = None, + ) -> Tuple[torch.fx.GraphModule, ep.ExportGraphSignature, ep.CallSpec, Dict[str, sympy.Symbol]]: + self.shape_env = symbolic_shapes.ShapeEnv() + self.fake_tensor_mode = FakeTensorMode(shape_env=self.shape_env) + self.symbol_name_to_symbol: Dict[str, sympy.Symbol] = {} + self.symbol_name_to_range = {} if symbol_name_to_range is None else symbol_name_to_range + + self.deserialize_graph(serialized_graph_module.graph) + sig = deserialize_signature(serialized_graph_module.signature) call_spec = deserialize_call_spec(serialized_graph_module.call_spec) - return torch.fx.GraphModule({}, graph), sig, call_spec, self.symbol_name_to_symbol + return torch.fx.GraphModule(self.module, self.graph), sig, call_spec, self.symbol_name_to_symbol - def sync_serialized_node(self, name: str, fx_node: torch.fx.Node): + def sync_fx_node(self, name: str, fx_node: torch.fx.Node): if name in self.serialized_name_to_node: raise SerializeError(f"Node {name} has already been deserialized before.") self.serialized_name_to_node[name] = fx_node + assert "val" not in fx_node.meta fx_node.meta["val"] = self.serialized_name_to_meta[name] def deserialize_sym_op_inputs(self, inputs): @@ -962,6 +977,17 @@ class GraphModuleDeserializer: return _SERIALIZE_TO_TORCH_MEMORY_FORMAT[value] elif typ_ == "as_layout": return _SERIALIZE_TO_TORCH_LAYOUT[value] + elif typ_ == "as_graph": + assert isinstance(value, GraphArgument) + with self.save_graph_module(): + self.deserialize_graph(value.graph) + submodule = torch.fx.GraphModule(self.module, self.graph) + self.module.register_module(value.name, submodule) + return self.graph.create_node( + "get_attr", + value.name, + name=value.name, + ) elif isinstance(value, Device): return deserialize_device(value) elif isinstance(value, TensorArgument): @@ -991,20 +1017,22 @@ class GraphModuleDeserializer: return self.serialized_name_to_node[sym_int_arg.as_name] def deserialize_sym_op_outputs(self, serialized_node: Node, fx_node: torch.fx.Node): - self.sync_serialized_node(serialized_node.outputs[0].value.as_name, fx_node) + self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node) - def deserialize_outputs(self, serialized_node: Node, fx_node: torch.fx.Node) -> None: + def deserialize_outputs(self, serialized_node: Node, fx_node: torch.fx.Node): # Simple case for single tensor return. assert isinstance(fx_node.target, torch._ops.OpOverload) returns = fx_node.target._schema.returns # Check single value return if len(returns) == 0: - return None + return if _is_single_tensor_return(fx_node.target): - return self.sync_serialized_node(serialized_node.outputs[0].as_tensor.name, fx_node) + self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node) + return elif len(returns) == 1 and isinstance(serialized_node.outputs[0].value, (SymIntArgument, SymBoolArgument)): - return self.sync_serialized_node(serialized_node.outputs[0].value.as_name, fx_node) + self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node) + return # Convert multiple return types to FX format. # In FX, each node only returns one value. So in order to represent @@ -1029,7 +1057,7 @@ class GraphModuleDeserializer: (fx_node, idx), name=name, ) - self.sync_serialized_node(name, individual_output) + self.sync_fx_node(name, individual_output) # The derived `getitem` nodes should have the same stacktrace as the # original `fx_node` individual_output.meta.update(deserialize_metadata(serialized_node.metadata)) @@ -1084,7 +1112,7 @@ class ExportedProgramDeserializer: equality_constraints = deserialize_equality_constraints(serialized_exported_program.equality_constraints) exported_program = ep.ExportedProgram( - state_dict, + graph_module, graph_module.graph, sig, call_spec, @@ -1157,6 +1185,7 @@ def serialize( def _dict_to_dataclass(cls, data): + assert not isinstance(cls, str), f"Unresolved class type: '{cls}'." if isinstance(cls, type) and issubclass(cls, _Union): obj = cls(**data) field_type = cls.__annotations__[obj.type] @@ -1164,9 +1193,10 @@ def _dict_to_dataclass(cls, data): return obj elif dataclasses.is_dataclass(cls): obj = cls(**data) # type: ignore[assignment] + type_hints = typing.get_type_hints(cls) for f in dataclasses.fields(cls): name = f.name - new_field_obj = _dict_to_dataclass(f.type, getattr(obj, name)) + new_field_obj = _dict_to_dataclass(type_hints[name], getattr(obj, name)) setattr(obj, name, new_field_obj) return obj elif isinstance(data, list):