mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[inductor][cpp] improve vector contiguous checks for FloorDiv and ModularIndexing (#117221)
Fix https://github.com/pytorch/pytorch/issues/114488 The PR tries to enable contiguous vector loads for cases where we can reduce `FloorDiv` and `ModularIndexing` in the vectorized loop. Take the index expression in test case `test_vec_contiguous_ModularIndexing` for example. `14336*x0 + 256*x1 + 128*((x2//256)) + ModularIndexing(x2, 1, 128) + 7168*ModularIndexing(x2, 128, 2)` can be reduced to `14336*x0 + 256*x1 + x2 + 128*x2_div_c0 + 7168*x2_mod_c0 + x2_mod_c1` where `x2` is a vectorized loop variable and the vector length is 16. This means we can do vectorized load for this index. Check the code comment for more details: https://github.com/pytorch/pytorch/pull/117221/files#diff-5ab7b0235e2076a5fc6629ba0b109208940f5b94f5c13babc3e0f87cf4fcec82R317-R329 Pull Request resolved: https://github.com/pytorch/pytorch/pull/117221 Approved by: https://github.com/jansel
This commit is contained in:
parent
6c624aad37
commit
172dd13ecf
2 changed files with 119 additions and 24 deletions
|
|
@ -2724,10 +2724,47 @@ class CPUReproTests(TestCase):
|
|||
return y.softmax(dim=-1)
|
||||
|
||||
x = torch.randn(128, 2048)
|
||||
opt_fn = torch.compile(fn)
|
||||
metrics.reset()
|
||||
self.common(fn, (x,))
|
||||
_, code = run_and_get_cpp_code(opt_fn, x)
|
||||
self.assertTrue(same(fn(x), opt_fn(x)))
|
||||
# 4 kernels for max, exp, sum and div
|
||||
assert metrics.generated_cpp_vec_kernel_count == 4
|
||||
FileCheck().check_count(
|
||||
"Vectorized<int>::loadu(tmpbuf.data())", 0, exactly=True
|
||||
).run(code)
|
||||
|
||||
def test_vec_contiguous_ModularIndexing(self):
|
||||
# https://github.com/pytorch/pytorch/issues/114488
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.norm = torch.nn.LayerNorm(dim * 4)
|
||||
|
||||
def forward(self, x):
|
||||
# the pattern from swin_base_patch4_window7_224
|
||||
B, H, W, C = x.shape
|
||||
x = (
|
||||
x.reshape(B, H // 2, 2, W // 2, 2, C)
|
||||
.permute(0, 1, 3, 4, 2, 5)
|
||||
.flatten(3)
|
||||
)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
x = torch.randn(1, 56, 56, 128)
|
||||
m = M(128)
|
||||
opt_m = torch.compile(m)
|
||||
with torch.no_grad():
|
||||
metrics.reset()
|
||||
_, code = run_and_get_cpp_code(opt_m, x)
|
||||
self.assertTrue(same(m(x), opt_m(x)))
|
||||
# Two kernels: one for reduction, one pointwises
|
||||
assert metrics.generated_cpp_vec_kernel_count == 2
|
||||
# Only one kernel has non-contiguous load
|
||||
FileCheck().check_count(
|
||||
"Vectorized<float>::loadu(tmpbuf.data())", 1, exactly=True
|
||||
).run(code)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import torch.fx
|
|||
from torch._inductor import dependencies
|
||||
from torch._inductor.ir import StorageBox, TensorBox
|
||||
from torch._prims_common import is_float_dtype
|
||||
from torch.utils._sympy.functions import FloorDiv
|
||||
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
|
||||
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
|
||||
|
||||
from .. import codecache, config, ir, metrics
|
||||
|
|
@ -311,6 +311,69 @@ def stride_at(var: sympy.Symbol, index: sympy.Expr):
|
|||
return sympy.simplify(new_index - index)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def simplify_index_in_vec_range(index: sympy.Expr, var: sympy.Expr, vec_length: int):
|
||||
"""
|
||||
Simplifies the index expression within the range of a vectorized loop.
|
||||
Given a vectorized loop variable `var` in the range of a loop with `vec_length`,
|
||||
this function transforms the `index` into an equivalent form. It handles
|
||||
simplifications for cases where `var` can be expressed as `vec_length * a + b`,
|
||||
where `b` ranges from 0 to `vec_length - 1`. The function reduces occurrences
|
||||
of `FloorDiv` and `ModularIndexing` in the `index` with best-effort optimizations.
|
||||
|
||||
NOTE:
|
||||
The simplified index expression is intended for analysis purposes only, not
|
||||
for code generation. It replaces `FloorDiv` and `ModularIndexing` with free variables
|
||||
which are not dependent on the loop variable `var` in the vectorized range. Check
|
||||
https://github.com/pytorch/pytorch/pull/117221#discussion_r1449746217 for more details.
|
||||
|
||||
Examples:
|
||||
1. If `var` is `x3` and `vec_length` is 16, and `x3 = 16*a + b`, then
|
||||
`FloorDiv(x3, div)` or `ModularIndexing(x3, div, mod)` becomes a free variable
|
||||
when `div` is divisible by 16.
|
||||
2. `ModularIndexing(x3, 1, mod)` can be simplified to `x3 + c` where `c` is a free
|
||||
variable when `mod` is divisible by 16.
|
||||
"""
|
||||
|
||||
div_freevar_id = 0
|
||||
mod_freevar_id = 0
|
||||
|
||||
def visit_indexing_div(divisor):
|
||||
nonlocal div_freevar_id
|
||||
result = FloorDiv(var, divisor)
|
||||
if sympy.gcd(divisor, vec_length) == vec_length:
|
||||
result = sympy.Symbol(f"{var}_div_c{div_freevar_id}")
|
||||
div_freevar_id += 1
|
||||
return result
|
||||
|
||||
def visit_modular_indexing(divisor, modulus):
|
||||
nonlocal mod_freevar_id
|
||||
result = ModularIndexing(var, divisor, modulus)
|
||||
if sympy.gcd(divisor, vec_length) == vec_length:
|
||||
result = sympy.Symbol(f"{var}_mod_c{mod_freevar_id}")
|
||||
mod_freevar_id += 1
|
||||
elif divisor == 1 and sympy.gcd(modulus, vec_length) == vec_length:
|
||||
result = var + sympy.Symbol(f"{var}_mod_c{mod_freevar_id}")
|
||||
mod_freevar_id += 1
|
||||
return result
|
||||
|
||||
original_index = index
|
||||
|
||||
div = sympy.Wild("divisor")
|
||||
if index.has(FloorDiv):
|
||||
index = index.replace(FloorDiv(var, div), visit_indexing_div)
|
||||
|
||||
mod = sympy.Wild("modulus")
|
||||
if index.has(ModularIndexing):
|
||||
index = index.replace(ModularIndexing(var, div, mod), visit_modular_indexing)
|
||||
|
||||
index = sympy.simplify(index)
|
||||
if index != original_index:
|
||||
return simplify_index_in_vec_range(index, var, vec_length)
|
||||
|
||||
return index
|
||||
|
||||
|
||||
class CppPrinter(ExprPrinter):
|
||||
def _print_Integer(self, expr):
|
||||
return f"{int(expr)}L"
|
||||
|
|
@ -1355,12 +1418,15 @@ class CppVecOverrides(CppOverrides):
|
|||
assert isinstance(V.kernel, CppVecKernel)
|
||||
index = V.kernel.rename_indexing(expr)
|
||||
tiling_var = V.kernel.itervars[V.kernel.tiling_idx]
|
||||
if V.kernel.index_is_vector_invariant(index):
|
||||
return CppOverrides.index_expr(expr, dtype)
|
||||
if stride_at(
|
||||
tiling_var, index
|
||||
).is_number and not V.kernel.index_indirect_depends_on(index, tiling_var):
|
||||
stride = stride_at(tiling_var, index)
|
||||
index_vec_simplified = simplify_index_in_vec_range(
|
||||
index, tiling_var, V.kernel.tiling_factor
|
||||
)
|
||||
stride = stride_at(tiling_var, index_vec_simplified)
|
||||
if stride.is_number and not V.kernel.index_indirect_depends_on(
|
||||
index, tiling_var
|
||||
):
|
||||
if stride == 0:
|
||||
return CppOverrides.index_expr(expr, dtype)
|
||||
value = ops.to_dtype(cexpr(index), dtype)
|
||||
if isinstance(value, OpsValue):
|
||||
value = value.value
|
||||
|
|
@ -1711,18 +1777,6 @@ class CppVecKernel(CppKernel):
|
|||
self.tiling_idx = tiling_idx
|
||||
metrics.generated_cpp_vec_kernel_count += 1
|
||||
|
||||
def index_is_vector_invariant(self, index: sympy.Expr):
|
||||
"""`index` is either independent from the tiling itervar or unchanged in the vector range"""
|
||||
tiling_var = self.itervars[self.tiling_idx]
|
||||
if not self.index_depends_on(index, tiling_var):
|
||||
return True
|
||||
if not self.index_indirect_depends_on(index, tiling_var):
|
||||
vec_range = [
|
||||
sympy_subs(index, {tiling_var: i}) for i in range(self.tiling_factor)
|
||||
]
|
||||
return all(expr == vec_range[0] for expr in vec_range)
|
||||
return False
|
||||
|
||||
def _get_vec_load_line(
|
||||
self,
|
||||
var: str,
|
||||
|
|
@ -1883,12 +1937,16 @@ class CppVecKernel(CppKernel):
|
|||
index = self.rename_indexing(index)
|
||||
dtype = V.graph.get_dtype(name)
|
||||
tiling_var = self.itervars[self.tiling_idx]
|
||||
if self.index_is_vector_invariant(index):
|
||||
index_vec_simplified = simplify_index_in_vec_range(
|
||||
index, tiling_var, self.tiling_factor
|
||||
)
|
||||
stride = stride_at(tiling_var, index_vec_simplified)
|
||||
if stride == 0:
|
||||
# load scalar and lazily broadcast it on demand
|
||||
return super().load(name, index)
|
||||
non_contiguous = stride_at(
|
||||
tiling_var, index
|
||||
) != 1 or self.index_indirect_depends_on(index, tiling_var)
|
||||
non_contiguous = stride != 1 or self.index_indirect_depends_on(
|
||||
index, tiling_var
|
||||
)
|
||||
if non_contiguous:
|
||||
csevar = self.load_non_contiguous(var, index, dtype)
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in a new issue