From 16ef1f542f503c23f4fd0d8351c82c770a74b719 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Sun, 9 Feb 2025 00:32:51 -0800 Subject: [PATCH] Update base for Update on "[dynamo][user-defined] Unify standard and non-standard __new__ codebase" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned] --- torch/_dynamo/variables/misc.py | 7 +++++-- torch/_dynamo/variables/user_defined.py | 11 +++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index fe2f3407ab0..d7c78bdaa93 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -164,10 +164,13 @@ class SuperVariable(VariableTracker): and variables.UserDefinedClassVariable.is_supported_new_method(inner_fn) ): user_cls = inner_fn.__self__ - if hasattr(user_cls, "__module__") and user_cls.__module__ == "bulitins": + if hasattr(user_cls, "__module__") and user_cls.__module__ == "builtins": user_cls_vt = variables.BuiltinVariable(user_cls) else: - user_cls_vt = variables.UserDefinedClassVariable(user_cls) + user_cls_source = source.member + user_cls_vt = variables.UserDefinedClassVariable( + user_cls, source=user_cls_source + ) return user_cls_vt.call_method(tx, "__new__", args, kwargs) elif isinstance(inner_fn, staticmethod) and isinstance( inner_fn.__func__, types.FunctionType diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 0dd3aec308f..61cea30d8a1 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -344,6 +344,17 @@ class UserDefinedClassVariable(UserDefinedVariable): return variables.ConstantVariable(self.value == args[0].value) elif name == "__ne__" and len(args) == 1 and hasattr(args[0], "value"): return variables.ConstantVariable(self.value != args[0].value) + elif ( + name == "__new__" + and self.value is collections.OrderedDict + and isinstance(args[0], UserDefinedClassVariable) + and args[0].value is collections.OrderedDict + ): + assert len(args) == 1 + assert len(kwargs) == 0 + return variables.ConstDictVariable( + {}, collections.OrderedDict, mutation_type=ValueMutationNew() + ) elif name == "__new__" and UserDefinedClassVariable.is_supported_new_method( self.value.__new__ ):