From 3682df77dbd47b5f949caae85be73963d544d527 Mon Sep 17 00:00:00 2001 From: kiyosora Date: Mon, 31 Aug 2020 15:43:51 -0700 Subject: [PATCH] Implementing NumPy-like function torch.heaviside() (#42523) Summary: - Related with https://github.com/pytorch/pytorch/issues/38349 - Implementing the NumPy-like function `torch.heaviside()` . Pull Request resolved: https://github.com/pytorch/pytorch/pull/42523 Reviewed By: ngimel Differential Revision: D23416743 Pulled By: mruberry fbshipit-source-id: 9975bd9c9fa73bd0958fe9879f79a692aeb722d5 --- aten/src/ATen/core/aten_interned_strings.h | 1 + aten/src/ATen/native/BinaryOps.cpp | 28 ++++++++ aten/src/ATen/native/BinaryOps.h | 1 + aten/src/ATen/native/cpu/BinaryOpsKernel.cpp | 9 +++ .../ATen/native/cuda/BinaryMiscOpsKernels.cu | 9 +++ aten/src/ATen/native/native_functions.yaml | 11 ++++ docs/source/tensors.rst | 1 + docs/source/torch.rst | 1 + test/test_torch.py | 65 +++++++++++++++++++ torch/_tensor_docs.py | 14 ++++ torch/_torch_docs.py | 34 ++++++++++ torch/overrides.py | 1 + 12 files changed, 175 insertions(+) diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index 57a36b135ea..c4b0e7498c7 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -367,6 +367,7 @@ _(aten, hardsigmoid_backward) \ _(aten, hardtanh) \ _(aten, hardtanh_backward) \ _(aten, hardtanh_forward) \ +_(aten, heaviside) \ _(aten, hinge_embedding_loss) \ _(aten, histc) \ _(aten, hspmm) \ diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp index bccba591a52..91c71722f10 100644 --- a/aten/src/ATen/native/BinaryOps.cpp +++ b/aten/src/ATen/native/BinaryOps.cpp @@ -47,6 +47,7 @@ DEFINE_DISPATCH(gcd_stub); DEFINE_DISPATCH(lcm_stub); DEFINE_DISPATCH(hypot_stub); DEFINE_DISPATCH(nextafter_stub); +DEFINE_DISPATCH(heaviside_stub); static Tensor wrapped_scalar_tensor(Scalar scalar) { auto tensor = scalar_to_tensor(scalar); @@ -952,6 +953,33 @@ Tensor _test_serialization_subcmul(const Tensor& self, const Tensor& other, Scal return self - (other * alpha); } +Tensor& heaviside_out(Tensor& result, const Tensor& self, const Tensor& values) { + TORCH_CHECK(!self.is_complex() && !result.is_complex() && !values.is_complex(), + "heaviside is not yet implemented for complex tensors."); + TORCH_CHECK(self.dtype() == values.dtype() && result.dtype() == self.dtype(), + "heaviside is not yet implemented for tensors with different dtypes."); + + auto iter = TensorIterator::binary_op(result, self, values); + heaviside_stub(iter.device_type(), iter); + return result; +} + +Tensor heaviside(const Tensor& self, const Tensor& values) { + TORCH_CHECK(!self.is_complex() && !values.is_complex(), + "heaviside is not yet implemented for complex tensors."); + TORCH_CHECK(self.dtype() == values.dtype(), + "heaviside is not yet implemented for tensors with different dtypes."); + + Tensor result; + auto iter = TensorIterator::binary_op(result, self, values); + heaviside_stub(iter.device_type(), iter); + return iter.output(); +} + +Tensor& heaviside_(Tensor& self, const Tensor& values) { + return at::heaviside_out(self, self, values); +} + // TODO: Deduplicate this with the TensorIterator logic. This would // also fix the TODOs below. Tensor binary_op_meta(const Tensor& self, const Tensor& other) { diff --git a/aten/src/ATen/native/BinaryOps.h b/aten/src/ATen/native/BinaryOps.h index d9da3a0c14b..e2dad35eb7e 100644 --- a/aten/src/ATen/native/BinaryOps.h +++ b/aten/src/ATen/native/BinaryOps.h @@ -67,5 +67,6 @@ DECLARE_DISPATCH(binary_fn, gcd_stub); DECLARE_DISPATCH(binary_fn, lcm_stub); DECLARE_DISPATCH(binary_fn, hypot_stub); DECLARE_DISPATCH(binary_fn, nextafter_stub); +DECLARE_DISPATCH(binary_fn, heaviside_stub); }} // namespace at::native diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index 133d51bd99a..09847a010ee 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -764,6 +764,14 @@ void nextafter_kernel(TensorIterator& iter) { }); } +void heaviside_kernel(TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, iter.dtype(), "heaviside_cpu", [&]() { + cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t { + return a == 0 ? b : static_cast(a > 0); + }); + }); +} + } // namespace REGISTER_DISPATCH(add_stub, &add_kernel); @@ -802,6 +810,7 @@ REGISTER_DISPATCH(gcd_stub, &gcd_kernel); REGISTER_DISPATCH(lcm_stub, &lcm_kernel); REGISTER_DISPATCH(hypot_stub, &hypot_kernel); REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel); +REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel); } // namespace native } // namespace at diff --git a/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu b/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu index 2f1b5540841..4514083bf97 100644 --- a/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu +++ b/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu @@ -101,6 +101,14 @@ void nextafter_kernel_cuda(TensorIterator& iter) { }); } +void heaviside_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, iter.dtype(), "heaviside_cuda", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return a == 0 ? b : static_cast(a > 0); + }); + }); +} + REGISTER_DISPATCH(atan2_stub, &atan2_kernel_cuda); REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel_cuda); REGISTER_DISPATCH(mse_stub, &mse_kernel_cuda); @@ -110,5 +118,6 @@ REGISTER_DISPATCH(gcd_stub, &gcd_kernel_cuda); REGISTER_DISPATCH(lcm_stub, &lcm_kernel_cuda); REGISTER_DISPATCH(hypot_stub, &hypot_kernel_cuda); REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel_cuda); +REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel_cuda); }} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 436619da4b4..326af4d7d64 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3631,6 +3631,17 @@ use_c10_dispatcher: full variants: function +- func: heaviside.out(Tensor self, Tensor values, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: heaviside_out + +- func: heaviside(Tensor self, Tensor values) -> Tensor + use_c10_dispatcher: full + variants: function, method + +- func: heaviside_(Tensor(a!) self, Tensor values) -> Tensor(a!) + variants: method + # For C++ only, until we have conversion from C++ numbers to Tensor - func: rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor use_c10_dispatcher: full diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 31d0ce34a2a..1cd67d48a98 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -334,6 +334,7 @@ view of a storage and defines numeric operations on it. .. automethod:: gt_ .. automethod:: half .. automethod:: hardshrink + .. automethod:: heaviside .. automethod:: histc .. automethod:: hypot .. automethod:: hypot_ diff --git a/docs/source/torch.rst b/docs/source/torch.rst index df41d7c246f..ac59369b686 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -75,6 +75,7 @@ Creation Ops dequantize complex polar + heaviside Indexing, Slicing, Joining, Mutating Ops ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/test/test_torch.py b/test/test_torch.py index b0446812593..f88b92b7f82 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6276,6 +6276,71 @@ class TestTorchDeviceType(TestCase): torch.bitwise_xor(torch.tensor([True, True, False], device=device), torch.tensor([False, True, False], device=device))) + @onlyOnCPUAndCUDA + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + @dtypes(*list(product(torch.testing.get_all_dtypes(include_complex=False), + torch.testing.get_all_dtypes(include_complex=False)))) + def test_heaviside(self, device, dtypes): + input_dtype = dtypes[0] + values_dtype = dtypes[1] + + rng = np.random.default_rng() + input = np.array(rng.integers(-10, 10, size=10), + dtype=torch_to_numpy_dtype_dict[input_dtype if (input_dtype != torch.bfloat16) else torch.float64]) + input[0] = input[3] = input[7] = 0 + values = np.array(rng.integers(-10, 10, size=10), + dtype=torch_to_numpy_dtype_dict[values_dtype if (values_dtype != torch.bfloat16) else torch.float64]) + np_result = torch.from_numpy(np.heaviside(input, values)).to(device=device, dtype=input_dtype) + + input = torch.from_numpy(input).to(device=device, dtype=input_dtype) + values = torch.from_numpy(values).to(device=device, dtype=values_dtype) + out = torch.empty_like(input) + + if input_dtype == values_dtype: + torch_result = torch.heaviside(input, values) + self.assertEqual(np_result, torch_result) + + torch_result = input.heaviside(values) + self.assertEqual(np_result, torch_result) + + torch.heaviside(input, values, out=out) + self.assertEqual(np_result, out) + + input.heaviside_(values) + self.assertEqual(np_result, input) + else: + with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'): + torch.heaviside(input, values) + with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'): + input.heaviside(values) + with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'): + torch.heaviside(input, values, out=out) + with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'): + input.heaviside_(values) + + + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + @dtypes(*list(product(torch.testing.get_all_complex_dtypes(), + torch.testing.get_all_complex_dtypes()))) + def test_heaviside_complex(self, device, dtypes): + input_dtype = dtypes[0] + values_dtype = dtypes[1] + + data = (complex(0, -6), complex(-1, 3), complex(1, 1)) + input = torch.tensor(data, device=device, dtype=input_dtype) + values = torch.tensor(data, device=device, dtype=values_dtype) + out = torch.empty_like(input) + real = input.real + + with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'): + torch.heaviside(input, real) + with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'): + real.heaviside(values) + with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'): + input.heaviside_(values) + with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'): + torch.heaviside(real, real, out=out) + @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') @dtypes(*torch.testing.get_all_dtypes()) def test_logical_not(self, device, dtype): diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 808cd949c03..9e85eecd49b 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -1535,6 +1535,20 @@ hardshrink(lambd=0.5) -> Tensor See :func:`torch.nn.functional.hardshrink` """) +add_docstr_all('heaviside', + r""" +heaviside(values) -> Tensor + +See :func:`torch.heaviside` +""") + +add_docstr_all('heaviside_', + r""" +heaviside_(values) -> Tensor + +In-place version of :meth:`~Tensor.heaviside` +""") + add_docstr_all('histc', r""" histc(bins=100, min=0, max=0) -> Tensor diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 78b474df9d5..301511f5fbd 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -5787,6 +5787,40 @@ Example:: """.format(**common_args)) +add_docstr(torch.heaviside, + r""" +heaviside(input, values, *, out=None) -> Tensor + +Computes the Heaviside step function for each element in :attr:`input`. +The Heaviside step function is defined as: + +.. math:: + \text{{heaviside}}(input, values) = \begin{cases} + \0, & \text{if input < 0}\\ + \values, & \text{if input == 0}\\ + \1, & \text{if input > 0} + \end{cases} +""" + r""" + +Args: + {input} + values (Tensor): The values to use where :attr:`input` is zero. + +Keyword arguments: + {out} + +Example:: + + >>> input = torch.tensor([-1.5, 0, 2.0]) + >>> values = torch.tensor([0.5]) + >>> torch.heaviside(input, values) + tensor([0.0000, 0.5000, 1.0000]) + >>> values = torch.tensor([1.2, -2.0, 3.5]) + >>> torch.heaviside(input, values) + tensor([0., -2., 1.]) + +""".format(**common_args)) + add_docstr(torch.rand, r""" rand(*size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor diff --git a/torch/overrides.py b/torch/overrides.py index e9db421c03b..492111fef37 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -379,6 +379,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.gru_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1, torch.gt: lambda input, other, out=None: -1, torch.hardshrink: lambda input, lambd=0.5: -1, + torch.heaviside: lambda input, values, out=None: -1, torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction='mean': -1, torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1, torch.hspmm: lambda mat1, mat2, out=None: -1,