mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-17 21:20:11 +00:00
Add sde scheduler
This commit is contained in:
parent
d8a7556d84
commit
ad32aa60f3
2 changed files with 23 additions and 4 deletions
|
|
@ -65,3 +65,13 @@ def test_state_dependent_offpolicy_noise(model_class):
|
|||
model = model_class('MlpPolicy', 'Pendulum-v0', use_sde=True, seed=None, create_eval_env=True,
|
||||
verbose=1, policy_kwargs=dict(log_std_init=-2))
|
||||
model.learn(total_timesteps=int(1000), eval_freq=500)
|
||||
|
||||
|
||||
def test_scheduler():
|
||||
def scheduler(progress):
|
||||
return -2.0 * progress + 1
|
||||
|
||||
model = TD3('MlpPolicy', 'Pendulum-v0', use_sde=True, seed=None, create_eval_env=True,
|
||||
verbose=1, sde_log_std_scheduler=scheduler)
|
||||
model.learn(total_timesteps=int(1000), eval_freq=500)
|
||||
assert th.isclose(model.actor.log_std, th.ones_like(model.actor.log_std)).all()
|
||||
|
|
|
|||
|
|
@ -40,6 +40,9 @@ class TD3(BaseRLModel):
|
|||
:param target_noise_clip: (float) Limit for absolute value of target policy smoothing noise.
|
||||
:param use_sde: (bool) Whether to use State Dependent Exploration (SDE)
|
||||
instead of action noise exploration (default: False)
|
||||
:param sde_max_grad_norm: (float)
|
||||
:param sde_ent_coef: (float)
|
||||
:param sde_log_std_scheduler: (callable)
|
||||
:param create_eval_env: (bool) Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
|
||||
|
|
@ -53,7 +56,7 @@ class TD3(BaseRLModel):
|
|||
policy_delay=2, learning_starts=100, gamma=0.99, batch_size=100,
|
||||
train_freq=-1, gradient_steps=-1, n_episodes_rollout=1,
|
||||
tau=0.005, action_noise=None, target_policy_noise=0.2, target_noise_clip=0.5,
|
||||
use_sde=False, sde_max_grad_norm=1, sde_ent_coef=0.0,
|
||||
use_sde=False, sde_max_grad_norm=1, sde_ent_coef=0.0, sde_log_std_scheduler=None,
|
||||
tensorboard_log=None, create_eval_env=False, policy_kwargs=None, verbose=0,
|
||||
seed=0, device='auto', _init_setup_model=True):
|
||||
|
||||
|
|
@ -78,6 +81,7 @@ class TD3(BaseRLModel):
|
|||
self.use_sde = use_sde
|
||||
self.sde_max_grad_norm = sde_max_grad_norm
|
||||
self.sde_ent_coef = sde_ent_coef
|
||||
self.sde_log_std_scheduler = sde_log_std_scheduler
|
||||
|
||||
if _init_setup_model:
|
||||
self._setup_model()
|
||||
|
|
@ -228,8 +232,7 @@ class TD3(BaseRLModel):
|
|||
assert not th.isnan(entropy).any()
|
||||
assert not th.isnan(self.actor.log_std.grad).any()
|
||||
assert not th.isnan(self.actor.log_std).any()
|
||||
# print(self.actor.log_std.grad.mean().item(), self.actor.log_std.grad.max().item(), self.actor.log_std.grad.min().item())
|
||||
# print(self.actor.log_std.mean().item(), self.actor.log_std.max().item(), self.actor.log_std.min().item())
|
||||
|
||||
# Clip grad norm
|
||||
th.nn.utils.clip_grad_norm_([self.actor.log_std], self.sde_max_grad_norm)
|
||||
self.actor.sde_optimizer.step()
|
||||
|
|
@ -270,7 +273,13 @@ class TD3(BaseRLModel):
|
|||
self.num_timesteps, episode_num, episode_timesteps, episode_reward))
|
||||
|
||||
if self.use_sde:
|
||||
self.train_sde()
|
||||
if self.sde_log_std_scheduler is not None:
|
||||
# Call the scheduler
|
||||
value = self.sde_log_std_scheduler(self._current_progress)
|
||||
self.actor.log_std.data = th.ones_like(self.actor.log_std) * value
|
||||
else:
|
||||
# On-policy gradient
|
||||
self.train_sde()
|
||||
|
||||
gradient_steps = self.gradient_steps if self.gradient_steps > 0 else episode_timesteps
|
||||
self.train(gradient_steps, batch_size=self.batch_size, policy_delay=self.policy_delay)
|
||||
|
|
|
|||
Loading…
Reference in a new issue