diff --git a/torchy_baselines/a2c/a2c.py b/torchy_baselines/a2c/a2c.py index 01cbd39..ea27d71 100644 --- a/torchy_baselines/a2c/a2c.py +++ b/torchy_baselines/a2c/a2c.py @@ -81,30 +81,30 @@ class A2C(PPO): # Update optimizer learning rate self._update_learning_rate(self.policy.optimizer) # A2C with gradient_steps > 1 does not make sense - assert gradient_steps == 1 + assert gradient_steps == 1, "A2C does not support multiple gradient steps" # We do not use minibatches for A2C - assert batch_size is None + assert batch_size is None, "A2C does not support minibatch" for rollout_data in self.rollout_buffer.get(batch_size=None): - # Unpack - obs, action, _, _, advantage, return_batch = rollout_data + actions = rollout_data.actions if isinstance(self.action_space, spaces.Discrete): - # Convert discrete action for float to long - action = action.long().flatten() + # Convert discrete action from float to long + actions = actions.long().flatten() # TODO: avoid second computation of everything because of the gradient - values, log_prob, entropy = self.policy.evaluate_actions(obs, action) + values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions) values = values.flatten() # Normalize advantage (not present in the original implementation) if self.normalize_advantage: - advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8) + advantages = (rollout_data.advantages - rollout_data.advantages.mean()) / (rollout_data.advantages.std() + 1e-8) - policy_loss = -(advantage * log_prob).mean() + # Policy gradient loss + policy_loss = -(advantages * log_prob).mean() # Value loss using the TD(gae_lambda) target - value_loss = F.mse_loss(return_batch, values) + value_loss = F.mse_loss(rollout_data.returns, values) # Entropy loss favor exploration if entropy is None: diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 6b3fb6a..b010025 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -956,7 +956,6 @@ class OffPolicyRLModel(BaseRLModel): total_episodes += 1 episode_rewards.append(episode_reward) total_timesteps.append(episode_timesteps) - # TODO: reset SDE matrix at the end of the episode? if action_noise is not None: action_noise.reset() diff --git a/torchy_baselines/common/buffers.py b/torchy_baselines/common/buffers.py index c6e9d5c..6ffe479 100644 --- a/torchy_baselines/common/buffers.py +++ b/torchy_baselines/common/buffers.py @@ -1,4 +1,4 @@ -from typing import Union, Optional, Tuple, Generator +from typing import Union, Optional, Generator import numpy as np import torch as th @@ -80,11 +80,12 @@ class BaseBuffer(object): def sample(self, batch_size: int, env: Optional[VecNormalize] = None - ) -> Tuple[th.Tensor, ...]: + ): """ :param batch_size: (int) Number of element to sample :param env: (Optional[VecNormalize]) associated gym VecEnv to normalize the observations/rewards when sampling + :return: (Union[RolloutBufferSamples, ReplayBufferSamples]) """ upper_bound = self.buffer_size if self.full else self.pos batch_inds = np.random.randint(0, upper_bound, size=batch_size) @@ -93,11 +94,11 @@ class BaseBuffer(object): def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None - ) -> Tuple[th.Tensor, ...]: + ): """ :param batch_inds: (th.Tensor) :param env: (Optional[VecNormalize]) - :return: ([th.Tensor]) + :return: (Union[RolloutBufferSamples, ReplayBufferSamples]) """ raise NotImplementedError() @@ -184,7 +185,7 @@ class ReplayBuffer(BaseBuffer): self._normalize_obs(self.next_observations[batch_inds, 0, :], env), self.dones[batch_inds], self._normalize_reward(self.rewards[batch_inds], env)) - return tuple(map(self.to_torch, data)) + return ReplayBufferSamples(*tuple(map(self.to_torch, data))) class RolloutBuffer(BaseBuffer): @@ -333,4 +334,4 @@ class RolloutBuffer(BaseBuffer): self.log_probs[batch_inds].flatten(), self.advantages[batch_inds].flatten(), self.returns[batch_inds].flatten()) - return tuple(map(self.to_torch, data)) + return RolloutBufferSamples(*tuple(map(self.to_torch, data))) diff --git a/torchy_baselines/common/type_aliases.py b/torchy_baselines/common/type_aliases.py index b9035db..16576ff 100644 --- a/torchy_baselines/common/type_aliases.py +++ b/torchy_baselines/common/type_aliases.py @@ -1,7 +1,8 @@ """ Common aliases for type hing """ -from typing import Union, Type, Optional, Dict, Any, List, Tuple +from typing import Union, Type, Optional, Dict, Any, List, NamedTuple +from collections import namedtuple import torch as th import gym @@ -13,6 +14,19 @@ GymEnv = Union[gym.Env, VecEnv] TensorDict = Dict[str, th.Tensor] OptimizerStateDict = Dict[str, Any] # obs, action, old_values, old_log_prob, advantage, return_batch -RolloutBufferSamples = Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor] +class RolloutBufferSamples(NamedTuple): + observations: th.Tensor + actions: th.Tensor + old_values: th.Tensor + old_log_prob: th.Tensor + advantages: th.Tensor + returns: th.Tensor + + # obs, action, next_obs, done, reward -ReplayBufferSamples = Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor] +class ReplayBufferSamples(NamedTuple): + observations: th.Tensor + actions: th.Tensor + next_observations: th.Tensor + dones: th.Tensor + rewards: th.Tensor diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index 6f30f7f..f10f5e2 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -195,13 +195,12 @@ class PPO(BaseRLModel): for gradient_step in range(gradient_steps): approx_kl_divs = [] # Sample replay buffer - for replay_data in self.rollout_buffer.get(batch_size): - # Unpack - obs, action, old_values, old_log_prob, advantage, return_batch = replay_data + for rollout_data in self.rollout_buffer.get(batch_size): + actions = rollout_data.actions if isinstance(self.action_space, spaces.Discrete): - # Convert discrete action for float to long - action = action.long().flatten() + # Convert discrete action from float to long + actions = rollout_data.actions.long().flatten() # Re-sample the noise matrix because the log_std has changed # TODO: investigate why there is no issue with the gradient @@ -209,16 +208,16 @@ class PPO(BaseRLModel): if self.use_sde: self.policy.reset_noise(batch_size) - values, log_prob, entropy = self.policy.evaluate_actions(obs, action) + values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions) values = values.flatten() # Normalize advantage - advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8) + advantages = (rollout_data.advantages - rollout_data.advantages.mean()) / (rollout_data.advantages.std() + 1e-8) # ratio between old and new policy, should be one at the first iteration - ratio = th.exp(log_prob - old_log_prob) + ratio = th.exp(log_prob - rollout_data.old_log_prob) # clipped surrogate loss - policy_loss_1 = advantage * ratio - policy_loss_2 = advantage * th.clamp(ratio, 1 - clip_range, 1 + clip_range) + policy_loss_1 = advantages * ratio + policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) policy_loss = -th.min(policy_loss_1, policy_loss_2).mean() if self.clip_range_vf is None: @@ -227,9 +226,9 @@ class PPO(BaseRLModel): else: # Clip the different between old and new value # NOTE: this depends on the reward scaling - values_pred = old_values + th.clamp(values - old_values, -clip_range_vf, clip_range_vf) + values_pred = rollout_data.old_values + th.clamp(values - rollout_data.old_values, -clip_range_vf, clip_range_vf) # Value loss using the TD(gae_lambda) target - value_loss = F.mse_loss(return_batch, values_pred) + value_loss = F.mse_loss(rollout_data.returns, values_pred) # Entropy loss favor exploration if entropy is None: @@ -246,7 +245,7 @@ class PPO(BaseRLModel): # Clip grad norm th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() - approx_kl_divs.append(th.mean(old_log_prob - log_prob).detach().cpu().numpy()) + approx_kl_divs.append(th.mean(rollout_data.old_log_prob - log_prob).detach().cpu().numpy()) if self.target_kl is not None and np.mean(approx_kl_divs) > 1.5 * self.target_kl: print("Early stopping at step {} due to reaching max kl: {:.2f}".format(gradient_step, diff --git a/torchy_baselines/sac/sac.py b/torchy_baselines/sac/sac.py index fa6738b..d7906e8 100644 --- a/torchy_baselines/sac/sac.py +++ b/torchy_baselines/sac/sac.py @@ -163,14 +163,12 @@ class SAC(OffPolicyRLModel): # Sample replay buffer replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) - obs, action_batch, next_obs, done, reward = replay_data - # We need to sample because `log_std` may have changed between two gradient steps if self.use_sde: self.actor.reset_noise() # Action by the current actor for the sampled state - action_pi, log_prob = self.actor.action_log_prob(obs) + actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations) log_prob = log_prob.reshape(-1, 1) ent_coef_loss = None @@ -192,17 +190,17 @@ class SAC(OffPolicyRLModel): with th.no_grad(): # Select action according to policy - next_action, next_log_prob = self.actor.action_log_prob(next_obs) + next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations) # Compute the target Q value - target_q1, target_q2 = self.critic_target(next_obs, next_action) + target_q1, target_q2 = self.critic_target(replay_data.next_observations, next_actions) target_q = th.min(target_q1, target_q2) - target_q = reward + (1 - done) * self.gamma * target_q + target_q = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q # td error + entropy term q_backup = target_q - ent_coef * next_log_prob.reshape(-1, 1) # Get current Q estimates # using action from the replay buffer - current_q1, current_q2 = self.critic(obs, action_batch) + current_q1, current_q2 = self.critic(replay_data.observations, replay_data.actions) # Compute critic loss critic_loss = 0.5 * (F.mse_loss(current_q1, q_backup) + F.mse_loss(current_q2, q_backup)) @@ -214,7 +212,7 @@ class SAC(OffPolicyRLModel): # Compute actor loss # Alternative: actor_loss = th.mean(log_prob - qf1_pi) - qf1_pi, qf2_pi = self.critic.forward(obs, action_pi) + qf1_pi, qf2_pi = self.critic.forward(replay_data.observations, actions_pi) min_qf_pi = th.min(qf1_pi, qf2_pi) actor_loss = (ent_coef * log_prob - min_qf_pi).mean() diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 1ba5947..963338d 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -124,22 +124,20 @@ class TD3(OffPolicyRLModel): for gradient_step in range(gradient_steps): # Sample replay buffer if replay_data is None: - obs, action, next_obs, done, reward = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) - else: - obs, action, next_obs, done, reward = replay_data + replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # Select action according to policy and add clipped noise - noise = action.clone().data.normal_(0, self.target_policy_noise) + noise = replay_data.actions.clone().data.normal_(0, self.target_policy_noise) noise = noise.clamp(-self.target_noise_clip, self.target_noise_clip) - next_action = (self.actor_target(next_obs) + noise).clamp(-1, 1) + next_actions = (self.actor_target(replay_data.next_observations) + noise).clamp(-1, 1) # Compute the target Q value - target_q1, target_q2 = self.critic_target(next_obs, next_action) + target_q1, target_q2 = self.critic_target(replay_data.next_observations, next_actions) target_q = th.min(target_q1, target_q2) - target_q = reward + ((1 - done) * self.gamma * target_q).detach() + target_q = replay_data.rewards + ((1 - replay_data.dones) * self.gamma * target_q).detach() # Get current Q estimates - current_q1, current_q2 = self.critic(obs, action) + current_q1, current_q2 = self.critic(replay_data.observations, replay_data.actions) # Compute critic loss critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q) @@ -167,12 +165,10 @@ class TD3(OffPolicyRLModel): for gradient_step in range(gradient_steps): # Sample replay buffer if replay_data is None: - obs, _, next_obs, done, reward = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) - else: - obs, _, next_obs, done, reward = replay_data + replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # Compute actor loss - actor_loss = -self.critic.q1_forward(obs, self.actor(obs)).mean() + actor_loss = -self.critic.q1_forward(replay_data.observations, self.actor(replay_data.observations)).mean() # Optimize the actor self.actor.optimizer.zero_grad()