Turn on mypy for _dynamo/variables/builtin.py (#145552)

The fact that mypy errors were ignored was hiding several bugs in builtin.py (for example the previous diff's incorrect override and use of `call_getattr`)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145552
Approved by: https://github.com/anijain2305, https://github.com/Skylion007
ghstack dependencies: #145551
This commit is contained in:
Aaron Orenstein 2025-01-27 20:39:30 -08:00 committed by PyTorch MergeBot
parent f3120f6d26
commit ccbbc88bbb
8 changed files with 113 additions and 34 deletions

View file

@ -252,6 +252,8 @@ class OutputGraph:
the root InstructionTranslator's OutputGraph.
"""
side_effects: SideEffects
def __init__(
self,
code_options: dict[str, Any],

View file

@ -20,6 +20,7 @@ class ModuleRecord:
class DummyModule:
name: str
is_torch: bool = False
value: object = None
@property
def __name__(self) -> str:

View file

@ -2,7 +2,7 @@
import collections
from enum import Enum
from typing import Any, Callable, Optional, TYPE_CHECKING
from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING
from .. import variables
from ..current_scope_id import current_scope_id
@ -390,7 +390,7 @@ class VariableTracker(metaclass=VariableTrackerMeta):
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
args: Sequence["VariableTracker"],
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
unimplemented(f"call_function {self} {args} {kwargs}")

View file

@ -1,4 +1,4 @@
# mypy: ignore-errors
# mypy: allow-untyped-defs
import contextlib
import functools
@ -8,9 +8,10 @@ import logging
import math
import operator
import types
import typing
from collections import defaultdict, OrderedDict
from collections.abc import KeysView
from typing import TYPE_CHECKING
from typing import Callable, Sequence, TYPE_CHECKING, Union
import torch
from torch import sym_float, sym_int
@ -78,6 +79,7 @@ from .user_defined import UserDefinedObjectVariable, UserDefinedVariable
if TYPE_CHECKING:
# Cyclic dependency...
from torch._dynamo.symbolic_convert import InstructionTranslator
log = logging.getLogger(__name__)
@ -100,6 +102,12 @@ IN_PLACE_DESUGARING_MAP = {
}
_HandlerCallback = Callable[
["InstructionTranslator", typing.Any, typing.Any], VariableTracker
]
_TrackersType = Union[type[VariableTracker], tuple[type[VariableTracker], ...]]
class BuiltinVariable(VariableTracker):
_SENTINEL = object()
_nonvar_fields = {
@ -230,9 +238,11 @@ class BuiltinVariable(VariableTracker):
@staticmethod
@functools.lru_cache(None)
def _binops():
def _binops() -> (
dict[Callable[..., object], tuple[list[str], Callable[..., object]]]
):
# function -> ([forward name, reverse name, in-place name], in-place op)
fns = {
fns: dict[Callable[..., object], tuple[list[str], Callable[..., object]]] = {
operator.add: (["__add__", "__radd__", "__iadd__"], operator.iadd),
operator.sub: (["__sub__", "__rsub__", "__isub__"], operator.isub),
operator.mul: (["__mul__", "__rmul__", "__imul__"], operator.imul),
@ -284,7 +294,18 @@ class BuiltinVariable(VariableTracker):
)
# Override table contains: op_fn -> [list of handlers]
op_handlers = {}
op_handlers: dict[
Callable[..., object],
list[
tuple[
tuple[
type[VariableTracker],
_TrackersType,
],
_HandlerCallback,
]
],
] = {}
for (
op,
(magic_method_names, in_place_op),
@ -376,7 +397,15 @@ class BuiltinVariable(VariableTracker):
def size_add_handler(tx: "InstructionTranslator", a, b):
return SizeVariable([*a.items, *b.unpack_var_sequence(tx)])
list_like_addition_handlers = [
list_like_addition_handlers: list[
tuple[
tuple[
type[VariableTracker],
_TrackersType,
],
_HandlerCallback,
]
] = [
# NB: Prefer the tuple-specific logic over base logic because of
# some SizeVariable weirdness. Specifically, the tuple-specific logic
# drops the subclass type (e.g. SizeVariable) and returns TupleVariables.
@ -395,7 +424,10 @@ class BuiltinVariable(VariableTracker):
(
(ConstantVariable, TupleVariable),
lambda tx, a, b: TupleVariable(
[*a.unpack_var_sequence(tx), *b.items],
[
*a.unpack_var_sequence(tx),
*b.items,
],
),
),
(
@ -410,7 +442,12 @@ class BuiltinVariable(VariableTracker):
),
(
(BaseListVariable, BaseListVariable),
lambda tx, a, b: type(a)([*a.items, *b.items]),
lambda tx, a, b: type(a)(
[
*a.items,
*b.items,
]
),
),
]
op_handlers[operator.add].extend(list_like_addition_handlers)
@ -425,7 +462,12 @@ class BuiltinVariable(VariableTracker):
a.items.extend(seq)
return a
list_like_iadd_handlers = [
list_like_iadd_handlers: list[
tuple[
tuple[type[VariableTracker], type[VariableTracker]],
_HandlerCallback,
]
] = [
(
(ListVariable, VariableTracker),
list_iadd_handler,
@ -450,7 +492,12 @@ class BuiltinVariable(VariableTracker):
mutation_type=ValueMutationNew(),
)
list_like_expansion_handlers = [
list_like_expansion_handlers: list[
tuple[
tuple[type[VariableTracker], type[VariableTracker]],
_HandlerCallback,
]
] = [
((ListVariable, ConstantVariable), expand_list_like),
((TupleVariable, ConstantVariable), expand_list_like),
((ConstantVariable, ListVariable), expand_list_like),
@ -465,7 +512,15 @@ class BuiltinVariable(VariableTracker):
def compare_by_value(tx: "InstructionTranslator", a, b):
return ConstantVariable(op(a.value, b.value))
result = [((ConstantVariable, ConstantVariable), compare_by_value)]
result: list[
tuple[
tuple[
_TrackersType,
_TrackersType,
],
_HandlerCallback,
]
] = [((ConstantVariable, ConstantVariable), compare_by_value)]
if op in supported_const_comparison_ops.values():
# Tensor is None, List is not None, etc
@ -530,6 +585,7 @@ class BuiltinVariable(VariableTracker):
def compare_via_method(tx: "InstructionTranslator", left, right):
return left.call_method(tx, f"__{op.__name__}__", [right], {})
compare_user_defined: Callable[..., object]
if op.__name__.startswith("is_"):
compare_user_defined = compare_by_value
else:
@ -543,7 +599,12 @@ class BuiltinVariable(VariableTracker):
(UserFunctionVariable, BuiltinVariable),
(UserFunctionVariable, BuiltinVariable),
),
lambda tx, a, b: ConstantVariable(op(a.fn, b.fn)),
lambda tx, a, b: ConstantVariable(
op(
a.fn,
b.fn,
)
),
),
(
(
@ -710,7 +771,7 @@ class BuiltinVariable(VariableTracker):
from .lazy import LazyVariableTracker
obj = BuiltinVariable(fn)
handlers = []
handlers: list[_HandlerCallback] = []
if any(issubclass(t, LazyVariableTracker) for t in arg_types):
return lambda tx, args, kwargs: obj.call_function(
@ -981,14 +1042,25 @@ class BuiltinVariable(VariableTracker):
except NotImplementedError:
unimplemented(f"partial tensor op: {self} {args} {kwargs}")
call_function_handler_cache = {}
call_function_handler_cache: dict[
tuple[object, ...],
Callable[
[
"InstructionTranslator",
Sequence[VariableTracker],
dict[str, VariableTracker],
],
VariableTracker,
],
] = {}
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
args: Sequence["VariableTracker"],
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
key: tuple[object, ...]
if kwargs:
kwargs = {k: v.realize() for k, v in kwargs.items()}
key = (self.fn, *(type(x) for x in args), True)
@ -1130,7 +1202,7 @@ class BuiltinVariable(VariableTracker):
else:
# Overrides for custom str method
# Pass method as function to call tx.inline_user_function_return
bound_method = str_method.__func__
bound_method = str_method.__func__ # type: ignore[attr-defined]
try:
# Only supports certain function types
@ -1191,6 +1263,7 @@ class BuiltinVariable(VariableTracker):
# convert min/max to torch ops
if b.is_python_constant():
fn: VariableTracker
if isinstance(a, variables.NumpyNdarrayVariable):
import numpy as np
@ -1203,11 +1276,11 @@ class BuiltinVariable(VariableTracker):
if isinstance(a, variables.NumpyNdarrayVariable):
import numpy as np
fn = {max: np.maximum, min: np.minimum}[self.fn]
fn = variables.NumpyVariable(fn)
np_fn = {max: np.maximum, min: np.minimum}[self.fn]
fn = variables.NumpyVariable(np_fn)
else:
fn = {max: torch.maximum, min: torch.minimum}[self.fn]
fn = variables.TorchInGraphFunctionVariable(fn)
torch_fn = {max: torch.maximum, min: torch.minimum}[self.fn]
fn = variables.TorchInGraphFunctionVariable(torch_fn)
result = fn.call_function(tx, [a, b], {})
# return unspec if both a, b are unspec or const
@ -1245,9 +1318,9 @@ class BuiltinVariable(VariableTracker):
else:
return result
elif isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable):
fn = torch.sym_max if self.fn is max else torch.sym_min
py_fn = torch.sym_max if self.fn is max else torch.sym_min
proxy = tx.output.create_proxy(
"call_function", fn, *proxy_args_kwargs([a, b], {})
"call_function", py_fn, *proxy_args_kwargs([a, b], {})
)
return SymNodeVariable.create(tx, proxy, None)
elif isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable):
@ -1293,9 +1366,9 @@ class BuiltinVariable(VariableTracker):
if check_unspec_or_constant_args(args, {}):
return variables.RangeVariable(args)
elif self._dynamic_args(*args):
args = [
args = tuple(
variables.ConstantVariable.create(guard_if_dyn(arg)) for arg in args
]
)
return variables.RangeVariable(args)
# None no-ops this handler and lets the driving function proceed
return None
@ -1437,7 +1510,7 @@ class BuiltinVariable(VariableTracker):
assert len(args) == 1 and len(kwargs) == 1 and "value" in kwargs
args = (*args, kwargs.pop("value"))
if len(args) == 0:
raise UserError(TypeError, "fromkeys expected at least 1 argument, got 0")
raise UserError(TypeError, "fromkeys expected at least 1 argument, got 0") # type: ignore[arg-type]
if len(args) == 1:
args = (*args, ConstantVariable.create(None))
assert len(args) == 2
@ -1714,7 +1787,7 @@ class BuiltinVariable(VariableTracker):
member = obj.value.__dict__[name]
if config.replay_record_enabled:
tx.exec_recorder.record_module_access(obj.value, name, member)
tx.exec_recorder.record_module_access(obj.value, name, member) # type: ignore[arg-type, union-attr]
return VariableTracker.build(tx, member, source)
elif istype(obj, variables.UserFunctionVariable) and name in (

View file

@ -5,6 +5,7 @@ import inspect
import logging
import math
import re
from collections.abc import Sequence
from typing import TYPE_CHECKING
import torch._C
@ -242,7 +243,7 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
args: Sequence[VariableTracker],
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
from . import (
@ -931,7 +932,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
args: Sequence[VariableTracker],
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
from . import ConstantVariable, SymNodeVariable, TensorVariable

View file

@ -99,10 +99,12 @@ def is_forbidden_context_manager(ctx):
class UserDefinedVariable(VariableTracker):
pass
value: object
class UserDefinedClassVariable(UserDefinedVariable):
value: type[object]
def __init__(self, value, **kwargs) -> None:
super().__init__(**kwargs)
self.value = value

View file

@ -239,7 +239,7 @@ class KnapsackEvaluator:
"""
results = self.evaluate_distribution_of_results_for_knapsack_algo(
knapsack_algo=knapsack_algo,
memory_budget_values=np.linspace(
memory_budget_values=np.linspace( # type: ignore[arg-type]
min_mem_budget, max_mem_budget, iterations
).tolist(),
)

View file

@ -697,8 +697,8 @@ class SACEstimator(TorchDispatchMode):
return SACTradeOffStats(
n_segments=n_segments,
slopes=slopes,
intercepts=intercepts,
fit_breaks=fit_breaks,
intercepts=intercepts, # type: ignore[arg-type]
fit_breaks=fit_breaks, # type: ignore[arg-type]
tradeoff_curve=tradeoff_curve,
sac_memory=sac_memory,
sac_runtime=sac_runtime,