diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py index 4968544d80f..c3a3ab7133e 100644 --- a/torch/_inductor/loop_body.py +++ b/torch/_inductor/loop_body.py @@ -17,7 +17,7 @@ from torch.utils._sympy.symbol import SymT from . import config, dependencies from .codegen.common import index_prevent_reordering -from .ops_handler import DefaultHandler, OpsHandler +from .ops_handler import DefaultHandler, OpsHandler, WrapperHandler from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs from .virtualized import ops, V @@ -440,179 +440,13 @@ class LoopBodyBlock: def __init__(self, body: LoopBody, fn: Callable[..., Any], args: list[Any]): self.body = body - - def add_index(expr: sympy.Expr, mtype: MemoryUsageType, **kwargs): - return tracer.create_proxy( - "call_module", - "get_index", - (body.add_index_expr(expr, mtype, **kwargs),), - {}, - ) - - class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined] - name = "CaptureIndexing" - - def load(self, name: str, index: sympy.Expr): - index = add_index(index, MemoryUsageType.LOAD, buffer_name=name) - return self._inner.load(name, index) - - def load_seed(self, name: str, index: int): - assert isinstance(index, int) - body.add_index_expr( - sympy.Integer(index), MemoryUsageType.LOAD_SEED, buffer_name=name - ) - return self._inner.load_seed(name, index) - - def store(self, name, index, value, mode=None): - index = add_index( - index, MemoryUsageType.STORE, buffer_name=name, mode=mode - ) - return self._inner.store(name, index, value, mode) - - def store_reduction(self, name, index, value): - index = add_index( - index, MemoryUsageType.STORE_REDUCTION, buffer_name=name - ) - return self._inner.store_reduction(name, index, value) - - def reduction(self, dtype, src_dtype, reduction_type, value): - result = self._inner.reduction(dtype, src_dtype, reduction_type, value) - if "welford" in reduction_type: - return tuple(result[i] for i in range(3)) - return result - - def index_expr(self, index, dtype): - if isinstance(index, (int, sympy.Integer)): - return self._inner.constant(int(index), dtype) - index = add_index(index, MemoryUsageType.INDEX_EXPR) - return self._inner.index_expr(index, dtype) - - def check_bounds(self, index, size, lower, upper): - index = add_index(index, MemoryUsageType.CHECK_BOUNDS) - size = add_index(size, MemoryUsageType.CHECK_BOUNDS) - return self._inner.check_bounds(index, size, lower, upper) - - def bucketize( - self, - values: T, - boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], - boundary_indices: T, - indexing_dtype: torch.dtype, - right: bool, - sorter: Optional[tuple[str, sympy.Expr]] = None, - sorter_indices: Optional[T] = None, - ) -> T: - """ - See [Note: Inductor bucketize op] - """ - boundaries = ( - boundaries[0], - add_index( - boundaries[1], - MemoryUsageType.BUCKETIZE, - buffer_name=boundaries[0], - ), - add_index( - boundaries[2], - MemoryUsageType.BUCKETIZE, - buffer_name=boundaries[0], - ), - add_index( - boundaries[3], - MemoryUsageType.BUCKETIZE, - buffer_name=boundaries[0], - ), - ) - if sorter is not None: - sorter = ( - sorter[0], - add_index( - sorter[1], MemoryUsageType.BUCKETIZE, buffer_name=sorter[0] - ), - ) - - return self._inner.bucketize( - values, - boundaries, - boundary_indices, - indexing_dtype, - right, - sorter, - sorter_indices, - ) - - @staticmethod - def masked(mask_proxy, masked_body: Callable[..., Any], other_proxy): - """ - Recursively capture the masked out body in another LoopBodyBlock - """ - name = self.body.add_submodule(None, "masked_subblock") - self.body.submodules[name] = self.body.bind_masked_shim(name) - self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, []) - return tracer.create_proxy( - "call_module", name, (mask_proxy, other_proxy), {} - ) - - @staticmethod - def scan( - dtype_proxy, - combine_fn: Callable[ - [tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...] - ], - value_proxy, - ): - shim = self.body.bind_scan_shim(combine_fn) - name = self.body.add_submodule(shim, "scan") - result = tracer.create_proxy( - "call_module", - name, - (dtype_proxy, value_proxy), - {}, - ) - # Proxies are iterable, but some methods expect tuples/lists - return tuple(result[i] for i in range(len(value_proxy))) - - def sort(self, dtypes, values, stable, descending): - result = self._inner.sort(dtypes, values, stable, descending) - # Proxies are iterable, but some methods expect tuples/lists - return tuple(result[i] for i in range(len(values))) - - def frexp(self, value_proxy): - result = self._inner.frexp(value_proxy) - # Proxies are iterable, but some methods expect tuples/lists - return (result[0], result[1]) - - @staticmethod - def indirect_indexing(index_proxy, size, check=True, wrap_neg=True): - """ - Flow data from tensors into indexing formulas. - Introduce a call_module to update the indexing. - """ - - var = self.body.add_indirect(size) - set_indirect = self.body.bind_set_indirect_shim( - var, size, check, wrap_neg - ) - tracer.create_proxy( - "call_module", - self.body.add_submodule(set_indirect, f"set_{var}"), - (index_proxy,), - {}, - ) - return var - - @staticmethod - def output(*result): - tracer.create_proxy("output", "output", result, {}) - tracer = LightTracer() proxy_ops = tracer.create_proxy("placeholder", "ops", (), {}) from .index_propagation import IndexPropagation - from .sizevars import SimplifyIndexing handler: Any = CountOps( - SimplifyIndexing(CaptureIndexing(proxy_ops), self.body.var_ranges), + CaptureIndexing(proxy_ops, body, tracer), body.op_counts, ) if config.constant_and_index_propagation: @@ -662,3 +496,179 @@ class CountOps(DefaultHandler): def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: self._counts[name] += 1 return getattr(self._inner, name)(*args, **kwargs) + + +class CaptureIndexing(WrapperHandler): + name = "CaptureIndexing" + + def __init__( + self, + inner: OpsHandler[Any], + body: LoopBody, + tracer: LightTracer, + ): + super().__init__(inner) + self.body = body + self.tracer = tracer + + def _add_index(self, expr: sympy.Expr, mtype: MemoryUsageType, **kwargs: Any): + return self.tracer.create_proxy( + "call_module", + "get_index", + (self.body.add_index_expr(expr, mtype, **kwargs),), + {}, + ) + + def _simplify(self, expr: sympy.Expr) -> sympy.Expr: + return V.graph.sizevars.simplify_with_ranges(expr, self.body.var_ranges) + + def load(self, name: str, index: sympy.Expr): + index = self._simplify(index) + index = self._add_index(index, MemoryUsageType.LOAD, buffer_name=name) + return self._inner.load(name, index) + + def load_seed(self, name: str, index: int): + assert isinstance(index, int) + self.body.add_index_expr( + sympy.Integer(index), MemoryUsageType.LOAD_SEED, buffer_name=name + ) + return self._inner.load_seed(name, index) + + def store(self, name, index, value, mode=None): + index = self._simplify(index) + index = self._add_index( + index, MemoryUsageType.STORE, buffer_name=name, mode=mode + ) + return self._inner.store(name, index, value, mode) + + def store_reduction(self, name, index, value): + index = self._simplify(index) + index = self._add_index( + index, MemoryUsageType.STORE_REDUCTION, buffer_name=name + ) + return self._inner.store_reduction(name, index, value) + + def reduction(self, dtype, src_dtype, reduction_type, value): + result = self._inner.reduction(dtype, src_dtype, reduction_type, value) + if "welford" in reduction_type: + return tuple(result[i] for i in range(3)) + return result + + def index_expr(self, index, dtype): + index = self._simplify(index) + if isinstance(index, (int, sympy.Integer)): + return self._inner.constant(int(index), dtype) + index = self._add_index(index, MemoryUsageType.INDEX_EXPR) + return self._inner.index_expr(index, dtype) + + def check_bounds(self, index, size, lower, upper): + index = self._simplify(index) + index = self._add_index(index, MemoryUsageType.CHECK_BOUNDS) + size = self._add_index(size, MemoryUsageType.CHECK_BOUNDS) + return self._inner.check_bounds(index, size, lower, upper) + + def bucketize( + self, + values: T, + boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: T, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[T] = None, + ) -> T: + """ + See [Note: Inductor bucketize op] + """ + boundaries = ( + boundaries[0], + self._add_index( + boundaries[1], + MemoryUsageType.BUCKETIZE, + buffer_name=boundaries[0], + ), + self._add_index( + boundaries[2], + MemoryUsageType.BUCKETIZE, + buffer_name=boundaries[0], + ), + self._add_index( + boundaries[3], + MemoryUsageType.BUCKETIZE, + buffer_name=boundaries[0], + ), + ) + if sorter is not None: + sorter = ( + sorter[0], + self._add_index( + sorter[1], MemoryUsageType.BUCKETIZE, buffer_name=sorter[0] + ), + ) + + return self._inner.bucketize( + values, + boundaries, + boundary_indices, + indexing_dtype, + right, + sorter, + sorter_indices, + ) + + def masked(self, mask_proxy, masked_body: Callable[..., Any], other_proxy): + """ + Recursively capture the masked out body in another LoopBodyBlock + """ + name = self.body.add_submodule(None, "masked_subblock") + self.body.submodules[name] = self.body.bind_masked_shim(name) + self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, []) + return self.tracer.create_proxy( + "call_module", name, (mask_proxy, other_proxy), {} + ) + + def scan( + self, + dtype_proxy, + combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]], + value_proxy, + ): + shim = self.body.bind_scan_shim(combine_fn) + name = self.body.add_submodule(shim, "scan") + result = self.tracer.create_proxy( + "call_module", + name, + (dtype_proxy, value_proxy), + {}, + ) + # Proxies are iterable, but some methods expect tuples/lists + return tuple(result[i] for i in range(len(value_proxy))) + + def sort(self, dtypes, values, stable, descending): + result = self._inner.sort(dtypes, values, stable, descending) + # Proxies are iterable, but some methods expect tuples/lists + return tuple(result[i] for i in range(len(values))) + + def frexp(self, value_proxy): + result = self._inner.frexp(value_proxy) + # Proxies are iterable, but some methods expect tuples/lists + return (result[0], result[1]) + + def indirect_indexing(self, index_proxy, size, check=True, wrap_neg=True): + """ + Flow data from tensors into indexing formulas. + Introduce a call_module to update the indexing. + """ + + var = self.body.add_indirect(size) + set_indirect = self.body.bind_set_indirect_shim(var, size, check, wrap_neg) + self.tracer.create_proxy( + "call_module", + self.body.add_submodule(set_indirect, f"set_{var}"), + (index_proxy,), + {}, + ) + return var + + def output(self, *result): + self.tracer.create_proxy("output", "output", result, {})