From 172dd13ecff965a549bf4f2a58f5b2900a00497a Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Fri, 12 Jan 2024 15:10:42 +0800 Subject: [PATCH] [inductor][cpp] improve vector contiguous checks for FloorDiv and ModularIndexing (#117221) Fix https://github.com/pytorch/pytorch/issues/114488 The PR tries to enable contiguous vector loads for cases where we can reduce `FloorDiv` and `ModularIndexing` in the vectorized loop. Take the index expression in test case `test_vec_contiguous_ModularIndexing` for example. `14336*x0 + 256*x1 + 128*((x2//256)) + ModularIndexing(x2, 1, 128) + 7168*ModularIndexing(x2, 128, 2)` can be reduced to `14336*x0 + 256*x1 + x2 + 128*x2_div_c0 + 7168*x2_mod_c0 + x2_mod_c1` where `x2` is a vectorized loop variable and the vector length is 16. This means we can do vectorized load for this index. Check the code comment for more details: https://github.com/pytorch/pytorch/pull/117221/files#diff-5ab7b0235e2076a5fc6629ba0b109208940f5b94f5c13babc3e0f87cf4fcec82R317-R329 Pull Request resolved: https://github.com/pytorch/pytorch/pull/117221 Approved by: https://github.com/jansel --- test/inductor/test_cpu_repro.py | 39 +++++++++++- torch/_inductor/codegen/cpp.py | 104 +++++++++++++++++++++++++------- 2 files changed, 119 insertions(+), 24 deletions(-) diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index a36913192fa..925d502ce9c 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -2724,10 +2724,47 @@ class CPUReproTests(TestCase): return y.softmax(dim=-1) x = torch.randn(128, 2048) + opt_fn = torch.compile(fn) metrics.reset() - self.common(fn, (x,)) + _, code = run_and_get_cpp_code(opt_fn, x) + self.assertTrue(same(fn(x), opt_fn(x))) # 4 kernels for max, exp, sum and div assert metrics.generated_cpp_vec_kernel_count == 4 + FileCheck().check_count( + "Vectorized::loadu(tmpbuf.data())", 0, exactly=True + ).run(code) + + def test_vec_contiguous_ModularIndexing(self): + # https://github.com/pytorch/pytorch/issues/114488 + class M(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.norm = torch.nn.LayerNorm(dim * 4) + + def forward(self, x): + # the pattern from swin_base_patch4_window7_224 + B, H, W, C = x.shape + x = ( + x.reshape(B, H // 2, 2, W // 2, 2, C) + .permute(0, 1, 3, 4, 2, 5) + .flatten(3) + ) + x = self.norm(x) + return x + + x = torch.randn(1, 56, 56, 128) + m = M(128) + opt_m = torch.compile(m) + with torch.no_grad(): + metrics.reset() + _, code = run_and_get_cpp_code(opt_m, x) + self.assertTrue(same(m(x), opt_m(x))) + # Two kernels: one for reduction, one pointwises + assert metrics.generated_cpp_vec_kernel_count == 2 + # Only one kernel has non-contiguous load + FileCheck().check_count( + "Vectorized::loadu(tmpbuf.data())", 1, exactly=True + ).run(code) if __name__ == "__main__": diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index fb2a5aaa98c..f2f40143ff4 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -16,7 +16,7 @@ import torch.fx from torch._inductor import dependencies from torch._inductor.ir import StorageBox, TensorBox from torch._prims_common import is_float_dtype -from torch.utils._sympy.functions import FloorDiv +from torch.utils._sympy.functions import FloorDiv, ModularIndexing from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges from .. import codecache, config, ir, metrics @@ -311,6 +311,69 @@ def stride_at(var: sympy.Symbol, index: sympy.Expr): return sympy.simplify(new_index - index) +@functools.lru_cache +def simplify_index_in_vec_range(index: sympy.Expr, var: sympy.Expr, vec_length: int): + """ + Simplifies the index expression within the range of a vectorized loop. + Given a vectorized loop variable `var` in the range of a loop with `vec_length`, + this function transforms the `index` into an equivalent form. It handles + simplifications for cases where `var` can be expressed as `vec_length * a + b`, + where `b` ranges from 0 to `vec_length - 1`. The function reduces occurrences + of `FloorDiv` and `ModularIndexing` in the `index` with best-effort optimizations. + + NOTE: + The simplified index expression is intended for analysis purposes only, not + for code generation. It replaces `FloorDiv` and `ModularIndexing` with free variables + which are not dependent on the loop variable `var` in the vectorized range. Check + https://github.com/pytorch/pytorch/pull/117221#discussion_r1449746217 for more details. + + Examples: + 1. If `var` is `x3` and `vec_length` is 16, and `x3 = 16*a + b`, then + `FloorDiv(x3, div)` or `ModularIndexing(x3, div, mod)` becomes a free variable + when `div` is divisible by 16. + 2. `ModularIndexing(x3, 1, mod)` can be simplified to `x3 + c` where `c` is a free + variable when `mod` is divisible by 16. + """ + + div_freevar_id = 0 + mod_freevar_id = 0 + + def visit_indexing_div(divisor): + nonlocal div_freevar_id + result = FloorDiv(var, divisor) + if sympy.gcd(divisor, vec_length) == vec_length: + result = sympy.Symbol(f"{var}_div_c{div_freevar_id}") + div_freevar_id += 1 + return result + + def visit_modular_indexing(divisor, modulus): + nonlocal mod_freevar_id + result = ModularIndexing(var, divisor, modulus) + if sympy.gcd(divisor, vec_length) == vec_length: + result = sympy.Symbol(f"{var}_mod_c{mod_freevar_id}") + mod_freevar_id += 1 + elif divisor == 1 and sympy.gcd(modulus, vec_length) == vec_length: + result = var + sympy.Symbol(f"{var}_mod_c{mod_freevar_id}") + mod_freevar_id += 1 + return result + + original_index = index + + div = sympy.Wild("divisor") + if index.has(FloorDiv): + index = index.replace(FloorDiv(var, div), visit_indexing_div) + + mod = sympy.Wild("modulus") + if index.has(ModularIndexing): + index = index.replace(ModularIndexing(var, div, mod), visit_modular_indexing) + + index = sympy.simplify(index) + if index != original_index: + return simplify_index_in_vec_range(index, var, vec_length) + + return index + + class CppPrinter(ExprPrinter): def _print_Integer(self, expr): return f"{int(expr)}L" @@ -1355,12 +1418,15 @@ class CppVecOverrides(CppOverrides): assert isinstance(V.kernel, CppVecKernel) index = V.kernel.rename_indexing(expr) tiling_var = V.kernel.itervars[V.kernel.tiling_idx] - if V.kernel.index_is_vector_invariant(index): - return CppOverrides.index_expr(expr, dtype) - if stride_at( - tiling_var, index - ).is_number and not V.kernel.index_indirect_depends_on(index, tiling_var): - stride = stride_at(tiling_var, index) + index_vec_simplified = simplify_index_in_vec_range( + index, tiling_var, V.kernel.tiling_factor + ) + stride = stride_at(tiling_var, index_vec_simplified) + if stride.is_number and not V.kernel.index_indirect_depends_on( + index, tiling_var + ): + if stride == 0: + return CppOverrides.index_expr(expr, dtype) value = ops.to_dtype(cexpr(index), dtype) if isinstance(value, OpsValue): value = value.value @@ -1711,18 +1777,6 @@ class CppVecKernel(CppKernel): self.tiling_idx = tiling_idx metrics.generated_cpp_vec_kernel_count += 1 - def index_is_vector_invariant(self, index: sympy.Expr): - """`index` is either independent from the tiling itervar or unchanged in the vector range""" - tiling_var = self.itervars[self.tiling_idx] - if not self.index_depends_on(index, tiling_var): - return True - if not self.index_indirect_depends_on(index, tiling_var): - vec_range = [ - sympy_subs(index, {tiling_var: i}) for i in range(self.tiling_factor) - ] - return all(expr == vec_range[0] for expr in vec_range) - return False - def _get_vec_load_line( self, var: str, @@ -1883,12 +1937,16 @@ class CppVecKernel(CppKernel): index = self.rename_indexing(index) dtype = V.graph.get_dtype(name) tiling_var = self.itervars[self.tiling_idx] - if self.index_is_vector_invariant(index): + index_vec_simplified = simplify_index_in_vec_range( + index, tiling_var, self.tiling_factor + ) + stride = stride_at(tiling_var, index_vec_simplified) + if stride == 0: # load scalar and lazily broadcast it on demand return super().load(name, index) - non_contiguous = stride_at( - tiling_var, index - ) != 1 or self.index_indirect_depends_on(index, tiling_var) + non_contiguous = stride != 1 or self.index_indirect_depends_on( + index, tiling_var + ) if non_contiguous: csevar = self.load_non_contiguous(var, index, dtype) else: