mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-05 04:17:59 +00:00
Add gradient clipping for SAC
This commit is contained in:
parent
233f346d53
commit
d63cef7693
2 changed files with 7 additions and 2 deletions
|
|
@ -23,7 +23,7 @@ TODO:
|
|||
- Refactor: buffer with numpy array instead of pytorch
|
||||
- Refactor: remove duplicated code for evaluation
|
||||
- double check the shape of log prob
|
||||
|
||||
- try squashing both mean and output when using SAC + SDE
|
||||
- plotting? -> zoo
|
||||
|
||||
Later:
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ class SAC(BaseRLModel):
|
|||
:param target_entropy: (str or float) target entropy when learning ent_coef (ent_coef = 'auto')
|
||||
:param action_noise: (ActionNoise) the action noise type (None by default), this can help
|
||||
for hard exploration problem. Cf common.noise for the different action noise type.
|
||||
:param max_grad_norm: (float) The maximum value for the gradient clipping (None by default)
|
||||
:param gamma: (float) the discount factor
|
||||
:param use_sde: (bool) Whether to use State Dependent Exploration (SDE)
|
||||
instead of action noise exploration (default: False)
|
||||
|
|
@ -61,7 +62,7 @@ class SAC(BaseRLModel):
|
|||
learning_starts=100, batch_size=256,
|
||||
tau=0.005, ent_coef='auto', target_update_interval=1,
|
||||
train_freq=1, gradient_steps=1, n_episodes_rollout=-1,
|
||||
target_entropy='auto', action_noise=None,
|
||||
target_entropy='auto', action_noise=None, max_grad_norm=None,
|
||||
gamma=0.99, use_sde=False, tensorboard_log=None, create_eval_env=False,
|
||||
policy_kwargs=None, verbose=0, seed=0, device='auto',
|
||||
_init_setup_model=True):
|
||||
|
|
@ -89,6 +90,7 @@ class SAC(BaseRLModel):
|
|||
self.action_noise = action_noise
|
||||
self.gamma = gamma
|
||||
self.ent_coef_optimizer = None
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
if _init_setup_model:
|
||||
self._setup_model()
|
||||
|
|
@ -236,6 +238,9 @@ class SAC(BaseRLModel):
|
|||
# Optimize the actor
|
||||
self.actor.optimizer.zero_grad()
|
||||
actor_loss.backward()
|
||||
# Clip grad norm
|
||||
if self.max_grad_norm is not None:
|
||||
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
||||
self.actor.optimizer.step()
|
||||
|
||||
# Update target networks
|
||||
|
|
|
|||
Loading…
Reference in a new issue