Update on "[dynamo][user-defined] Unify standard and non-standard __new__ codebase"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
This commit is contained in:
Animesh Jain 2025-02-09 00:11:08 -08:00
commit 84626097ad
2 changed files with 13 additions and 5 deletions

View file

@ -514,22 +514,31 @@ class SideEffects:
elif isinstance(var.mutation_type, AttributeMutationNew):
if isinstance(var, variables.AutogradFunctionContextVariable):
unimplemented("AutogradFunctionContextVariable escaped")
# Reconstruct the bytecode for
# base_cls.__new__(user_cls, *args)
if isinstance(var, variables.UserDefinedObjectVariable):
def gen_fn():
def load_new_method():
assert var.base_cls_vt is not None
cg(var.base_cls_vt) # type: ignore[attr-defined]
cg.extend_output([cg.create_load_attr("__new__")])
cg.add_push_null(gen_fn)
cg.add_push_null(load_new_method)
else:
cg.add_push_null(
lambda: cg.load_import_from(utils.__name__, "object_new")
)
cg(var.mutation_type.cls_source)
for i in var.init_args:
cg(i)
# Generate the args to the __new__ method
for arg in var.init_args:
cg(arg)
# Call the __new__ method
cg.extend_output(create_call_function(1 + len(var.init_args), False))
cg.add_cache(var)
var.source = LocalSource(cg.tempvars[var])
else:

View file

@ -672,7 +672,6 @@ class UserDefinedClassVariable(UserDefinedVariable):
new_fn = inspect.getattr_static(self.value, "__new__", None)
if isinstance(new_fn, staticmethod):
new_fn = new_fn.__func__
# return new_fn in (object.__new__, Generic.__new__, dict.__new__)
return new_fn is object.__new__
def call_obj_hasattr(