diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp index 984e60056af..0b71379bcf1 100644 --- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp +++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp @@ -1,3 +1,4 @@ +#include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include @@ -178,6 +179,7 @@ static void aminmax_kernel( " but got ", min_result.scalar_type(), " and ", max_result.scalar_type()); if (self.numel() == 1 && self.ndimension() == 0) { + TORCH_CHECK(!self.is_complex(), "aminmax not implemented for ", self.scalar_type()); min_result.resize_({}); max_result.resize_({}); min_result.fill_(self); diff --git a/test/test_reductions.py b/test/test_reductions.py index 90f192199d2..8e07692cb58 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -1228,6 +1228,12 @@ class TestReductions(TestCase): self._test_minmax_helper(_amin_wrapper, np.amin, device, dtype) self._test_minmax_helper(_amax_wrapper, np.amax, device, dtype) + @onlyNativeDeviceTypes + @dtypes(*complex_types()) + def test_invalid_0dim_aminmax(self, device, dtype): + with self.assertRaisesRegex(RuntimeError, 'not implemented'): + torch.aminmax(torch.tensor(1., dtype=dtype, device=device), dim=0) + # TODO: bincount isn't a classic reduction -- maybe this test suite is # reductions and summary ops? def test_bincount(self, device):