Bug fixes for A2C and PPO

This commit is contained in:
Antonin Raffin 2019-10-28 14:27:32 +01:00
parent b150167bdd
commit 799e30ff3d
3 changed files with 39 additions and 31 deletions

View file

@ -51,7 +51,7 @@ class A2C(PPO):
_init_setup_model=True):
super(A2C, self).__init__(policy, env, learning_rate=learning_rate,
n_steps=n_steps, batch_size=n_steps, n_epochs=1,
n_steps=n_steps, batch_size=None, n_epochs=1,
gamma=gamma, gae_lambda=gae_lambda, ent_coef=ent_coef,
vf_coef=vf_coef, max_grad_norm=max_grad_norm,
tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs,
@ -72,42 +72,46 @@ class A2C(PPO):
lr=self.learning_rate, alpha=0.99,
eps=self.rms_prop_eps, weight_decay=0)
def train(self, gradient_steps, batch_size=64):
def train(self, gradient_steps, batch_size=None):
for gradient_step in range(gradient_steps):
# approx_kl_divs = []
# Sample replay buffer
for replay_data in self.rollout_buffer.get(batch_size):
# Unpack
obs, action, _, _, advantage, return_batch = replay_data
# A2C with gradient_steps > 1 does not make sense
assert gradient_steps == 1
# We do not use minibatches for A2C
assert batch_size is None
if isinstance(self.action_space, spaces.Discrete):
# Convert discrete action for float to long
action = action.long().flatten()
for rollout_data in self.rollout_buffer.get(batch_size=None):
# Unpack
obs, action, _, _, advantage, return_batch = rollout_data
values, log_prob, entropy = self.policy.get_policy_stats(obs, action)
values = values.flatten()
# Normalize advantage (not present in the original implementation)
if self.normalize_advantage:
advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)
if isinstance(self.action_space, spaces.Discrete):
# Convert discrete action for float to long
action = action.long().flatten()
policy_loss = -(advantage * log_prob).mean()
# TODO: avoid second computation of everything because of the gradient
values, log_prob, entropy = self.policy.get_policy_stats(obs, action)
values = values.flatten()
# Value loss using the TD(gae_lambda) target
value_loss = F.mse_loss(return_batch, values)
# Normalize advantage (not present in the original implementation)
if self.normalize_advantage:
advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)
# Entropy loss favor exploration
entropy_loss = th.mean(entropy)
policy_loss = -(advantage * log_prob).mean()
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
# Value loss using the TD(gae_lambda) target
value_loss = F.mse_loss(return_batch, values)
# Optimization step
self.policy.optimizer.zero_grad()
loss.backward()
# Clip grad norm
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
# approx_kl_divs.append(th.mean(old_log_prob - log_prob).detach().cpu().numpy())
# Entropy loss favor exploration
entropy_loss = -th.mean(entropy)
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
# Optimization step
self.policy.optimizer.zero_grad()
loss.backward()
# Clip grad norm
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
# approx_kl_divs.append(th.mean(old_log_prob - log_prob).detach().cpu().numpy())
# print(explained_variance(self.rollout_buffer.returns.flatten().cpu().numpy(),
# self.rollout_buffer.values.flatten().cpu().numpy()))

View file

@ -145,7 +145,7 @@ class RolloutBuffer(BaseBuffer):
if self.pos == self.buffer_size:
self.full = True
def get(self, batch_size):
def get(self, batch_size=None):
assert self.full
indices = th.randperm(self.buffer_size * self.n_envs)
# Prepare the data
@ -155,6 +155,10 @@ class RolloutBuffer(BaseBuffer):
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
self.generator_ready = True
# Return everything, don't create minibatches
if batch_size is None:
batch_size = self.buffer_size * self.n_envs
start_idx = 0
while start_idx < self.buffer_size * self.n_envs:
yield self._get_samples(indices[start_idx:start_idx + batch_size])

View file

@ -205,7 +205,7 @@ class PPO(BaseRLModel):
# Entropy loss favor exploration
entropy_loss = th.mean(entropy)
entropy_loss = -th.mean(entropy)
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss