[dynamo] Alternative way to skip empty hooks guards on inbuilt nn modules (#131057)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131057
Approved by: https://github.com/williamwen42, https://github.com/jansel
ghstack dependencies: #131056
This commit is contained in:
Animesh Jain 2024-07-18 12:31:45 -07:00 committed by PyTorch MergeBot
parent 00e54e74ff
commit ac76dd606f
2 changed files with 30 additions and 22 deletions

View file

@ -85,7 +85,6 @@ from .source import (
GlobalStateSource,
GlobalWeakRefSource,
GradSource,
is_unspecialized_builtin_nnmodule_attr,
LocalSource,
NNModuleSource,
NumpyTensorSource,
@ -1684,6 +1683,13 @@ class GuardBuilder(GuardBuilderBase):
else:
self._produce_guard_code(guard, code)
def EMPTY_NN_MODULE_HOOKS_DICT(self, guard):
"""Special guard to skip guards on empty hooks. This is controlled by skip_nnmodule_hook_guards"""
if config.skip_nnmodule_hook_guards:
# This is unsafe if you add/remove a hook on nn module variable
return
self.SEQUENCE_LENGTH(guard)
def OBJECT_MUTATION(self, guard: Guard):
mutation_guard.watch(self.get(guard.name), self.check_fn_manager)
@ -2112,22 +2118,6 @@ class DeletedGuardFn:
pass
def is_nn_module_hook(source: Source) -> bool:
# Note that we only skip guards on builtin nn modules like Conv2D etc. But still this is a soundness issue if one
# adds/removes a hook after the model is compiled.
return (
is_unspecialized_builtin_nnmodule_attr(source)
and isinstance(source, AttrSource)
and source.member
in (
"_backward_hooks",
"_backward_pre_hooks",
"_forward_hooks",
"_forward_pre_hooks",
)
)
# NB: Naively, you'd expect this to only be a function that produces
# the callable that constitutes the guard. However, there is some
# delicate handling for invalidating this check function when the
@ -2188,11 +2178,6 @@ class CheckFunctionManager:
):
continue
# This is unsafe if you add/remove a hook on unspecialized nn module variable
if config.skip_nnmodule_hook_guards and is_nn_module_hook(
guard.originating_source
):
continue
guard.create(builder)
self.check_fn = self.compile_check_fn(builder, guards, guard_fail_fn)

View file

@ -1038,6 +1038,29 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
return dict_vt.maybe_getitem_const(name_vt)
return None
def var_getattr(self, tx, name):
# Allow skipping of empty hook dict guards on inbuilt nn modules
if name in (
"_backward_hooks",
"_backward_pre_hooks",
"_forward_hooks",
"_forward_pre_hooks",
):
if not tx.output.side_effects.has_pending_mutation_of_attr(
self, name
) and self.value.__module__.startswith(("torch.nn.", "torch.ao.")):
hooks_dict = getattr(self.value, name)
if isinstance(hooks_dict, dict) and len(hooks_dict) == 0:
if self.source:
hooks_source = AttrSource(self.source, name)
install_guard(
hooks_source.make_guard(
GuardBuilder.EMPTY_NN_MODULE_HOOKS_DICT
)
)
return variables.ConstDictVariable({})
return super().var_getattr(tx, name)
def manually_trace_nn_module_getattr(self, tx, name):
"""
Dynamo tracing of nn.Module __getattr__ can be expensive if the model