[ONNX] Disable ONNX ceil_mode and count_include_pad to aligntorch ceil_mode results in corner case (#87892)

ONNX and PyTorch has different equation on pooling and different strategy on ceil_mode, which leads to discrepancy on corner case (#71549 ).
Specifically, PyTorch avereage pooling is not following [the equation on documentation](https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html), it allows sliding window to go off-bound instead, if they start within the left padding or the input (in NOTE section). More details can be found in #57178.

This PR changes avgpool in opset 10 and 11 back the way as opset 9, which it stops using ceil_mode and count_include_pad  in onnx::AveragePool

A comprehensive test for all combinations of parameters can be found in the next PR. #87893
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87892
Approved by: https://github.com/BowenBao
This commit is contained in:
AllenTiTaiWang 2022-10-28 19:31:23 +00:00 committed by PyTorch MergeBot
parent c810489dd9
commit f2ae459311
4 changed files with 35 additions and 54 deletions

View file

@ -1,6 +1,6 @@
ir_version: 7
producer_name: "pytorch"
producer_version: "CURRENT_VERSION"
producer_version: "1.14.0"
graph {
node {
output: "onnx::Pad_1"
@ -33,11 +33,6 @@ graph {
output: "3"
name: "AveragePool_2"
op_type: "AveragePool"
attribute {
name: "ceil_mode"
i: 0
type: INT
}
attribute {
name: "kernel_shape"
ints: 3

View file

@ -1,7 +1,7 @@
import functools
import sys
import warnings
from typing import Callable, Sequence
from typing import Callable
import torch
import torch._C._onnx as _C_onnx
@ -251,47 +251,12 @@ def _max_pool(name: str, tuple_fn: Callable, ndims: int, return_indices: bool):
)
@_beartype.beartype
def _avg_pool(name, tuple_fn):
@symbolic_helper.quantized_args(True, False, False, False, False, False, False)
@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
@_beartype.beartype
def symbolic_fn(
g,
input: _C.Value,
kernel_size: Sequence[int],
stride: Sequence[int],
padding: Sequence[int],
ceil_mode: int,
count_include_pad: int,
divisor_override=None,
):
if not stride:
stride = kernel_size
padding = symbolic_helper._avgpool_helper(
tuple_fn, padding, kernel_size, stride, divisor_override, name
)
assert isinstance(padding, tuple)
if count_include_pad:
input = opset9._op_with_optional_float_cast(
g,
"Pad",
input,
pads_i=((0,) * 2 + padding) * 2,
mode_s="constant",
value_f=0.0,
opset_before=11,
)
padding = (0,) * len(padding)
output = g.op(
"AveragePool",
input,
kernel_shape_i=tuple_fn(kernel_size),
strides_i=tuple_fn(stride),
pads_i=padding * 2,
ceil_mode_i=ceil_mode,
)
return output
return symbolic_fn
# Although onnx::AvgPool provides count_include_pad and ceil_mode,
# The corner case of Average Pooling with ceil_mode on
# PyTorch allows sliding window go off bound, which leads to
# this accommodation.
# More detail on https://github.com/pytorch/pytorch/issues/57178
return opset9._avg_pool(name, tuple_fn)
@_onnx_symbolic(

View file

@ -590,12 +590,18 @@ def _avg_pool(name, tuple_fn):
count_include_pad: int,
divisor_override=None,
):
# Although onnx::AvgPool provides count_include_pad and ceil_mode,
# The corner case of Average Pooling with ceil_mode on
# PyTorch allows sliding window go off bound, which leads to
# this accommodation.
# More detail on https://github.com/pytorch/pytorch/issues/57178
if not stride:
stride = kernel_size
padding = symbolic_helper._avgpool_helper(
tuple_fn, padding, kernel_size, stride, divisor_override, name
)
assert isinstance(padding, tuple)
if not stride:
stride = kernel_size
adjusted_padding = padding
if count_include_pad:
input = g.op(
"Pad",
@ -603,14 +609,22 @@ def _avg_pool(name, tuple_fn):
g.op("Constant", value_t=torch.tensor(((0,) * 2 + padding) * 2)),
mode_s="constant",
)
padding = (0,) * len(padding)
adjusted_padding = (0,) * len(padding)
if ceil_mode:
padding_ceil = opset9.get_pool_ceil_padding(
input, kernel_size, stride, padding
)
adjusted_padding = adjusted_padding + tuple(
a + b for (a, b) in zip(padding_ceil, adjusted_padding)
)
else:
adjusted_padding = adjusted_padding * 2
output = g.op(
"AveragePool",
input,
kernel_shape_i=tuple_fn(kernel_size),
strides_i=tuple_fn(stride),
pads_i=padding * 2,
ceil_mode_i=ceil_mode,
pads_i=adjusted_padding,
)
return output

View file

@ -1650,13 +1650,20 @@ def _avg_pool(name, tuple_fn):
)
assert isinstance(padding, tuple)
adjusted_padding = padding
# Although onnx::AvgPool provides count_include_pad,
# The corner case of Average Pooling with ceil_mode on
# PyTorch allows sliding window go off bound, which leads to
# this accommodation.
# More detail on https://github.com/pytorch/pytorch/issues/57178
if count_include_pad:
input = g.op(
input = _op_with_optional_float_cast(
g,
"Pad",
input,
pads_i=((0,) * 2 + padding) * 2,
mode_s="constant",
value_f=0.0,
opset_before=11,
)
adjusted_padding = (0,) * len(padding)
if ceil_mode: