mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-29 23:07:07 +00:00
Enable kwargs for proba dist
This commit is contained in:
parent
5bbb14188d
commit
d0003ee4ec
4 changed files with 26 additions and 12 deletions
|
|
@ -21,8 +21,6 @@ TODO:
|
|||
- save/load
|
||||
- better predict
|
||||
- complete logger
|
||||
- SDE: reduce the number of parameters (only n_features instead of n_features x n_actions) for A2C
|
||||
(done for TD3)
|
||||
- SDE: learn the feature extractor?
|
||||
- Refactor: buffer with numpy array instead of pytorch
|
||||
- Refactor: remove duplicated code for evaluation
|
||||
|
|
|
|||
|
|
@ -297,7 +297,7 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
self.weights_dist = Normal(th.zeros_like(std), std)
|
||||
self.exploration_mat = self.weights_dist.rsample()
|
||||
|
||||
def proba_distribution_net(self, latent_dim, log_std_init=0.0):
|
||||
def proba_distribution_net(self, latent_dim, log_std_init=-2.0):
|
||||
"""
|
||||
Create the layers and parameter that represent the distribution:
|
||||
one output will be the deterministic action, the other parameter will be the
|
||||
|
|
@ -423,26 +423,30 @@ class TanhBijector(object):
|
|||
return th.log(1 - th.tanh(x) ** 2 + self.epsilon)
|
||||
|
||||
|
||||
def make_proba_distribution(action_space, use_sde=False):
|
||||
def make_proba_distribution(action_space, use_sde=False, dist_kwargs=None):
|
||||
"""
|
||||
Return an instance of Distribution for the correct type of action space
|
||||
|
||||
:param action_space: (Gym Space) the input action space
|
||||
:param use_sde: (bool) Force the use of StateDependentNoiseDistribution
|
||||
instead of DiagGaussianDistribution
|
||||
:param dist_kwargs: (dict) Keyword arguments to pass to the probabilty distribution
|
||||
:return: (Distribution) the approriate Distribution object
|
||||
"""
|
||||
if dist_kwargs is None:
|
||||
dist_kwargs = {}
|
||||
|
||||
if isinstance(action_space, spaces.Box):
|
||||
assert len(action_space.shape) == 1, "Error: the action space must be a vector"
|
||||
if use_sde:
|
||||
return StateDependentNoiseDistribution(action_space.shape[0])
|
||||
return DiagGaussianDistribution(action_space.shape[0])
|
||||
return StateDependentNoiseDistribution(action_space.shape[0], **dist_kwargs)
|
||||
return DiagGaussianDistribution(action_space.shape[0], **dist_kwargs)
|
||||
elif isinstance(action_space, spaces.Discrete):
|
||||
return CategoricalDistribution(action_space.n)
|
||||
return CategoricalDistribution(action_space.n, **dist_kwargs)
|
||||
# elif isinstance(action_space, spaces.MultiDiscrete):
|
||||
# return MultiCategoricalDistribution(action_space.nvec)
|
||||
# return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs)
|
||||
# elif isinstance(action_space, spaces.MultiBinary):
|
||||
# return BernoulliDistribution(action_space.n)
|
||||
# return BernoulliDistribution(action_space.n, **dist_kwargs)
|
||||
else:
|
||||
raise NotImplementedError("Error: probability distribution, not implemented for action space of type {}."
|
||||
.format(type(action_space)) +
|
||||
|
|
|
|||
|
|
@ -23,11 +23,14 @@ class PPOPolicy(BasePolicy):
|
|||
:param ortho_init: (bool) Whether to use or not orthogonal initialization
|
||||
:param use_sde: (bool) Whether to use State Dependent Exploration or not
|
||||
:param log_std_init: (float) Initial value for the log standard deviation
|
||||
:param full_std: (bool) Whether to use (n_features x n_actions) parameters
|
||||
for the std instead of only (n_features,) when using SDE
|
||||
"""
|
||||
def __init__(self, observation_space, action_space,
|
||||
learning_rate, net_arch=None, device='cpu',
|
||||
activation_fn=nn.Tanh, adam_epsilon=1e-5,
|
||||
ortho_init=True, use_sde=False, log_std_init=0.0):
|
||||
ortho_init=True, use_sde=False,
|
||||
log_std_init=0.0, full_std=True):
|
||||
super(PPOPolicy, self).__init__(observation_space, action_space, device)
|
||||
self.obs_dim = self.observation_space.shape[0]
|
||||
|
||||
|
|
@ -52,8 +55,17 @@ class PPOPolicy(BasePolicy):
|
|||
self.features_extractor = nn.Flatten()
|
||||
self.features_dim = self.obs_dim
|
||||
self.log_std_init = log_std_init
|
||||
dist_kwargs = None
|
||||
# Keyword arguments for SDE distribution
|
||||
if use_sde:
|
||||
dist_kwargs = {
|
||||
'full_std': full_std,
|
||||
'squash_output': False,
|
||||
'use_expln': False
|
||||
}
|
||||
|
||||
# Action distribution
|
||||
self.action_dist = make_proba_distribution(action_space, use_sde=use_sde)
|
||||
self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs)
|
||||
|
||||
self._build(learning_rate)
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class Actor(BaseNetwork):
|
|||
:param clip_noise: (float) Clip the magnitude of the noise
|
||||
:param lr_sde: (float) Learning rate for the standard deviation of the noise
|
||||
:param full_std: (bool) Whether to use (n_features x n_actions) parameters
|
||||
for the std instead of only (n_features,)
|
||||
for the std instead of only (n_features,) when using SDE.
|
||||
"""
|
||||
def __init__(self, obs_dim, action_dim, net_arch, activation_fn=nn.ReLU,
|
||||
use_sde=False, log_std_init=-2, clip_noise=None,
|
||||
|
|
|
|||
Loading…
Reference in a new issue