[inductor] Refactor op handlers part 3 (#146254)

Fixes type errors that arise from typing `V.ops`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146254
Approved by: https://github.com/shunting314
ghstack dependencies: #146225, #146226, #146235, #146252
This commit is contained in:
Jason Ansel 2025-02-04 08:34:01 -08:00 committed by PyTorch MergeBot
parent 13f0436abd
commit 8e9bda8d89
3 changed files with 44 additions and 30 deletions

View file

@ -74,7 +74,7 @@ from .dependencies import (
var_builder,
)
from .loop_body import LoopBody
from .ops_handler import OpCounterCSE, OpCountResult
from .ops_handler import OpCounterCSE, OpCountResult, ReductionType, StoreMode
from .runtime.benchmarking import benchmarker
from .runtime.hints import DeviceProperties, ReductionHint
from .utils import (
@ -915,9 +915,9 @@ class Pointwise(Loops):
output_name: Optional[str],
indexer: Callable[[Sequence[Expr]], Never],
vars: Sequence[Expr],
) -> OpsValue:
) -> None:
loader = self.make_loader()
return ops.store(output_name, indexer(vars), loader(vars))
return ops.store(output_name or "unnamed", indexer(vars), loader(vars))
def constant_to_device(self, device: torch.device) -> IRNode:
"""Move this to a given device. Requires that all reads are to constants."""
@ -931,7 +931,7 @@ class Pointwise(Loops):
@ir_dataclass
class Scatter(Pointwise):
output_indexer: Callable[[Sequence[Expr]], Expr]
scatter_mode: Optional[str] = None
scatter_mode: StoreMode = None
def constant_to_device(self, device: torch.device) -> IRNode:
"""Move this to a given device. Requires that all reads are to constants."""
@ -951,8 +951,10 @@ class Scatter(Pointwise):
output_name: Optional[str],
indexer: Callable[[Sequence[Expr]], Never],
vars: Sequence[Expr],
) -> OpsValue:
) -> None:
loader = self.make_loader()
if output_name is None:
output_name = "unnamed"
return ops.store(
output_name,
indexer(self.output_indexer(vars)),
@ -1037,7 +1039,7 @@ def get_reduction_combine_fn(
@ir_dataclass
class Reduction(Loops):
reduction_ranges: Sequence[_IntLike]
reduction_type: str
reduction_type: ReductionType
# self.dtype represents the dst dtype
src_dtype: torch.dtype
reduction_hint: ReductionHint
@ -1064,14 +1066,14 @@ class Reduction(Loops):
indexer: Callable[[Sequence[Expr]], Never],
vars: Sequence[Expr],
reduction_vars: Sequence[Symbol],
) -> OpsValue:
) -> None:
value = ops.reduction(
self.dtype,
self.src_dtype,
self.reduction_type,
self.inner_fn(vars, reduction_vars),
)
return ops.store_reduction(output_name, indexer(vars), value)
return ops.store_reduction(output_name or "unnamed", indexer(vars), value)
def index_length(self) -> int:
return len(self.ranges) + len(self.reduction_ranges)
@ -1109,7 +1111,7 @@ class Reduction(Loops):
inner_fn: Callable[..., OpsValue],
ranges: Sequence[_IntLike],
reduction_ranges: Sequence[_IntLike],
reduction_type: str,
reduction_type: Union[ReductionType, Literal["scan"]],
reduction_numel: Expr,
input_node: Optional[IRNode] = None,
) -> tuple[ReductionHint, _IntLike]:
@ -1195,7 +1197,7 @@ class Reduction(Loops):
inner_fn=inner_fn,
ranges=ranges,
reduction_ranges=reduction_ranges,
reduction_type=reduction_type,
reduction_type=reduction_type if reduction_type != "scan" else "sum",
src_dtype=src_dtype,
reduction_hint=ReductionHint.DEFAULT,
)
@ -1322,7 +1324,7 @@ class Reduction(Loops):
inner_fn: Callable[..., Any],
ranges: Sequence[Expr],
reduction_ranges: Sequence[Expr],
reduction_type: str,
reduction_type: ReductionType,
reduction_hint: ReductionHint = ReductionHint.DEFAULT,
input_node: Optional[IRNode] = None,
) -> TensorBox:
@ -1590,7 +1592,7 @@ class Reduction(Loops):
original_reduction_ranges: Sequence[Expr],
new_ranges: list[Expr],
new_reduction_ranges: list[Integer],
reduction_type: str,
reduction_type: ReductionType,
split: _IntLike,
reduction_hint: ReductionHint,
) -> TensorBox:
@ -1652,7 +1654,7 @@ class Reduction(Loops):
inner_fn: Callable[..., Any],
ranges: Sequence[Expr],
reduction_ranges: Sequence[Expr],
reduction_type: str,
reduction_type: ReductionType,
split: _IntLike,
reduction_hint: ReductionHint,
) -> TensorBox:
@ -1693,7 +1695,7 @@ class Reduction(Loops):
original_reduction_ranges: Sequence[Expr],
new_ranges: list[Integer],
new_reduction_ranges: list[Integer],
reduction_type: str,
reduction_type: ReductionType,
reduction_hint: ReductionHint,
) -> TensorBox:
"""
@ -1732,7 +1734,7 @@ class WelfordReduction(Reduction):
inner_fns: Sequence[Callable[[Sequence[Expr], Sequence[Expr]], OpsValue]],
ranges: Sequence[Integer],
reduction_ranges: Sequence[Integer],
reduction_type: str,
reduction_type: ReductionType,
reduction_hint: ReductionHint,
output_index: int,
) -> None:
@ -1764,7 +1766,7 @@ class WelfordReduction(Reduction):
indexer: Callable[[Sequence[Expr]], Never],
vars: Sequence[Expr],
reduction_vars: Sequence[Symbol],
) -> OpsValue:
) -> None:
values = ops.reduction(
self.dtype,
self.src_dtype,
@ -1772,7 +1774,7 @@ class WelfordReduction(Reduction):
self.inner_fn(vars, reduction_vars),
)
value = values[self.output_index]
return ops.store_reduction(output_name, indexer(vars), value)
return ops.store_reduction(output_name or "unnamed", indexer(vars), value)
@classmethod
def create( # type: ignore[override]
@ -1782,7 +1784,7 @@ class WelfordReduction(Reduction):
inner_fns: Sequence[Callable[..., Any]],
ranges: list[Integer],
reduction_ranges: list[Integer],
reduction_type: str,
reduction_type: ReductionType,
reduction_hint: ReductionHint = ReductionHint.DEFAULT,
) -> Sequence[TensorBox]:
assert reduction_type in ("welford_reduce", "welford_combine")
@ -1908,7 +1910,7 @@ class WelfordReduction(Reduction):
inner_fns: Sequence[Callable[..., Any]],
ranges: list[Integer],
reduction_ranges: list[Integer],
reduction_type: str,
reduction_type: ReductionType,
split: _IntLike,
reduction_hint: ReductionHint,
) -> Sequence[TensorBox]:
@ -2028,11 +2030,13 @@ class Scan(Loops):
indexer: Callable[[Sequence[_IntLike]], Never],
vars: Sequence[Expr],
scan_vars: Sequence[Symbol],
) -> OpsValue:
) -> None:
idx = self.reindex(vars, scan_vars)
values = [inner_fn(idx) for inner_fn in self.inner_fns]
values = tuple(inner_fn(idx) for inner_fn in self.inner_fns)
result = ops.scan(self.dtypes, self.combine_fn, values)
return ops.store(output_name, indexer(idx), result[self.output_index])
return ops.store(
output_name or "unnamed", indexer(idx), result[self.output_index]
)
def get_reduction_type(self) -> Optional[str]:
# return self.scan_op
@ -2226,11 +2230,13 @@ class Sort(Loops):
indexer: Callable[[Sequence[Expr]], Expr],
vars: Sequence[Expr],
reduction_vars: Sequence[Expr],
) -> OpsValue:
) -> None:
idx = self.reindex(vars, reduction_vars)
values = [inner_fn(idx) for inner_fn in self.inner_fns]
values = tuple(inner_fn(idx) for inner_fn in self.inner_fns)
result = ops.sort(self.dtypes, values, self.stable, self.descending)
return ops.store(output_name, indexer(idx), result[self.output_index])
return ops.store(
output_name or "unnamed", indexer(idx), result[self.output_index]
)
def get_reduction_type(self) -> Optional[str]:
return "sort"
@ -3787,7 +3793,7 @@ class Buffer(IRNode):
def loader(index): # type: ignore[no-untyped-def]
indexer = self.make_indexer()
return ops.load(self.name, indexer(index))
return ops.load(self.name or "unnamed", indexer(index))
return loader
@ -3980,7 +3986,7 @@ class ComputedBuffer(OperationBuffer):
return self.data.make_loader()
return super().make_loader()
def get_store_function(self) -> Callable[..., OpsValue]:
def get_store_function(self) -> Callable[..., None]:
indexer = self.get_layout().as_fixed().make_indexer()
if isinstance(self.data, (Reduction, Scan, Sort)):
return partial(self.data.store_reduction, self.name, indexer)

View file

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import dataclasses
import functools
import itertools
@ -9,7 +11,7 @@ import os
import warnings
from collections import defaultdict
from collections.abc import Iterable, Sequence
from typing import Any, Callable, Optional, TypeVar, Union
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import ParamSpec
from unittest.mock import patch
@ -78,6 +80,10 @@ from .utils import (
from .virtualized import ops, V
if TYPE_CHECKING:
from .ops_handler import ReductionType
_T = TypeVar("_T")
_P = ParamSpec("_P")
@ -5605,7 +5611,7 @@ def _make_reduction_inner(x, *, axis, keepdims, dtype, override_return_dtype):
)
def make_reduction(reduction_type: str, override_return_dtype=None):
def make_reduction(reduction_type: ReductionType, override_return_dtype=None):
def inner(x, axis=None, keepdims=False, *, dtype=None):
kwargs = _make_reduction_inner(
x,

View file

@ -308,7 +308,9 @@ class OpsWrapper:
return _ops.indirect_indexing(index, size, check, wrap_neg)
ops = OpsWrapper()
# we lie about the type of ops so the rest of the codebase typecheck properly
# DefaultHandler implements the OpsHandler protocol via metaprogramming
ops = cast(OpsHandler[Any], OpsWrapper())
class _V: