mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Turn on mypy for _dynamo/variables/builtin.py (#145552)
The fact that mypy errors were ignored was hiding several bugs in builtin.py (for example the previous diff's incorrect override and use of `call_getattr`) Pull Request resolved: https://github.com/pytorch/pytorch/pull/145552 Approved by: https://github.com/anijain2305, https://github.com/Skylion007 ghstack dependencies: #145551
This commit is contained in:
parent
f3120f6d26
commit
ccbbc88bbb
8 changed files with 113 additions and 34 deletions
|
|
@ -252,6 +252,8 @@ class OutputGraph:
|
|||
the root InstructionTranslator's OutputGraph.
|
||||
"""
|
||||
|
||||
side_effects: SideEffects
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
code_options: dict[str, Any],
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ class ModuleRecord:
|
|||
class DummyModule:
|
||||
name: str
|
||||
is_torch: bool = False
|
||||
value: object = None
|
||||
|
||||
@property
|
||||
def __name__(self) -> str:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import collections
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||
from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING
|
||||
|
||||
from .. import variables
|
||||
from ..current_scope_id import current_scope_id
|
||||
|
|
@ -390,7 +390,7 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
|||
def call_function(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
args: "list[VariableTracker]",
|
||||
args: Sequence["VariableTracker"],
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
unimplemented(f"call_function {self} {args} {kwargs}")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
# mypy: ignore-errors
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
|
|
@ -8,9 +8,10 @@ import logging
|
|||
import math
|
||||
import operator
|
||||
import types
|
||||
import typing
|
||||
from collections import defaultdict, OrderedDict
|
||||
from collections.abc import KeysView
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Callable, Sequence, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch import sym_float, sym_int
|
||||
|
|
@ -78,6 +79,7 @@ from .user_defined import UserDefinedObjectVariable, UserDefinedVariable
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Cyclic dependency...
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -100,6 +102,12 @@ IN_PLACE_DESUGARING_MAP = {
|
|||
}
|
||||
|
||||
|
||||
_HandlerCallback = Callable[
|
||||
["InstructionTranslator", typing.Any, typing.Any], VariableTracker
|
||||
]
|
||||
_TrackersType = Union[type[VariableTracker], tuple[type[VariableTracker], ...]]
|
||||
|
||||
|
||||
class BuiltinVariable(VariableTracker):
|
||||
_SENTINEL = object()
|
||||
_nonvar_fields = {
|
||||
|
|
@ -230,9 +238,11 @@ class BuiltinVariable(VariableTracker):
|
|||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def _binops():
|
||||
def _binops() -> (
|
||||
dict[Callable[..., object], tuple[list[str], Callable[..., object]]]
|
||||
):
|
||||
# function -> ([forward name, reverse name, in-place name], in-place op)
|
||||
fns = {
|
||||
fns: dict[Callable[..., object], tuple[list[str], Callable[..., object]]] = {
|
||||
operator.add: (["__add__", "__radd__", "__iadd__"], operator.iadd),
|
||||
operator.sub: (["__sub__", "__rsub__", "__isub__"], operator.isub),
|
||||
operator.mul: (["__mul__", "__rmul__", "__imul__"], operator.imul),
|
||||
|
|
@ -284,7 +294,18 @@ class BuiltinVariable(VariableTracker):
|
|||
)
|
||||
|
||||
# Override table contains: op_fn -> [list of handlers]
|
||||
op_handlers = {}
|
||||
op_handlers: dict[
|
||||
Callable[..., object],
|
||||
list[
|
||||
tuple[
|
||||
tuple[
|
||||
type[VariableTracker],
|
||||
_TrackersType,
|
||||
],
|
||||
_HandlerCallback,
|
||||
]
|
||||
],
|
||||
] = {}
|
||||
for (
|
||||
op,
|
||||
(magic_method_names, in_place_op),
|
||||
|
|
@ -376,7 +397,15 @@ class BuiltinVariable(VariableTracker):
|
|||
def size_add_handler(tx: "InstructionTranslator", a, b):
|
||||
return SizeVariable([*a.items, *b.unpack_var_sequence(tx)])
|
||||
|
||||
list_like_addition_handlers = [
|
||||
list_like_addition_handlers: list[
|
||||
tuple[
|
||||
tuple[
|
||||
type[VariableTracker],
|
||||
_TrackersType,
|
||||
],
|
||||
_HandlerCallback,
|
||||
]
|
||||
] = [
|
||||
# NB: Prefer the tuple-specific logic over base logic because of
|
||||
# some SizeVariable weirdness. Specifically, the tuple-specific logic
|
||||
# drops the subclass type (e.g. SizeVariable) and returns TupleVariables.
|
||||
|
|
@ -395,7 +424,10 @@ class BuiltinVariable(VariableTracker):
|
|||
(
|
||||
(ConstantVariable, TupleVariable),
|
||||
lambda tx, a, b: TupleVariable(
|
||||
[*a.unpack_var_sequence(tx), *b.items],
|
||||
[
|
||||
*a.unpack_var_sequence(tx),
|
||||
*b.items,
|
||||
],
|
||||
),
|
||||
),
|
||||
(
|
||||
|
|
@ -410,7 +442,12 @@ class BuiltinVariable(VariableTracker):
|
|||
),
|
||||
(
|
||||
(BaseListVariable, BaseListVariable),
|
||||
lambda tx, a, b: type(a)([*a.items, *b.items]),
|
||||
lambda tx, a, b: type(a)(
|
||||
[
|
||||
*a.items,
|
||||
*b.items,
|
||||
]
|
||||
),
|
||||
),
|
||||
]
|
||||
op_handlers[operator.add].extend(list_like_addition_handlers)
|
||||
|
|
@ -425,7 +462,12 @@ class BuiltinVariable(VariableTracker):
|
|||
a.items.extend(seq)
|
||||
return a
|
||||
|
||||
list_like_iadd_handlers = [
|
||||
list_like_iadd_handlers: list[
|
||||
tuple[
|
||||
tuple[type[VariableTracker], type[VariableTracker]],
|
||||
_HandlerCallback,
|
||||
]
|
||||
] = [
|
||||
(
|
||||
(ListVariable, VariableTracker),
|
||||
list_iadd_handler,
|
||||
|
|
@ -450,7 +492,12 @@ class BuiltinVariable(VariableTracker):
|
|||
mutation_type=ValueMutationNew(),
|
||||
)
|
||||
|
||||
list_like_expansion_handlers = [
|
||||
list_like_expansion_handlers: list[
|
||||
tuple[
|
||||
tuple[type[VariableTracker], type[VariableTracker]],
|
||||
_HandlerCallback,
|
||||
]
|
||||
] = [
|
||||
((ListVariable, ConstantVariable), expand_list_like),
|
||||
((TupleVariable, ConstantVariable), expand_list_like),
|
||||
((ConstantVariable, ListVariable), expand_list_like),
|
||||
|
|
@ -465,7 +512,15 @@ class BuiltinVariable(VariableTracker):
|
|||
def compare_by_value(tx: "InstructionTranslator", a, b):
|
||||
return ConstantVariable(op(a.value, b.value))
|
||||
|
||||
result = [((ConstantVariable, ConstantVariable), compare_by_value)]
|
||||
result: list[
|
||||
tuple[
|
||||
tuple[
|
||||
_TrackersType,
|
||||
_TrackersType,
|
||||
],
|
||||
_HandlerCallback,
|
||||
]
|
||||
] = [((ConstantVariable, ConstantVariable), compare_by_value)]
|
||||
|
||||
if op in supported_const_comparison_ops.values():
|
||||
# Tensor is None, List is not None, etc
|
||||
|
|
@ -530,6 +585,7 @@ class BuiltinVariable(VariableTracker):
|
|||
def compare_via_method(tx: "InstructionTranslator", left, right):
|
||||
return left.call_method(tx, f"__{op.__name__}__", [right], {})
|
||||
|
||||
compare_user_defined: Callable[..., object]
|
||||
if op.__name__.startswith("is_"):
|
||||
compare_user_defined = compare_by_value
|
||||
else:
|
||||
|
|
@ -543,7 +599,12 @@ class BuiltinVariable(VariableTracker):
|
|||
(UserFunctionVariable, BuiltinVariable),
|
||||
(UserFunctionVariable, BuiltinVariable),
|
||||
),
|
||||
lambda tx, a, b: ConstantVariable(op(a.fn, b.fn)),
|
||||
lambda tx, a, b: ConstantVariable(
|
||||
op(
|
||||
a.fn,
|
||||
b.fn,
|
||||
)
|
||||
),
|
||||
),
|
||||
(
|
||||
(
|
||||
|
|
@ -710,7 +771,7 @@ class BuiltinVariable(VariableTracker):
|
|||
from .lazy import LazyVariableTracker
|
||||
|
||||
obj = BuiltinVariable(fn)
|
||||
handlers = []
|
||||
handlers: list[_HandlerCallback] = []
|
||||
|
||||
if any(issubclass(t, LazyVariableTracker) for t in arg_types):
|
||||
return lambda tx, args, kwargs: obj.call_function(
|
||||
|
|
@ -981,14 +1042,25 @@ class BuiltinVariable(VariableTracker):
|
|||
except NotImplementedError:
|
||||
unimplemented(f"partial tensor op: {self} {args} {kwargs}")
|
||||
|
||||
call_function_handler_cache = {}
|
||||
call_function_handler_cache: dict[
|
||||
tuple[object, ...],
|
||||
Callable[
|
||||
[
|
||||
"InstructionTranslator",
|
||||
Sequence[VariableTracker],
|
||||
dict[str, VariableTracker],
|
||||
],
|
||||
VariableTracker,
|
||||
],
|
||||
] = {}
|
||||
|
||||
def call_function(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
args: "list[VariableTracker]",
|
||||
args: Sequence["VariableTracker"],
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
key: tuple[object, ...]
|
||||
if kwargs:
|
||||
kwargs = {k: v.realize() for k, v in kwargs.items()}
|
||||
key = (self.fn, *(type(x) for x in args), True)
|
||||
|
|
@ -1130,7 +1202,7 @@ class BuiltinVariable(VariableTracker):
|
|||
else:
|
||||
# Overrides for custom str method
|
||||
# Pass method as function to call tx.inline_user_function_return
|
||||
bound_method = str_method.__func__
|
||||
bound_method = str_method.__func__ # type: ignore[attr-defined]
|
||||
|
||||
try:
|
||||
# Only supports certain function types
|
||||
|
|
@ -1191,6 +1263,7 @@ class BuiltinVariable(VariableTracker):
|
|||
|
||||
# convert min/max to torch ops
|
||||
if b.is_python_constant():
|
||||
fn: VariableTracker
|
||||
if isinstance(a, variables.NumpyNdarrayVariable):
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -1203,11 +1276,11 @@ class BuiltinVariable(VariableTracker):
|
|||
if isinstance(a, variables.NumpyNdarrayVariable):
|
||||
import numpy as np
|
||||
|
||||
fn = {max: np.maximum, min: np.minimum}[self.fn]
|
||||
fn = variables.NumpyVariable(fn)
|
||||
np_fn = {max: np.maximum, min: np.minimum}[self.fn]
|
||||
fn = variables.NumpyVariable(np_fn)
|
||||
else:
|
||||
fn = {max: torch.maximum, min: torch.minimum}[self.fn]
|
||||
fn = variables.TorchInGraphFunctionVariable(fn)
|
||||
torch_fn = {max: torch.maximum, min: torch.minimum}[self.fn]
|
||||
fn = variables.TorchInGraphFunctionVariable(torch_fn)
|
||||
result = fn.call_function(tx, [a, b], {})
|
||||
|
||||
# return unspec if both a, b are unspec or const
|
||||
|
|
@ -1245,9 +1318,9 @@ class BuiltinVariable(VariableTracker):
|
|||
else:
|
||||
return result
|
||||
elif isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable):
|
||||
fn = torch.sym_max if self.fn is max else torch.sym_min
|
||||
py_fn = torch.sym_max if self.fn is max else torch.sym_min
|
||||
proxy = tx.output.create_proxy(
|
||||
"call_function", fn, *proxy_args_kwargs([a, b], {})
|
||||
"call_function", py_fn, *proxy_args_kwargs([a, b], {})
|
||||
)
|
||||
return SymNodeVariable.create(tx, proxy, None)
|
||||
elif isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable):
|
||||
|
|
@ -1293,9 +1366,9 @@ class BuiltinVariable(VariableTracker):
|
|||
if check_unspec_or_constant_args(args, {}):
|
||||
return variables.RangeVariable(args)
|
||||
elif self._dynamic_args(*args):
|
||||
args = [
|
||||
args = tuple(
|
||||
variables.ConstantVariable.create(guard_if_dyn(arg)) for arg in args
|
||||
]
|
||||
)
|
||||
return variables.RangeVariable(args)
|
||||
# None no-ops this handler and lets the driving function proceed
|
||||
return None
|
||||
|
|
@ -1437,7 +1510,7 @@ class BuiltinVariable(VariableTracker):
|
|||
assert len(args) == 1 and len(kwargs) == 1 and "value" in kwargs
|
||||
args = (*args, kwargs.pop("value"))
|
||||
if len(args) == 0:
|
||||
raise UserError(TypeError, "fromkeys expected at least 1 argument, got 0")
|
||||
raise UserError(TypeError, "fromkeys expected at least 1 argument, got 0") # type: ignore[arg-type]
|
||||
if len(args) == 1:
|
||||
args = (*args, ConstantVariable.create(None))
|
||||
assert len(args) == 2
|
||||
|
|
@ -1714,7 +1787,7 @@ class BuiltinVariable(VariableTracker):
|
|||
member = obj.value.__dict__[name]
|
||||
|
||||
if config.replay_record_enabled:
|
||||
tx.exec_recorder.record_module_access(obj.value, name, member)
|
||||
tx.exec_recorder.record_module_access(obj.value, name, member) # type: ignore[arg-type, union-attr]
|
||||
return VariableTracker.build(tx, member, source)
|
||||
|
||||
elif istype(obj, variables.UserFunctionVariable) and name in (
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import inspect
|
|||
import logging
|
||||
import math
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch._C
|
||||
|
|
@ -242,7 +243,7 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
|
|||
def call_function(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
args: "list[VariableTracker]",
|
||||
args: Sequence[VariableTracker],
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
from . import (
|
||||
|
|
@ -931,7 +932,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||
def call_function(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
args: "list[VariableTracker]",
|
||||
args: Sequence[VariableTracker],
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
from . import ConstantVariable, SymNodeVariable, TensorVariable
|
||||
|
|
|
|||
|
|
@ -99,10 +99,12 @@ def is_forbidden_context_manager(ctx):
|
|||
|
||||
|
||||
class UserDefinedVariable(VariableTracker):
|
||||
pass
|
||||
value: object
|
||||
|
||||
|
||||
class UserDefinedClassVariable(UserDefinedVariable):
|
||||
value: type[object]
|
||||
|
||||
def __init__(self, value, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.value = value
|
||||
|
|
|
|||
|
|
@ -239,7 +239,7 @@ class KnapsackEvaluator:
|
|||
"""
|
||||
results = self.evaluate_distribution_of_results_for_knapsack_algo(
|
||||
knapsack_algo=knapsack_algo,
|
||||
memory_budget_values=np.linspace(
|
||||
memory_budget_values=np.linspace( # type: ignore[arg-type]
|
||||
min_mem_budget, max_mem_budget, iterations
|
||||
).tolist(),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -697,8 +697,8 @@ class SACEstimator(TorchDispatchMode):
|
|||
return SACTradeOffStats(
|
||||
n_segments=n_segments,
|
||||
slopes=slopes,
|
||||
intercepts=intercepts,
|
||||
fit_breaks=fit_breaks,
|
||||
intercepts=intercepts, # type: ignore[arg-type]
|
||||
fit_breaks=fit_breaks, # type: ignore[arg-type]
|
||||
tradeoff_curve=tradeoff_curve,
|
||||
sac_memory=sac_memory,
|
||||
sac_runtime=sac_runtime,
|
||||
|
|
|
|||
Loading…
Reference in a new issue