diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 0190cec3f40..caa434ce7ec 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -293,9 +293,20 @@ class SideEffects: return variable def get_variable_cls(self, user_cls): + from torch.overrides import TorchFunctionMode + + from .variables.ctx_manager import GenericContextWrappingVariable + from .variables.torch_function import TorchFunctionModeVariable + variable_cls: type[ variables.UserDefinedObjectVariable ] = variables.UserDefinedObjectVariable + if issubclass( + user_cls, TorchFunctionMode + ) and TorchFunctionModeVariable.is_supported_torch_function_mode(user_cls): + variable_cls = TorchFunctionModeVariable + elif hasattr(user_cls, "__enter__") and hasattr(user_cls, "__exit__"): + variable_cls = GenericContextWrappingVariable if issubclass(user_cls, torch.nn.Module): variable_cls = variables.UnspecializedNNModuleVariable elif issubclass(user_cls, (dict, collections.OrderedDict)): diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 89b761cc1e0..0ea1583af57 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -480,23 +480,10 @@ class UserDefinedClassVariable(UserDefinedVariable): and self.source and not is_forbidden_context_manager(self.value) ): - from torch.overrides import TorchFunctionMode - - from .ctx_manager import GenericContextWrappingVariable from .functions import ( BaseUserFunctionVariable, FunctionDecoratedByContextlibContextManagerVariable, ) - from .torch_function import TorchFunctionModeVariable - - if issubclass( - self.value, TorchFunctionMode - ) and TorchFunctionModeVariable.is_supported_torch_function_mode( - self.value - ): - var_cls = TorchFunctionModeVariable - else: - var_cls = GenericContextWrappingVariable # graph break on any contextlib.* that it is not contextlib.contextmanager # Some of the APIs below are not supported because they rely on features @@ -533,11 +520,10 @@ class UserDefinedClassVariable(UserDefinedVariable): ) ] + args[1:] - options = {} - options["base_cls_vt"] = variables.BuiltinVariable(object) - options["init_args"] = [] - cm_obj = tx.output.side_effects.track_object_new( - self.source, self.value, var_cls, options + cm_obj = tx.output.side_effects.track_new_user_defined_object( + variables.BuiltinVariable(object), + self, + args, ) cm_obj.call_method(tx, "__init__", args, kwargs) return cm_obj