mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-22 22:10:16 +00:00
Bug fix + add test for sde net arch
This commit is contained in:
parent
8e9802784c
commit
3cdd5f20af
2 changed files with 7 additions and 19 deletions
|
|
@ -9,7 +9,7 @@ from torchy_baselines.common.vec_env import DummyVecEnv, VecNormalize
|
|||
from torchy_baselines.common.monitor import Monitor
|
||||
|
||||
|
||||
def test_state_dependent_exploration():
|
||||
def test_state_dependent_exploration_grad():
|
||||
"""
|
||||
Check that the gradient correspond to the expected one
|
||||
"""
|
||||
|
|
@ -19,7 +19,6 @@ def test_state_dependent_exploration():
|
|||
sigma_hat = th.ones(state_dim, action_dim, requires_grad=True)
|
||||
# Reduce the number of parameters
|
||||
# sigma_ = th.ones(state_dim, action_dim) * sigma_
|
||||
|
||||
# weights_dist = Normal(th.zeros_like(log_sigma), th.exp(log_sigma))
|
||||
th.manual_seed(2)
|
||||
weights_dist = Normal(th.zeros_like(sigma_hat), sigma_hat)
|
||||
|
|
@ -60,23 +59,11 @@ def test_state_dependent_exploration():
|
|||
assert sigma_hat.grad.allclose(grad)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C])
|
||||
@pytest.mark.parametrize("sde_net_arch", [None, [32, 16]])
|
||||
def test_state_dependent_noise(model_class, sde_net_arch):
|
||||
env_id = 'MountainCarContinuous-v0'
|
||||
|
||||
env = VecNormalize(DummyVecEnv([lambda: Monitor(gym.make(env_id))]), norm_reward=True)
|
||||
eval_env = VecNormalize(DummyVecEnv([lambda: Monitor(gym.make(env_id))]), training=False, norm_reward=False)
|
||||
|
||||
model = model_class('MlpPolicy', env, n_steps=200, use_sde=True, ent_coef=0.00, verbose=1, learning_rate=3e-4,
|
||||
policy_kwargs=dict(log_std_init=0.0, ortho_init=False, sde_net_arch=sde_net_arch), seed=None)
|
||||
model.learn(total_timesteps=int(1000), log_interval=5, eval_freq=500, eval_env=eval_env)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [TD3, SAC])
|
||||
def test_state_dependent_offpolicy_noise(model_class):
|
||||
@pytest.mark.parametrize("model_class", [TD3, SAC, A2C])
|
||||
@pytest.mark.parametrize("sde_net_arch", [None, [32, 16], []])
|
||||
def test_state_dependent_offpolicy_noise(model_class, sde_net_arch):
|
||||
model = model_class('MlpPolicy', 'Pendulum-v0', use_sde=True, seed=None, create_eval_env=True,
|
||||
verbose=1, policy_kwargs=dict(log_std_init=-2))
|
||||
verbose=1, policy_kwargs=dict(log_std_init=-2, sde_net_arch=sde_net_arch))
|
||||
model.learn(total_timesteps=int(1000), eval_freq=500)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import torch as th
|
|||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from torchy_baselines.common.policies import BasePolicy, register_policy, MlpExtractor, create_mlp
|
||||
from torchy_baselines.common.policies import BasePolicy, register_policy, MlpExtractor, create_mlp, create_sde_feature_extractor
|
||||
from torchy_baselines.common.distributions import make_proba_distribution,\
|
||||
DiagGaussianDistribution, CategoricalDistribution, StateDependentNoiseDistribution
|
||||
|
||||
|
|
@ -143,6 +143,7 @@ class PPOPolicy(BasePolicy):
|
|||
return self.action_dist.proba_distribution(mean_actions, self.log_std, deterministic=deterministic)
|
||||
|
||||
elif isinstance(self.action_dist, CategoricalDistribution):
|
||||
# Here mean_actions are the logits before the softmax
|
||||
return self.action_dist.proba_distribution(mean_actions, deterministic=deterministic)
|
||||
|
||||
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
|
||||
|
|
|
|||
Loading…
Reference in a new issue