mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[Inductor] move block pointer analysis to a new module (#141733)
# Summary Preparatory refactor for https://github.com/pytorch/pytorch/pull/137243. This refactors the ModularIndexing block pointer analysis into its own module. That way, we can call it from other places besides Triton codegen. In the parent PR, we will use this to find tiling splits that simplify the indexing. # Test plan Tested by the existing CI. Pull Request resolved: https://github.com/pytorch/pytorch/pull/141733 Approved by: https://github.com/jansel
This commit is contained in:
parent
49fde426ba
commit
c2fa544472
2 changed files with 141 additions and 82 deletions
116
torch/_inductor/codegen/block_analysis.py
Normal file
116
torch/_inductor/codegen/block_analysis.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
import collections
|
||||
import functools
|
||||
import textwrap
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import sympy
|
||||
from sympy import Expr, Symbol
|
||||
|
||||
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
|
||||
|
||||
from ..utils import sympy_dot, sympy_subs
|
||||
from ..virtualized import V
|
||||
|
||||
|
||||
class BlockPatternMatcher:
|
||||
"""
|
||||
Matches block indexing expressions.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_subexpr_involving_symbol(expr: Expr, symbol: Symbol) -> Expr:
|
||||
"""
|
||||
Given a sympy expression, return the subexpression comprised only of terms
|
||||
involving the specified symbol.
|
||||
|
||||
For example, if `expr` is `x * 5 + x ** 2 + y * 2 + 5`, and `symbol` is `x`,
|
||||
this returns `x * 5 + x ** 2`.
|
||||
"""
|
||||
return sympy.S.Zero + sum(
|
||||
term for term in sympy.Add.make_args(expr) if symbol in term.free_symbols
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_slice_numels(dims: List[Expr]) -> List[Expr]:
|
||||
"""
|
||||
Compute the cumulative size of each dimension's slice.
|
||||
This proceeds from the last dim up to the second.
|
||||
"""
|
||||
numels = collections.deque([sympy.S.One])
|
||||
for dim in dims[:0:-1]:
|
||||
numel = dim * numels[0]
|
||||
numels.appendleft(numel)
|
||||
return [*numels]
|
||||
|
||||
@classmethod
|
||||
def match_mod_div_block_expr(
|
||||
cls,
|
||||
index: Expr,
|
||||
index_var: Symbol,
|
||||
numel: Expr,
|
||||
num_dims: int,
|
||||
) -> Optional[Tuple[List[Expr], List[Expr], List[Expr]]]:
|
||||
"""
|
||||
Matches modular indexing expressions, converting them to implied block dimensions and strides.
|
||||
See triton.py for more information.
|
||||
"""
|
||||
|
||||
# Pattern match to find the strides and offset.
|
||||
wild = functools.partial(sympy.Wild, exclude=[index_var])
|
||||
dims: List[Expr] = [wild(f"dim_mod{idx}") for idx in range(num_dims)]
|
||||
strides: List[Expr] = [wild(f"stride_mod{idx}") for idx in range(num_dims)]
|
||||
|
||||
# The first dimension's index is computed by division.
|
||||
# The remaining are computed by modulo.
|
||||
slice_numels = cls.get_slice_numels(dims[:num_dims])
|
||||
block_index_exprs = [FloorDiv(index_var, slice_numels[0])] + [
|
||||
ModularIndexing(index_var, numel, dim)
|
||||
for dim, numel in zip(dims[1:], slice_numels[1:])
|
||||
]
|
||||
|
||||
# Calculate a linear index from block indices.
|
||||
match_expr = sympy_dot(strides, block_index_exprs)
|
||||
|
||||
# Pattern match.
|
||||
match = index.match(match_expr)
|
||||
if match is None:
|
||||
return None
|
||||
|
||||
# Provide default values for unmatched dims and strides.
|
||||
for dim in dims[1:]:
|
||||
if dim not in match:
|
||||
match[dim] = sympy.S.One
|
||||
for stride in strides[1:]:
|
||||
if stride not in match:
|
||||
match[stride] = sympy.S.Zero
|
||||
|
||||
sizevars = V.graph.sizevars
|
||||
|
||||
def get_match(expr: Expr) -> Expr:
|
||||
return sizevars.lookup_precomputed_size(match[expr])
|
||||
|
||||
# Replace wildcards with matched expressions.
|
||||
dims = [dims[0]] + [get_match(dim) for dim in dims[1:]]
|
||||
strides = [get_match(stride) for stride in strides]
|
||||
slice_numels = cls.get_slice_numels(dims)
|
||||
block_index_exprs = [sympy_subs(expr, match) for expr in block_index_exprs]
|
||||
|
||||
# The leading dimension is not directly matched in our expression.
|
||||
# We solve for it by dividing the range tree numel by the product of
|
||||
# all other dimensions. We quit if they are not known to be divisible.
|
||||
assert dims[0] not in match, "Expected not to match the leading dimension!"
|
||||
if not sizevars.statically_known_multiple_of(numel, slice_numels[0]):
|
||||
return None
|
||||
dims[0] = numel / slice_numels[0]
|
||||
|
||||
# Sanity check that we can recover the index from the matched subexpressions.
|
||||
matched_index = sympy_dot(strides, block_index_exprs)
|
||||
assert sizevars.statically_known_equals(matched_index, index), textwrap.dedent(
|
||||
f"""
|
||||
Invalid match!
|
||||
Index: {index}
|
||||
Matched expression: {matched_index}
|
||||
"""
|
||||
)
|
||||
|
||||
return dims, strides, block_index_exprs
|
||||
|
|
@ -61,12 +61,12 @@ from ..utils import (
|
|||
get_kernel_metadata,
|
||||
is_welford_reduction,
|
||||
Placeholder,
|
||||
sympy_dot,
|
||||
sympy_subs,
|
||||
upcast_compute_type,
|
||||
)
|
||||
from ..virtualized import _ops as ops, OpsHandler, ReductionType, StoreMode, V
|
||||
from ..wrapper_benchmark import get_kernel_category_by_source_code
|
||||
from .block_analysis import BlockPatternMatcher
|
||||
from .common import (
|
||||
BackendFeature,
|
||||
CSE,
|
||||
|
|
@ -1651,73 +1651,18 @@ class TritonKernel(SIMDKernel):
|
|||
|
||||
# Pattern match to find the strides and offset.
|
||||
index_var = range_tree.symbol()
|
||||
wild = functools.partial(sympy.Wild, exclude=[index_var])
|
||||
dims: List[sympy.Expr] = [
|
||||
wild(f"dim_mod{idx}") for idx in range(num_dims)
|
||||
]
|
||||
strides: List[sympy.Expr] = [
|
||||
wild(f"stride_mod{idx}") for idx in range(num_dims)
|
||||
]
|
||||
|
||||
def get_slice_numels(dims: List[Any]) -> List[Any]:
|
||||
"""
|
||||
Compute the cumulative size of each dimension's slice.
|
||||
This proceeds from the last dim up to the second.
|
||||
"""
|
||||
numels = [sympy.S.One]
|
||||
for dim in dims[:0:-1]:
|
||||
numel = dim * numels[0]
|
||||
numels.insert(0, numel)
|
||||
return numels
|
||||
|
||||
# The first dimension's index is computed by division.
|
||||
# The remaining are computed by modulo.
|
||||
slice_numels = get_slice_numels(dims[:num_dims])
|
||||
block_index_exprs = [FloorDiv(index_var, slice_numels[0])] + [
|
||||
ModularIndexing(index_var, numel, dim)
|
||||
for dim, numel in zip(dims[1:], slice_numels[1:])
|
||||
]
|
||||
|
||||
# Calculate a linear index from block indices.
|
||||
match_expr = sympy_dot(strides, block_index_exprs)
|
||||
|
||||
# Pattern match.
|
||||
match = index.match(match_expr)
|
||||
if match is None:
|
||||
match_result = BlockPatternMatcher.match_mod_div_block_expr(
|
||||
index, index_var, range_tree.numel, num_dims
|
||||
)
|
||||
if match_result is None:
|
||||
return None
|
||||
|
||||
# Provide default values for unmatched dims and strides.
|
||||
for dim in dims[1:]:
|
||||
if dim not in match:
|
||||
match[dim] = sympy.S.One
|
||||
for stride in strides[1:]:
|
||||
if stride not in match:
|
||||
match[stride] = sympy.S.Zero
|
||||
|
||||
sizevars = V.graph.sizevars
|
||||
|
||||
def get_match(expr: sympy.Expr) -> sympy.Expr:
|
||||
return sizevars.lookup_precomputed_size(match[expr])
|
||||
|
||||
# Replace wildcards with matched expressions.
|
||||
dims = [dims[0]] + [get_match(dim) for dim in dims[1:]]
|
||||
strides = [get_match(stride) for stride in strides]
|
||||
slice_numels = get_slice_numels(dims)
|
||||
block_index_exprs = [
|
||||
sympy_subs(expr, match) for expr in block_index_exprs
|
||||
]
|
||||
|
||||
# The leading dimension is not directly matched in our expression.
|
||||
# We solve for it by dividing the range tree numel by the product of
|
||||
# all other dimensions. We quit if they are not known to be divisible.
|
||||
assert (
|
||||
dims[0] not in match
|
||||
), "Expected not to match the leading dimension!"
|
||||
if not sizevars.statically_known_multiple_of(
|
||||
range_tree.numel, slice_numels[0]
|
||||
):
|
||||
return None
|
||||
dims[0] = range_tree.numel / slice_numels[0]
|
||||
(
|
||||
dims,
|
||||
strides,
|
||||
block_index_exprs,
|
||||
) = match_result
|
||||
slice_numels = BlockPatternMatcher.get_slice_numels(dims)
|
||||
|
||||
# Check for applicable iteration range sizes.
|
||||
# When mapping a 1D block into an ND one, we need to know that
|
||||
|
|
@ -1728,6 +1673,7 @@ class TritonKernel(SIMDKernel):
|
|||
# with n and m integers, then either numel is a multiple of XBLOCK, or numel
|
||||
# is less than XBLOCK. (If numel is less than XBLOCK, we round up to 1 below.)
|
||||
# 2. Numels are multiples of the maximum possible block size.
|
||||
sizevars = V.graph.sizevars
|
||||
max_block = self.max_block(range_tree.prefix)
|
||||
if any(
|
||||
not sizevars.statically_known_multiple_of(numel, max_block)
|
||||
|
|
@ -1786,19 +1732,20 @@ class TritonKernel(SIMDKernel):
|
|||
)
|
||||
range_trees = self.active_range_trees(reorder=True)
|
||||
|
||||
# Match each range tree separately.
|
||||
range_symbols = {tree.symbol() for tree in range_trees}
|
||||
index_terms = sympy.Add.make_args(index_relative_to_xyr_index)
|
||||
block_params = BlockParameters()
|
||||
for tree in range_trees:
|
||||
# Partition the index into subexpressions pertaining to each range tree.
|
||||
# For example xindex * 5 + rindex * 3 is partitioned to
|
||||
# (xindex * 5, rindex * 3).
|
||||
symbol = tree.symbol()
|
||||
subexpr = sympy.S.Zero + sum(
|
||||
expr for expr in index_terms if symbol in expr.free_symbols
|
||||
# Partition the index into subexpressions pertaining to each range tree.
|
||||
# For example xindex * 5 + rindex * 3 is partitioned to
|
||||
# (xindex * 5, rindex * 3).
|
||||
index_subexprs = [
|
||||
BlockPatternMatcher.get_subexpr_involving_symbol(
|
||||
index_relative_to_xyr_index, tree.symbol()
|
||||
)
|
||||
for tree in range_trees
|
||||
]
|
||||
|
||||
# Match each range tree's subexpression separately.
|
||||
range_symbols = {tree.symbol() for tree in range_trees}
|
||||
block_params = BlockParameters()
|
||||
for tree, subexpr in zip(range_trees, index_subexprs):
|
||||
# Reject mixed terms, e.g. xindex * rindex.
|
||||
# NB: the zero expression is allowed, for broadcasting.
|
||||
if len(range_symbols.intersection(subexpr.free_symbols)) > 1:
|
||||
|
|
@ -1811,11 +1758,7 @@ class TritonKernel(SIMDKernel):
|
|||
block_params += params
|
||||
|
||||
# Collect leftover terms as a constant offset.
|
||||
offset = sum(
|
||||
expr
|
||||
for expr in index_terms
|
||||
if not range_symbols.intersection(expr.free_symbols)
|
||||
)
|
||||
offset = index_relative_to_xyr_index - sum(index_subexprs)
|
||||
|
||||
# Form the block pointer.
|
||||
self.filter_masks(mask_vars)
|
||||
|
|
|
|||
Loading…
Reference in a new issue