mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-26 03:01:19 +00:00
Bug fixes for A2C and PPO
This commit is contained in:
parent
b150167bdd
commit
799e30ff3d
3 changed files with 39 additions and 31 deletions
|
|
@ -51,7 +51,7 @@ class A2C(PPO):
|
|||
_init_setup_model=True):
|
||||
|
||||
super(A2C, self).__init__(policy, env, learning_rate=learning_rate,
|
||||
n_steps=n_steps, batch_size=n_steps, n_epochs=1,
|
||||
n_steps=n_steps, batch_size=None, n_epochs=1,
|
||||
gamma=gamma, gae_lambda=gae_lambda, ent_coef=ent_coef,
|
||||
vf_coef=vf_coef, max_grad_norm=max_grad_norm,
|
||||
tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs,
|
||||
|
|
@ -72,42 +72,46 @@ class A2C(PPO):
|
|||
lr=self.learning_rate, alpha=0.99,
|
||||
eps=self.rms_prop_eps, weight_decay=0)
|
||||
|
||||
def train(self, gradient_steps, batch_size=64):
|
||||
def train(self, gradient_steps, batch_size=None):
|
||||
|
||||
for gradient_step in range(gradient_steps):
|
||||
# approx_kl_divs = []
|
||||
# Sample replay buffer
|
||||
for replay_data in self.rollout_buffer.get(batch_size):
|
||||
# Unpack
|
||||
obs, action, _, _, advantage, return_batch = replay_data
|
||||
# A2C with gradient_steps > 1 does not make sense
|
||||
assert gradient_steps == 1
|
||||
# We do not use minibatches for A2C
|
||||
assert batch_size is None
|
||||
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
# Convert discrete action for float to long
|
||||
action = action.long().flatten()
|
||||
for rollout_data in self.rollout_buffer.get(batch_size=None):
|
||||
# Unpack
|
||||
obs, action, _, _, advantage, return_batch = rollout_data
|
||||
|
||||
values, log_prob, entropy = self.policy.get_policy_stats(obs, action)
|
||||
values = values.flatten()
|
||||
# Normalize advantage (not present in the original implementation)
|
||||
if self.normalize_advantage:
|
||||
advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
# Convert discrete action for float to long
|
||||
action = action.long().flatten()
|
||||
|
||||
policy_loss = -(advantage * log_prob).mean()
|
||||
# TODO: avoid second computation of everything because of the gradient
|
||||
values, log_prob, entropy = self.policy.get_policy_stats(obs, action)
|
||||
values = values.flatten()
|
||||
|
||||
# Value loss using the TD(gae_lambda) target
|
||||
value_loss = F.mse_loss(return_batch, values)
|
||||
# Normalize advantage (not present in the original implementation)
|
||||
if self.normalize_advantage:
|
||||
advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)
|
||||
|
||||
# Entropy loss favor exploration
|
||||
entropy_loss = th.mean(entropy)
|
||||
policy_loss = -(advantage * log_prob).mean()
|
||||
|
||||
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
|
||||
# Value loss using the TD(gae_lambda) target
|
||||
value_loss = F.mse_loss(return_batch, values)
|
||||
|
||||
# Optimization step
|
||||
self.policy.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
# Clip grad norm
|
||||
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
||||
self.policy.optimizer.step()
|
||||
# approx_kl_divs.append(th.mean(old_log_prob - log_prob).detach().cpu().numpy())
|
||||
# Entropy loss favor exploration
|
||||
entropy_loss = -th.mean(entropy)
|
||||
|
||||
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
|
||||
|
||||
# Optimization step
|
||||
self.policy.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
# Clip grad norm
|
||||
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
||||
self.policy.optimizer.step()
|
||||
# approx_kl_divs.append(th.mean(old_log_prob - log_prob).detach().cpu().numpy())
|
||||
|
||||
# print(explained_variance(self.rollout_buffer.returns.flatten().cpu().numpy(),
|
||||
# self.rollout_buffer.values.flatten().cpu().numpy()))
|
||||
|
|
|
|||
|
|
@ -145,7 +145,7 @@ class RolloutBuffer(BaseBuffer):
|
|||
if self.pos == self.buffer_size:
|
||||
self.full = True
|
||||
|
||||
def get(self, batch_size):
|
||||
def get(self, batch_size=None):
|
||||
assert self.full
|
||||
indices = th.randperm(self.buffer_size * self.n_envs)
|
||||
# Prepare the data
|
||||
|
|
@ -155,6 +155,10 @@ class RolloutBuffer(BaseBuffer):
|
|||
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
|
||||
self.generator_ready = True
|
||||
|
||||
# Return everything, don't create minibatches
|
||||
if batch_size is None:
|
||||
batch_size = self.buffer_size * self.n_envs
|
||||
|
||||
start_idx = 0
|
||||
while start_idx < self.buffer_size * self.n_envs:
|
||||
yield self._get_samples(indices[start_idx:start_idx + batch_size])
|
||||
|
|
|
|||
|
|
@ -205,7 +205,7 @@ class PPO(BaseRLModel):
|
|||
|
||||
|
||||
# Entropy loss favor exploration
|
||||
entropy_loss = th.mean(entropy)
|
||||
entropy_loss = -th.mean(entropy)
|
||||
|
||||
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue