mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-18 21:30:19 +00:00
Add sde test + fix random seed
This commit is contained in:
parent
925afe784c
commit
72a6f18e43
3 changed files with 22 additions and 12 deletions
|
|
@ -7,24 +7,22 @@ from torchy_baselines import A2C
|
|||
|
||||
|
||||
def test_state_dependent_exploration():
|
||||
n_states = 2
|
||||
state_dim = 3
|
||||
# TODO: fix for action_dim > 1
|
||||
action_dim = 1
|
||||
sigma = th.ones(state_dim, action_dim, requires_grad=True)
|
||||
|
||||
# log_sigma = th.ones(2, 1, requires_grad=True)
|
||||
|
||||
# weights_dist = Normal(th.zeros_like(log_sigma), th.exp(log_sigma))
|
||||
th.manual_seed(2)
|
||||
weights_dist = Normal(th.zeros_like(sigma), sigma)
|
||||
|
||||
weights = weights_dist.rsample()
|
||||
state = th.rand(1, state_dim)
|
||||
# state = (th.ones(state_dim,) * 2).view(1, -1)
|
||||
state = th.rand(n_states, state_dim)
|
||||
mu = th.ones(action_dim)
|
||||
# print(weights.shape, state.shape)
|
||||
noise = th.mm(state, weights)
|
||||
# variance = th.mm(state ** 2, th.exp(log_sigma) ** 2)
|
||||
|
||||
variance = th.mm(state ** 2, sigma ** 2)
|
||||
action_dist = Normal(mu, th.sqrt(variance))
|
||||
|
||||
|
|
@ -35,7 +33,8 @@ def test_state_dependent_exploration():
|
|||
grad = th.zeros_like(sigma)
|
||||
for j in range(action_dim):
|
||||
for i in range(state_dim):
|
||||
grad[i, j] = ((noise[:, j] ** 2 - variance[:, j]) / (variance[:, j] ** 2)) * (state[:, i] ** 2 * sigma[i, j])
|
||||
a = ((noise[:, j] ** 2 - variance[:, j]) / (variance[:, j] ** 2)) * (state[:, i] ** 2 * sigma[i, j])
|
||||
grad[i, j] = a.mean()
|
||||
|
||||
# sigma.grad should be equal to grad
|
||||
assert sigma.grad.allclose(grad)
|
||||
|
|
@ -43,6 +42,16 @@ def test_state_dependent_exploration():
|
|||
|
||||
@pytest.mark.parametrize("model_class", [A2C])
|
||||
def test_state_dependent_noise(model_class):
|
||||
model = model_class('MlpPolicy', 'Pendulum-v0', n_steps=200,
|
||||
use_sde=True, ent_coef=0.0, verbose=1, create_eval_env=True)
|
||||
model.learn(total_timesteps=int(1e6), log_interval=10, eval_freq=10000)
|
||||
import gym
|
||||
from torchy_baselines.common.vec_env import DummyVecEnv, VecNormalize
|
||||
from torchy_baselines.common.monitor import Monitor
|
||||
|
||||
# env_id = 'Pendulum-v0'
|
||||
env_id = 'MountainCarContinuous-v0'
|
||||
# env_id = 'LunarLanderContinuous-v2'
|
||||
env = VecNormalize(DummyVecEnv([lambda: Monitor(gym.make(env_id))]), norm_reward=True)
|
||||
eval_env = VecNormalize(DummyVecEnv([lambda: Monitor(gym.make(env_id))]), training=False, norm_reward=False)
|
||||
model = model_class('MlpPolicy', env, n_steps=200, max_grad_norm=1, use_rms_prop=False,
|
||||
use_sde=True, ent_coef=0.00, verbose=1, create_eval_env=True, learning_rate=3e-4,
|
||||
policy_kwargs=dict(log_std_init=0.0, ortho_init=False, net_arch=[256, dict(pi=[256], vf=[256])]), seed=None)
|
||||
model.learn(total_timesteps=int(20000), log_interval=5, eval_freq=10000, eval_env=eval_env)
|
||||
|
|
|
|||
|
|
@ -282,7 +282,9 @@ class BaseRLModel(object):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def set_random_seed(self, seed=0):
|
||||
def set_random_seed(self, seed=None):
|
||||
if seed is None:
|
||||
return
|
||||
set_random_seed(seed, using_cuda=self.device == th.device('cuda'))
|
||||
self.action_space.seed(seed)
|
||||
if self.env is not None:
|
||||
|
|
|
|||
|
|
@ -199,9 +199,8 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
self.weights_dist = Normal(th.zeros_like(log_std), self.get_std(log_std))
|
||||
self.exploration_mat = self.weights_dist.rsample()
|
||||
|
||||
def proba_distribution_net(self, latent_dim, log_std_init=-1):
|
||||
def proba_distribution_net(self, latent_dim, log_std_init=0.0):
|
||||
mean_actions = nn.Linear(latent_dim, self.action_dim)
|
||||
# TODO: log_std_init depending on the number of layers?
|
||||
log_std = nn.Parameter(th.ones(latent_dim, self.action_dim) * log_std_init)
|
||||
self.sample_weights(log_std)
|
||||
return mean_actions, log_std
|
||||
|
|
|
|||
Loading…
Reference in a new issue