mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-01 23:30:53 +00:00
Update A2C params
This commit is contained in:
parent
0ad743c85d
commit
f8bcb8ee16
1 changed files with 9 additions and 11 deletions
|
|
@ -22,13 +22,12 @@ class A2C(PPO):
|
|||
:param learning_rate: (float or callable) The learning rate, it can be a function
|
||||
:param n_steps: (int) The number of steps to run for each environment per update
|
||||
(i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
|
||||
:param batch_size: (int) Minibatch size
|
||||
:param n_epochs: (int) Number of epoch when optimizing the surrogate loss
|
||||
:param gamma: (float) Discount factor
|
||||
:param gae_lambda: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
||||
:param ent_coef: (float) Entropy coefficient for the loss calculation
|
||||
:param vf_coef: (float) Value function coefficient for the loss calculation
|
||||
:param max_grad_norm: (float) The maximum value for the gradient clipping
|
||||
:param normalize_advantage: (bool) Whether to normalize or not the advantage
|
||||
:param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
|
||||
:param create_eval_env: (bool) Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
|
|
@ -41,23 +40,22 @@ class A2C(PPO):
|
|||
"""
|
||||
|
||||
def __init__(self, policy, env, learning_rate=3e-4,
|
||||
n_steps=2048, batch_size=64, n_epochs=1,
|
||||
gamma=0.99, gae_lambda=0.95,
|
||||
n_steps=5, gamma=0.99, gae_lambda=0.95,
|
||||
ent_coef=0.0, vf_coef=0.5, max_grad_norm=0.5,
|
||||
tensorboard_log=None, create_eval_env=False,
|
||||
normalize_advantage=True, tensorboard_log=None, create_eval_env=False,
|
||||
policy_kwargs=None, verbose=0, seed=0, device='auto',
|
||||
_init_setup_model=True):
|
||||
|
||||
super(A2C, self).__init__(policy, env, learning_rate=learning_rate,
|
||||
n_steps=n_steps, batch_size=batch_size, n_epochs=n_epochs,
|
||||
n_steps=n_steps, batch_size=n_steps, n_epochs=1,
|
||||
gamma=gamma, gae_lambda=gae_lambda, ent_coef=ent_coef,
|
||||
vf_coef=vf_coef, max_grad_norm=max_grad_norm,
|
||||
tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs,
|
||||
verbose=verbose, device=device, create_eval_env=create_eval_env,
|
||||
seed=seed, _init_setup_model=False)
|
||||
|
||||
self.batch_size = n_steps
|
||||
|
||||
# Note: in the original implementation, this is RMSProp that is used
|
||||
self.normalize_advantage = normalize_advantage
|
||||
if _init_setup_model:
|
||||
self._setup_model()
|
||||
|
||||
|
|
@ -76,9 +74,9 @@ class A2C(PPO):
|
|||
|
||||
values, log_prob, entropy = self.policy.get_policy_stats(obs, action)
|
||||
values = values.flatten()
|
||||
# Normalize advantage
|
||||
# TODO: check without
|
||||
advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)
|
||||
# Normalize advantage (not present in the original implementation)
|
||||
if self.normalize_advantage:
|
||||
advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)
|
||||
|
||||
policy_loss = -(advantage * log_prob).mean()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue