mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-19 21:40:19 +00:00
Merge pull request #51 from Antonin-Raffin/fix/entropy-squashed
Fix entropy loss for squashed Gaussian and VecEnv seeding
This commit is contained in:
commit
cbb0843201
9 changed files with 78 additions and 30 deletions
|
|
@ -24,6 +24,8 @@ Bug Fixes:
|
|||
^^^^^^^^^^
|
||||
- Fix loading model on CPU that were trained on GPU
|
||||
- Fix `reset_num_timesteps` that was not used
|
||||
- Fix entropy computation for squashed Gaussian (approximate it now)
|
||||
- Fix seeding when using multiple environments (different seed per env)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import pytest
|
||||
import torch as th
|
||||
|
||||
from torchy_baselines import A2C, PPO
|
||||
from torchy_baselines.common.distributions import DiagGaussianDistribution, TanhBijector, \
|
||||
StateDependentNoiseDistribution
|
||||
from torchy_baselines.common.utils import set_random_seed
|
||||
|
|
@ -21,6 +22,14 @@ def test_bijector():
|
|||
# Check the inverse method
|
||||
assert th.isclose(TanhBijector.inverse(squashed_actions), actions).all()
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO])
|
||||
def test_squashed_gaussian(model_class):
|
||||
"""
|
||||
Test run with squashed Gaussian (notably entropy computation)
|
||||
"""
|
||||
model = model_class('MlpPolicy', 'Pendulum-v0', use_sde=True, n_steps=100, policy_kwargs=dict(squash_output=True))
|
||||
model.learn(500)
|
||||
|
||||
|
||||
def test_sde_distribution():
|
||||
n_samples = int(5e6)
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ class A2C(PPO):
|
|||
lr=self.learning_rate(1), alpha=0.99,
|
||||
eps=self.rms_prop_eps, weight_decay=0)
|
||||
|
||||
def train(self, gradient_steps, batch_size=None):
|
||||
def train(self, gradient_steps: int, batch_size=None):
|
||||
# Update optimizer learning rate
|
||||
self._update_learning_rate(self.policy.optimizer)
|
||||
# A2C with gradient_steps > 1 does not make sense
|
||||
|
|
@ -107,7 +107,11 @@ class A2C(PPO):
|
|||
value_loss = F.mse_loss(return_batch, values)
|
||||
|
||||
# Entropy loss favor exploration
|
||||
entropy_loss = -th.mean(entropy)
|
||||
if entropy is None:
|
||||
# Approximate entropy when no analytical form
|
||||
entropy_loss = -log_prob.mean()
|
||||
else:
|
||||
entropy_loss = -th.mean(entropy)
|
||||
|
||||
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
|
||||
|
||||
|
|
@ -123,7 +127,7 @@ class A2C(PPO):
|
|||
self.rollout_buffer.values.flatten())
|
||||
|
||||
logger.logkv("explained_variance", explained_var)
|
||||
logger.logkv("entropy", entropy.mean().item())
|
||||
logger.logkv("entropy_loss", entropy_loss.item())
|
||||
logger.logkv("policy_loss", policy_loss.item())
|
||||
logger.logkv("value_loss", value_loss.item())
|
||||
if hasattr(self.policy, 'log_std'):
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import numpy as np
|
|||
import torch as th
|
||||
|
||||
from torchy_baselines.common.vec_env import VecNormalize
|
||||
from torchy_baselines.common.type_aliases import RolloutBufferSamples, ReplayBufferSamples
|
||||
|
||||
|
||||
class BaseBuffer(object):
|
||||
|
|
@ -177,7 +178,7 @@ class ReplayBuffer(BaseBuffer):
|
|||
def _get_samples(self,
|
||||
batch_inds: np.ndarray,
|
||||
env: Optional[VecNormalize] = None
|
||||
) -> Tuple[th.Tensor, ...]:
|
||||
) -> ReplayBufferSamples:
|
||||
data = (self._normalize_obs(self.observations[batch_inds, 0, :], env),
|
||||
self.actions[batch_inds, 0, :],
|
||||
self._normalize_obs(self.next_observations[batch_inds, 0, :], env),
|
||||
|
|
@ -305,7 +306,7 @@ class RolloutBuffer(BaseBuffer):
|
|||
if self.pos == self.buffer_size:
|
||||
self.full = True
|
||||
|
||||
def get(self, batch_size: Optional[int] = None) -> Generator[Tuple[th.Tensor, ...], None, None]:
|
||||
def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:
|
||||
assert self.full, ''
|
||||
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
||||
# Prepare the data
|
||||
|
|
@ -325,7 +326,7 @@ class RolloutBuffer(BaseBuffer):
|
|||
start_idx += batch_size
|
||||
|
||||
def _get_samples(self, batch_inds: np.ndarray,
|
||||
env: Optional[VecNormalize] = None) -> Tuple[th.Tensor, ...]:
|
||||
env: Optional[VecNormalize] = None) -> RolloutBufferSamples:
|
||||
data = (self.observations[batch_inds],
|
||||
self.actions[batch_inds],
|
||||
self.values[batch_inds].flatten(),
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
from torch.distributions import Normal, Categorical
|
||||
|
|
@ -8,24 +10,25 @@ class Distribution(object):
|
|||
def __init__(self):
|
||||
super(Distribution, self).__init__()
|
||||
|
||||
def log_prob(self, x):
|
||||
def log_prob(self, x: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
returns the log likelihood
|
||||
|
||||
:param x: (object) the taken action
|
||||
:param x: (th.Tensor) the taken action
|
||||
:return: (th.Tensor) The log likelihood of the distribution
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def entropy(self):
|
||||
def entropy(self) -> Optional[th.Tensor]:
|
||||
"""
|
||||
Returns shannon's entropy of the probability
|
||||
|
||||
:return: (th.Tensor) the entropy
|
||||
:return: (Optional[th.Tensor]) the entropy,
|
||||
return None if no analytical form is known
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def sample(self):
|
||||
def sample(self) -> th.Tensor:
|
||||
"""
|
||||
returns a sample from the probabilty distribution
|
||||
|
||||
|
|
@ -145,6 +148,11 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
|
|||
# Squash the output
|
||||
return th.tanh(self.gaussian_action)
|
||||
|
||||
def entropy(self):
|
||||
# No analytical form,
|
||||
# entropy needs to be estimated using -log_prob.mean()
|
||||
return None
|
||||
|
||||
def sample(self):
|
||||
self.gaussian_action = self.distribution.rsample()
|
||||
return th.tanh(self.gaussian_action)
|
||||
|
|
@ -371,7 +379,10 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
return action
|
||||
|
||||
def entropy(self):
|
||||
# TODO: account for the squashing?
|
||||
# No analytical form,
|
||||
# entropy needs to be estimated using -log_prob.mean()
|
||||
if self.bijector is not None:
|
||||
return None
|
||||
return self.distribution.entropy()
|
||||
|
||||
def log_prob_from_params(self, mean_actions, log_std, latent_sde):
|
||||
|
|
|
|||
|
|
@ -3,12 +3,16 @@ Common aliases for type hing
|
|||
"""
|
||||
from typing import Union, Type, Optional, Dict, Any, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch as th
|
||||
import gym
|
||||
|
||||
from torchy_baselines.common.vec_env import VecEnv
|
||||
|
||||
|
||||
GymEnv = Union[gym.Env, VecEnv]
|
||||
TensorDict = Dict[str, torch.Tensor]
|
||||
TensorDict = Dict[str, th.Tensor]
|
||||
OptimizerStateDict = Dict[str, Any]
|
||||
# obs, action, old_values, old_log_prob, advantage, return_batch
|
||||
RolloutBufferSamples = Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]
|
||||
# obs, action, next_obs, done, reward
|
||||
ReplayBufferSamples = Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]
|
||||
|
|
|
|||
|
|
@ -152,8 +152,9 @@ class VecEnv(ABC):
|
|||
:param indices: ([int])
|
||||
"""
|
||||
indices = self._get_indices(indices)
|
||||
# Different seed per environment
|
||||
if not hasattr(seed, 'len'):
|
||||
seed = [seed] * len(indices)
|
||||
seed = [seed + i for i in range(len(indices))]
|
||||
assert len(seed) == len(indices)
|
||||
return [self.env_method('seed', seed[i], indices=i) for i in indices]
|
||||
|
||||
|
|
|
|||
|
|
@ -116,9 +116,7 @@ class PPO(BaseRLModel):
|
|||
# Action is a scalar
|
||||
action_dim = 1
|
||||
|
||||
# TODO: different seed for each env when n_envs > 1
|
||||
if self.n_envs == 1:
|
||||
self.set_random_seed(self.seed)
|
||||
self.set_random_seed(self.seed)
|
||||
|
||||
self.rollout_buffer = RolloutBuffer(self.n_steps, state_dim, action_dim, self.device,
|
||||
gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs)
|
||||
|
|
@ -208,15 +206,14 @@ class PPO(BaseRLModel):
|
|||
|
||||
return obs, continue_training
|
||||
|
||||
def train(self, gradient_steps, batch_size=64):
|
||||
def train(self, gradient_steps: int, batch_size: int = 64) -> None:
|
||||
# Update optimizer learning rate
|
||||
self._update_learning_rate(self.policy.optimizer)
|
||||
# Compute current clip range
|
||||
clip_range = self.clip_range(self._current_progress)
|
||||
logger.logkv("clip_range", clip_range)
|
||||
# Optional: clip range for the value function
|
||||
if self.clip_range_vf is not None:
|
||||
clip_range_vf = self.clip_range_vf(self._current_progress)
|
||||
logger.logkv("clip_range_vf", clip_range_vf)
|
||||
|
||||
for gradient_step in range(gradient_steps):
|
||||
approx_kl_divs = []
|
||||
|
|
@ -258,7 +255,11 @@ class PPO(BaseRLModel):
|
|||
value_loss = F.mse_loss(return_batch, values_pred)
|
||||
|
||||
# Entropy loss favor exploration
|
||||
entropy_loss = -th.mean(entropy)
|
||||
if entropy is None:
|
||||
# Approximate entropy when no analytical form
|
||||
entropy_loss = -log_prob.mean()
|
||||
else:
|
||||
entropy_loss = -th.mean(entropy)
|
||||
|
||||
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
|
||||
|
||||
|
|
@ -278,9 +279,14 @@ class PPO(BaseRLModel):
|
|||
explained_var = explained_variance(self.rollout_buffer.returns.flatten(),
|
||||
self.rollout_buffer.values.flatten())
|
||||
|
||||
logger.logkv("clip_range", clip_range)
|
||||
if self.clip_range_vf is not None:
|
||||
logger.logkv("clip_range_vf", clip_range_vf)
|
||||
|
||||
|
||||
logger.logkv("explained_variance", explained_var)
|
||||
# TODO: gather stats for the entropy and other losses?
|
||||
logger.logkv("entropy", entropy.mean().item())
|
||||
logger.logkv("entropy_loss", entropy_loss.item())
|
||||
logger.logkv("policy_loss", policy_loss.item())
|
||||
logger.logkv("value_loss", value_loss.item())
|
||||
if hasattr(self.policy, 'log_std'):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
import torch as th
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -6,6 +6,7 @@ import numpy as np
|
|||
|
||||
from torchy_baselines.common.base_class import OffPolicyRLModel
|
||||
from torchy_baselines.common.buffers import ReplayBuffer
|
||||
from torchy_baselines.common.type_aliases import ReplayBufferSamples
|
||||
from torchy_baselines.td3.policies import TD3Policy
|
||||
|
||||
|
||||
|
|
@ -132,7 +133,10 @@ class TD3(OffPolicyRLModel):
|
|||
"""
|
||||
return self.unscale_action(self.select_action(observation, deterministic=deterministic))
|
||||
|
||||
def train_critic(self, gradient_steps=1, batch_size=100, replay_data=None, tau=0.0):
|
||||
def train_critic(self, gradient_steps: int = 1,
|
||||
batch_size: int = 100,
|
||||
replay_data: Optional[ReplayBufferSamples] = None,
|
||||
tau: float = 0.0):
|
||||
# Update optimizer learning rate
|
||||
self._update_learning_rate(self.critic.optimizer)
|
||||
|
||||
|
|
@ -171,9 +175,11 @@ class TD3(OffPolicyRLModel):
|
|||
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
|
||||
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
|
||||
|
||||
def train_actor(self, gradient_steps=1, batch_size=100, tau_actor=0.005,
|
||||
tau_critic=0.005,
|
||||
replay_data=None):
|
||||
def train_actor(self, gradient_steps: int = 1,
|
||||
batch_size: int = 100,
|
||||
tau_actor: float = 0.005,
|
||||
tau_critic: float = 0.005,
|
||||
replay_data: Optional[ReplayBufferSamples] = None):
|
||||
# Update optimizer learning rate
|
||||
self._update_learning_rate(self.actor.optimizer)
|
||||
|
||||
|
|
@ -200,7 +206,7 @@ class TD3(OffPolicyRLModel):
|
|||
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
|
||||
target_param.data.copy_(tau_actor * param.data + (1 - tau_actor) * target_param.data)
|
||||
|
||||
def train(self, gradient_steps, batch_size=100, policy_delay=2):
|
||||
def train(self, gradient_steps: int, batch_size: int = 100, policy_delay: int = 2):
|
||||
|
||||
for gradient_step in range(gradient_steps):
|
||||
|
||||
|
|
@ -234,7 +240,11 @@ class TD3(OffPolicyRLModel):
|
|||
policy_loss = -(advantage * log_prob).mean()
|
||||
|
||||
# Entropy loss favor exploration
|
||||
entropy_loss = -th.mean(entropy)
|
||||
if entropy is None:
|
||||
# Approximate entropy when no analytical form
|
||||
entropy_loss = -log_prob.mean()
|
||||
else:
|
||||
entropy_loss = -th.mean(entropy)
|
||||
|
||||
vf_coef = 0.5
|
||||
loss = policy_loss + self.sde_ent_coef * entropy_loss + vf_coef * value_loss
|
||||
|
|
|
|||
Loading…
Reference in a new issue