Enable separate feature extraction for SDE

This commit is contained in:
Antonin Raffin 2019-11-25 14:54:13 +01:00
parent d0003ee4ec
commit 5d6649d92b
4 changed files with 66 additions and 32 deletions

View file

@ -49,14 +49,15 @@ def test_state_dependent_exploration():
@pytest.mark.parametrize("model_class", [A2C])
def test_state_dependent_noise(model_class):
@pytest.mark.parametrize("sde_net_arch", [None, [64, 64]])
def test_state_dependent_noise(model_class, sde_net_arch):
env_id = 'MountainCarContinuous-v0'
env = VecNormalize(DummyVecEnv([lambda: Monitor(gym.make(env_id))]), norm_reward=True)
eval_env = VecNormalize(DummyVecEnv([lambda: Monitor(gym.make(env_id))]), training=False, norm_reward=False)
model = model_class('MlpPolicy', env, n_steps=200, use_sde=True, ent_coef=0.00, verbose=1, learning_rate=3e-4,
policy_kwargs=dict(log_std_init=0.0, ortho_init=False), seed=None)
policy_kwargs=dict(log_std_init=0.0, ortho_init=False, sde_net_arch=sde_net_arch), seed=None)
model.learn(total_timesteps=int(1000), log_interval=5, eval_freq=500, eval_env=eval_env)

View file

@ -113,6 +113,7 @@ class A2C(PPO):
# Optimization step
self.policy.optimizer.zero_grad()
loss.backward()
# Clip grad norm
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()

View file

