Fix entropy loss for squashed Gaussian and VecEnv seeding

This commit is contained in:
Antonin Raffin 2020-02-11 17:22:03 +01:00
parent 02a080f647
commit 2ce31c1e21
7 changed files with 67 additions and 25 deletions

View file

@ -24,6 +24,8 @@ Bug Fixes:
^^^^^^^^^^
- Fix loading model on CPU that were trained on GPU
- Fix `reset_num_timesteps` that was not used
- Fix entropy computation for squashed Gaussian (approximate it now)
- Fix seeding when using multiple environments (different seed per env)
Deprecations:
^^^^^^^^^^^^^

View file

@ -1,6 +1,7 @@
import pytest
import torch as th
from torchy_baselines import A2C, PPO
from torchy_baselines.common.distributions import DiagGaussianDistribution, TanhBijector, \
StateDependentNoiseDistribution
from torchy_baselines.common.utils import set_random_seed
@ -21,6 +22,14 @@ def test_bijector():
# Check the inverse method
assert th.isclose(TanhBijector.inverse(squashed_actions), actions).all()
@pytest.mark.parametrize("model_class", [A2C, PPO])
def test_squashed_gaussian(model_class):
"""
Test run with squashed Gaussian (notably entropy computation)
"""
model = model_class('MlpPolicy', 'Pendulum-v0', use_sde=True, n_steps=100, policy_kwargs=dict(squash_output=True))
model.learn(500)
def test_sde_distribution():
n_samples = int(5e6)

View file

@ -77,7 +77,7 @@ class A2C(PPO):
lr=self.learning_rate(1), alpha=0.99,
eps=self.rms_prop_eps, weight_decay=0)
def train(self, gradient_steps, batch_size=None):
def train(self, gradient_steps: int, batch_size=None):
# Update optimizer learning rate
self._update_learning_rate(self.policy.optimizer)
# A2C with gradient_steps > 1 does not make sense
@ -107,7 +107,11 @@ class A2C(PPO):
value_loss = F.mse_loss(return_batch, values)
# Entropy loss favor exploration
entropy_loss = -th.mean(entropy)
if entropy is None:
# Approximate entropy when no analytical form
entropy_loss = -log_prob.mean()
else:
entropy_loss = -th.mean(entropy)
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
@ -123,7 +127,7 @@ class A2C(PPO):
self.rollout_buffer.values.flatten())
logger.logkv("explained_variance", explained_var)
logger.logkv("entropy", entropy.mean().item())
logger.logkv("entropy_loss", entropy_loss.item())
logger.logkv("policy_loss", policy_loss.item())
logger.logkv("value_loss", value_loss.item())
if hasattr(self.policy, 'log_std'):

View file

@ -1,3 +1,5 @@
from typing import Optional
import torch as th
import torch.nn as nn
from torch.distributions import Normal, Categorical
@ -8,24 +10,25 @@ class Distribution(object):
def __init__(self):
super(Distribution, self).__init__()
def log_prob(self, x):
def log_prob(self, x: th.Tensor) -> th.Tensor:
"""
returns the log likelihood
:param x: (object) the taken action
:param x: (th.Tensor) the taken action
:return: (th.Tensor) The log likelihood of the distribution
"""
raise NotImplementedError
def entropy(self):
def entropy(self) -> Optional[th.Tensor]:
"""
Returns shannon's entropy of the probability
:return: (th.Tensor) the entropy
:return: (Optional[th.Tensor]) the entropy,
return None if no analytical form is known
"""
raise NotImplementedError
def sample(self):
def sample(self) -> th.Tensor:
"""
returns a sample from the probabilty distribution
@ -145,6 +148,11 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
# Squash the output
return th.tanh(self.gaussian_action)
def entropy(self):
# No analytical form,
# entropy needs to be estimated using -log_prob.mean()
return None
def sample(self):
self.gaussian_action = self.distribution.rsample()
return th.tanh(self.gaussian_action)
@ -371,7 +379,10 @@ class StateDependentNoiseDistribution(Distribution):
return action
def entropy(self):
# TODO: account for the squashing?
# No analytical form,
# entropy needs to be estimated using -log_prob.mean()
if self.bijector is not None:
return None
return self.distribution.entropy()
def log_prob_from_params(self, mean_actions, log_std, latent_sde):

View file

@ -152,8 +152,9 @@ class VecEnv(ABC):
:param indices: ([int])
"""
indices = self._get_indices(indices)
# Different seed per environment
if not hasattr(seed, 'len'):
seed = [seed] * len(indices)
seed = [seed + i for i in range(len(indices))]
assert len(seed) == len(indices)
return [self.env_method('seed', seed[i], indices=i) for i in indices]

