From ad32aa60f3659b8bbb72fa44cdc4addaa73e85ad Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 18 Nov 2019 16:03:08 +0100 Subject: [PATCH] Add sde scheduler --- tests/test_sde.py | 10 ++++++++++ torchy_baselines/td3/td3.py | 17 +++++++++++++---- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/tests/test_sde.py b/tests/test_sde.py index 7c6f5ff..09b48e8 100644 --- a/tests/test_sde.py +++ b/tests/test_sde.py @@ -65,3 +65,13 @@ def test_state_dependent_offpolicy_noise(model_class): model = model_class('MlpPolicy', 'Pendulum-v0', use_sde=True, seed=None, create_eval_env=True, verbose=1, policy_kwargs=dict(log_std_init=-2)) model.learn(total_timesteps=int(1000), eval_freq=500) + + +def test_scheduler(): + def scheduler(progress): + return -2.0 * progress + 1 + + model = TD3('MlpPolicy', 'Pendulum-v0', use_sde=True, seed=None, create_eval_env=True, + verbose=1, sde_log_std_scheduler=scheduler) + model.learn(total_timesteps=int(1000), eval_freq=500) + assert th.isclose(model.actor.log_std, th.ones_like(model.actor.log_std)).all() diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 6c1dbcf..ddd7e9d 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -40,6 +40,9 @@ class TD3(BaseRLModel): :param target_noise_clip: (float) Limit for absolute value of target policy smoothing noise. :param use_sde: (bool) Whether to use State Dependent Exploration (SDE) instead of action noise exploration (default: False) + :param sde_max_grad_norm: (float) + :param sde_ent_coef: (float) + :param sde_log_std_scheduler: (callable) :param create_eval_env: (bool) Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment) :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation @@ -53,7 +56,7 @@ class TD3(BaseRLModel): policy_delay=2, learning_starts=100, gamma=0.99, batch_size=100, train_freq=-1, gradient_steps=-1, n_episodes_rollout=1, tau=0.005, action_noise=None, target_policy_noise=0.2, target_noise_clip=0.5, - use_sde=False, sde_max_grad_norm=1, sde_ent_coef=0.0, + use_sde=False, sde_max_grad_norm=1, sde_ent_coef=0.0, sde_log_std_scheduler=None, tensorboard_log=None, create_eval_env=False, policy_kwargs=None, verbose=0, seed=0, device='auto', _init_setup_model=True): @@ -78,6 +81,7 @@ class TD3(BaseRLModel): self.use_sde = use_sde self.sde_max_grad_norm = sde_max_grad_norm self.sde_ent_coef = sde_ent_coef + self.sde_log_std_scheduler = sde_log_std_scheduler if _init_setup_model: self._setup_model() @@ -228,8 +232,7 @@ class TD3(BaseRLModel): assert not th.isnan(entropy).any() assert not th.isnan(self.actor.log_std.grad).any() assert not th.isnan(self.actor.log_std).any() - # print(self.actor.log_std.grad.mean().item(), self.actor.log_std.grad.max().item(), self.actor.log_std.grad.min().item()) - # print(self.actor.log_std.mean().item(), self.actor.log_std.max().item(), self.actor.log_std.min().item()) + # Clip grad norm th.nn.utils.clip_grad_norm_([self.actor.log_std], self.sde_max_grad_norm) self.actor.sde_optimizer.step() @@ -270,7 +273,13 @@ class TD3(BaseRLModel): self.num_timesteps, episode_num, episode_timesteps, episode_reward)) if self.use_sde: - self.train_sde() + if self.sde_log_std_scheduler is not None: + # Call the scheduler + value = self.sde_log_std_scheduler(self._current_progress) + self.actor.log_std.data = th.ones_like(self.actor.log_std) * value + else: + # On-policy gradient + self.train_sde() gradient_steps = self.gradient_steps if self.gradient_steps > 0 else episode_timesteps self.train(gradient_steps, batch_size=self.batch_size, policy_delay=self.policy_delay)