mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[inductor] Refactor op handlers part 3 (#146254)
Fixes type errors that arise from typing `V.ops`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146254 Approved by: https://github.com/shunting314 ghstack dependencies: #146225, #146226, #146235, #146252
This commit is contained in:
parent
13f0436abd
commit
8e9bda8d89
3 changed files with 44 additions and 30 deletions
|
|
@ -74,7 +74,7 @@ from .dependencies import (
|
|||
var_builder,
|
||||
)
|
||||
from .loop_body import LoopBody
|
||||
from .ops_handler import OpCounterCSE, OpCountResult
|
||||
from .ops_handler import OpCounterCSE, OpCountResult, ReductionType, StoreMode
|
||||
from .runtime.benchmarking import benchmarker
|
||||
from .runtime.hints import DeviceProperties, ReductionHint
|
||||
from .utils import (
|
||||
|
|
@ -915,9 +915,9 @@ class Pointwise(Loops):
|
|||
output_name: Optional[str],
|
||||
indexer: Callable[[Sequence[Expr]], Never],
|
||||
vars: Sequence[Expr],
|
||||
) -> OpsValue:
|
||||
) -> None:
|
||||
loader = self.make_loader()
|
||||
return ops.store(output_name, indexer(vars), loader(vars))
|
||||
return ops.store(output_name or "unnamed", indexer(vars), loader(vars))
|
||||
|
||||
def constant_to_device(self, device: torch.device) -> IRNode:
|
||||
"""Move this to a given device. Requires that all reads are to constants."""
|
||||
|
|
@ -931,7 +931,7 @@ class Pointwise(Loops):
|
|||
@ir_dataclass
|
||||
class Scatter(Pointwise):
|
||||
output_indexer: Callable[[Sequence[Expr]], Expr]
|
||||
scatter_mode: Optional[str] = None
|
||||
scatter_mode: StoreMode = None
|
||||
|
||||
def constant_to_device(self, device: torch.device) -> IRNode:
|
||||
"""Move this to a given device. Requires that all reads are to constants."""
|
||||
|
|
@ -951,8 +951,10 @@ class Scatter(Pointwise):
|
|||
output_name: Optional[str],
|
||||
indexer: Callable[[Sequence[Expr]], Never],
|
||||
vars: Sequence[Expr],
|
||||
) -> OpsValue:
|
||||
) -> None:
|
||||
loader = self.make_loader()
|
||||
if output_name is None:
|
||||
output_name = "unnamed"
|
||||
return ops.store(
|
||||
output_name,
|
||||
indexer(self.output_indexer(vars)),
|
||||
|
|
@ -1037,7 +1039,7 @@ def get_reduction_combine_fn(
|
|||
@ir_dataclass
|
||||
class Reduction(Loops):
|
||||
reduction_ranges: Sequence[_IntLike]
|
||||
reduction_type: str
|
||||
reduction_type: ReductionType
|
||||
# self.dtype represents the dst dtype
|
||||
src_dtype: torch.dtype
|
||||
reduction_hint: ReductionHint
|
||||
|
|
@ -1064,14 +1066,14 @@ class Reduction(Loops):
|
|||
indexer: Callable[[Sequence[Expr]], Never],
|
||||
vars: Sequence[Expr],
|
||||
reduction_vars: Sequence[Symbol],
|
||||
) -> OpsValue:
|
||||
) -> None:
|
||||
value = ops.reduction(
|
||||
self.dtype,
|
||||
self.src_dtype,
|
||||
self.reduction_type,
|
||||
self.inner_fn(vars, reduction_vars),
|
||||
)
|
||||
return ops.store_reduction(output_name, indexer(vars), value)
|
||||
return ops.store_reduction(output_name or "unnamed", indexer(vars), value)
|
||||
|
||||
def index_length(self) -> int:
|
||||
return len(self.ranges) + len(self.reduction_ranges)
|
||||
|
|
@ -1109,7 +1111,7 @@ class Reduction(Loops):
|
|||
inner_fn: Callable[..., OpsValue],
|
||||
ranges: Sequence[_IntLike],
|
||||
reduction_ranges: Sequence[_IntLike],
|
||||
reduction_type: str,
|
||||
reduction_type: Union[ReductionType, Literal["scan"]],
|
||||
reduction_numel: Expr,
|
||||
input_node: Optional[IRNode] = None,
|
||||
) -> tuple[ReductionHint, _IntLike]:
|
||||
|
|
@ -1195,7 +1197,7 @@ class Reduction(Loops):
|
|||
inner_fn=inner_fn,
|
||||
ranges=ranges,
|
||||
reduction_ranges=reduction_ranges,
|
||||
reduction_type=reduction_type,
|
||||
reduction_type=reduction_type if reduction_type != "scan" else "sum",
|
||||
src_dtype=src_dtype,
|
||||
reduction_hint=ReductionHint.DEFAULT,
|
||||
)
|
||||
|
|
@ -1322,7 +1324,7 @@ class Reduction(Loops):
|
|||
inner_fn: Callable[..., Any],
|
||||
ranges: Sequence[Expr],
|
||||
reduction_ranges: Sequence[Expr],
|
||||
reduction_type: str,
|
||||
reduction_type: ReductionType,
|
||||
reduction_hint: ReductionHint = ReductionHint.DEFAULT,
|
||||
input_node: Optional[IRNode] = None,
|
||||
) -> TensorBox:
|
||||
|
|
@ -1590,7 +1592,7 @@ class Reduction(Loops):
|
|||
original_reduction_ranges: Sequence[Expr],
|
||||
new_ranges: list[Expr],
|
||||
new_reduction_ranges: list[Integer],
|
||||
reduction_type: str,
|
||||
reduction_type: ReductionType,
|
||||
split: _IntLike,
|
||||
reduction_hint: ReductionHint,
|
||||
) -> TensorBox:
|
||||
|
|
@ -1652,7 +1654,7 @@ class Reduction(Loops):
|
|||
inner_fn: Callable[..., Any],
|
||||
ranges: Sequence[Expr],
|
||||
reduction_ranges: Sequence[Expr],
|
||||
reduction_type: str,
|
||||
reduction_type: ReductionType,
|
||||
split: _IntLike,
|
||||
reduction_hint: ReductionHint,
|
||||
) -> TensorBox:
|
||||
|
|
@ -1693,7 +1695,7 @@ class Reduction(Loops):
|
|||
original_reduction_ranges: Sequence[Expr],
|
||||
new_ranges: list[Integer],
|
||||
new_reduction_ranges: list[Integer],
|
||||
reduction_type: str,
|
||||
reduction_type: ReductionType,
|
||||
reduction_hint: ReductionHint,
|
||||
) -> TensorBox:
|
||||
"""
|
||||
|
|
@ -1732,7 +1734,7 @@ class WelfordReduction(Reduction):
|
|||
inner_fns: Sequence[Callable[[Sequence[Expr], Sequence[Expr]], OpsValue]],
|
||||
ranges: Sequence[Integer],
|
||||
reduction_ranges: Sequence[Integer],
|
||||
reduction_type: str,
|
||||
reduction_type: ReductionType,
|
||||
reduction_hint: ReductionHint,
|
||||
output_index: int,
|
||||
) -> None:
|
||||
|
|
@ -1764,7 +1766,7 @@ class WelfordReduction(Reduction):
|
|||
indexer: Callable[[Sequence[Expr]], Never],
|
||||
vars: Sequence[Expr],
|
||||
reduction_vars: Sequence[Symbol],
|
||||
) -> OpsValue:
|
||||
) -> None:
|
||||
values = ops.reduction(
|
||||
self.dtype,
|
||||
self.src_dtype,
|
||||
|
|
@ -1772,7 +1774,7 @@ class WelfordReduction(Reduction):
|
|||
self.inner_fn(vars, reduction_vars),
|
||||
)
|
||||
value = values[self.output_index]
|
||||
return ops.store_reduction(output_name, indexer(vars), value)
|
||||
return ops.store_reduction(output_name or "unnamed", indexer(vars), value)
|
||||
|
||||
@classmethod
|
||||
def create( # type: ignore[override]
|
||||
|
|
@ -1782,7 +1784,7 @@ class WelfordReduction(Reduction):
|
|||
inner_fns: Sequence[Callable[..., Any]],
|
||||
ranges: list[Integer],
|
||||
reduction_ranges: list[Integer],
|
||||
reduction_type: str,
|
||||
reduction_type: ReductionType,
|
||||
reduction_hint: ReductionHint = ReductionHint.DEFAULT,
|
||||
) -> Sequence[TensorBox]:
|
||||
assert reduction_type in ("welford_reduce", "welford_combine")
|
||||
|
|
@ -1908,7 +1910,7 @@ class WelfordReduction(Reduction):
|
|||
inner_fns: Sequence[Callable[..., Any]],
|
||||
ranges: list[Integer],
|
||||
reduction_ranges: list[Integer],
|
||||
reduction_type: str,
|
||||
reduction_type: ReductionType,
|
||||
split: _IntLike,
|
||||
reduction_hint: ReductionHint,
|
||||
) -> Sequence[TensorBox]:
|
||||
|
|
@ -2028,11 +2030,13 @@ class Scan(Loops):
|
|||
indexer: Callable[[Sequence[_IntLike]], Never],
|
||||
vars: Sequence[Expr],
|
||||
scan_vars: Sequence[Symbol],
|
||||
) -> OpsValue:
|
||||
) -> None:
|
||||
idx = self.reindex(vars, scan_vars)
|
||||
values = [inner_fn(idx) for inner_fn in self.inner_fns]
|
||||
values = tuple(inner_fn(idx) for inner_fn in self.inner_fns)
|
||||
result = ops.scan(self.dtypes, self.combine_fn, values)
|
||||
return ops.store(output_name, indexer(idx), result[self.output_index])
|
||||
return ops.store(
|
||||
output_name or "unnamed", indexer(idx), result[self.output_index]
|
||||
)
|
||||
|
||||
def get_reduction_type(self) -> Optional[str]:
|
||||
# return self.scan_op
|
||||
|
|
@ -2226,11 +2230,13 @@ class Sort(Loops):
|
|||
indexer: Callable[[Sequence[Expr]], Expr],
|
||||
vars: Sequence[Expr],
|
||||
reduction_vars: Sequence[Expr],
|
||||
) -> OpsValue:
|
||||
) -> None:
|
||||
idx = self.reindex(vars, reduction_vars)
|
||||
values = [inner_fn(idx) for inner_fn in self.inner_fns]
|
||||
values = tuple(inner_fn(idx) for inner_fn in self.inner_fns)
|
||||
result = ops.sort(self.dtypes, values, self.stable, self.descending)
|
||||
return ops.store(output_name, indexer(idx), result[self.output_index])
|
||||
return ops.store(
|
||||
output_name or "unnamed", indexer(idx), result[self.output_index]
|
||||
)
|
||||
|
||||
def get_reduction_type(self) -> Optional[str]:
|
||||
return "sort"
|
||||
|
|
@ -3787,7 +3793,7 @@ class Buffer(IRNode):
|
|||
|
||||
def loader(index): # type: ignore[no-untyped-def]
|
||||
indexer = self.make_indexer()
|
||||
return ops.load(self.name, indexer(index))
|
||||
return ops.load(self.name or "unnamed", indexer(index))
|
||||
|
||||
return loader
|
||||
|
||||
|
|
@ -3980,7 +3986,7 @@ class ComputedBuffer(OperationBuffer):
|
|||
return self.data.make_loader()
|
||||
return super().make_loader()
|
||||
|
||||
def get_store_function(self) -> Callable[..., OpsValue]:
|
||||
def get_store_function(self) -> Callable[..., None]:
|
||||
indexer = self.get_layout().as_fixed().make_indexer()
|
||||
if isinstance(self.data, (Reduction, Scan, Sort)):
|
||||
return partial(self.data.store_reduction, self.name, indexer)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools
|
||||
|
|
@ -9,7 +11,7 @@ import os
|
|||
import warnings
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import ParamSpec
|
||||
from unittest.mock import patch
|
||||
|
||||
|
|
@ -78,6 +80,10 @@ from .utils import (
|
|||
from .virtualized import ops, V
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .ops_handler import ReductionType
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
|
@ -5605,7 +5611,7 @@ def _make_reduction_inner(x, *, axis, keepdims, dtype, override_return_dtype):
|
|||
)
|
||||
|
||||
|
||||
def make_reduction(reduction_type: str, override_return_dtype=None):
|
||||
def make_reduction(reduction_type: ReductionType, override_return_dtype=None):
|
||||
def inner(x, axis=None, keepdims=False, *, dtype=None):
|
||||
kwargs = _make_reduction_inner(
|
||||
x,
|
||||
|
|
|
|||
|
|
@ -308,7 +308,9 @@ class OpsWrapper:
|
|||
return _ops.indirect_indexing(index, size, check, wrap_neg)
|
||||
|
||||
|
||||
ops = OpsWrapper()
|
||||
# we lie about the type of ops so the rest of the codebase typecheck properly
|
||||
# DefaultHandler implements the OpsHandler protocol via metaprogramming
|
||||
ops = cast(OpsHandler[Any], OpsWrapper())
|
||||
|
||||
|
||||
class _V:
|
||||
|
|
|
|||
Loading…
Reference in a new issue