Add squash_output and expln as policy param for ppo and a2c

This commit is contained in:
Antonin Raffin 2020-01-15 13:21:20 +01:00
parent 60d5f4463d
commit 03e853997a

View file

@ -29,12 +29,18 @@ class PPOPolicy(BasePolicy):
:param sde_net_arch: ([int]) Network architecture for extracting features
when using SDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: (bool) Use `expln()` function instead of `exp()` to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, `exp()` is usually enough.
:param squash_output: (bool) Whether to squash the output using a tanh function,
this allows to ensure boundaries when using SDE.
"""
def __init__(self, observation_space, action_space,
learning_rate, net_arch=None, device='cpu',
activation_fn=nn.Tanh, adam_epsilon=1e-5,
ortho_init=True, use_sde=False,
log_std_init=0.0, full_std=True, sde_net_arch=None):
log_std_init=0.0, full_std=True,
sde_net_arch=None, use_expln=False, squash_output=False):
super(PPOPolicy, self).__init__(observation_space, action_space, device)
self.obs_dim = self.observation_space.shape[0]
@ -63,8 +69,8 @@ class PPOPolicy(BasePolicy):
if use_sde:
dist_kwargs = {
'full_std': full_std,
'squash_output': False,
'use_expln': False,
'squash_output': squash_output,
'use_expln': use_expln,
'learn_features': sde_net_arch is not None
}