From 2ce31c1e210156fc4a671ea86a7ba199bdab5ec3 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 11 Feb 2020 17:22:03 +0100 Subject: [PATCH 1/2] Fix entropy loss for squashed Gaussian and VecEnv seeding --- docs/misc/changelog.rst | 2 ++ tests/test_distributions.py | 9 ++++++++ torchy_baselines/a2c/a2c.py | 10 +++++--- torchy_baselines/common/distributions.py | 23 ++++++++++++++----- .../common/vec_env/base_vec_env.py | 3 ++- torchy_baselines/ppo/ppo.py | 22 +++++++++++------- torchy_baselines/td3/td3.py | 23 +++++++++++++------ 7 files changed, 67 insertions(+), 25 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 5ab772c..30f5a60 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -24,6 +24,8 @@ Bug Fixes: ^^^^^^^^^^ - Fix loading model on CPU that were trained on GPU - Fix `reset_num_timesteps` that was not used +- Fix entropy computation for squashed Gaussian (approximate it now) +- Fix seeding when using multiple environments (different seed per env) Deprecations: ^^^^^^^^^^^^^ diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 7e24cda..4b5d792 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -1,6 +1,7 @@ import pytest import torch as th +from torchy_baselines import A2C, PPO from torchy_baselines.common.distributions import DiagGaussianDistribution, TanhBijector, \ StateDependentNoiseDistribution from torchy_baselines.common.utils import set_random_seed @@ -21,6 +22,14 @@ def test_bijector(): # Check the inverse method assert th.isclose(TanhBijector.inverse(squashed_actions), actions).all() +@pytest.mark.parametrize("model_class", [A2C, PPO]) +def test_squashed_gaussian(model_class): + """ + Test run with squashed Gaussian (notably entropy computation) + """ + model = model_class('MlpPolicy', 'Pendulum-v0', use_sde=True, n_steps=100, policy_kwargs=dict(squash_output=True)) + model.learn(500) + def test_sde_distribution(): n_samples = int(5e6) diff --git a/torchy_baselines/a2c/a2c.py b/torchy_baselines/a2c/a2c.py index f837e9c..01cbd39 100644 --- a/torchy_baselines/a2c/a2c.py +++ b/torchy_baselines/a2c/a2c.py @@ -77,7 +77,7 @@ class A2C(PPO): lr=self.learning_rate(1), alpha=0.99, eps=self.rms_prop_eps, weight_decay=0) - def train(self, gradient_steps, batch_size=None): + def train(self, gradient_steps: int, batch_size=None): # Update optimizer learning rate self._update_learning_rate(self.policy.optimizer) # A2C with gradient_steps > 1 does not make sense @@ -107,7 +107,11 @@ class A2C(PPO): value_loss = F.mse_loss(return_batch, values) # Entropy loss favor exploration - entropy_loss = -th.mean(entropy) + if entropy is None: + # Approximate entropy when no analytical form + entropy_loss = -log_prob.mean() + else: + entropy_loss = -th.mean(entropy) loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss @@ -123,7 +127,7 @@ class A2C(PPO): self.rollout_buffer.values.flatten()) logger.logkv("explained_variance", explained_var) - logger.logkv("entropy", entropy.mean().item()) + logger.logkv("entropy_loss", entropy_loss.item()) logger.logkv("policy_loss", policy_loss.item()) logger.logkv("value_loss", value_loss.item()) if hasattr(self.policy, 'log_std'): diff --git a/torchy_baselines/common/distributions.py b/torchy_baselines/common/distributions.py index 6ede43f..451535c 100644 --- a/torchy_baselines/common/distributions.py +++ b/torchy_baselines/common/distributions.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch as th import torch.nn as nn from torch.distributions import Normal, Categorical @@ -8,24 +10,25 @@ class Distribution(object): def __init__(self): super(Distribution, self).__init__() - def log_prob(self, x): + def log_prob(self, x: th.Tensor) -> th.Tensor: """ returns the log likelihood - :param x: (object) the taken action + :param x: (th.Tensor) the taken action :return: (th.Tensor) The log likelihood of the distribution """ raise NotImplementedError - def entropy(self): + def entropy(self) -> Optional[th.Tensor]: """ Returns shannon's entropy of the probability - :return: (th.Tensor) the entropy + :return: (Optional[th.Tensor]) the entropy, + return None if no analytical form is known """ raise NotImplementedError - def sample(self): + def sample(self) -> th.Tensor: """ returns a sample from the probabilty distribution @@ -145,6 +148,11 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution): # Squash the output return th.tanh(self.gaussian_action) + def entropy(self): + # No analytical form, + # entropy needs to be estimated using -log_prob.mean() + return None + def sample(self): self.gaussian_action = self.distribution.rsample() return th.tanh(self.gaussian_action) @@ -371,7 +379,10 @@ class StateDependentNoiseDistribution(Distribution): return action def entropy(self): - # TODO: account for the squashing? + # No analytical form, + # entropy needs to be estimated using -log_prob.mean() + if self.bijector is not None: + return None return self.distribution.entropy() def log_prob_from_params(self, mean_actions, log_std, latent_sde): diff --git a/torchy_baselines/common/vec_env/base_vec_env.py b/torchy_baselines/common/vec_env/base_vec_env.py index e29c35e..18e8a59 100644 --- a/torchy_baselines/common/vec_env/base_vec_env.py +++ b/torchy_baselines/common/vec_env/base_vec_env.py @@ -152,8 +152,9 @@ class VecEnv(ABC): :param indices: ([int]) """ indices = self._get_indices(indices) + # Different seed per environment if not hasattr(seed, 'len'): - seed = [seed] * len(indices) + seed = [seed + i for i in range(len(indices))] assert len(seed) == len(indices) return [self.env_method('seed', seed[i], indices=i) for i in indices] diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index 58c799f..b17f4f8 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -116,9 +116,7 @@ class PPO(BaseRLModel): # Action is a scalar action_dim = 1 - # TODO: different seed for each env when n_envs > 1 - if self.n_envs == 1: - self.set_random_seed(self.seed) + self.set_random_seed(self.seed) self.rollout_buffer = RolloutBuffer(self.n_steps, state_dim, action_dim, self.device, gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs) @@ -208,15 +206,14 @@ class PPO(BaseRLModel): return obs, continue_training - def train(self, gradient_steps, batch_size=64): + def train(self, gradient_steps: int, batch_size: int = 64) -> None: # Update optimizer learning rate self._update_learning_rate(self.policy.optimizer) # Compute current clip range clip_range = self.clip_range(self._current_progress) - logger.logkv("clip_range", clip_range) + # Optional: clip range for the value function if self.clip_range_vf is not None: clip_range_vf = self.clip_range_vf(self._current_progress) - logger.logkv("clip_range_vf", clip_range_vf) for gradient_step in range(gradient_steps): approx_kl_divs = [] @@ -258,7 +255,11 @@ class PPO(BaseRLModel): value_loss = F.mse_loss(return_batch, values_pred) # Entropy loss favor exploration - entropy_loss = -th.mean(entropy) + if entropy is None: + # Approximate entropy when no analytical form + entropy_loss = -log_prob.mean() + else: + entropy_loss = -th.mean(entropy) loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss @@ -278,9 +279,14 @@ class PPO(BaseRLModel): explained_var = explained_variance(self.rollout_buffer.returns.flatten(), self.rollout_buffer.values.flatten()) + logger.logkv("clip_range", clip_range) + if self.clip_range_vf is not None: + logger.logkv("clip_range_vf", clip_range_vf) + + logger.logkv("explained_variance", explained_var) # TODO: gather stats for the entropy and other losses? - logger.logkv("entropy", entropy.mean().item()) + logger.logkv("entropy_loss", entropy_loss.item()) logger.logkv("policy_loss", policy_loss.item()) logger.logkv("value_loss", value_loss.item()) if hasattr(self.policy, 'log_std'): diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index c485b4d..6b731e3 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List, Tuple, Optional import torch as th import torch.nn.functional as F @@ -132,7 +132,10 @@ class TD3(OffPolicyRLModel): """ return self.unscale_action(self.select_action(observation, deterministic=deterministic)) - def train_critic(self, gradient_steps=1, batch_size=100, replay_data=None, tau=0.0): + def train_critic(self, gradient_steps: int = 1, + batch_size: int = 100, + replay_data: Optional[Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]] = None, + tau: float = 0.0): # Update optimizer learning rate self._update_learning_rate(self.critic.optimizer) @@ -171,9 +174,11 @@ class TD3(OffPolicyRLModel): for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) - def train_actor(self, gradient_steps=1, batch_size=100, tau_actor=0.005, - tau_critic=0.005, - replay_data=None): + def train_actor(self, gradient_steps: int = 1, + batch_size: int = 100, + tau_actor: float = 0.005, + tau_critic: float = 0.005, + replay_data: Optional[Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]] = None): # Update optimizer learning rate self._update_learning_rate(self.actor.optimizer) @@ -200,7 +205,7 @@ class TD3(OffPolicyRLModel): for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): target_param.data.copy_(tau_actor * param.data + (1 - tau_actor) * target_param.data) - def train(self, gradient_steps, batch_size=100, policy_delay=2): + def train(self, gradient_steps: int, batch_size: int = 100, policy_delay: int = 2): for gradient_step in range(gradient_steps): @@ -234,7 +239,11 @@ class TD3(OffPolicyRLModel): policy_loss = -(advantage * log_prob).mean() # Entropy loss favor exploration - entropy_loss = -th.mean(entropy) + if entropy is None: + # Approximate entropy when no analytical form + entropy_loss = -log_prob.mean() + else: + entropy_loss = -th.mean(entropy) vf_coef = 0.5 loss = policy_loss + self.sde_ent_coef * entropy_loss + vf_coef * value_loss From 240833ffef2867b6f4d1c6fcb9335969a20b8bcd Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 11 Feb 2020 17:33:22 +0100 Subject: [PATCH 2/2] Add type aliases for buffer samples --- torchy_baselines/common/buffers.py | 7 ++++--- torchy_baselines/common/type_aliases.py | 8 ++++++-- torchy_baselines/td3/td3.py | 5 +++-- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/torchy_baselines/common/buffers.py b/torchy_baselines/common/buffers.py index 0ac3e2f..c6e9d5c 100644 --- a/torchy_baselines/common/buffers.py +++ b/torchy_baselines/common/buffers.py @@ -4,6 +4,7 @@ import numpy as np import torch as th from torchy_baselines.common.vec_env import VecNormalize +from torchy_baselines.common.type_aliases import RolloutBufferSamples, ReplayBufferSamples class BaseBuffer(object): @@ -177,7 +178,7 @@ class ReplayBuffer(BaseBuffer): def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None - ) -> Tuple[th.Tensor, ...]: + ) -> ReplayBufferSamples: data = (self._normalize_obs(self.observations[batch_inds, 0, :], env), self.actions[batch_inds, 0, :], self._normalize_obs(self.next_observations[batch_inds, 0, :], env), @@ -305,7 +306,7 @@ class RolloutBuffer(BaseBuffer): if self.pos == self.buffer_size: self.full = True - def get(self, batch_size: Optional[int] = None) -> Generator[Tuple[th.Tensor, ...], None, None]: + def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]: assert self.full, '' indices = np.random.permutation(self.buffer_size * self.n_envs) # Prepare the data @@ -325,7 +326,7 @@ class RolloutBuffer(BaseBuffer): start_idx += batch_size def _get_samples(self, batch_inds: np.ndarray, - env: Optional[VecNormalize] = None) -> Tuple[th.Tensor, ...]: + env: Optional[VecNormalize] = None) -> RolloutBufferSamples: data = (self.observations[batch_inds], self.actions[batch_inds], self.values[batch_inds].flatten(), diff --git a/torchy_baselines/common/type_aliases.py b/torchy_baselines/common/type_aliases.py index 8378647..b9035db 100644 --- a/torchy_baselines/common/type_aliases.py +++ b/torchy_baselines/common/type_aliases.py @@ -3,12 +3,16 @@ Common aliases for type hing """ from typing import Union, Type, Optional, Dict, Any, List, Tuple -import torch +import torch as th import gym from torchy_baselines.common.vec_env import VecEnv GymEnv = Union[gym.Env, VecEnv] -TensorDict = Dict[str, torch.Tensor] +TensorDict = Dict[str, th.Tensor] OptimizerStateDict = Dict[str, Any] +# obs, action, old_values, old_log_prob, advantage, return_batch +RolloutBufferSamples = Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor] +# obs, action, next_obs, done, reward +ReplayBufferSamples = Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor] diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 6b731e3..05f1b4f 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -6,6 +6,7 @@ import numpy as np from torchy_baselines.common.base_class import OffPolicyRLModel from torchy_baselines.common.buffers import ReplayBuffer +from torchy_baselines.common.type_aliases import ReplayBufferSamples from torchy_baselines.td3.policies import TD3Policy @@ -134,7 +135,7 @@ class TD3(OffPolicyRLModel): def train_critic(self, gradient_steps: int = 1, batch_size: int = 100, - replay_data: Optional[Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]] = None, + replay_data: Optional[ReplayBufferSamples] = None, tau: float = 0.0): # Update optimizer learning rate self._update_learning_rate(self.critic.optimizer) @@ -178,7 +179,7 @@ class TD3(OffPolicyRLModel): batch_size: int = 100, tau_actor: float = 0.005, tau_critic: float = 0.005, - replay_data: Optional[Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]] = None): + replay_data: Optional[ReplayBufferSamples] = None): # Update optimizer learning rate self._update_learning_rate(self.actor.optimizer)