mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-24 02:48:02 +00:00
Add squash_output and expln as policy param for ppo and a2c
This commit is contained in:
parent
60d5f4463d
commit
03e853997a
1 changed files with 9 additions and 3 deletions
|
|
@ -29,12 +29,18 @@ class PPOPolicy(BasePolicy):
|
|||
:param sde_net_arch: ([int]) Network architecture for extracting features
|
||||
when using SDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
:param use_expln: (bool) Use `expln()` function instead of `exp()` to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, `exp()` is usually enough.
|
||||
:param squash_output: (bool) Whether to squash the output using a tanh function,
|
||||
this allows to ensure boundaries when using SDE.
|
||||
"""
|
||||
def __init__(self, observation_space, action_space,
|
||||
learning_rate, net_arch=None, device='cpu',
|
||||
activation_fn=nn.Tanh, adam_epsilon=1e-5,
|
||||
ortho_init=True, use_sde=False,
|
||||
log_std_init=0.0, full_std=True, sde_net_arch=None):
|
||||
log_std_init=0.0, full_std=True,
|
||||
sde_net_arch=None, use_expln=False, squash_output=False):
|
||||
super(PPOPolicy, self).__init__(observation_space, action_space, device)
|
||||
self.obs_dim = self.observation_space.shape[0]
|
||||
|
||||
|
|
@ -63,8 +69,8 @@ class PPOPolicy(BasePolicy):
|
|||
if use_sde:
|
||||
dist_kwargs = {
|
||||
'full_std': full_std,
|
||||
'squash_output': False,
|
||||
'use_expln': False,
|
||||
'squash_output': squash_output,
|
||||
'use_expln': use_expln,
|
||||
'learn_features': sde_net_arch is not None
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue