From 799e30ff3d575015866a1a48da7b4641aa56eed2 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 28 Oct 2019 14:27:32 +0100 Subject: [PATCH] Bug fixes for A2C and PPO --- torchy_baselines/a2c/a2c.py | 62 ++++++++++++++++-------------- torchy_baselines/common/buffers.py | 6 ++- torchy_baselines/ppo/ppo.py | 2 +- 3 files changed, 39 insertions(+), 31 deletions(-) diff --git a/torchy_baselines/a2c/a2c.py b/torchy_baselines/a2c/a2c.py index f08a352..c426552 100644 --- a/torchy_baselines/a2c/a2c.py +++ b/torchy_baselines/a2c/a2c.py @@ -51,7 +51,7 @@ class A2C(PPO): _init_setup_model=True): super(A2C, self).__init__(policy, env, learning_rate=learning_rate, - n_steps=n_steps, batch_size=n_steps, n_epochs=1, + n_steps=n_steps, batch_size=None, n_epochs=1, gamma=gamma, gae_lambda=gae_lambda, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs, @@ -72,42 +72,46 @@ class A2C(PPO): lr=self.learning_rate, alpha=0.99, eps=self.rms_prop_eps, weight_decay=0) - def train(self, gradient_steps, batch_size=64): + def train(self, gradient_steps, batch_size=None): - for gradient_step in range(gradient_steps): - # approx_kl_divs = [] - # Sample replay buffer - for replay_data in self.rollout_buffer.get(batch_size): - # Unpack - obs, action, _, _, advantage, return_batch = replay_data + # A2C with gradient_steps > 1 does not make sense + assert gradient_steps == 1 + # We do not use minibatches for A2C + assert batch_size is None - if isinstance(self.action_space, spaces.Discrete): - # Convert discrete action for float to long - action = action.long().flatten() + for rollout_data in self.rollout_buffer.get(batch_size=None): + # Unpack + obs, action, _, _, advantage, return_batch = rollout_data - values, log_prob, entropy = self.policy.get_policy_stats(obs, action) - values = values.flatten() - # Normalize advantage (not present in the original implementation) - if self.normalize_advantage: - advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8) + if isinstance(self.action_space, spaces.Discrete): + # Convert discrete action for float to long + action = action.long().flatten() - policy_loss = -(advantage * log_prob).mean() + # TODO: avoid second computation of everything because of the gradient + values, log_prob, entropy = self.policy.get_policy_stats(obs, action) + values = values.flatten() - # Value loss using the TD(gae_lambda) target - value_loss = F.mse_loss(return_batch, values) + # Normalize advantage (not present in the original implementation) + if self.normalize_advantage: + advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8) - # Entropy loss favor exploration - entropy_loss = th.mean(entropy) + policy_loss = -(advantage * log_prob).mean() - loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss + # Value loss using the TD(gae_lambda) target + value_loss = F.mse_loss(return_batch, values) - # Optimization step - self.policy.optimizer.zero_grad() - loss.backward() - # Clip grad norm - th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) - self.policy.optimizer.step() - # approx_kl_divs.append(th.mean(old_log_prob - log_prob).detach().cpu().numpy()) + # Entropy loss favor exploration + entropy_loss = -th.mean(entropy) + + loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss + + # Optimization step + self.policy.optimizer.zero_grad() + loss.backward() + # Clip grad norm + th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + self.policy.optimizer.step() + # approx_kl_divs.append(th.mean(old_log_prob - log_prob).detach().cpu().numpy()) # print(explained_variance(self.rollout_buffer.returns.flatten().cpu().numpy(), # self.rollout_buffer.values.flatten().cpu().numpy())) diff --git a/torchy_baselines/common/buffers.py b/torchy_baselines/common/buffers.py index b169dd3..34a7098 100644 --- a/torchy_baselines/common/buffers.py +++ b/torchy_baselines/common/buffers.py @@ -145,7 +145,7 @@ class RolloutBuffer(BaseBuffer): if self.pos == self.buffer_size: self.full = True - def get(self, batch_size): + def get(self, batch_size=None): assert self.full indices = th.randperm(self.buffer_size * self.n_envs) # Prepare the data @@ -155,6 +155,10 @@ class RolloutBuffer(BaseBuffer): self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) self.generator_ready = True + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + start_idx = 0 while start_idx < self.buffer_size * self.n_envs: yield self._get_samples(indices[start_idx:start_idx + batch_size]) diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index b6584dd..cf36770 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -205,7 +205,7 @@ class PPO(BaseRLModel): # Entropy loss favor exploration - entropy_loss = th.mean(entropy) + entropy_loss = -th.mean(entropy) loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss