Bug fix + add test for sde net arch

This commit is contained in:
Antonin Raffin 2019-12-02 14:14:48 +01:00
parent 8e9802784c
commit 3cdd5f20af
2 changed files with 7 additions and 19 deletions

View file

@ -9,7 +9,7 @@ from torchy_baselines.common.vec_env import DummyVecEnv, VecNormalize
from torchy_baselines.common.monitor import Monitor
def test_state_dependent_exploration():
def test_state_dependent_exploration_grad():
"""
Check that the gradient correspond to the expected one
"""
@ -19,7 +19,6 @@ def test_state_dependent_exploration():
sigma_hat = th.ones(state_dim, action_dim, requires_grad=True)
# Reduce the number of parameters
# sigma_ = th.ones(state_dim, action_dim) * sigma_
# weights_dist = Normal(th.zeros_like(log_sigma), th.exp(log_sigma))
th.manual_seed(2)
weights_dist = Normal(th.zeros_like(sigma_hat), sigma_hat)
@ -60,23 +59,11 @@ def test_state_dependent_exploration():
assert sigma_hat.grad.allclose(grad)
@pytest.mark.parametrize("model_class", [A2C])
@pytest.mark.parametrize("sde_net_arch", [None, [32, 16]])
def test_state_dependent_noise(model_class, sde_net_arch):
env_id = 'MountainCarContinuous-v0'
env = VecNormalize(DummyVecEnv([lambda: Monitor(gym.make(env_id))]), norm_reward=True)
eval_env = VecNormalize(DummyVecEnv([lambda: Monitor(gym.make(env_id))]), training=False, norm_reward=False)
model = model_class('MlpPolicy', env, n_steps=200, use_sde=True, ent_coef=0.00, verbose=1, learning_rate=3e-4,
policy_kwargs=dict(log_std_init=0.0, ortho_init=False, sde_net_arch=sde_net_arch), seed=None)
model.learn(total_timesteps=int(1000), log_interval=5, eval_freq=500, eval_env=eval_env)
@pytest.mark.parametrize("model_class", [TD3, SAC])
def test_state_dependent_offpolicy_noise(model_class):
@pytest.mark.parametrize("model_class", [TD3, SAC, A2C])
@pytest.mark.parametrize("sde_net_arch", [None, [32, 16], []])
def test_state_dependent_offpolicy_noise(model_class, sde_net_arch):
model = model_class('MlpPolicy', 'Pendulum-v0', use_sde=True, seed=None, create_eval_env=True,
verbose=1, policy_kwargs=dict(log_std_init=-2))
verbose=1, policy_kwargs=dict(log_std_init=-2, sde_net_arch=sde_net_arch))
model.learn(total_timesteps=int(1000), eval_freq=500)

View file

@ -4,7 +4,7 @@ import torch as th
import torch.nn as nn
import numpy as np
from torchy_baselines.common.policies import BasePolicy, register_policy, MlpExtractor, create_mlp
from torchy_baselines.common.policies import BasePolicy, register_policy, MlpExtractor, create_mlp, create_sde_feature_extractor
from torchy_baselines.common.distributions import make_proba_distribution,\
DiagGaussianDistribution, CategoricalDistribution, StateDependentNoiseDistribution
@ -143,6 +143,7 @@ class PPOPolicy(BasePolicy):
return self.action_dist.proba_distribution(mean_actions, self.log_std, deterministic=deterministic)
elif isinstance(self.action_dist, CategoricalDistribution):
# Here mean_actions are the logits before the softmax
return self.action_dist.proba_distribution(mean_actions, deterministic=deterministic)
elif isinstance(self.action_dist, StateDependentNoiseDistribution):