stable-baselines3/tests/test_sde.py

79 lines
2.7 KiB
Python
Raw Normal View History

2019-10-28 17:24:13 +00:00
import pytest
import torch as th
from torch.distributions import Normal
from stable_baselines3 import A2C, PPO, SAC
2019-10-28 17:24:13 +00:00
2019-12-02 13:14:48 +00:00
def test_state_dependent_exploration_grad():
2019-11-18 13:09:31 +00:00
"""
Check that the gradient correspond to the expected one
"""
2019-10-31 13:14:30 +00:00
n_states = 2
2019-10-28 17:24:13 +00:00
state_dim = 3
2019-11-26 10:57:48 +00:00
action_dim = 10
sigma_hat = th.ones(state_dim, action_dim, requires_grad=True)
2019-11-18 13:09:31 +00:00
# Reduce the number of parameters
# sigma_ = th.ones(state_dim, action_dim) * sigma_
2019-10-28 17:24:13 +00:00
# weights_dist = Normal(th.zeros_like(log_sigma), th.exp(log_sigma))
th.manual_seed(2)
2019-11-26 10:57:48 +00:00
weights_dist = Normal(th.zeros_like(sigma_hat), sigma_hat)
2019-10-28 17:24:13 +00:00
weights = weights_dist.rsample()
2019-11-26 10:57:48 +00:00
2019-10-31 13:14:30 +00:00
state = th.rand(n_states, state_dim)
2019-10-28 17:24:13 +00:00
mu = th.ones(action_dim)
noise = th.mm(state, weights)
2019-10-31 13:14:30 +00:00
2019-11-26 10:57:48 +00:00
action = mu + noise
variance = th.mm(state**2, sigma_hat**2)
2019-10-28 17:24:13 +00:00
action_dist = Normal(mu, th.sqrt(variance))
2019-11-26 10:57:48 +00:00
# Sum over the action dimension because we assume they are independent
2020-01-20 10:17:55 +00:00
loss = action_dist.log_prob(action.detach()).sum(dim=-1).mean()
2019-10-28 17:24:13 +00:00
loss.backward()
2019-11-26 10:57:48 +00:00
# From Rueckstiess paper: check that the computed gradient
# correspond to the analytical form
grad = th.zeros_like(sigma_hat)
2019-10-28 17:24:13 +00:00
for j in range(action_dim):
2019-11-26 10:57:48 +00:00
# sigma_hat is the std of the gaussian distribution of the noise matrix weights
# sigma_j = sum_j(state_i **2 * sigma_hat_ij ** 2)
# sigma_j is the standard deviation of the policy gaussian distribution
sigma_j = th.sqrt(variance[:, j])
2019-10-28 17:24:13 +00:00
for i in range(state_dim):
2019-11-26 10:57:48 +00:00
# Derivative of the log probability of the jth component of the action
# w.r.t. the standard deviation sigma_j
d_log_policy_j = (noise[:, j] ** 2 - sigma_j**2) / sigma_j**3
2019-11-26 10:57:48 +00:00
# Derivative of sigma_j w.r.t. sigma_hat_ij
d_log_sigma_j = (state[:, i] ** 2 * sigma_hat[i, j]) / sigma_j
# Chain rule, average over the minibatch
grad[i, j] = (d_log_policy_j * d_log_sigma_j).mean()
2019-10-28 17:24:13 +00:00
# sigma.grad should be equal to grad
2019-11-26 10:57:48 +00:00
assert sigma_hat.grad.allclose(grad)
2019-10-28 17:24:13 +00:00
Implement HER (#120) * Added working her version, Online sampling is missing. * Updated test_her. * Added first version of online her sampling. Still problems with tensor dimensions. * Reformat * Fixed tests * Added some comments. * Updated changelog. * Add missing init file * Fixed some small bugs. * Reduced arguments for HER, small changes. * Added getattr. Fixed bug for online sampling. * Updated save/load funtions. Small changes. * Added her to init. * Updated save method. * Updated her ratio. * Move obs_wrapper * Added DQN test. * Fix potential bug * Offline and online her share same sample_goal function. * Changed lists into arrays. * Updated her test. * Fix online sampling * Fixed action bug. Updated time limit for episodes. * Updated convert_dict method to take keys as arguments. * Renamed obs dict wrapper. * Seed bit flipping env * Remove get_episode_dict * Add fast online sampling version * Added documentation. * Vectorized reward computation * Vectorized goal sampling * Update time limit for episodes in online her sampling. * Fix max episode length inference * Bug fix for Fetch envs * Fix for HER + gSDE * Reformat (new black version) * Added info dict to compute new reward. Check her_replay_buffer again. * Fix info buffer * Updated done flag. * Fixes for gSDE * Offline her version uses now HerReplayBuffer as episode storage. * Fix num_timesteps computation * Fix get torch params * Vectorized version for offline sampling. * Modified offline her sampling to use sample method of her_replay_buffer * Updated HER tests. * Updated documentation * Cleanup docstrings * Updated to review comments * Fix pytype * Update according to review comments. * Removed random goal strategy. Updated sample transitions. * Updated migration. Removed time signal removal. * Update doc * Fix potential load issue * Add VecNormalize support for dict obs * Updated saving/loading replay buffer for HER. * Fix test memory usage * Fixed save/load replay buffer. * Fixed save/load replay buffer * Fixed transition index after loading replay buffer in online sampling * Better error handling * Add tests for get_time_limit * More tests for VecNormalize with dict obs * Update doc * Improve HER description * Add test for sde support * Add comments * Add comments * Remove check that was always valid * Fix for terminal observation * Updated buffer size in offline version and reset of HER buffer * Reformat * Update doc * Remove np.empty + add doc * Fix loading * Updated loading replay buffer * Separate online and offline sampling + bug fixes * Update tensorboard log name * Version bump * Bug fix for special case Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-10-22 09:56:43 +00:00
def test_sde_check():
with pytest.raises(ValueError):
PPO("MlpPolicy", "CartPole-v1", use_sde=True)
2020-05-08 13:00:34 +00:00
@pytest.mark.parametrize("model_class", [SAC, A2C, PPO])
2020-03-11 15:35:13 +00:00
@pytest.mark.parametrize("use_expln", [False, True])
def test_state_dependent_noise(model_class, use_expln):
kwargs = {"learning_starts": 0} if model_class == SAC else {"n_steps": 64}
model = model_class(
"MlpPolicy",
"Pendulum-v1",
use_sde=True,
seed=None,
verbose=1,
policy_kwargs=dict(log_std_init=-2, use_expln=use_expln, net_arch=[64]),
**kwargs,
)
model.learn(total_timesteps=255)
model.policy.reset_noise()
if model_class == SAC:
model.policy.actor.get_std()