View file

@ -116,9 +116,7 @@ class PPO(BaseRLModel):
# Action is a scalar
action_dim = 1
# TODO: different seed for each env when n_envs > 1
if self.n_envs == 1:
self.set_random_seed(self.seed)
self.set_random_seed(self.seed)
self.rollout_buffer = RolloutBuffer(self.n_steps, state_dim, action_dim, self.device,
gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs)
@ -208,15 +206,14 @@ class PPO(BaseRLModel):
return obs, continue_training
def train(self, gradient_steps, batch_size=64):
def train(self, gradient_steps: int, batch_size: int = 64) -> None:
# Update optimizer learning rate
self._update_learning_rate(self.policy.optimizer)
# Compute current clip range
clip_range = self.clip_range(self._current_progress)
logger.logkv("clip_range", clip_range)
# Optional: clip range for the value function
if self.clip_range_vf is not None:
clip_range_vf = self.clip_range_vf(self._current_progress)
logger.logkv("clip_range_vf", clip_range_vf)
for gradient_step in range(gradient_steps):
approx_kl_divs = []
@ -258,7 +255,11 @@ class PPO(BaseRLModel):
value_loss = F.mse_loss(return_batch, values_pred)
# Entropy loss favor exploration
entropy_loss = -th.mean(entropy)
if entropy is None:
# Approximate entropy when no analytical form
entropy_loss = -log_prob.mean()
else:
entropy_loss = -th.mean(entropy)
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
@ -278,9 +279,14 @@ class PPO(BaseRLModel):
explained_var = explained_variance(self.rollout_buffer.returns.flatten(),
self.rollout_buffer.values.flatten())
logger.logkv("clip_range", clip_range)
if self.clip_range_vf is not None:
logger.logkv("clip_range_vf", clip_range_vf)
logger.logkv("explained_variance", explained_var)
# TODO: gather stats for the entropy and other losses?
logger.logkv("entropy", entropy.mean().item())
logger.logkv("entropy_loss", entropy_loss.item())
logger.logkv("policy_loss", policy_loss.item())
logger.logkv("value_loss", value_loss.item())
if hasattr(self.policy, 'log_std'):

View file

@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import List, Tuple, Optional
import torch as th
import torch.nn.functional as F
@ -132,7 +132,10 @@ class TD3(OffPolicyRLModel):
"""
return self.unscale_action(self.select_action(observation, deterministic=deterministic))
def train_critic(self, gradient_steps=1, batch_size=100, replay_data=None, tau=0.0):
def train_critic(self, gradient_steps: int = 1,
batch_size: int = 100,
replay_data: Optional[Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]] = None,
tau: float = 0.0):
# Update optimizer learning rate
self._update_learning_rate(self.critic.optimizer)
@ -171,9 +174,11 @@ class TD3(OffPolicyRLModel):
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
def train_actor(self, gradient_steps=1, batch_size=100, tau_actor=0.005,
tau_critic=0.005,
replay_data=None):
def train_actor(self, gradient_steps: int = 1,
batch_size: int = 100,
tau_actor: float = 0.005,
tau_critic: float = 0.005,
replay_data: Optional[Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]] = None):
# Update optimizer learning rate
self._update_learning_rate(self.actor.optimizer)
@ -200,7 +205,7 @@ class TD3(OffPolicyRLModel):
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(tau_actor * param.data + (1 - tau_actor) * target_param.data)
def train(self, gradient_steps, batch_size=100, policy_delay=2):
def train(self, gradient_steps: int, batch_size: int = 100, policy_delay: int = 2):
for gradient_step in range(gradient_steps):
@ -234,7 +239,11 @@ class TD3(OffPolicyRLModel):
policy_loss = -(advantage * log_prob).mean()
# Entropy loss favor exploration
entropy_loss = -th.mean(entropy)
if entropy is None:
# Approximate entropy when no analytical form
entropy_loss = -log_prob.mean()
else:
entropy_loss = -th.mean(entropy)
vf_coef = 0.5
loss = policy_loss + self.sde_ent_coef * entropy_loss + vf_coef * value_loss