[Reland] Recursively print graph module and its submodule (#81639)

ghstack-source-id: fcfc024c440981ee3fe3537a5816089eadf2cc13
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81080

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81639
Approved by: https://github.com/ezyang
This commit is contained in:
Sherlock Huang 2022-07-21 16:58:25 +00:00 committed by PyTorch MergeBot
parent 8d0cbce069
commit 43e7fee764
3 changed files with 25 additions and 7 deletions

View file

@ -2,7 +2,7 @@ torch.fx._symbolic_trace.ProxyableClassMeta []
torch.fx._symbolic_trace.Tracer ['call_module', 'create_arg', 'create_args_for_root', 'is_leaf_module', 'path_of_module', 'trace']
torch.fx.graph.Graph ['call_function', 'call_method', 'call_module', 'create_node', 'eliminate_dead_code', 'erase_node', 'get_attr', 'graph_copy', 'inserting_after', 'inserting_before', 'lint', 'node_copy', 'nodes', 'on_generate_code', 'output', 'owning_module', 'placeholder', 'print_tabular', 'process_inputs', 'process_outputs', 'python_code', 'set_codegen']
torch.fx.graph.PythonCode []
torch.fx.graph_module.GraphModule ['add_submodule', 'code', 'delete_all_unused_submodules', 'delete_submodule', 'graph', 'recompile', 'to_folder']
torch.fx.graph_module.GraphModule ['add_submodule', 'code', 'delete_all_unused_submodules', 'delete_submodule', 'graph', 'nested_str', 'recompile', 'to_folder']
torch.fx.immutable_collections.immutable_dict ['clear', 'pop', 'popitem', 'update']
torch.fx.immutable_collections.immutable_list ['append', 'clear', 'extend', 'insert', 'pop', 'remove']
torch.fx.interpreter.Interpreter ['call_function', 'call_method', 'call_module', 'fetch_args_kwargs_from_env', 'fetch_attr', 'get_attr', 'map_nodes_to_values', 'output', 'placeholder', 'run', 'run_node']

View file

@ -706,6 +706,25 @@ class {module_name}(torch.nn.Module):
def __copy__(self):
return GraphModule(self, self.graph)
@compatibility(is_backward_compatible=False)
def nested_str(self) -> str:
"""
Return the Python code generated for current GraphModule and its children GraphModules
"""
module_code = self.code
module_code = module_code.lstrip('\n')
module_code = f"class {self._get_name()}(torch.nn.Module):\n" + module_code
module_code = _addindent(module_code, 4)
submodule_code_list = [""]
for submodule in self.children():
if isinstance(submodule, GraphModule):
submodule_code_list.append(submodule.__nested_code())
submodule_code = "\n".join(submodule_code_list)
submodule_code = _addindent(submodule_code, 4)
return module_code + submodule_code
def __str__(self) -> str:
orig_str = super().__str__()
return '\n'.join([orig_str, self._code])

View file

@ -167,13 +167,15 @@ def fuse_as_graphmodule(gm: GraphModule,
return fused_gm, original_inputs, original_outputs
def insert_subgm(gm: GraphModule, sub_gm: GraphModule, orig_inputs: Tuple[Node, ...], orig_outputs: Tuple[Node, ...]):
# assign sub_gm into gm
setattr(gm, sub_gm.__class__.__name__, sub_gm)
# add sub_gm into gm
submodule_name = sub_gm.__class__.__name__
gm.add_submodule(submodule_name, sub_gm)
# Create a call_module node in main graph.
module_node = gm.graph.call_module(
sub_gm.__class__.__name__,
submodule_name,
args=orig_inputs,
kwargs=None)
@ -185,8 +187,6 @@ def insert_subgm(gm: GraphModule, sub_gm: GraphModule, orig_inputs: Tuple[Node,
# Use Proxy to record getitem access.
proxy_out = torch.fx.Proxy(module_node)[i].node # type: ignore[index]
orig_output.replace_all_uses_with(proxy_out)
return gm
def erase_nodes(gm: GraphModule, nodes: NodeList):
@ -196,7 +196,6 @@ def erase_nodes(gm: GraphModule, nodes: NodeList):
gm.graph.erase_node(node)
def fuse_by_partitions(gm: GraphModule, partitions: List[NodeList]) -> GraphModule:
for partition_id, nodes in enumerate(partitions):
sorted_nodes = topo_sort(nodes)