diff --git a/torchy_baselines/common/distributions.py b/torchy_baselines/common/distributions.py index e188667..90258cf 100644 --- a/torchy_baselines/common/distributions.py +++ b/torchy_baselines/common/distributions.py @@ -103,6 +103,8 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution): # Inverse tanh # Naive implementation (not stable): 0.5 * torch.log((1 + x ) / (1 - x)) # We use numpy to avoid numerical instability + # TODO: store the gaussian action because of the action added + # this would avoid also inverting the tanh if gaussian_action is None: gaussian_action = th.from_numpy(np.arctanh(action.cpu().numpy())).to(action.device)