mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[dynamo] Alternative way to skip empty hooks guards on inbuilt nn modules (#131057)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131057 Approved by: https://github.com/williamwen42, https://github.com/jansel ghstack dependencies: #131056
This commit is contained in:
parent
00e54e74ff
commit
ac76dd606f
2 changed files with 30 additions and 22 deletions
|
|
@ -85,7 +85,6 @@ from .source import (
|
|||
GlobalStateSource,
|
||||
GlobalWeakRefSource,
|
||||
GradSource,
|
||||
is_unspecialized_builtin_nnmodule_attr,
|
||||
LocalSource,
|
||||
NNModuleSource,
|
||||
NumpyTensorSource,
|
||||
|
|
@ -1684,6 +1683,13 @@ class GuardBuilder(GuardBuilderBase):
|
|||
else:
|
||||
self._produce_guard_code(guard, code)
|
||||
|
||||
def EMPTY_NN_MODULE_HOOKS_DICT(self, guard):
|
||||
"""Special guard to skip guards on empty hooks. This is controlled by skip_nnmodule_hook_guards"""
|
||||
if config.skip_nnmodule_hook_guards:
|
||||
# This is unsafe if you add/remove a hook on nn module variable
|
||||
return
|
||||
self.SEQUENCE_LENGTH(guard)
|
||||
|
||||
def OBJECT_MUTATION(self, guard: Guard):
|
||||
mutation_guard.watch(self.get(guard.name), self.check_fn_manager)
|
||||
|
||||
|
|
@ -2112,22 +2118,6 @@ class DeletedGuardFn:
|
|||
pass
|
||||
|
||||
|
||||
def is_nn_module_hook(source: Source) -> bool:
|
||||
# Note that we only skip guards on builtin nn modules like Conv2D etc. But still this is a soundness issue if one
|
||||
# adds/removes a hook after the model is compiled.
|
||||
return (
|
||||
is_unspecialized_builtin_nnmodule_attr(source)
|
||||
and isinstance(source, AttrSource)
|
||||
and source.member
|
||||
in (
|
||||
"_backward_hooks",
|
||||
"_backward_pre_hooks",
|
||||
"_forward_hooks",
|
||||
"_forward_pre_hooks",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# NB: Naively, you'd expect this to only be a function that produces
|
||||
# the callable that constitutes the guard. However, there is some
|
||||
# delicate handling for invalidating this check function when the
|
||||
|
|
@ -2188,11 +2178,6 @@ class CheckFunctionManager:
|
|||
):
|
||||
continue
|
||||
|
||||
# This is unsafe if you add/remove a hook on unspecialized nn module variable
|
||||
if config.skip_nnmodule_hook_guards and is_nn_module_hook(
|
||||
guard.originating_source
|
||||
):
|
||||
continue
|
||||
guard.create(builder)
|
||||
|
||||
self.check_fn = self.compile_check_fn(builder, guards, guard_fail_fn)
|
||||
|
|
|
|||
|
|
@ -1038,6 +1038,29 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
|
|||
return dict_vt.maybe_getitem_const(name_vt)
|
||||
return None
|
||||
|
||||
def var_getattr(self, tx, name):
|
||||
# Allow skipping of empty hook dict guards on inbuilt nn modules
|
||||
if name in (
|
||||
"_backward_hooks",
|
||||
"_backward_pre_hooks",
|
||||
"_forward_hooks",
|
||||
"_forward_pre_hooks",
|
||||
):
|
||||
if not tx.output.side_effects.has_pending_mutation_of_attr(
|
||||
self, name
|
||||
) and self.value.__module__.startswith(("torch.nn.", "torch.ao.")):
|
||||
hooks_dict = getattr(self.value, name)
|
||||
if isinstance(hooks_dict, dict) and len(hooks_dict) == 0:
|
||||
if self.source:
|
||||
hooks_source = AttrSource(self.source, name)
|
||||
install_guard(
|
||||
hooks_source.make_guard(
|
||||
GuardBuilder.EMPTY_NN_MODULE_HOOKS_DICT
|
||||
)
|
||||
)
|
||||
return variables.ConstDictVariable({})
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
def manually_trace_nn_module_getattr(self, tx, name):
|
||||
"""
|
||||
Dynamo tracing of nn.Module __getattr__ can be expensive if the model
|
||||
|
|
|
|||
Loading…
Reference in a new issue