diff --git a/test/onnx/expect/TestOperators.test_avg_pool2d.expect b/test/onnx/expect/TestOperators.test_avg_pool2d.expect index d551ff38f80..4839fb5a35a 100644 --- a/test/onnx/expect/TestOperators.test_avg_pool2d.expect +++ b/test/onnx/expect/TestOperators.test_avg_pool2d.expect @@ -1,6 +1,6 @@ ir_version: 7 producer_name: "pytorch" -producer_version: "CURRENT_VERSION" +producer_version: "1.14.0" graph { node { output: "onnx::Pad_1" @@ -33,11 +33,6 @@ graph { output: "3" name: "AveragePool_2" op_type: "AveragePool" - attribute { - name: "ceil_mode" - i: 0 - type: INT - } attribute { name: "kernel_shape" ints: 3 diff --git a/torch/onnx/symbolic_opset10.py b/torch/onnx/symbolic_opset10.py index f20a1290ca1..27cb161a1ae 100644 --- a/torch/onnx/symbolic_opset10.py +++ b/torch/onnx/symbolic_opset10.py @@ -1,7 +1,7 @@ import functools import sys import warnings -from typing import Callable, Sequence +from typing import Callable import torch import torch._C._onnx as _C_onnx @@ -251,47 +251,12 @@ def _max_pool(name: str, tuple_fn: Callable, ndims: int, return_indices: bool): ) @_beartype.beartype def _avg_pool(name, tuple_fn): - @symbolic_helper.quantized_args(True, False, False, False, False, False, False) - @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") - @_beartype.beartype - def symbolic_fn( - g, - input: _C.Value, - kernel_size: Sequence[int], - stride: Sequence[int], - padding: Sequence[int], - ceil_mode: int, - count_include_pad: int, - divisor_override=None, - ): - if not stride: - stride = kernel_size - padding = symbolic_helper._avgpool_helper( - tuple_fn, padding, kernel_size, stride, divisor_override, name - ) - assert isinstance(padding, tuple) - if count_include_pad: - input = opset9._op_with_optional_float_cast( - g, - "Pad", - input, - pads_i=((0,) * 2 + padding) * 2, - mode_s="constant", - value_f=0.0, - opset_before=11, - ) - padding = (0,) * len(padding) - output = g.op( - "AveragePool", - input, - kernel_shape_i=tuple_fn(kernel_size), - strides_i=tuple_fn(stride), - pads_i=padding * 2, - ceil_mode_i=ceil_mode, - ) - return output - - return symbolic_fn + # Although onnx::AvgPool provides count_include_pad and ceil_mode, + # The corner case of Average Pooling with ceil_mode on + # PyTorch allows sliding window go off bound, which leads to + # this accommodation. + # More detail on https://github.com/pytorch/pytorch/issues/57178 + return opset9._avg_pool(name, tuple_fn) @_onnx_symbolic( diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index c845d6dcc2e..6c71cc16515 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -590,12 +590,18 @@ def _avg_pool(name, tuple_fn): count_include_pad: int, divisor_override=None, ): + # Although onnx::AvgPool provides count_include_pad and ceil_mode, + # The corner case of Average Pooling with ceil_mode on + # PyTorch allows sliding window go off bound, which leads to + # this accommodation. + # More detail on https://github.com/pytorch/pytorch/issues/57178 + if not stride: + stride = kernel_size padding = symbolic_helper._avgpool_helper( tuple_fn, padding, kernel_size, stride, divisor_override, name ) assert isinstance(padding, tuple) - if not stride: - stride = kernel_size + adjusted_padding = padding if count_include_pad: input = g.op( "Pad", @@ -603,14 +609,22 @@ def _avg_pool(name, tuple_fn): g.op("Constant", value_t=torch.tensor(((0,) * 2 + padding) * 2)), mode_s="constant", ) - padding = (0,) * len(padding) + adjusted_padding = (0,) * len(padding) + if ceil_mode: + padding_ceil = opset9.get_pool_ceil_padding( + input, kernel_size, stride, padding + ) + adjusted_padding = adjusted_padding + tuple( + a + b for (a, b) in zip(padding_ceil, adjusted_padding) + ) + else: + adjusted_padding = adjusted_padding * 2 output = g.op( "AveragePool", input, kernel_shape_i=tuple_fn(kernel_size), strides_i=tuple_fn(stride), - pads_i=padding * 2, - ceil_mode_i=ceil_mode, + pads_i=adjusted_padding, ) return output diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index bbb97f3f8d7..d31bb8d1a9d 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -1650,13 +1650,20 @@ def _avg_pool(name, tuple_fn): ) assert isinstance(padding, tuple) adjusted_padding = padding + # Although onnx::AvgPool provides count_include_pad, + # The corner case of Average Pooling with ceil_mode on + # PyTorch allows sliding window go off bound, which leads to + # this accommodation. + # More detail on https://github.com/pytorch/pytorch/issues/57178 if count_include_pad: - input = g.op( + input = _op_with_optional_float_cast( + g, "Pad", input, pads_i=((0,) * 2 + padding) * 2, mode_s="constant", value_f=0.0, + opset_before=11, ) adjusted_padding = (0,) * len(padding) if ceil_mode: