fixes torch.jit.script lp_pool bug. (#73287)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/60258

I used the solution proposed in https://github.com/pytorch/pytorch/issues/61275.  His solution failed unit tests and there was no progress after 08/07/2021. I'm willing to fix problems if they arise during CI.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73287

Reviewed By: navahgar, zou3519

Differential Revision: D35057812

Pulled By: eellison

fbshipit-source-id: 8e82e9f73b9536979aecf476c5c65336cdffc93a
(cherry picked from commit e85e912a4edec1111623c5cbbba4171fe3bc5b1d)
This commit is contained in:
Davit Kobaladze 2022-03-28 16:09:51 -07:00 committed by PyTorch MergeBot
parent 8ed6cb42ba
commit 8e12d2bf25
2 changed files with 55 additions and 5 deletions

View file

@ -970,6 +970,56 @@ class TestJit(JitTestCase):
m_dropout.eval()
self.assertEqual(dropout(input) + 1, m_dropout(input))
def test_nn_lp_pool2d(self):
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.l = torch.nn.LPPool2d(2, 3)
self.n = torch.nn.LPPool2d(2, (7, 1))
def forward(self, x):
return (self.l(x),
self.n(x),
torch.nn.functional.lp_pool2d(x, float(2), 3),
torch.nn.functional.lp_pool2d(x, 2, 3),
torch.nn.functional.lp_pool2d(x, float(2), (7, 1)))
self.checkModule(Mod(), (torch.rand(1, 3, 7, 7),))
def test_nn_lp_pool1d(self):
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.l = torch.nn.LPPool1d(2, 3)
self.n = torch.nn.LPPool1d(2, 7)
def forward(self, x):
return (self.l(x),
self.n(x),
torch.nn.functional.lp_pool1d(x, float(2), 3),
torch.nn.functional.lp_pool1d(x, 2, 3),
torch.nn.functional.lp_pool1d(x, float(2), 7))
self.checkModule(Mod(), (torch.rand(1, 3, 7),))
def test_nn_padding_functional(self):
class Mod(nn.Module):
def __init__(self, *pad):
super().__init__()
self.pad = pad
def forward(self, x):
return F.pad(x, self.pad, mode='constant', value=3.5)
inputs = [
(Mod(1, 2), torch.randn(1, 3, 4)), # 1D
(Mod(1, 2, 3, 4), torch.randn(1, 3, 4)), # 2D
(Mod(1, 2, 3, 4, 5, 6), torch.randn(1, 3, 4)), # 3D
]
for m, inp in inputs:
self.checkModule(m, (inp,))
def test_nn_padding(self):
class Mod(nn.Module):
def __init__(self, padding):

View file

@ -1,5 +1,5 @@
r"""Functional interface"""
from typing import Callable, List, Optional, Tuple
from typing import Callable, List, Optional, Tuple, Union
import math
import warnings
@ -1004,8 +1004,8 @@ def max_unpool3d(
def lp_pool2d(
input: Tensor, norm_type: float,
kernel_size: int,
input: Tensor, norm_type: Union[int, float],
kernel_size: BroadcastingList2[int],
stride: Optional[BroadcastingList2[int]] = None,
ceil_mode: bool = False
) -> Tensor:
@ -1029,7 +1029,7 @@ def lp_pool2d(
def lp_pool1d(
input: Tensor, norm_type: float,
input: Tensor, norm_type: Union[int, float],
kernel_size: int,
stride: Optional[BroadcastingList1[int]] = None,
ceil_mode: bool = False
@ -4263,7 +4263,7 @@ def affine_grid(theta: Tensor, size: List[int], align_corners: Optional[bool] =
return torch.affine_grid_generator(theta, size, align_corners)
def _pad(input: Tensor, pad: List[int], mode: str = "constant", value: float = 0.0) -> Tensor:
def _pad(input: Tensor, pad: BroadcastingList1[int], mode: str = "constant", value: Union[int, float] = 0.0) -> Tensor:
r"""Pads tensor.
Padding size: