diff --git a/torchy_baselines/a2c/a2c.py b/torchy_baselines/a2c/a2c.py index 4de140b..709f67b 100644 --- a/torchy_baselines/a2c/a2c.py +++ b/torchy_baselines/a2c/a2c.py @@ -22,13 +22,12 @@ class A2C(PPO): :param learning_rate: (float or callable) The learning rate, it can be a function :param n_steps: (int) The number of steps to run for each environment per update (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) - :param batch_size: (int) Minibatch size - :param n_epochs: (int) Number of epoch when optimizing the surrogate loss :param gamma: (float) Discount factor :param gae_lambda: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator :param ent_coef: (float) Entropy coefficient for the loss calculation :param vf_coef: (float) Value function coefficient for the loss calculation :param max_grad_norm: (float) The maximum value for the gradient clipping + :param normalize_advantage: (bool) Whether to normalize or not the advantage :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) :param create_eval_env: (bool) Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment) @@ -41,23 +40,22 @@ class A2C(PPO): """ def __init__(self, policy, env, learning_rate=3e-4, - n_steps=2048, batch_size=64, n_epochs=1, - gamma=0.99, gae_lambda=0.95, + n_steps=5, gamma=0.99, gae_lambda=0.95, ent_coef=0.0, vf_coef=0.5, max_grad_norm=0.5, - tensorboard_log=None, create_eval_env=False, + normalize_advantage=True, tensorboard_log=None, create_eval_env=False, policy_kwargs=None, verbose=0, seed=0, device='auto', _init_setup_model=True): super(A2C, self).__init__(policy, env, learning_rate=learning_rate, - n_steps=n_steps, batch_size=batch_size, n_epochs=n_epochs, + n_steps=n_steps, batch_size=n_steps, n_epochs=1, gamma=gamma, gae_lambda=gae_lambda, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs, verbose=verbose, device=device, create_eval_env=create_eval_env, seed=seed, _init_setup_model=False) - self.batch_size = n_steps - + # Note: in the original implementation, this is RMSProp that is used + self.normalize_advantage = normalize_advantage if _init_setup_model: self._setup_model() @@ -76,9 +74,9 @@ class A2C(PPO): values, log_prob, entropy = self.policy.get_policy_stats(obs, action) values = values.flatten() - # Normalize advantage - # TODO: check without - advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8) + # Normalize advantage (not present in the original implementation) + if self.normalize_advantage: + advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8) policy_loss = -(advantage * log_prob).mean()