mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-02 23:40:09 +00:00
Fix entropy loss for squashed Gaussian and VecEnv seeding
This commit is contained in:
parent
02a080f647
commit
2ce31c1e21
7 changed files with 67 additions and 25 deletions
|
|
@ -24,6 +24,8 @@ Bug Fixes:
|
|||
^^^^^^^^^^
|
||||
- Fix loading model on CPU that were trained on GPU
|
||||
- Fix `reset_num_timesteps` that was not used
|
||||
- Fix entropy computation for squashed Gaussian (approximate it now)
|
||||
- Fix seeding when using multiple environments (different seed per env)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import pytest
|
||||
import torch as th
|
||||
|
||||
from torchy_baselines import A2C, PPO
|
||||
from torchy_baselines.common.distributions import DiagGaussianDistribution, TanhBijector, \
|
||||
StateDependentNoiseDistribution
|
||||
from torchy_baselines.common.utils import set_random_seed
|
||||
|
|
@ -21,6 +22,14 @@ def test_bijector():
|
|||
# Check the inverse method
|
||||
assert th.isclose(TanhBijector.inverse(squashed_actions), actions).all()
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO])
|
||||
def test_squashed_gaussian(model_class):
|
||||
"""
|
||||
Test run with squashed Gaussian (notably entropy computation)
|
||||
"""
|
||||
model = model_class('MlpPolicy', 'Pendulum-v0', use_sde=True, n_steps=100, policy_kwargs=dict(squash_output=True))
|
||||
model.learn(500)
|
||||
|
||||
|
||||
def test_sde_distribution():
|
||||
n_samples = int(5e6)
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ class A2C(PPO):
|
|||
lr=self.learning_rate(1), alpha=0.99,
|
||||
eps=self.rms_prop_eps, weight_decay=0)
|
||||
|
||||
def train(self, gradient_steps, batch_size=None):
|
||||
def train(self, gradient_steps: int, batch_size=None):
|
||||
# Update optimizer learning rate
|
||||
self._update_learning_rate(self.policy.optimizer)
|
||||
# A2C with gradient_steps > 1 does not make sense
|
||||
|
|
@ -107,7 +107,11 @@ class A2C(PPO):
|
|||
value_loss = F.mse_loss(return_batch, values)
|
||||
|
||||
# Entropy loss favor exploration
|
||||
entropy_loss = -th.mean(entropy)
|
||||
if entropy is None:
|
||||
# Approximate entropy when no analytical form
|
||||
entropy_loss = -log_prob.mean()
|
||||
else:
|
||||
entropy_loss = -th.mean(entropy)
|
||||
|
||||
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
|
||||
|
||||
|
|
@ -123,7 +127,7 @@ class A2C(PPO):
|
|||
self.rollout_buffer.values.flatten())
|
||||
|
||||
logger.logkv("explained_variance", explained_var)
|
||||
logger.logkv("entropy", entropy.mean().item())
|
||||
logger.logkv("entropy_loss", entropy_loss.item())
|
||||
logger.logkv("policy_loss", policy_loss.item())
|
||||
logger.logkv("value_loss", value_loss.item())
|
||||
if hasattr(self.policy, 'log_std'):
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
from torch.distributions import Normal, Categorical
|
||||
|
|
@ -8,24 +10,25 @@ class Distribution(object):
|
|||
def __init__(self):
|
||||
super(Distribution, self).__init__()
|
||||
|
||||
def log_prob(self, x):
|
||||
def log_prob(self, x: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
returns the log likelihood
|
||||
|
||||
:param x: (object) the taken action
|
||||
:param x: (th.Tensor) the taken action
|
||||
:return: (th.Tensor) The log likelihood of the distribution
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def entropy(self):
|
||||
def entropy(self) -> Optional[th.Tensor]:
|
||||
"""
|
||||
Returns shannon's entropy of the probability
|
||||
|
||||
:return: (th.Tensor) the entropy
|
||||
:return: (Optional[th.Tensor]) the entropy,
|
||||
return None if no analytical form is known
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def sample(self):
|
||||
def sample(self) -> th.Tensor:
|
||||
"""
|
||||
returns a sample from the probabilty distribution
|
||||
|
||||
|
|
@ -145,6 +148,11 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
|
|||
# Squash the output
|
||||
return th.tanh(self.gaussian_action)
|
||||
|
||||
def entropy(self):
|
||||
# No analytical form,
|
||||
# entropy needs to be estimated using -log_prob.mean()
|
||||
return None
|
||||
|
||||
def sample(self):
|
||||
self.gaussian_action = self.distribution.rsample()
|
||||
return th.tanh(self.gaussian_action)
|
||||
|
|
@ -371,7 +379,10 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
return action
|
||||
|
||||
def entropy(self):
|
||||
# TODO: account for the squashing?
|
||||
# No analytical form,
|
||||
# entropy needs to be estimated using -log_prob.mean()
|
||||
if self.bijector is not None:
|
||||
return None
|
||||
return self.distribution.entropy()
|
||||
|
||||
def log_prob_from_params(self, mean_actions, log_std, latent_sde):
|
||||
|
|
|
|||
|
|
@ -152,8 +152,9 @@ class VecEnv(ABC):
|
|||
:param indices: ([int])
|
||||
"""
|
||||
indices = self._get_indices(indices)
|
||||
# Different seed per environment
|
||||
if not hasattr(seed, 'len'):
|
||||
seed = [seed] * len(indices)
|
||||
seed = [seed + i for i in range(len(indices))]
|
||||
assert len(seed) == len(indices)
|
||||
return [self.env_method('seed', seed[i], indices=i) for i in indices]
|
||||
|
||||
|
|
|
|||
|
|
@ -116,9 +116,7 @@ class PPO(BaseRLModel):
|
|||
# Action is a scalar
|
||||
action_dim = 1
|
||||
|
||||
# TODO: different seed for each env when n_envs > 1
|
||||
if self.n_envs == 1:
|
||||
self.set_random_seed(self.seed)
|
||||
self.set_random_seed(self.seed)
|
||||
|
||||
self.rollout_buffer = RolloutBuffer(self.n_steps, state_dim, action_dim, self.device,
|
||||
gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs)
|
||||
|
|
@ -208,15 +206,14 @@ class PPO(BaseRLModel):
|
|||
|
||||
return obs, continue_training
|
||||
|
||||
def train(self, gradient_steps, batch_size=64):
|
||||
def train(self, gradient_steps: int, batch_size: int = 64) -> None:
|
||||
# Update optimizer learning rate
|
||||
self._update_learning_rate(self.policy.optimizer)
|
||||
# Compute current clip range
|
||||
clip_range = self.clip_range(self._current_progress)
|
||||
logger.logkv("clip_range", clip_range)
|
||||
# Optional: clip range for the value function
|
||||
if self.clip_range_vf is not None:
|
||||
clip_range_vf = self.clip_range_vf(self._current_progress)
|
||||
logger.logkv("clip_range_vf", clip_range_vf)
|
||||
|
||||
for gradient_step in range(gradient_steps):
|
||||
approx_kl_divs = []
|
||||
|
|
@ -258,7 +255,11 @@ class PPO(BaseRLModel):
|
|||
value_loss = F.mse_loss(return_batch, values_pred)
|
||||
|
||||
# Entropy loss favor exploration
|
||||
entropy_loss = -th.mean(entropy)
|
||||
if entropy is None:
|
||||
# Approximate entropy when no analytical form
|
||||
entropy_loss = -log_prob.mean()
|
||||
else:
|
||||
entropy_loss = -th.mean(entropy)
|
||||
|
||||
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
|
||||
|
||||
|
|
@ -278,9 +279,14 @@ class PPO(BaseRLModel):
|
|||
explained_var = explained_variance(self.rollout_buffer.returns.flatten(),
|
||||
self.rollout_buffer.values.flatten())
|
||||
|
||||
logger.logkv("clip_range", clip_range)
|
||||
if self.clip_range_vf is not None:
|
||||
logger.logkv("clip_range_vf", clip_range_vf)
|
||||
|
||||
|
||||
logger.logkv("explained_variance", explained_var)
|
||||
# TODO: gather stats for the entropy and other losses?
|
||||
logger.logkv("entropy", entropy.mean().item())
|
||||
logger.logkv("entropy_loss", entropy_loss.item())
|
||||
logger.logkv("policy_loss", policy_loss.item())
|
||||
logger.logkv("value_loss", value_loss.item())
|
||||
if hasattr(self.policy, 'log_std'):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
import torch as th
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -132,7 +132,10 @@ class TD3(OffPolicyRLModel):
|
|||
"""
|
||||
return self.unscale_action(self.select_action(observation, deterministic=deterministic))
|
||||
|
||||
def train_critic(self, gradient_steps=1, batch_size=100, replay_data=None, tau=0.0):
|
||||
def train_critic(self, gradient_steps: int = 1,
|
||||
batch_size: int = 100,
|
||||
replay_data: Optional[Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]] = None,
|
||||
tau: float = 0.0):
|
||||
# Update optimizer learning rate
|
||||
self._update_learning_rate(self.critic.optimizer)
|
||||
|
||||
|
|
@ -171,9 +174,11 @@ class TD3(OffPolicyRLModel):
|
|||
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
|
||||
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
|
||||
|
||||
def train_actor(self, gradient_steps=1, batch_size=100, tau_actor=0.005,
|
||||
tau_critic=0.005,
|
||||
replay_data=None):
|
||||
def train_actor(self, gradient_steps: int = 1,
|
||||
batch_size: int = 100,
|
||||
tau_actor: float = 0.005,
|
||||
tau_critic: float = 0.005,
|
||||
replay_data: Optional[Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]] = None):
|
||||
# Update optimizer learning rate
|
||||
self._update_learning_rate(self.actor.optimizer)
|
||||
|
||||
|
|
@ -200,7 +205,7 @@ class TD3(OffPolicyRLModel):
|
|||
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
|
||||
target_param.data.copy_(tau_actor * param.data + (1 - tau_actor) * target_param.data)
|
||||
|
||||
def train(self, gradient_steps, batch_size=100, policy_delay=2):
|
||||
def train(self, gradient_steps: int, batch_size: int = 100, policy_delay: int = 2):
|
||||
|
||||
for gradient_step in range(gradient_steps):
|
||||
|
||||
|
|
@ -234,7 +239,11 @@ class TD3(OffPolicyRLModel):
|
|||
policy_loss = -(advantage * log_prob).mean()
|
||||
|
||||
# Entropy loss favor exploration
|
||||
entropy_loss = -th.mean(entropy)
|
||||
if entropy is None:
|
||||
# Approximate entropy when no analytical form
|
||||
entropy_loss = -log_prob.mean()
|
||||
else:
|
||||
entropy_loss = -th.mean(entropy)
|
||||
|
||||
vf_coef = 0.5
|
||||
loss = policy_loss + self.sde_ent_coef * entropy_loss + vf_coef * value_loss
|
||||
|
|
|
|||
Loading…
Reference in a new issue