mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
use mkldnn_max_pool2d support CPU inductor
This commit is contained in:
parent
d2dc41a92d
commit
80000cfd2c
4 changed files with 47 additions and 1 deletions
|
|
@ -1012,6 +1012,49 @@ def max_pool2d_with_indices(
|
|||
)
|
||||
return vals, indices
|
||||
|
||||
@register_decomposition(aten.mkldnn_max_pool2d)
|
||||
def mkldnn_max_pool2d(
|
||||
x: torch.Tensor,
|
||||
kernel_size: list[int],
|
||||
stride: Optional[Union[int, list[int]]] = None,
|
||||
padding: Union[int, list[int]] = 0,
|
||||
dilation: Union[int, list[int]] = 1,
|
||||
ceil_mode: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if dilation == 1:
|
||||
dilation = [1, 1]
|
||||
|
||||
if padding == 0:
|
||||
padding = [0, 0]
|
||||
|
||||
if not stride:
|
||||
stride = kernel_size
|
||||
|
||||
kernel_size = pad_listlike(kernel_size, 2)
|
||||
dilation = pad_listlike(dilation, 2)
|
||||
padding = pad_listlike(padding, 2)
|
||||
stride = pad_listlike(stride, 2)
|
||||
|
||||
window_size = kernel_size[0] * kernel_size[1]
|
||||
# We fallback when using non-default dilation or when the window size is too large
|
||||
if (
|
||||
torch._inductor.lowering.should_fallback_max_pool2d_with_indices(
|
||||
kernel_size, dilation
|
||||
)
|
||||
or window_size > torch.iinfo(torch.int8).max
|
||||
):
|
||||
return NotImplemented
|
||||
|
||||
vals, _ = prims._low_memory_max_pool2d_with_offsets(
|
||||
x,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
ceil_mode,
|
||||
)
|
||||
|
||||
return vals
|
||||
|
||||
@register_decomposition(aten.adaptive_max_pool2d)
|
||||
def adaptive_max_pool2d(
|
||||
|
|
|
|||
|
|
@ -218,6 +218,7 @@ add_needs_realized_inputs(
|
|||
aten.convolution_backward,
|
||||
aten.max_pool2d_with_indices,
|
||||
aten.max_pool2d_with_indices_backward,
|
||||
aten.mkldnn_max_pool2d,
|
||||
aten.mm,
|
||||
aten.upsample_nearest2d,
|
||||
aten._upsample_nearest_exact2d,
|
||||
|
|
@ -4297,7 +4298,7 @@ def _low_memory_max_pool2d_offsets_to_indices(
|
|||
|
||||
# Fallback selected when we do not decompose to the low-memory path.
|
||||
make_fallback(aten.max_pool2d_with_indices)
|
||||
|
||||
make_fallback(aten.mkldnn_max_pool2d)
|
||||
|
||||
fallback_max_pool2d_with_indices_backward = fallback_handler(
|
||||
aten.max_pool2d_with_indices_backward.default,
|
||||
|
|
|
|||
|
|
@ -93,6 +93,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_max_pool3d_with_indices_backward
|
|||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_max_unpool2d(AtenTensorHandle self, AtenTensorHandle indices, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_max_unpool3d(AtenTensorHandle self, AtenTensorHandle indices, const int64_t* output_size, int64_t output_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_median(AtenTensorHandle self, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mkldnn_max_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mode(AtenTensorHandle self, int64_t dim, int32_t keepdim, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mul_Scalar(AtenTensorHandle self, double other, AtenTensorHandle* ret0);
|
||||
|
|
|
|||
|
|
@ -84,6 +84,7 @@ inductor_fallback_ops = {
|
|||
"aten.max_pool2d_with_indices.default",
|
||||
"aten.max_pool3d_with_indices.default",
|
||||
"aten.max_pool3d_with_indices_backward.default",
|
||||
"aten.mkldnn_max_pool2d.default",
|
||||
"aten.max_unpool2d.default",
|
||||
"aten.max_unpool3d.default",
|
||||
"aten.median.default",
|
||||
|
|
|
|||
Loading…
Reference in a new issue