Add type check for dilation in torch.quantized_max_pool3d() (#137845)

Fixes #136716

repro:

```python
import torch

input = torch.randn([1, 1, 1, 1, 1])
input = torch.quantize_per_tensor(input, 0.1, 10, torch.qint32)
torch.quantized_max_pool3d(input, (1, 1, 1), (1, 1, 1), (0, 0, 0), (-3, 1, 1)) # crash

input = torch.randn([1, 1, 1, 1, 1])
input = torch.quantize_per_tensor(input, 0.1, 10, torch.qint32)
result = torch.nn.functional.max_pool3d(input, (1, 1, 1), (1, 1, 1), (0, 0, 0), (-3, 1, 1))  # crash
```

result:

```
RuntimeError: Expected dilation >= 1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137845
Approved by: https://github.com/albanD
This commit is contained in:
Yuanhao Ji 2024-10-21 16:15:54 +00:00 committed by PyTorch MergeBot
parent a8b912f39d
commit 279ddfc6ee
3 changed files with 19 additions and 0 deletions

View file

@ -478,6 +478,8 @@ void check_maxpool2d_params(
"Expected 1d or 2d padding, got ", padding.size());
TORCH_CHECK(dilation.size() == 1 || dilation.size() == 2,
"Expected 1d or 2d dilation, got ", dilation.size());
TORCH_CHECK(dilation.allMatch([](const auto& ele) { return ele >= 1L; }),
"Expected dilation >= 1");
}
void check_maxpool3d_params(
@ -490,6 +492,8 @@ void check_maxpool3d_params(
"Expected no strides or 3d strides, got", stride.size());
TORCH_CHECK(padding.size() == 3, "Expected 3d padding, got ", padding.size());
TORCH_CHECK(dilation.size() == 3, "Expected 1d or 3d dilation, got ", dilation.size());
TORCH_CHECK(dilation.allMatch([](const auto& ele) { return ele >= 1L; }),
"Expected dilation >= 1");
}
#ifdef USE_PYTORCH_QNNPACK

View file

@ -162,6 +162,11 @@ class ArrayRef final {
return reverse_iterator(begin());
}
/// Check if all elements in the array satisfy the given expression
constexpr bool allMatch(const std::function<bool(const T&)>& pred) const {
return std::all_of(cbegin(), cend(), pred);
}
/// empty - Check if the array is empty.
constexpr bool empty() const {
return Length == 0;

View file

@ -492,6 +492,16 @@ class TestPoolingNN(NNTestCase):
with self.assertRaises(RuntimeError):
torch.quantized_max_pool1d(temp_tensor, [])
def test_quantized_max_pool3d(self):
# This used to segfault when called with a negative dilation
# see https://github.com/pytorch/pytorch/issues/136716
input = torch.randn([1, 1, 1, 1, 1])
input = torch.quantize_per_tensor(input, -0.1, -10, torch.qint32)
with self.assertRaisesRegex(RuntimeError, "Expected dilation >= 1"):
torch.quantized_max_pool3d(
input, (1, 1, 1), (1, 1, 1), (0, 0, 0), (-3, 1, 1)
)
class TestPoolingNNDeviceType(NNTestCase):
@onlyNativeDeviceTypes