diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect b/test/expect/TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect index f01221172b7..eb6bd8aca3d 100644 --- a/test/expect/TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect +++ b/test/expect/TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect @@ -2,7 +2,7 @@ torch.fx._symbolic_trace.ProxyableClassMeta [] torch.fx._symbolic_trace.Tracer ['call_module', 'create_arg', 'create_args_for_root', 'is_leaf_module', 'path_of_module', 'trace'] torch.fx.graph.Graph ['call_function', 'call_method', 'call_module', 'create_node', 'eliminate_dead_code', 'erase_node', 'get_attr', 'graph_copy', 'inserting_after', 'inserting_before', 'lint', 'node_copy', 'nodes', 'on_generate_code', 'output', 'owning_module', 'placeholder', 'print_tabular', 'process_inputs', 'process_outputs', 'python_code', 'set_codegen'] torch.fx.graph.PythonCode [] -torch.fx.graph_module.GraphModule ['add_submodule', 'code', 'delete_all_unused_submodules', 'delete_submodule', 'graph', 'recompile', 'to_folder'] +torch.fx.graph_module.GraphModule ['add_submodule', 'code', 'delete_all_unused_submodules', 'delete_submodule', 'graph', 'nested_str', 'recompile', 'to_folder'] torch.fx.immutable_collections.immutable_dict ['clear', 'pop', 'popitem', 'update'] torch.fx.immutable_collections.immutable_list ['append', 'clear', 'extend', 'insert', 'pop', 'remove'] torch.fx.interpreter.Interpreter ['call_function', 'call_method', 'call_module', 'fetch_args_kwargs_from_env', 'fetch_attr', 'get_attr', 'map_nodes_to_values', 'output', 'placeholder', 'run', 'run_node'] diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 36daa56cc21..4f8bcc13d9e 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -706,6 +706,25 @@ class {module_name}(torch.nn.Module): def __copy__(self): return GraphModule(self, self.graph) + @compatibility(is_backward_compatible=False) + def nested_str(self) -> str: + """ + Return the Python code generated for current GraphModule and its children GraphModules + """ + module_code = self.code + module_code = module_code.lstrip('\n') + module_code = f"class {self._get_name()}(torch.nn.Module):\n" + module_code + module_code = _addindent(module_code, 4) + + submodule_code_list = [""] + for submodule in self.children(): + if isinstance(submodule, GraphModule): + submodule_code_list.append(submodule.__nested_code()) + submodule_code = "\n".join(submodule_code_list) + submodule_code = _addindent(submodule_code, 4) + + return module_code + submodule_code + def __str__(self) -> str: orig_str = super().__str__() return '\n'.join([orig_str, self._code]) diff --git a/torch/fx/passes/utils/fuser_utils.py b/torch/fx/passes/utils/fuser_utils.py index aa626f70831..f3d5f024216 100644 --- a/torch/fx/passes/utils/fuser_utils.py +++ b/torch/fx/passes/utils/fuser_utils.py @@ -167,13 +167,15 @@ def fuse_as_graphmodule(gm: GraphModule, return fused_gm, original_inputs, original_outputs + def insert_subgm(gm: GraphModule, sub_gm: GraphModule, orig_inputs: Tuple[Node, ...], orig_outputs: Tuple[Node, ...]): - # assign sub_gm into gm - setattr(gm, sub_gm.__class__.__name__, sub_gm) + # add sub_gm into gm + submodule_name = sub_gm.__class__.__name__ + gm.add_submodule(submodule_name, sub_gm) # Create a call_module node in main graph. module_node = gm.graph.call_module( - sub_gm.__class__.__name__, + submodule_name, args=orig_inputs, kwargs=None) @@ -185,8 +187,6 @@ def insert_subgm(gm: GraphModule, sub_gm: GraphModule, orig_inputs: Tuple[Node, # Use Proxy to record getitem access. proxy_out = torch.fx.Proxy(module_node)[i].node # type: ignore[index] orig_output.replace_all_uses_with(proxy_out) - - return gm def erase_nodes(gm: GraphModule, nodes: NodeList): @@ -196,7 +196,6 @@ def erase_nodes(gm: GraphModule, nodes: NodeList): gm.graph.erase_node(node) - def fuse_by_partitions(gm: GraphModule, partitions: List[NodeList]) -> GraphModule: for partition_id, nodes in enumerate(partitions): sorted_nodes = topo_sort(nodes)