diff --git a/torchy_baselines/ppo/policies.py b/torchy_baselines/ppo/policies.py index 8c4df3f..d3b897e 100644 --- a/torchy_baselines/ppo/policies.py +++ b/torchy_baselines/ppo/policies.py @@ -29,12 +29,18 @@ class PPOPolicy(BasePolicy): :param sde_net_arch: ([int]) Network architecture for extracting features when using SDE. If None, the latent features from the policy will be used. Pass an empty list to use the states as features. + :param use_expln: (bool) Use `expln()` function instead of `exp()` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, `exp()` is usually enough. + :param squash_output: (bool) Whether to squash the output using a tanh function, + this allows to ensure boundaries when using SDE. """ def __init__(self, observation_space, action_space, learning_rate, net_arch=None, device='cpu', activation_fn=nn.Tanh, adam_epsilon=1e-5, ortho_init=True, use_sde=False, - log_std_init=0.0, full_std=True, sde_net_arch=None): + log_std_init=0.0, full_std=True, + sde_net_arch=None, use_expln=False, squash_output=False): super(PPOPolicy, self).__init__(observation_space, action_space, device) self.obs_dim = self.observation_space.shape[0] @@ -63,8 +69,8 @@ class PPOPolicy(BasePolicy): if use_sde: dist_kwargs = { 'full_std': full_std, - 'squash_output': False, - 'use_expln': False, + 'squash_output': squash_output, + 'use_expln': use_expln, 'learn_features': sde_net_arch is not None }