Implementing NumPy-like function torch.heaviside() (#42523)

Summary:
- Related with https://github.com/pytorch/pytorch/issues/38349
- Implementing the NumPy-like function `torch.heaviside()` .

Pull Request resolved: https://github.com/pytorch/pytorch/pull/42523

Reviewed By: ngimel

Differential Revision: D23416743

Pulled By: mruberry

fbshipit-source-id: 9975bd9c9fa73bd0958fe9879f79a692aeb722d5
This commit is contained in:
kiyosora 2020-08-31 15:43:51 -07:00 committed by Facebook GitHub Bot
parent 7680d87a76
commit 3682df77db
12 changed files with 175 additions and 0 deletions

View file

@ -367,6 +367,7 @@ _(aten, hardsigmoid_backward) \
_(aten, hardtanh) \
_(aten, hardtanh_backward) \
_(aten, hardtanh_forward) \
_(aten, heaviside) \
_(aten, hinge_embedding_loss) \
_(aten, histc) \
_(aten, hspmm) \

View file

@ -47,6 +47,7 @@ DEFINE_DISPATCH(gcd_stub);
DEFINE_DISPATCH(lcm_stub);
DEFINE_DISPATCH(hypot_stub);
DEFINE_DISPATCH(nextafter_stub);
DEFINE_DISPATCH(heaviside_stub);
static Tensor wrapped_scalar_tensor(Scalar scalar) {
auto tensor = scalar_to_tensor(scalar);
@ -952,6 +953,33 @@ Tensor _test_serialization_subcmul(const Tensor& self, const Tensor& other, Scal
return self - (other * alpha);
}
Tensor& heaviside_out(Tensor& result, const Tensor& self, const Tensor& values) {
TORCH_CHECK(!self.is_complex() && !result.is_complex() && !values.is_complex(),
"heaviside is not yet implemented for complex tensors.");
TORCH_CHECK(self.dtype() == values.dtype() && result.dtype() == self.dtype(),
"heaviside is not yet implemented for tensors with different dtypes.");
auto iter = TensorIterator::binary_op(result, self, values);
heaviside_stub(iter.device_type(), iter);
return result;
}
Tensor heaviside(const Tensor& self, const Tensor& values) {
TORCH_CHECK(!self.is_complex() && !values.is_complex(),
"heaviside is not yet implemented for complex tensors.");
TORCH_CHECK(self.dtype() == values.dtype(),
"heaviside is not yet implemented for tensors with different dtypes.");
Tensor result;
auto iter = TensorIterator::binary_op(result, self, values);
heaviside_stub(iter.device_type(), iter);
return iter.output();
}
Tensor& heaviside_(Tensor& self, const Tensor& values) {
return at::heaviside_out(self, self, values);
}
// TODO: Deduplicate this with the TensorIterator logic. This would
// also fix the TODOs below.
Tensor binary_op_meta(const Tensor& self, const Tensor& other) {

View file

@ -67,5 +67,6 @@ DECLARE_DISPATCH(binary_fn, gcd_stub);
DECLARE_DISPATCH(binary_fn, lcm_stub);
DECLARE_DISPATCH(binary_fn, hypot_stub);
DECLARE_DISPATCH(binary_fn, nextafter_stub);
DECLARE_DISPATCH(binary_fn, heaviside_stub);
}} // namespace at::native

View file

@ -764,6 +764,14 @@ void nextafter_kernel(TensorIterator& iter) {
});
}
void heaviside_kernel(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, iter.dtype(), "heaviside_cpu", [&]() {
cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
return a == 0 ? b : static_cast<scalar_t>(a > 0);
});
});
}
} // namespace
REGISTER_DISPATCH(add_stub, &add_kernel);
@ -802,6 +810,7 @@ REGISTER_DISPATCH(gcd_stub, &gcd_kernel);
REGISTER_DISPATCH(lcm_stub, &lcm_kernel);
REGISTER_DISPATCH(hypot_stub, &hypot_kernel);
REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel);
REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel);
} // namespace native
} // namespace at

View file

@ -101,6 +101,14 @@ void nextafter_kernel_cuda(TensorIterator& iter) {
});
}
void heaviside_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, iter.dtype(), "heaviside_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a == 0 ? b : static_cast<scalar_t>(a > 0);
});
});
}
REGISTER_DISPATCH(atan2_stub, &atan2_kernel_cuda);
REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel_cuda);
REGISTER_DISPATCH(mse_stub, &mse_kernel_cuda);
@ -110,5 +118,6 @@ REGISTER_DISPATCH(gcd_stub, &gcd_kernel_cuda);
REGISTER_DISPATCH(lcm_stub, &lcm_kernel_cuda);
REGISTER_DISPATCH(hypot_stub, &hypot_kernel_cuda);
REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel_cuda);
REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel_cuda);
}} // namespace at::native

View file

