From 5cd8a6d40a8a68f611cff34a561ac20d2ea5e561 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Fri, 15 Sep 2023 17:50:17 -0700 Subject: [PATCH] Enable typechecking for _inductor/fx_utils.py (#109415) Pull Request resolved: https://github.com/pytorch/pytorch/pull/109415 Approved by: https://github.com/Skylion007 ghstack dependencies: #109269, #109347, #109335 --- .lintrunner.toml | 1 - torch/_inductor/fx_utils.py | 15 ++++++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index b0b5c4badd3..bbfba8232cf 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -199,7 +199,6 @@ exclude_patterns = [ 'torch/_inductor/scheduler.py', 'torch/_inductor/sizevars.py', 'torch/_inductor/pattern_matcher.py', - 'torch/_inductor/fx_utils.py', 'torch/_inductor/codegen/triton_foreach.py', 'torch/_inductor/codegen/cpp.py', 'torch/_inductor/codegen/triton.py', diff --git a/torch/_inductor/fx_utils.py b/torch/_inductor/fx_utils.py index 9673142f7da..53f17d60136 100644 --- a/torch/_inductor/fx_utils.py +++ b/torch/_inductor/fx_utils.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Optional, Tuple +from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Type import torch import torch.fx @@ -9,7 +9,11 @@ from .virtualized import V # Check the pattern: (nn.module, F.function/torch.Tensor.method) matched. # Works for length 2 patterns with 1 module and 1 function/method. -def matches_module_function_pattern(pattern, node, modules): +def matches_module_function_pattern( + pattern: Tuple[Type[torch.nn.modules.Module], Callable[..., Any]], + node: torch.fx.node.Node, + modules: Dict[str, torch.nn.modules.Module], +) -> bool: if len(node.args) == 0: return False if not isinstance(node.args[0], torch.fx.Node) or not isinstance( @@ -69,7 +73,7 @@ class FakeTensorUpdater: def incremental_update(self): processed = set() - existing_storages = defaultdict(int) + existing_storages: DefaultDict[Optional[int], int] = defaultdict(int) for node in self.graph.nodes: existing_storages[get_node_storage(node)] += 1 @@ -128,7 +132,8 @@ class FakeTensorUpdater: ): continue updating_node.meta["val"] = new_fake_tensor - existing_storages.add(get_node_storage(new_fake_tensor)) + # FIXME: defaultdict has no add() method + existing_storages.add(get_node_storage(new_fake_tensor)) # type: ignore[attr-defined] processed.add(updating_node) for user in updating_node.users: processing.append(user) @@ -158,7 +163,7 @@ def get_fake(x): return x -def get_fake_args_kwargs(x: torch.fx.Node) -> Tuple[bool, tuple, dict]: +def get_fake_args_kwargs(x: torch.fx.Node) -> Tuple[bool, Tuple[Any], Dict[str, Any]]: """ First value returns a boolean if any of the input nodes don't have a faketensor. """