mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-29 23:07:07 +00:00
Use NamedTuple for buffers
This commit is contained in:
parent
1e81f38d66
commit
fb4e66213d
7 changed files with 60 additions and 53 deletions
|
|
@ -81,30 +81,30 @@ class A2C(PPO):
|
|||
# Update optimizer learning rate
|
||||
self._update_learning_rate(self.policy.optimizer)
|
||||
# A2C with gradient_steps > 1 does not make sense
|
||||
assert gradient_steps == 1
|
||||
assert gradient_steps == 1, "A2C does not support multiple gradient steps"
|
||||
# We do not use minibatches for A2C
|
||||
assert batch_size is None
|
||||
assert batch_size is None, "A2C does not support minibatch"
|
||||
|
||||
for rollout_data in self.rollout_buffer.get(batch_size=None):
|
||||
# Unpack
|
||||
obs, action, _, _, advantage, return_batch = rollout_data
|
||||
|
||||
actions = rollout_data.actions
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
# Convert discrete action for float to long
|
||||
action = action.long().flatten()
|
||||
# Convert discrete action from float to long
|
||||
actions = actions.long().flatten()
|
||||
|
||||
# TODO: avoid second computation of everything because of the gradient
|
||||
values, log_prob, entropy = self.policy.evaluate_actions(obs, action)
|
||||
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
|
||||
values = values.flatten()
|
||||
|
||||
# Normalize advantage (not present in the original implementation)
|
||||
if self.normalize_advantage:
|
||||
advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)
|
||||
advantages = (rollout_data.advantages - rollout_data.advantages.mean()) / (rollout_data.advantages.std() + 1e-8)
|
||||
|
||||
policy_loss = -(advantage * log_prob).mean()
|
||||
# Policy gradient loss
|
||||
policy_loss = -(advantages * log_prob).mean()
|
||||
|
||||
# Value loss using the TD(gae_lambda) target
|
||||
value_loss = F.mse_loss(return_batch, values)
|
||||
value_loss = F.mse_loss(rollout_data.returns, values)
|
||||
|
||||
# Entropy loss favor exploration
|
||||
if entropy is None:
|
||||
|
|
|
|||
|
|
@ -956,7 +956,6 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
total_episodes += 1
|
||||
episode_rewards.append(episode_reward)
|
||||
total_timesteps.append(episode_timesteps)
|
||||
# TODO: reset SDE matrix at the end of the episode?
|
||||
if action_noise is not None:
|
||||
action_noise.reset()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Union, Optional, Tuple, Generator
|
||||
from typing import Union, Optional, Generator
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -80,11 +80,12 @@ class BaseBuffer(object):
|
|||
def sample(self,
|
||||
batch_size: int,
|
||||
env: Optional[VecNormalize] = None
|
||||
) -> Tuple[th.Tensor, ...]:
|
||||
):
|
||||
"""
|
||||
:param batch_size: (int) Number of element to sample
|
||||
:param env: (Optional[VecNormalize]) associated gym VecEnv
|
||||
to normalize the observations/rewards when sampling
|
||||
:return: (Union[RolloutBufferSamples, ReplayBufferSamples])
|
||||
"""
|
||||
upper_bound = self.buffer_size if self.full else self.pos
|
||||
batch_inds = np.random.randint(0, upper_bound, size=batch_size)
|
||||
|
|
@ -93,11 +94,11 @@ class BaseBuffer(object):
|
|||
def _get_samples(self,
|
||||
batch_inds: np.ndarray,
|
||||
env: Optional[VecNormalize] = None
|
||||
) -> Tuple[th.Tensor, ...]:
|
||||
):
|
||||
"""
|
||||
:param batch_inds: (th.Tensor)
|
||||
:param env: (Optional[VecNormalize])
|
||||
:return: ([th.Tensor])
|
||||
:return: (Union[RolloutBufferSamples, ReplayBufferSamples])
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
@ -184,7 +185,7 @@ class ReplayBuffer(BaseBuffer):
|
|||
self._normalize_obs(self.next_observations[batch_inds, 0, :], env),
|
||||
self.dones[batch_inds],
|
||||
self._normalize_reward(self.rewards[batch_inds], env))
|
||||
return tuple(map(self.to_torch, data))
|
||||
return ReplayBufferSamples(*tuple(map(self.to_torch, data)))
|
||||
|
||||
|
||||
class RolloutBuffer(BaseBuffer):
|
||||
|
|
@ -333,4 +334,4 @@ class RolloutBuffer(BaseBuffer):
|
|||
self.log_probs[batch_inds].flatten(),
|
||||
self.advantages[batch_inds].flatten(),
|
||||
self.returns[batch_inds].flatten())
|
||||
return tuple(map(self.to_torch, data))
|
||||
return RolloutBufferSamples(*tuple(map(self.to_torch, data)))
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
"""
|
||||
Common aliases for type hing
|
||||
"""
|
||||
from typing import Union, Type, Optional, Dict, Any, List, Tuple
|
||||
from typing import Union, Type, Optional, Dict, Any, List, NamedTuple
|
||||
from collections import namedtuple
|
||||
|
||||
import torch as th
|
||||
import gym
|
||||
|
|
@ -13,6 +14,19 @@ GymEnv = Union[gym.Env, VecEnv]
|
|||
TensorDict = Dict[str, th.Tensor]
|
||||
OptimizerStateDict = Dict[str, Any]
|
||||
# obs, action, old_values, old_log_prob, advantage, return_batch
|
||||
RolloutBufferSamples = Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]
|
||||
class RolloutBufferSamples(NamedTuple):
|
||||
observations: th.Tensor
|
||||
actions: th.Tensor
|
||||
old_values: th.Tensor
|
||||
old_log_prob: th.Tensor
|
||||
advantages: th.Tensor
|
||||
returns: th.Tensor
|
||||
|
||||
|
||||
# obs, action, next_obs, done, reward
|
||||
ReplayBufferSamples = Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]
|
||||
class ReplayBufferSamples(NamedTuple):
|
||||
observations: th.Tensor
|
||||
actions: th.Tensor
|
||||
next_observations: th.Tensor
|
||||
dones: th.Tensor
|
||||
rewards: th.Tensor
|
||||
|
|
|
|||
|
|
@ -195,13 +195,12 @@ class PPO(BaseRLModel):
|
|||
for gradient_step in range(gradient_steps):
|
||||
approx_kl_divs = []
|
||||
# Sample replay buffer
|
||||
for replay_data in self.rollout_buffer.get(batch_size):
|
||||
# Unpack
|
||||
obs, action, old_values, old_log_prob, advantage, return_batch = replay_data
|
||||
for rollout_data in self.rollout_buffer.get(batch_size):
|
||||
|
||||
actions = rollout_data.actions
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
# Convert discrete action for float to long
|
||||
action = action.long().flatten()
|
||||
# Convert discrete action from float to long
|
||||
actions = rollout_data.actions.long().flatten()
|
||||
|
||||
# Re-sample the noise matrix because the log_std has changed
|
||||
# TODO: investigate why there is no issue with the gradient
|
||||
|
|
@ -209,16 +208,16 @@ class PPO(BaseRLModel):
|
|||
if self.use_sde:
|
||||
self.policy.reset_noise(batch_size)
|
||||
|
||||
values, log_prob, entropy = self.policy.evaluate_actions(obs, action)
|
||||
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
|
||||
values = values.flatten()
|
||||
# Normalize advantage
|
||||
advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)
|
||||
advantages = (rollout_data.advantages - rollout_data.advantages.mean()) / (rollout_data.advantages.std() + 1e-8)
|
||||
|
||||
# ratio between old and new policy, should be one at the first iteration
|
||||
ratio = th.exp(log_prob - old_log_prob)
|
||||
ratio = th.exp(log_prob - rollout_data.old_log_prob)
|
||||
# clipped surrogate loss
|
||||
policy_loss_1 = advantage * ratio
|
||||
policy_loss_2 = advantage * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
|
||||
policy_loss_1 = advantages * ratio
|
||||
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
|
||||
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
|
||||
|
||||
if self.clip_range_vf is None:
|
||||
|
|
@ -227,9 +226,9 @@ class PPO(BaseRLModel):
|
|||
else:
|
||||
# Clip the different between old and new value
|
||||
# NOTE: this depends on the reward scaling
|
||||
values_pred = old_values + th.clamp(values - old_values, -clip_range_vf, clip_range_vf)
|
||||
values_pred = rollout_data.old_values + th.clamp(values - rollout_data.old_values, -clip_range_vf, clip_range_vf)
|
||||
# Value loss using the TD(gae_lambda) target
|
||||
value_loss = F.mse_loss(return_batch, values_pred)
|
||||
value_loss = F.mse_loss(rollout_data.returns, values_pred)
|
||||
|
||||
# Entropy loss favor exploration
|
||||
if entropy is None:
|
||||
|
|
@ -246,7 +245,7 @@ class PPO(BaseRLModel):
|
|||
# Clip grad norm
|
||||
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
||||
self.policy.optimizer.step()
|
||||
approx_kl_divs.append(th.mean(old_log_prob - log_prob).detach().cpu().numpy())
|
||||
approx_kl_divs.append(th.mean(rollout_data.old_log_prob - log_prob).detach().cpu().numpy())
|
||||
|
||||
if self.target_kl is not None and np.mean(approx_kl_divs) > 1.5 * self.target_kl:
|
||||
print("Early stopping at step {} due to reaching max kl: {:.2f}".format(gradient_step,
|
||||
|
|
|
|||
|
|
@ -163,14 +163,12 @@ class SAC(OffPolicyRLModel):
|
|||
# Sample replay buffer
|
||||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
|
||||
|
||||
obs, action_batch, next_obs, done, reward = replay_data
|
||||
|
||||
# We need to sample because `log_std` may have changed between two gradient steps
|
||||
if self.use_sde:
|
||||
self.actor.reset_noise()
|
||||
|
||||
# Action by the current actor for the sampled state
|
||||
action_pi, log_prob = self.actor.action_log_prob(obs)
|
||||
actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations)
|
||||
log_prob = log_prob.reshape(-1, 1)
|
||||
|
||||
ent_coef_loss = None
|
||||
|
|
@ -192,17 +190,17 @@ class SAC(OffPolicyRLModel):
|
|||
|
||||
with th.no_grad():
|
||||
# Select action according to policy
|
||||
next_action, next_log_prob = self.actor.action_log_prob(next_obs)
|
||||
next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
|
||||
# Compute the target Q value
|
||||
target_q1, target_q2 = self.critic_target(next_obs, next_action)
|
||||
target_q1, target_q2 = self.critic_target(replay_data.next_observations, next_actions)
|
||||
target_q = th.min(target_q1, target_q2)
|
||||
target_q = reward + (1 - done) * self.gamma * target_q
|
||||
target_q = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q
|
||||
# td error + entropy term
|
||||
q_backup = target_q - ent_coef * next_log_prob.reshape(-1, 1)
|
||||
|
||||
# Get current Q estimates
|
||||
# using action from the replay buffer
|
||||
current_q1, current_q2 = self.critic(obs, action_batch)
|
||||
current_q1, current_q2 = self.critic(replay_data.observations, replay_data.actions)
|
||||
|
||||
# Compute critic loss
|
||||
critic_loss = 0.5 * (F.mse_loss(current_q1, q_backup) + F.mse_loss(current_q2, q_backup))
|
||||
|
|
@ -214,7 +212,7 @@ class SAC(OffPolicyRLModel):
|
|||
|
||||
# Compute actor loss
|
||||
# Alternative: actor_loss = th.mean(log_prob - qf1_pi)
|
||||
qf1_pi, qf2_pi = self.critic.forward(obs, action_pi)
|
||||
qf1_pi, qf2_pi = self.critic.forward(replay_data.observations, actions_pi)
|
||||
min_qf_pi = th.min(qf1_pi, qf2_pi)
|
||||
actor_loss = (ent_coef * log_prob - min_qf_pi).mean()
|
||||
|
||||
|
|
|
|||
|
|
@ -124,22 +124,20 @@ class TD3(OffPolicyRLModel):
|
|||
for gradient_step in range(gradient_steps):
|
||||
# Sample replay buffer
|
||||
if replay_data is None:
|
||||
obs, action, next_obs, done, reward = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
|
||||
else:
|
||||
obs, action, next_obs, done, reward = replay_data
|
||||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
|
||||
|
||||
# Select action according to policy and add clipped noise
|
||||
noise = action.clone().data.normal_(0, self.target_policy_noise)
|
||||
noise = replay_data.actions.clone().data.normal_(0, self.target_policy_noise)
|
||||
noise = noise.clamp(-self.target_noise_clip, self.target_noise_clip)
|
||||
next_action = (self.actor_target(next_obs) + noise).clamp(-1, 1)
|
||||
next_actions = (self.actor_target(replay_data.next_observations) + noise).clamp(-1, 1)
|
||||
|
||||
# Compute the target Q value
|
||||
target_q1, target_q2 = self.critic_target(next_obs, next_action)
|
||||
target_q1, target_q2 = self.critic_target(replay_data.next_observations, next_actions)
|
||||
target_q = th.min(target_q1, target_q2)
|
||||
target_q = reward + ((1 - done) * self.gamma * target_q).detach()
|
||||
target_q = replay_data.rewards + ((1 - replay_data.dones) * self.gamma * target_q).detach()
|
||||
|
||||
# Get current Q estimates
|
||||
current_q1, current_q2 = self.critic(obs, action)
|
||||
current_q1, current_q2 = self.critic(replay_data.observations, replay_data.actions)
|
||||
|
||||
# Compute critic loss
|
||||
critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
|
||||
|
|
@ -167,12 +165,10 @@ class TD3(OffPolicyRLModel):
|
|||
for gradient_step in range(gradient_steps):
|
||||
# Sample replay buffer
|
||||
if replay_data is None:
|
||||
obs, _, next_obs, done, reward = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
|
||||
else:
|
||||
obs, _, next_obs, done, reward = replay_data
|
||||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
|
||||
|
||||
# Compute actor loss
|
||||
actor_loss = -self.critic.q1_forward(obs, self.actor(obs)).mean()
|
||||
actor_loss = -self.critic.q1_forward(replay_data.observations, self.actor(replay_data.observations)).mean()
|
||||
|
||||
# Optimize the actor
|
||||
self.actor.optimizer.zero_grad()
|
||||
|
|
|
|||
Loading…
Reference in a new issue