mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[dynamo][builtin-skipfiles-cleanup] Remove inspect (#146116)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146116 Approved by: https://github.com/williamwen42, https://github.com/zou3519, https://github.com/jansel ghstack dependencies: #146322
This commit is contained in:
parent
762a05b3b3
commit
5f53889850
5 changed files with 3 additions and 274 deletions
|
|
@ -522,6 +522,7 @@ class SerializationMixin:
|
|||
def test_serialization_backwards_compat_safe(self):
|
||||
self._test_serialization_backwards_compat(True)
|
||||
|
||||
@skipIfTorchDynamo("graph breaks messages collide with warnings")
|
||||
def test_serialization_save_warnings(self):
|
||||
with warnings.catch_warnings(record=True) as warns:
|
||||
with tempfile.NamedTemporaryFile() as checkpoint:
|
||||
|
|
|
|||
|
|
@ -3145,7 +3145,6 @@ BUILTIN_SKIPLIST = (
|
|||
abc,
|
||||
collections,
|
||||
copy,
|
||||
inspect,
|
||||
random,
|
||||
traceback,
|
||||
linecache,
|
||||
|
|
|
|||
|
|
@ -81,7 +81,6 @@ from .misc import (
|
|||
DeletedVariable,
|
||||
ExceptionVariable,
|
||||
GetAttrVariable,
|
||||
InspectSignatureVariable,
|
||||
LambdaVariable,
|
||||
MethodWrapperVariable,
|
||||
NewGlobalVariable,
|
||||
|
|
@ -148,7 +147,6 @@ __all__ = [
|
|||
"FakeItemVariable",
|
||||
"GetAttrVariable",
|
||||
"GradModeVariable",
|
||||
"InspectSignatureVariable",
|
||||
"IteratorVariable",
|
||||
"ItertoolsVariable",
|
||||
"LambdaVariable",
|
||||
|
|
|
|||
|
|
@ -175,7 +175,6 @@ from .misc import (
|
|||
DelayGraphBreakVariable,
|
||||
GetAttrVariable,
|
||||
GetSetDescriptorVariable,
|
||||
InspectSignatureVariable,
|
||||
LambdaVariable,
|
||||
LoggingLoggerVariable,
|
||||
MethodWrapperVariable,
|
||||
|
|
@ -486,14 +485,6 @@ class VariableBuilder:
|
|||
from ..comptime import comptime
|
||||
|
||||
entries = [
|
||||
(
|
||||
inspect.signature,
|
||||
lambda self, value: LambdaVariable(
|
||||
InspectSignatureVariable.create,
|
||||
source=self.source,
|
||||
**self.install_guards(GuardBuilder.CLOSURE_MATCH),
|
||||
),
|
||||
),
|
||||
(comptime, lambda self, value: ComptimeVariable()),
|
||||
(
|
||||
dataclasses.fields,
|
||||
|
|
|
|||
|
|
@ -20,13 +20,7 @@ from ..create_parameter_op import do_not_convert_to_tracable_parameter
|
|||
from ..exc import raise_observed_exception, unimplemented
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..mutation_guard import unpatched_nn_module_init
|
||||
from ..source import (
|
||||
AttrSource,
|
||||
DefaultsSource,
|
||||
GetItemSource,
|
||||
TypeSource,
|
||||
WeakRefCallSource,
|
||||
)
|
||||
from ..source import AttrSource, GetItemSource, TypeSource, WeakRefCallSource
|
||||
from ..utils import (
|
||||
check_unspec_or_constant_args,
|
||||
identity,
|
||||
|
|
@ -36,13 +30,7 @@ from ..utils import (
|
|||
tuple_methods,
|
||||
)
|
||||
from .base import VariableTracker
|
||||
from .functions import (
|
||||
NestedUserFunctionVariable,
|
||||
UserFunctionVariable,
|
||||
UserMethodVariable,
|
||||
wrap_bound_arg,
|
||||
)
|
||||
from .nn_module import UnspecializedNNModuleVariable
|
||||
from .functions import NestedUserFunctionVariable, UserFunctionVariable
|
||||
from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObjectVariable
|
||||
|
||||
|
||||
|
|
@ -356,254 +344,6 @@ class NewGlobalVariable(VariableTracker):
|
|||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class InspectSignatureVariable(VariableTracker):
|
||||
"""represents inspect.signature(...)"""
|
||||
|
||||
_nonvar_fields = {
|
||||
"signature",
|
||||
"parameters",
|
||||
*VariableTracker._nonvar_fields,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create(callable, **kwargs):
|
||||
if kwargs:
|
||||
unimplemented(f"inspect.signature with {kwargs}")
|
||||
return InspectSignatureVariable(
|
||||
callable, mutation_type=variables.base.ValueMutationNew()
|
||||
)
|
||||
|
||||
def __init__(self, inspected: VariableTracker, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.inspected = inspected
|
||||
|
||||
try:
|
||||
if hasattr(self.inspected, "get_function"):
|
||||
self.fn = self.inspected.get_function()
|
||||
elif isinstance(self.inspected, UnspecializedNNModuleVariable):
|
||||
self.fn = self.inspected.value
|
||||
else:
|
||||
self.fn = self.inspected.as_python_constant()
|
||||
except NotImplementedError:
|
||||
unimplemented("inspect.signature with non-constant function")
|
||||
|
||||
self.signature = inspect.signature(self.fn)
|
||||
self.parameters = list(self.signature.parameters.items())
|
||||
if isinstance(self.inspected, UserMethodVariable):
|
||||
self.parameters = self.parameters[1:]
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
|
||||
if name == "parameters":
|
||||
return variables.ConstDictVariable(
|
||||
{
|
||||
variables.ConstantVariable.create(
|
||||
param[0]
|
||||
): InspectParameterVariable(param[1])
|
||||
for param in self.parameters
|
||||
},
|
||||
user_cls=dict,
|
||||
)
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
if name == "bind":
|
||||
if not hasattr(self.fn, "__kwdefaults__"):
|
||||
unimplemented(
|
||||
f"inspect.signature.bind with {self.fn} without __kwdefaults__"
|
||||
)
|
||||
obj = self.signature.bind(*args, **kwargs)
|
||||
|
||||
# wrap function defaults in VTs
|
||||
defaults = {}
|
||||
if self.fn.__kwdefaults__:
|
||||
wrap = functools.partial(wrap_bound_arg, tx=tx)
|
||||
kwdefaults_sources = {
|
||||
k: (
|
||||
None
|
||||
if self.source is None
|
||||
else DefaultsSource(self.source, k, is_kw=True)
|
||||
)
|
||||
for k in self.fn.__kwdefaults__
|
||||
}
|
||||
defaults = {
|
||||
k: wrap(val=v, source=kwdefaults_sources[k])
|
||||
for k, v in self.fn.__kwdefaults__.items()
|
||||
}
|
||||
|
||||
return InspectBoundArgumentsVariable(
|
||||
obj,
|
||||
defaults,
|
||||
self,
|
||||
)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.extend_output(
|
||||
[
|
||||
codegen.create_load_python_module(inspect),
|
||||
codegen.create_load_attr("signature"),
|
||||
]
|
||||
)
|
||||
)
|
||||
codegen(self.inspected)
|
||||
codegen.extend_output(create_call_function(1, False))
|
||||
|
||||
|
||||
class InspectParameterVariable(VariableTracker):
|
||||
"""represents inspect.Parameter(...)"""
|
||||
|
||||
def __init__(self, value, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.value = value
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
|
||||
try:
|
||||
attr_value = getattr(self.value, name)
|
||||
source = self.source and AttrSource(self.source, name)
|
||||
return VariableTracker.build(tx, attr_value, source)
|
||||
except AttributeError:
|
||||
unimplemented(f"getattr({self.value}, {name})")
|
||||
|
||||
|
||||
class InspectBoundArgumentsVariable(VariableTracker):
|
||||
"""represents inspect.signature(...).bind(...)"""
|
||||
|
||||
_nonvar_fields = {
|
||||
"bound_arguments",
|
||||
"packed_vars",
|
||||
*VariableTracker._nonvar_fields,
|
||||
}
|
||||
|
||||
# NOTE: we keep track of changes to arguments via bound_arguments_var,
|
||||
# but we still keep a copy of the inspect.BoundArguments object in order
|
||||
# to get the correct args/kwargs.
|
||||
def __init__(
|
||||
self,
|
||||
bound_arguments: inspect.BoundArguments,
|
||||
defaults: dict[str, VariableTracker],
|
||||
signature: InspectSignatureVariable,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.bound_arguments = bound_arguments
|
||||
self.defaults = defaults
|
||||
# used to convert from VT to tuple/dict when updating bound_arguments
|
||||
self.packed_vars = set()
|
||||
|
||||
arguments_dict = {}
|
||||
for key, val in bound_arguments.arguments.items():
|
||||
key_var = variables.ConstantVariable(key)
|
||||
# convert val to VT
|
||||
if isinstance(val, tuple):
|
||||
arguments_dict[key_var] = variables.TupleVariable(list(val))
|
||||
self.packed_vars.add(key)
|
||||
elif isinstance(val, dict):
|
||||
self.packed_vars.add(key)
|
||||
arguments_dict[key_var] = variables.ConstDictVariable(
|
||||
{variables.ConstantVariable(k): v for k, v in val.items()}
|
||||
)
|
||||
elif isinstance(val, VariableTracker):
|
||||
arguments_dict[key_var] = val
|
||||
else:
|
||||
unimplemented(
|
||||
"inspect.signature(...).bind(...).arguments contains non-variable/tuple/dict"
|
||||
)
|
||||
|
||||
self.bound_arguments_var = variables.ConstDictVariable(
|
||||
arguments_dict,
|
||||
type(bound_arguments.arguments),
|
||||
mutation_type=variables.base.ValueMutationNew(),
|
||||
)
|
||||
self.signature = signature
|
||||
|
||||
def _update_bound_arguments(self):
|
||||
for key, val in self.bound_arguments_var.items.items():
|
||||
true_val = val
|
||||
if key.underlying_value in self.packed_vars:
|
||||
if isinstance(val, variables.TupleVariable):
|
||||
true_val = tuple(val.items)
|
||||
elif isinstance(val, variables.ConstDictVariable):
|
||||
true_val = {k.underlying_value: v for k, v in val.items.items()}
|
||||
else:
|
||||
unimplemented(
|
||||
"inspect.signature(...).bind(...) cannot update bound arguments"
|
||||
)
|
||||
self.bound_arguments.arguments[key.underlying_value] = true_val
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
|
||||
if name == "arguments":
|
||||
return self.bound_arguments_var
|
||||
elif name == "args":
|
||||
self._update_bound_arguments()
|
||||
return variables.TupleVariable(list(self.bound_arguments.args))
|
||||
elif name == "kwargs":
|
||||
self._update_bound_arguments()
|
||||
kw = {
|
||||
variables.ConstantVariable(key): val
|
||||
for key, val in self.bound_arguments.kwargs.items()
|
||||
}
|
||||
return variables.ConstDictVariable(kw)
|
||||
elif name == "signature":
|
||||
return self.signature
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
if name == "apply_defaults":
|
||||
# mimic calling apply_defaults
|
||||
for key, val in self.defaults.items():
|
||||
key_var = variables.ConstantVariable(key)
|
||||
if key_var not in self.bound_arguments_var:
|
||||
self.bound_arguments_var.call_method(
|
||||
tx, "__setitem__", [key_var, val], {}
|
||||
)
|
||||
|
||||
# actually apply the changes
|
||||
self._update_bound_arguments()
|
||||
|
||||
return variables.ConstantVariable(None)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
# reconstruct inspect.signature(...).bind(*bound_arguments.args, **bound_arguments.kwargs)
|
||||
# NOTE the reconstructed inspect.signature(...) object might not be the same object
|
||||
# as the Signature object that originally created the BoundArguments object.
|
||||
self._update_bound_arguments()
|
||||
|
||||
def gen_fn():
|
||||
codegen(self.signature)
|
||||
codegen.append_output(codegen.create_load_attr("bind"))
|
||||
|
||||
codegen.add_push_null(gen_fn, call_function_ex=True)
|
||||
|
||||
codegen.foreach(self.bound_arguments.args)
|
||||
codegen.append_output(
|
||||
create_instruction("BUILD_TUPLE", arg=len(self.bound_arguments.args))
|
||||
)
|
||||
|
||||
for key, val in self.bound_arguments.kwargs.items():
|
||||
codegen.append_output(codegen.create_load_const(key))
|
||||
codegen(val)
|
||||
codegen.extend_output(
|
||||
[
|
||||
create_instruction("BUILD_MAP", arg=len(self.bound_arguments.kwargs)),
|
||||
create_instruction("CALL_FUNCTION_EX", arg=1),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def produce_trampoline_autograd_apply(fn_cls):
|
||||
def trampoline_autograd_apply(*args, **kwargs):
|
||||
return fn_cls.apply(*args, **kwargs)
|
||||
|
|
|
|||
Loading…
Reference in a new issue