[dynamo][builtin-skipfiles-cleanup] Remove inspect (#146116)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146116
Approved by: https://github.com/williamwen42, https://github.com/zou3519, https://github.com/jansel
ghstack dependencies: #146322
This commit is contained in:
Animesh Jain 2025-02-03 10:13:17 -08:00 committed by PyTorch MergeBot
parent 762a05b3b3
commit 5f53889850
5 changed files with 3 additions and 274 deletions

View file

@ -522,6 +522,7 @@ class SerializationMixin:
def test_serialization_backwards_compat_safe(self):
self._test_serialization_backwards_compat(True)
@skipIfTorchDynamo("graph breaks messages collide with warnings")
def test_serialization_save_warnings(self):
with warnings.catch_warnings(record=True) as warns:
with tempfile.NamedTemporaryFile() as checkpoint:

View file

@ -3145,7 +3145,6 @@ BUILTIN_SKIPLIST = (
abc,
collections,
copy,
inspect,
random,
traceback,
linecache,

View file

@ -81,7 +81,6 @@ from .misc import (
DeletedVariable,
ExceptionVariable,
GetAttrVariable,
InspectSignatureVariable,
LambdaVariable,
MethodWrapperVariable,
NewGlobalVariable,
@ -148,7 +147,6 @@ __all__ = [
"FakeItemVariable",
"GetAttrVariable",
"GradModeVariable",
"InspectSignatureVariable",
"IteratorVariable",
"ItertoolsVariable",
"LambdaVariable",

View file

@ -175,7 +175,6 @@ from .misc import (
DelayGraphBreakVariable,
GetAttrVariable,
GetSetDescriptorVariable,
InspectSignatureVariable,
LambdaVariable,
LoggingLoggerVariable,
MethodWrapperVariable,
@ -486,14 +485,6 @@ class VariableBuilder:
from ..comptime import comptime
entries = [
(
inspect.signature,
lambda self, value: LambdaVariable(
InspectSignatureVariable.create,
source=self.source,
**self.install_guards(GuardBuilder.CLOSURE_MATCH),
),
),
(comptime, lambda self, value: ComptimeVariable()),
(
dataclasses.fields,

View file

@ -20,13 +20,7 @@ from ..create_parameter_op import do_not_convert_to_tracable_parameter
from ..exc import raise_observed_exception, unimplemented
from ..guards import GuardBuilder, install_guard
from ..mutation_guard import unpatched_nn_module_init
from ..source import (
AttrSource,
DefaultsSource,
GetItemSource,
TypeSource,
WeakRefCallSource,
)
from ..source import AttrSource, GetItemSource, TypeSource, WeakRefCallSource
from ..utils import (
check_unspec_or_constant_args,
identity,
@ -36,13 +30,7 @@ from ..utils import (
tuple_methods,
)
from .base import VariableTracker
from .functions import (
NestedUserFunctionVariable,
UserFunctionVariable,
UserMethodVariable,
wrap_bound_arg,
)
from .nn_module import UnspecializedNNModuleVariable
from .functions import NestedUserFunctionVariable, UserFunctionVariable
from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObjectVariable
@ -356,254 +344,6 @@ class NewGlobalVariable(VariableTracker):
super().__init__(**kwargs)
class InspectSignatureVariable(VariableTracker):
"""represents inspect.signature(...)"""
_nonvar_fields = {
"signature",
"parameters",
*VariableTracker._nonvar_fields,
}
@staticmethod
def create(callable, **kwargs):
if kwargs:
unimplemented(f"inspect.signature with {kwargs}")
return InspectSignatureVariable(
callable, mutation_type=variables.base.ValueMutationNew()
)
def __init__(self, inspected: VariableTracker, **kwargs) -> None:
super().__init__(**kwargs)
self.inspected = inspected
try:
if hasattr(self.inspected, "get_function"):
self.fn = self.inspected.get_function()
elif isinstance(self.inspected, UnspecializedNNModuleVariable):
self.fn = self.inspected.value
else:
self.fn = self.inspected.as_python_constant()
except NotImplementedError:
unimplemented("inspect.signature with non-constant function")
self.signature = inspect.signature(self.fn)
self.parameters = list(self.signature.parameters.items())
if isinstance(self.inspected, UserMethodVariable):
self.parameters = self.parameters[1:]
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
if name == "parameters":
return variables.ConstDictVariable(
{
variables.ConstantVariable.create(
param[0]
): InspectParameterVariable(param[1])
for param in self.parameters
},
user_cls=dict,
)
return super().var_getattr(tx, name)
def call_method(
self,
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "bind":
if not hasattr(self.fn, "__kwdefaults__"):
unimplemented(
f"inspect.signature.bind with {self.fn} without __kwdefaults__"
)
obj = self.signature.bind(*args, **kwargs)
# wrap function defaults in VTs
defaults = {}
if self.fn.__kwdefaults__:
wrap = functools.partial(wrap_bound_arg, tx=tx)
kwdefaults_sources = {
k: (
None
if self.source is None
else DefaultsSource(self.source, k, is_kw=True)
)
for k in self.fn.__kwdefaults__
}
defaults = {
k: wrap(val=v, source=kwdefaults_sources[k])
for k, v in self.fn.__kwdefaults__.items()
}
return InspectBoundArgumentsVariable(
obj,
defaults,
self,
)
return super().call_method(tx, name, args, kwargs)
def reconstruct(self, codegen):
codegen.add_push_null(
lambda: codegen.extend_output(
[
codegen.create_load_python_module(inspect),
codegen.create_load_attr("signature"),
]
)
)
codegen(self.inspected)
codegen.extend_output(create_call_function(1, False))
class InspectParameterVariable(VariableTracker):
"""represents inspect.Parameter(...)"""
def __init__(self, value, **kwargs) -> None:
super().__init__(**kwargs)
self.value = value
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
try:
attr_value = getattr(self.value, name)
source = self.source and AttrSource(self.source, name)
return VariableTracker.build(tx, attr_value, source)
except AttributeError:
unimplemented(f"getattr({self.value}, {name})")
class InspectBoundArgumentsVariable(VariableTracker):
"""represents inspect.signature(...).bind(...)"""
_nonvar_fields = {
"bound_arguments",
"packed_vars",
*VariableTracker._nonvar_fields,
}
# NOTE: we keep track of changes to arguments via bound_arguments_var,
# but we still keep a copy of the inspect.BoundArguments object in order
# to get the correct args/kwargs.
def __init__(
self,
bound_arguments: inspect.BoundArguments,
defaults: dict[str, VariableTracker],
signature: InspectSignatureVariable,
**kwargs,
):
super().__init__(**kwargs)
self.bound_arguments = bound_arguments
self.defaults = defaults
# used to convert from VT to tuple/dict when updating bound_arguments
self.packed_vars = set()
arguments_dict = {}
for key, val in bound_arguments.arguments.items():
key_var = variables.ConstantVariable(key)
# convert val to VT
if isinstance(val, tuple):
arguments_dict[key_var] = variables.TupleVariable(list(val))
self.packed_vars.add(key)
elif isinstance(val, dict):
self.packed_vars.add(key)
arguments_dict[key_var] = variables.ConstDictVariable(
{variables.ConstantVariable(k): v for k, v in val.items()}
)
elif isinstance(val, VariableTracker):
arguments_dict[key_var] = val
else:
unimplemented(
"inspect.signature(...).bind(...).arguments contains non-variable/tuple/dict"
)
self.bound_arguments_var = variables.ConstDictVariable(
arguments_dict,
type(bound_arguments.arguments),
mutation_type=variables.base.ValueMutationNew(),
)
self.signature = signature
def _update_bound_arguments(self):
for key, val in self.bound_arguments_var.items.items():
true_val = val
if key.underlying_value in self.packed_vars:
if isinstance(val, variables.TupleVariable):
true_val = tuple(val.items)
elif isinstance(val, variables.ConstDictVariable):
true_val = {k.underlying_value: v for k, v in val.items.items()}
else:
unimplemented(
"inspect.signature(...).bind(...) cannot update bound arguments"
)
self.bound_arguments.arguments[key.underlying_value] = true_val
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
if name == "arguments":
return self.bound_arguments_var
elif name == "args":
self._update_bound_arguments()
return variables.TupleVariable(list(self.bound_arguments.args))
elif name == "kwargs":
self._update_bound_arguments()
kw = {
variables.ConstantVariable(key): val
for key, val in self.bound_arguments.kwargs.items()
}
return variables.ConstDictVariable(kw)
elif name == "signature":
return self.signature
return super().var_getattr(tx, name)
def call_method(
self,
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "apply_defaults":
# mimic calling apply_defaults
for key, val in self.defaults.items():
key_var = variables.ConstantVariable(key)
if key_var not in self.bound_arguments_var:
self.bound_arguments_var.call_method(
tx, "__setitem__", [key_var, val], {}
)
# actually apply the changes
self._update_bound_arguments()
return variables.ConstantVariable(None)
return super().call_method(tx, name, args, kwargs)
def reconstruct(self, codegen):
# reconstruct inspect.signature(...).bind(*bound_arguments.args, **bound_arguments.kwargs)
# NOTE the reconstructed inspect.signature(...) object might not be the same object
# as the Signature object that originally created the BoundArguments object.
self._update_bound_arguments()
def gen_fn():
codegen(self.signature)
codegen.append_output(codegen.create_load_attr("bind"))
codegen.add_push_null(gen_fn, call_function_ex=True)
codegen.foreach(self.bound_arguments.args)
codegen.append_output(
create_instruction("BUILD_TUPLE", arg=len(self.bound_arguments.args))
)
for key, val in self.bound_arguments.kwargs.items():
codegen.append_output(codegen.create_load_const(key))
codegen(val)
codegen.extend_output(
[
create_instruction("BUILD_MAP", arg=len(self.bound_arguments.kwargs)),
create_instruction("CALL_FUNCTION_EX", arg=1),
]
)
def produce_trampoline_autograd_apply(fn_cls):
def trampoline_autograd_apply(*args, **kwargs):
return fn_cls.apply(*args, **kwargs)