From a74d29b06949e210aa7cfe61fb64402ce8ae54a0 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Sun, 9 Feb 2025 15:08:59 -0800 Subject: [PATCH] Update base for Update on "[dynamo][user-defined] Unify standard and non-standard __new__ codebase" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned] --- torch/_dynamo/side_effects.py | 11 +++++++++++ torch/_dynamo/variables/user_defined.py | 22 ++++------------------ 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 0190cec3f40..caa434ce7ec 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -293,9 +293,20 @@ class SideEffects: return variable def get_variable_cls(self, user_cls): + from torch.overrides import TorchFunctionMode + + from .variables.ctx_manager import GenericContextWrappingVariable + from .variables.torch_function import TorchFunctionModeVariable + variable_cls: type[ variables.UserDefinedObjectVariable ] = variables.UserDefinedObjectVariable + if issubclass( + user_cls, TorchFunctionMode + ) and TorchFunctionModeVariable.is_supported_torch_function_mode(user_cls): + variable_cls = TorchFunctionModeVariable + elif hasattr(user_cls, "__enter__") and hasattr(user_cls, "__exit__"): + variable_cls = GenericContextWrappingVariable if issubclass(user_cls, torch.nn.Module): variable_cls = variables.UnspecializedNNModuleVariable elif issubclass(user_cls, (dict, collections.OrderedDict)): diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 61cea30d8a1..17a7a723c08 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -480,23 +480,10 @@ class UserDefinedClassVariable(UserDefinedVariable): and self.source and not is_forbidden_context_manager(self.value) ): - from torch.overrides import TorchFunctionMode - - from .ctx_manager import GenericContextWrappingVariable from .functions import ( BaseUserFunctionVariable, FunctionDecoratedByContextlibContextManagerVariable, ) - from .torch_function import TorchFunctionModeVariable - - if issubclass( - self.value, TorchFunctionMode - ) and TorchFunctionModeVariable.is_supported_torch_function_mode( - self.value - ): - var_cls = TorchFunctionModeVariable - else: - var_cls = GenericContextWrappingVariable # graph break on any contextlib.* that it is not contextlib.contextmanager # Some of the APIs below are not supported because they rely on features @@ -533,11 +520,10 @@ class UserDefinedClassVariable(UserDefinedVariable): ) ] + args[1:] - options = {} - options["base_cls_vt"] = variables.BuiltinVariable(object) - options["init_args"] = [] - cm_obj = tx.output.side_effects.track_object_new( - self.source, self.value, var_cls, options + cm_obj = tx.output.side_effects.track_new_user_defined_object( + variables.BuiltinVariable(object), + self, + args, ) cm_obj.call_method(tx, "__init__", args, kwargs) return cm_obj