mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-02 03:55:39 +00:00
Enable separate feature extraction for SDE
This commit is contained in:
parent
d0003ee4ec
commit
5d6649d92b
4 changed files with 66 additions and 32 deletions
|
|
@ -49,14 +49,15 @@ def test_state_dependent_exploration():
|
|||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C])
|
||||
def test_state_dependent_noise(model_class):
|
||||
@pytest.mark.parametrize("sde_net_arch", [None, [64, 64]])
|
||||
def test_state_dependent_noise(model_class, sde_net_arch):
|
||||
env_id = 'MountainCarContinuous-v0'
|
||||
|
||||
env = VecNormalize(DummyVecEnv([lambda: Monitor(gym.make(env_id))]), norm_reward=True)
|
||||
eval_env = VecNormalize(DummyVecEnv([lambda: Monitor(gym.make(env_id))]), training=False, norm_reward=False)
|
||||
|
||||
model = model_class('MlpPolicy', env, n_steps=200, use_sde=True, ent_coef=0.00, verbose=1, learning_rate=3e-4,
|
||||
policy_kwargs=dict(log_std_init=0.0, ortho_init=False), seed=None)
|
||||
policy_kwargs=dict(log_std_init=0.0, ortho_init=False, sde_net_arch=sde_net_arch), seed=None)
|
||||
model.learn(total_timesteps=int(1000), log_interval=5, eval_freq=500, eval_env=eval_env)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -113,6 +113,7 @@ class A2C(PPO):
|
|||
# Optimization step
|
||||
self.policy.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
# Clip grad norm
|
||||
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
||||
self.policy.optimizer.step()
|
||||
|
|
|
|||
|
|
@ -241,14 +241,17 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
above zero and prevent it from growing too fast. In practice, `exp()` is usually enough.
|
||||
:param squash_output: (bool) Whether to squash the output using a tanh function,
|
||||
this allows to ensure boundaries.
|
||||
:param learn_features: (bool) Whether to learn features for SDE or not.
|
||||
This will enable gradients to be backpropagated through the features
|
||||
`latent_sde` in the code.
|
||||
:param epsilon: (float) small value to avoid NaN due to numerical imprecision.
|
||||
"""
|
||||
def __init__(self, action_dim, full_std=True, use_expln=False,
|
||||
squash_output=False, epsilon=1e-6):
|
||||
squash_output=False, learn_features=False, epsilon=1e-6):
|
||||
super(StateDependentNoiseDistribution, self).__init__()
|
||||
self.distribution = None
|
||||
self.action_dim = action_dim
|
||||
self.latent_dim = None
|
||||
self.latent_sde_dim = None
|
||||
self.mean_actions = None
|
||||
self.log_std = None
|
||||
self.weights_dist = None
|
||||
|
|
@ -256,6 +259,7 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
self.use_expln = use_expln
|
||||
self.full_std = full_std
|
||||
self.epsilon = epsilon
|
||||
self.learn_features = learn_features
|
||||
if squash_output:
|
||||
print("== Using TanhBijector ===")
|
||||
self.bijector = TanhBijector(epsilon)
|
||||
|
|
@ -284,7 +288,7 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
if self.full_std:
|
||||
return std
|
||||
# Reduce the number of parameters:
|
||||
return th.ones((self.latent_dim, self.action_dim)).to(log_std.device) * std
|
||||
return th.ones((self.latent_sde_dim, self.action_dim)).to(log_std.device) * std
|
||||
|
||||
def sample_weights(self, log_std):
|
||||
"""
|
||||
|
|
@ -297,29 +301,32 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
self.weights_dist = Normal(th.zeros_like(std), std)
|
||||
self.exploration_mat = self.weights_dist.rsample()
|
||||
|
||||
def proba_distribution_net(self, latent_dim, log_std_init=-2.0):
|
||||
def proba_distribution_net(self, latent_dim, log_std_init=-2.0, latent_sde_dim=None):
|
||||
"""
|
||||
Create the layers and parameter that represent the distribution:
|
||||
one output will be the deterministic action, the other parameter will be the
|
||||
standard deviation of the distribution that control the weights of the noise matrix.
|
||||
|
||||
:param latent_dim: (int) Dimension og the last layer of the policy (before the action layer)
|
||||
:param latent_dim: (int) Dimension of the last layer of the policy (before the action layer)
|
||||
:param log_std_init: (float) Initial value for the log standard deviation
|
||||
:param latent_sde_dim: (int) Dimension of the last layer of the feature extractor
|
||||
for SDE. By default, it is shared with the policy network.
|
||||
:return: (nn.Linear, nn.Parameter)
|
||||
"""
|
||||
# Network for the deterministic action, it represents the mean of the distribution
|
||||
mean_actions_net = nn.Linear(latent_dim, self.action_dim)
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
# When we learn features for the noise, the feature dimension
|
||||
# can be different between the policy and the noise network
|
||||
self.latent_sde_dim = latent_dim if latent_sde_dim is None else latent_sde_dim
|
||||
# Reduce the number of parameters if needed
|
||||
log_std = th.ones(latent_dim, self.action_dim) if self.full_std else th.ones(latent_dim, 1)
|
||||
log_std = th.ones(self.latent_sde_dim, self.action_dim) if self.full_std else th.ones(self.latent_sde_dim, 1)
|
||||
# Transform it to a parameter so it can be optimized
|
||||
log_std = nn.Parameter(log_std * log_std_init)
|
||||
# Sample an exploration matrix
|
||||
self.sample_weights(log_std)
|
||||
return mean_actions_net, log_std
|
||||
|
||||
def proba_distribution(self, mean_actions, log_std, latent_pi, deterministic=False):
|
||||
def proba_distribution(self, mean_actions, log_std, latent_sde, deterministic=False):
|
||||
"""
|
||||
Create and sample for the distribution given its parameters (mean, std)
|
||||
|
||||
|
|
@ -328,13 +335,15 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
:param deterministic: (bool)
|
||||
:return: (th.Tensor)
|
||||
"""
|
||||
variance = th.mm(latent_pi.detach() ** 2, self.get_std(log_std) ** 2)
|
||||
# Stop gradient if we don't want to influence the features
|
||||
latent_sde = latent_sde if self.learn_features else latent_sde.detach()
|
||||
variance = th.mm(latent_sde ** 2, self.get_std(log_std) ** 2)
|
||||
self.distribution = Normal(mean_actions, th.sqrt(variance + self.epsilon))
|
||||
|
||||
if deterministic:
|
||||
action = self.mode()
|
||||
else:
|
||||
action = self.sample(latent_pi)
|
||||
action = self.sample(latent_sde)
|
||||
return action, self
|
||||
|
||||
def mode(self):
|
||||
|
|
@ -343,11 +352,12 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
return self.bijector.forward(action)
|
||||
return action
|
||||
|
||||
def get_noise(self, latent_pi):
|
||||
return th.mm(latent_pi.detach(), self.exploration_mat)
|
||||
def get_noise(self, latent_sde):
|
||||
latent_sde = latent_sde if self.learn_features else latent_sde.detach()
|
||||
return th.mm(latent_sde, self.exploration_mat)
|
||||
|
||||
def sample(self, latent_pi):
|
||||
noise = self.get_noise(latent_pi)
|
||||
def sample(self, latent_sde):
|
||||
noise = self.get_noise(latent_sde)
|
||||
action = self.distribution.mean + noise
|
||||
if self.bijector is not None:
|
||||
return self.bijector.forward(action)
|
||||
|
|
@ -357,8 +367,8 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
# TODO: account for the squashing?
|
||||
return self.distribution.entropy()
|
||||
|
||||
def log_prob_from_params(self, mean_actions, log_std, latent_pi):
|
||||
action, _ = self.proba_distribution(mean_actions, log_std, latent_pi)
|
||||
def log_prob_from_params(self, mean_actions, log_std, latent_sde):
|
||||
action, _ = self.proba_distribution(mean_actions, log_std, latent_sde)
|
||||
log_prob = self.log_prob(action)
|
||||
return action, log_prob
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import torch as th
|
|||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from torchy_baselines.common.policies import BasePolicy, register_policy, MlpExtractor
|
||||
from torchy_baselines.common.policies import BasePolicy, register_policy, MlpExtractor, create_mlp
|
||||
from torchy_baselines.common.distributions import make_proba_distribution,\
|
||||
DiagGaussianDistribution, CategoricalDistribution, StateDependentNoiseDistribution
|
||||
|
||||
|
|
@ -30,7 +30,7 @@ class PPOPolicy(BasePolicy):
|
|||
learning_rate, net_arch=None, device='cpu',
|
||||
activation_fn=nn.Tanh, adam_epsilon=1e-5,
|
||||
ortho_init=True, use_sde=False,
|
||||
log_std_init=0.0, full_std=True):
|
||||
log_std_init=0.0, full_std=True, sde_net_arch=None):
|
||||
super(PPOPolicy, self).__init__(observation_space, action_space, device)
|
||||
self.obs_dim = self.observation_space.shape[0]
|
||||
|
||||
|
|
@ -61,9 +61,13 @@ class PPOPolicy(BasePolicy):
|
|||
dist_kwargs = {
|
||||
'full_std': full_std,
|
||||
'squash_output': False,
|
||||
'use_expln': False
|
||||
'use_expln': False,
|
||||
'learn_features': sde_net_arch is not None
|
||||
}
|
||||
|
||||
self.sde_feature_extractor = None
|
||||
self.sde_net_arch = sde_net_arch
|
||||
|
||||
# Action distribution
|
||||
self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs)
|
||||
|
||||
|
|
@ -79,9 +83,20 @@ class PPOPolicy(BasePolicy):
|
|||
self.mlp_extractor = MlpExtractor(self.features_dim, net_arch=self.net_arch,
|
||||
activation_fn=self.activation_fn, device=self.device)
|
||||
|
||||
# Separate feature extractor for SDE
|
||||
if self.sde_net_arch is not None:
|
||||
latent_sde = create_mlp(self.features_dim, -1, self.sde_net_arch,
|
||||
activation_fn=self.activation_fn, squash_out=False)
|
||||
self.sde_feature_extractor = nn.Sequential(*latent_sde)
|
||||
|
||||
if isinstance(self.action_dist, (DiagGaussianDistribution, StateDependentNoiseDistribution)):
|
||||
self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=self.mlp_extractor.latent_dim_pi,
|
||||
log_std_init=self.log_std_init)
|
||||
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
|
||||
latent_sde_dim = self.mlp_extractor.latent_dim_pi if self.sde_net_arch is None else self.sde_net_arch[-1]
|
||||
self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=self.mlp_extractor.latent_dim_pi,
|
||||
latent_sde_dim=latent_sde_dim,
|
||||
log_std_init=self.log_std_init)
|
||||
elif isinstance(self.action_dist, CategoricalDistribution):
|
||||
self.action_net = self.action_dist.proba_distribution_net(latent_dim=self.mlp_extractor.latent_dim_pi)
|
||||
|
||||
|
|
@ -102,16 +117,23 @@ class PPOPolicy(BasePolicy):
|
|||
def forward(self, obs, deterministic=False):
|
||||
if not isinstance(obs, th.Tensor):
|
||||
obs = th.FloatTensor(obs).to(self.device)
|
||||
latent_pi, latent_vf = self._get_latent(obs)
|
||||
latent_pi, latent_vf, latent_sde = self._get_latent(obs)
|
||||
value = self.value_net(latent_vf)
|
||||
action, action_distribution = self._get_action_dist_from_latent(latent_pi, deterministic=deterministic)
|
||||
action, action_distribution = self._get_action_dist_from_latent(latent_pi, latent_sde=latent_sde,
|
||||
deterministic=deterministic)
|
||||
log_prob = action_distribution.log_prob(action)
|
||||
return action, value, log_prob
|
||||
|
||||
def _get_latent(self, obs):
|
||||
return self.mlp_extractor(self.features_extractor(obs))
|
||||
features = self.features_extractor(obs)
|
||||
latent_pi, latent_vf = self.mlp_extractor(features)
|
||||
# Features for sde
|
||||
latent_sde = latent_pi
|
||||
if self.sde_feature_extractor is not None:
|
||||
latent_sde = self.sde_feature_extractor(features)
|
||||
return latent_pi, latent_vf, latent_sde
|
||||
|
||||
def _get_action_dist_from_latent(self, latent_pi, deterministic=False):
|
||||
def _get_action_dist_from_latent(self, latent_pi, latent_sde=None, deterministic=False):
|
||||
mean_actions = self.action_net(latent_pi)
|
||||
|
||||
if isinstance(self.action_dist, DiagGaussianDistribution):
|
||||
|
|
@ -121,11 +143,11 @@ class PPOPolicy(BasePolicy):
|
|||
return self.action_dist.proba_distribution(mean_actions, deterministic=deterministic)
|
||||
|
||||
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
|
||||
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi, deterministic=deterministic)
|
||||
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde, deterministic=deterministic)
|
||||
|
||||
def actor_forward(self, obs, deterministic=False):
|
||||
latent_pi, _ = self._get_latent(obs)
|
||||
action, _ = self._get_action_dist_from_latent(latent_pi, deterministic=deterministic)
|
||||
latent_pi, _, latent_sde = self._get_latent(obs)
|
||||
action, _ = self._get_action_dist_from_latent(latent_pi, latent_sde, deterministic=deterministic)
|
||||
return action.detach().cpu().numpy()
|
||||
|
||||
def evaluate_actions(self, obs, action, deterministic=False):
|
||||
|
|
@ -139,14 +161,14 @@ class PPOPolicy(BasePolicy):
|
|||
:return: (th.Tensor, th.Tensor, th.Tensor) estimated value, log likelihood of taking those actions
|
||||
and entropy of the action distribution.
|
||||
"""
|
||||
latent_pi, latent_vf = self._get_latent(obs)
|
||||
_, action_distribution = self._get_action_dist_from_latent(latent_pi, deterministic=deterministic)
|
||||
latent_pi, latent_vf, latent_sde = self._get_latent(obs)
|
||||
_, action_distribution = self._get_action_dist_from_latent(latent_pi, latent_sde, deterministic=deterministic)
|
||||
log_prob = action_distribution.log_prob(action)
|
||||
value = self.value_net(latent_vf)
|
||||
return value, log_prob, action_distribution.entropy()
|
||||
|
||||
def value_forward(self, obs):
|
||||
_, latent_vf = self._get_latent(obs)
|
||||
_, latent_vf, _ = self._get_latent(obs)
|
||||
return self.value_net(latent_vf)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue