Merge pull request #51 from Antonin-Raffin/fix/entropy-squashed

Fix entropy loss for squashed Gaussian and VecEnv seeding
This commit is contained in:
Raffin, Antonin 2020-02-11 17:46:56 +01:00 committed by GitHub Enterprise
commit cbb0843201
9 changed files with 78 additions and 30 deletions

View file

@ -24,6 +24,8 @@ Bug Fixes:
^^^^^^^^^^
- Fix loading model on CPU that were trained on GPU
- Fix `reset_num_timesteps` that was not used
- Fix entropy computation for squashed Gaussian (approximate it now)
- Fix seeding when using multiple environments (different seed per env)
Deprecations:
^^^^^^^^^^^^^

View file

@ -1,6 +1,7 @@
import pytest
import torch as th
from torchy_baselines import A2C, PPO
from torchy_baselines.common.distributions import DiagGaussianDistribution, TanhBijector, \
StateDependentNoiseDistribution
from torchy_baselines.common.utils import set_random_seed
@ -21,6 +22,14 @@ def test_bijector():
# Check the inverse method
assert th.isclose(TanhBijector.inverse(squashed_actions), actions).all()
@pytest.mark.parametrize("model_class", [A2C, PPO])
def test_squashed_gaussian(model_class):
"""
Test run with squashed Gaussian (notably entropy computation)
"""
model = model_class('MlpPolicy', 'Pendulum-v0', use_sde=True, n_steps=100, policy_kwargs=dict(squash_output=True))
model.learn(500)
def test_sde_distribution():
n_samples = int(5e6)

View file

@ -77,7 +77,7 @@ class A2C(PPO):
lr=self.learning_rate(1), alpha=0.99,
eps=self.rms_prop_eps, weight_decay=0)
def train(self, gradient_steps, batch_size=None):
def train(self, gradient_steps: int, batch_size=None):
# Update optimizer learning rate
self._update_learning_rate(self.policy.optimizer)
# A2C with gradient_steps > 1 does not make sense
@ -107,7 +107,11 @@ class A2C(PPO):
value_loss = F.mse_loss(return_batch, values)
# Entropy loss favor exploration
entropy_loss = -th.mean(entropy)
if entropy is None:
# Approximate entropy when no analytical form
entropy_loss = -log_prob.mean()
else:
entropy_loss = -th.mean(entropy)
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
@ -123,7 +127,7 @@ class A2C(PPO):
self.rollout_buffer.values.flatten())
logger.logkv("explained_variance", explained_var)
logger.logkv("entropy", entropy.mean().item())
logger.logkv("entropy_loss", entropy_loss.item())
logger.logkv("policy_loss", policy_loss.item())
logger.logkv("value_loss", value_loss.item())
if hasattr(self.policy, 'log_std'):

View file

@ -4,6 +4,7 @@ import numpy as np
import torch as th
from torchy_baselines.common.vec_env import VecNormalize
from torchy_baselines.common.type_aliases import RolloutBufferSamples, ReplayBufferSamples
class BaseBuffer(object):
@ -177,7 +178,7 @@ class ReplayBuffer(BaseBuffer):
def _get_samples(self,
batch_inds: np.ndarray,
env: Optional[VecNormalize] = None
) -> Tuple[th.Tensor, ...]:
) -> ReplayBufferSamples:
data = (self._normalize_obs(self.observations[batch_inds, 0, :], env),
self.actions[batch_inds, 0, :],
self._normalize_obs(self.next_observations[batch_inds, 0, :], env),
@ -305,7 +306,7 @@ class RolloutBuffer(BaseBuffer):
if self.pos == self.buffer_size:
self.full = True
def get(self, batch_size: Optional[int] = None) -> Generator[Tuple[th.Tensor, ...], None, None]:
def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:
assert self.full, ''
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
@ -325,7 +326,7 @@ class RolloutBuffer(BaseBuffer):
start_idx += batch_size
def _get_samples(self, batch_inds: np.ndarray,
env: Optional[VecNormalize] = None) -> Tuple[th.Tensor, ...]:
env: Optional[VecNormalize] = None) -> RolloutBufferSamples:
data = (self.observations[batch_inds],
self.actions[batch_inds],
self.values[batch_inds].flatten(),

View file

@ -1,3 +1,5 @@
from typing import Optional
import torch as th
import torch.nn as nn
from torch.distributions import Normal, Categorical
@ -8,24 +10,25 @@ class Distribution(object):
def __init__(self):
super(Distribution, self).__init__()
def log_prob(self, x):
def log_prob(self, x: th.Tensor) -> th.Tensor:
"""
returns the log likelihood
:param x: (object) the taken action
:param x: (th.Tensor) the taken action
:return: (th.Tensor) The log likelihood of the distribution
"""
raise NotImplementedError
def entropy(self):
def entropy(self) -> Optional[th.Tensor]:
"""
Returns shannon's entropy of the probability
:return: (th.Tensor) the entropy
:return: (Optional[th.Tensor]) the entropy,
return None if no analytical form is known
"""
raise NotImplementedError
def sample(self):
def sample(self) -> th.Tensor:
"""
returns a sample from the probabilty distribution
@ -145,6 +148,11 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
# Squash the output
return th.tanh(self.gaussian_action)
def entropy(self):
# No analytical form,
# entropy needs to be estimated using -log_prob.mean()
return None
def sample(self):
self.gaussian_action = self.distribution.rsample()
return th.tanh(self.gaussian_action)
@ -371,7 +379,10 @@ class StateDependentNoiseDistribution(Distribution):
return action
def entropy(self):
# TODO: account for the squashing?
# No analytical form,
# entropy needs to be estimated using -log_prob.mean()
if self.bijector is not None:
return None
return self.distribution.entropy()
def log_prob_from_params(self, mean_actions, log_std, latent_sde):

View file

@ -3,12 +3,16 @@ Common aliases for type hing
"""
from typing import Union, Type, Optional, Dict, Any, List, Tuple
import torch
import torch as th
import gym
from torchy_baselines.common.vec_env import VecEnv
GymEnv = Union[gym.Env, VecEnv]
TensorDict = Dict[str, torch.Tensor]
TensorDict = Dict[str, th.Tensor]
OptimizerStateDict = Dict[str, Any]
# obs, action, old_values, old_log_prob, advantage, return_batch
RolloutBufferSamples = Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]
# obs, action, next_obs, done, reward
ReplayBufferSamples = Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]

