mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-25 02:50:59 +00:00
* First commit * Fixing missing refs from a quick merge from master * Reformat * Adding DictBuffers * Reformat * Minor reformat * added slow dict test. Added SACMultiInputPolicy for future. Added private static image transpose helper to common policy * Ran black on buffers * Ran isort * Adding StackedObservations classes used within VecStackEnvs wrappers. Made test_dict_env shorter and removed slow * Running isort :facepalm * Fixed typing issues * Adding docstrings and typing. Using util for moving data to device. * Fixed trailing commas * Fix types * Minor edits * Avoid duplicating code * Fix calls to parents * Adding assert to buffers. Updating changelong * Running format on buffers * Adding multi-input policies to dqn,td3,a2c. Fixing warnings. Fixed bug with DictReplayBuffer as Replay buffers use only 1 env * Fixing warnings, splitting is_vectorized_observation into multiple functions based on space type * Created envs folder in common. Updated imports. Moved stacked_obs to vec_env folder * Moved envs to envs directory. Moved stacked obs to vec_envs. Started update on documentation * Fixes * Running code style * Update docstrings on torch_layers * Decapitalize non-constant variables * Using NatureCNN architecture in combined extractor. Increasing img size in multi input env. Adding memory reduction in test * Update doc * Update doc * Fix format * Removing NineRoom env. Using nested preprocess. Removing mutable default args * running code style * Passing channel check through to stacked dict observations. * Running black * Adding channel control to SimpleMultiObsEnv. Passing check_channels to CombinedExtractor * Remove optimize memory for dict buffers * Update doc * Move identity env * Minor edits + bump version * Update doc * Fix doc build * Bug fixes + add support for more type of dict env * Fixes + add multi env test * Add support for vectranspose * Fix stacked obs for dict and add tests * Add check for nested spaces. Fix dict-subprocvecenv test * Fix (single) pytype error * Simplify CombinedExtractor * Fix tests * Fix check * Merge branch 'master' into feat/dict_observations * Fix for net_arch with dict and vector obs * Fixes * Add consistency test * Update env checker * Add some docs on dict obs * Update default CNN feature vector size * Refactor HER (#351) * Start refactoring HER * Fixes * Additional fixes * Faster tests * WIP: HER as a custom replay buffer * New replay only version (working with DQN) * Add support for all off-policy algorithms * Fix saving/loading * Remove ObsDictWrapper and add VecNormalize tests with dict * Stable-Baselines3 v1.0 (#354) * Bump version and update doc * Fix name * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> * Update docs/index.rst Co-authored-by: Adam Gleave <adam@gleave.me> * Update wording for RL zoo Co-authored-by: Adam Gleave <adam@gleave.me> * Add gym-pybullet-drones project (#358) * Update projects.rst Added gym-pybullet-drones * Update projects.rst Longer title underline * Update changelog Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org> * Include SuperSuit in projects (#359) * include supersuit * longer title underline * Update changelog.rst * Fix default arguments + add bugbear (#363) * Fix potential bug + add bug bear * Remove unused variables * Minor: version bump * Add code of conduct + update doc (#373) * Add code of conduct * Fix DQN doc example * Update doc (channel-last/first) * Apply suggestions from code review Co-authored-by: Anssi <kaneran21@hotmail.com> * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Adam Gleave <adam@gleave.me> * Make installation command compatible with ZSH (#376) * Add quotes * Add Zsh bracket info * Add clarify pip installation line * Make note bold * Add Zsh pip installation note * Add handle timeouts param * Fixes * Fixes (buffer size, extend test) * Fix `max_episode_length` redefinition * Fix potential issue * Add some docs on dict obs * Fix performance bug * Fix slowdown * Add package to install (#378) * Add package to install * Update docs packages installation command Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Fix backward compat + add test * Fix VecEnv detection * Update doc * Fix vec env check * Support for `VecMonitor` for gym3-style environments (#311) * add vectorized monitor * auto format of the code * add documentation and VecExtractDictObs * refactor and add test cases * add test cases and format * avoid circular import and fix doc * fix type * fix type * oops * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * add test cases * update changelog * fix mutable argument * quick fix * Apply suggestions from code review * fix terminal observation for gym3 envs * delete comment * Update doc and bump version * Add warning when already using `Monitor` wrapper * Update vecmonitor tests * Fixes Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Reformat * Fixed loading of ``ent_coef`` for ``SAC`` and ``TQC``, it was not optimized anymore (#392) * Fix ent coef loading bug * Add test * Add comment * Reuse save path * Add test for GAE + rename `RolloutBuffer.dones` for clarification (#375) * Fix return computation + add test for GAE * Rename `last_dones` to `episode_starts` for clarification * Revert advantage * Cleanup test * Rename variable * Clarify return computation * Clarify docs * Add multi-episode rollout test * Reformat Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com> * Fixed saving of `A2C` and `PPO` policy when using gSDE (#401) * Improve doc and replay buffer loading * Add support for images * Fix doc * Update Procgen doc * Update changelog * Update docstrings Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca> Co-authored-by: Justin Terry <justinkterry@gmail.com> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Tom Dörr <tomdoerr96@gmail.com> Co-authored-by: Tom Dörr <tom.doerr@tum.de> Co-authored-by: Costa Huang <costa.huang@outlook.com> * Update doc and minor fixes * Update doc * Added note about MultiInputPolicy in error of NatureCNN * Merge branch 'master' into feat/dict_observations * Address comments * Naming clarifications * Actually saving the file would be nice * Fix edge case when doing online sampling with HER * Cleanup * Add sanity check Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com> Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca> Co-authored-by: Justin Terry <justinkterry@gmail.com> Co-authored-by: Tom Dörr <tomdoerr96@gmail.com> Co-authored-by: Tom Dörr <tom.doerr@tum.de> Co-authored-by: Costa Huang <costa.huang@outlook.com>
739 lines
30 KiB
Python
739 lines
30 KiB
Python
import warnings
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Dict, Generator, List, Optional, Union
|
|
|
|
import numpy as np
|
|
import torch as th
|
|
from gym import spaces
|
|
|
|
from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape
|
|
from stable_baselines3.common.type_aliases import (
|
|
DictReplayBufferSamples,
|
|
DictRolloutBufferSamples,
|
|
ReplayBufferSamples,
|
|
RolloutBufferSamples,
|
|
)
|
|
from stable_baselines3.common.vec_env import VecNormalize
|
|
|
|
try:
|
|
# Check memory used by replay buffer when possible
|
|
import psutil
|
|
except ImportError:
|
|
psutil = None
|
|
|
|
|
|
class BaseBuffer(ABC):
|
|
"""
|
|
Base class that represent a buffer (rollout or replay)
|
|
|
|
:param buffer_size: Max number of element in the buffer
|
|
:param observation_space: Observation space
|
|
:param action_space: Action space
|
|
:param device: PyTorch device
|
|
to which the values will be converted
|
|
:param n_envs: Number of parallel environments
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
buffer_size: int,
|
|
observation_space: spaces.Space,
|
|
action_space: spaces.Space,
|
|
device: Union[th.device, str] = "cpu",
|
|
n_envs: int = 1,
|
|
):
|
|
super(BaseBuffer, self).__init__()
|
|
self.buffer_size = buffer_size
|
|
self.observation_space = observation_space
|
|
self.action_space = action_space
|
|
self.obs_shape = get_obs_shape(observation_space)
|
|
|
|
self.action_dim = get_action_dim(action_space)
|
|
self.pos = 0
|
|
self.full = False
|
|
self.device = device
|
|
self.n_envs = n_envs
|
|
|
|
@staticmethod
|
|
def swap_and_flatten(arr: np.ndarray) -> np.ndarray:
|
|
"""
|
|
Swap and then flatten axes 0 (buffer_size) and 1 (n_envs)
|
|
to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features)
|
|
to [n_steps * n_envs, ...] (which maintain the order)
|
|
|
|
:param arr:
|
|
:return:
|
|
"""
|
|
shape = arr.shape
|
|
if len(shape) < 3:
|
|
shape = shape + (1,)
|
|
return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:])
|
|
|
|
def size(self) -> int:
|
|
"""
|
|
:return: The current size of the buffer
|
|
"""
|
|
if self.full:
|
|
return self.buffer_size
|
|
return self.pos
|
|
|
|
def add(self, *args, **kwargs) -> None:
|
|
"""
|
|
Add elements to the buffer.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
def extend(self, *args, **kwargs) -> None:
|
|
"""
|
|
Add a new batch of transitions to the buffer
|
|
"""
|
|
# Do a for loop along the batch axis
|
|
for data in zip(*args):
|
|
self.add(*data)
|
|
|
|
def reset(self) -> None:
|
|
"""
|
|
Reset the buffer.
|
|
"""
|
|
self.pos = 0
|
|
self.full = False
|
|
|
|
def sample(self, batch_size: int, env: Optional[VecNormalize] = None):
|
|
"""
|
|
:param batch_size: Number of element to sample
|
|
:param env: associated gym VecEnv
|
|
to normalize the observations/rewards when sampling
|
|
:return:
|
|
"""
|
|
upper_bound = self.buffer_size if self.full else self.pos
|
|
batch_inds = np.random.randint(0, upper_bound, size=batch_size)
|
|
return self._get_samples(batch_inds, env=env)
|
|
|
|
@abstractmethod
|
|
def _get_samples(
|
|
self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None
|
|
) -> Union[ReplayBufferSamples, RolloutBufferSamples]:
|
|
"""
|
|
:param batch_inds:
|
|
:param env:
|
|
:return:
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:
|
|
"""
|
|
Convert a numpy array to a PyTorch tensor.
|
|
Note: it copies the data by default
|
|
|
|
:param array:
|
|
:param copy: Whether to copy or not the data
|
|
(may be useful to avoid changing things be reference)
|
|
:return:
|
|
"""
|
|
if copy:
|
|
return th.tensor(array).to(self.device)
|
|
return th.as_tensor(array).to(self.device)
|
|
|
|
@staticmethod
|
|
def _normalize_obs(
|
|
obs: Union[np.ndarray, Dict[str, np.ndarray]],
|
|
env: Optional[VecNormalize] = None,
|
|
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
|
if env is not None:
|
|
return env.normalize_obs(obs)
|
|
return obs
|
|
|
|
@staticmethod
|
|
def _normalize_reward(reward: np.ndarray, env: Optional[VecNormalize] = None) -> np.ndarray:
|
|
if env is not None:
|
|
return env.normalize_reward(reward).astype(np.float32)
|
|
return reward
|
|
|
|
|
|
class ReplayBuffer(BaseBuffer):
|
|
"""
|
|
Replay buffer used in off-policy algorithms like SAC/TD3.
|
|
|
|
:param buffer_size: Max number of element in the buffer
|
|
:param observation_space: Observation space
|
|
:param action_space: Action space
|
|
:param device:
|
|
:param n_envs: Number of parallel environments
|
|
:param optimize_memory_usage: Enable a memory efficient variant
|
|
of the replay buffer which reduces by almost a factor two the memory used,
|
|
at a cost of more complexity.
|
|
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
|
|
and https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
|
|
:param handle_timeout_termination: Handle timeout termination (due to timelimit)
|
|
separately and treat the task as infinite horizon task.
|
|
https://github.com/DLR-RM/stable-baselines3/issues/284
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
buffer_size: int,
|
|
observation_space: spaces.Space,
|
|
action_space: spaces.Space,
|
|
device: Union[th.device, str] = "cpu",
|
|
n_envs: int = 1,
|
|
optimize_memory_usage: bool = False,
|
|
handle_timeout_termination: bool = True,
|
|
):
|
|
super(ReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
|
|
|
|
assert n_envs == 1, "Replay buffer only support single environment for now"
|
|
|
|
# Check that the replay buffer can fit into the memory
|
|
if psutil is not None:
|
|
mem_available = psutil.virtual_memory().available
|
|
|
|
self.optimize_memory_usage = optimize_memory_usage
|
|
|
|
self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype)
|
|
|
|
if optimize_memory_usage:
|
|
# `observations` contains also the next observation
|
|
self.next_observations = None
|
|
else:
|
|
self.next_observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype)
|
|
|
|
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype)
|
|
|
|
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
# Handle timeouts termination properly if needed
|
|
# see https://github.com/DLR-RM/stable-baselines3/issues/284
|
|
self.handle_timeout_termination = handle_timeout_termination
|
|
self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
|
|
if psutil is not None:
|
|
total_memory_usage = self.observations.nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
|
|
|
|
if self.next_observations is not None:
|
|
total_memory_usage += self.next_observations.nbytes
|
|
|
|
if total_memory_usage > mem_available:
|
|
# Convert to GB
|
|
total_memory_usage /= 1e9
|
|
mem_available /= 1e9
|
|
warnings.warn(
|
|
"This system does not have apparently enough memory to store the complete "
|
|
f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
|
|
)
|
|
|
|
def add(
|
|
self,
|
|
obs: np.ndarray,
|
|
next_obs: np.ndarray,
|
|
action: np.ndarray,
|
|
reward: np.ndarray,
|
|
done: np.ndarray,
|
|
infos: List[Dict[str, Any]],
|
|
) -> None:
|
|
# Copy to avoid modification by reference
|
|
self.observations[self.pos] = np.array(obs).copy()
|
|
|
|
if self.optimize_memory_usage:
|
|
self.observations[(self.pos + 1) % self.buffer_size] = np.array(next_obs).copy()
|
|
else:
|
|
self.next_observations[self.pos] = np.array(next_obs).copy()
|
|
|
|
self.actions[self.pos] = np.array(action).copy()
|
|
self.rewards[self.pos] = np.array(reward).copy()
|
|
self.dones[self.pos] = np.array(done).copy()
|
|
|
|
if self.handle_timeout_termination:
|
|
self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
|
|
|
|
self.pos += 1
|
|
if self.pos == self.buffer_size:
|
|
self.full = True
|
|
self.pos = 0
|
|
|
|
def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
|
|
"""
|
|
Sample elements from the replay buffer.
|
|
Custom sampling when using memory efficient variant,
|
|
as we should not sample the element with index `self.pos`
|
|
See https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
|
|
|
|
:param batch_size: Number of element to sample
|
|
:param env: associated gym VecEnv
|
|
to normalize the observations/rewards when sampling
|
|
:return:
|
|
"""
|
|
if not self.optimize_memory_usage:
|
|
return super().sample(batch_size=batch_size, env=env)
|
|
# Do not sample the element with index `self.pos` as the transitions is invalid
|
|
# (we use only one array to store `obs` and `next_obs`)
|
|
if self.full:
|
|
batch_inds = (np.random.randint(1, self.buffer_size, size=batch_size) + self.pos) % self.buffer_size
|
|
else:
|
|
batch_inds = np.random.randint(0, self.pos, size=batch_size)
|
|
return self._get_samples(batch_inds, env=env)
|
|
|
|
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
|
|
|
|
if self.optimize_memory_usage:
|
|
next_obs = self._normalize_obs(self.observations[(batch_inds + 1) % self.buffer_size, 0, :], env)
|
|
else:
|
|
next_obs = self._normalize_obs(self.next_observations[batch_inds, 0, :], env)
|
|
|
|
data = (
|
|
self._normalize_obs(self.observations[batch_inds, 0, :], env),
|
|
self.actions[batch_inds, 0, :],
|
|
next_obs,
|
|
# Only use dones that are not due to timeouts
|
|
# deactivated by default (timeouts is initialized as an array of False)
|
|
self.dones[batch_inds] * (1 - self.timeouts[batch_inds]),
|
|
self._normalize_reward(self.rewards[batch_inds], env),
|
|
)
|
|
return ReplayBufferSamples(*tuple(map(self.to_torch, data)))
|
|
|
|
|
|
class RolloutBuffer(BaseBuffer):
|
|
"""
|
|
Rollout buffer used in on-policy algorithms like A2C/PPO.
|
|
It corresponds to ``buffer_size`` transitions collected
|
|
using the current policy.
|
|
This experience will be discarded after the policy update.
|
|
In order to use PPO objective, we also store the current value of each state
|
|
and the log probability of each taken action.
|
|
|
|
The term rollout here refers to the model-free notion and should not
|
|
be used with the concept of rollout used in model-based RL or planning.
|
|
Hence, it is only involved in policy and value function training but not action selection.
|
|
|
|
:param buffer_size: Max number of element in the buffer
|
|
:param observation_space: Observation space
|
|
:param action_space: Action space
|
|
:param device:
|
|
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
|
Equivalent to classic advantage when set to 1.
|
|
:param gamma: Discount factor
|
|
:param n_envs: Number of parallel environments
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
buffer_size: int,
|
|
observation_space: spaces.Space,
|
|
action_space: spaces.Space,
|
|
device: Union[th.device, str] = "cpu",
|
|
gae_lambda: float = 1,
|
|
gamma: float = 0.99,
|
|
n_envs: int = 1,
|
|
):
|
|
|
|
super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
|
|
self.gae_lambda = gae_lambda
|
|
self.gamma = gamma
|
|
self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
|
|
self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None
|
|
self.generator_ready = False
|
|
self.reset()
|
|
|
|
def reset(self) -> None:
|
|
|
|
self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=np.float32)
|
|
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
|
|
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
self.generator_ready = False
|
|
super(RolloutBuffer, self).reset()
|
|
|
|
def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None:
|
|
"""
|
|
Post-processing step: compute the lambda-return (TD(lambda) estimate)
|
|
and GAE(lambda) advantage.
|
|
|
|
Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
|
|
to compute the advantage. To obtain vanilla advantage (A(s) = R - V(S))
|
|
where R is the discounted reward with value bootstrap,
|
|
set ``gae_lambda=1.0`` during initialization.
|
|
|
|
The TD(lambda) estimator has also two special cases:
|
|
- TD(1) is Monte-Carlo estimate (sum of discounted rewards)
|
|
- TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1}))
|
|
|
|
For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375.
|
|
|
|
:param last_values: state value estimation for the last step (one for each env)
|
|
:param dones: if the last step was a terminal step (one bool for each env).
|
|
|
|
"""
|
|
# Convert to numpy
|
|
last_values = last_values.clone().cpu().numpy().flatten()
|
|
|
|
last_gae_lam = 0
|
|
for step in reversed(range(self.buffer_size)):
|
|
if step == self.buffer_size - 1:
|
|
next_non_terminal = 1.0 - dones
|
|
next_values = last_values
|
|
else:
|
|
next_non_terminal = 1.0 - self.episode_starts[step + 1]
|
|
next_values = self.values[step + 1]
|
|
delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
|
|
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
|
|
self.advantages[step] = last_gae_lam
|
|
# TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
|
|
# in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
|
|
self.returns = self.advantages + self.values
|
|
|
|
def add(
|
|
self,
|
|
obs: np.ndarray,
|
|
action: np.ndarray,
|
|
reward: np.ndarray,
|
|
episode_start: np.ndarray,
|
|
value: th.Tensor,
|
|
log_prob: th.Tensor,
|
|
) -> None:
|
|
"""
|
|
:param obs: Observation
|
|
:param action: Action
|
|
:param reward:
|
|
:param episode_start: Start of episode signal.
|
|
:param value: estimated value of the current state
|
|
following the current policy.
|
|
:param log_prob: log probability of the action
|
|
following the current policy.
|
|
"""
|
|
if len(log_prob.shape) == 0:
|
|
# Reshape 0-d tensor to avoid error
|
|
log_prob = log_prob.reshape(-1, 1)
|
|
|
|
# Reshape needed when using multiple envs with discrete observations
|
|
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
|
|
if isinstance(self.observation_space, spaces.Discrete):
|
|
obs = obs.reshape((self.n_envs,) + self.obs_shape)
|
|
|
|
self.observations[self.pos] = np.array(obs).copy()
|
|
self.actions[self.pos] = np.array(action).copy()
|
|
self.rewards[self.pos] = np.array(reward).copy()
|
|
self.episode_starts[self.pos] = np.array(episode_start).copy()
|
|
self.values[self.pos] = value.clone().cpu().numpy().flatten()
|
|
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
|
|
self.pos += 1
|
|
if self.pos == self.buffer_size:
|
|
self.full = True
|
|
|
|
def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:
|
|
assert self.full, ""
|
|
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
|
# Prepare the data
|
|
if not self.generator_ready:
|
|
|
|
_tensor_names = [
|
|
"observations",
|
|
"actions",
|
|
"values",
|
|
"log_probs",
|
|
"advantages",
|
|
"returns",
|
|
]
|
|
|
|
for tensor in _tensor_names:
|
|
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
|
|
self.generator_ready = True
|
|
|
|
# Return everything, don't create minibatches
|
|
if batch_size is None:
|
|
batch_size = self.buffer_size * self.n_envs
|
|
|
|
start_idx = 0
|
|
while start_idx < self.buffer_size * self.n_envs:
|
|
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
|
start_idx += batch_size
|
|
|
|
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> RolloutBufferSamples:
|
|
data = (
|
|
self.observations[batch_inds],
|
|
self.actions[batch_inds],
|
|
self.values[batch_inds].flatten(),
|
|
self.log_probs[batch_inds].flatten(),
|
|
self.advantages[batch_inds].flatten(),
|
|
self.returns[batch_inds].flatten(),
|
|
)
|
|
return RolloutBufferSamples(*tuple(map(self.to_torch, data)))
|
|
|
|
|
|
class DictReplayBuffer(ReplayBuffer):
|
|
"""
|
|
Dict Replay buffer used in off-policy algorithms like SAC/TD3.
|
|
Extends the ReplayBuffer to use dictionary observations
|
|
|
|
:param buffer_size: Max number of element in the buffer
|
|
:param observation_space: Observation space
|
|
:param action_space: Action space
|
|
:param device:
|
|
:param n_envs: Number of parallel environments
|
|
:param optimize_memory_usage: Enable a memory efficient variant
|
|
Disabled for now (see https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702)
|
|
:param handle_timeout_termination: Handle timeout termination (due to timelimit)
|
|
separately and treat the task as infinite horizon task.
|
|
https://github.com/DLR-RM/stable-baselines3/issues/284
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
buffer_size: int,
|
|
observation_space: spaces.Space,
|
|
action_space: spaces.Space,
|
|
device: Union[th.device, str] = "cpu",
|
|
n_envs: int = 1,
|
|
optimize_memory_usage: bool = False,
|
|
handle_timeout_termination: bool = True,
|
|
):
|
|
super(ReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
|
|
|
|
assert isinstance(self.obs_shape, dict), "DictReplayBuffer must be used with Dict obs space only"
|
|
assert n_envs == 1, "Replay buffer only support single environment for now"
|
|
|
|
# Check that the replay buffer can fit into the memory
|
|
if psutil is not None:
|
|
mem_available = psutil.virtual_memory().available
|
|
|
|
assert optimize_memory_usage is False, "DictReplayBuffer does not support optimize_memory_usage"
|
|
# disabling as this adds quite a bit of complexity
|
|
# https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702
|
|
self.optimize_memory_usage = optimize_memory_usage
|
|
|
|
self.observations = {
|
|
key: np.zeros((self.buffer_size, self.n_envs) + _obs_shape) for key, _obs_shape in self.obs_shape.items()
|
|
}
|
|
self.next_observations = {
|
|
key: np.zeros((self.buffer_size, self.n_envs) + _obs_shape) for key, _obs_shape in self.obs_shape.items()
|
|
}
|
|
|
|
# only 1 env is supported
|
|
self.actions = np.zeros((self.buffer_size, self.action_dim), dtype=action_space.dtype)
|
|
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
|
|
# Handle timeouts termination properly if needed
|
|
# see https://github.com/DLR-RM/stable-baselines3/issues/284
|
|
self.handle_timeout_termination = handle_timeout_termination
|
|
self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
|
|
if psutil is not None:
|
|
obs_nbytes = 0
|
|
for _, obs in self.observations.items():
|
|
obs_nbytes += obs.nbytes
|
|
|
|
total_memory_usage = obs_nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
|
|
if self.next_observations is not None:
|
|
next_obs_nbytes = 0
|
|
for _, obs in self.observations.items():
|
|
next_obs_nbytes += obs.nbytes
|
|
total_memory_usage += next_obs_nbytes
|
|
|
|
if total_memory_usage > mem_available:
|
|
# Convert to GB
|
|
total_memory_usage /= 1e9
|
|
mem_available /= 1e9
|
|
warnings.warn(
|
|
"This system does not have apparently enough memory to store the complete "
|
|
f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
|
|
)
|
|
|
|
def add(
|
|
self,
|
|
obs: Dict[str, np.ndarray],
|
|
next_obs: Dict[str, np.ndarray],
|
|
action: np.ndarray,
|
|
reward: np.ndarray,
|
|
done: np.ndarray,
|
|
infos: List[Dict[str, Any]],
|
|
) -> None:
|
|
# Copy to avoid modification by reference
|
|
for key in self.observations.keys():
|
|
self.observations[key][self.pos] = np.array(obs[key]).copy()
|
|
|
|
for key in self.next_observations.keys():
|
|
self.next_observations[key][self.pos] = np.array(next_obs[key]).copy()
|
|
|
|
self.actions[self.pos] = np.array(action).copy()
|
|
self.rewards[self.pos] = np.array(reward).copy()
|
|
self.dones[self.pos] = np.array(done).copy()
|
|
|
|
if self.handle_timeout_termination:
|
|
self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
|
|
|
|
self.pos += 1
|
|
if self.pos == self.buffer_size:
|
|
self.full = True
|
|
self.pos = 0
|
|
|
|
def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples:
|
|
"""
|
|
Sample elements from the replay buffer.
|
|
|
|
:param batch_size: Number of element to sample
|
|
:param env: associated gym VecEnv
|
|
to normalize the observations/rewards when sampling
|
|
:return:
|
|
"""
|
|
return super(ReplayBuffer, self).sample(batch_size=batch_size, env=env)
|
|
|
|
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples:
|
|
|
|
# Normalize if needed and remove extra dimension (we are using only one env for now)
|
|
obs_ = self._normalize_obs({key: obs[batch_inds, 0, :] for key, obs in self.observations.items()})
|
|
next_obs_ = self._normalize_obs({key: obs[batch_inds, 0, :] for key, obs in self.next_observations.items()})
|
|
|
|
# Convert to torch tensor
|
|
observations = {key: self.to_torch(obs) for key, obs in obs_.items()}
|
|
next_observations = {key: self.to_torch(obs) for key, obs in next_obs_.items()}
|
|
|
|
return DictReplayBufferSamples(
|
|
observations=observations,
|
|
actions=self.to_torch(self.actions[batch_inds]),
|
|
next_observations=next_observations,
|
|
# Only use dones that are not due to timeouts
|
|
# deactivated by default (timeouts is initialized as an array of False)
|
|
dones=self.to_torch(self.dones[batch_inds] * (1 - self.timeouts[batch_inds])),
|
|
rewards=self.to_torch(self._normalize_reward(self.rewards[batch_inds], env)),
|
|
)
|
|
|
|
|
|
class DictRolloutBuffer(RolloutBuffer):
|
|
"""
|
|
Dict Rollout buffer used in on-policy algorithms like A2C/PPO.
|
|
Extends the RolloutBuffer to use dictionary observations
|
|
|
|
It corresponds to ``buffer_size`` transitions collected
|
|
using the current policy.
|
|
This experience will be discarded after the policy update.
|
|
In order to use PPO objective, we also store the current value of each state
|
|
and the log probability of each taken action.
|
|
|
|
The term rollout here refers to the model-free notion and should not
|
|
be used with the concept of rollout used in model-based RL or planning.
|
|
Hence, it is only involved in policy and value function training but not action selection.
|
|
|
|
:param buffer_size: Max number of element in the buffer
|
|
:param observation_space: Observation space
|
|
:param action_space: Action space
|
|
:param device:
|
|
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
|
Equivalent to classic advantage when set to 1.
|
|
:param gamma: Discount factor
|
|
:param n_envs: Number of parallel environments
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
buffer_size: int,
|
|
observation_space: spaces.Space,
|
|
action_space: spaces.Space,
|
|
device: Union[th.device, str] = "cpu",
|
|
gae_lambda: float = 1,
|
|
gamma: float = 0.99,
|
|
n_envs: int = 1,
|
|
):
|
|
|
|
super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
|
|
|
|
assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
|
|
|
|
self.gae_lambda = gae_lambda
|
|
self.gamma = gamma
|
|
self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
|
|
self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None
|
|
self.generator_ready = False
|
|
self.reset()
|
|
|
|
def reset(self) -> None:
|
|
assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
|
|
self.observations = {}
|
|
for key, obs_input_shape in self.obs_shape.items():
|
|
self.observations[key] = np.zeros((self.buffer_size, self.n_envs) + obs_input_shape, dtype=np.float32)
|
|
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
|
|
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
self.generator_ready = False
|
|
super(RolloutBuffer, self).reset()
|
|
|
|
def add(
|
|
self,
|
|
obs: Dict[str, np.ndarray],
|
|
action: np.ndarray,
|
|
reward: np.ndarray,
|
|
episode_start: np.ndarray,
|
|
value: th.Tensor,
|
|
log_prob: th.Tensor,
|
|
) -> None:
|
|
"""
|
|
:param obs: Observation
|
|
:param action: Action
|
|
:param reward:
|
|
:param episode_start: Start of episode signal.
|
|
:param value: estimated value of the current state
|
|
following the current policy.
|
|
:param log_prob: log probability of the action
|
|
following the current policy.
|
|
"""
|
|
if len(log_prob.shape) == 0:
|
|
# Reshape 0-d tensor to avoid error
|
|
log_prob = log_prob.reshape(-1, 1)
|
|
|
|
for key in self.observations.keys():
|
|
obs_ = np.array(obs[key]).copy()
|
|
# Reshape needed when using multiple envs with discrete observations
|
|
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
|
|
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
|
|
obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key])
|
|
self.observations[key][self.pos] = obs_
|
|
|
|
self.actions[self.pos] = np.array(action).copy()
|
|
self.rewards[self.pos] = np.array(reward).copy()
|
|
self.episode_starts[self.pos] = np.array(episode_start).copy()
|
|
self.values[self.pos] = value.clone().cpu().numpy().flatten()
|
|
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
|
|
self.pos += 1
|
|
if self.pos == self.buffer_size:
|
|
self.full = True
|
|
|
|
def get(self, batch_size: Optional[int] = None) -> Generator[DictRolloutBufferSamples, None, None]:
|
|
assert self.full, ""
|
|
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
|
# Prepare the data
|
|
if not self.generator_ready:
|
|
|
|
for key, obs in self.observations.items():
|
|
self.observations[key] = self.swap_and_flatten(obs)
|
|
|
|
_tensor_names = ["actions", "values", "log_probs", "advantages", "returns"]
|
|
|
|
for tensor in _tensor_names:
|
|
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
|
|
self.generator_ready = True
|
|
|
|
# Return everything, don't create minibatches
|
|
if batch_size is None:
|
|
batch_size = self.buffer_size * self.n_envs
|
|
|
|
start_idx = 0
|
|
while start_idx < self.buffer_size * self.n_envs:
|
|
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
|
start_idx += batch_size
|
|
|
|
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictRolloutBufferSamples:
|
|
|
|
return DictRolloutBufferSamples(
|
|
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
|
|
actions=self.to_torch(self.actions[batch_inds]),
|
|
old_values=self.to_torch(self.values[batch_inds].flatten()),
|
|
old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()),
|
|
advantages=self.to_torch(self.advantages[batch_inds].flatten()),
|
|
returns=self.to_torch(self.returns[batch_inds].flatten()),
|
|
)
|