Add expln

This commit is contained in:
Antonin Raffin 2019-10-29 15:15:54 +01:00
parent 0d41bc1356
commit 42d50ed09b
2 changed files with 18 additions and 8 deletions

View file

@ -43,5 +43,6 @@ def test_state_dependent_exploration():
@pytest.mark.parametrize("model_class", [A2C])
def test_state_dependent_noise(model_class):
model = model_class('MlpPolicy', 'Pendulum-v0', n_steps=200, use_sde=True, verbose=1, create_eval_env=True)
model = model_class('MlpPolicy', 'Pendulum-v0', n_steps=200,
use_sde=True, ent_coef=0.0, verbose=1, create_eval_env=True)
model.learn(total_timesteps=int(1e6), log_interval=10, eval_freq=10000)

View file

@ -166,7 +166,7 @@ class CategoricalDistribution(Distribution):
class StateDependentNoiseDistribution(Distribution):
def __init__(self, features_dim, action_dim):
def __init__(self, features_dim, action_dim, use_expln=False):
super(StateDependentNoiseDistribution, self).__init__()
self.distribution = None
self.action_dim = action_dim
@ -175,19 +175,28 @@ class StateDependentNoiseDistribution(Distribution):
self.log_std = None
self.weights_dist = None
self.noise_weights = None
self.use_expln = use_expln
@staticmethod
def get_std(log_std):
# TODO: use expln instead of exp only to avoid sigma growing too fast
return th.exp(log_std)
def get_std(self, log_std):
if self.use_expln:
# From SDE paper, it allows to keep variance
# above zero and prevent it from growing too fast
if log_std <= 0:
return th.exp(log_std)
else:
return th.log(log_std + 1.0) + 1.0
else:
return th.exp(log_std)
def sample_weights(self, log_std):
self.weights_dist = Normal(th.zeros_like(log_std), self.get_std(log_std))
self.noise_weights = self.weights_dist.rsample()
def proba_distribution_net(self, latent_dim, log_std_init=0.0):
def proba_distribution_net(self, latent_dim, log_std_init=-3):
print("Log std init:", log_std_init)
mean_actions = nn.Linear(latent_dim, self.action_dim)
log_std = nn.Parameter(th.zeros(self.features_dim, self.action_dim))
# TODO: log_std_init depending on the number of layers?
log_std = nn.Parameter(th.ones(self.features_dim, self.action_dim) * log_std_init)
self.sample_weights(log_std)
return mean_actions, log_std