diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index dff6806ca35..b5e35dfc0a5 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -2,7 +2,7 @@ import collections from enum import Enum -from typing import Any, Callable, Dict, List, TYPE_CHECKING +from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING if TYPE_CHECKING: from torch._dynamo.symbolic_convert import InstructionTranslator @@ -142,9 +142,9 @@ class VariableTracker(metaclass=VariableTrackerMeta): def visit( cls, fn: Callable[["VariableTracker"], None], - value, - cache=None, - ): + value: Any, + cache: Optional[Dict[int, Any]] = None, + ) -> None: """ Walk value and call fn on all the VariableTracker instances """ diff --git a/torch/_dynamo/variables/lazy.py b/torch/_dynamo/variables/lazy.py index ccd2f1fd5a4..dd704500ef8 100644 --- a/torch/_dynamo/variables/lazy.py +++ b/torch/_dynamo/variables/lazy.py @@ -1,9 +1,10 @@ -# mypy: ignore-errors +# mypy: allow-untyped-defs import collections import functools from typing import Optional from .base import VariableTracker +from .tensor import SymNodeVariable class LazyCache: @@ -60,6 +61,7 @@ class LazyVariableTracker(VariableTracker): """Force construction of the real VariableTracker""" if self._cache.vt is None: self._cache.realize() + assert self._cache.vt is not None return self._cache.vt def unwrap(self): @@ -86,7 +88,7 @@ class LazyVariableTracker(VariableTracker): return getattr(self.realize(), item) # most methods are auto-generated below, these are the ones we want to exclude - visit = VariableTracker.visit + visit = VariableTracker.visit # type: ignore[assignment] __repr__ = VariableTracker.__repr__ @classmethod @@ -132,7 +134,7 @@ class LazyVariableTracker(VariableTracker): class LazySymNodeFormatString: def __init__( - self, sym_node_variable: VariableTracker, fmt_spec_var: VariableTracker + self, sym_node_variable: SymNodeVariable, fmt_spec_var: VariableTracker ): from .constant import ConstantVariable