mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add type check for dilation in torch.quantized_max_pool3d() (#137845)
Fixes #136716 repro: ```python import torch input = torch.randn([1, 1, 1, 1, 1]) input = torch.quantize_per_tensor(input, 0.1, 10, torch.qint32) torch.quantized_max_pool3d(input, (1, 1, 1), (1, 1, 1), (0, 0, 0), (-3, 1, 1)) # crash input = torch.randn([1, 1, 1, 1, 1]) input = torch.quantize_per_tensor(input, 0.1, 10, torch.qint32) result = torch.nn.functional.max_pool3d(input, (1, 1, 1), (1, 1, 1), (0, 0, 0), (-3, 1, 1)) # crash ``` result: ``` RuntimeError: Expected dilation >= 1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/137845 Approved by: https://github.com/albanD
This commit is contained in:
parent
a8b912f39d
commit
279ddfc6ee
3 changed files with 19 additions and 0 deletions
|
|
@ -478,6 +478,8 @@ void check_maxpool2d_params(
|
|||
"Expected 1d or 2d padding, got ", padding.size());
|
||||
TORCH_CHECK(dilation.size() == 1 || dilation.size() == 2,
|
||||
"Expected 1d or 2d dilation, got ", dilation.size());
|
||||
TORCH_CHECK(dilation.allMatch([](const auto& ele) { return ele >= 1L; }),
|
||||
"Expected dilation >= 1");
|
||||
}
|
||||
|
||||
void check_maxpool3d_params(
|
||||
|
|
@ -490,6 +492,8 @@ void check_maxpool3d_params(
|
|||
"Expected no strides or 3d strides, got", stride.size());
|
||||
TORCH_CHECK(padding.size() == 3, "Expected 3d padding, got ", padding.size());
|
||||
TORCH_CHECK(dilation.size() == 3, "Expected 1d or 3d dilation, got ", dilation.size());
|
||||
TORCH_CHECK(dilation.allMatch([](const auto& ele) { return ele >= 1L; }),
|
||||
"Expected dilation >= 1");
|
||||
}
|
||||
|
||||
#ifdef USE_PYTORCH_QNNPACK
|
||||
|
|
|
|||
|
|
@ -162,6 +162,11 @@ class ArrayRef final {
|
|||
return reverse_iterator(begin());
|
||||
}
|
||||
|
||||
/// Check if all elements in the array satisfy the given expression
|
||||
constexpr bool allMatch(const std::function<bool(const T&)>& pred) const {
|
||||
return std::all_of(cbegin(), cend(), pred);
|
||||
}
|
||||
|
||||
/// empty - Check if the array is empty.
|
||||
constexpr bool empty() const {
|
||||
return Length == 0;
|
||||
|
|
|
|||
|
|
@ -492,6 +492,16 @@ class TestPoolingNN(NNTestCase):
|
|||
with self.assertRaises(RuntimeError):
|
||||
torch.quantized_max_pool1d(temp_tensor, [])
|
||||
|
||||
def test_quantized_max_pool3d(self):
|
||||
# This used to segfault when called with a negative dilation
|
||||
# see https://github.com/pytorch/pytorch/issues/136716
|
||||
input = torch.randn([1, 1, 1, 1, 1])
|
||||
input = torch.quantize_per_tensor(input, -0.1, -10, torch.qint32)
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected dilation >= 1"):
|
||||
torch.quantized_max_pool3d(
|
||||
input, (1, 1, 1), (1, 1, 1), (0, 0, 0), (-3, 1, 1)
|
||||
)
|
||||
|
||||
|
||||
class TestPoolingNNDeviceType(NNTestCase):
|
||||
@onlyNativeDeviceTypes
|
||||
|
|
|
|||
Loading…
Reference in a new issue