@ -3631,6 +3631,17 @@
use_c10_dispatcher: full
variants: function
- func: heaviside.out(Tensor self, Tensor values, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: heaviside_out
- func: heaviside(Tensor self, Tensor values) -> Tensor
use_c10_dispatcher: full
variants: function, method
- func: heaviside_(Tensor(a!) self, Tensor values) -> Tensor(a!)
variants: method
# For C++ only, until we have conversion from C++ numbers to Tensor
- func: rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
use_c10_dispatcher: full

View file

@ -334,6 +334,7 @@ view of a storage and defines numeric operations on it.
.. automethod:: gt_
.. automethod:: half
.. automethod:: hardshrink
.. automethod:: heaviside
.. automethod:: histc
.. automethod:: hypot
.. automethod:: hypot_

View file

@ -75,6 +75,7 @@ Creation Ops
dequantize
complex
polar
heaviside
Indexing, Slicing, Joining, Mutating Ops
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View file

@ -6276,6 +6276,71 @@ class TestTorchDeviceType(TestCase):
torch.bitwise_xor(torch.tensor([True, True, False], device=device),
torch.tensor([False, True, False], device=device)))
@onlyOnCPUAndCUDA
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
@dtypes(*list(product(torch.testing.get_all_dtypes(include_complex=False),
torch.testing.get_all_dtypes(include_complex=False))))
def test_heaviside(self, device, dtypes):
input_dtype = dtypes[0]
values_dtype = dtypes[1]
rng = np.random.default_rng()
input = np.array(rng.integers(-10, 10, size=10),
dtype=torch_to_numpy_dtype_dict[input_dtype if (input_dtype != torch.bfloat16) else torch.float64])
input[0] = input[3] = input[7] = 0
values = np.array(rng.integers(-10, 10, size=10),
dtype=torch_to_numpy_dtype_dict[values_dtype if (values_dtype != torch.bfloat16) else torch.float64])
np_result = torch.from_numpy(np.heaviside(input, values)).to(device=device, dtype=input_dtype)
input = torch.from_numpy(input).to(device=device, dtype=input_dtype)
values = torch.from_numpy(values).to(device=device, dtype=values_dtype)
out = torch.empty_like(input)
if input_dtype == values_dtype:
torch_result = torch.heaviside(input, values)
self.assertEqual(np_result, torch_result)
torch_result = input.heaviside(values)
self.assertEqual(np_result, torch_result)
torch.heaviside(input, values, out=out)
self.assertEqual(np_result, out)
input.heaviside_(values)
self.assertEqual(np_result, input)
else:
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
torch.heaviside(input, values)
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
input.heaviside(values)
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
torch.heaviside(input, values, out=out)
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
input.heaviside_(values)
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
@dtypes(*list(product(torch.testing.get_all_complex_dtypes(),
torch.testing.get_all_complex_dtypes())))
def test_heaviside_complex(self, device, dtypes):
input_dtype = dtypes[0]
values_dtype = dtypes[1]
data = (complex(0, -6), complex(-1, 3), complex(1, 1))
input = torch.tensor(data, device=device, dtype=input_dtype)
values = torch.tensor(data, device=device, dtype=values_dtype)
out = torch.empty_like(input)
real = input.real
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
torch.heaviside(input, real)
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
real.heaviside(values)
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
input.heaviside_(values)
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
torch.heaviside(real, real, out=out)
@unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
@dtypes(*torch.testing.get_all_dtypes())
def test_logical_not(self, device, dtype):

View file

@ -1535,6 +1535,20 @@ hardshrink(lambd=0.5) -> Tensor
See :func:`torch.nn.functional.hardshrink`
""")
add_docstr_all('heaviside',
r"""
heaviside(values) -> Tensor
See :func:`torch.heaviside`
""")
add_docstr_all('heaviside_',
r"""
heaviside_(values) -> Tensor
In-place version of :meth:`~Tensor.heaviside`
""")
add_docstr_all('histc',
r"""
histc(bins=100, min=0, max=0) -> Tensor

View file

@ -5787,6 +5787,40 @@ Example::
""".format(**common_args))
add_docstr(torch.heaviside,
r"""
heaviside(input, values, *, out=None) -> Tensor
Computes the Heaviside step function for each element in :attr:`input`.
The Heaviside step function is defined as:
.. math::
\text{{heaviside}}(input, values) = \begin{cases}
\0, & \text{if input < 0}\\
\values, & \text{if input == 0}\\
\1, & \text{if input > 0}
\end{cases}
""" + r"""
Args:
{input}
values (Tensor): The values to use where :attr:`input` is zero.
Keyword arguments:
{out}
Example::
>>> input = torch.tensor([-1.5, 0, 2.0])
>>> values = torch.tensor([0.5])
>>> torch.heaviside(input, values)
tensor([0.0000, 0.5000, 1.0000])
>>> values = torch.tensor([1.2, -2.0, 3.5])
>>> torch.heaviside(input, values)
tensor([0., -2., 1.])
""".format(**common_args))
add_docstr(torch.rand,
r"""
rand(*size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor

View file

@ -379,6 +379,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.gru_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
torch.gt: lambda input, other, out=None: -1,
torch.hardshrink: lambda input, lambd=0.5: -1,
torch.heaviside: lambda input, values, out=None: -1,
torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction='mean': -1,
torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1,
torch.hspmm: lambda mat1, mat2, out=None: -1,