From 20ee8cb68dcd036b5045e6fce41d6b31f359e16a Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 10 Mar 2020 16:55:13 +0100 Subject: [PATCH] Update changelog and add more namedtuples --- docs/misc/changelog.rst | 1 + torchy_baselines/common/base_class.py | 10 +++++----- torchy_baselines/common/type_aliases.py | 13 +++++++++++-- torchy_baselines/sac/sac.py | 10 ++++------ torchy_baselines/td3/td3.py | 9 ++++----- 5 files changed, 25 insertions(+), 18 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index b05e73a..2aef391 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -23,6 +23,7 @@ Others: ^^^^^^^ - SAC with SDE now sample only one matrix - Added ``clip_mean`` parameter to SAC policy +- Buffers now return ``NamedTuple`` Documentation: ^^^^^^^^^^^^^^ diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index b010025..d6bd69d 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -16,7 +16,7 @@ from torchy_baselines.common.policies import BasePolicy, get_policy_from_name from torchy_baselines.common.utils import set_random_seed, get_schedule_fn, update_learning_rate from torchy_baselines.common.vec_env import DummyVecEnv, VecEnv, unwrap_vec_normalize, VecNormalize from torchy_baselines.common.save_util import data_to_json, json_to_data, recursive_getattr, recursive_setattr -from torchy_baselines.common.type_aliases import GymEnv, TensorDict, OptimizerStateDict +from torchy_baselines.common.type_aliases import GymEnv, TensorDict, OptimizerStateDict, RolloutReturn from torchy_baselines.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback from torchy_baselines.common.monitor import Monitor from torchy_baselines.common.noise import ActionNoise @@ -830,7 +830,7 @@ class OffPolicyRLModel(BaseRLModel): replay_buffer: Optional[ReplayBuffer] = None, obs: Optional[np.ndarray] = None, episode_num: int = 0, - log_interval: Optional[int] = None) -> Tuple[float, int, int, Optional[np.ndarray], bool]: + log_interval: Optional[int] = None) -> RolloutReturn: """ Collect rollout using the current policy (and possibly fill the replay buffer) @@ -849,6 +849,7 @@ class OffPolicyRLModel(BaseRLModel): :param obs: (np.ndarray) Last observation from the environment :param episode_num: (int) Episode index :param log_interval: (int) Log data every `log_interval` episodes + :return: (RolloutReturn) """ episode_rewards, total_timesteps = [], [] total_steps, total_episodes = 0, 0 @@ -878,8 +879,7 @@ class OffPolicyRLModel(BaseRLModel): # Only stop training if return value is False, not when it is None. if callback() is False: - continue_training = False - return 0.0, total_steps, total_episodes, None, continue_training + return RolloutReturn(0.0, total_steps, total_episodes, None, continue_training=False) if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: # Sample a new noise matrix @@ -1003,4 +1003,4 @@ class OffPolicyRLModel(BaseRLModel): callback.on_rollout_end() - return mean_reward, total_steps, total_episodes, obs, continue_training + return RolloutReturn(mean_reward, total_steps, total_episodes, obs, continue_training) diff --git a/torchy_baselines/common/type_aliases.py b/torchy_baselines/common/type_aliases.py index 16576ff..12c220f 100644 --- a/torchy_baselines/common/type_aliases.py +++ b/torchy_baselines/common/type_aliases.py @@ -4,6 +4,7 @@ Common aliases for type hing from typing import Union, Type, Optional, Dict, Any, List, NamedTuple from collections import namedtuple +import numpy as np import torch as th import gym @@ -13,7 +14,8 @@ from torchy_baselines.common.vec_env import VecEnv GymEnv = Union[gym.Env, VecEnv] TensorDict = Dict[str, th.Tensor] OptimizerStateDict = Dict[str, Any] -# obs, action, old_values, old_log_prob, advantage, return_batch + + class RolloutBufferSamples(NamedTuple): observations: th.Tensor actions: th.Tensor @@ -23,10 +25,17 @@ class RolloutBufferSamples(NamedTuple): returns: th.Tensor -# obs, action, next_obs, done, reward class ReplayBufferSamples(NamedTuple): observations: th.Tensor actions: th.Tensor next_observations: th.Tensor dones: th.Tensor rewards: th.Tensor + + +class RolloutReturn(NamedTuple): + episode_reward: float + episode_timesteps: int + n_episodes: int + obs: Optional[np.ndarray] + continue_training: bool diff --git a/torchy_baselines/sac/sac.py b/torchy_baselines/sac/sac.py index d7906e8..58b19db 100644 --- a/torchy_baselines/sac/sac.py +++ b/torchy_baselines/sac/sac.py @@ -249,18 +249,16 @@ class SAC(OffPolicyRLModel): replay_buffer=self.replay_buffer, obs=obs, episode_num=episode_num, log_interval=log_interval) - # Unpack - episode_reward, episode_timesteps, n_episodes, obs, continue_training = rollout - if continue_training is False: + if rollout.continue_training is False: break - episode_num += n_episodes + obs = rollout.obs + episode_num += rollout.n_episodes self._update_current_progress(self.num_timesteps, total_timesteps) if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts: - gradient_steps = self.gradient_steps if self.gradient_steps > 0 else episode_timesteps - + gradient_steps = self.gradient_steps if self.gradient_steps > 0 else rollout.episode_timesteps self.train(gradient_steps, batch_size=self.batch_size) callback.on_training_end() diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 963338d..767f2f0 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -259,13 +259,12 @@ class TD3(OffPolicyRLModel): replay_buffer=self.replay_buffer, obs=obs, episode_num=episode_num, log_interval=log_interval) - # Unpack - episode_reward, episode_timesteps, n_episodes, obs, continue_training = rollout - if continue_training is False: + if rollout.continue_training is False: break - episode_num += n_episodes + obs = rollout.obs + episode_num += rollout.n_episodes self._update_current_progress(self.num_timesteps, total_timesteps) if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts: @@ -279,7 +278,7 @@ class TD3(OffPolicyRLModel): # On-policy gradient self.train_sde() - gradient_steps = self.gradient_steps if self.gradient_steps > 0 else episode_timesteps + gradient_steps = self.gradient_steps if self.gradient_steps > 0 else rollout.episode_timesteps self.train(gradient_steps, batch_size=self.batch_size, policy_delay=self.policy_delay) callback.on_training_end()