mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Retry: Low mem max_pool2d_with_indices (#122832)
Based on #105687 The low memory path does not need to strictly return the int8 offsets instead the offset to index computation can be separated from the inner function of the max pool lowering. The partitioner can then choose to move the offset to index computation into the backward pass. Pull Request resolved: https://github.com/pytorch/pytorch/pull/122832 Approved by: https://github.com/peterbell10, https://github.com/eellison
This commit is contained in:
parent
005a12722d
commit
445a0c01da
6 changed files with 201 additions and 32 deletions
|
|
@ -8690,6 +8690,7 @@ class CommonTemplate:
|
|||
from torch._inductor.codegen.common import boolean_ops
|
||||
from torch._inductor.compile_fx import _shape_env_from_inputs
|
||||
from torch._inductor.debug import DebugContext
|
||||
from torch._inductor.decomposition import decompositions
|
||||
from torch._inductor.graph import GraphLowering
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
|
||||
|
|
@ -8715,7 +8716,9 @@ class CommonTemplate:
|
|||
)
|
||||
]
|
||||
|
||||
gm = torch.fx.symbolic_trace(func)
|
||||
gm = make_fx(func, decomposition_table=decompositions, tracing_mode="fake")(
|
||||
*example_inputs
|
||||
)
|
||||
|
||||
shape_env = _shape_env_from_inputs(example_inputs)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from collections import defaultdict
|
|||
from typing import List, Optional, Set, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch._inductor.inductor_prims
|
||||
import torch.fx as fx
|
||||
import torch.utils._pytree as pytree
|
||||
from torch.fx.experimental._backward_state import BackwardState
|
||||
|
|
@ -932,6 +933,7 @@ def min_cut_rematerialization_partition(
|
|||
aten.argmax,
|
||||
aten.maximum,
|
||||
prims.iota,
|
||||
prims._low_memory_max_pool2d_offsets_to_indices,
|
||||
] # noqa: E501,B950
|
||||
view_ops += [
|
||||
aten.view,
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from torch._decomp.decompositions import (
|
|||
)
|
||||
from torch._decomp.decompositions_for_rng import extra_random_decomps
|
||||
from torch._higher_order_ops.out_dtype import out_dtype
|
||||
from torch._inductor.utils import pad_listlike
|
||||
from torch._prims_common import (
|
||||
elementwise_dtypes,
|
||||
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||||
|
|
@ -754,3 +755,49 @@ def index_reduce(
|
|||
reduction_type,
|
||||
include_self=include_self,
|
||||
)
|
||||
|
||||
|
||||
@register_decomposition(aten.max_pool2d_with_indices)
|
||||
def max_pool2d_with_indices(
|
||||
x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False
|
||||
):
|
||||
if dilation == 1:
|
||||
dilation = [1, 1]
|
||||
|
||||
if padding == 0:
|
||||
padding = [0, 0]
|
||||
|
||||
if stride is None:
|
||||
stride = kernel_size
|
||||
|
||||
kernel_size = pad_listlike(kernel_size, 2)
|
||||
dilation = pad_listlike(dilation, 2)
|
||||
padding = pad_listlike(padding, 2)
|
||||
stride = pad_listlike(stride, 2)
|
||||
|
||||
window_size = kernel_size[0] * kernel_size[1]
|
||||
# We fallback when using non-default dilation or when the window size is too large
|
||||
if (
|
||||
torch._inductor.lowering.should_fallback_max_pool2d_with_indices(
|
||||
kernel_size, dilation
|
||||
)
|
||||
or window_size > torch.iinfo(torch.int8).max
|
||||
):
|
||||
return NotImplemented
|
||||
|
||||
vals, offsets = prims._low_memory_max_pool2d_with_offsets(
|
||||
x,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
ceil_mode,
|
||||
)
|
||||
indices = prims._low_memory_max_pool2d_offsets_to_indices(
|
||||
offsets,
|
||||
kernel_size[1],
|
||||
x.size(-1),
|
||||
stride,
|
||||
padding,
|
||||
)
|
||||
return vals, indices
|
||||
|
|
|
|||
|
|
@ -993,7 +993,6 @@ def _register_quantization_maxpool2d():
|
|||
KeywordArg("ceil_mode"),
|
||||
],
|
||||
]
|
||||
|
||||
for max_pool2d_args in max_pool2d_args_list:
|
||||
dequantize_maxpool2d_pattern = CallFunction(
|
||||
aten.max_pool2d_with_indices.default,
|
||||
|
|
@ -1001,15 +1000,33 @@ def _register_quantization_maxpool2d():
|
|||
KeywordArg("kernel_size"),
|
||||
*max_pool2d_args,
|
||||
)
|
||||
dequantize_lowmem_maxpool2d_pattern = CallFunction(
|
||||
prims._low_memory_max_pool2d_with_offsets.default,
|
||||
dequantize_per_tensor_activation_pattern,
|
||||
KeywordArg("kernel_size"),
|
||||
*max_pool2d_args,
|
||||
KeywordArg("offset_dtype"),
|
||||
)
|
||||
dequantize_maxpool2d_get_item_pattern = CallFunction(
|
||||
operator.getitem,
|
||||
dequantize_maxpool2d_pattern,
|
||||
Arg(),
|
||||
)
|
||||
dequantize_lowmem_maxpool2d_get_item_pattern = CallFunction(
|
||||
operator.getitem,
|
||||
dequantize_lowmem_maxpool2d_pattern,
|
||||
Arg(),
|
||||
)
|
||||
_register_quantized_maxpool2d_lowering(
|
||||
generate_pattern_with_output_quant(dequantize_maxpool2d_get_item_pattern),
|
||||
quantized.max_pool2d.default,
|
||||
)
|
||||
_register_quantized_maxpool2d_lowering(
|
||||
generate_pattern_with_output_quant(
|
||||
dequantize_lowmem_maxpool2d_get_item_pattern
|
||||
),
|
||||
quantized.max_pool2d.default,
|
||||
)
|
||||
|
||||
|
||||
def _is_input_output_same_scale_zp(check_node):
|
||||
|
|
|
|||
|
|
@ -16,8 +16,15 @@ def make_prim(
|
|||
doc: str = "",
|
||||
tags: Optional[Sequence[torch.Tag]] = None,
|
||||
):
|
||||
def meta(*args, **kwargs):
|
||||
return _prims.TensorMeta(impl_aten(*args, **kwargs))
|
||||
if isinstance(return_type, tuple):
|
||||
|
||||
def meta(*args, **kwargs):
|
||||
return tuple(_prims.TensorMeta(o) for o in impl_aten(*args, **kwargs))
|
||||
|
||||
else:
|
||||
|
||||
def meta(*args, **kwargs):
|
||||
return _prims.TensorMeta(impl_aten(*args, **kwargs))
|
||||
|
||||
return _prims._make_prim(
|
||||
schema=schema,
|
||||
|
|
@ -93,3 +100,31 @@ fma = make_prim(
|
|||
lambda a, b, c: (a * b) + c,
|
||||
doc="Fused multiply add: fma(a, b, c) -> (a * b) + c without rounding after the multiplication",
|
||||
)
|
||||
|
||||
|
||||
def _low_memory_max_pool2d_with_offsets_aten(
|
||||
self,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
ceil_mode,
|
||||
):
|
||||
vals, indices = torch.ops.aten.max_pool2d_with_indices(
|
||||
self, kernel_size, stride, padding, dilation, ceil_mode
|
||||
)
|
||||
return vals, indices.to(torch.int8)
|
||||
|
||||
|
||||
_low_memory_max_pool2d_with_offsets = make_prim(
|
||||
"_low_memory_max_pool2d_with_offsets(Tensor self, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, bool ceil_mode) -> (Tensor, Tensor)", # noqa: B950
|
||||
_low_memory_max_pool2d_with_offsets_aten,
|
||||
return_type=(_prims.RETURN_TYPE.NEW, _prims.RETURN_TYPE.NEW),
|
||||
doc="Instead of returning indices, returns indices offsets.",
|
||||
)
|
||||
|
||||
_low_memory_max_pool2d_offsets_to_indices = make_prim(
|
||||
"_low_memory_max_pool2d_offsets_to_indices(Tensor self, SymInt kernel_w, SymInt input_w, SymInt[2] stride, SymInt[2] padding) -> Tensor", # noqa: B950
|
||||
lambda self, *args: self.to(torch.int64),
|
||||
doc="Convert small int offsets to regular indices.",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3511,15 +3511,14 @@ def pooling_size(x, i, kernel_size, stride, padding, ceil_mode):
|
|||
return x_out, ceil_mode
|
||||
|
||||
|
||||
fallback_max_pool2d_with_indices = fallback_handler(
|
||||
aten.max_pool2d_with_indices.default,
|
||||
add_to_fallback_set=False,
|
||||
)
|
||||
def should_fallback_max_pool2d_with_indices(kernel_size, dilation):
|
||||
kernel_size = pad_listlike(kernel_size, 2)
|
||||
window_size = kernel_size[0] * kernel_size[1]
|
||||
return (window_size > 25) or any(d > 1 for d in dilation)
|
||||
|
||||
|
||||
@register_lowering(aten.max_pool2d_with_indices, type_promotion_kind=None)
|
||||
def max_pool2d_with_indices(
|
||||
x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False
|
||||
def max_pool2d_checks(
|
||||
x, kernel_size, stride, padding, dilation, *, assert_fallback=None
|
||||
):
|
||||
if padding == 0:
|
||||
padding = [0, 0]
|
||||
|
|
@ -3527,6 +3526,7 @@ def max_pool2d_with_indices(
|
|||
dilation = [1, 1]
|
||||
if not stride:
|
||||
stride = kernel_size
|
||||
|
||||
kernel_size = pad_listlike(kernel_size, 2)
|
||||
stride = pad_listlike(stride, 2)
|
||||
padding = pad_listlike(padding, 2)
|
||||
|
|
@ -3539,36 +3539,51 @@ def max_pool2d_with_indices(
|
|||
assert len(dilation) == 2
|
||||
assert len(x.get_size()) in (3, 4)
|
||||
|
||||
use_fallback = should_fallback_max_pool2d_with_indices(kernel_size, dilation)
|
||||
if assert_fallback is not None:
|
||||
assert use_fallback == assert_fallback
|
||||
|
||||
return kernel_size, stride, padding, dilation, use_fallback
|
||||
|
||||
|
||||
@register_lowering(prims._low_memory_max_pool2d_with_offsets, type_promotion_kind=None)
|
||||
def _low_memory_max_pool2d_with_offsets(
|
||||
x,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
ceil_mode=False,
|
||||
):
|
||||
# assert we are not on a fallback path, the inductor decomp should have guaranteed this
|
||||
kernel_size, stride, padding, dilation, _ = max_pool2d_checks(
|
||||
x, kernel_size, stride, padding, dilation, assert_fallback=False
|
||||
)
|
||||
|
||||
x.realize_hint()
|
||||
*batch, h, w = x.get_size()
|
||||
|
||||
h_out, ceil_mode1 = pooling_size(h, 0, kernel_size, stride, padding, ceil_mode)
|
||||
w_out, ceil_mode2 = pooling_size(w, 1, kernel_size, stride, padding, ceil_mode)
|
||||
|
||||
new_size = list(batch) + [h_out, w_out]
|
||||
if padding[0] or padding[1] or ceil_mode1 or ceil_mode2:
|
||||
x_loader = constant_boundary_condition(x, float("-inf"), dim=2)
|
||||
else:
|
||||
x_loader = x.make_loader()
|
||||
|
||||
new_size = list(batch) + [h_out, w_out]
|
||||
window_size = kernel_size[0] * kernel_size[1]
|
||||
|
||||
if window_size > 25 or any(d != 1 for d in dilation):
|
||||
# Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
|
||||
return fallback_max_pool2d_with_indices(
|
||||
x, kernel_size, stride, padding, dilation, ceil_mode
|
||||
)
|
||||
|
||||
def fn(idx, return_index):
|
||||
*prefix, bh, bw = idx
|
||||
maxval = None
|
||||
maxindex = None
|
||||
for ih, iw in itertools.product(range(kernel_size[0]), range(kernel_size[1])):
|
||||
ih = bh * stride[0] + ih - padding[0]
|
||||
iw = bw * stride[1] + iw - padding[1]
|
||||
for h_inc, w_inc in itertools.product(
|
||||
range(kernel_size[0]), range(kernel_size[1])
|
||||
):
|
||||
ih = bh * stride[0] + h_inc - padding[0]
|
||||
iw = bw * stride[1] + w_inc - padding[1]
|
||||
val = x_loader([*prefix, ih, iw])
|
||||
if return_index:
|
||||
index = ops.index_expr(ih * w + iw, torch.int64)
|
||||
index = ops.index_expr(h_inc * kernel_size[1] + w_inc, torch.int8)
|
||||
if maxindex is None:
|
||||
maxindex = index
|
||||
else:
|
||||
|
|
@ -3582,20 +3597,58 @@ def max_pool2d_with_indices(
|
|||
else:
|
||||
return maxval
|
||||
|
||||
r1 = Pointwise.create(
|
||||
out = Pointwise.create(
|
||||
device=x.get_device(),
|
||||
dtype=x.get_dtype(),
|
||||
inner_fn=functools.partial(fn, return_index=False),
|
||||
ranges=new_size,
|
||||
)
|
||||
r2 = Pointwise.create(
|
||||
offsets = Pointwise.create(
|
||||
device=x.get_device(),
|
||||
dtype=torch.int64,
|
||||
dtype=torch.int8,
|
||||
inner_fn=functools.partial(fn, return_index=True),
|
||||
ranges=new_size,
|
||||
)
|
||||
# TODO(jansel): should we force these to be realized?
|
||||
return r1, r2
|
||||
return out, offsets
|
||||
|
||||
|
||||
@register_lowering(
|
||||
prims._low_memory_max_pool2d_offsets_to_indices, type_promotion_kind=None
|
||||
)
|
||||
def _low_memory_max_pool2d_offsets_to_indices(
|
||||
offsets, kernel_width, input_width, stride, padding
|
||||
):
|
||||
# TODO: Generalize to other max pooling flavors, and arbitrary dim
|
||||
|
||||
offsets_loader = offsets.make_loader()
|
||||
|
||||
def increments_to_index(h_inc, w_inc, bh, bw):
|
||||
w_in = ops.index_expr(input_width, torch.int64)
|
||||
hbase = ops.index_expr(bh * stride[0] - padding[0], torch.int64)
|
||||
wbase = ops.index_expr(bw * stride[1] - padding[1], torch.int64)
|
||||
ih = hbase + h_inc
|
||||
iw = wbase + w_inc
|
||||
return ih * w_in + iw
|
||||
|
||||
def offsets_to_indices(idx):
|
||||
*prefix, bh, bw = idx
|
||||
offset = offsets_loader([*prefix, bh, bw])
|
||||
kw_const = ops.constant(kernel_width, torch.int32)
|
||||
h_inc = offset // kw_const
|
||||
w_inc = offset - (h_inc * kw_const)
|
||||
return increments_to_index(h_inc, w_inc, bh, bw)
|
||||
|
||||
indices = Pointwise.create(
|
||||
device=offsets.get_device(),
|
||||
dtype=torch.int64,
|
||||
inner_fn=offsets_to_indices,
|
||||
ranges=offsets.get_size(),
|
||||
)
|
||||
return indices
|
||||
|
||||
|
||||
# Fallback selected when we do not decompose to the low-memory path.
|
||||
make_fallback(aten.max_pool2d_with_indices)
|
||||
|
||||
|
||||
fallback_max_pool2d_with_indices_backward = fallback_handler(
|
||||
|
|
@ -3658,8 +3711,6 @@ def max_pool2d_with_indices_backward(
|
|||
grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices
|
||||
)
|
||||
|
||||
indices.realize_hint()
|
||||
|
||||
*batch, height, width = x.get_size()
|
||||
*_, pooled_height, pooled_width = grad_output.get_size()
|
||||
|
||||
|
|
@ -3970,7 +4021,21 @@ def adaptive_max_pool2d(x, output_size):
|
|||
)
|
||||
if h_in % h_out == 0 and w_in % w_out == 0:
|
||||
kernel_size = [h_in // h_out, w_in // w_out]
|
||||
return max_pool2d_with_indices(x, kernel_size)
|
||||
if should_fallback_max_pool2d_with_indices(kernel_size, dilation=[1, 1]):
|
||||
return max_pool2d_with_indices(x, kernel_size) # type: ignore[name-defined] # noqa: F821
|
||||
else:
|
||||
v, offsets = _low_memory_max_pool2d_with_offsets(
|
||||
x,
|
||||
kernel_size,
|
||||
stride=kernel_size,
|
||||
padding=[0, 0],
|
||||
dilation=[1, 1],
|
||||
ceil_mode=False,
|
||||
)
|
||||
indices = _low_memory_max_pool2d_offsets_to_indices(
|
||||
offsets, kernel_size[1], w_in, kernel_size, padding=[0, 0]
|
||||
)
|
||||
return v, indices
|
||||
|
||||
h_kernel_max = ceildiv((h_in + h_out - 1), h_out)
|
||||
w_kernel_max = ceildiv((w_in + w_out - 1), w_out)
|
||||
|
|
|
|||
Loading…
Reference in a new issue