Use NamedTuple for buffers

This commit is contained in:
Antonin Raffin 2020-03-10 16:43:10 +01:00
parent 1e81f38d66
commit fb4e66213d
7 changed files with 60 additions and 53 deletions

View file

@ -81,30 +81,30 @@ class A2C(PPO):
# Update optimizer learning rate
self._update_learning_rate(self.policy.optimizer)
# A2C with gradient_steps > 1 does not make sense
assert gradient_steps == 1
assert gradient_steps == 1, "A2C does not support multiple gradient steps"
# We do not use minibatches for A2C
assert batch_size is None
assert batch_size is None, "A2C does not support minibatch"
for rollout_data in self.rollout_buffer.get(batch_size=None):
# Unpack
obs, action, _, _, advantage, return_batch = rollout_data
actions = rollout_data.actions
if isinstance(self.action_space, spaces.Discrete):
# Convert discrete action for float to long
action = action.long().flatten()
# Convert discrete action from float to long
actions = actions.long().flatten()
# TODO: avoid second computation of everything because of the gradient
values, log_prob, entropy = self.policy.evaluate_actions(obs, action)
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
values = values.flatten()
# Normalize advantage (not present in the original implementation)
if self.normalize_advantage:
advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)
advantages = (rollout_data.advantages - rollout_data.advantages.mean()) / (rollout_data.advantages.std() + 1e-8)
policy_loss = -(advantage * log_prob).mean()
# Policy gradient loss
policy_loss = -(advantages * log_prob).mean()
# Value loss using the TD(gae_lambda) target
value_loss = F.mse_loss(return_batch, values)
value_loss = F.mse_loss(rollout_data.returns, values)
# Entropy loss favor exploration
if entropy is None:

View file

@ -956,7 +956,6 @@ class OffPolicyRLModel(BaseRLModel):
total_episodes += 1
episode_rewards.append(episode_reward)
total_timesteps.append(episode_timesteps)
# TODO: reset SDE matrix at the end of the episode?
if action_noise is not None:
action_noise.reset()

View file

@ -1,4 +1,4 @@
from typing import Union, Optional, Tuple, Generator
from typing import Union, Optional, Generator
import numpy as np
import torch as th
@ -80,11 +80,12 @@ class BaseBuffer(object):
def sample(self,
batch_size: int,
env: Optional[VecNormalize] = None
) -> Tuple[th.Tensor, ...]:
):
"""
:param batch_size: (int) Number of element to sample
:param env: (Optional[VecNormalize]) associated gym VecEnv
to normalize the observations/rewards when sampling
:return: (Union[RolloutBufferSamples, ReplayBufferSamples])
"""
upper_bound = self.buffer_size if self.full else self.pos
batch_inds = np.random.randint(0, upper_bound, size=batch_size)
@ -93,11 +94,11 @@ class BaseBuffer(object):
def _get_samples(self,
batch_inds: np.ndarray,
env: Optional[VecNormalize] = None
) -> Tuple[th.Tensor, ...]:
):
"""
:param batch_inds: (th.Tensor)
:param env: (Optional[VecNormalize])
:return: ([th.Tensor])
:return: (Union[RolloutBufferSamples, ReplayBufferSamples])
"""
raise NotImplementedError()
@ -184,7 +185,7 @@ class ReplayBuffer(BaseBuffer):
self._normalize_obs(self.next_observations[batch_inds, 0, :], env),
self.dones[batch_inds],
self._normalize_reward(self.rewards[batch_inds], env))
return tuple(map(self.to_torch, data))
return ReplayBufferSamples(*tuple(map(self.to_torch, data)))
class RolloutBuffer(BaseBuffer):
@ -333,4 +334,4 @@ class RolloutBuffer(BaseBuffer):
self.log_probs[batch_inds].flatten(),
self.advantages[batch_inds].flatten(),
self.returns[batch_inds].flatten())
return tuple(map(self.to_torch, data))
return RolloutBufferSamples(*tuple(map(self.to_torch, data)))

View file

@ -1,7 +1,8 @@
"""
Common aliases for type hing
"""
from typing import Union, Type, Optional, Dict, Any, List, Tuple
from typing import Union, Type, Optional, Dict, Any, List, NamedTuple
from collections import namedtuple
import torch as th
import gym
@ -13,6 +14,19 @@ GymEnv = Union[gym.Env, VecEnv]
TensorDict = Dict[str, th.Tensor]
OptimizerStateDict = Dict[str, Any]
# obs, action, old_values, old_log_prob, advantage, return_batch
RolloutBufferSamples = Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]
class RolloutBufferSamples(NamedTuple):
observations: th.Tensor
actions: th.Tensor
old_values: th.Tensor
old_log_prob: th.Tensor
advantages: th.Tensor
returns: th.Tensor
# obs, action, next_obs, done, reward
ReplayBufferSamples = Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]
class ReplayBufferSamples(NamedTuple):
observations: th.Tensor
actions: th.Tensor
next_observations: th.Tensor
dones: th.Tensor
rewards: th.Tensor

View file

