From b80a86ba0f92491f02358cad029591153f421428 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Sun, 9 Feb 2025 17:56:29 -0800 Subject: [PATCH] Update base for Update on "[dynamo][not ready] polyfill infra for classes" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned] --- torch/_dynamo/side_effects.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index caa434ce7ec..93ec5aa05e2 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -297,6 +297,7 @@ class SideEffects: from .variables.ctx_manager import GenericContextWrappingVariable from .variables.torch_function import TorchFunctionModeVariable + from .variables.user_defined import is_forbidden_context_manager variable_cls: type[ variables.UserDefinedObjectVariable @@ -305,9 +306,13 @@ class SideEffects: user_cls, TorchFunctionMode ) and TorchFunctionModeVariable.is_supported_torch_function_mode(user_cls): variable_cls = TorchFunctionModeVariable - elif hasattr(user_cls, "__enter__") and hasattr(user_cls, "__exit__"): + elif ( + hasattr(user_cls, "__enter__") + and hasattr(user_cls, "__exit__") + and not is_forbidden_context_manager(user_cls) + ): variable_cls = GenericContextWrappingVariable - if issubclass(user_cls, torch.nn.Module): + elif issubclass(user_cls, torch.nn.Module): variable_cls = variables.UnspecializedNNModuleVariable elif issubclass(user_cls, (dict, collections.OrderedDict)): variable_cls = variables.UserDefinedDictVariable