mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[Reland] Recursively print graph module and its submodule (#81639)
ghstack-source-id: fcfc024c440981ee3fe3537a5816089eadf2cc13 Pull Request resolved: https://github.com/pytorch/pytorch/pull/81080 Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/81639 Approved by: https://github.com/ezyang
This commit is contained in:
parent
8d0cbce069
commit
43e7fee764
3 changed files with 25 additions and 7 deletions
|
|
@ -2,7 +2,7 @@ torch.fx._symbolic_trace.ProxyableClassMeta []
|
|||
torch.fx._symbolic_trace.Tracer ['call_module', 'create_arg', 'create_args_for_root', 'is_leaf_module', 'path_of_module', 'trace']
|
||||
torch.fx.graph.Graph ['call_function', 'call_method', 'call_module', 'create_node', 'eliminate_dead_code', 'erase_node', 'get_attr', 'graph_copy', 'inserting_after', 'inserting_before', 'lint', 'node_copy', 'nodes', 'on_generate_code', 'output', 'owning_module', 'placeholder', 'print_tabular', 'process_inputs', 'process_outputs', 'python_code', 'set_codegen']
|
||||
torch.fx.graph.PythonCode []
|
||||
torch.fx.graph_module.GraphModule ['add_submodule', 'code', 'delete_all_unused_submodules', 'delete_submodule', 'graph', 'recompile', 'to_folder']
|
||||
torch.fx.graph_module.GraphModule ['add_submodule', 'code', 'delete_all_unused_submodules', 'delete_submodule', 'graph', 'nested_str', 'recompile', 'to_folder']
|
||||
torch.fx.immutable_collections.immutable_dict ['clear', 'pop', 'popitem', 'update']
|
||||
torch.fx.immutable_collections.immutable_list ['append', 'clear', 'extend', 'insert', 'pop', 'remove']
|
||||
torch.fx.interpreter.Interpreter ['call_function', 'call_method', 'call_module', 'fetch_args_kwargs_from_env', 'fetch_attr', 'get_attr', 'map_nodes_to_values', 'output', 'placeholder', 'run', 'run_node']
|
||||
|
|
|
|||
|
|
@ -706,6 +706,25 @@ class {module_name}(torch.nn.Module):
|
|||
def __copy__(self):
|
||||
return GraphModule(self, self.graph)
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def nested_str(self) -> str:
|
||||
"""
|
||||
Return the Python code generated for current GraphModule and its children GraphModules
|
||||
"""
|
||||
module_code = self.code
|
||||
module_code = module_code.lstrip('\n')
|
||||
module_code = f"class {self._get_name()}(torch.nn.Module):\n" + module_code
|
||||
module_code = _addindent(module_code, 4)
|
||||
|
||||
submodule_code_list = [""]
|
||||
for submodule in self.children():
|
||||
if isinstance(submodule, GraphModule):
|
||||
submodule_code_list.append(submodule.__nested_code())
|
||||
submodule_code = "\n".join(submodule_code_list)
|
||||
submodule_code = _addindent(submodule_code, 4)
|
||||
|
||||
return module_code + submodule_code
|
||||
|
||||
def __str__(self) -> str:
|
||||
orig_str = super().__str__()
|
||||
return '\n'.join([orig_str, self._code])
|
||||
|
|
|
|||
|
|
@ -167,13 +167,15 @@ def fuse_as_graphmodule(gm: GraphModule,
|
|||
|
||||
return fused_gm, original_inputs, original_outputs
|
||||
|
||||
|
||||
def insert_subgm(gm: GraphModule, sub_gm: GraphModule, orig_inputs: Tuple[Node, ...], orig_outputs: Tuple[Node, ...]):
|
||||
# assign sub_gm into gm
|
||||
setattr(gm, sub_gm.__class__.__name__, sub_gm)
|
||||
# add sub_gm into gm
|
||||
submodule_name = sub_gm.__class__.__name__
|
||||
gm.add_submodule(submodule_name, sub_gm)
|
||||
|
||||
# Create a call_module node in main graph.
|
||||
module_node = gm.graph.call_module(
|
||||
sub_gm.__class__.__name__,
|
||||
submodule_name,
|
||||
args=orig_inputs,
|
||||
kwargs=None)
|
||||
|
||||
|
|
@ -185,8 +187,6 @@ def insert_subgm(gm: GraphModule, sub_gm: GraphModule, orig_inputs: Tuple[Node,
|
|||
# Use Proxy to record getitem access.
|
||||
proxy_out = torch.fx.Proxy(module_node)[i].node # type: ignore[index]
|
||||
orig_output.replace_all_uses_with(proxy_out)
|
||||
|
||||
|
||||
return gm
|
||||
|
||||
def erase_nodes(gm: GraphModule, nodes: NodeList):
|
||||
|
|
@ -196,7 +196,6 @@ def erase_nodes(gm: GraphModule, nodes: NodeList):
|
|||
gm.graph.erase_node(node)
|
||||
|
||||
|
||||
|
||||
def fuse_by_partitions(gm: GraphModule, partitions: List[NodeList]) -> GraphModule:
|
||||
for partition_id, nodes in enumerate(partitions):
|
||||
sorted_nodes = topo_sort(nodes)
|
||||
|
|
|
|||
Loading…
Reference in a new issue