Update on "[dynamo][user-defined] Unify standard and non-standard __new__ codebase"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
This commit is contained in:
Animesh Jain 2025-02-09 15:08:59 -08:00
commit e5a69acd2f
2 changed files with 15 additions and 18 deletions

View file

@ -293,9 +293,20 @@ class SideEffects:
return variable
def get_variable_cls(self, user_cls):
from torch.overrides import TorchFunctionMode
from .variables.ctx_manager import GenericContextWrappingVariable
from .variables.torch_function import TorchFunctionModeVariable
variable_cls: type[
variables.UserDefinedObjectVariable
] = variables.UserDefinedObjectVariable
if issubclass(
user_cls, TorchFunctionMode
) and TorchFunctionModeVariable.is_supported_torch_function_mode(user_cls):
variable_cls = TorchFunctionModeVariable
elif hasattr(user_cls, "__enter__") and hasattr(user_cls, "__exit__"):
variable_cls = GenericContextWrappingVariable
if issubclass(user_cls, torch.nn.Module):
variable_cls = variables.UnspecializedNNModuleVariable
elif issubclass(user_cls, (dict, collections.OrderedDict)):

View file

@ -480,23 +480,10 @@ class UserDefinedClassVariable(UserDefinedVariable):
and self.source
and not is_forbidden_context_manager(self.value)
):
from torch.overrides import TorchFunctionMode
from .ctx_manager import GenericContextWrappingVariable
from .functions import (
BaseUserFunctionVariable,
FunctionDecoratedByContextlibContextManagerVariable,
)
from .torch_function import TorchFunctionModeVariable
if issubclass(
self.value, TorchFunctionMode
) and TorchFunctionModeVariable.is_supported_torch_function_mode(
self.value
):
var_cls = TorchFunctionModeVariable
else:
var_cls = GenericContextWrappingVariable
# graph break on any contextlib.* that it is not contextlib.contextmanager
# Some of the APIs below are not supported because they rely on features
@ -533,11 +520,10 @@ class UserDefinedClassVariable(UserDefinedVariable):
)
] + args[1:]
options = {}
options["base_cls_vt"] = variables.BuiltinVariable(object)
options["init_args"] = []
cm_obj = tx.output.side_effects.track_object_new(
self.source, self.value, var_cls, options
cm_obj = tx.output.side_effects.track_new_user_defined_object(
variables.BuiltinVariable(object),
self,
args,
)
cm_obj.call_method(tx, "__init__", args, kwargs)
return cm_obj