Retry: Low mem max_pool2d_with_indices (#122832)

Based on #105687

The low memory path does not need to strictly return the int8 offsets
instead the offset to index computation can be separated from the
inner function of the max pool lowering. The partitioner can then choose
to move the offset to index computation into the backward pass.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122832
Approved by: https://github.com/peterbell10, https://github.com/eellison
This commit is contained in:
Andrew M. James 2024-05-08 14:40:07 +00:00 committed by PyTorch MergeBot
parent 005a12722d
commit 445a0c01da
6 changed files with 201 additions and 32 deletions

View file

@ -8690,6 +8690,7 @@ class CommonTemplate:
from torch._inductor.codegen.common import boolean_ops
from torch._inductor.compile_fx import _shape_env_from_inputs
from torch._inductor.debug import DebugContext
from torch._inductor.decomposition import decompositions
from torch._inductor.graph import GraphLowering
from torch._inductor.virtualized import V
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
@ -8715,7 +8716,9 @@ class CommonTemplate:
)
]
gm = torch.fx.symbolic_trace(func)
gm = make_fx(func, decomposition_table=decompositions, tracing_mode="fake")(
*example_inputs
)
shape_env = _shape_env_from_inputs(example_inputs)

View file

@ -12,6 +12,7 @@ from collections import defaultdict
from typing import List, Optional, Set, Tuple, TYPE_CHECKING, Union
import torch
import torch._inductor.inductor_prims
import torch.fx as fx
import torch.utils._pytree as pytree
from torch.fx.experimental._backward_state import BackwardState
@ -932,6 +933,7 @@ def min_cut_rematerialization_partition(
aten.argmax,
aten.maximum,
prims.iota,
prims._low_memory_max_pool2d_offsets_to_indices,
] # noqa: E501,B950
view_ops += [
aten.view,

View file

@ -20,6 +20,7 @@ from torch._decomp.decompositions import (
)
from torch._decomp.decompositions_for_rng import extra_random_decomps
from torch._higher_order_ops.out_dtype import out_dtype
from torch._inductor.utils import pad_listlike
from torch._prims_common import (
elementwise_dtypes,
ELEMENTWISE_TYPE_PROMOTION_KIND,
@ -754,3 +755,49 @@ def index_reduce(
reduction_type,
include_self=include_self,
)
@register_decomposition(aten.max_pool2d_with_indices)
def max_pool2d_with_indices(
x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False
):
if dilation == 1:
dilation = [1, 1]
if padding == 0:
padding = [0, 0]
if stride is None:
stride = kernel_size
kernel_size = pad_listlike(kernel_size, 2)
dilation = pad_listlike(dilation, 2)
padding = pad_listlike(padding, 2)
stride = pad_listlike(stride, 2)
window_size = kernel_size[0] * kernel_size[1]
# We fallback when using non-default dilation or when the window size is too large
if (
torch._inductor.lowering.should_fallback_max_pool2d_with_indices(
kernel_size, dilation
)
or window_size > torch.iinfo(torch.int8).max
):
return NotImplemented
vals, offsets = prims._low_memory_max_pool2d_with_offsets(
x,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
)
indices = prims._low_memory_max_pool2d_offsets_to_indices(
offsets,
kernel_size[1],
x.size(-1),
stride,
padding,
)
return vals, indices

View file

@ -993,7 +993,6 @@ def _register_quantization_maxpool2d():
KeywordArg("ceil_mode"),
],
]
for max_pool2d_args in max_pool2d_args_list:
dequantize_maxpool2d_pattern = CallFunction(
aten.max_pool2d_with_indices.default,
@ -1001,15 +1000,33 @@ def _register_quantization_maxpool2d():
KeywordArg("kernel_size"),
*max_pool2d_args,
)
dequantize_lowmem_maxpool2d_pattern = CallFunction(
prims._low_memory_max_pool2d_with_offsets.default,
dequantize_per_tensor_activation_pattern,
KeywordArg("kernel_size"),
*max_pool2d_args,
KeywordArg("offset_dtype"),
)
dequantize_maxpool2d_get_item_pattern = CallFunction(
operator.getitem,
dequantize_maxpool2d_pattern,
Arg(),
)
dequantize_lowmem_maxpool2d_get_item_pattern = CallFunction(
operator.getitem,
dequantize_lowmem_maxpool2d_pattern,
Arg(),
)
_register_quantized_maxpool2d_lowering(
generate_pattern_with_output_quant(dequantize_maxpool2d_get_item_pattern),
quantized.max_pool2d.default,
)
_register_quantized_maxpool2d_lowering(
generate_pattern_with_output_quant(
dequantize_lowmem_maxpool2d_get_item_pattern
),
quantized.max_pool2d.default,
)
def _is_input_output_same_scale_zp(check_node):

