mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
softshrink nan fixes (#138421)
Fixes #138385 . Currently contains fixes for cpu and cuda. Will add fixes to mps as well soon if my mac can build it from source.(Had some issues with building it on my linux pc due to limited memory) Pull Request resolved: https://github.com/pytorch/pytorch/pull/138421 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
parent
3b84fb26d0
commit
37fe8015ac
5 changed files with 25 additions and 12 deletions
|
|
@ -681,12 +681,17 @@ void softshrink_kernel(TensorIteratorBase& iter, const Scalar& lambd) {
|
|||
cpu_kernel_vec(
|
||||
iter,
|
||||
[=](scalar_t a) -> scalar_t {
|
||||
return float(a) > lambd_val ? a - lambd_val : (float(a) < -lambd_val ? a + lambd_val : float(0));
|
||||
return float(a) > lambd_val ? a - lambd_val
|
||||
: (float(a) < -lambd_val ? a + lambd_val : float(a) * float(0));
|
||||
},
|
||||
[=](Vectorized<scalar_t> self_val) -> Vectorized<scalar_t> {
|
||||
auto [self_val0, self_val1] = convert_to_float<scalar_t>(self_val);
|
||||
auto self_val_t0 = convert_from_float<scalar_t>((self_val0 > lambdVec) & (self_val0 - lambdVec), (self_val1 > lambdVec) & (self_val1 - lambdVec));
|
||||
auto self_val_t1 = convert_from_float<scalar_t>((self_val0 < -lambd_val) & (self_val0 + lambdVec), (self_val1 < -lambd_val) & (self_val1 + lambdVec));
|
||||
auto self_val_t0 = convert_from_float<scalar_t>(
|
||||
((self_val0 > lambdVec) | (self_val0.isnan())) & (self_val0 - lambdVec),
|
||||
((self_val1 > lambdVec) | (self_val1.isnan())) & (self_val1 - lambdVec));
|
||||
auto self_val_t1 = convert_from_float<scalar_t>(
|
||||
((self_val0 < -lambd_val) | (self_val0.isnan())) & (self_val0 + lambdVec),
|
||||
((self_val1 < -lambd_val) | (self_val1.isnan())) & (self_val1 + lambdVec));
|
||||
return (self_val_t0 | self_val_t1);
|
||||
});
|
||||
});
|
||||
|
|
@ -697,12 +702,12 @@ void softshrink_kernel(TensorIteratorBase& iter, const Scalar& lambd) {
|
|||
cpu_kernel_vec(
|
||||
iter,
|
||||
[=](scalar_t a) -> scalar_t {
|
||||
return a > lambd_val ? a - lambd_val : (a < -lambd_val ? a + lambd_val : scalar_t(0));
|
||||
return a > lambd_val ? a - lambd_val : (a < -lambd_val ? a + lambd_val : a * scalar_t(0));
|
||||
},
|
||||
[=](Vectorized<scalar_t> self_val) -> Vectorized<scalar_t> {
|
||||
Vectorized<scalar_t> self_val_t0, self_val_t1;
|
||||
self_val_t0 = (self_val > lambdVec) & (self_val - lambdVec);
|
||||
self_val_t1 = (self_val < -lambd_val) & (self_val + lambdVec);
|
||||
self_val_t0 = ((self_val > lambdVec) | (self_val.isnan())) & (self_val - lambdVec);
|
||||
self_val_t1 = ((self_val < -lambd_val) | (self_val.isnan())) & (self_val + lambdVec);
|
||||
return (self_val_t0 | self_val_t1);
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
#include <ATen/core/TensorBase.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/cuda/CUDAMathCompat.h>
|
||||
#include <ATen/NumericUtils.h>
|
||||
#include <ATen/cuda/ApplyGridUtils.cuh>
|
||||
#include <ATen/cuda/detail/OffsetCalculator.cuh>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
|
|
@ -28,7 +29,7 @@ void softshrink_kernel(TensorIteratorBase& iter, const Scalar& value) {
|
|||
[&]() {
|
||||
auto lambd = value.to<scalar_t>();
|
||||
gpu_kernel(iter, [lambd] GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
return a > lambd ? a - lambd : (a < -lambd ? a + lambd : scalar_t(0));
|
||||
return at::_isnan(a) ? a : (a > lambd ? a - lambd : (a < -lambd ? a + lambd : scalar_t(0)));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1428,6 +1428,12 @@ TORCH_IMPL_FUNC(softshrink_out_mps)
|
|||
name:nil]
|
||||
falsePredicateTensor:outputTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* isNanTensor = [mpsGraph isNaNWithTensor:inputTensor name:nil];
|
||||
|
||||
outputTensor = [mpsGraph selectWithPredicateTensor:isNanTensor
|
||||
truePredicateTensor:inputTensor
|
||||
falsePredicateTensor:outputTensor
|
||||
name:nil];
|
||||
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
newCachedGraph->outputTensor_ = outputTensor;
|
||||
|
|
|
|||
|
|
@ -12248,22 +12248,22 @@ if __name__ == '__main__':
|
|||
0.254, -0.24, -0.225, 0.104, 0.002, -0.001, 0.0574, 1.2344,
|
||||
0.1748, -0.1797, -0.8125, 0.2051, -1.1328, 1.2344, -0.1562, 2.3554,
|
||||
-0.1953, 0.0304, -0.3613, -1.3047, 1.0312, 0.1436, -0.6953, 0.5664,
|
||||
-0.5820, -0.3301, 0.8203, 0.6133, 0.5938],
|
||||
-0.5820, -0.3301, 0.8203, 0.6133, 0.5938, float('nan')],
|
||||
[-0.8203, -1.2344, -0.5234, 2.5312, -0.4551, -0.6875, -1.5547, -0.2217,
|
||||
-0.3027, 2.6406, 1.3047, 0.2344, -1.6719, 0.2773, -1.3516, 3.4575,
|
||||
0.4414, 0.2656, 2.1094, -1.5156, 1.2344, -0.4336, 0.6797, -3.5486,
|
||||
0.9766, -0.4062, 1.4844, 0.7500, -1.7578, 0.7461, 1.6094, 8.5458,
|
||||
0.3730, -0.3477, -1.0625, 0.3848, 0.0557]], device=device)
|
||||
0.3730, -0.3477, -1.0625, 0.3848, 0.0557, float('nan')]], device=device)
|
||||
expected = torch.tensor([[0.71, 0.06, 0.0001, 0., 0.7357, 0., -0.0001, -0.654,
|
||||
0., 0., 0., 0., 0., 0., 0., 0.7344,
|
||||
0., 0., -0.3125, 0., -0.6328, 0.7344, 0., 1.8554,
|
||||
0., 0., 0., -0.8047, 0.5312, 0., -0.1953, 0.0664,
|
||||
-0.0820, 0.0, 0.3203, 0.1133, 0.0938],
|
||||
-0.0820, 0.0, 0.3203, 0.1133, 0.0938, float('nan')],
|
||||
[-0.3203, -0.7344, -0.0234, 2.0312, 0.0, -0.1875, -1.0547, 0.,
|
||||
0.0, 2.1406, 0.8047, 0., -1.1719, 0., -0.8516, 2.9575,
|
||||
0., 0., 1.6094, -1.0156, 0.7344, 0., 0.1797, -3.0486,
|
||||
0.4766, 0., 0.9844, 0.2500, -1.2578, 0.2461, 1.1094, 8.0458,
|
||||
0., 0., -0.5625, 0., 0.]])
|
||||
0., 0., -0.5625, 0., 0., float('nan')]])
|
||||
softshrink = torch.nn.Softshrink()
|
||||
out = softshrink(x)
|
||||
self.assertEqual(out, expected, atol=1e-2, rtol=0)
|
||||
|
|
|
|||
|
|
@ -512,7 +512,8 @@ def softshrink(a: TensorLikeType, lambd: float = 0.5):
|
|||
)
|
||||
# We implement this in one torch.where to generate better code in the backward
|
||||
# see https://github.com/pytorch/pytorch/pull/107052#discussion_r1293748211
|
||||
return torch.where(torch.abs(a) > lambd, a - torch.sign(a) * lambd, 0)
|
||||
# We multiply by 0 for dealing with nans
|
||||
return torch.where(torch.abs(a) > lambd, a - torch.sign(a) * lambd, a * 0)
|
||||
|
||||
|
||||
# Losses
|
||||
|
|
|
|||
Loading…
Reference in a new issue