pytorch/test/inductor/test_block_analysis.py
Blaine Burton Rister a1bfb39a31 [Inductor] Expand Identity ops prior to block pattern matching (#146000)
# Feature

Inductor sometimes uses `Identity` functions to group various terms of an expression. While this is convenient in some scenarios, it can frustrate pattern matching. For example, when we're matching an indexing expression to tell if it can be represented as a block pointer, that analysis should be invariant to `Identity`'s.

This PR adds a few features to achieve this invariance.
 - Create a new expansion mode `expr.expand(identity=True)`, which removes all `Identity` functions from the expression.
 -  Preprocess the expression with this expansion prior to pattern matching.
 - Bonus: create a new test utility function called `dummy_graph()`, which creates a simple `GraphLowering`. This is useful for testing the pattern matcher, as we need to initialize `V.graph` before we can access `V.graph.sizevars`.

# Test plan
This PR adds a few new unit tests:
 - Added a unit test specifically for `expr.expand(identity=True)`.
 - Added a new unit test module for the block pattern matcher. Tested that we can correctly match some example patterns containing Identity ops.

I originally intended to add an end to end test compiling pointwise cat, and mapping the corresponding memory accesses to block pointers. However, it looks like that will take more work, since the [relevant code path](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/triton.py#L1306) disables block pointer analysis. It might be better to defer that to a future PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146000
Approved by: https://github.com/eellison, https://github.com/jansel
2025-02-08 18:11:53 +00:00

102 lines
3 KiB
Python

# Owner(s): ["module: inductor"]
import sympy
import torch
from torch._inductor.codegen.block_analysis import BlockPatternMatcher
from torch._inductor.virtualized import V
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TestCase,
)
from torch.testing._internal.inductor_utils import dummy_graph
from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing
# Some useful symbols
x, y = sympy.symbols("x y")
@instantiate_parametrized_tests
class BlockAnalysisTest(TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
# Create a GraphLowering, so we can access V.graph.
cls.graph = dummy_graph()
@parametrize(
"stride,symbol,expr",
[
(5, x, Identity(5 * x)),
(4, y, 4 * Identity(y)),
(3, x, Identity(3) * x),
],
)
def test_affine_identity(self, stride: int, symbol: sympy.Symbol, expr: sympy.Expr):
# Test that we can handle an identity expression in affine indexing.
matched_stride = BlockPatternMatcher.match_affine_block_expr(expr, symbol)
self.assertEqual(matched_stride, stride)
@parametrize(
"dims,strides,symbol,expr",
[
(
(2, 4),
(4, 1),
x,
4 * FloorDiv(Identity(x), 4) + ModularIndexing(x, 1, 4),
),
(
(3, 9),
(5, 2),
x,
5 * FloorDiv(x, 9) + 2 * ModularIndexing(Identity(x), 1, 9),
),
((2, 7), (1, 1), x, Identity(FloorDiv(x, 7) + ModularIndexing(x, 1, 7))),
],
)
def test_mod_div_identity(
self,
dims: tuple[int],
strides: tuple[int],
symbol: sympy.Symbol,
expr: sympy.Expr,
):
# Test that we can handle an identity expression in modular indexing.
numel = int(torch.prod(torch.Tensor(dims)))
num_dims = len(dims)
with V.set_graph_handler(self.graph):
match_result = BlockPatternMatcher.match_mod_div_block_expr(
expr, symbol, numel, num_dims
)
# Check the matched block dimensions.
self.assertNotEqual(match_result, None)
matched_dims, matched_strides, matched_block_index_exprs = match_result
self.assertEqual(matched_dims, dims)
self.assertEqual(matched_strides, strides)
@parametrize(
"symbol,expr,subexpr",
[
(x, Identity(x), x),
(x, Identity(x + 5), x),
(y, Identity(x + 2 * y) + 5, 2 * y),
],
)
def test_subexpr_identity(
self,
symbol: sympy.Symbol,
expr: sympy.Expr,
subexpr: sympy.Expr,
):
matched_subexpr = BlockPatternMatcher.get_subexpr_involving_symbol(expr, symbol)
self.assertEqual(matched_subexpr, subexpr)
if __name__ == "__main__":
run_tests()