From a1bfb39a31aff91dfeba8730f5b496cb1e44ceca Mon Sep 17 00:00:00 2001 From: Blaine Burton Rister <145300525+blaine-rister@users.noreply.github.com> Date: Sat, 8 Feb 2025 18:11:53 +0000 Subject: [PATCH] [Inductor] Expand Identity ops prior to block pattern matching (#146000) # Feature Inductor sometimes uses `Identity` functions to group various terms of an expression. While this is convenient in some scenarios, it can frustrate pattern matching. For example, when we're matching an indexing expression to tell if it can be represented as a block pointer, that analysis should be invariant to `Identity`'s. This PR adds a few features to achieve this invariance. - Create a new expansion mode `expr.expand(identity=True)`, which removes all `Identity` functions from the expression. - Preprocess the expression with this expansion prior to pattern matching. - Bonus: create a new test utility function called `dummy_graph()`, which creates a simple `GraphLowering`. This is useful for testing the pattern matcher, as we need to initialize `V.graph` before we can access `V.graph.sizevars`. # Test plan This PR adds a few new unit tests: - Added a unit test specifically for `expr.expand(identity=True)`. - Added a new unit test module for the block pattern matcher. Tested that we can correctly match some example patterns containing Identity ops. I originally intended to add an end to end test compiling pointwise cat, and mapping the corresponding memory accesses to block pointers. However, it looks like that will take more work, since the [relevant code path](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/triton.py#L1306) disables block pointer analysis. It might be better to defer that to a future PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146000 Approved by: https://github.com/eellison, https://github.com/jansel --- test/inductor/test_block_analysis.py | 102 ++++++++++++++++++++++ test/test_sympy_utils.py | 12 +++ torch/_inductor/codegen/block_analysis.py | 29 +++++- torch/_inductor/codegen/triton.py | 14 +-- torch/testing/_internal/inductor_utils.py | 18 ++++ torch/utils/_sympy/functions.py | 4 + 6 files changed, 170 insertions(+), 9 deletions(-) create mode 100644 test/inductor/test_block_analysis.py diff --git a/test/inductor/test_block_analysis.py b/test/inductor/test_block_analysis.py new file mode 100644 index 00000000000..5cf932d52e8 --- /dev/null +++ b/test/inductor/test_block_analysis.py @@ -0,0 +1,102 @@ +# Owner(s): ["module: inductor"] + +import sympy + +import torch +from torch._inductor.codegen.block_analysis import BlockPatternMatcher +from torch._inductor.virtualized import V +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + run_tests, + TestCase, +) +from torch.testing._internal.inductor_utils import dummy_graph +from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing + + +# Some useful symbols +x, y = sympy.symbols("x y") + + +@instantiate_parametrized_tests +class BlockAnalysisTest(TestCase): + @classmethod + def setUpClass(cls): + super().setUpClass() + + # Create a GraphLowering, so we can access V.graph. + cls.graph = dummy_graph() + + @parametrize( + "stride,symbol,expr", + [ + (5, x, Identity(5 * x)), + (4, y, 4 * Identity(y)), + (3, x, Identity(3) * x), + ], + ) + def test_affine_identity(self, stride: int, symbol: sympy.Symbol, expr: sympy.Expr): + # Test that we can handle an identity expression in affine indexing. + matched_stride = BlockPatternMatcher.match_affine_block_expr(expr, symbol) + self.assertEqual(matched_stride, stride) + + @parametrize( + "dims,strides,symbol,expr", + [ + ( + (2, 4), + (4, 1), + x, + 4 * FloorDiv(Identity(x), 4) + ModularIndexing(x, 1, 4), + ), + ( + (3, 9), + (5, 2), + x, + 5 * FloorDiv(x, 9) + 2 * ModularIndexing(Identity(x), 1, 9), + ), + ((2, 7), (1, 1), x, Identity(FloorDiv(x, 7) + ModularIndexing(x, 1, 7))), + ], + ) + def test_mod_div_identity( + self, + dims: tuple[int], + strides: tuple[int], + symbol: sympy.Symbol, + expr: sympy.Expr, + ): + # Test that we can handle an identity expression in modular indexing. + numel = int(torch.prod(torch.Tensor(dims))) + num_dims = len(dims) + with V.set_graph_handler(self.graph): + match_result = BlockPatternMatcher.match_mod_div_block_expr( + expr, symbol, numel, num_dims + ) + + # Check the matched block dimensions. + self.assertNotEqual(match_result, None) + matched_dims, matched_strides, matched_block_index_exprs = match_result + self.assertEqual(matched_dims, dims) + self.assertEqual(matched_strides, strides) + + @parametrize( + "symbol,expr,subexpr", + [ + (x, Identity(x), x), + (x, Identity(x + 5), x), + (y, Identity(x + 2 * y) + 5, 2 * y), + ], + ) + def test_subexpr_identity( + self, + symbol: sympy.Symbol, + expr: sympy.Expr, + subexpr: sympy.Expr, + ): + matched_subexpr = BlockPatternMatcher.get_subexpr_involving_symbol(expr, symbol) + self.assertEqual(matched_subexpr, subexpr) + + +if __name__ == "__main__": + run_tests() diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py index dddb73c2851..e804e289c1c 100644 --- a/test/test_sympy_utils.py +++ b/test/test_sympy_utils.py @@ -22,6 +22,7 @@ from torch.testing._internal.common_utils import ( ) from torch.utils._sympy.functions import ( FloorDiv, + Identity, OpaqueUnaryFn_cos, simple_floordiv_gcd, ) @@ -955,6 +956,17 @@ class TestSingletonInt(TestCase): self.assertEqual(j1.free_symbols, set()) +class TestIdentity(TestCase): + def test_expand_identity(self): + """ + Test removing an identity via expansion. + """ + x = sympy.Symbol("x") + arg = x + sympy.S.One + expr = Identity(arg) + expanded = expr.expand(identity=True) + self.assertEqual(expanded.count(Identity), 0) + self.assertEqual(expanded, arg) instantiate_parametrized_tests(TestValueRanges) instantiate_parametrized_tests(TestSympyInterp) diff --git a/torch/_inductor/codegen/block_analysis.py b/torch/_inductor/codegen/block_analysis.py index 484fa135986..1c816eb8e29 100644 --- a/torch/_inductor/codegen/block_analysis.py +++ b/torch/_inductor/codegen/block_analysis.py @@ -17,8 +17,8 @@ class BlockPatternMatcher: Matches block indexing expressions. """ - @staticmethod - def get_subexpr_involving_symbol(expr: Expr, symbol: Symbol) -> Expr: + @classmethod + def get_subexpr_involving_symbol(cls, expr: Expr, symbol: Symbol) -> Expr: """ Given a sympy expression, return the subexpression comprised only of terms involving the specified symbol. @@ -26,6 +26,7 @@ class BlockPatternMatcher: For example, if `expr` is `x * 5 + x ** 2 + y * 2 + 5`, and `symbol` is `x`, this returns `x * 5 + x ** 2`. """ + expr = cls._preprocess(expr) return sympy.S.Zero + sum( term for term in sympy.Add.make_args(expr) if symbol in term.free_symbols ) @@ -42,6 +43,11 @@ class BlockPatternMatcher: numels.appendleft(numel) return [*numels] + @staticmethod + def _preprocess(expr: Expr) -> Expr: + # Remove any Identity nodes, e.g. expand x + (5 * y) to x + 5 * y. + return expr.expand(identity=True) + @classmethod def match_mod_div_block_expr( cls, @@ -54,6 +60,7 @@ class BlockPatternMatcher: Matches modular indexing expressions, converting them to implied block dimensions and strides. See triton.py for more information. """ + index = cls._preprocess(index) # Pattern match to find the strides and offset. wild = functools.partial(sympy.Wild, exclude=[index_var]) @@ -141,3 +148,21 @@ class BlockPatternMatcher: ) return dims, strides, block_index_exprs + + @classmethod + def match_affine_block_expr( + cls, + index: Expr, + index_var: Symbol, + ) -> Optional[Expr]: + """ + Matches simple expressions of the form stride * index, returning the + stride. + """ + index = cls._preprocess(index) + stride = sympy.Wild("stride", exclude=[index_var]) + m = index.match(index_var * stride) + if m is None: + return None + + return m[stride] diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index e0c1f988479..5f3e1b78783 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1790,7 +1790,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): and self.index_dtype == "tl.int32" ): - def match_strided_block( + def match_affine_block( index: sympy.Expr, range_tree: IterationRangesRoot ) -> Optional[BlockParameters]: """ @@ -1799,16 +1799,16 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): This implies stride (s,), and shape (XBLOCK,). """ - symbol = range_tree.symbol() - stride = sympy.Wild("stride", exclude=[symbol]) - m = index.match(symbol * stride) - if m is None: + stride = BlockPatternMatcher.match_affine_block_expr( + index, range_tree.symbol() + ) + if stride is None: return None return BlockParameters( shape=[range_tree.numel], block_shape=[TritonSymbols.get_block_size(range_tree)], - strides=[m[stride]], + strides=[stride], offsets=[TritonSymbols.get_block_offset(range_tree)], ) @@ -1917,7 +1917,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): Match a block indexing subexpression involving a single range tree. """ for match_func in ( - match_strided_block, + match_affine_block, match_mod_div_block, ): match = match_func(expr, range_tree) diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index 3110c3947af..13de003b330 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -10,6 +10,9 @@ import os from subprocess import CalledProcessError import sys import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +from torch.fx.experimental.proxy_tensor import make_fx +from torch._inductor.graph import GraphLowering +from torch._inductor.compile_fx import shape_env_from_inputs from torch._inductor.codecache import CppCodeCache from torch._inductor.utils import get_gpu_shared_memory, is_big_gpu from torch._inductor.utils import GPU_TYPES, get_gpu_type @@ -142,6 +145,21 @@ IS_H100 = LazyVal( IS_BIG_GPU = LazyVal(lambda: HAS_CUDA and is_big_gpu()) +def dummy_graph() -> GraphLowering: + """ + Create a graph. This is useful for unit testing code which accesses + V.graph.sizevars. + """ + example_inputs = [torch.randn(10) for _ in range(2)] + gm = make_fx(torch.add, tracing_mode="fake")(*example_inputs) + shape_env = shape_env_from_inputs(example_inputs) + graph = GraphLowering( + gm, + shape_env=shape_env, + ) + + return graph + def maybe_skip_size_asserts(op): """ For certain ops, there meta and eager implementation returns differents diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index ae0a1eee398..15db18f3307 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -1286,6 +1286,10 @@ class Identity(sympy.Function): def _eval_is_integer(self): return self.args[0].is_integer # type: ignore[attr-defined] + def _eval_expand_identity(self, **hints): + # Removes the identity op. + return self.args[0] + def make_opaque_unary_fn(name): class OpaqueUnaryFn(sympy.Function):