softshrink nan fixes (#138421)

Fixes #138385 .

Currently contains fixes for cpu and cuda. Will add fixes to mps as well soon if my mac can build it from source.(Had some issues with building it on my linux pc due to limited memory)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138421
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
isalia20 2024-11-21 23:06:08 +00:00 committed by PyTorch MergeBot
parent 3b84fb26d0
commit 37fe8015ac
5 changed files with 25 additions and 12 deletions

View file

@ -681,12 +681,17 @@ void softshrink_kernel(TensorIteratorBase& iter, const Scalar& lambd) {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t {
return float(a) > lambd_val ? a - lambd_val : (float(a) < -lambd_val ? a + lambd_val : float(0));
return float(a) > lambd_val ? a - lambd_val
: (float(a) < -lambd_val ? a + lambd_val : float(a) * float(0));
},
[=](Vectorized<scalar_t> self_val) -> Vectorized<scalar_t> {
auto [self_val0, self_val1] = convert_to_float<scalar_t>(self_val);
auto self_val_t0 = convert_from_float<scalar_t>((self_val0 > lambdVec) & (self_val0 - lambdVec), (self_val1 > lambdVec) & (self_val1 - lambdVec));
auto self_val_t1 = convert_from_float<scalar_t>((self_val0 < -lambd_val) & (self_val0 + lambdVec), (self_val1 < -lambd_val) & (self_val1 + lambdVec));
auto self_val_t0 = convert_from_float<scalar_t>(
((self_val0 > lambdVec) | (self_val0.isnan())) & (self_val0 - lambdVec),
((self_val1 > lambdVec) | (self_val1.isnan())) & (self_val1 - lambdVec));
auto self_val_t1 = convert_from_float<scalar_t>(
((self_val0 < -lambd_val) | (self_val0.isnan())) & (self_val0 + lambdVec),
((self_val1 < -lambd_val) | (self_val1.isnan())) & (self_val1 + lambdVec));
return (self_val_t0 | self_val_t1);
});
});
@ -697,12 +702,12 @@ void softshrink_kernel(TensorIteratorBase& iter, const Scalar& lambd) {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t {
return a > lambd_val ? a - lambd_val : (a < -lambd_val ? a + lambd_val : scalar_t(0));
return a > lambd_val ? a - lambd_val : (a < -lambd_val ? a + lambd_val : a * scalar_t(0));
},
[=](Vectorized<scalar_t> self_val) -> Vectorized<scalar_t> {
Vectorized<scalar_t> self_val_t0, self_val_t1;
self_val_t0 = (self_val > lambdVec) & (self_val - lambdVec);
self_val_t1 = (self_val < -lambd_val) & (self_val + lambdVec);
self_val_t0 = ((self_val > lambdVec) | (self_val.isnan())) & (self_val - lambdVec);
self_val_t1 = ((self_val < -lambd_val) | (self_val.isnan())) & (self_val + lambdVec);
return (self_val_t0 | self_val_t1);
});
});

View file

@ -12,6 +12,7 @@
#include <ATen/core/TensorBase.h>
#include <c10/core/Scalar.h>
#include <c10/cuda/CUDAMathCompat.h>
#include <ATen/NumericUtils.h>
#include <ATen/cuda/ApplyGridUtils.cuh>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/native/cuda/Loops.cuh>
@ -28,7 +29,7 @@ void softshrink_kernel(TensorIteratorBase& iter, const Scalar& value) {
[&]() {
auto lambd = value.to<scalar_t>();
gpu_kernel(iter, [lambd] GPU_LAMBDA(scalar_t a) -> scalar_t {
return a > lambd ? a - lambd : (a < -lambd ? a + lambd : scalar_t(0));
return at::_isnan(a) ? a : (a > lambd ? a - lambd : (a < -lambd ? a + lambd : scalar_t(0)));
});
});
}

View file

@ -1428,6 +1428,12 @@ TORCH_IMPL_FUNC(softshrink_out_mps)
name:nil]
falsePredicateTensor:outputTensor
name:nil];
MPSGraphTensor* isNanTensor = [mpsGraph isNaNWithTensor:inputTensor name:nil];
outputTensor = [mpsGraph selectWithPredicateTensor:isNanTensor
truePredicateTensor:inputTensor
falsePredicateTensor:outputTensor
name:nil];
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;

View file

@ -12248,22 +12248,22 @@ if __name__ == '__main__':
0.254, -0.24, -0.225, 0.104, 0.002, -0.001, 0.0574, 1.2344,
0.1748, -0.1797, -0.8125, 0.2051, -1.1328, 1.2344, -0.1562, 2.3554,
-0.1953, 0.0304, -0.3613, -1.3047, 1.0312, 0.1436, -0.6953, 0.5664,
-0.5820, -0.3301, 0.8203, 0.6133, 0.5938],
-0.5820, -0.3301, 0.8203, 0.6133, 0.5938, float('nan')],
[-0.8203, -1.2344, -0.5234, 2.5312, -0.4551, -0.6875, -1.5547, -0.2217,
-0.3027, 2.6406, 1.3047, 0.2344, -1.6719, 0.2773, -1.3516, 3.4575,
0.4414, 0.2656, 2.1094, -1.5156, 1.2344, -0.4336, 0.6797, -3.5486,
0.9766, -0.4062, 1.4844, 0.7500, -1.7578, 0.7461, 1.6094, 8.5458,
0.3730, -0.3477, -1.0625, 0.3848, 0.0557]], device=device)
0.3730, -0.3477, -1.0625, 0.3848, 0.0557, float('nan')]], device=device)
expected = torch.tensor([[0.71, 0.06, 0.0001, 0., 0.7357, 0., -0.0001, -0.654,
0., 0., 0., 0., 0., 0., 0., 0.7344,
0., 0., -0.3125, 0., -0.6328, 0.7344, 0., 1.8554,
0., 0., 0., -0.8047, 0.5312, 0., -0.1953, 0.0664,
-0.0820, 0.0, 0.3203, 0.1133, 0.0938],
-0.0820, 0.0, 0.3203, 0.1133, 0.0938, float('nan')],
[-0.3203, -0.7344, -0.0234, 2.0312, 0.0, -0.1875, -1.0547, 0.,
0.0, 2.1406, 0.8047, 0., -1.1719, 0., -0.8516, 2.9575,
0., 0., 1.6094, -1.0156, 0.7344, 0., 0.1797, -3.0486,
0.4766, 0., 0.9844, 0.2500, -1.2578, 0.2461, 1.1094, 8.0458,
0., 0., -0.5625, 0., 0.]])
0., 0., -0.5625, 0., 0., float('nan')]])
softshrink = torch.nn.Softshrink()
out = softshrink(x)
self.assertEqual(out, expected, atol=1e-2, rtol=0)

View file

@ -512,7 +512,8 @@ def softshrink(a: TensorLikeType, lambd: float = 0.5):
)
# We implement this in one torch.where to generate better code in the backward
# see https://github.com/pytorch/pytorch/pull/107052#discussion_r1293748211
return torch.where(torch.abs(a) > lambd, a - torch.sign(a) * lambd, 0)
# We multiply by 0 for dealing with nans
return torch.where(torch.abs(a) > lambd, a - torch.sign(a) * lambd, a * 0)
# Losses