From d63cef7693e630fe42541471902168e655210375 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 6 Dec 2019 18:32:57 +0100 Subject: [PATCH] Add gradient clipping for SAC --- README.md | 2 +- torchy_baselines/sac/sac.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 388570a..8ed3ae3 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ TODO: - Refactor: buffer with numpy array instead of pytorch - Refactor: remove duplicated code for evaluation - double check the shape of log prob - +- try squashing both mean and output when using SAC + SDE - plotting? -> zoo Later: diff --git a/torchy_baselines/sac/sac.py b/torchy_baselines/sac/sac.py index b5d2a17..e170cd0 100644 --- a/torchy_baselines/sac/sac.py +++ b/torchy_baselines/sac/sac.py @@ -44,6 +44,7 @@ class SAC(BaseRLModel): :param target_entropy: (str or float) target entropy when learning ent_coef (ent_coef = 'auto') :param action_noise: (ActionNoise) the action noise type (None by default), this can help for hard exploration problem. Cf common.noise for the different action noise type. + :param max_grad_norm: (float) The maximum value for the gradient clipping (None by default) :param gamma: (float) the discount factor :param use_sde: (bool) Whether to use State Dependent Exploration (SDE) instead of action noise exploration (default: False) @@ -61,7 +62,7 @@ class SAC(BaseRLModel): learning_starts=100, batch_size=256, tau=0.005, ent_coef='auto', target_update_interval=1, train_freq=1, gradient_steps=1, n_episodes_rollout=-1, - target_entropy='auto', action_noise=None, + target_entropy='auto', action_noise=None, max_grad_norm=None, gamma=0.99, use_sde=False, tensorboard_log=None, create_eval_env=False, policy_kwargs=None, verbose=0, seed=0, device='auto', _init_setup_model=True): @@ -89,6 +90,7 @@ class SAC(BaseRLModel): self.action_noise = action_noise self.gamma = gamma self.ent_coef_optimizer = None + self.max_grad_norm = max_grad_norm if _init_setup_model: self._setup_model() @@ -236,6 +238,9 @@ class SAC(BaseRLModel): # Optimize the actor self.actor.optimizer.zero_grad() actor_loss.backward() + # Clip grad norm + if self.max_grad_norm is not None: + th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.actor.optimizer.step() # Update target networks