diff --git a/tests/test_sde.py b/tests/test_sde.py index 7874ae9..3c5db43 100644 --- a/tests/test_sde.py +++ b/tests/test_sde.py @@ -43,5 +43,6 @@ def test_state_dependent_exploration(): @pytest.mark.parametrize("model_class", [A2C]) def test_state_dependent_noise(model_class): - model = model_class('MlpPolicy', 'Pendulum-v0', n_steps=200, use_sde=True, verbose=1, create_eval_env=True) + model = model_class('MlpPolicy', 'Pendulum-v0', n_steps=200, + use_sde=True, ent_coef=0.0, verbose=1, create_eval_env=True) model.learn(total_timesteps=int(1e6), log_interval=10, eval_freq=10000) diff --git a/torchy_baselines/common/distributions.py b/torchy_baselines/common/distributions.py index d420e03..2944d8c 100644 --- a/torchy_baselines/common/distributions.py +++ b/torchy_baselines/common/distributions.py @@ -166,7 +166,7 @@ class CategoricalDistribution(Distribution): class StateDependentNoiseDistribution(Distribution): - def __init__(self, features_dim, action_dim): + def __init__(self, features_dim, action_dim, use_expln=False): super(StateDependentNoiseDistribution, self).__init__() self.distribution = None self.action_dim = action_dim @@ -175,19 +175,28 @@ class StateDependentNoiseDistribution(Distribution): self.log_std = None self.weights_dist = None self.noise_weights = None + self.use_expln = use_expln - @staticmethod - def get_std(log_std): - # TODO: use expln instead of exp only to avoid sigma growing too fast - return th.exp(log_std) + def get_std(self, log_std): + if self.use_expln: + # From SDE paper, it allows to keep variance + # above zero and prevent it from growing too fast + if log_std <= 0: + return th.exp(log_std) + else: + return th.log(log_std + 1.0) + 1.0 + else: + return th.exp(log_std) def sample_weights(self, log_std): self.weights_dist = Normal(th.zeros_like(log_std), self.get_std(log_std)) self.noise_weights = self.weights_dist.rsample() - def proba_distribution_net(self, latent_dim, log_std_init=0.0): + def proba_distribution_net(self, latent_dim, log_std_init=-3): + print("Log std init:", log_std_init) mean_actions = nn.Linear(latent_dim, self.action_dim) - log_std = nn.Parameter(th.zeros(self.features_dim, self.action_dim)) + # TODO: log_std_init depending on the number of layers? + log_std = nn.Parameter(th.ones(self.features_dim, self.action_dim) * log_std_init) self.sample_weights(log_std) return mean_actions, log_std