mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-28 22:56:53 +00:00
Update changelog and add more namedtuples
This commit is contained in:
parent
fb4e66213d
commit
20ee8cb68d
5 changed files with 25 additions and 18 deletions
|
|
@ -23,6 +23,7 @@ Others:
|
|||
^^^^^^^
|
||||
- SAC with SDE now sample only one matrix
|
||||
- Added ``clip_mean`` parameter to SAC policy
|
||||
- Buffers now return ``NamedTuple``
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from torchy_baselines.common.policies import BasePolicy, get_policy_from_name
|
|||
from torchy_baselines.common.utils import set_random_seed, get_schedule_fn, update_learning_rate
|
||||
from torchy_baselines.common.vec_env import DummyVecEnv, VecEnv, unwrap_vec_normalize, VecNormalize
|
||||
from torchy_baselines.common.save_util import data_to_json, json_to_data, recursive_getattr, recursive_setattr
|
||||
from torchy_baselines.common.type_aliases import GymEnv, TensorDict, OptimizerStateDict
|
||||
from torchy_baselines.common.type_aliases import GymEnv, TensorDict, OptimizerStateDict, RolloutReturn
|
||||
from torchy_baselines.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback
|
||||
from torchy_baselines.common.monitor import Monitor
|
||||
from torchy_baselines.common.noise import ActionNoise
|
||||
|
|
@ -830,7 +830,7 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
replay_buffer: Optional[ReplayBuffer] = None,
|
||||
obs: Optional[np.ndarray] = None,
|
||||
episode_num: int = 0,
|
||||
log_interval: Optional[int] = None) -> Tuple[float, int, int, Optional[np.ndarray], bool]:
|
||||
log_interval: Optional[int] = None) -> RolloutReturn:
|
||||
"""
|
||||
Collect rollout using the current policy (and possibly fill the replay buffer)
|
||||
|
||||
|
|
@ -849,6 +849,7 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
:param obs: (np.ndarray) Last observation from the environment
|
||||
:param episode_num: (int) Episode index
|
||||
:param log_interval: (int) Log data every `log_interval` episodes
|
||||
:return: (RolloutReturn)
|
||||
"""
|
||||
episode_rewards, total_timesteps = [], []
|
||||
total_steps, total_episodes = 0, 0
|
||||
|
|
@ -878,8 +879,7 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
|
||||
# Only stop training if return value is False, not when it is None.
|
||||
if callback() is False:
|
||||
continue_training = False
|
||||
return 0.0, total_steps, total_episodes, None, continue_training
|
||||
return RolloutReturn(0.0, total_steps, total_episodes, None, continue_training=False)
|
||||
|
||||
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
|
||||
# Sample a new noise matrix
|
||||
|
|
@ -1003,4 +1003,4 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
|
||||
callback.on_rollout_end()
|
||||
|
||||
return mean_reward, total_steps, total_episodes, obs, continue_training
|
||||
return RolloutReturn(mean_reward, total_steps, total_episodes, obs, continue_training)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ Common aliases for type hing
|
|||
from typing import Union, Type, Optional, Dict, Any, List, NamedTuple
|
||||
from collections import namedtuple
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import gym
|
||||
|
||||
|
|
@ -13,7 +14,8 @@ from torchy_baselines.common.vec_env import VecEnv
|
|||
GymEnv = Union[gym.Env, VecEnv]
|
||||
TensorDict = Dict[str, th.Tensor]
|
||||
OptimizerStateDict = Dict[str, Any]
|
||||
# obs, action, old_values, old_log_prob, advantage, return_batch
|
||||
|
||||
|
||||
class RolloutBufferSamples(NamedTuple):
|
||||
observations: th.Tensor
|
||||
actions: th.Tensor
|
||||
|
|
@ -23,10 +25,17 @@ class RolloutBufferSamples(NamedTuple):
|
|||
returns: th.Tensor
|
||||
|
||||
|
||||
# obs, action, next_obs, done, reward
|
||||
class ReplayBufferSamples(NamedTuple):
|
||||
observations: th.Tensor
|
||||
actions: th.Tensor
|
||||
next_observations: th.Tensor
|
||||
dones: th.Tensor
|
||||
rewards: th.Tensor
|
||||
|
||||
|
||||
class RolloutReturn(NamedTuple):
|
||||
episode_reward: float
|
||||
episode_timesteps: int
|
||||
n_episodes: int
|
||||
obs: Optional[np.ndarray]
|
||||
continue_training: bool
|
||||
|
|
|
|||
|
|
@ -249,18 +249,16 @@ class SAC(OffPolicyRLModel):
|
|||
replay_buffer=self.replay_buffer,
|
||||
obs=obs, episode_num=episode_num,
|
||||
log_interval=log_interval)
|
||||
# Unpack
|
||||
episode_reward, episode_timesteps, n_episodes, obs, continue_training = rollout
|
||||
|
||||
if continue_training is False:
|
||||
if rollout.continue_training is False:
|
||||
break
|
||||
|
||||
episode_num += n_episodes
|
||||
obs = rollout.obs
|
||||
episode_num += rollout.n_episodes
|
||||
self._update_current_progress(self.num_timesteps, total_timesteps)
|
||||
|
||||
if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
|
||||
gradient_steps = self.gradient_steps if self.gradient_steps > 0 else episode_timesteps
|
||||
|
||||
gradient_steps = self.gradient_steps if self.gradient_steps > 0 else rollout.episode_timesteps
|
||||
self.train(gradient_steps, batch_size=self.batch_size)
|
||||
|
||||
callback.on_training_end()
|
||||
|
|
|
|||
|
|
@ -259,13 +259,12 @@ class TD3(OffPolicyRLModel):
|
|||
replay_buffer=self.replay_buffer,
|
||||
obs=obs, episode_num=episode_num,
|
||||
log_interval=log_interval)
|
||||
# Unpack
|
||||
episode_reward, episode_timesteps, n_episodes, obs, continue_training = rollout
|
||||
|
||||
if continue_training is False:
|
||||
if rollout.continue_training is False:
|
||||
break
|
||||
|
||||
episode_num += n_episodes
|
||||
obs = rollout.obs
|
||||
episode_num += rollout.n_episodes
|
||||
self._update_current_progress(self.num_timesteps, total_timesteps)
|
||||
|
||||
if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
|
||||
|
|
@ -279,7 +278,7 @@ class TD3(OffPolicyRLModel):
|
|||
# On-policy gradient
|
||||
self.train_sde()
|
||||
|
||||
gradient_steps = self.gradient_steps if self.gradient_steps > 0 else episode_timesteps
|
||||
gradient_steps = self.gradient_steps if self.gradient_steps > 0 else rollout.episode_timesteps
|
||||
self.train(gradient_steps, batch_size=self.batch_size, policy_delay=self.policy_delay)
|
||||
|
||||
callback.on_training_end()
|
||||
|
|
|
|||
Loading…
Reference in a new issue