@ -241,14 +241,17 @@ class StateDependentNoiseDistribution(Distribution):
above zero and prevent it from growing too fast. In practice, `exp()` is usually enough.
:param squash_output: (bool) Whether to squash the output using a tanh function,
this allows to ensure boundaries.
:param learn_features: (bool) Whether to learn features for SDE or not.
This will enable gradients to be backpropagated through the features
`latent_sde` in the code.
:param epsilon: (float) small value to avoid NaN due to numerical imprecision.
"""
def __init__(self, action_dim, full_std=True, use_expln=False,
squash_output=False, epsilon=1e-6):
squash_output=False, learn_features=False, epsilon=1e-6):
super(StateDependentNoiseDistribution, self).__init__()
self.distribution = None
self.action_dim = action_dim
self.latent_dim = None
self.latent_sde_dim = None
self.mean_actions = None
self.log_std = None
self.weights_dist = None
@ -256,6 +259,7 @@ class StateDependentNoiseDistribution(Distribution):
self.use_expln = use_expln
self.full_std = full_std
self.epsilon = epsilon
self.learn_features = learn_features
if squash_output:
print("== Using TanhBijector ===")
self.bijector = TanhBijector(epsilon)
@ -284,7 +288,7 @@ class StateDependentNoiseDistribution(Distribution):
if self.full_std:
return std
# Reduce the number of parameters:
return th.ones((self.latent_dim, self.action_dim)).to(log_std.device) * std
return th.ones((self.latent_sde_dim, self.action_dim)).to(log_std.device) * std
def sample_weights(self, log_std):
"""
@ -297,29 +301,32 @@ class StateDependentNoiseDistribution(Distribution):
self.weights_dist = Normal(th.zeros_like(std), std)
self.exploration_mat = self.weights_dist.rsample()
def proba_distribution_net(self, latent_dim, log_std_init=-2.0):
def proba_distribution_net(self, latent_dim, log_std_init=-2.0, latent_sde_dim=None):
"""
Create the layers and parameter that represent the distribution:
one output will be the deterministic action, the other parameter will be the
standard deviation of the distribution that control the weights of the noise matrix.
:param latent_dim: (int) Dimension og the last layer of the policy (before the action layer)
:param latent_dim: (int) Dimension of the last layer of the policy (before the action layer)
:param log_std_init: (float) Initial value for the log standard deviation
:param latent_sde_dim: (int) Dimension of the last layer of the feature extractor
for SDE. By default, it is shared with the policy network.
:return: (nn.Linear, nn.Parameter)
"""
# Network for the deterministic action, it represents the mean of the distribution
mean_actions_net = nn.Linear(latent_dim, self.action_dim)
self.latent_dim = latent_dim
# When we learn features for the noise, the feature dimension
# can be different between the policy and the noise network
self.latent_sde_dim = latent_dim if latent_sde_dim is None else latent_sde_dim
# Reduce the number of parameters if needed
log_std = th.ones(latent_dim, self.action_dim) if self.full_std else th.ones(latent_dim, 1)
log_std = th.ones(self.latent_sde_dim, self.action_dim) if self.full_std else th.ones(self.latent_sde_dim, 1)
# Transform it to a parameter so it can be optimized
log_std = nn.Parameter(log_std * log_std_init)
# Sample an exploration matrix
self.sample_weights(log_std)
return mean_actions_net, log_std
def proba_distribution(self, mean_actions, log_std, latent_pi, deterministic=False):
def proba_distribution(self, mean_actions, log_std, latent_sde, deterministic=False):
"""
Create and sample for the distribution given its parameters (mean, std)
@ -328,13 +335,15 @@ class StateDependentNoiseDistribution(Distribution):
:param deterministic: (bool)
:return: (th.Tensor)
"""
variance = th.mm(latent_pi.detach() ** 2, self.get_std(log_std) ** 2)
# Stop gradient if we don't want to influence the features
latent_sde = latent_sde if self.learn_features else latent_sde.detach()
variance = th.mm(latent_sde ** 2, self.get_std(log_std) ** 2)
self.distribution = Normal(mean_actions, th.sqrt(variance + self.epsilon))
if deterministic:
action = self.mode()
else:
action = self.sample(latent_pi)
action = self.sample(latent_sde)
return action, self
def mode(self):
@ -343,11 +352,12 @@ class StateDependentNoiseDistribution(Distribution):
return self.bijector.forward(action)
return action
def get_noise(self, latent_pi):
return th.mm(latent_pi.detach(), self.exploration_mat)
def get_noise(self, latent_sde):
latent_sde = latent_sde if self.learn_features else latent_sde.detach()
return th.mm(latent_sde, self.exploration_mat)
def sample(self, latent_pi):
noise = self.get_noise(latent_pi)
def sample(self, latent_sde):
noise = self.get_noise(latent_sde)
action = self.distribution.mean + noise
if self.bijector is not None:
return self.bijector.forward(action)
@ -357,8 +367,8 @@ class StateDependentNoiseDistribution(Distribution):
# TODO: account for the squashing?
return self.distribution.entropy()
def log_prob_from_params(self, mean_actions, log_std, latent_pi):
action, _ = self.proba_distribution(mean_actions, log_std, latent_pi)
def log_prob_from_params(self, mean_actions, log_std, latent_sde):
action, _ = self.proba_distribution(mean_actions, log_std, latent_sde)
log_prob = self.log_prob(action)
return action, log_prob

View file

