mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[inductor] Refactor CaptureIndexing into global scope (#146297)
And inline SimplifyIndexing into it CaptureIndexing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146297 Approved by: https://github.com/shunting314 ghstack dependencies: #146252, #146254, #146255, #146257, #146282
This commit is contained in:
parent
d35f6b2339
commit
c098385cb3
1 changed files with 178 additions and 168 deletions
|
|
@ -17,7 +17,7 @@ from torch.utils._sympy.symbol import SymT
|
|||
|
||||
from . import config, dependencies
|
||||
from .codegen.common import index_prevent_reordering
|
||||
from .ops_handler import DefaultHandler, OpsHandler
|
||||
from .ops_handler import DefaultHandler, OpsHandler, WrapperHandler
|
||||
from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs
|
||||
from .virtualized import ops, V
|
||||
|
||||
|
|
@ -440,179 +440,13 @@ class LoopBodyBlock:
|
|||
|
||||
def __init__(self, body: LoopBody, fn: Callable[..., Any], args: list[Any]):
|
||||
self.body = body
|
||||
|
||||
def add_index(expr: sympy.Expr, mtype: MemoryUsageType, **kwargs):
|
||||
return tracer.create_proxy(
|
||||
"call_module",
|
||||
"get_index",
|
||||
(body.add_index_expr(expr, mtype, **kwargs),),
|
||||
{},
|
||||
)
|
||||
|
||||
class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined]
|
||||
name = "CaptureIndexing"
|
||||
|
||||
def load(self, name: str, index: sympy.Expr):
|
||||
index = add_index(index, MemoryUsageType.LOAD, buffer_name=name)
|
||||
return self._inner.load(name, index)
|
||||
|
||||
def load_seed(self, name: str, index: int):
|
||||
assert isinstance(index, int)
|
||||
body.add_index_expr(
|
||||
sympy.Integer(index), MemoryUsageType.LOAD_SEED, buffer_name=name
|
||||
)
|
||||
return self._inner.load_seed(name, index)
|
||||
|
||||
def store(self, name, index, value, mode=None):
|
||||
index = add_index(
|
||||
index, MemoryUsageType.STORE, buffer_name=name, mode=mode
|
||||
)
|
||||
return self._inner.store(name, index, value, mode)
|
||||
|
||||
def store_reduction(self, name, index, value):
|
||||
index = add_index(
|
||||
index, MemoryUsageType.STORE_REDUCTION, buffer_name=name
|
||||
)
|
||||
return self._inner.store_reduction(name, index, value)
|
||||
|
||||
def reduction(self, dtype, src_dtype, reduction_type, value):
|
||||
result = self._inner.reduction(dtype, src_dtype, reduction_type, value)
|
||||
if "welford" in reduction_type:
|
||||
return tuple(result[i] for i in range(3))
|
||||
return result
|
||||
|
||||
def index_expr(self, index, dtype):
|
||||
if isinstance(index, (int, sympy.Integer)):
|
||||
return self._inner.constant(int(index), dtype)
|
||||
index = add_index(index, MemoryUsageType.INDEX_EXPR)
|
||||
return self._inner.index_expr(index, dtype)
|
||||
|
||||
def check_bounds(self, index, size, lower, upper):
|
||||
index = add_index(index, MemoryUsageType.CHECK_BOUNDS)
|
||||
size = add_index(size, MemoryUsageType.CHECK_BOUNDS)
|
||||
return self._inner.check_bounds(index, size, lower, upper)
|
||||
|
||||
def bucketize(
|
||||
self,
|
||||
values: T,
|
||||
boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
|
||||
boundary_indices: T,
|
||||
indexing_dtype: torch.dtype,
|
||||
right: bool,
|
||||
sorter: Optional[tuple[str, sympy.Expr]] = None,
|
||||
sorter_indices: Optional[T] = None,
|
||||
) -> T:
|
||||
"""
|
||||
See [Note: Inductor bucketize op]
|
||||
"""
|
||||
boundaries = (
|
||||
boundaries[0],
|
||||
add_index(
|
||||
boundaries[1],
|
||||
MemoryUsageType.BUCKETIZE,
|
||||
buffer_name=boundaries[0],
|
||||
),
|
||||
add_index(
|
||||
boundaries[2],
|
||||
MemoryUsageType.BUCKETIZE,
|
||||
buffer_name=boundaries[0],
|
||||
),
|
||||
add_index(
|
||||
boundaries[3],
|
||||
MemoryUsageType.BUCKETIZE,
|
||||
buffer_name=boundaries[0],
|
||||
),
|
||||
)
|
||||
if sorter is not None:
|
||||
sorter = (
|
||||
sorter[0],
|
||||
add_index(
|
||||
sorter[1], MemoryUsageType.BUCKETIZE, buffer_name=sorter[0]
|
||||
),
|
||||
)
|
||||
|
||||
return self._inner.bucketize(
|
||||
values,
|
||||
boundaries,
|
||||
boundary_indices,
|
||||
indexing_dtype,
|
||||
right,
|
||||
sorter,
|
||||
sorter_indices,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def masked(mask_proxy, masked_body: Callable[..., Any], other_proxy):
|
||||
"""
|
||||
Recursively capture the masked out body in another LoopBodyBlock
|
||||
"""
|
||||
name = self.body.add_submodule(None, "masked_subblock")
|
||||
self.body.submodules[name] = self.body.bind_masked_shim(name)
|
||||
self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, [])
|
||||
return tracer.create_proxy(
|
||||
"call_module", name, (mask_proxy, other_proxy), {}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def scan(
|
||||
dtype_proxy,
|
||||
combine_fn: Callable[
|
||||
[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]
|
||||
],
|
||||
value_proxy,
|
||||
):
|
||||
shim = self.body.bind_scan_shim(combine_fn)
|
||||
name = self.body.add_submodule(shim, "scan")
|
||||
result = tracer.create_proxy(
|
||||
"call_module",
|
||||
name,
|
||||
(dtype_proxy, value_proxy),
|
||||
{},
|
||||
)
|
||||
# Proxies are iterable, but some methods expect tuples/lists
|
||||
return tuple(result[i] for i in range(len(value_proxy)))
|
||||
|
||||
def sort(self, dtypes, values, stable, descending):
|
||||
result = self._inner.sort(dtypes, values, stable, descending)
|
||||
# Proxies are iterable, but some methods expect tuples/lists
|
||||
return tuple(result[i] for i in range(len(values)))
|
||||
|
||||
def frexp(self, value_proxy):
|
||||
result = self._inner.frexp(value_proxy)
|
||||
# Proxies are iterable, but some methods expect tuples/lists
|
||||
return (result[0], result[1])
|
||||
|
||||
@staticmethod
|
||||
def indirect_indexing(index_proxy, size, check=True, wrap_neg=True):
|
||||
"""
|
||||
Flow data from tensors into indexing formulas.
|
||||
Introduce a call_module to update the indexing.
|
||||
"""
|
||||
|
||||
var = self.body.add_indirect(size)
|
||||
set_indirect = self.body.bind_set_indirect_shim(
|
||||
var, size, check, wrap_neg
|
||||
)
|
||||
tracer.create_proxy(
|
||||
"call_module",
|
||||
self.body.add_submodule(set_indirect, f"set_{var}"),
|
||||
(index_proxy,),
|
||||
{},
|
||||
)
|
||||
return var
|
||||
|
||||
@staticmethod
|
||||
def output(*result):
|
||||
tracer.create_proxy("output", "output", result, {})
|
||||
|
||||
tracer = LightTracer()
|
||||
proxy_ops = tracer.create_proxy("placeholder", "ops", (), {})
|
||||
|
||||
from .index_propagation import IndexPropagation
|
||||
from .sizevars import SimplifyIndexing
|
||||
|
||||
handler: Any = CountOps(
|
||||
SimplifyIndexing(CaptureIndexing(proxy_ops), self.body.var_ranges),
|
||||
CaptureIndexing(proxy_ops, body, tracer),
|
||||
body.op_counts,
|
||||
)
|
||||
if config.constant_and_index_propagation:
|
||||
|
|
@ -662,3 +496,179 @@ class CountOps(DefaultHandler):
|
|||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
self._counts[name] += 1
|
||||
return getattr(self._inner, name)(*args, **kwargs)
|
||||
|
||||
|
||||
class CaptureIndexing(WrapperHandler):
|
||||
name = "CaptureIndexing"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inner: OpsHandler[Any],
|
||||
body: LoopBody,
|
||||
tracer: LightTracer,
|
||||
):
|
||||
super().__init__(inner)
|
||||
self.body = body
|
||||
self.tracer = tracer
|
||||
|
||||
def _add_index(self, expr: sympy.Expr, mtype: MemoryUsageType, **kwargs: Any):
|
||||
return self.tracer.create_proxy(
|
||||
"call_module",
|
||||
"get_index",
|
||||
(self.body.add_index_expr(expr, mtype, **kwargs),),
|
||||
{},
|
||||
)
|
||||
|
||||
def _simplify(self, expr: sympy.Expr) -> sympy.Expr:
|
||||
return V.graph.sizevars.simplify_with_ranges(expr, self.body.var_ranges)
|
||||
|
||||
def load(self, name: str, index: sympy.Expr):
|
||||
index = self._simplify(index)
|
||||
index = self._add_index(index, MemoryUsageType.LOAD, buffer_name=name)
|
||||
return self._inner.load(name, index)
|
||||
|
||||
def load_seed(self, name: str, index: int):
|
||||
assert isinstance(index, int)
|
||||
self.body.add_index_expr(
|
||||
sympy.Integer(index), MemoryUsageType.LOAD_SEED, buffer_name=name
|
||||
)
|
||||
return self._inner.load_seed(name, index)
|
||||
|
||||
def store(self, name, index, value, mode=None):
|
||||
index = self._simplify(index)
|
||||
index = self._add_index(
|
||||
index, MemoryUsageType.STORE, buffer_name=name, mode=mode
|
||||
)
|
||||
return self._inner.store(name, index, value, mode)
|
||||
|
||||
def store_reduction(self, name, index, value):
|
||||
index = self._simplify(index)
|
||||
index = self._add_index(
|
||||
index, MemoryUsageType.STORE_REDUCTION, buffer_name=name
|
||||
)
|
||||
return self._inner.store_reduction(name, index, value)
|
||||
|
||||
def reduction(self, dtype, src_dtype, reduction_type, value):
|
||||
result = self._inner.reduction(dtype, src_dtype, reduction_type, value)
|
||||
if "welford" in reduction_type:
|
||||
return tuple(result[i] for i in range(3))
|
||||
return result
|
||||
|
||||
def index_expr(self, index, dtype):
|
||||
index = self._simplify(index)
|
||||
if isinstance(index, (int, sympy.Integer)):
|
||||
return self._inner.constant(int(index), dtype)
|
||||
index = self._add_index(index, MemoryUsageType.INDEX_EXPR)
|
||||
return self._inner.index_expr(index, dtype)
|
||||
|
||||
def check_bounds(self, index, size, lower, upper):
|
||||
index = self._simplify(index)
|
||||
index = self._add_index(index, MemoryUsageType.CHECK_BOUNDS)
|
||||
size = self._add_index(size, MemoryUsageType.CHECK_BOUNDS)
|
||||
return self._inner.check_bounds(index, size, lower, upper)
|
||||
|
||||
def bucketize(
|
||||
self,
|
||||
values: T,
|
||||
boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
|
||||
boundary_indices: T,
|
||||
indexing_dtype: torch.dtype,
|
||||
right: bool,
|
||||
sorter: Optional[tuple[str, sympy.Expr]] = None,
|
||||
sorter_indices: Optional[T] = None,
|
||||
) -> T:
|
||||
"""
|
||||
See [Note: Inductor bucketize op]
|
||||
"""
|
||||
boundaries = (
|
||||
boundaries[0],
|
||||
self._add_index(
|
||||
boundaries[1],
|
||||
MemoryUsageType.BUCKETIZE,
|
||||
buffer_name=boundaries[0],
|
||||
),
|
||||
self._add_index(
|
||||
boundaries[2],
|
||||
MemoryUsageType.BUCKETIZE,
|
||||
buffer_name=boundaries[0],
|
||||
),
|
||||
self._add_index(
|
||||
boundaries[3],
|
||||
MemoryUsageType.BUCKETIZE,
|
||||
buffer_name=boundaries[0],
|
||||
),
|
||||
)
|
||||
if sorter is not None:
|
||||
sorter = (
|
||||
sorter[0],
|
||||
self._add_index(
|
||||
sorter[1], MemoryUsageType.BUCKETIZE, buffer_name=sorter[0]
|
||||
),
|
||||
)
|
||||
|
||||
return self._inner.bucketize(
|
||||
values,
|
||||
boundaries,
|
||||
boundary_indices,
|
||||
indexing_dtype,
|
||||
right,
|
||||
sorter,
|
||||
sorter_indices,
|
||||
)
|
||||
|
||||
def masked(self, mask_proxy, masked_body: Callable[..., Any], other_proxy):
|
||||
"""
|
||||
Recursively capture the masked out body in another LoopBodyBlock
|
||||
"""
|
||||
name = self.body.add_submodule(None, "masked_subblock")
|
||||
self.body.submodules[name] = self.body.bind_masked_shim(name)
|
||||
self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, [])
|
||||
return self.tracer.create_proxy(
|
||||
"call_module", name, (mask_proxy, other_proxy), {}
|
||||
)
|
||||
|
||||
def scan(
|
||||
self,
|
||||
dtype_proxy,
|
||||
combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]],
|
||||
value_proxy,
|
||||
):
|
||||
shim = self.body.bind_scan_shim(combine_fn)
|
||||
name = self.body.add_submodule(shim, "scan")
|
||||
result = self.tracer.create_proxy(
|
||||
"call_module",
|
||||
name,
|
||||
(dtype_proxy, value_proxy),
|
||||
{},
|
||||
)
|
||||
# Proxies are iterable, but some methods expect tuples/lists
|
||||
return tuple(result[i] for i in range(len(value_proxy)))
|
||||
|
||||
def sort(self, dtypes, values, stable, descending):
|
||||
result = self._inner.sort(dtypes, values, stable, descending)
|
||||
# Proxies are iterable, but some methods expect tuples/lists
|
||||
return tuple(result[i] for i in range(len(values)))
|
||||
|
||||
def frexp(self, value_proxy):
|
||||
result = self._inner.frexp(value_proxy)
|
||||
# Proxies are iterable, but some methods expect tuples/lists
|
||||
return (result[0], result[1])
|
||||
|
||||
def indirect_indexing(self, index_proxy, size, check=True, wrap_neg=True):
|
||||
"""
|
||||
Flow data from tensors into indexing formulas.
|
||||
Introduce a call_module to update the indexing.
|
||||
"""
|
||||
|
||||
var = self.body.add_indirect(size)
|
||||
set_indirect = self.body.bind_set_indirect_shim(var, size, check, wrap_neg)
|
||||
self.tracer.create_proxy(
|
||||
"call_module",
|
||||
self.body.add_submodule(set_indirect, f"set_{var}"),
|
||||
(index_proxy,),
|
||||
{},
|
||||
)
|
||||
return var
|
||||
|
||||
def output(self, *result):
|
||||
self.tracer.create_proxy("output", "output", result, {})
|
||||
|
|
|
|||
Loading…
Reference in a new issue