diff --git a/test/test_serialization.py b/test/test_serialization.py index 6ada03e6cbd..bf5effb8d08 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -522,6 +522,7 @@ class SerializationMixin: def test_serialization_backwards_compat_safe(self): self._test_serialization_backwards_compat(True) + @skipIfTorchDynamo("graph breaks messages collide with warnings") def test_serialization_save_warnings(self): with warnings.catch_warnings(record=True) as warns: with tempfile.NamedTemporaryFile() as checkpoint: diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 48ca3fa65cd..7b459ffcbb9 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3145,7 +3145,6 @@ BUILTIN_SKIPLIST = ( abc, collections, copy, - inspect, random, traceback, linecache, diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index 848cced5bc8..9fc28fe50a6 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -81,7 +81,6 @@ from .misc import ( DeletedVariable, ExceptionVariable, GetAttrVariable, - InspectSignatureVariable, LambdaVariable, MethodWrapperVariable, NewGlobalVariable, @@ -148,7 +147,6 @@ __all__ = [ "FakeItemVariable", "GetAttrVariable", "GradModeVariable", - "InspectSignatureVariable", "IteratorVariable", "ItertoolsVariable", "LambdaVariable", diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index a026989a089..92d7a971b7a 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -175,7 +175,6 @@ from .misc import ( DelayGraphBreakVariable, GetAttrVariable, GetSetDescriptorVariable, - InspectSignatureVariable, LambdaVariable, LoggingLoggerVariable, MethodWrapperVariable, @@ -486,14 +485,6 @@ class VariableBuilder: from ..comptime import comptime entries = [ - ( - inspect.signature, - lambda self, value: LambdaVariable( - InspectSignatureVariable.create, - source=self.source, - **self.install_guards(GuardBuilder.CLOSURE_MATCH), - ), - ), (comptime, lambda self, value: ComptimeVariable()), ( dataclasses.fields, diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 46f84992696..d213e20169e 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -20,13 +20,7 @@ from ..create_parameter_op import do_not_convert_to_tracable_parameter from ..exc import raise_observed_exception, unimplemented from ..guards import GuardBuilder, install_guard from ..mutation_guard import unpatched_nn_module_init -from ..source import ( - AttrSource, - DefaultsSource, - GetItemSource, - TypeSource, - WeakRefCallSource, -) +from ..source import AttrSource, GetItemSource, TypeSource, WeakRefCallSource from ..utils import ( check_unspec_or_constant_args, identity, @@ -36,13 +30,7 @@ from ..utils import ( tuple_methods, ) from .base import VariableTracker -from .functions import ( - NestedUserFunctionVariable, - UserFunctionVariable, - UserMethodVariable, - wrap_bound_arg, -) -from .nn_module import UnspecializedNNModuleVariable +from .functions import NestedUserFunctionVariable, UserFunctionVariable from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObjectVariable @@ -356,254 +344,6 @@ class NewGlobalVariable(VariableTracker): super().__init__(**kwargs) -class InspectSignatureVariable(VariableTracker): - """represents inspect.signature(...)""" - - _nonvar_fields = { - "signature", - "parameters", - *VariableTracker._nonvar_fields, - } - - @staticmethod - def create(callable, **kwargs): - if kwargs: - unimplemented(f"inspect.signature with {kwargs}") - return InspectSignatureVariable( - callable, mutation_type=variables.base.ValueMutationNew() - ) - - def __init__(self, inspected: VariableTracker, **kwargs) -> None: - super().__init__(**kwargs) - self.inspected = inspected - - try: - if hasattr(self.inspected, "get_function"): - self.fn = self.inspected.get_function() - elif isinstance(self.inspected, UnspecializedNNModuleVariable): - self.fn = self.inspected.value - else: - self.fn = self.inspected.as_python_constant() - except NotImplementedError: - unimplemented("inspect.signature with non-constant function") - - self.signature = inspect.signature(self.fn) - self.parameters = list(self.signature.parameters.items()) - if isinstance(self.inspected, UserMethodVariable): - self.parameters = self.parameters[1:] - - def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": - if name == "parameters": - return variables.ConstDictVariable( - { - variables.ConstantVariable.create( - param[0] - ): InspectParameterVariable(param[1]) - for param in self.parameters - }, - user_cls=dict, - ) - return super().var_getattr(tx, name) - - def call_method( - self, - tx, - name, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": - if name == "bind": - if not hasattr(self.fn, "__kwdefaults__"): - unimplemented( - f"inspect.signature.bind with {self.fn} without __kwdefaults__" - ) - obj = self.signature.bind(*args, **kwargs) - - # wrap function defaults in VTs - defaults = {} - if self.fn.__kwdefaults__: - wrap = functools.partial(wrap_bound_arg, tx=tx) - kwdefaults_sources = { - k: ( - None - if self.source is None - else DefaultsSource(self.source, k, is_kw=True) - ) - for k in self.fn.__kwdefaults__ - } - defaults = { - k: wrap(val=v, source=kwdefaults_sources[k]) - for k, v in self.fn.__kwdefaults__.items() - } - - return InspectBoundArgumentsVariable( - obj, - defaults, - self, - ) - return super().call_method(tx, name, args, kwargs) - - def reconstruct(self, codegen): - codegen.add_push_null( - lambda: codegen.extend_output( - [ - codegen.create_load_python_module(inspect), - codegen.create_load_attr("signature"), - ] - ) - ) - codegen(self.inspected) - codegen.extend_output(create_call_function(1, False)) - - -class InspectParameterVariable(VariableTracker): - """represents inspect.Parameter(...)""" - - def __init__(self, value, **kwargs) -> None: - super().__init__(**kwargs) - self.value = value - - def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": - try: - attr_value = getattr(self.value, name) - source = self.source and AttrSource(self.source, name) - return VariableTracker.build(tx, attr_value, source) - except AttributeError: - unimplemented(f"getattr({self.value}, {name})") - - -class InspectBoundArgumentsVariable(VariableTracker): - """represents inspect.signature(...).bind(...)""" - - _nonvar_fields = { - "bound_arguments", - "packed_vars", - *VariableTracker._nonvar_fields, - } - - # NOTE: we keep track of changes to arguments via bound_arguments_var, - # but we still keep a copy of the inspect.BoundArguments object in order - # to get the correct args/kwargs. - def __init__( - self, - bound_arguments: inspect.BoundArguments, - defaults: dict[str, VariableTracker], - signature: InspectSignatureVariable, - **kwargs, - ): - super().__init__(**kwargs) - self.bound_arguments = bound_arguments - self.defaults = defaults - # used to convert from VT to tuple/dict when updating bound_arguments - self.packed_vars = set() - - arguments_dict = {} - for key, val in bound_arguments.arguments.items(): - key_var = variables.ConstantVariable(key) - # convert val to VT - if isinstance(val, tuple): - arguments_dict[key_var] = variables.TupleVariable(list(val)) - self.packed_vars.add(key) - elif isinstance(val, dict): - self.packed_vars.add(key) - arguments_dict[key_var] = variables.ConstDictVariable( - {variables.ConstantVariable(k): v for k, v in val.items()} - ) - elif isinstance(val, VariableTracker): - arguments_dict[key_var] = val - else: - unimplemented( - "inspect.signature(...).bind(...).arguments contains non-variable/tuple/dict" - ) - - self.bound_arguments_var = variables.ConstDictVariable( - arguments_dict, - type(bound_arguments.arguments), - mutation_type=variables.base.ValueMutationNew(), - ) - self.signature = signature - - def _update_bound_arguments(self): - for key, val in self.bound_arguments_var.items.items(): - true_val = val - if key.underlying_value in self.packed_vars: - if isinstance(val, variables.TupleVariable): - true_val = tuple(val.items) - elif isinstance(val, variables.ConstDictVariable): - true_val = {k.underlying_value: v for k, v in val.items.items()} - else: - unimplemented( - "inspect.signature(...).bind(...) cannot update bound arguments" - ) - self.bound_arguments.arguments[key.underlying_value] = true_val - - def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": - if name == "arguments": - return self.bound_arguments_var - elif name == "args": - self._update_bound_arguments() - return variables.TupleVariable(list(self.bound_arguments.args)) - elif name == "kwargs": - self._update_bound_arguments() - kw = { - variables.ConstantVariable(key): val - for key, val in self.bound_arguments.kwargs.items() - } - return variables.ConstDictVariable(kw) - elif name == "signature": - return self.signature - return super().var_getattr(tx, name) - - def call_method( - self, - tx, - name, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": - if name == "apply_defaults": - # mimic calling apply_defaults - for key, val in self.defaults.items(): - key_var = variables.ConstantVariable(key) - if key_var not in self.bound_arguments_var: - self.bound_arguments_var.call_method( - tx, "__setitem__", [key_var, val], {} - ) - - # actually apply the changes - self._update_bound_arguments() - - return variables.ConstantVariable(None) - return super().call_method(tx, name, args, kwargs) - - def reconstruct(self, codegen): - # reconstruct inspect.signature(...).bind(*bound_arguments.args, **bound_arguments.kwargs) - # NOTE the reconstructed inspect.signature(...) object might not be the same object - # as the Signature object that originally created the BoundArguments object. - self._update_bound_arguments() - - def gen_fn(): - codegen(self.signature) - codegen.append_output(codegen.create_load_attr("bind")) - - codegen.add_push_null(gen_fn, call_function_ex=True) - - codegen.foreach(self.bound_arguments.args) - codegen.append_output( - create_instruction("BUILD_TUPLE", arg=len(self.bound_arguments.args)) - ) - - for key, val in self.bound_arguments.kwargs.items(): - codegen.append_output(codegen.create_load_const(key)) - codegen(val) - codegen.extend_output( - [ - create_instruction("BUILD_MAP", arg=len(self.bound_arguments.kwargs)), - create_instruction("CALL_FUNCTION_EX", arg=1), - ] - ) - - def produce_trampoline_autograd_apply(fn_cls): def trampoline_autograd_apply(*args, **kwargs): return fn_cls.apply(*args, **kwargs)