@ -4,7 +4,7 @@ import torch as th
import torch.nn as nn
import numpy as np
from torchy_baselines.common.policies import BasePolicy, register_policy, MlpExtractor
from torchy_baselines.common.policies import BasePolicy, register_policy, MlpExtractor, create_mlp
from torchy_baselines.common.distributions import make_proba_distribution,\
DiagGaussianDistribution, CategoricalDistribution, StateDependentNoiseDistribution
@ -30,7 +30,7 @@ class PPOPolicy(BasePolicy):
learning_rate, net_arch=None, device='cpu',
activation_fn=nn.Tanh, adam_epsilon=1e-5,
ortho_init=True, use_sde=False,
log_std_init=0.0, full_std=True):
log_std_init=0.0, full_std=True, sde_net_arch=None):
super(PPOPolicy, self).__init__(observation_space, action_space, device)
self.obs_dim = self.observation_space.shape[0]
@ -61,9 +61,13 @@ class PPOPolicy(BasePolicy):
dist_kwargs = {
'full_std': full_std,
'squash_output': False,
'use_expln': False
'use_expln': False,
'learn_features': sde_net_arch is not None
}
self.sde_feature_extractor = None
self.sde_net_arch = sde_net_arch
# Action distribution
self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs)
@ -79,9 +83,20 @@ class PPOPolicy(BasePolicy):
self.mlp_extractor = MlpExtractor(self.features_dim, net_arch=self.net_arch,
activation_fn=self.activation_fn, device=self.device)
# Separate feature extractor for SDE
if self.sde_net_arch is not None:
latent_sde = create_mlp(self.features_dim, -1, self.sde_net_arch,
activation_fn=self.activation_fn, squash_out=False)
self.sde_feature_extractor = nn.Sequential(*latent_sde)
if isinstance(self.action_dist, (DiagGaussianDistribution, StateDependentNoiseDistribution)):
self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=self.mlp_extractor.latent_dim_pi,
log_std_init=self.log_std_init)
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
latent_sde_dim = self.mlp_extractor.latent_dim_pi if self.sde_net_arch is None else self.sde_net_arch[-1]
self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=self.mlp_extractor.latent_dim_pi,
latent_sde_dim=latent_sde_dim,
log_std_init=self.log_std_init)
elif isinstance(self.action_dist, CategoricalDistribution):
self.action_net = self.action_dist.proba_distribution_net(latent_dim=self.mlp_extractor.latent_dim_pi)
@ -102,16 +117,23 @@ class PPOPolicy(BasePolicy):
def forward(self, obs, deterministic=False):
if not isinstance(obs, th.Tensor):
obs = th.FloatTensor(obs).to(self.device)
latent_pi, latent_vf = self._get_latent(obs)
latent_pi, latent_vf, latent_sde = self._get_latent(obs)
value = self.value_net(latent_vf)
action, action_distribution = self._get_action_dist_from_latent(latent_pi, deterministic=deterministic)
action, action_distribution = self._get_action_dist_from_latent(latent_pi, latent_sde=latent_sde,
deterministic=deterministic)
log_prob = action_distribution.log_prob(action)
return action, value, log_prob
def _get_latent(self, obs):
return self.mlp_extractor(self.features_extractor(obs))
features = self.features_extractor(obs)
latent_pi, latent_vf = self.mlp_extractor(features)
# Features for sde
latent_sde = latent_pi
if self.sde_feature_extractor is not None:
latent_sde = self.sde_feature_extractor(features)
return latent_pi, latent_vf, latent_sde
def _get_action_dist_from_latent(self, latent_pi, deterministic=False):
def _get_action_dist_from_latent(self, latent_pi, latent_sde=None, deterministic=False):
mean_actions = self.action_net(latent_pi)
if isinstance(self.action_dist, DiagGaussianDistribution):
@ -121,11 +143,11 @@ class PPOPolicy(BasePolicy):
return self.action_dist.proba_distribution(mean_actions, deterministic=deterministic)
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi, deterministic=deterministic)
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde, deterministic=deterministic)
def actor_forward(self, obs, deterministic=False):
latent_pi, _ = self._get_latent(obs)
action, _ = self._get_action_dist_from_latent(latent_pi, deterministic=deterministic)
latent_pi, _, latent_sde = self._get_latent(obs)
action, _ = self._get_action_dist_from_latent(latent_pi, latent_sde, deterministic=deterministic)
return action.detach().cpu().numpy()
def evaluate_actions(self, obs, action, deterministic=False):
@ -139,14 +161,14 @@ class PPOPolicy(BasePolicy):
:return: (th.Tensor, th.Tensor, th.Tensor) estimated value, log likelihood of taking those actions
and entropy of the action distribution.
"""
latent_pi, latent_vf = self._get_latent(obs)
_, action_distribution = self._get_action_dist_from_latent(latent_pi, deterministic=deterministic)
latent_pi, latent_vf, latent_sde = self._get_latent(obs)
_, action_distribution = self._get_action_dist_from_latent(latent_pi, latent_sde, deterministic=deterministic)
log_prob = action_distribution.log_prob(action)
value = self.value_net(latent_vf)
return value, log_prob, action_distribution.entropy()
def value_forward(self, obs):
_, latent_vf = self._get_latent(obs)
_, latent_vf, _ = self._get_latent(obs)
return self.value_net(latent_vf)