diff --git a/test/inductor/test_block_analysis.py b/test/inductor/test_block_analysis.py new file mode 100644 index 00000000000..5cf932d52e8 --- /dev/null +++ b/test/inductor/test_block_analysis.py @@ -0,0 +1,102 @@ +# Owner(s): ["module: inductor"] + +import sympy + +import torch +from torch._inductor.codegen.block_analysis import BlockPatternMatcher +from torch._inductor.virtualized import V +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + run_tests, + TestCase, +) +from torch.testing._internal.inductor_utils import dummy_graph +from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing + + +# Some useful symbols +x, y = sympy.symbols("x y") + + +@instantiate_parametrized_tests +class BlockAnalysisTest(TestCase): + @classmethod + def setUpClass(cls): + super().setUpClass() + + # Create a GraphLowering, so we can access V.graph. + cls.graph = dummy_graph() + + @parametrize( + "stride,symbol,expr", + [ + (5, x, Identity(5 * x)), + (4, y, 4 * Identity(y)), + (3, x, Identity(3) * x), + ], + ) + def test_affine_identity(self, stride: int, symbol: sympy.Symbol, expr: sympy.Expr): + # Test that we can handle an identity expression in affine indexing. + matched_stride = BlockPatternMatcher.match_affine_block_expr(expr, symbol) + self.assertEqual(matched_stride, stride) + + @parametrize( + "dims,strides,symbol,expr", + [ + ( + (2, 4), + (4, 1), + x, + 4 * FloorDiv(Identity(x), 4) + ModularIndexing(x, 1, 4), + ), + ( + (3, 9), + (5, 2), + x, + 5 * FloorDiv(x, 9) + 2 * ModularIndexing(Identity(x), 1, 9), + ), + ((2, 7), (1, 1), x, Identity(FloorDiv(x, 7) + ModularIndexing(x, 1, 7))), + ], + ) + def test_mod_div_identity( + self, + dims: tuple[int], + strides: tuple[int], + symbol: sympy.Symbol, + expr: sympy.Expr, + ): + # Test that we can handle an identity expression in modular indexing. + numel = int(torch.prod(torch.Tensor(dims))) + num_dims = len(dims) + with V.set_graph_handler(self.graph): + match_result = BlockPatternMatcher.match_mod_div_block_expr( + expr, symbol, numel, num_dims + ) + + # Check the matched block dimensions. + self.assertNotEqual(match_result, None) + matched_dims, matched_strides, matched_block_index_exprs = match_result + self.assertEqual(matched_dims, dims) + self.assertEqual(matched_strides, strides) + + @parametrize( + "symbol,expr,subexpr", + [ + (x, Identity(x), x), + (x, Identity(x + 5), x), + (y, Identity(x + 2 * y) + 5, 2 * y), + ], + ) + def test_subexpr_identity( + self, + symbol: sympy.Symbol, + expr: sympy.Expr, + subexpr: sympy.Expr, + ): + matched_subexpr = BlockPatternMatcher.get_subexpr_involving_symbol(expr, symbol) + self.assertEqual(matched_subexpr, subexpr) + + +if __name__ == "__main__": + run_tests() diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py index dddb73c2851..e804e289c1c 100644 --- a/test/test_sympy_utils.py +++ b/test/test_sympy_utils.py @@ -22,6 +22,7 @@ from torch.testing._internal.common_utils import ( ) from torch.utils._sympy.functions import ( FloorDiv, + Identity, OpaqueUnaryFn_cos, simple_floordiv_gcd, ) @@ -955,6 +956,17 @@ class TestSingletonInt(TestCase): self.assertEqual(j1.free_symbols, set()) +class TestIdentity(TestCase): + def test_expand_identity(self): + """ + Test removing an identity via expansion. + """ + x = sympy.Symbol("x") + arg = x + sympy.S.One + expr = Identity(arg) + expanded = expr.expand(identity=True) + self.assertEqual(expanded.count(Identity), 0) + self.assertEqual(expanded, arg) instantiate_parametrized_tests(TestValueRanges) instantiate_parametrized_tests(TestSympyInterp) diff --git a/torch/_inductor/codegen/block_analysis.py b/torch/_inductor/codegen/block_analysis.py index 484fa135986..1c816eb8e29 100644 --- a/torch/_inductor/codegen/block_analysis.py +++ b/torch/_inductor/codegen/block_analysis.py @@ -17,8 +17,8 @@ class BlockPatternMatcher: Matches block indexing expressions. """ - @staticmethod - def get_subexpr_involving_symbol(expr: Expr, symbol: Symbol) -> Expr: + @classmethod + def get_subexpr_involving_symbol(cls, expr: Expr, symbol: Symbol) -> Expr: """ Given a sympy expression, return the subexpression comprised only of terms involving the specified symbol. @@ -26,6 +26,7 @@ class BlockPatternMatcher: For example, if `expr` is `x * 5 + x ** 2 + y * 2 + 5`, and `symbol` is `x`, this returns `x * 5 + x ** 2`. """ + expr = cls._preprocess(expr) return sympy.S.Zero + sum( term for term in sympy.Add.make_args(expr) if symbol in term.free_symbols ) @@ -42,6 +43,11 @@ class BlockPatternMatcher: numels.appendleft(numel) return [*numels] + @staticmethod + def _preprocess(expr: Expr) -> Expr: + # Remove any Identity nodes, e.g. expand x + (5 * y) to x + 5 * y. + return expr.expand(identity=True) + @classmethod def match_mod_div_block_expr( cls, @@ -54,6 +60,7 @@ class BlockPatternMatcher: Matches modular indexing expressions, converting them to implied block dimensions and strides. See triton.py for more information. """ + index = cls._preprocess(index) # Pattern match to find the strides and offset. wild = functools.partial(sympy.Wild, exclude=[index_var]) @@ -141,3 +148,21 @@ class BlockPatternMatcher: ) return dims, strides, block_index_exprs + + @classmethod + def match_affine_block_expr( + cls, + index: Expr, + index_var: Symbol, + ) -> Optional[Expr]: + """ + Matches simple expressions of the form stride * index, returning the + stride. + """ + index = cls._preprocess(index) + stride = sympy.Wild("stride", exclude=[index_var]) + m = index.match(index_var * stride) + if m is None: + return None + + return m[stride] diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index e0c1f988479..5f3e1b78783 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1790,7 +1790,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): and self.index_dtype == "tl.int32" ): - def match_strided_block( + def match_affine_block( index: sympy.Expr, range_tree: IterationRangesRoot ) -> Optional[BlockParameters]: """ @@ -1799,16 +1799,16 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): This implies stride (s,), and shape (XBLOCK,). """ - symbol = range_tree.symbol() - stride = sympy.Wild("stride", exclude=[symbol]) - m = index.match(symbol * stride) - if m is None: + stride = BlockPatternMatcher.match_affine_block_expr( + index, range_tree.symbol() + ) + if stride is None: return None return BlockParameters( shape=[range_tree.numel], block_shape=[TritonSymbols.get_block_size(range_tree)], - strides=[m[stride]], + strides=[stride], offsets=[TritonSymbols.get_block_offset(range_tree)], ) @@ -1917,7 +1917,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): Match a block indexing subexpression involving a single range tree. """ for match_func in ( - match_strided_block, + match_affine_block, match_mod_div_block, ): match = match_func(expr, range_tree) diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index 3110c3947af..13de003b330 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -10,6 +10,9 @@ import os from subprocess import CalledProcessError import sys import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +from torch.fx.experimental.proxy_tensor import make_fx +from torch._inductor.graph import GraphLowering +from torch._inductor.compile_fx import shape_env_from_inputs from torch._inductor.codecache import CppCodeCache from torch._inductor.utils import get_gpu_shared_memory, is_big_gpu from torch._inductor.utils import GPU_TYPES, get_gpu_type @@ -142,6 +145,21 @@ IS_H100 = LazyVal( IS_BIG_GPU = LazyVal(lambda: HAS_CUDA and is_big_gpu()) +def dummy_graph() -> GraphLowering: + """ + Create a graph. This is useful for unit testing code which accesses + V.graph.sizevars. + """ + example_inputs = [torch.randn(10) for _ in range(2)] + gm = make_fx(torch.add, tracing_mode="fake")(*example_inputs) + shape_env = shape_env_from_inputs(example_inputs) + graph = GraphLowering( + gm, + shape_env=shape_env, + ) + + return graph + def maybe_skip_size_asserts(op): """ For certain ops, there meta and eager implementation returns differents diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index ae0a1eee398..15db18f3307 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -1286,6 +1286,10 @@ class Identity(sympy.Function): def _eval_is_integer(self): return self.args[0].is_integer # type: ignore[attr-defined] + def _eval_expand_identity(self, **hints): + # Removes the identity op. + return self.args[0] + def make_opaque_unary_fn(name): class OpaqueUnaryFn(sympy.Function):