mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Skip nnmodule hook guards by default (#98371)
This PR makes basic nnmodule forward hooks work by default, without any overhead. But it leaves silent correctness issues if users modify/remove their hooks later, thus also emits a warning. - the usual case is to not use hooks, so avoid guard overhead here - registering any hook before compile will trigger a warning about hook support - registering a hook later (or removing one) requires user knowledge and opting in, currently this isn't warnable (but maybe we can observe compiled nnmodules to make it warnable). Why skip hook guards by default instead of not tracing __call__/hooks by default? - avoid having a mode flag that alters dynamo tracing behavior (harder to test both codepaths in CI with full coverage) - the most basic hook usecase (registering a hook before compile, and never removing it) will work by default with this PR, while it would require enablement and incur overhead in the 'not tracing __call__' proposal. Pull Request resolved: https://github.com/pytorch/pytorch/pull/98371 Approved by: https://github.com/jansel
This commit is contained in:
parent
46d765c15e
commit
390c51bf87
6 changed files with 113 additions and 23 deletions
|
|
@ -78,3 +78,4 @@ please check out the references below.
|
|||
|
||||
get-started
|
||||
technical-overview
|
||||
nn-module
|
||||
|
|
|
|||
47
docs/source/compile/nn-module.rst
Normal file
47
docs/source/compile/nn-module.rst
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
PyTorch 2.0 NNModule Support
|
||||
============================
|
||||
|
||||
**Author**: `Will Constable <https://github.com/wconstab>`_
|
||||
|
||||
`torch.compile` has special handling for torch.nn.Module objects, tracing them differently than it traces
|
||||
arbitrary python classes, with the intent of producing faster code by making assumptions about the structure.
|
||||
|
||||
This doc describes some of the tradeoffs or edge cases that come up due to this specialization.
|
||||
|
||||
NNModule Hooks Support
|
||||
----------------------
|
||||
Previously, `torch.compile` had no support for hooks on nn.Modules, and if hooks were registered
|
||||
they would simply be ignored in the compiled program. Indeed many users do not
|
||||
use nn.Module hooks at all, or only use them for debug workflows, but there are valid use cases
|
||||
for composing nn.Module hooks with `torch.compile`.
|
||||
|
||||
Hooks that are orchestrated via nn.Module.__call__ implementation include `_forward_pre_hooks`,
|
||||
`forward_hooks`, `_backward_pre_hooks`, and `_backward_hooks`, and will be referred to as 'call hooks'.
|
||||
These hooks are partially supported by `torch.compile` with limitations described below.
|
||||
|
||||
Another category of hooks includes `_state_dict_hooks` and its `pre` and `load_` variants, and are still
|
||||
unsupported by `torch.compile`.
|
||||
|
||||
`nn.Module.__call__` Hooks Usage and limitations
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
By default, `torch.compile` will trace the contents of `nn.Module.__call__` which means it will encounter
|
||||
and run forward/pre-forward hooks. If you install hooks before calling `torch.compile` and then do not remove
|
||||
or alter the hooks later, your use case should be supported by default.
|
||||
|
||||
**skip_nnmodule_hook_guards**
|
||||
By default, `torch._dynamo.config.skip_nnmodule_hook_guards` is set to True, meaning no guards will be installed
|
||||
on each nn.Module hook dictionary, improving runtime by reducing guard execution time, at the cost of not noticing
|
||||
if any hook dict is changed after compilation.
|
||||
|
||||
If you want to be able to remove or modify hooks after compilation and have `torch.compile` react appropriately
|
||||
(by recompiling), then you need to set `skip_nnmodule_hook_guards=False` and expect a runtime penalty for the added
|
||||
guards.
|
||||
|
||||
TODO: confirm if backward/pre_backward hooks are working or not and document accordingly
|
||||
|
||||
state_dict Hooks
|
||||
~~~~~~~~~~~~~~~~
|
||||
State dict hooks have not yet been supported in `torch.compile`.
|
||||
|
||||
|
||||
TODO: warn_once if graph-breaking on hooks. warn_once to point to this doc if hooks are present.
|
||||
|
|
@ -1308,6 +1308,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
|||
)
|
||||
)
|
||||
|
||||
@patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False)
|
||||
def test_hooks_outer(self):
|
||||
class TestModule(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
|
@ -1354,6 +1355,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
|||
the eval_frame entrypoint to Module.__call__?
|
||||
"""
|
||||
|
||||
@patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False)
|
||||
def test_hooks_inner(self):
|
||||
class TestModule(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
|
|
|||
|
|
@ -201,7 +201,7 @@ skip_fsdp_guards = True
|
|||
# Make dynamo skip guarding on hooks on nn modules
|
||||
# Note: unsafe: if your model actually has hooks and you remove them, or doesn't and you add them,
|
||||
# dynamo will not notice and will execute whichever version you first compiled.
|
||||
skip_nnmodule_hook_guards = False
|
||||
skip_nnmodule_hook_guards = True
|
||||
|
||||
# If True, raises exception if TorchDynamo is called with a context manager
|
||||
raise_on_ctx_manager_usage = True
|
||||
|
|
|
|||
|
|
@ -56,6 +56,8 @@ from .utils import (
|
|||
dynamo_timed,
|
||||
format_graph_code,
|
||||
format_graph_tabular,
|
||||
nnmodule_doc_url_msg,
|
||||
nnmodule_has_hooks,
|
||||
same,
|
||||
)
|
||||
from .variables.base import VariableTracker
|
||||
|
|
@ -382,26 +384,6 @@ class OutputGraph(fx.Tracer, Checkpointable[OutputGraphState]):
|
|||
if name not in self.code_options["co_names"]:
|
||||
self.code_options["co_names"] += (name,)
|
||||
|
||||
@staticmethod
|
||||
def module_has_hooks(mod, only_check_unsupported=False):
|
||||
supported_hooks = [
|
||||
"_forward_pre_hooks",
|
||||
"_forward_hooks",
|
||||
]
|
||||
unsupported_hooks = [
|
||||
"_backward_pre_hooks",
|
||||
"_backward_hooks",
|
||||
"_state_dict_pre_hooks",
|
||||
"_state_dict_hooks",
|
||||
"_load_state_dict_pre_hooks",
|
||||
"_load_state_dict_post_hooks",
|
||||
]
|
||||
check_hooks = unsupported_hooks
|
||||
if not only_check_unsupported:
|
||||
check_hooks += supported_hooks
|
||||
|
||||
return any(len(getattr(mod, x)) > 0 for x in check_hooks if hasattr(mod, x))
|
||||
|
||||
def register_attr_or_module(
|
||||
self,
|
||||
target: Union[torch.nn.Module, torch.Tensor, Any],
|
||||
|
|
@ -433,10 +415,24 @@ class OutputGraph(fx.Tracer, Checkpointable[OutputGraphState]):
|
|||
|
||||
elif isinstance(target, torch.nn.Module):
|
||||
assert isinstance(target, torch.nn.Module)
|
||||
if self.module_has_hooks(target, only_check_unsupported=True):
|
||||
if nnmodule_has_hooks(target, check_forward_hooks=True):
|
||||
torch._logging.warning_once(
|
||||
log, "nn.Module hooks are not fully supported, they may be ignored"
|
||||
log,
|
||||
"nn.Module forward/_pre hooks are only partially supported, and were detected in your model. "
|
||||
"In particular, if you do not change/remove hooks after calling .compile(), you can disregard this "
|
||||
"warning, and otherwise you may need to set torch._dynamo.config.skip_nnmodule_hook_guards=False "
|
||||
"to ensure recompiling after changing hooks."
|
||||
f"{nnmodule_doc_url_msg} ",
|
||||
)
|
||||
if nnmodule_has_hooks(
|
||||
target, check_backward_hooks=True, check_state_dict_hooks=True
|
||||
):
|
||||
torch._logging.warning_once(
|
||||
log,
|
||||
"nn.Module state_dict and backward hooks are not yet supported by torch.compile, "
|
||||
f"but were detected in your model and will be silently ignored. {nnmodule_doc_url_msg}",
|
||||
)
|
||||
|
||||
options["guards"].add(source.make_guard(GuardBuilder.NN_MODULE))
|
||||
|
||||
def wrap_name(module_key):
|
||||
|
|
|
|||
|
|
@ -49,6 +49,8 @@ from torch.utils._pytree import tree_map
|
|||
|
||||
counters = collections.defaultdict(collections.Counter)
|
||||
troubleshooting_url = "https://pytorch.org/docs/master/compile/troubleshooting.html"
|
||||
nnmodule_doc_url = "https://pytorch.org/docs/master/compile/nn-module.html"
|
||||
nnmodule_doc_url_msg = f"See {nnmodule_doc_url} for more information and limitations."
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -1439,3 +1441,45 @@ def format_graph_tabular(fn_name, gm):
|
|||
|
||||
def format_bytecode(prefix, name, filename, line_no, code):
|
||||
return f"{prefix} {name} {filename} line {line_no} \n{dis.Bytecode(code).dis()}\n"
|
||||
|
||||
|
||||
def nnmodule_has_hooks(
|
||||
mod,
|
||||
check_forward_hooks=False,
|
||||
check_backward_hooks=False,
|
||||
check_state_dict_hooks=False,
|
||||
):
|
||||
"""
|
||||
Sometimes its useful to differentiate between types of hooks such as forward/backward/pre
|
||||
hooks executed during module.__call__, and state_dict hooks which are executed separately.
|
||||
"""
|
||||
hook_dicts_to_check = []
|
||||
check_all_hooks = (
|
||||
not check_forward_hooks
|
||||
and not check_backward_hooks
|
||||
and not check_state_dict_hooks
|
||||
)
|
||||
if check_forward_hooks or check_all_hooks:
|
||||
hook_dicts_to_check.extend(
|
||||
[
|
||||
"_forward_pre_hooks",
|
||||
"_forward_hooks",
|
||||
]
|
||||
)
|
||||
if check_backward_hooks or check_all_hooks:
|
||||
hook_dicts_to_check.extend(
|
||||
[
|
||||
"_backward_pre_hooks",
|
||||
"_backward_hooks",
|
||||
]
|
||||
)
|
||||
if check_state_dict_hooks:
|
||||
hook_dicts_to_check.extend(
|
||||
[
|
||||
"_state_dict_pre_hooks",
|
||||
"_state_dict_hooks",
|
||||
"_load_state_dict_pre_hooks",
|
||||
"_load_state_dict_post_hooks",
|
||||
]
|
||||
)
|
||||
return any(len(getattr(mod, x)) > 0 for x in hook_dicts_to_check if hasattr(mod, x))
|
||||
|
|
|
|||
Loading…
Reference in a new issue