mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Implementing NumPy-like function torch.heaviside() (#42523)
Summary: - Related with https://github.com/pytorch/pytorch/issues/38349 - Implementing the NumPy-like function `torch.heaviside()` . Pull Request resolved: https://github.com/pytorch/pytorch/pull/42523 Reviewed By: ngimel Differential Revision: D23416743 Pulled By: mruberry fbshipit-source-id: 9975bd9c9fa73bd0958fe9879f79a692aeb722d5
This commit is contained in:
parent
7680d87a76
commit
3682df77db
12 changed files with 175 additions and 0 deletions
|
|
@ -367,6 +367,7 @@ _(aten, hardsigmoid_backward) \
|
|||
_(aten, hardtanh) \
|
||||
_(aten, hardtanh_backward) \
|
||||
_(aten, hardtanh_forward) \
|
||||
_(aten, heaviside) \
|
||||
_(aten, hinge_embedding_loss) \
|
||||
_(aten, histc) \
|
||||
_(aten, hspmm) \
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ DEFINE_DISPATCH(gcd_stub);
|
|||
DEFINE_DISPATCH(lcm_stub);
|
||||
DEFINE_DISPATCH(hypot_stub);
|
||||
DEFINE_DISPATCH(nextafter_stub);
|
||||
DEFINE_DISPATCH(heaviside_stub);
|
||||
|
||||
static Tensor wrapped_scalar_tensor(Scalar scalar) {
|
||||
auto tensor = scalar_to_tensor(scalar);
|
||||
|
|
@ -952,6 +953,33 @@ Tensor _test_serialization_subcmul(const Tensor& self, const Tensor& other, Scal
|
|||
return self - (other * alpha);
|
||||
}
|
||||
|
||||
Tensor& heaviside_out(Tensor& result, const Tensor& self, const Tensor& values) {
|
||||
TORCH_CHECK(!self.is_complex() && !result.is_complex() && !values.is_complex(),
|
||||
"heaviside is not yet implemented for complex tensors.");
|
||||
TORCH_CHECK(self.dtype() == values.dtype() && result.dtype() == self.dtype(),
|
||||
"heaviside is not yet implemented for tensors with different dtypes.");
|
||||
|
||||
auto iter = TensorIterator::binary_op(result, self, values);
|
||||
heaviside_stub(iter.device_type(), iter);
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor heaviside(const Tensor& self, const Tensor& values) {
|
||||
TORCH_CHECK(!self.is_complex() && !values.is_complex(),
|
||||
"heaviside is not yet implemented for complex tensors.");
|
||||
TORCH_CHECK(self.dtype() == values.dtype(),
|
||||
"heaviside is not yet implemented for tensors with different dtypes.");
|
||||
|
||||
Tensor result;
|
||||
auto iter = TensorIterator::binary_op(result, self, values);
|
||||
heaviside_stub(iter.device_type(), iter);
|
||||
return iter.output();
|
||||
}
|
||||
|
||||
Tensor& heaviside_(Tensor& self, const Tensor& values) {
|
||||
return at::heaviside_out(self, self, values);
|
||||
}
|
||||
|
||||
// TODO: Deduplicate this with the TensorIterator logic. This would
|
||||
// also fix the TODOs below.
|
||||
Tensor binary_op_meta(const Tensor& self, const Tensor& other) {
|
||||
|
|
|
|||
|
|
@ -67,5 +67,6 @@ DECLARE_DISPATCH(binary_fn, gcd_stub);
|
|||
DECLARE_DISPATCH(binary_fn, lcm_stub);
|
||||
DECLARE_DISPATCH(binary_fn, hypot_stub);
|
||||
DECLARE_DISPATCH(binary_fn, nextafter_stub);
|
||||
DECLARE_DISPATCH(binary_fn, heaviside_stub);
|
||||
|
||||
}} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -764,6 +764,14 @@ void nextafter_kernel(TensorIterator& iter) {
|
|||
});
|
||||
}
|
||||
|
||||
void heaviside_kernel(TensorIterator& iter) {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, iter.dtype(), "heaviside_cpu", [&]() {
|
||||
cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
|
||||
return a == 0 ? b : static_cast<scalar_t>(a > 0);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
REGISTER_DISPATCH(add_stub, &add_kernel);
|
||||
|
|
@ -802,6 +810,7 @@ REGISTER_DISPATCH(gcd_stub, &gcd_kernel);
|
|||
REGISTER_DISPATCH(lcm_stub, &lcm_kernel);
|
||||
REGISTER_DISPATCH(hypot_stub, &hypot_kernel);
|
||||
REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel);
|
||||
REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel);
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
|
|
|||
|
|
@ -101,6 +101,14 @@ void nextafter_kernel_cuda(TensorIterator& iter) {
|
|||
});
|
||||
}
|
||||
|
||||
void heaviside_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, iter.dtype(), "heaviside_cuda", [&]() {
|
||||
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
|
||||
return a == 0 ? b : static_cast<scalar_t>(a > 0);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(atan2_stub, &atan2_kernel_cuda);
|
||||
REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel_cuda);
|
||||
REGISTER_DISPATCH(mse_stub, &mse_kernel_cuda);
|
||||
|
|
@ -110,5 +118,6 @@ REGISTER_DISPATCH(gcd_stub, &gcd_kernel_cuda);
|
|||
REGISTER_DISPATCH(lcm_stub, &lcm_kernel_cuda);
|
||||
REGISTER_DISPATCH(hypot_stub, &hypot_kernel_cuda);
|
||||
REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel_cuda);
|
||||
REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel_cuda);
|
||||
|
||||
}} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -3631,6 +3631,17 @@
|
|||
use_c10_dispatcher: full
|
||||
variants: function
|
||||
|
||||
- func: heaviside.out(Tensor self, Tensor values, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU, CUDA: heaviside_out
|
||||
|
||||
- func: heaviside(Tensor self, Tensor values) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
variants: function, method
|
||||
|
||||
- func: heaviside_(Tensor(a!) self, Tensor values) -> Tensor(a!)
|
||||
variants: method
|
||||
|
||||
# For C++ only, until we have conversion from C++ numbers to Tensor
|
||||
- func: rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
|
|
|
|||
|
|
@ -334,6 +334,7 @@ view of a storage and defines numeric operations on it.
|
|||
.. automethod:: gt_
|
||||
.. automethod:: half
|
||||
.. automethod:: hardshrink
|
||||
.. automethod:: heaviside
|
||||
.. automethod:: histc
|
||||
.. automethod:: hypot
|
||||
.. automethod:: hypot_
|
||||
|
|
|
|||
|
|
@ -75,6 +75,7 @@ Creation Ops
|
|||
dequantize
|
||||
complex
|
||||
polar
|
||||
heaviside
|
||||
|
||||
Indexing, Slicing, Joining, Mutating Ops
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
|
|
|||
|
|
@ -6276,6 +6276,71 @@ class TestTorchDeviceType(TestCase):
|
|||
torch.bitwise_xor(torch.tensor([True, True, False], device=device),
|
||||
torch.tensor([False, True, False], device=device)))
|
||||
|
||||
@onlyOnCPUAndCUDA
|
||||
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
||||
@dtypes(*list(product(torch.testing.get_all_dtypes(include_complex=False),
|
||||
torch.testing.get_all_dtypes(include_complex=False))))
|
||||
def test_heaviside(self, device, dtypes):
|
||||
input_dtype = dtypes[0]
|
||||
values_dtype = dtypes[1]
|
||||
|
||||
rng = np.random.default_rng()
|
||||
input = np.array(rng.integers(-10, 10, size=10),
|
||||
dtype=torch_to_numpy_dtype_dict[input_dtype if (input_dtype != torch.bfloat16) else torch.float64])
|
||||
input[0] = input[3] = input[7] = 0
|
||||
values = np.array(rng.integers(-10, 10, size=10),
|
||||
dtype=torch_to_numpy_dtype_dict[values_dtype if (values_dtype != torch.bfloat16) else torch.float64])
|
||||
np_result = torch.from_numpy(np.heaviside(input, values)).to(device=device, dtype=input_dtype)
|
||||
|
||||
input = torch.from_numpy(input).to(device=device, dtype=input_dtype)
|
||||
values = torch.from_numpy(values).to(device=device, dtype=values_dtype)
|
||||
out = torch.empty_like(input)
|
||||
|
||||
if input_dtype == values_dtype:
|
||||
torch_result = torch.heaviside(input, values)
|
||||
self.assertEqual(np_result, torch_result)
|
||||
|
||||
torch_result = input.heaviside(values)
|
||||
self.assertEqual(np_result, torch_result)
|
||||
|
||||
torch.heaviside(input, values, out=out)
|
||||
self.assertEqual(np_result, out)
|
||||
|
||||
input.heaviside_(values)
|
||||
self.assertEqual(np_result, input)
|
||||
else:
|
||||
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
|
||||
torch.heaviside(input, values)
|
||||
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
|
||||
input.heaviside(values)
|
||||
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
|
||||
torch.heaviside(input, values, out=out)
|
||||
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
|
||||
input.heaviside_(values)
|
||||
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
||||
@dtypes(*list(product(torch.testing.get_all_complex_dtypes(),
|
||||
torch.testing.get_all_complex_dtypes())))
|
||||
def test_heaviside_complex(self, device, dtypes):
|
||||
input_dtype = dtypes[0]
|
||||
values_dtype = dtypes[1]
|
||||
|
||||
data = (complex(0, -6), complex(-1, 3), complex(1, 1))
|
||||
input = torch.tensor(data, device=device, dtype=input_dtype)
|
||||
values = torch.tensor(data, device=device, dtype=values_dtype)
|
||||
out = torch.empty_like(input)
|
||||
real = input.real
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
|
||||
torch.heaviside(input, real)
|
||||
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
|
||||
real.heaviside(values)
|
||||
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
|
||||
input.heaviside_(values)
|
||||
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
|
||||
torch.heaviside(real, real, out=out)
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
|
||||
@dtypes(*torch.testing.get_all_dtypes())
|
||||
def test_logical_not(self, device, dtype):
|
||||
|
|
|
|||
|
|
@ -1535,6 +1535,20 @@ hardshrink(lambd=0.5) -> Tensor
|
|||
See :func:`torch.nn.functional.hardshrink`
|
||||
""")
|
||||
|
||||
add_docstr_all('heaviside',
|
||||
r"""
|
||||
heaviside(values) -> Tensor
|
||||
|
||||
See :func:`torch.heaviside`
|
||||
""")
|
||||
|
||||
add_docstr_all('heaviside_',
|
||||
r"""
|
||||
heaviside_(values) -> Tensor
|
||||
|
||||
In-place version of :meth:`~Tensor.heaviside`
|
||||
""")
|
||||
|
||||
add_docstr_all('histc',
|
||||
r"""
|
||||
histc(bins=100, min=0, max=0) -> Tensor
|
||||
|
|
|
|||
|
|
@ -5787,6 +5787,40 @@ Example::
|
|||
|
||||
""".format(**common_args))
|
||||
|
||||
add_docstr(torch.heaviside,
|
||||
r"""
|
||||
heaviside(input, values, *, out=None) -> Tensor
|
||||
|
||||
Computes the Heaviside step function for each element in :attr:`input`.
|
||||
The Heaviside step function is defined as:
|
||||
|
||||
.. math::
|
||||
\text{{heaviside}}(input, values) = \begin{cases}
|
||||
\0, & \text{if input < 0}\\
|
||||
\values, & \text{if input == 0}\\
|
||||
\1, & \text{if input > 0}
|
||||
\end{cases}
|
||||
""" + r"""
|
||||
|
||||
Args:
|
||||
{input}
|
||||
values (Tensor): The values to use where :attr:`input` is zero.
|
||||
|
||||
Keyword arguments:
|
||||
{out}
|
||||
|
||||
Example::
|
||||
|
||||
>>> input = torch.tensor([-1.5, 0, 2.0])
|
||||
>>> values = torch.tensor([0.5])
|
||||
>>> torch.heaviside(input, values)
|
||||
tensor([0.0000, 0.5000, 1.0000])
|
||||
>>> values = torch.tensor([1.2, -2.0, 3.5])
|
||||
>>> torch.heaviside(input, values)
|
||||
tensor([0., -2., 1.])
|
||||
|
||||
""".format(**common_args))
|
||||
|
||||
add_docstr(torch.rand,
|
||||
r"""
|
||||
rand(*size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor
|
||||
|
|
|
|||
|
|
@ -379,6 +379,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.gru_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
|
||||
torch.gt: lambda input, other, out=None: -1,
|
||||
torch.hardshrink: lambda input, lambd=0.5: -1,
|
||||
torch.heaviside: lambda input, values, out=None: -1,
|
||||
torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction='mean': -1,
|
||||
torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1,
|
||||
torch.hspmm: lambda mat1, mat2, out=None: -1,
|
||||
|
|
|
|||
Loading…
Reference in a new issue