diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 946dabeb7e1..b3a0ddc6245 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -85,7 +85,6 @@ from .source import ( GlobalStateSource, GlobalWeakRefSource, GradSource, - is_unspecialized_builtin_nnmodule_attr, LocalSource, NNModuleSource, NumpyTensorSource, @@ -1684,6 +1683,13 @@ class GuardBuilder(GuardBuilderBase): else: self._produce_guard_code(guard, code) + def EMPTY_NN_MODULE_HOOKS_DICT(self, guard): + """Special guard to skip guards on empty hooks. This is controlled by skip_nnmodule_hook_guards""" + if config.skip_nnmodule_hook_guards: + # This is unsafe if you add/remove a hook on nn module variable + return + self.SEQUENCE_LENGTH(guard) + def OBJECT_MUTATION(self, guard: Guard): mutation_guard.watch(self.get(guard.name), self.check_fn_manager) @@ -2112,22 +2118,6 @@ class DeletedGuardFn: pass -def is_nn_module_hook(source: Source) -> bool: - # Note that we only skip guards on builtin nn modules like Conv2D etc. But still this is a soundness issue if one - # adds/removes a hook after the model is compiled. - return ( - is_unspecialized_builtin_nnmodule_attr(source) - and isinstance(source, AttrSource) - and source.member - in ( - "_backward_hooks", - "_backward_pre_hooks", - "_forward_hooks", - "_forward_pre_hooks", - ) - ) - - # NB: Naively, you'd expect this to only be a function that produces # the callable that constitutes the guard. However, there is some # delicate handling for invalidating this check function when the @@ -2188,11 +2178,6 @@ class CheckFunctionManager: ): continue - # This is unsafe if you add/remove a hook on unspecialized nn module variable - if config.skip_nnmodule_hook_guards and is_nn_module_hook( - guard.originating_source - ): - continue guard.create(builder) self.check_fn = self.compile_check_fn(builder, guards, guard_fail_fn) diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index 895cfffdc72..81c73b5a3b0 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -1038,6 +1038,29 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable): return dict_vt.maybe_getitem_const(name_vt) return None + def var_getattr(self, tx, name): + # Allow skipping of empty hook dict guards on inbuilt nn modules + if name in ( + "_backward_hooks", + "_backward_pre_hooks", + "_forward_hooks", + "_forward_pre_hooks", + ): + if not tx.output.side_effects.has_pending_mutation_of_attr( + self, name + ) and self.value.__module__.startswith(("torch.nn.", "torch.ao.")): + hooks_dict = getattr(self.value, name) + if isinstance(hooks_dict, dict) and len(hooks_dict) == 0: + if self.source: + hooks_source = AttrSource(self.source, name) + install_guard( + hooks_source.make_guard( + GuardBuilder.EMPTY_NN_MODULE_HOOKS_DICT + ) + ) + return variables.ConstDictVariable({}) + return super().var_getattr(tx, name) + def manually_trace_nn_module_getattr(self, tx, name): """ Dynamo tracing of nn.Module __getattr__ can be expensive if the model