From 80000cfd2c0598d2c65daa11207b80d51aadef51 Mon Sep 17 00:00:00 2001 From: CaoE Date: Mon, 10 Feb 2025 03:35:08 -0500 Subject: [PATCH] use mkldnn_max_pool2d support CPU inductor --- torch/_inductor/decomposition.py | 43 +++++++++++++++++++ torch/_inductor/lowering.py | 3 +- .../aoti_torch/generated/c_shim_cpu.h | 1 + torchgen/aoti/fallback_ops.py | 1 + 4 files changed, 47 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 19ceafc5e76..c691e59d40d 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -1012,6 +1012,49 @@ def max_pool2d_with_indices( ) return vals, indices +@register_decomposition(aten.mkldnn_max_pool2d) +def mkldnn_max_pool2d( + x: torch.Tensor, + kernel_size: list[int], + stride: Optional[Union[int, list[int]]] = None, + padding: Union[int, list[int]] = 0, + dilation: Union[int, list[int]] = 1, + ceil_mode: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + if dilation == 1: + dilation = [1, 1] + + if padding == 0: + padding = [0, 0] + + if not stride: + stride = kernel_size + + kernel_size = pad_listlike(kernel_size, 2) + dilation = pad_listlike(dilation, 2) + padding = pad_listlike(padding, 2) + stride = pad_listlike(stride, 2) + + window_size = kernel_size[0] * kernel_size[1] + # We fallback when using non-default dilation or when the window size is too large + if ( + torch._inductor.lowering.should_fallback_max_pool2d_with_indices( + kernel_size, dilation + ) + or window_size > torch.iinfo(torch.int8).max + ): + return NotImplemented + + vals, _ = prims._low_memory_max_pool2d_with_offsets( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + ) + + return vals @register_decomposition(aten.adaptive_max_pool2d) def adaptive_max_pool2d( diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index dd89cecb0fc..2a2d85fa0f2 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -218,6 +218,7 @@ add_needs_realized_inputs( aten.convolution_backward, aten.max_pool2d_with_indices, aten.max_pool2d_with_indices_backward, + aten.mkldnn_max_pool2d, aten.mm, aten.upsample_nearest2d, aten._upsample_nearest_exact2d, @@ -4297,7 +4298,7 @@ def _low_memory_max_pool2d_offsets_to_indices( # Fallback selected when we do not decompose to the low-memory path. make_fallback(aten.max_pool2d_with_indices) - +make_fallback(aten.mkldnn_max_pool2d) fallback_max_pool2d_with_indices_backward = fallback_handler( aten.max_pool2d_with_indices_backward.default, diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h index 2a5eb60e9c8..b8afe60d35b 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h @@ -93,6 +93,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_max_pool3d_with_indices_backward AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_max_unpool2d(AtenTensorHandle self, AtenTensorHandle indices, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_max_unpool3d(AtenTensorHandle self, AtenTensorHandle indices, const int64_t* output_size, int64_t output_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_median(AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mkldnn_max_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mode(AtenTensorHandle self, int64_t dim, int32_t keepdim, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mul_Scalar(AtenTensorHandle self, double other, AtenTensorHandle* ret0); diff --git a/torchgen/aoti/fallback_ops.py b/torchgen/aoti/fallback_ops.py index dead690831f..8396e0f6af6 100644 --- a/torchgen/aoti/fallback_ops.py +++ b/torchgen/aoti/fallback_ops.py @@ -84,6 +84,7 @@ inductor_fallback_ops = { "aten.max_pool2d_with_indices.default", "aten.max_pool3d_with_indices.default", "aten.max_pool3d_with_indices_backward.default", + "aten.mkldnn_max_pool2d.default", "aten.max_unpool2d.default", "aten.max_unpool3d.default", "aten.median.default",