From 03e853997a5c78cab38ca69989a5ee73c3e2ed58 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 15 Jan 2020 13:21:20 +0100 Subject: [PATCH] Add squash_output and expln as policy param for ppo and a2c --- torchy_baselines/ppo/policies.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/torchy_baselines/ppo/policies.py b/torchy_baselines/ppo/policies.py index 8c4df3f..d3b897e 100644 --- a/torchy_baselines/ppo/policies.py +++ b/torchy_baselines/ppo/policies.py @@ -29,12 +29,18 @@ class PPOPolicy(BasePolicy): :param sde_net_arch: ([int]) Network architecture for extracting features when using SDE. If None, the latent features from the policy will be used. Pass an empty list to use the states as features. + :param use_expln: (bool) Use `expln()` function instead of `exp()` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, `exp()` is usually enough. + :param squash_output: (bool) Whether to squash the output using a tanh function, + this allows to ensure boundaries when using SDE. """ def __init__(self, observation_space, action_space, learning_rate, net_arch=None, device='cpu', activation_fn=nn.Tanh, adam_epsilon=1e-5, ortho_init=True, use_sde=False, - log_std_init=0.0, full_std=True, sde_net_arch=None): + log_std_init=0.0, full_std=True, + sde_net_arch=None, use_expln=False, squash_output=False): super(PPOPolicy, self).__init__(observation_space, action_space, device) self.obs_dim = self.observation_space.shape[0] @@ -63,8 +69,8 @@ class PPOPolicy(BasePolicy): if use_sde: dist_kwargs = { 'full_std': full_std, - 'squash_output': False, - 'use_expln': False, + 'squash_output': squash_output, + 'use_expln': use_expln, 'learn_features': sde_net_arch is not None }