mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Enable typechecking for _inductor/fx_utils.py (#109415)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109415 Approved by: https://github.com/Skylion007 ghstack dependencies: #109269, #109347, #109335
This commit is contained in:
parent
fe452108fb
commit
5cd8a6d40a
2 changed files with 10 additions and 6 deletions
|
|
@ -199,7 +199,6 @@ exclude_patterns = [
|
|||
'torch/_inductor/scheduler.py',
|
||||
'torch/_inductor/sizevars.py',
|
||||
'torch/_inductor/pattern_matcher.py',
|
||||
'torch/_inductor/fx_utils.py',
|
||||
'torch/_inductor/codegen/triton_foreach.py',
|
||||
'torch/_inductor/codegen/cpp.py',
|
||||
'torch/_inductor/codegen/triton.py',
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections import defaultdict
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
|
|
@ -9,7 +9,11 @@ from .virtualized import V
|
|||
|
||||
# Check the pattern: (nn.module, F.function/torch.Tensor.method) matched.
|
||||
# Works for length 2 patterns with 1 module and 1 function/method.
|
||||
def matches_module_function_pattern(pattern, node, modules):
|
||||
def matches_module_function_pattern(
|
||||
pattern: Tuple[Type[torch.nn.modules.Module], Callable[..., Any]],
|
||||
node: torch.fx.node.Node,
|
||||
modules: Dict[str, torch.nn.modules.Module],
|
||||
) -> bool:
|
||||
if len(node.args) == 0:
|
||||
return False
|
||||
if not isinstance(node.args[0], torch.fx.Node) or not isinstance(
|
||||
|
|
@ -69,7 +73,7 @@ class FakeTensorUpdater:
|
|||
|
||||
def incremental_update(self):
|
||||
processed = set()
|
||||
existing_storages = defaultdict(int)
|
||||
existing_storages: DefaultDict[Optional[int], int] = defaultdict(int)
|
||||
for node in self.graph.nodes:
|
||||
existing_storages[get_node_storage(node)] += 1
|
||||
|
||||
|
|
@ -128,7 +132,8 @@ class FakeTensorUpdater:
|
|||
):
|
||||
continue
|
||||
updating_node.meta["val"] = new_fake_tensor
|
||||
existing_storages.add(get_node_storage(new_fake_tensor))
|
||||
# FIXME: defaultdict has no add() method
|
||||
existing_storages.add(get_node_storage(new_fake_tensor)) # type: ignore[attr-defined]
|
||||
processed.add(updating_node)
|
||||
for user in updating_node.users:
|
||||
processing.append(user)
|
||||
|
|
@ -158,7 +163,7 @@ def get_fake(x):
|
|||
return x
|
||||
|
||||
|
||||
def get_fake_args_kwargs(x: torch.fx.Node) -> Tuple[bool, tuple, dict]:
|
||||
def get_fake_args_kwargs(x: torch.fx.Node) -> Tuple[bool, Tuple[Any], Dict[str, Any]]:
|
||||
"""
|
||||
First value returns a boolean if any of the input nodes don't have a faketensor.
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in a new issue