Enable typechecking for _inductor/fx_utils.py (#109415)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109415
Approved by: https://github.com/Skylion007
ghstack dependencies: #109269, #109347, #109335
This commit is contained in:
Jez Ng 2023-09-15 17:50:17 -07:00 committed by PyTorch MergeBot
parent fe452108fb
commit 5cd8a6d40a
2 changed files with 10 additions and 6 deletions

View file

@ -199,7 +199,6 @@ exclude_patterns = [
'torch/_inductor/scheduler.py',
'torch/_inductor/sizevars.py',
'torch/_inductor/pattern_matcher.py',
'torch/_inductor/fx_utils.py',
'torch/_inductor/codegen/triton_foreach.py',
'torch/_inductor/codegen/cpp.py',
'torch/_inductor/codegen/triton.py',

View file

@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Optional, Tuple
from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Type
import torch
import torch.fx
@ -9,7 +9,11 @@ from .virtualized import V
# Check the pattern: (nn.module, F.function/torch.Tensor.method) matched.
# Works for length 2 patterns with 1 module and 1 function/method.
def matches_module_function_pattern(pattern, node, modules):
def matches_module_function_pattern(
pattern: Tuple[Type[torch.nn.modules.Module], Callable[..., Any]],
node: torch.fx.node.Node,
modules: Dict[str, torch.nn.modules.Module],
) -> bool:
if len(node.args) == 0:
return False
if not isinstance(node.args[0], torch.fx.Node) or not isinstance(
@ -69,7 +73,7 @@ class FakeTensorUpdater:
def incremental_update(self):
processed = set()
existing_storages = defaultdict(int)
existing_storages: DefaultDict[Optional[int], int] = defaultdict(int)
for node in self.graph.nodes:
existing_storages[get_node_storage(node)] += 1
@ -128,7 +132,8 @@ class FakeTensorUpdater:
):
continue
updating_node.meta["val"] = new_fake_tensor
existing_storages.add(get_node_storage(new_fake_tensor))
# FIXME: defaultdict has no add() method
existing_storages.add(get_node_storage(new_fake_tensor)) # type: ignore[attr-defined]
processed.add(updating_node)
for user in updating_node.users:
processing.append(user)
@ -158,7 +163,7 @@ def get_fake(x):
return x
def get_fake_args_kwargs(x: torch.fx.Node) -> Tuple[bool, tuple, dict]:
def get_fake_args_kwargs(x: torch.fx.Node) -> Tuple[bool, Tuple[Any], Dict[str, Any]]:
"""
First value returns a boolean if any of the input nodes don't have a faketensor.
"""