pytorch/torch/_inductor
Horace He 0485bf5398 Avoid saving pointwise intermediate to global memory if followed by a reduction (#93810)
Should fix https://github.com/pytorch/pytorch/issues/91880 and maybe https://github.com/pytorch/pytorch/issues/91799

For this code:
```
@torch.compile
def f(a, b):
    return (a-b).sum(dim=-1).amax(dim=-1)

N = 2**14
K = 5

A = torch.randn(N, 1, K, device='cuda')
B = torch.randn(1, N, K, device='cuda')
bench(lambda: f(A, B), name=f"K={K}")
print(f"peak Mem: {torch.cuda.max_memory_allocated()/1e9}GB")
```

Before my change, we generated (simplified versions)
```
def triton_(in_ptr0, in_ptr1, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    ...
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp1 = tl.load(in_ptr1 + (5*r1), rmask, eviction_policy='evict_last')
       ...
        tmp18 = tmp14 + tmp17
        tl.store(out_ptr0 + (r1 + (16384*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp18, rmask & xmask)
    _tmp20 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf")
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp19 = tl.load(out_ptr0 + (r1 + (16384*x0)), rmask & xmask, eviction_policy='evict_last')
        _tmp20 = tl.where(rmask & xmask & (_tmp20 < tmp19), tmp19, _tmp20)
    tmp20 = tl.max(_tmp20, 1)[:, None]
    tl.store(out_ptr1 + x0, tmp20, xmask)
```
and after
```
def triton_(in_ptr0, in_ptr1, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
   ...
    _tmp19 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf")
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp1 = tl.load(in_ptr1 + (5*r1), rmask, eviction_policy='evict_last')
        ...
        tmp18 = tmp14 + tmp17
        _tmp19 = tl.where(rmask & xmask & (_tmp19 < tmp18), tmp18, _tmp19)
    tmp19 = tl.max(_tmp19, 1)[:, None]
    tl.store(out_ptr1 + x0, tmp19, xmask)
```
<details>
  <summary>full kernels here
</summary>
Before:
  ```
def triton_(in_ptr0, in_ptr1, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 16384
    rnumel = 16384
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (5*x0), xmask)
    tmp3 = tl.load(in_ptr0 + (1 + (5*x0)), xmask)
    tmp7 = tl.load(in_ptr0 + (2 + (5*x0)), xmask)
    tmp11 = tl.load(in_ptr0 + (3 + (5*x0)), xmask)
    tmp15 = tl.load(in_ptr0 + (4 + (5*x0)), xmask)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp1 = tl.load(in_ptr1 + (5*r1), rmask, eviction_policy='evict_last')
        tmp4 = tl.load(in_ptr1 + (1 + (5*r1)), rmask, eviction_policy='evict_last')
        tmp8 = tl.load(in_ptr1 + (2 + (5*r1)), rmask, eviction_policy='evict_last')
        tmp12 = tl.load(in_ptr1 + (3 + (5*r1)), rmask, eviction_policy='evict_last')
        tmp16 = tl.load(in_ptr1 + (4 + (5*r1)), rmask, eviction_policy='evict_last')
        tmp2 = tmp0 - tmp1
        tmp5 = tmp3 - tmp4
        tmp6 = tmp2 + tmp5
        tmp9 = tmp7 - tmp8
        tmp10 = tmp6 + tmp9
        tmp13 = tmp11 - tmp12
        tmp14 = tmp10 + tmp13
        tmp17 = tmp15 - tmp16
        tmp18 = tmp14 + tmp17
        tl.store(out_ptr0 + (r1 + (16384*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp18, rmask & xmask)
    _tmp20 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf")
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp19 = tl.load(out_ptr0 + (r1 + (16384*x0)), rmask & xmask, eviction_policy='evict_last')
        _tmp20 = tl.where(rmask & xmask & (_tmp20 < tmp19), tmp19, _tmp20)
    tmp20 = tl.max(_tmp20, 1)[:, None]
    tl.store(out_ptr1 + x0, tmp20, xmask)
```
After:
```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 16384
    rnumel = 16384
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (5*x0), xmask)
    tmp3 = tl.load(in_ptr0 + (1 + (5*x0)), xmask)
    tmp7 = tl.load(in_ptr0 + (2 + (5*x0)), xmask)
    tmp11 = tl.load(in_ptr0 + (3 + (5*x0)), xmask)
    tmp15 = tl.load(in_ptr0 + (4 + (5*x0)), xmask)
    _tmp19 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf")
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp1 = tl.load(in_ptr1 + (5*r1), rmask, eviction_policy='evict_last')
        tmp4 = tl.load(in_ptr1 + (1 + (5*r1)), rmask, eviction_policy='evict_last')
        tmp8 = tl.load(in_ptr1 + (2 + (5*r1)), rmask, eviction_policy='evict_last')
        tmp12 = tl.load(in_ptr1 + (3 + (5*r1)), rmask, eviction_policy='evict_last')
        tmp16 = tl.load(in_ptr1 + (4 + (5*r1)), rmask, eviction_policy='evict_last')
        tmp2 = tmp0 - tmp1
        tmp5 = tmp3 - tmp4
        tmp6 = tmp2 + tmp5
        tmp9 = tmp7 - tmp8
        tmp10 = tmp6 + tmp9
        tmp13 = tmp11 - tmp12
        tmp14 = tmp10 + tmp13
        tmp17 = tmp15 - tmp16
        tmp18 = tmp14 + tmp17
        _tmp19 = tl.where(rmask & xmask & (_tmp19 < tmp18), tmp18, _tmp19)
    tmp19 = tl.max(_tmp19, 1)[:, None]
    tl.store(out_ptr1 + x0, tmp19, xmask)
```

</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93810
Approved by: https://github.com/ngimel, https://github.com/jansel
2023-02-02 00:02:14 +00:00
..
codegen Avoid saving pointwise intermediate to global memory if followed by a reduction (#93810) 2023-02-02 00:02:14 +00:00
kernel [inductor] Pattern matching engine (copy) (#93291) 2023-01-31 04:51:00 +00:00
triton_ops [inductor] Clean up TRITON_CACHE_DIR (#92879) 2023-01-27 08:08:27 +00:00
__init__.py
codecache.py [inductor] Clean up TRITON_CACHE_DIR (#92879) 2023-01-27 08:08:27 +00:00
compile_fx.py ConfigModule for _{dynamo,inductor}.config (#93252) 2023-02-01 19:38:05 +00:00
config.py Disable cudagraphs by default (#93253) 2023-02-01 19:38:05 +00:00
cuda_properties.py
debug.py [inductor] Pattern matching engine (copy) (#93291) 2023-01-31 04:51:00 +00:00
decomposition.py Introduce core_aten_decompositions (#93131) 2023-02-01 08:35:46 +00:00
dependencies.py [inductor] Prevent blowup in inner_fn_str and extract_read_writes (#88933) 2022-12-15 15:36:52 +00:00
exc.py Find other temp directory for code cache if no /tmp (#91701) 2023-01-05 02:29:52 +00:00
fx_utils.py inductor: support more conv+unary fusion (#92518) 2023-01-30 07:21:50 +00:00
graph.py [inductor] Lower fallback kernel warnings from WARNING to INFO (#93330) 2023-01-31 17:34:17 +00:00
ir.py [Re-open 90266] [inductor] weight prepack for _convolution_transpose_pointwise (#91955) 2023-01-31 13:28:57 +00:00
lowering.py [inductor] Don't skip realize heuristics with dynamic shapes (#93814) 2023-02-01 06:27:45 +00:00
metrics.py
mkldnn.py [Re-open 90267] [inductor] weight prepack for single conv_transpose2d (#91956) 2023-02-01 12:36:52 +00:00
optimize_indexing.py Minor sympy usage fix in fbcode (#93171) 2023-01-30 23:34:22 +00:00
overrides.py Make CPU inductor work with dynamic shapes (#93077) 2023-01-27 23:18:55 +00:00
pattern_matcher.py [inductor] Pattern matching engine (copy) (#93291) 2023-01-31 04:51:00 +00:00
scheduler.py [inductor] only check mutations attr for TritonKernel (#92277) 2023-02-01 14:12:33 +00:00
select_algorithm.py Populate extern_kernels on import (#93282) 2023-01-31 04:52:10 +00:00
sizevars.py Replace IndexingDiv with FloorDiv in Inductor (#92878) 2023-01-24 15:06:22 +00:00
test_operators.py
utils.py Replace IndexingDiv with FloorDiv in Inductor (#92878) 2023-01-24 15:06:22 +00:00
virtualized.py [inductor] Pattern matching engine (copy) (#93291) 2023-01-31 04:51:00 +00:00