Update changelog and add more namedtuples

This commit is contained in:
Antonin Raffin 2020-03-10 16:55:13 +01:00
parent fb4e66213d
commit 20ee8cb68d
5 changed files with 25 additions and 18 deletions

View file

@ -23,6 +23,7 @@ Others:
^^^^^^^
- SAC with SDE now sample only one matrix
- Added ``clip_mean`` parameter to SAC policy
- Buffers now return ``NamedTuple``
Documentation:
^^^^^^^^^^^^^^

View file

@ -16,7 +16,7 @@ from torchy_baselines.common.policies import BasePolicy, get_policy_from_name
from torchy_baselines.common.utils import set_random_seed, get_schedule_fn, update_learning_rate
from torchy_baselines.common.vec_env import DummyVecEnv, VecEnv, unwrap_vec_normalize, VecNormalize
from torchy_baselines.common.save_util import data_to_json, json_to_data, recursive_getattr, recursive_setattr
from torchy_baselines.common.type_aliases import GymEnv, TensorDict, OptimizerStateDict
from torchy_baselines.common.type_aliases import GymEnv, TensorDict, OptimizerStateDict, RolloutReturn
from torchy_baselines.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback
from torchy_baselines.common.monitor import Monitor
from torchy_baselines.common.noise import ActionNoise
@ -830,7 +830,7 @@ class OffPolicyRLModel(BaseRLModel):
replay_buffer: Optional[ReplayBuffer] = None,
obs: Optional[np.ndarray] = None,
episode_num: int = 0,
log_interval: Optional[int] = None) -> Tuple[float, int, int, Optional[np.ndarray], bool]:
log_interval: Optional[int] = None) -> RolloutReturn:
"""
Collect rollout using the current policy (and possibly fill the replay buffer)
@ -849,6 +849,7 @@ class OffPolicyRLModel(BaseRLModel):
:param obs: (np.ndarray) Last observation from the environment
:param episode_num: (int) Episode index
:param log_interval: (int) Log data every `log_interval` episodes
:return: (RolloutReturn)
"""
episode_rewards, total_timesteps = [], []
total_steps, total_episodes = 0, 0
@ -878,8 +879,7 @@ class OffPolicyRLModel(BaseRLModel):
# Only stop training if return value is False, not when it is None.
if callback() is False:
continue_training = False
return 0.0, total_steps, total_episodes, None, continue_training
return RolloutReturn(0.0, total_steps, total_episodes, None, continue_training=False)
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
# Sample a new noise matrix
@ -1003,4 +1003,4 @@ class OffPolicyRLModel(BaseRLModel):
callback.on_rollout_end()
return mean_reward, total_steps, total_episodes, obs, continue_training
return RolloutReturn(mean_reward, total_steps, total_episodes, obs, continue_training)

View file

@ -4,6 +4,7 @@ Common aliases for type hing
from typing import Union, Type, Optional, Dict, Any, List, NamedTuple
from collections import namedtuple
import numpy as np
import torch as th
import gym
@ -13,7 +14,8 @@ from torchy_baselines.common.vec_env import VecEnv
GymEnv = Union[gym.Env, VecEnv]
TensorDict = Dict[str, th.Tensor]
OptimizerStateDict = Dict[str, Any]
# obs, action, old_values, old_log_prob, advantage, return_batch
class RolloutBufferSamples(NamedTuple):
observations: th.Tensor
actions: th.Tensor
@ -23,10 +25,17 @@ class RolloutBufferSamples(NamedTuple):
returns: th.Tensor
# obs, action, next_obs, done, reward
class ReplayBufferSamples(NamedTuple):
observations: th.Tensor
actions: th.Tensor
next_observations: th.Tensor
dones: th.Tensor
rewards: th.Tensor
class RolloutReturn(NamedTuple):
episode_reward: float
episode_timesteps: int
n_episodes: int
obs: Optional[np.ndarray]
continue_training: bool

View file

@ -249,18 +249,16 @@ class SAC(OffPolicyRLModel):
replay_buffer=self.replay_buffer,
obs=obs, episode_num=episode_num,
log_interval=log_interval)
# Unpack
episode_reward, episode_timesteps, n_episodes, obs, continue_training = rollout
if continue_training is False:
if rollout.continue_training is False:
break
episode_num += n_episodes
obs = rollout.obs
episode_num += rollout.n_episodes
self._update_current_progress(self.num_timesteps, total_timesteps)
if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
gradient_steps = self.gradient_steps if self.gradient_steps > 0 else episode_timesteps
gradient_steps = self.gradient_steps if self.gradient_steps > 0 else rollout.episode_timesteps
self.train(gradient_steps, batch_size=self.batch_size)
callback.on_training_end()

View file

@ -259,13 +259,12 @@ class TD3(OffPolicyRLModel):
replay_buffer=self.replay_buffer,
obs=obs, episode_num=episode_num,
log_interval=log_interval)
# Unpack
episode_reward, episode_timesteps, n_episodes, obs, continue_training = rollout
if continue_training is False:
if rollout.continue_training is False:
break
episode_num += n_episodes
obs = rollout.obs
episode_num += rollout.n_episodes
self._update_current_progress(self.num_timesteps, total_timesteps)
if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
@ -279,7 +278,7 @@ class TD3(OffPolicyRLModel):
# On-policy gradient
self.train_sde()
gradient_steps = self.gradient_steps if self.gradient_steps > 0 else episode_timesteps
gradient_steps = self.gradient_steps if self.gradient_steps > 0 else rollout.episode_timesteps
self.train(gradient_steps, batch_size=self.batch_size, policy_delay=self.policy_delay)
callback.on_training_end()