@ -195,13 +195,12 @@ class PPO(BaseRLModel):
for gradient_step in range(gradient_steps):
approx_kl_divs = []
# Sample replay buffer
for replay_data in self.rollout_buffer.get(batch_size):
# Unpack
obs, action, old_values, old_log_prob, advantage, return_batch = replay_data
for rollout_data in self.rollout_buffer.get(batch_size):
actions = rollout_data.actions
if isinstance(self.action_space, spaces.Discrete):
# Convert discrete action for float to long
action = action.long().flatten()
# Convert discrete action from float to long
actions = rollout_data.actions.long().flatten()
# Re-sample the noise matrix because the log_std has changed
# TODO: investigate why there is no issue with the gradient
@ -209,16 +208,16 @@ class PPO(BaseRLModel):
if self.use_sde:
self.policy.reset_noise(batch_size)
values, log_prob, entropy = self.policy.evaluate_actions(obs, action)
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
values = values.flatten()
# Normalize advantage
advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)
advantages = (rollout_data.advantages - rollout_data.advantages.mean()) / (rollout_data.advantages.std() + 1e-8)
# ratio between old and new policy, should be one at the first iteration
ratio = th.exp(log_prob - old_log_prob)
ratio = th.exp(log_prob - rollout_data.old_log_prob)
# clipped surrogate loss
policy_loss_1 = advantage * ratio
policy_loss_2 = advantage * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
policy_loss_1 = advantages * ratio
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
if self.clip_range_vf is None:
@ -227,9 +226,9 @@ class PPO(BaseRLModel):
else:
# Clip the different between old and new value
# NOTE: this depends on the reward scaling
values_pred = old_values + th.clamp(values - old_values, -clip_range_vf, clip_range_vf)
values_pred = rollout_data.old_values + th.clamp(values - rollout_data.old_values, -clip_range_vf, clip_range_vf)
# Value loss using the TD(gae_lambda) target
value_loss = F.mse_loss(return_batch, values_pred)
value_loss = F.mse_loss(rollout_data.returns, values_pred)
# Entropy loss favor exploration
if entropy is None:
@ -246,7 +245,7 @@ class PPO(BaseRLModel):
# Clip grad norm
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
approx_kl_divs.append(th.mean(old_log_prob - log_prob).detach().cpu().numpy())
approx_kl_divs.append(th.mean(rollout_data.old_log_prob - log_prob).detach().cpu().numpy())
if self.target_kl is not None and np.mean(approx_kl_divs) > 1.5 * self.target_kl:
print("Early stopping at step {} due to reaching max kl: {:.2f}".format(gradient_step,

View file

@ -163,14 +163,12 @@ class SAC(OffPolicyRLModel):
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
obs, action_batch, next_obs, done, reward = replay_data
# We need to sample because `log_std` may have changed between two gradient steps
if self.use_sde:
self.actor.reset_noise()
# Action by the current actor for the sampled state
action_pi, log_prob = self.actor.action_log_prob(obs)
actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations)
log_prob = log_prob.reshape(-1, 1)
ent_coef_loss = None
@ -192,17 +190,17 @@ class SAC(OffPolicyRLModel):
with th.no_grad():
# Select action according to policy
next_action, next_log_prob = self.actor.action_log_prob(next_obs)
next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
# Compute the target Q value
target_q1, target_q2 = self.critic_target(next_obs, next_action)
target_q1, target_q2 = self.critic_target(replay_data.next_observations, next_actions)
target_q = th.min(target_q1, target_q2)
target_q = reward + (1 - done) * self.gamma * target_q
target_q = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q
# td error + entropy term
q_backup = target_q - ent_coef * next_log_prob.reshape(-1, 1)
# Get current Q estimates
# using action from the replay buffer
current_q1, current_q2 = self.critic(obs, action_batch)
current_q1, current_q2 = self.critic(replay_data.observations, replay_data.actions)
# Compute critic loss
critic_loss = 0.5 * (F.mse_loss(current_q1, q_backup) + F.mse_loss(current_q2, q_backup))
@ -214,7 +212,7 @@ class SAC(OffPolicyRLModel):
# Compute actor loss
# Alternative: actor_loss = th.mean(log_prob - qf1_pi)
qf1_pi, qf2_pi = self.critic.forward(obs, action_pi)
qf1_pi, qf2_pi = self.critic.forward(replay_data.observations, actions_pi)
min_qf_pi = th.min(qf1_pi, qf2_pi)
actor_loss = (ent_coef * log_prob - min_qf_pi).mean()

View file

@ -124,22 +124,20 @@ class TD3(OffPolicyRLModel):
for gradient_step in range(gradient_steps):
# Sample replay buffer
if replay_data is None:
obs, action, next_obs, done, reward = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
else:
obs, action, next_obs, done, reward = replay_data
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
# Select action according to policy and add clipped noise
noise = action.clone().data.normal_(0, self.target_policy_noise)
noise = replay_data.actions.clone().data.normal_(0, self.target_policy_noise)
noise = noise.clamp(-self.target_noise_clip, self.target_noise_clip)
next_action = (self.actor_target(next_obs) + noise).clamp(-1, 1)
next_actions = (self.actor_target(replay_data.next_observations) + noise).clamp(-1, 1)
# Compute the target Q value
target_q1, target_q2 = self.critic_target(next_obs, next_action)
target_q1, target_q2 = self.critic_target(replay_data.next_observations, next_actions)
target_q = th.min(target_q1, target_q2)
target_q = reward + ((1 - done) * self.gamma * target_q).detach()
target_q = replay_data.rewards + ((1 - replay_data.dones) * self.gamma * target_q).detach()
# Get current Q estimates
current_q1, current_q2 = self.critic(obs, action)
current_q1, current_q2 = self.critic(replay_data.observations, replay_data.actions)
# Compute critic loss
critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
@ -167,12 +165,10 @@ class TD3(OffPolicyRLModel):
for gradient_step in range(gradient_steps):
# Sample replay buffer
if replay_data is None:
obs, _, next_obs, done, reward = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
else:
obs, _, next_obs, done, reward = replay_data
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
# Compute actor loss
actor_loss = -self.critic.q1_forward(obs, self.actor(obs)).mean()
actor_loss = -self.critic.q1_forward(replay_data.observations, self.actor(replay_data.observations)).mean()
# Optimize the actor
self.actor.optimizer.zero_grad()