Enable kwargs for proba dist

This commit is contained in:
Antonin Raffin 2019-11-25 14:00:21 +01:00
parent 5bbb14188d
commit d0003ee4ec
4 changed files with 26 additions and 12 deletions

View file

@ -21,8 +21,6 @@ TODO:
- save/load
- better predict
- complete logger
- SDE: reduce the number of parameters (only n_features instead of n_features x n_actions) for A2C
(done for TD3)
- SDE: learn the feature extractor?
- Refactor: buffer with numpy array instead of pytorch
- Refactor: remove duplicated code for evaluation

View file

@ -297,7 +297,7 @@ class StateDependentNoiseDistribution(Distribution):
self.weights_dist = Normal(th.zeros_like(std), std)
self.exploration_mat = self.weights_dist.rsample()
def proba_distribution_net(self, latent_dim, log_std_init=0.0):
def proba_distribution_net(self, latent_dim, log_std_init=-2.0):
"""
Create the layers and parameter that represent the distribution:
one output will be the deterministic action, the other parameter will be the
@ -423,26 +423,30 @@ class TanhBijector(object):
return th.log(1 - th.tanh(x) ** 2 + self.epsilon)
def make_proba_distribution(action_space, use_sde=False):
def make_proba_distribution(action_space, use_sde=False, dist_kwargs=None):
"""
Return an instance of Distribution for the correct type of action space
:param action_space: (Gym Space) the input action space
:param use_sde: (bool) Force the use of StateDependentNoiseDistribution
instead of DiagGaussianDistribution
:param dist_kwargs: (dict) Keyword arguments to pass to the probabilty distribution
:return: (Distribution) the approriate Distribution object
"""
if dist_kwargs is None:
dist_kwargs = {}
if isinstance(action_space, spaces.Box):
assert len(action_space.shape) == 1, "Error: the action space must be a vector"
if use_sde:
return StateDependentNoiseDistribution(action_space.shape[0])
return DiagGaussianDistribution(action_space.shape[0])
return StateDependentNoiseDistribution(action_space.shape[0], **dist_kwargs)
return DiagGaussianDistribution(action_space.shape[0], **dist_kwargs)
elif isinstance(action_space, spaces.Discrete):
return CategoricalDistribution(action_space.n)
return CategoricalDistribution(action_space.n, **dist_kwargs)
# elif isinstance(action_space, spaces.MultiDiscrete):
# return MultiCategoricalDistribution(action_space.nvec)
# return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs)
# elif isinstance(action_space, spaces.MultiBinary):
# return BernoulliDistribution(action_space.n)
# return BernoulliDistribution(action_space.n, **dist_kwargs)
else:
raise NotImplementedError("Error: probability distribution, not implemented for action space of type {}."
.format(type(action_space)) +

View file

@ -23,11 +23,14 @@ class PPOPolicy(BasePolicy):
:param ortho_init: (bool) Whether to use or not orthogonal initialization
:param use_sde: (bool) Whether to use State Dependent Exploration or not
:param log_std_init: (float) Initial value for the log standard deviation
:param full_std: (bool) Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using SDE
"""
def __init__(self, observation_space, action_space,
learning_rate, net_arch=None, device='cpu',
activation_fn=nn.Tanh, adam_epsilon=1e-5,
ortho_init=True, use_sde=False, log_std_init=0.0):
ortho_init=True, use_sde=False,
log_std_init=0.0, full_std=True):
super(PPOPolicy, self).__init__(observation_space, action_space, device)
self.obs_dim = self.observation_space.shape[0]
@ -52,8 +55,17 @@ class PPOPolicy(BasePolicy):
self.features_extractor = nn.Flatten()
self.features_dim = self.obs_dim
self.log_std_init = log_std_init
dist_kwargs = None
# Keyword arguments for SDE distribution
if use_sde:
dist_kwargs = {
'full_std': full_std,
'squash_output': False,
'use_expln': False
}
# Action distribution
self.action_dist = make_proba_distribution(action_space, use_sde=use_sde)
self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs)
self._build(learning_rate)

View file

@ -18,7 +18,7 @@ class Actor(BaseNetwork):
:param clip_noise: (float) Clip the magnitude of the noise
:param lr_sde: (float) Learning rate for the standard deviation of the noise
:param full_std: (bool) Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,)
for the std instead of only (n_features,) when using SDE.
"""
def __init__(self, obs_dim, action_dim, net_arch, activation_fn=nn.ReLU,
use_sde=False, log_std_init=-2, clip_noise=None,