Errors when 0-dim tensor of complex or bool type passed to aminmax. (#128404)

Fixes #126742

Added errors for the case of 0-dim tensors of complex or bool types passed to aminmax.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128404
Approved by: https://github.com/janeyx99
This commit is contained in:
ajbrent 2024-06-24 21:46:49 +00:00 committed by PyTorch MergeBot
parent 18fdc0ae5b
commit 30bfdf1afc
2 changed files with 8 additions and 0 deletions

View file

@ -1,3 +1,4 @@
#include <c10/core/ScalarType.h>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/native/ReduceOps.h>
@ -178,6 +179,7 @@ static void aminmax_kernel(
" but got ", min_result.scalar_type(), " and ", max_result.scalar_type());
if (self.numel() == 1 && self.ndimension() == 0) {
TORCH_CHECK(!self.is_complex(), "aminmax not implemented for ", self.scalar_type());
min_result.resize_({});
max_result.resize_({});
min_result.fill_(self);

View file

@ -1228,6 +1228,12 @@ class TestReductions(TestCase):
self._test_minmax_helper(_amin_wrapper, np.amin, device, dtype)
self._test_minmax_helper(_amax_wrapper, np.amax, device, dtype)
@onlyNativeDeviceTypes
@dtypes(*complex_types())
def test_invalid_0dim_aminmax(self, device, dtype):
with self.assertRaisesRegex(RuntimeError, 'not implemented'):
torch.aminmax(torch.tensor(1., dtype=dtype, device=device), dim=0)
# TODO: bincount isn't a classic reduction -- maybe this test suite is
# reductions and summary ops?
def test_bincount(self, device):