From 3cdd5f20afc766a89fd2fa0a06b33292a36a2d64 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 2 Dec 2019 14:14:48 +0100 Subject: [PATCH] Bug fix + add test for sde net arch --- tests/test_sde.py | 23 +++++------------------ torchy_baselines/ppo/policies.py | 3 ++- 2 files changed, 7 insertions(+), 19 deletions(-) diff --git a/tests/test_sde.py b/tests/test_sde.py index 668bc9e..851ca42 100644 --- a/tests/test_sde.py +++ b/tests/test_sde.py @@ -9,7 +9,7 @@ from torchy_baselines.common.vec_env import DummyVecEnv, VecNormalize from torchy_baselines.common.monitor import Monitor -def test_state_dependent_exploration(): +def test_state_dependent_exploration_grad(): """ Check that the gradient correspond to the expected one """ @@ -19,7 +19,6 @@ def test_state_dependent_exploration(): sigma_hat = th.ones(state_dim, action_dim, requires_grad=True) # Reduce the number of parameters # sigma_ = th.ones(state_dim, action_dim) * sigma_ - # weights_dist = Normal(th.zeros_like(log_sigma), th.exp(log_sigma)) th.manual_seed(2) weights_dist = Normal(th.zeros_like(sigma_hat), sigma_hat) @@ -60,23 +59,11 @@ def test_state_dependent_exploration(): assert sigma_hat.grad.allclose(grad) -@pytest.mark.parametrize("model_class", [A2C]) -@pytest.mark.parametrize("sde_net_arch", [None, [32, 16]]) -def test_state_dependent_noise(model_class, sde_net_arch): - env_id = 'MountainCarContinuous-v0' - - env = VecNormalize(DummyVecEnv([lambda: Monitor(gym.make(env_id))]), norm_reward=True) - eval_env = VecNormalize(DummyVecEnv([lambda: Monitor(gym.make(env_id))]), training=False, norm_reward=False) - - model = model_class('MlpPolicy', env, n_steps=200, use_sde=True, ent_coef=0.00, verbose=1, learning_rate=3e-4, - policy_kwargs=dict(log_std_init=0.0, ortho_init=False, sde_net_arch=sde_net_arch), seed=None) - model.learn(total_timesteps=int(1000), log_interval=5, eval_freq=500, eval_env=eval_env) - - -@pytest.mark.parametrize("model_class", [TD3, SAC]) -def test_state_dependent_offpolicy_noise(model_class): +@pytest.mark.parametrize("model_class", [TD3, SAC, A2C]) +@pytest.mark.parametrize("sde_net_arch", [None, [32, 16], []]) +def test_state_dependent_offpolicy_noise(model_class, sde_net_arch): model = model_class('MlpPolicy', 'Pendulum-v0', use_sde=True, seed=None, create_eval_env=True, - verbose=1, policy_kwargs=dict(log_std_init=-2)) + verbose=1, policy_kwargs=dict(log_std_init=-2, sde_net_arch=sde_net_arch)) model.learn(total_timesteps=int(1000), eval_freq=500) diff --git a/torchy_baselines/ppo/policies.py b/torchy_baselines/ppo/policies.py index 9b38317..886d958 100644 --- a/torchy_baselines/ppo/policies.py +++ b/torchy_baselines/ppo/policies.py @@ -4,7 +4,7 @@ import torch as th import torch.nn as nn import numpy as np -from torchy_baselines.common.policies import BasePolicy, register_policy, MlpExtractor, create_mlp +from torchy_baselines.common.policies import BasePolicy, register_policy, MlpExtractor, create_mlp, create_sde_feature_extractor from torchy_baselines.common.distributions import make_proba_distribution,\ DiagGaussianDistribution, CategoricalDistribution, StateDependentNoiseDistribution @@ -143,6 +143,7 @@ class PPOPolicy(BasePolicy): return self.action_dist.proba_distribution(mean_actions, self.log_std, deterministic=deterministic) elif isinstance(self.action_dist, CategoricalDistribution): + # Here mean_actions are the logits before the softmax return self.action_dist.proba_distribution(mean_actions, deterministic=deterministic) elif isinstance(self.action_dist, StateDependentNoiseDistribution):