diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index b5563da6d98..0190cec3f40 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -514,22 +514,31 @@ class SideEffects: elif isinstance(var.mutation_type, AttributeMutationNew): if isinstance(var, variables.AutogradFunctionContextVariable): unimplemented("AutogradFunctionContextVariable escaped") + + # Reconstruct the bytecode for + # base_cls.__new__(user_cls, *args) + if isinstance(var, variables.UserDefinedObjectVariable): - def gen_fn(): + def load_new_method(): assert var.base_cls_vt is not None cg(var.base_cls_vt) # type: ignore[attr-defined] cg.extend_output([cg.create_load_attr("__new__")]) - cg.add_push_null(gen_fn) + cg.add_push_null(load_new_method) else: cg.add_push_null( lambda: cg.load_import_from(utils.__name__, "object_new") ) cg(var.mutation_type.cls_source) - for i in var.init_args: - cg(i) + + # Generate the args to the __new__ method + for arg in var.init_args: + cg(arg) + + # Call the __new__ method cg.extend_output(create_call_function(1 + len(var.init_args), False)) + cg.add_cache(var) var.source = LocalSource(cg.tempvars[var]) else: diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 73a3056c280..c0d69300d71 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -672,7 +672,6 @@ class UserDefinedClassVariable(UserDefinedVariable): new_fn = inspect.getattr_static(self.value, "__new__", None) if isinstance(new_fn, staticmethod): new_fn = new_fn.__func__ - # return new_fn in (object.__new__, Generic.__new__, dict.__new__) return new_fn is object.__new__ def call_obj_hasattr(