use mkldnn_max_pool2d support CPU inductor

This commit is contained in:
CaoE 2025-02-10 03:35:08 -05:00
parent d2dc41a92d
commit 80000cfd2c
4 changed files with 47 additions and 1 deletions

View file

@ -1012,6 +1012,49 @@ def max_pool2d_with_indices(
)
return vals, indices
@register_decomposition(aten.mkldnn_max_pool2d)
def mkldnn_max_pool2d(
x: torch.Tensor,
kernel_size: list[int],
stride: Optional[Union[int, list[int]]] = None,
padding: Union[int, list[int]] = 0,
dilation: Union[int, list[int]] = 1,
ceil_mode: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
if dilation == 1:
dilation = [1, 1]
if padding == 0:
padding = [0, 0]
if not stride:
stride = kernel_size
kernel_size = pad_listlike(kernel_size, 2)
dilation = pad_listlike(dilation, 2)
padding = pad_listlike(padding, 2)
stride = pad_listlike(stride, 2)
window_size = kernel_size[0] * kernel_size[1]
# We fallback when using non-default dilation or when the window size is too large
if (
torch._inductor.lowering.should_fallback_max_pool2d_with_indices(
kernel_size, dilation
)
or window_size > torch.iinfo(torch.int8).max
):
return NotImplemented
vals, _ = prims._low_memory_max_pool2d_with_offsets(
x,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
)
return vals
@register_decomposition(aten.adaptive_max_pool2d)
def adaptive_max_pool2d(

View file

@ -218,6 +218,7 @@ add_needs_realized_inputs(
aten.convolution_backward,
aten.max_pool2d_with_indices,
aten.max_pool2d_with_indices_backward,
aten.mkldnn_max_pool2d,
aten.mm,
aten.upsample_nearest2d,
aten._upsample_nearest_exact2d,
@ -4297,7 +4298,7 @@ def _low_memory_max_pool2d_offsets_to_indices(
# Fallback selected when we do not decompose to the low-memory path.
make_fallback(aten.max_pool2d_with_indices)
make_fallback(aten.mkldnn_max_pool2d)
fallback_max_pool2d_with_indices_backward = fallback_handler(
aten.max_pool2d_with_indices_backward.default,

View file

@ -93,6 +93,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_max_pool3d_with_indices_backward
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_max_unpool2d(AtenTensorHandle self, AtenTensorHandle indices, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_max_unpool3d(AtenTensorHandle self, AtenTensorHandle indices, const int64_t* output_size, int64_t output_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_median(AtenTensorHandle self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mkldnn_max_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mode(AtenTensorHandle self, int64_t dim, int32_t keepdim, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mul_Scalar(AtenTensorHandle self, double other, AtenTensorHandle* ret0);

View file

@ -84,6 +84,7 @@ inductor_fallback_ops = {
"aten.max_pool2d_with_indices.default",
"aten.max_pool3d_with_indices.default",
"aten.max_pool3d_with_indices_backward.default",
"aten.mkldnn_max_pool2d.default",
"aten.max_unpool2d.default",
"aten.max_unpool3d.default",
"aten.median.default",