[inductor] Refactor CaptureIndexing into global scope (#146297)

And inline SimplifyIndexing into it CaptureIndexing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146297
Approved by: https://github.com/shunting314
ghstack dependencies: #146252, #146254, #146255, #146257, #146282
This commit is contained in:
Jason Ansel 2025-02-07 13:32:55 -08:00 committed by PyTorch MergeBot
parent d35f6b2339
commit c098385cb3

View file

@ -17,7 +17,7 @@ from torch.utils._sympy.symbol import SymT
from . import config, dependencies
from .codegen.common import index_prevent_reordering
from .ops_handler import DefaultHandler, OpsHandler
from .ops_handler import DefaultHandler, OpsHandler, WrapperHandler
from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs
from .virtualized import ops, V
@ -440,179 +440,13 @@ class LoopBodyBlock:
def __init__(self, body: LoopBody, fn: Callable[..., Any], args: list[Any]):
self.body = body
def add_index(expr: sympy.Expr, mtype: MemoryUsageType, **kwargs):
return tracer.create_proxy(
"call_module",
"get_index",
(body.add_index_expr(expr, mtype, **kwargs),),
{},
)
class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined]
name = "CaptureIndexing"
def load(self, name: str, index: sympy.Expr):
index = add_index(index, MemoryUsageType.LOAD, buffer_name=name)
return self._inner.load(name, index)
def load_seed(self, name: str, index: int):
assert isinstance(index, int)
body.add_index_expr(
sympy.Integer(index), MemoryUsageType.LOAD_SEED, buffer_name=name
)
return self._inner.load_seed(name, index)
def store(self, name, index, value, mode=None):
index = add_index(
index, MemoryUsageType.STORE, buffer_name=name, mode=mode
)
return self._inner.store(name, index, value, mode)
def store_reduction(self, name, index, value):
index = add_index(
index, MemoryUsageType.STORE_REDUCTION, buffer_name=name
)
return self._inner.store_reduction(name, index, value)
def reduction(self, dtype, src_dtype, reduction_type, value):
result = self._inner.reduction(dtype, src_dtype, reduction_type, value)
if "welford" in reduction_type:
return tuple(result[i] for i in range(3))
return result
def index_expr(self, index, dtype):
if isinstance(index, (int, sympy.Integer)):
return self._inner.constant(int(index), dtype)
index = add_index(index, MemoryUsageType.INDEX_EXPR)
return self._inner.index_expr(index, dtype)
def check_bounds(self, index, size, lower, upper):
index = add_index(index, MemoryUsageType.CHECK_BOUNDS)
size = add_index(size, MemoryUsageType.CHECK_BOUNDS)
return self._inner.check_bounds(index, size, lower, upper)
def bucketize(
self,
values: T,
boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
boundary_indices: T,
indexing_dtype: torch.dtype,
right: bool,
sorter: Optional[tuple[str, sympy.Expr]] = None,
sorter_indices: Optional[T] = None,
) -> T:
"""
See [Note: Inductor bucketize op]
"""
boundaries = (
boundaries[0],
add_index(
boundaries[1],
MemoryUsageType.BUCKETIZE,
buffer_name=boundaries[0],
),
add_index(
boundaries[2],
MemoryUsageType.BUCKETIZE,
buffer_name=boundaries[0],
),
add_index(
boundaries[3],
MemoryUsageType.BUCKETIZE,
buffer_name=boundaries[0],
),
)
if sorter is not None:
sorter = (
sorter[0],
add_index(
sorter[1], MemoryUsageType.BUCKETIZE, buffer_name=sorter[0]
),
)
return self._inner.bucketize(
values,
boundaries,
boundary_indices,
indexing_dtype,
right,
sorter,
sorter_indices,
)
@staticmethod
def masked(mask_proxy, masked_body: Callable[..., Any], other_proxy):
"""
Recursively capture the masked out body in another LoopBodyBlock
"""
name = self.body.add_submodule(None, "masked_subblock")
self.body.submodules[name] = self.body.bind_masked_shim(name)
self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, [])
return tracer.create_proxy(
"call_module", name, (mask_proxy, other_proxy), {}
)
@staticmethod
def scan(
dtype_proxy,
combine_fn: Callable[
[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]
],
value_proxy,
):
shim = self.body.bind_scan_shim(combine_fn)
name = self.body.add_submodule(shim, "scan")
result = tracer.create_proxy(
"call_module",
name,
(dtype_proxy, value_proxy),
{},
)
# Proxies are iterable, but some methods expect tuples/lists
return tuple(result[i] for i in range(len(value_proxy)))
def sort(self, dtypes, values, stable, descending):
result = self._inner.sort(dtypes, values, stable, descending)
# Proxies are iterable, but some methods expect tuples/lists
return tuple(result[i] for i in range(len(values)))
def frexp(self, value_proxy):
result = self._inner.frexp(value_proxy)
# Proxies are iterable, but some methods expect tuples/lists
return (result[0], result[1])
@staticmethod
def indirect_indexing(index_proxy, size, check=True, wrap_neg=True):
"""
Flow data from tensors into indexing formulas.
Introduce a call_module to update the indexing.
"""
var = self.body.add_indirect(size)
set_indirect = self.body.bind_set_indirect_shim(
var, size, check, wrap_neg
)
tracer.create_proxy(
"call_module",
self.body.add_submodule(set_indirect, f"set_{var}"),
(index_proxy,),
{},
)
return var
@staticmethod
def output(*result):
tracer.create_proxy("output", "output", result, {})
tracer = LightTracer()
proxy_ops = tracer.create_proxy("placeholder", "ops", (), {})
from .index_propagation import IndexPropagation
from .sizevars import SimplifyIndexing
handler: Any = CountOps(
SimplifyIndexing(CaptureIndexing(proxy_ops), self.body.var_ranges),
CaptureIndexing(proxy_ops, body, tracer),
body.op_counts,
)
if config.constant_and_index_propagation:
@ -662,3 +496,179 @@ class CountOps(DefaultHandler):
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
self._counts[name] += 1
return getattr(self._inner, name)(*args, **kwargs)
class CaptureIndexing(WrapperHandler):
name = "CaptureIndexing"
def __init__(
self,
inner: OpsHandler[Any],
body: LoopBody,
tracer: LightTracer,
):
super().__init__(inner)
self.body = body
self.tracer = tracer
def _add_index(self, expr: sympy.Expr, mtype: MemoryUsageType, **kwargs: Any):
return self.tracer.create_proxy(
"call_module",
"get_index",
(self.body.add_index_expr(expr, mtype, **kwargs),),
{},
)
def _simplify(self, expr: sympy.Expr) -> sympy.Expr:
return V.graph.sizevars.simplify_with_ranges(expr, self.body.var_ranges)
def load(self, name: str, index: sympy.Expr):
index = self._simplify(index)
index = self._add_index(index, MemoryUsageType.LOAD, buffer_name=name)
return self._inner.load(name, index)
def load_seed(self, name: str, index: int):
assert isinstance(index, int)
self.body.add_index_expr(
sympy.Integer(index), MemoryUsageType.LOAD_SEED, buffer_name=name
)
return self._inner.load_seed(name, index)
def store(self, name, index, value, mode=None):
index = self._simplify(index)
index = self._add_index(
index, MemoryUsageType.STORE, buffer_name=name, mode=mode
)
return self._inner.store(name, index, value, mode)
def store_reduction(self, name, index, value):
index = self._simplify(index)
index = self._add_index(
index, MemoryUsageType.STORE_REDUCTION, buffer_name=name
)
return self._inner.store_reduction(name, index, value)
def reduction(self, dtype, src_dtype, reduction_type, value):
result = self._inner.reduction(dtype, src_dtype, reduction_type, value)
if "welford" in reduction_type:
return tuple(result[i] for i in range(3))
return result
def index_expr(self, index, dtype):
index = self._simplify(index)
if isinstance(index, (int, sympy.Integer)):
return self._inner.constant(int(index), dtype)
index = self._add_index(index, MemoryUsageType.INDEX_EXPR)
return self._inner.index_expr(index, dtype)
def check_bounds(self, index, size, lower, upper):
index = self._simplify(index)
index = self._add_index(index, MemoryUsageType.CHECK_BOUNDS)
size = self._add_index(size, MemoryUsageType.CHECK_BOUNDS)
return self._inner.check_bounds(index, size, lower, upper)
def bucketize(
self,
values: T,
boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
boundary_indices: T,
indexing_dtype: torch.dtype,
right: bool,
sorter: Optional[tuple[str, sympy.Expr]] = None,
sorter_indices: Optional[T] = None,
) -> T:
"""
See [Note: Inductor bucketize op]
"""
boundaries = (
boundaries[0],
self._add_index(
boundaries[1],
MemoryUsageType.BUCKETIZE,
buffer_name=boundaries[0],
),
self._add_index(
boundaries[2],
MemoryUsageType.BUCKETIZE,
buffer_name=boundaries[0],
),
self._add_index(
boundaries[3],
MemoryUsageType.BUCKETIZE,
buffer_name=boundaries[0],
),
)
if sorter is not None:
sorter = (
sorter[0],
self._add_index(
sorter[1], MemoryUsageType.BUCKETIZE, buffer_name=sorter[0]
),
)
return self._inner.bucketize(
values,
boundaries,
boundary_indices,
indexing_dtype,
right,
sorter,
sorter_indices,
)
def masked(self, mask_proxy, masked_body: Callable[..., Any], other_proxy):
"""
Recursively capture the masked out body in another LoopBodyBlock
"""
name = self.body.add_submodule(None, "masked_subblock")
self.body.submodules[name] = self.body.bind_masked_shim(name)
self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, [])
return self.tracer.create_proxy(
"call_module", name, (mask_proxy, other_proxy), {}
)
def scan(
self,
dtype_proxy,
combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]],
value_proxy,
):
shim = self.body.bind_scan_shim(combine_fn)
name = self.body.add_submodule(shim, "scan")
result = self.tracer.create_proxy(
"call_module",
name,
(dtype_proxy, value_proxy),
{},
)
# Proxies are iterable, but some methods expect tuples/lists
return tuple(result[i] for i in range(len(value_proxy)))
def sort(self, dtypes, values, stable, descending):
result = self._inner.sort(dtypes, values, stable, descending)
# Proxies are iterable, but some methods expect tuples/lists
return tuple(result[i] for i in range(len(values)))
def frexp(self, value_proxy):
result = self._inner.frexp(value_proxy)
# Proxies are iterable, but some methods expect tuples/lists
return (result[0], result[1])
def indirect_indexing(self, index_proxy, size, check=True, wrap_neg=True):
"""
Flow data from tensors into indexing formulas.
Introduce a call_module to update the indexing.
"""
var = self.body.add_indirect(size)
set_indirect = self.body.bind_set_indirect_shim(var, size, check, wrap_neg)
self.tracer.create_proxy(
"call_module",
self.body.add_submodule(set_indirect, f"set_{var}"),
(index_proxy,),
{},
)
return var
def output(self, *result):
self.tracer.create_proxy("output", "output", result, {})