From f65ab89eddc5cf449b94fdbddb4715fb8f57d2d6 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Mon, 5 Oct 2020 01:36:53 -0700 Subject: [PATCH] [numpy] Add torch.nan_to_num (#44592) Summary: Reference https://github.com/pytorch/pytorch/issues/42515 TODO: * [x] Add tests * [x] Add docs Pull Request resolved: https://github.com/pytorch/pytorch/pull/44592 Reviewed By: colesbury Differential Revision: D24079472 Pulled By: mruberry fbshipit-source-id: 2b67d36cba46eaa7ca16cd72671b57750bd568bc --- aten/src/ATen/NumericUtils.h | 4 +-- aten/src/ATen/core/aten_interned_strings.h | 1 + aten/src/ATen/native/UnaryOps.cpp | 36 +++++++++++++++++++ aten/src/ATen/native/UnaryOps.h | 7 ++++ aten/src/ATen/native/cpu/UnaryOpsKernel.cpp | 28 +++++++++++++++ aten/src/ATen/native/cuda/UnaryOpsKernel.cu | 28 +++++++++++++++ aten/src/ATen/native/native_functions.yaml | 10 ++++++ docs/source/tensors.rst | 2 ++ docs/source/torch.rst | 1 + test/test_autograd.py | 27 ++++++++++++++ test/test_unary_ufuncs.py | 36 +++++++++++++++++++ tools/autograd/derivatives.yaml | 3 ++ tools/autograd/gen_variable_type.py | 4 ++- torch/_tensor_docs.py | 12 +++++++ torch/_torch_docs.py | 35 ++++++++++++++++++ torch/overrides.py | 1 + .../_internal/common_methods_invocations.py | 7 +++- 17 files changed, 238 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/NumericUtils.h b/aten/src/ATen/NumericUtils.h index 6cbd974f51d..d691fec1aa3 100644 --- a/aten/src/ATen/NumericUtils.h +++ b/aten/src/ATen/NumericUtils.h @@ -42,12 +42,12 @@ inline bool _isnan(T val) { template ::value, int>::type = 0> inline C10_HOST_DEVICE bool _isnan(T val) { - return at::_isnan(float(val)); + return at::_isnan(static_cast(val)); } inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) { - return at::_isnan(float(val)); + return at::_isnan(static_cast(val)); } template diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index 4fa49302240..54481814be5 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -502,6 +502,7 @@ _(aten, multinomial) \ _(aten, mv) \ _(aten, mvlgamma) \ _(aten, nansum) \ +_(aten, nan_to_num) \ _(aten, narrow) \ _(aten, narrow_copy) \ _(aten, native_batch_norm) \ diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 2c66a197086..e2b5639f8dc 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -387,6 +387,41 @@ Tensor& logit_(Tensor& self, c10::optional eps) { return at::logit_out(self, self, eps); } +Tensor& nan_to_num_out( + Tensor& result, + const Tensor& self, + c10::optional nan, + c10::optional pos_inf, + c10::optional neg_inf) { + + if (c10::isIntegralType(self.scalar_type())) { + result.resize_as_(self); + result.copy_(self); + return result; + } + + auto iter = TensorIterator::unary_op(result, self); + nan_to_num_stub(iter.device_type(), iter, nan, pos_inf, neg_inf); + return result; +} + +Tensor nan_to_num( + const Tensor& self, + c10::optional nan, + c10::optional pos_inf, + c10::optional neg_inf) { + auto result = at::empty_like(self); + return at::nan_to_num_out(result, self, nan, pos_inf, neg_inf); +} + +Tensor& nan_to_num_( + Tensor& self, + c10::optional nan, + c10::optional pos_inf, + c10::optional neg_inf) { + return at::nan_to_num_out(self, self, nan, pos_inf, neg_inf); +} + Tensor& tanh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, tanh_stub); } Tensor tanh(const Tensor& self) { return unary_op_impl(self, at::tanh_out); } Tensor& tanh_(Tensor& self) { return unary_op_impl_(self, at::tanh_out); } @@ -645,6 +680,7 @@ DEFINE_DISPATCH(log1p_stub); DEFINE_DISPATCH(log2_stub); DEFINE_DISPATCH(logical_not_stub); DEFINE_DISPATCH(neg_stub); +DEFINE_DISPATCH(nan_to_num_stub); DEFINE_DISPATCH(polygamma_stub); DEFINE_DISPATCH(reciprocal_stub); DEFINE_DISPATCH(round_stub); diff --git a/aten/src/ATen/native/UnaryOps.h b/aten/src/ATen/native/UnaryOps.h index 0dcd5a0b947..a6db47f1715 100644 --- a/aten/src/ATen/native/UnaryOps.h +++ b/aten/src/ATen/native/UnaryOps.h @@ -77,6 +77,13 @@ DECLARE_DISPATCH(void(*)(TensorIterator&, c10::optional), random_stub DECLARE_DISPATCH(void(*)(TensorIterator&, const int64_t), polygamma_stub); DECLARE_DISPATCH(void(*)(TensorIterator&, Scalar a, Scalar b), clamp_stub); DECLARE_DISPATCH(void(*)(Tensor&, const Tensor&, int64_t, bool, c10::optional), multinomial_stub); +DECLARE_DISPATCH( + void (*)( + TensorIterator&, + c10::optional, + c10::optional, + c10::optional), + nan_to_num_stub); // Missing unary functions // digamma diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp index cc9eedebff5..84c3ceed3a2 100644 --- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp @@ -383,6 +383,33 @@ static void polygamma_kernel(TensorIterator& iter, int64_t n) { } } +static void nan_to_num_kernel( + TensorIterator& iter, + c10::optional nan, + c10::optional pos_inf, + c10::optional neg_inf) { + AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.dtype(), "nan_to_num", [&]() { + scalar_t nan_replacement = static_cast(nan.value_or(0.)); + scalar_t pos_inf_replacement = pos_inf.has_value() + ? static_cast(pos_inf.value()) + : std::numeric_limits::max(); + scalar_t neg_inf_replacement = neg_inf.has_value() + ? static_cast(neg_inf.value()) + : std::numeric_limits::lowest(); + + cpu_kernel(iter, [=](scalar_t a) -> scalar_t { + return ( + at::_isnan(a) + ? nan_replacement + : (a == std::numeric_limits::infinity() + ? pos_inf_replacement + : (a == -std::numeric_limits::infinity() + ? neg_inf_replacement + : a))); + }); + }); +} + static void clamp_kernel(TensorIterator& iter, Scalar min_scalar, Scalar max_scalar) { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, iter.dtype(), "clamp_cpu", [&]() { c10::scalar_value_type::type (*zabs_)(scalar_t) = zabs; @@ -648,6 +675,7 @@ REGISTER_DISPATCH(bitwise_not_stub, &bitwise_not_kernel); REGISTER_DISPATCH(logical_not_stub, &logical_not_kernel); REGISTER_DISPATCH(frac_stub, &frac_kernel); REGISTER_DISPATCH(reciprocal_stub, &reciprocal_kernel); +REGISTER_DISPATCH(nan_to_num_stub, &nan_to_num_kernel); REGISTER_DISPATCH(neg_stub, &neg_kernel); REGISTER_DISPATCH(sign_stub, &sign_kernel); REGISTER_DISPATCH(signbit_stub, &signbit_kernel); diff --git a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu index 1067d7c61bc..5b545471fb3 100644 --- a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -180,6 +181,32 @@ void clamp_max_kernel_cuda(TensorIterator& iter, Scalar max_value) { }); } +void nan_to_num_kernel_cuda( + TensorIterator& iter, + c10::optional nan, + c10::optional pos_inf, + c10::optional neg_inf) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "nan_to_num_cuda", [&]() { + scalar_t nan_replacement = static_cast(nan.value_or(0.)); + scalar_t pos_inf_replacement = pos_inf.has_value() + ? static_cast(pos_inf.value()) + : std::numeric_limits::max(); + scalar_t neg_inf_replacement = neg_inf.has_value() + ? static_cast(neg_inf.value()) + : std::numeric_limits::lowest(); + gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t a) -> scalar_t { + return ( + at::_isnan(a) + ? nan_replacement + : (a == std::numeric_limits::infinity() + ? pos_inf_replacement + : (a == -std::numeric_limits::infinity() + ? neg_inf_replacement + : a))); + }); + }); +} + void kaiser_window_kernel_cuda(TensorIterator& iter, int64_t window_length, double beta){ AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "kaiser_window_cuda", [&](){ AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "kaiser_window_cuda", [&] { @@ -206,6 +233,7 @@ REGISTER_DISPATCH(erfinv_stub, &erfinv_kernel_cuda); REGISTER_DISPATCH(clamp_stub, &clamp_kernel_cuda); REGISTER_DISPATCH(clamp_min_stub, &clamp_min_kernel_cuda); REGISTER_DISPATCH(clamp_max_stub, &clamp_max_kernel_cuda); +REGISTER_DISPATCH(nan_to_num_stub, &nan_to_num_kernel_cuda); REGISTER_DISPATCH(kaiser_window_stub, &kaiser_window_kernel_cuda); } // namespace native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 92a20d8625b..de5e9803727 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1981,6 +1981,16 @@ CPU: layer_norm_backward_cpu CUDA: layer_norm_backward_cuda +- func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor + use_c10_dispatcher: full + variants: function, method + +- func: nan_to_num_(Tensor(a!) self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor(a!) + use_c10_dispatcher: full + variants: function, method + +- func: nan_to_num.out(Tensor self, float? nan=None, float? posinf=None, float? neginf=None, *, Tensor(a!) out) -> Tensor(a!) + - func: linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor use_c10_dispatcher: full python_module: nn diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 7cd1a88f82b..94b1fb25f58 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -453,6 +453,8 @@ view of a storage and defines numeric operations on it. .. automethod:: narrow .. automethod:: narrow_copy .. automethod:: ndimension + .. automethod:: nan_to_num + .. automethod:: nan_to_num_ .. automethod:: ne .. automethod:: ne_ .. automethod:: not_equal diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 0063c6cc8db..d0537947d4f 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -312,6 +312,7 @@ Pointwise Ops mul multiply mvlgamma + nan_to_num neg negative nextafter diff --git a/test/test_autograd.py b/test/test_autograd.py index d6661b4662f..e92fbcbf21b 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -4641,6 +4641,33 @@ for shape in [(1,), ()]: test(inp, torch.float, torch.double) test(inp, torch.double, torch.float) + def test_nan_to_num(self): + a = torch.randn(3, 3, 3, 3) + with torch.no_grad(): + a[torch.rand_like(a) < 0.2] = float('nan') + a[torch.rand_like(a) < 0.2] = float('inf') + a[torch.rand_like(a) < 0.2] = -float('inf') + + a.requires_grad = True + + gradcheck(lambda x: x.nan_to_num(), a) + gradgradcheck(lambda x: x.nan_to_num(), a) + + gradcheck(lambda x: x.nan_to_num(nan=1.2), a) + gradgradcheck(lambda x: x.nan_to_num(nan=1.2), a) + + gradcheck(lambda x: x.nan_to_num(nan=1.2, posinf=2.0), a) + gradgradcheck(lambda x: x.nan_to_num(nan=1.2, posinf=2.0), a) + + gradcheck(lambda x: x.nan_to_num(nan=1.2, posinf=2.0, neginf=-2.0), a) + gradgradcheck(lambda x: x.nan_to_num(nan=1.2, posinf=2.0, neginf=-2.0), a) + + gradcheck(lambda x: x.nan_to_num(posinf=2.0, neginf=-2.0), a) + gradgradcheck(lambda x: x.nan_to_num(posinf=2.0, neginf=-2.0), a) + + gradcheck(lambda x: x.nan_to_num(neginf=-2.0), a) + gradgradcheck(lambda x: x.nan_to_num(neginf=-2.0), a) + def test_custom_function_error(self): class BadFw(Function): @staticmethod diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 09a3cbd583a..ddc735199f2 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -1,6 +1,7 @@ import math from itertools import product, chain from numbers import Number +import random import unittest @@ -377,6 +378,41 @@ class TestUnaryUfuncs(TestCase): self.assertEqual(actual, expected) + @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False))) + def test_nan_to_num(self, device, dtype): + for contiguous in [False, True]: + x = make_tensor((64, 64), low=0., high=100., dtype=dtype, device=device) + + if dtype.is_floating_point: + # Add extremal values. + extremals = [float('nan'), float('inf'), -float('inf')] + for idx, extremal in zip(torch.randint(0, 63, (3,)), extremals): + x[idx, :] = extremal + + if not contiguous: + x = x.T + + # With args + nan = random.random() + posinf = random.random() * 5 + neginf = random.random() * 10 + + self.compare_with_numpy(lambda x: x.nan_to_num(nan=nan, posinf=posinf), + lambda x: np.nan_to_num(x, nan=nan, posinf=posinf), + x) + self.compare_with_numpy(lambda x: x.nan_to_num(posinf=posinf, neginf=neginf), + lambda x: np.nan_to_num(x, posinf=posinf, neginf=neginf), + x) + + # Out Variant + out = torch.empty_like(x) + result = torch.nan_to_num(x) + torch.nan_to_num(x, out=out) + self.assertEqual(result, out) + + result = torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf) + torch.nan_to_num(x, out=out, nan=nan, posinf=posinf, neginf=neginf) + self.assertEqual(result, out) instantiate_device_type_tests(TestUnaryUfuncs, globals()) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 92ee277e9ec..2af8ee81604 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -749,6 +749,9 @@ - name: mvlgamma(Tensor self, int p) -> Tensor self: mvlgamma_backward(grad, self, p) +- name: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor + self: grad * at::isfinite(self) + - name: native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, eps, grad_input_mask) : std::tuple()" diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index b7e61b5b7a8..6e0dc0721ae 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -145,7 +145,9 @@ DONT_REQUIRE_DERIVATIVE = { 'quantize_per_tensor', 'quantize_per_channel', # Functions that return integers should not have output that require gradients 'argmax', 'argmin', 'argsort', 'searchsorted', - 'bucketize' + 'bucketize', + # Functions that return booleans are not differentiable + 'isnan', 'isposinf', 'isneginf', 'isinf' } # The C -> R functions at the time of adding this are still being audited and tested diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 55c5613cdcc..7caceff4a1d 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -2342,6 +2342,18 @@ ndimension() -> int Alias for :meth:`~Tensor.dim()` """) +add_docstr_all('nan_to_num', r""" +nan_to_num(nan=0.0, posinf=None, neginf=None) -> Tensor + +See :func:`torch.nan_to_num`. +""") + +add_docstr_all('nan_to_num_', r""" +nan_to_num_(nan=0.0, posinf=None, neginf=None) -> Tensor + +In-place version of :meth:`~Tensor.nan_to_num`. +""") + add_docstr_all('ne', r""" ne(other) -> Tensor diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index a939eacad1f..6c641c3df14 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -5301,6 +5301,41 @@ Example:: [ 8, 9]]) """) +add_docstr(torch.nan_to_num, + r""" +nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None) -> Tensor + +Replaces :literal:`NaN`, positive infinity, and negative infinity values in :attr:`input` +with the values specified by :attr:`nan`, :attr:`posinf`, and :attr:`neginf`, respectively. +By default, :literal:`NaN`s are replaced with zero, positive infinity is replaced with the +greatest finite value representable by :attr:`input`'s dtype, and negative infinity +is replaced with the least finite value representable by :attr:`input`'s dtype. + +Args: + {input} + nan (Number, optional): the value to replace :literal:`NaN`\s with. Default is zero. + posinf (Number, optional): if a Number, the value to replace positive infinity values with. + If None, positive infinity values are replaced with the greatest finite value representable by :attr:`input`'s dtype. + Default is None. + neginf (Number, optional): if a Number, the value to replace negative infinity values with. + If None, negative infinity values are replaced with the lowest finite value representable by :attr:`input`'s dtype. + Default is None. + +Keyword args: + {out} + +Example:: + + >>> x = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14]) + >>> torch.nan_to_num(x) + tensor([ 0.0000e+00, 3.4028e+38, -3.4028e+38, 3.1400e+00]) + >>> torch.nan_to_num(x, nan=2.0) + tensor([ 2.0000e+00, 3.4028e+38, -3.4028e+38, 3.1400e+00]) + >>> torch.nan_to_num(x, nan=2.0, posinf=1.0) + tensor([ 2.0000e+00, 1.0000e+00, -3.4028e+38, 3.1400e+00]) + +""".format(**common_args)) + add_docstr(torch.ne, r""" ne(input, other, *, out=None) -> Tensor diff --git a/torch/overrides.py b/torch/overrides.py index bab17c1e961..43efda1da86 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -490,6 +490,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.mv: lambda input, vec, out=None: -1, torch.mvlgamma: lambda input, p: -1, torch.narrow: lambda input, dim, start, length: -1, + torch.nan_to_num: lambda input, nan=0.0, posinf=None, neginf=None, out=None: -1, torch.native_batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps: -1, torch.native_layer_norm: lambda input, weight, bias, M, N, eps: -1, torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 290645fd0d3..46ba17f61d8 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -12,7 +12,7 @@ from typing import List, Tuple, Dict, Any from torch.testing import \ (make_non_contiguous, _dispatch_dtypes, floating_types, floating_types_and, floating_and_complex_types, - floating_and_complex_types_and, all_types_and_complex_and) + floating_and_complex_types_and, all_types_and_complex_and, all_types_and) from torch.testing._internal.common_device_type import \ (skipCUDAIfNoMagma, skipCPUIfNoLapack, expectedFailureCUDA, expectedAlertNondeterministic, precisionOverride) @@ -389,6 +389,11 @@ op_db = [ ref=np.exp2, dtypes=floating_types_and(torch.half), dtypesIfCPU=None, + dtypesIfCUDA=None), + UnaryUfuncInfo('nan_to_num', + ref=np.nan_to_num, + dtypes=all_types_and(torch.half), + dtypesIfCPU=None, dtypesIfCUDA=None) ]