diff --git a/torchy_baselines/sac/policies.py b/torchy_baselines/sac/policies.py index 610b59e..61080c0 100644 --- a/torchy_baselines/sac/policies.py +++ b/torchy_baselines/sac/policies.py @@ -49,7 +49,9 @@ class Actor(BaseNetwork): latent_sde_dim=latent_sde_dim, log_std_init=log_std_init) # Avoid saturation by limiting the mean of the gaussian to be in [-1, 1] - self.mu = nn.Sequential(self.mu, nn.Tanh()) + # self.mu = nn.Sequential(self.mu, nn.Tanh()) + # TODO: test with small positive slope to have non zero gradient + self.mu = nn.Sequential(self.mu, nn.Hardtanh(min_val=-2.0, max_val=2.0)) else: self.action_dist = SquashedDiagGaussianDistribution(action_dim) self.mu = nn.Linear(net_arch[-1], action_dim)