View file

@ -16,8 +16,15 @@ def make_prim(
doc: str = "",
tags: Optional[Sequence[torch.Tag]] = None,
):
def meta(*args, **kwargs):
return _prims.TensorMeta(impl_aten(*args, **kwargs))
if isinstance(return_type, tuple):
def meta(*args, **kwargs):
return tuple(_prims.TensorMeta(o) for o in impl_aten(*args, **kwargs))
else:
def meta(*args, **kwargs):
return _prims.TensorMeta(impl_aten(*args, **kwargs))
return _prims._make_prim(
schema=schema,
@ -93,3 +100,31 @@ fma = make_prim(
lambda a, b, c: (a * b) + c,
doc="Fused multiply add: fma(a, b, c) -> (a * b) + c without rounding after the multiplication",
)
def _low_memory_max_pool2d_with_offsets_aten(
self,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
):
vals, indices = torch.ops.aten.max_pool2d_with_indices(
self, kernel_size, stride, padding, dilation, ceil_mode
)
return vals, indices.to(torch.int8)
_low_memory_max_pool2d_with_offsets = make_prim(
"_low_memory_max_pool2d_with_offsets(Tensor self, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, bool ceil_mode) -> (Tensor, Tensor)", # noqa: B950
_low_memory_max_pool2d_with_offsets_aten,
return_type=(_prims.RETURN_TYPE.NEW, _prims.RETURN_TYPE.NEW),
doc="Instead of returning indices, returns indices offsets.",
)
_low_memory_max_pool2d_offsets_to_indices = make_prim(
"_low_memory_max_pool2d_offsets_to_indices(Tensor self, SymInt kernel_w, SymInt input_w, SymInt[2] stride, SymInt[2] padding) -> Tensor", # noqa: B950
lambda self, *args: self.to(torch.int64),
doc="Convert small int offsets to regular indices.",
)

View file

@ -3511,15 +3511,14 @@ def pooling_size(x, i, kernel_size, stride, padding, ceil_mode):
return x_out, ceil_mode
fallback_max_pool2d_with_indices = fallback_handler(
aten.max_pool2d_with_indices.default,
add_to_fallback_set=False,
)
def should_fallback_max_pool2d_with_indices(kernel_size, dilation):
kernel_size = pad_listlike(kernel_size, 2)
window_size = kernel_size[0] * kernel_size[1]
return (window_size > 25) or any(d > 1 for d in dilation)
@register_lowering(aten.max_pool2d_with_indices, type_promotion_kind=None)
def max_pool2d_with_indices(
x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False
def max_pool2d_checks(
x, kernel_size, stride, padding, dilation, *, assert_fallback=None
):
if padding == 0:
padding = [0, 0]
@ -3527,6 +3526,7 @@ def max_pool2d_with_indices(
dilation = [1, 1]
if not stride:
stride = kernel_size
kernel_size = pad_listlike(kernel_size, 2)
stride = pad_listlike(stride, 2)
padding = pad_listlike(padding, 2)
@ -3539,36 +3539,51 @@ def max_pool2d_with_indices(
assert len(dilation) == 2
assert len(x.get_size()) in (3, 4)
use_fallback = should_fallback_max_pool2d_with_indices(kernel_size, dilation)
if assert_fallback is not None:
assert use_fallback == assert_fallback
return kernel_size, stride, padding, dilation, use_fallback
@register_lowering(prims._low_memory_max_pool2d_with_offsets, type_promotion_kind=None)
def _low_memory_max_pool2d_with_offsets(
x,
kernel_size,
stride,
padding,
dilation,
ceil_mode=False,
):
# assert we are not on a fallback path, the inductor decomp should have guaranteed this
kernel_size, stride, padding, dilation, _ = max_pool2d_checks(
x, kernel_size, stride, padding, dilation, assert_fallback=False
)
x.realize_hint()
*batch, h, w = x.get_size()
h_out, ceil_mode1 = pooling_size(h, 0, kernel_size, stride, padding, ceil_mode)
w_out, ceil_mode2 = pooling_size(w, 1, kernel_size, stride, padding, ceil_mode)
new_size = list(batch) + [h_out, w_out]
if padding[0] or padding[1] or ceil_mode1 or ceil_mode2:
x_loader = constant_boundary_condition(x, float("-inf"), dim=2)
else:
x_loader = x.make_loader()
new_size = list(batch) + [h_out, w_out]
window_size = kernel_size[0] * kernel_size[1]
if window_size > 25 or any(d != 1 for d in dilation):
# Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
return fallback_max_pool2d_with_indices(
x, kernel_size, stride, padding, dilation, ceil_mode
)
def fn(idx, return_index):
*prefix, bh, bw = idx
maxval = None
maxindex = None
for ih, iw in itertools.product(range(kernel_size[0]), range(kernel_size[1])):
ih = bh * stride[0] + ih - padding[0]
iw = bw * stride[1] + iw - padding[1]
for h_inc, w_inc in itertools.product(
range(kernel_size[0]), range(kernel_size[1])
):
ih = bh * stride[0] + h_inc - padding[0]
iw = bw * stride[1] + w_inc - padding[1]
val = x_loader([*prefix, ih, iw])
if return_index:
index = ops.index_expr(ih * w + iw, torch.int64)
index = ops.index_expr(h_inc * kernel_size[1] + w_inc, torch.int8)
if maxindex is None:
maxindex = index
else:
@ -3582,20 +3597,58 @@ def max_pool2d_with_indices(
else:
return maxval
r1 = Pointwise.create(
out = Pointwise.create(
device=x.get_device(),
dtype=x.get_dtype(),
inner_fn=functools.partial(fn, return_index=False),
ranges=new_size,
)
r2 = Pointwise.create(
offsets = Pointwise.create(
device=x.get_device(),
dtype=torch.int64,
dtype=torch.int8,
inner_fn=functools.partial(fn, return_index=True),
ranges=new_size,
)
# TODO(jansel): should we force these to be realized?
return r1, r2
return out, offsets
@register_lowering(
prims._low_memory_max_pool2d_offsets_to_indices, type_promotion_kind=None
)
def _low_memory_max_pool2d_offsets_to_indices(
offsets, kernel_width, input_width, stride, padding
):
# TODO: Generalize to other max pooling flavors, and arbitrary dim
offsets_loader = offsets.make_loader()
def increments_to_index(h_inc, w_inc, bh, bw):
w_in = ops.index_expr(input_width, torch.int64)
hbase = ops.index_expr(bh * stride[0] - padding[0], torch.int64)
wbase = ops.index_expr(bw * stride[1] - padding[1], torch.int64)
ih = hbase + h_inc
iw = wbase + w_inc
return ih * w_in + iw
def offsets_to_indices(idx):
*prefix, bh, bw = idx
offset = offsets_loader([*prefix, bh, bw])
kw_const = ops.constant(kernel_width, torch.int32)
h_inc = offset // kw_const
w_inc = offset - (h_inc * kw_const)
return increments_to_index(h_inc, w_inc, bh, bw)
indices = Pointwise.create(
device=offsets.get_device(),
dtype=torch.int64,
inner_fn=offsets_to_indices,
ranges=offsets.get_size(),
)
return indices
# Fallback selected when we do not decompose to the low-memory path.
make_fallback(aten.max_pool2d_with_indices)
fallback_max_pool2d_with_indices_backward = fallback_handler(
@ -3658,8 +3711,6 @@ def max_pool2d_with_indices_backward(
grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices
)
indices.realize_hint()
*batch, height, width = x.get_size()
*_, pooled_height, pooled_width = grad_output.get_size()
@ -3970,7 +4021,21 @@ def adaptive_max_pool2d(x, output_size):
)
if h_in % h_out == 0 and w_in % w_out == 0:
kernel_size = [h_in // h_out, w_in // w_out]
return max_pool2d_with_indices(x, kernel_size)
if should_fallback_max_pool2d_with_indices(kernel_size, dilation=[1, 1]):
return max_pool2d_with_indices(x, kernel_size) # type: ignore[name-defined] # noqa: F821
else:
v, offsets = _low_memory_max_pool2d_with_offsets(
x,
kernel_size,
stride=kernel_size,
padding=[0, 0],
dilation=[1, 1],
ceil_mode=False,
)
indices = _low_memory_max_pool2d_offsets_to_indices(
offsets, kernel_size[1], w_in, kernel_size, padding=[0, 0]
)
return v, indices
h_kernel_max = ceildiv((h_in + h_out - 1), h_out)
w_kernel_max = ceildiv((w_in + w_out - 1), w_out)