From d0003ee4ecd7bb447ec945517d1c35e8b51abc28 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 25 Nov 2019 14:00:21 +0100 Subject: [PATCH] Enable kwargs for proba dist --- README.md | 2 -- torchy_baselines/common/distributions.py | 18 +++++++++++------- torchy_baselines/ppo/policies.py | 16 ++++++++++++++-- torchy_baselines/td3/policies.py | 2 +- 4 files changed, 26 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index cc9f015..0b19832 100644 --- a/README.md +++ b/README.md @@ -21,8 +21,6 @@ TODO: - save/load - better predict - complete logger -- SDE: reduce the number of parameters (only n_features instead of n_features x n_actions) for A2C -(done for TD3) - SDE: learn the feature extractor? - Refactor: buffer with numpy array instead of pytorch - Refactor: remove duplicated code for evaluation diff --git a/torchy_baselines/common/distributions.py b/torchy_baselines/common/distributions.py index ad02734..f90883a 100644 --- a/torchy_baselines/common/distributions.py +++ b/torchy_baselines/common/distributions.py @@ -297,7 +297,7 @@ class StateDependentNoiseDistribution(Distribution): self.weights_dist = Normal(th.zeros_like(std), std) self.exploration_mat = self.weights_dist.rsample() - def proba_distribution_net(self, latent_dim, log_std_init=0.0): + def proba_distribution_net(self, latent_dim, log_std_init=-2.0): """ Create the layers and parameter that represent the distribution: one output will be the deterministic action, the other parameter will be the @@ -423,26 +423,30 @@ class TanhBijector(object): return th.log(1 - th.tanh(x) ** 2 + self.epsilon) -def make_proba_distribution(action_space, use_sde=False): +def make_proba_distribution(action_space, use_sde=False, dist_kwargs=None): """ Return an instance of Distribution for the correct type of action space :param action_space: (Gym Space) the input action space :param use_sde: (bool) Force the use of StateDependentNoiseDistribution instead of DiagGaussianDistribution + :param dist_kwargs: (dict) Keyword arguments to pass to the probabilty distribution :return: (Distribution) the approriate Distribution object """ + if dist_kwargs is None: + dist_kwargs = {} + if isinstance(action_space, spaces.Box): assert len(action_space.shape) == 1, "Error: the action space must be a vector" if use_sde: - return StateDependentNoiseDistribution(action_space.shape[0]) - return DiagGaussianDistribution(action_space.shape[0]) + return StateDependentNoiseDistribution(action_space.shape[0], **dist_kwargs) + return DiagGaussianDistribution(action_space.shape[0], **dist_kwargs) elif isinstance(action_space, spaces.Discrete): - return CategoricalDistribution(action_space.n) + return CategoricalDistribution(action_space.n, **dist_kwargs) # elif isinstance(action_space, spaces.MultiDiscrete): - # return MultiCategoricalDistribution(action_space.nvec) + # return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs) # elif isinstance(action_space, spaces.MultiBinary): - # return BernoulliDistribution(action_space.n) + # return BernoulliDistribution(action_space.n, **dist_kwargs) else: raise NotImplementedError("Error: probability distribution, not implemented for action space of type {}." .format(type(action_space)) + diff --git a/torchy_baselines/ppo/policies.py b/torchy_baselines/ppo/policies.py index adcf239..46fe75c 100644 --- a/torchy_baselines/ppo/policies.py +++ b/torchy_baselines/ppo/policies.py @@ -23,11 +23,14 @@ class PPOPolicy(BasePolicy): :param ortho_init: (bool) Whether to use or not orthogonal initialization :param use_sde: (bool) Whether to use State Dependent Exploration or not :param log_std_init: (float) Initial value for the log standard deviation + :param full_std: (bool) Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using SDE """ def __init__(self, observation_space, action_space, learning_rate, net_arch=None, device='cpu', activation_fn=nn.Tanh, adam_epsilon=1e-5, - ortho_init=True, use_sde=False, log_std_init=0.0): + ortho_init=True, use_sde=False, + log_std_init=0.0, full_std=True): super(PPOPolicy, self).__init__(observation_space, action_space, device) self.obs_dim = self.observation_space.shape[0] @@ -52,8 +55,17 @@ class PPOPolicy(BasePolicy): self.features_extractor = nn.Flatten() self.features_dim = self.obs_dim self.log_std_init = log_std_init + dist_kwargs = None + # Keyword arguments for SDE distribution + if use_sde: + dist_kwargs = { + 'full_std': full_std, + 'squash_output': False, + 'use_expln': False + } + # Action distribution - self.action_dist = make_proba_distribution(action_space, use_sde=use_sde) + self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs) self._build(learning_rate) diff --git a/torchy_baselines/td3/policies.py b/torchy_baselines/td3/policies.py index 97d5862..cd72950 100644 --- a/torchy_baselines/td3/policies.py +++ b/torchy_baselines/td3/policies.py @@ -18,7 +18,7 @@ class Actor(BaseNetwork): :param clip_noise: (float) Clip the magnitude of the noise :param lr_sde: (float) Learning rate for the standard deviation of the noise :param full_std: (bool) Whether to use (n_features x n_actions) parameters - for the std instead of only (n_features,) + for the std instead of only (n_features,) when using SDE. """ def __init__(self, obs_dim, action_dim, net_arch, activation_fn=nn.ReLU, use_sde=False, log_std_init=-2, clip_noise=None,