mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[FX] Rename reduce functions back to their old, public names (#64324)
Summary: Unfortunately pickle serializes the names of these functions. Also put them under backward-compatibility enforcement. Pull Request resolved: https://github.com/pytorch/pytorch/pull/64324 Test Plan: Local repro https://fb.workplace.com/groups/3440841732711443/permalink/4018921611570116/ Reviewed By: SplitInfinity, TailofJune Differential Revision: D30684185 Pulled By: jamesr66a fbshipit-source-id: 900701220155d15115cd0c07cf7774a2891bd04f
This commit is contained in:
parent
05ecaefbbf
commit
0c4e4e588e
2 changed files with 12 additions and 7 deletions
|
|
@ -28,6 +28,9 @@ torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.m
|
|||
torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None
|
||||
torch.fx.graph_module.GraphModule.delete_submodule(self, target: str) -> bool
|
||||
torch.fx.graph_module.GraphModule.recompile(self) -> torch.fx.graph.PythonCode
|
||||
torch.fx.graph_module.reduce_deploy_graph_module(importer: Callable, body: Dict[Any, Any], import_block: str) -> torch.nn.modules.module.Module
|
||||
torch.fx.graph_module.reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.modules.module.Module
|
||||
torch.fx.graph_module.reduce_package_graph_module(importer: Callable, body: Dict[Any, Any], generated_module_name: str) -> torch.nn.modules.module.Module
|
||||
torch.fx.interpreter.Interpreter.__init__(self, module: torch.fx.graph_module.GraphModule, garbage_collect_values: bool = True)
|
||||
torch.fx.interpreter.Interpreter.call_function(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any
|
||||
torch.fx.interpreter.Interpreter.call_method(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any
|
||||
|
|
|
|||
|
|
@ -96,7 +96,8 @@ def _format_import_block(globals: Dict[str, Any], importer: Importer):
|
|||
return '\n'.join(import_strs)
|
||||
|
||||
|
||||
def _reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module:
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module:
|
||||
# BC: attribute name was changed from `code` to `_code` to facilitate
|
||||
# making `code` into a property and adding a docstring to it
|
||||
fn_src = body.get('_code') or body['code']
|
||||
|
|
@ -104,14 +105,15 @@ def _reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Mo
|
|||
return _deserialize_graph_module(forward, body)
|
||||
|
||||
|
||||
def _reduce_package_graph_module(
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def reduce_package_graph_module(
|
||||
importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str
|
||||
) -> torch.nn.Module:
|
||||
forward = importer.import_module(generated_module_name).forward
|
||||
return _deserialize_graph_module(forward, body)
|
||||
|
||||
|
||||
def _reduce_deploy_graph_module(
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def reduce_deploy_graph_module(
|
||||
importer: PackageImporter, body: Dict[Any, Any], import_block: str
|
||||
) -> torch.nn.Module:
|
||||
ns = dict()
|
||||
|
|
@ -626,7 +628,7 @@ class {module_name}(torch.nn.Module):
|
|||
|
||||
python_code = self.recompile()
|
||||
import_block = _format_import_block(python_code.globals, importer)
|
||||
return (_reduce_deploy_graph_module, (dict_without_graph, import_block))
|
||||
return (reduce_deploy_graph_module, (dict_without_graph, import_block))
|
||||
|
||||
def __reduce_package__(self, exporter: PackageExporter):
|
||||
dict_without_graph = self.__dict__.copy()
|
||||
|
|
@ -638,7 +640,7 @@ class {module_name}(torch.nn.Module):
|
|||
import_block = _format_import_block(python_code.globals, exporter.importer)
|
||||
module_code = import_block + self.code
|
||||
exporter.save_source_string(generated_module_name, module_code)
|
||||
return (_reduce_package_graph_module, (dict_without_graph, generated_module_name))
|
||||
return (reduce_package_graph_module, (dict_without_graph, generated_module_name))
|
||||
|
||||
def __reduce__(self):
|
||||
"""
|
||||
|
|
@ -652,7 +654,7 @@ class {module_name}(torch.nn.Module):
|
|||
python_code = self.recompile()
|
||||
import_block = _format_import_block(python_code.globals, sys_importer)
|
||||
del dict_without_graph['_graph']
|
||||
return (_reduce_graph_module, (dict_without_graph, import_block))
|
||||
return (reduce_graph_module, (dict_without_graph, import_block))
|
||||
|
||||
# because __reduce__ is defined for serialization,
|
||||
# we need to define deepcopy otherwise it will call __reduce__
|
||||
|
|
|
|||
Loading…
Reference in a new issue