[inductor] Improve type annotations in _inductor/pattern_matcher.py

ghstack-source-id: ebf09f9eab
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146626
This commit is contained in:
Tom Ritchford 2025-02-08 14:04:07 +00:00
parent c5902143dc
commit 55fb29e75c
2 changed files with 41 additions and 33 deletions

View file

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
"""
# Inductor Pattern Matcher
@ -50,9 +49,9 @@ import textwrap
import typing
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Generator, Iterable, Mapping, Sequence
from collections.abc import Collection, Generator, Iterable, Mapping, Sequence
from pathlib import Path
from typing import Any, Callable, NoReturn, Optional, Protocol, TypeVar, Union
from typing import Any, Callable, cast, NoReturn, Optional, Protocol, TypeVar, Union
from typing_extensions import Self, TypeIs
import torch
@ -261,7 +260,7 @@ class Match:
fwd_only, run_functional_passes=run_functional_passes
)
replacement = trace_fn(
replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) # type: ignore[arg-type]
replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"])
)
if len(self.nodes) == 1:
for n in replacement.graph.nodes:
@ -646,8 +645,9 @@ class _TargetArgsExpr(_TargetExpr):
if len(_kwargs) < len(self.kwargs):
from torch.fx.operator_schemas import normalize_function
assert callable(node.target)
normalized_args_and_kwargs = normalize_function(
node.target, node.args, node.kwargs # type: ignore[arg-type]
node.target, node.args, node.kwargs
)
if normalized_args_and_kwargs is None:
@ -1074,7 +1074,8 @@ class ReplacementPatternEntry(PatternEntry):
if node.op == "call_function":
target = node.target
args, kwargs = self.fetch_args_kwargs_from_env(node)
result = graph.call_function(target, args, kwargs) # type: ignore[arg-type]
assert callable(target)
result = graph.call_function(target, args, kwargs)
_transfer_meta(
new_meta=result.meta,
old_node=node,
@ -1123,7 +1124,8 @@ class ReplacementPatternEntry(PatternEntry):
queue.extend(arg.all_input_nodes)
with graph.inserting_before(last_node):
replacement = Replacer(replacement_graph).run(*args) # type: ignore[arg-type]
replacement_module = cast(torch.fx.GraphModule, replacement_graph)
replacement = Replacer(replacement_module).run(*args)
if isinstance(replacement, torch.fx.Node):
replacement = [replacement]
@ -1201,7 +1203,7 @@ class ReplacementPatternEntry(PatternEntry):
idx = maybe_getitem(user)
if idx is None:
raise AssertionError("can't handle")
replace(user, new[idx]) # type: ignore[index]
replace(user, new[idx])
graph.erase_node(old)
if len(output_nodes) == len(replacement):
@ -1320,10 +1322,11 @@ def register_replacement(
)
args = list(
torch.fx.map_arg( # type: ignore[arg-type]
torch.fx.map_arg(
[match.kwargs[name] for name in argnames], lambda n: n.meta["val"]
)
)
sym_args: list[torch.SymInt] = []
with torch._dynamo.utils.detect_fake_mode(args):
for i, grad in enumerate(requires_grad):
@ -1618,7 +1621,7 @@ def gen_register_replacement(
)
@functorch_config.patch(functionalize_rng_ops=False)
@functorch_config.patch(functionalize_rng_ops=False) # type: ignore[misc]
def gen_pattern_and_search_gm(
search_fn: SearchFn,
example_inputs: Sequence[Any],
@ -1743,10 +1746,11 @@ def is_mutation_op(node: torch.fx.Node) -> bool:
):
return False
if node.op == "call_function":
if _mutation_op_re.search(node.target.__name__): # type: ignore[union-attr]
assert callable(node.target)
if _mutation_op_re.search(node.target.__name__):
return True
elif node.op == "call_method":
if _mutation_op_re.search(node.target): # type: ignore[union-attr, arg-type]
if _mutation_op_re.search(cast(str, node.target)):
return True
return node.kwargs.get("out") is not None
@ -1770,13 +1774,13 @@ def get_mutation_region_id(graph: torch.fx.Graph, node: torch.fx.Node) -> int:
return mutation_region_id
def should_compute_mutation_region_ids(graph: torch.fx.GraphModule) -> bool:
return "mutation_region_id" not in next(iter(graph.nodes)).meta # type: ignore[arg-type]
def should_compute_mutation_region_ids(graph: torch.fx.Graph) -> bool:
return "mutation_region_id" not in next(iter(graph.nodes)).meta
def compute_mutation_region_ids(graph: torch.fx.GraphModule) -> None:
def compute_mutation_region_ids(graph: torch.fx.Graph) -> None:
mutation_region_id = 0
for nd in graph.nodes: # type: ignore[union-attr]
for nd in graph.nodes:
if is_mutation_op(nd):
mutation_region_id += 1
nd.meta["mutation_region_id"] = mutation_region_id
@ -1814,8 +1818,8 @@ class PatternMatcherPass:
raise RuntimeError(
f"The input to PatternMatcherPass must be a GraphModule or a Graph, but got {type(gm)}"
)
if should_compute_mutation_region_ids(graph): # type: ignore[arg-type]
compute_mutation_region_ids(graph) # type: ignore[arg-type]
if should_compute_mutation_region_ids(graph):
compute_mutation_region_ids(graph)
get_mutation_region_id_partial = functools.partial(
get_mutation_region_id, graph
)
@ -1830,8 +1834,7 @@ class PatternMatcherPass:
if has_call_module:
nodes.append(graph.find_nodes(op="call_module", sort=False))
pass_name = self.pass_name if self.pass_name is not None else "pattern_matcher"
assert isinstance(gm, torch.fx.GraphModule)
with GraphTransformObserver(gm, pass_name):
with GraphTransformObserver(cast(torch.fx.GraphModule, gm), pass_name):
for node in sorted(itertools.chain.from_iterable(nodes), reverse=True):
target = extract_target(node)
if node.op == "call_module":
@ -1851,14 +1854,17 @@ class PatternMatcherPass:
# pattern match crosses mutation barrier - discard
if (
is_match(m)
and len(OrderedSet(map(get_mutation_region_id_partial, m.nodes))) != 1 # type: ignore[possibly-undefined]
and len(
OrderedSet(map(get_mutation_region_id_partial, m.nodes))
)
!= 1
):
continue
if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name:
log.warning("%s%s %s %s", node, node.args, m, entry.pattern)
if is_match(m) and entry.extra_check(m):
count += 1
entry.apply(m, graph, node) # type: ignore[arg-type]
entry.apply(m, graph, node)
counters["inductor"]["pattern_matcher_count"] += 1
counters["inductor"]["pattern_matcher_nodes"] += len(m.nodes)
return count
@ -1953,14 +1959,15 @@ def fx_to_pattern(
def run_node(self, n: torch.fx.Node) -> Any:
rv = super().run_node(n)
if n.op == "output" and isinstance(rv, tuple):
assert len(rv) == len(n.args[0]) # type: ignore[arg-type]
for r, arg in zip(rv, n.args[0]): # type: ignore[arg-type]
args = cast(Collection[Any], n.args[0])
assert len(rv) == len(args)
for r, arg in zip(rv, args):
r.users = len(arg.users)
else:
rv.users = len(n.users)
return rv
pattern = Converter(gm).run() # type: ignore[arg-type]
pattern = Converter(cast(torch.fx.GraphModule, gm)).run()
if not isinstance(pattern, PatternExpr):
return MultiOutputPattern(pytree.tree_leaves(pattern))
return pattern
@ -2030,7 +2037,7 @@ def joint_fwd_bwd(fn: Callable[..., Any], args: Sequence[Any]) -> torch.fx.Graph
GraphPatternEntry(
pattern=pattern, handler=pointless_view, extra_check=_return_true
).register(matcher_pass.patterns)
matcher_pass.apply(gm.graph) # type: ignore[arg-type]
matcher_pass.apply(gm.graph)
# remove in/out specs
gm.graph._codegen = torch.fx.graph.CodeGen()
@ -2130,11 +2137,12 @@ _seen_patterns = OrderedSet[str]()
def get_arg_value(
node: torch.fx.Node, arg_number: int, kwarg_name: Optional[str] = None
) -> Any:
return (
node.args[arg_number]
if len(node.args) > arg_number
else node.kwargs.get(kwarg_name) # type: ignore[arg-type]
)
if len(node.args) > arg_number:
return node.args[arg_number]
elif kwarg_name is None:
return None
else:
return node.kwargs.get(kwarg_name)
def filter_nodes(nodes: Iterable[torch.fx.Node], fn: Any) -> list[torch.fx.Node]:
@ -2151,5 +2159,5 @@ def extract_target(node: torch.fx.Node) -> torch.fx.node.Target:
as a function.
"""
if node.op == "call_module":
return _get_attr(node.graph.owning_module, node.target).__class__ # type: ignore[arg-type]
return _get_attr(node.graph.owning_module, cast(str, node.target)).__class__
return node.target

View file

@ -335,7 +335,7 @@ def type_matches(signature_type: Any, argument_type: Any):
@compatibility(is_backward_compatible=False)
def normalize_function(
target: Callable,
args: tuple[Any],
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
arg_types: Optional[tuple[Any]] = None,
kwarg_types: Optional[dict[str, Any]] = None,