mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-08 00:23:22 +00:00
Cleanup CEM, rename variables + add comments
This commit is contained in:
parent
c47be0086e
commit
c56865e10d
2 changed files with 34 additions and 23 deletions
|
|
@ -8,36 +8,39 @@ class CEM(object):
|
|||
"""
|
||||
Cross-entropy method with diagonal covariance (separable CEM).
|
||||
|
||||
:param num_params: (int)
|
||||
:param num_params: (int) Number of parameters per individual (dimension of the problem)
|
||||
:param mu_init: (np.ndarray) Initial mean of the population distribution
|
||||
Taken to be zero if None is passed.
|
||||
:param sigma_init: (float) Initial standard deviation of the population distribution
|
||||
:param pop_size: (int) Number of individuals in the population
|
||||
:param damp: (float) Damping for preventing from early convergence.
|
||||
:param damp_limit: (float) Final value of damping
|
||||
:param parents: (int)
|
||||
:param elitism: (bool)
|
||||
:param damping_init: (float) Initial value of damping for preventing from early convergence.
|
||||
:param damping_final: (float) Final value of damping
|
||||
:param parents: (int) Number of parents used to compute the new distribution
|
||||
of individuals.
|
||||
:param elitism: (bool) Keep the best known individual in the population
|
||||
:param antithetic: (bool) Use a finite difference like method for sampling
|
||||
(mu + epsilon, mu - epsilon)
|
||||
"""
|
||||
|
||||
def __init__(self, num_params, mu_init=None, sigma_init=1e-3,
|
||||
pop_size=256, damp=1e-3, damp_limit=1e-5,
|
||||
pop_size=256, damping_init=1e-3, damping_final=1e-5,
|
||||
parents=None, elitism=False, antithetic=False):
|
||||
super(CEM, self).__init__()
|
||||
# misc
|
||||
|
||||
self.num_params = num_params
|
||||
|
||||
# distribution parameters
|
||||
# Distribution parameters
|
||||
if mu_init is None:
|
||||
self.mu = np.zeros(self.num_params)
|
||||
else:
|
||||
self.mu = np.array(mu_init)
|
||||
|
||||
self.sigma = sigma_init
|
||||
self.damp = damp
|
||||
self.damp_limit = damp_limit
|
||||
# Damping parameters
|
||||
self.damping = damping_init
|
||||
self.damping_final = damping_final
|
||||
# Exponential moving average decay for damping
|
||||
self.tau = 0.95
|
||||
# Covariance matrix, here only the diagonal
|
||||
self.cov = self.sigma * np.ones(self.num_params)
|
||||
|
||||
# elite stuff
|
||||
|
|
@ -45,16 +48,20 @@ class CEM(object):
|
|||
self.elite = np.sqrt(self.sigma) * np.random.rand(self.num_params)
|
||||
self.elite_score = None
|
||||
|
||||
# sampling stuff
|
||||
# sampling parameters
|
||||
self.pop_size = pop_size
|
||||
self.antithetic = antithetic
|
||||
|
||||
if self.antithetic:
|
||||
assert (self.pop_size % 2 == 0), "Population size must be even"
|
||||
|
||||
if parents is None or parents <= 0:
|
||||
self.parents = pop_size // 2
|
||||
else:
|
||||
self.parents = parents
|
||||
|
||||
# Weighting for computing the new mean of the distributions
|
||||
# from the parents. The better the individual, the higher the weight
|
||||
self.weights = np.array([np.log((self.parents + 1) / i)
|
||||
for i in range(1, self.parents + 1)])
|
||||
self.weights /= self.weights.sum()
|
||||
|
|
@ -69,11 +76,12 @@ class CEM(object):
|
|||
if self.antithetic and not pop_size % 2:
|
||||
epsilon_half = np.random.randn(pop_size // 2, self.num_params)
|
||||
epsilon = np.concatenate([epsilon_half, - epsilon_half])
|
||||
|
||||
else:
|
||||
epsilon = np.random.randn(pop_size, self.num_params)
|
||||
|
||||
individuals = self.mu + epsilon * np.sqrt(self.cov)
|
||||
|
||||
# Keep the best known individual in the population
|
||||
if self.elitism:
|
||||
individuals[-1] = self.elite
|
||||
|
||||
|
|
@ -89,19 +97,22 @@ class CEM(object):
|
|||
# Convert rewards (we want to maximize) to cost (we want to minimize)
|
||||
scores = np.array(scores)
|
||||
scores *= -1
|
||||
# Sort the individuals by fitness
|
||||
idx_sorted = np.argsort(scores)
|
||||
|
||||
old_mu = self.mu
|
||||
self.damp = self.damp * self.tau + (1 - self.tau) * self.damp_limit
|
||||
# Update damping using a moving average
|
||||
self.damping = self.damping * self.tau + (1 - self.tau) * self.damping_final
|
||||
# self.mu = self.weights @ solutions[idx_sorted[:self.parents]]
|
||||
self.mu = self.weights.dot(solutions[idx_sorted[:self.parents]])
|
||||
|
||||
# CMA-ES style would be to use the new mean here
|
||||
z = (solutions[idx_sorted[:self.parents]] - old_mu)
|
||||
self.cov = 1 / self.parents * self.weights.dot(z * z) + self.damp * np.ones(self.num_params)
|
||||
self.cov = 1 / self.parents * self.weights.dot(z * z) + self.damping * np.ones(self.num_params)
|
||||
|
||||
# Retrieve the best individual
|
||||
self.elite = solutions[idx_sorted[0]]
|
||||
self.elite_score = scores[idx_sorted[0]]
|
||||
# print(self.cov)
|
||||
|
||||
def get_distrib_params(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -19,9 +19,9 @@ class CEMRL(TD3):
|
|||
:param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
|
||||
:param sigma_init: (float) Initial standard deviation of the population distribution
|
||||
:param pop_size: (int) Number of individuals in the population
|
||||
:param damp: (float) Damping for preventing from early convergence.
|
||||
:param damp_limit: (float) Final value of damping
|
||||
:param elitism: (bool)
|
||||
:param damping_init: (float) Initial value of damping for preventing from early convergence.
|
||||
:param damping_final: (float) Final value of damping
|
||||
:param elitism: (bool) Keep the best known individual in the population
|
||||
:param n_grad: (int) Number of individuals that will receive a gradient update.
|
||||
Half of the population size in the paper.
|
||||
:param buffer_size: (int) size of the replay buffer
|
||||
|
|
@ -48,7 +48,7 @@ class CEMRL(TD3):
|
|||
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
def __init__(self, policy, env, sigma_init=1e-3, pop_size=10,
|
||||
damp=1e-3, damp_limit=1e-5, elitism=False, n_grad=5,
|
||||
damping_init=1e-3, damping_final=1e-5, elitism=False, n_grad=5,
|
||||
buffer_size=int(1e6), learning_rate=1e-3, policy_delay=2,
|
||||
learning_starts=100, gamma=0.99, batch_size=100, tau=0.005,
|
||||
action_noise=None, target_policy_noise=0.2, target_noise_clip=0.5,
|
||||
|
|
@ -72,8 +72,8 @@ class CEMRL(TD3):
|
|||
self.es = None
|
||||
self.sigma_init = sigma_init
|
||||
self.pop_size = pop_size
|
||||
self.damp = damp
|
||||
self.damp_limit = damp_limit
|
||||
self.damping_init = damping_init
|
||||
self.damping_final = damping_final
|
||||
self.elitism = elitism
|
||||
self.n_grad = n_grad
|
||||
self.es_params = None
|
||||
|
|
@ -87,7 +87,7 @@ class CEMRL(TD3):
|
|||
super(CEMRL, self)._setup_model()
|
||||
params_vector = self.actor.parameters_to_vector()
|
||||
self.es = CEM(len(params_vector), mu_init=params_vector,
|
||||
sigma_init=self.sigma_init, damp=self.damp, damp_limit=self.damp_limit,
|
||||
sigma_init=self.sigma_init, damping_init=self.damping_init, damping_final=self.damping_final,
|
||||
pop_size=self.pop_size, antithetic=not self.pop_size % 2, parents=self.pop_size // 2,
|
||||
elitism=self.elitism)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue