[FX] Rename reduce functions back to their old, public names (#64324)

Summary:
Unfortunately pickle serializes the names of these functions. Also put them under backward-compatibility enforcement.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64324

Test Plan: Local repro https://fb.workplace.com/groups/3440841732711443/permalink/4018921611570116/

Reviewed By: SplitInfinity, TailofJune

Differential Revision: D30684185

Pulled By: jamesr66a

fbshipit-source-id: 900701220155d15115cd0c07cf7774a2891bd04f
This commit is contained in:
James Reed 2021-08-31 22:20:41 -07:00 committed by Facebook GitHub Bot
parent 05ecaefbbf
commit 0c4e4e588e
2 changed files with 12 additions and 7 deletions

View file

@ -28,6 +28,9 @@ torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.m
torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None
torch.fx.graph_module.GraphModule.delete_submodule(self, target: str) -> bool
torch.fx.graph_module.GraphModule.recompile(self) -> torch.fx.graph.PythonCode
torch.fx.graph_module.reduce_deploy_graph_module(importer: Callable, body: Dict[Any, Any], import_block: str) -> torch.nn.modules.module.Module
torch.fx.graph_module.reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.modules.module.Module
torch.fx.graph_module.reduce_package_graph_module(importer: Callable, body: Dict[Any, Any], generated_module_name: str) -> torch.nn.modules.module.Module
torch.fx.interpreter.Interpreter.__init__(self, module: torch.fx.graph_module.GraphModule, garbage_collect_values: bool = True)
torch.fx.interpreter.Interpreter.call_function(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any
torch.fx.interpreter.Interpreter.call_method(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any

View file

@ -96,7 +96,8 @@ def _format_import_block(globals: Dict[str, Any], importer: Importer):
return '\n'.join(import_strs)
def _reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module:
@compatibility(is_backward_compatible=True)
def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module:
# BC: attribute name was changed from `code` to `_code` to facilitate
# making `code` into a property and adding a docstring to it
fn_src = body.get('_code') or body['code']
@ -104,14 +105,15 @@ def _reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Mo
return _deserialize_graph_module(forward, body)
def _reduce_package_graph_module(
@compatibility(is_backward_compatible=True)
def reduce_package_graph_module(
importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str
) -> torch.nn.Module:
forward = importer.import_module(generated_module_name).forward
return _deserialize_graph_module(forward, body)
def _reduce_deploy_graph_module(
@compatibility(is_backward_compatible=True)
def reduce_deploy_graph_module(
importer: PackageImporter, body: Dict[Any, Any], import_block: str
) -> torch.nn.Module:
ns = dict()
@ -626,7 +628,7 @@ class {module_name}(torch.nn.Module):
python_code = self.recompile()
import_block = _format_import_block(python_code.globals, importer)
return (_reduce_deploy_graph_module, (dict_without_graph, import_block))
return (reduce_deploy_graph_module, (dict_without_graph, import_block))
def __reduce_package__(self, exporter: PackageExporter):
dict_without_graph = self.__dict__.copy()
@ -638,7 +640,7 @@ class {module_name}(torch.nn.Module):
import_block = _format_import_block(python_code.globals, exporter.importer)
module_code = import_block + self.code
exporter.save_source_string(generated_module_name, module_code)
return (_reduce_package_graph_module, (dict_without_graph, generated_module_name))
return (reduce_package_graph_module, (dict_without_graph, generated_module_name))
def __reduce__(self):
"""
@ -652,7 +654,7 @@ class {module_name}(torch.nn.Module):
python_code = self.recompile()
import_block = _format_import_block(python_code.globals, sys_importer)
del dict_without_graph['_graph']
return (_reduce_graph_module, (dict_without_graph, import_block))
return (reduce_graph_module, (dict_without_graph, import_block))
# because __reduce__ is defined for serialization,
# we need to define deepcopy otherwise it will call __reduce__