mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-30 23:18:47 +00:00
Add doc for CEM-RL
This commit is contained in:
parent
81a15414b0
commit
ea3902cd32
3 changed files with 78 additions and 30 deletions
|
|
@ -1,24 +1,27 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
# TODO: add more from https://github.com/hardmaru/estool/blob/master/es.py
|
||||
# or https://github.com/facebookresearch/nevergrad
|
||||
|
||||
class CEM(object):
|
||||
|
||||
"""
|
||||
Cross-entropy method with diagonal covariance (separable CEM)
|
||||
"""
|
||||
Cross-entropy method with diagonal covariance (separable CEM).
|
||||
|
||||
def __init__(self, num_params,
|
||||
mu_init=None,
|
||||
sigma_init=1e-3,
|
||||
pop_size=256,
|
||||
damp=1e-3,
|
||||
damp_limit=1e-5,
|
||||
parents=None,
|
||||
elitism=False,
|
||||
antithetic=False):
|
||||
:param num_params: (int)
|
||||
:param mu_init: (np.ndarray) Initial mean of the population distribution
|
||||
Taken to be zero if None is passed.
|
||||
:param sigma_init: (float) Initial standard deviation of the population distribution
|
||||
:param pop_size: (int) Number of individuals in the population
|
||||
:param damp: (float) Damping for preventing from early convergence.
|
||||
:param damp_limit: (float) Final value of damping
|
||||
:param parents: (int)
|
||||
:param elitism: (bool)
|
||||
:param antithetic: (bool) Use a finite difference like method for sampling
|
||||
(mu + epsilon, mu - epsilon)
|
||||
"""
|
||||
def __init__(self, num_params, mu_init=None, sigma_init=1e-3,
|
||||
pop_size=256, damp=1e-3, damp_limit=1e-5,
|
||||
parents=None, elitism=False, antithetic=False):
|
||||
super(CEM, self).__init__()
|
||||
# misc
|
||||
self.num_params = num_params
|
||||
|
|
@ -31,6 +34,7 @@ class CEM(object):
|
|||
self.sigma = sigma_init
|
||||
self.damp = damp
|
||||
self.damp_limit = damp_limit
|
||||
# Exponential moving average decay for damping
|
||||
self.tau = 0.95
|
||||
self.cov = self.sigma * np.ones(self.num_params)
|
||||
|
||||
|
|
@ -56,6 +60,9 @@ class CEM(object):
|
|||
def ask(self, pop_size):
|
||||
"""
|
||||
Returns a list of candidates parameters
|
||||
|
||||
:param pop_size: (int)
|
||||
:return: ([np.ndarray])
|
||||
"""
|
||||
if self.antithetic and not pop_size % 2:
|
||||
epsilon_half = np.random.randn(pop_size // 2, self.num_params)
|
||||
|
|
@ -64,16 +71,20 @@ class CEM(object):
|
|||
else:
|
||||
epsilon = np.random.randn(pop_size, self.num_params)
|
||||
|
||||
inds = self.mu + epsilon * np.sqrt(self.cov)
|
||||
individuals = self.mu + epsilon * np.sqrt(self.cov)
|
||||
if self.elitism:
|
||||
inds[-1] = self.elite
|
||||
individuals[-1] = self.elite
|
||||
|
||||
return inds
|
||||
return individuals
|
||||
|
||||
def tell(self, solutions, scores):
|
||||
"""
|
||||
Updates the distribution
|
||||
|
||||
:param solutions: ([np.ndarray])
|
||||
:param scores: ([float]) episode reward.
|
||||
"""
|
||||
# Convert rewards (we want to maximize) to cost (we want to minimize)
|
||||
scores = np.array(scores)
|
||||
scores *= -1
|
||||
idx_sorted = np.argsort(scores)
|
||||
|
|
@ -92,7 +103,9 @@ class CEM(object):
|
|||
|
||||
def get_distrib_params(self):
|
||||
"""
|
||||
Returns the parameters of the distrubtion:
|
||||
the mean and sigma
|
||||
Returns the parameters of the distribution:
|
||||
the mean and standard deviation.
|
||||
|
||||
:return: (np.ndarray, np.ndarray)
|
||||
"""
|
||||
return np.copy(self.mu), np.copy(self.cov)
|
||||
|
|
|
|||
|
|
@ -10,17 +10,48 @@ from torchy_baselines.common.vec_env import sync_envs_normalization
|
|||
|
||||
class CEMRL(TD3):
|
||||
"""
|
||||
Implementation of CEM-RL
|
||||
Implementation of CEM-RL, in fact CEM combined with TD3.
|
||||
|
||||
Paper: https://arxiv.org/abs/1810.01222
|
||||
Code: https://github.com/apourchot/CEM-RL
|
||||
"""
|
||||
|
||||
:param policy: (TD3Policy or str) The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||
:param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
|
||||
:param sigma_init: (float) Initial standard deviation of the population distribution
|
||||
:param pop_size: (int) Number of individuals in the population
|
||||
:param damp: (float) Damping for preventing from early convergence.
|
||||
:param damp_limit: (float) Final value of damping
|
||||
:param elitism: (bool)
|
||||
:param n_grad: (int) Number of individuals that will receive a gradient update.
|
||||
Half of the population size in the paper.
|
||||
:param buffer_size: (int) size of the replay buffer
|
||||
:param learning_rate: (float or callable) learning rate for adam optimizer,
|
||||
the same learning rate will be used for all networks (Q-Values and Actor networks)
|
||||
it can be a function of the current progress (from 1 to 0)
|
||||
:param policy_delay: (int) Policy and target networks will only be updated once every policy_delay steps
|
||||
per training steps. The Q values will be updated policy_delay more often (update every training step).
|
||||
:param learning_starts: (int) how many steps of the model to collect transitions for before learning starts
|
||||
:param gamma: (float) the discount factor
|
||||
:param batch_size: (int) Minibatch size for each gradient update
|
||||
:param tau: (float) the soft update coefficient ("polyak update" of the target networks, between 0 and 1)
|
||||
:param action_noise: (ActionNoise) the action noise type. Cf common.noise for the different action noise type.
|
||||
:param target_policy_noise: (float) Standard deviation of gaussian noise added to target policy
|
||||
(smoothing noise)
|
||||
:param target_noise_clip: (float) Limit for absolute value of target policy smoothing noise.
|
||||
:param create_eval_env: (bool) Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
|
||||
:param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
|
||||
:param seed: (int) Seed for the pseudo random generators
|
||||
:param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run.
|
||||
Setting it to auto, the code will be run on the GPU if possible.
|
||||
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
def __init__(self, policy, env, sigma_init=1e-3, pop_size=10,
|
||||
damp=1e-3, damp_limit=1e-5, elitism=False, n_grad=5,
|
||||
policy_delay=2, batch_size=100,
|
||||
buffer_size=int(1e6), learning_rate=1e-3,
|
||||
action_noise=None, learning_starts=100, tau=0.005,
|
||||
buffer_size=int(1e6), learning_rate=1e-3, policy_delay=2,
|
||||
learning_starts=100, gamma=0.99, batch_size=100, tau=0.005,
|
||||
action_noise=None, target_policy_noise=0.2, target_noise_clip=0.5,
|
||||
n_episodes_rollout=1, update_style='original',
|
||||
tensorboard_log=None, create_eval_env=False,
|
||||
policy_kwargs=None, verbose=0, seed=0, device='auto',
|
||||
|
|
@ -28,13 +59,16 @@ class CEMRL(TD3):
|
|||
|
||||
super(CEMRL, self).__init__(policy, env,
|
||||
buffer_size=buffer_size, learning_rate=learning_rate, seed=seed, device=device,
|
||||
action_noise=action_noise, learning_starts=learning_starts,
|
||||
n_episodes_rollout=n_episodes_rollout, tau=tau,
|
||||
action_noise=action_noise, target_policy_noise=target_policy_noise,
|
||||
target_noise_clip=target_noise_clip, learning_starts=learning_starts,
|
||||
n_episodes_rollout=n_episodes_rollout, tau=tau, gamma=gamma,
|
||||
policy_kwargs=policy_kwargs, verbose=verbose,
|
||||
policy_delay=policy_delay, batch_size=batch_size,
|
||||
create_eval_env=create_eval_env,
|
||||
create_eval_env=create_eval_env, tensorboard_log=tensorboard_log,
|
||||
_init_setup_model=False)
|
||||
|
||||
# Evolution strategy method that follows cma-es interface (ask-tell)
|
||||
# for now, only CEM is implemented
|
||||
self.es = None
|
||||
self.sigma_init = sigma_init
|
||||
self.pop_size = pop_size
|
||||
|
|
@ -79,7 +113,8 @@ class CEMRL(TD3):
|
|||
# set params
|
||||
self.actor.load_from_vector(self.es_params[i])
|
||||
self.actor_target.load_from_vector(self.es_params[i])
|
||||
self.actor.optimizer = th.optim.Adam(self.actor.parameters(), lr=self.learning_rate(self._current_progress))
|
||||
self.actor.optimizer = th.optim.Adam(self.actor.parameters(),
|
||||
lr=self.learning_rate(self._current_progress))
|
||||
|
||||
# In the paper: 2 * actor_steps // self.n_grad
|
||||
# In the original implementation: actor_steps // self.n_grad
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import gym
|
|||
from gym import spaces
|
||||
import torch as th
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Check if tensorboard is available for pytorch
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
|
@ -192,7 +193,6 @@ class PPO(BaseRLModel):
|
|||
clip_range_vf = self.clip_range_vf(self._current_progress)
|
||||
logger.logkv("clip_range_vf", clip_range_vf)
|
||||
|
||||
|
||||
for gradient_step in range(gradient_steps):
|
||||
approx_kl_divs = []
|
||||
# Sample replay buffer
|
||||
|
|
@ -226,7 +226,6 @@ class PPO(BaseRLModel):
|
|||
# Value loss using the TD(gae_lambda) target
|
||||
value_loss = F.mse_loss(return_batch, values_pred)
|
||||
|
||||
|
||||
# Entropy loss favor exploration
|
||||
entropy_loss = -th.mean(entropy)
|
||||
|
||||
|
|
@ -241,7 +240,8 @@ class PPO(BaseRLModel):
|
|||
approx_kl_divs.append(th.mean(old_log_prob - log_prob).detach().cpu().numpy())
|
||||
|
||||
if self.target_kl is not None and np.mean(approx_kl_divs) > 1.5 * self.target_kl:
|
||||
print("Early stopping at step {} due to reaching max kl: {:.2f}".format(gradient_step, np.mean(approx_kl_divs)))
|
||||
print("Early stopping at step {} due to reaching max kl: {:.2f}".format(gradient_step,
|
||||
np.mean(approx_kl_divs)))
|
||||
break
|
||||
|
||||
explained_var = explained_variance(self.rollout_buffer.returns.flatten().cpu().numpy(),
|
||||
|
|
|
|||
Loading…
Reference in a new issue