Add sde scheduler

This commit is contained in:
Antonin Raffin 2019-11-18 16:03:08 +01:00
parent d8a7556d84
commit ad32aa60f3
2 changed files with 23 additions and 4 deletions

View file

@ -65,3 +65,13 @@ def test_state_dependent_offpolicy_noise(model_class):
model = model_class('MlpPolicy', 'Pendulum-v0', use_sde=True, seed=None, create_eval_env=True,
verbose=1, policy_kwargs=dict(log_std_init=-2))
model.learn(total_timesteps=int(1000), eval_freq=500)
def test_scheduler():
def scheduler(progress):
return -2.0 * progress + 1
model = TD3('MlpPolicy', 'Pendulum-v0', use_sde=True, seed=None, create_eval_env=True,
verbose=1, sde_log_std_scheduler=scheduler)
model.learn(total_timesteps=int(1000), eval_freq=500)
assert th.isclose(model.actor.log_std, th.ones_like(model.actor.log_std)).all()

View file

@ -40,6 +40,9 @@ class TD3(BaseRLModel):
:param target_noise_clip: (float) Limit for absolute value of target policy smoothing noise.
:param use_sde: (bool) Whether to use State Dependent Exploration (SDE)
instead of action noise exploration (default: False)
:param sde_max_grad_norm: (float)
:param sde_ent_coef: (float)
:param sde_log_std_scheduler: (callable)
:param create_eval_env: (bool) Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment)
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
@ -53,7 +56,7 @@ class TD3(BaseRLModel):
policy_delay=2, learning_starts=100, gamma=0.99, batch_size=100,
train_freq=-1, gradient_steps=-1, n_episodes_rollout=1,
tau=0.005, action_noise=None, target_policy_noise=0.2, target_noise_clip=0.5,
use_sde=False, sde_max_grad_norm=1, sde_ent_coef=0.0,
use_sde=False, sde_max_grad_norm=1, sde_ent_coef=0.0, sde_log_std_scheduler=None,
tensorboard_log=None, create_eval_env=False, policy_kwargs=None, verbose=0,
seed=0, device='auto', _init_setup_model=True):
@ -78,6 +81,7 @@ class TD3(BaseRLModel):
self.use_sde = use_sde
self.sde_max_grad_norm = sde_max_grad_norm
self.sde_ent_coef = sde_ent_coef
self.sde_log_std_scheduler = sde_log_std_scheduler
if _init_setup_model:
self._setup_model()
@ -228,8 +232,7 @@ class TD3(BaseRLModel):
assert not th.isnan(entropy).any()
assert not th.isnan(self.actor.log_std.grad).any()
assert not th.isnan(self.actor.log_std).any()
# print(self.actor.log_std.grad.mean().item(), self.actor.log_std.grad.max().item(), self.actor.log_std.grad.min().item())
# print(self.actor.log_std.mean().item(), self.actor.log_std.max().item(), self.actor.log_std.min().item())
# Clip grad norm
th.nn.utils.clip_grad_norm_([self.actor.log_std], self.sde_max_grad_norm)
self.actor.sde_optimizer.step()
@ -270,7 +273,13 @@ class TD3(BaseRLModel):
self.num_timesteps, episode_num, episode_timesteps, episode_reward))
if self.use_sde:
self.train_sde()
if self.sde_log_std_scheduler is not None:
# Call the scheduler
value = self.sde_log_std_scheduler(self._current_progress)
self.actor.log_std.data = th.ones_like(self.actor.log_std) * value
else:
# On-policy gradient
self.train_sde()
gradient_steps = self.gradient_steps if self.gradient_steps > 0 else episode_timesteps
self.train(gradient_steps, batch_size=self.batch_size, policy_delay=self.policy_delay)