View file

@ -152,8 +152,9 @@ class VecEnv(ABC):
:param indices: ([int])
"""
indices = self._get_indices(indices)
# Different seed per environment
if not hasattr(seed, 'len'):
seed = [seed] * len(indices)
seed = [seed + i for i in range(len(indices))]
assert len(seed) == len(indices)
return [self.env_method('seed', seed[i], indices=i) for i in indices]

View file

@ -116,9 +116,7 @@ class PPO(BaseRLModel):
# Action is a scalar
action_dim = 1
# TODO: different seed for each env when n_envs > 1
if self.n_envs == 1:
self.set_random_seed(self.seed)
self.set_random_seed(self.seed)
self.rollout_buffer = RolloutBuffer(self.n_steps, state_dim, action_dim, self.device,
gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs)
@ -208,15 +206,14 @@ class PPO(BaseRLModel):
return obs, continue_training
def train(self, gradient_steps, batch_size=64):
def train(self, gradient_steps: int, batch_size: int = 64) -> None:
# Update optimizer learning rate
self._update_learning_rate(self.policy.optimizer)
# Compute current clip range
clip_range = self.clip_range(self._current_progress)
logger.logkv("clip_range", clip_range)
# Optional: clip range for the value function
if self.clip_range_vf is not None:
clip_range_vf = self.clip_range_vf(self._current_progress)
logger.logkv("clip_range_vf", clip_range_vf)
for gradient_step in range(gradient_steps):
approx_kl_divs = []
@ -258,7 +255,11 @@ class PPO(BaseRLModel):
value_loss = F.mse_loss(return_batch, values_pred)
# Entropy loss favor exploration
entropy_loss = -th.mean(entropy)
if entropy is None:
# Approximate entropy when no analytical form
entropy_loss = -log_prob.mean()
else:
entropy_loss = -th.mean(entropy)
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
@ -278,9 +279,14 @@ class PPO(BaseRLModel):
explained_var = explained_variance(self.rollout_buffer.returns.flatten(),
self.rollout_buffer.values.flatten())
logger.logkv("clip_range", clip_range)
if self.clip_range_vf is not None:
logger.logkv("clip_range_vf", clip_range_vf)
logger.logkv("explained_variance", explained_var)
# TODO: gather stats for the entropy and other losses?
logger.logkv("entropy", entropy.mean().item())
logger.logkv("entropy_loss", entropy_loss.item())
logger.logkv("policy_loss", policy_loss.item())
logger.logkv("value_loss", value_loss.item())
if hasattr(self.policy, 'log_std'):

View file

@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import List, Tuple, Optional
import torch as th
import torch.nn.functional as F
@ -6,6 +6,7 @@ import numpy as np
from torchy_baselines.common.base_class import OffPolicyRLModel
from torchy_baselines.common.buffers import ReplayBuffer
from torchy_baselines.common.type_aliases import ReplayBufferSamples
from torchy_baselines.td3.policies import TD3Policy
@ -132,7 +133,10 @@ class TD3(OffPolicyRLModel):
"""
return self.unscale_action(self.select_action(observation, deterministic=deterministic))
def train_critic(self, gradient_steps=1, batch_size=100, replay_data=None, tau=0.0):
def train_critic(self, gradient_steps: int = 1,
batch_size: int = 100,
replay_data: Optional[ReplayBufferSamples] = None,
tau: float = 0.0):
# Update optimizer learning rate
self._update_learning_rate(self.critic.optimizer)
@ -171,9 +175,11 @@ class TD3(OffPolicyRLModel):
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
def train_actor(self, gradient_steps=1, batch_size=100, tau_actor=0.005,
tau_critic=0.005,
replay_data=None):
def train_actor(self, gradient_steps: int = 1,
batch_size: int = 100,
tau_actor: float = 0.005,
tau_critic: float = 0.005,
replay_data: Optional[ReplayBufferSamples] = None):
# Update optimizer learning rate
self._update_learning_rate(self.actor.optimizer)
@ -200,7 +206,7 @@ class TD3(OffPolicyRLModel):
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(tau_actor * param.data + (1 - tau_actor) * target_param.data)
def train(self, gradient_steps, batch_size=100, policy_delay=2):
def train(self, gradient_steps: int, batch_size: int = 100, policy_delay: int = 2):
for gradient_step in range(gradient_steps):
@ -234,7 +240,11 @@ class TD3(OffPolicyRLModel):
policy_loss = -(advantage * log_prob).mean()
# Entropy loss favor exploration
entropy_loss = -th.mean(entropy)
if entropy is None:
# Approximate entropy when no analytical form
entropy_loss = -log_prob.mean()
else:
entropy_loss = -th.mean(entropy)
vf_coef = 0.5
loss = policy_loss + self.sde_ent_coef * entropy_loss + vf_coef * value_loss