stable-baselines3/stable_baselines3/common/buffers.py
Jaden Travnik 75b6f3b3b0
Dictionary Observations (#243)
* First commit

* Fixing missing refs from a quick merge from master

* Reformat

* Adding DictBuffers

* Reformat

* Minor reformat

* added slow dict test. Added SACMultiInputPolicy for future. Added private static image transpose helper to common policy

* Ran black on buffers

* Ran isort

* Adding StackedObservations classes used within VecStackEnvs wrappers. Made test_dict_env shorter and removed slow

* Running isort :facepalm

* Fixed typing issues

* Adding docstrings and typing. Using util for moving data to device.

* Fixed trailing commas

* Fix types

* Minor edits

* Avoid duplicating code

* Fix calls to parents

* Adding assert to buffers. Updating changelong

* Running format on buffers

* Adding multi-input policies to dqn,td3,a2c. Fixing warnings. Fixed bug with DictReplayBuffer as Replay buffers use only 1 env

* Fixing warnings, splitting is_vectorized_observation into multiple functions based on space type

* Created envs folder in common. Updated imports. Moved stacked_obs to vec_env folder

* Moved envs to envs directory. Moved stacked obs to vec_envs. Started update on documentation

* Fixes

* Running code style

* Update docstrings on torch_layers

* Decapitalize non-constant variables

* Using NatureCNN architecture in combined extractor. Increasing img size in multi input env. Adding memory reduction in test

* Update doc

* Update doc

* Fix format

* Removing NineRoom env. Using nested preprocess. Removing mutable default args

* running code style

* Passing channel check through to stacked dict observations.

* Running black

* Adding channel control to SimpleMultiObsEnv. Passing check_channels to CombinedExtractor

* Remove optimize memory for dict buffers

* Update doc

* Move identity env

* Minor edits + bump version

* Update doc

* Fix doc build

* Bug fixes + add support for more type of dict env

* Fixes + add multi env test

* Add support for vectranspose

* Fix stacked obs for dict and add tests

* Add check for nested spaces. Fix dict-subprocvecenv test

* Fix (single) pytype error

* Simplify CombinedExtractor

* Fix tests

* Fix check

* Merge branch 'master' into feat/dict_observations

* Fix for net_arch with dict and vector obs

* Fixes

* Add consistency test

* Update env checker

* Add some docs on dict obs

* Update default CNN feature vector size

* Refactor HER (#351)

* Start refactoring HER

* Fixes

* Additional fixes

* Faster tests

* WIP: HER as a custom replay buffer

* New replay only version (working with DQN)

* Add support for all off-policy algorithms

* Fix saving/loading

* Remove ObsDictWrapper and add VecNormalize tests with dict

* Stable-Baselines3 v1.0 (#354)

* Bump version and update doc

* Fix name

* Apply suggestions from code review

Co-authored-by: Adam Gleave <adam@gleave.me>

* Update docs/index.rst

Co-authored-by: Adam Gleave <adam@gleave.me>

* Update wording for RL zoo

Co-authored-by: Adam Gleave <adam@gleave.me>

* Add gym-pybullet-drones project (#358)

* Update projects.rst

Added gym-pybullet-drones

* Update projects.rst

Longer title underline

* Update changelog

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>

* Include SuperSuit in projects (#359)

* include supersuit

* longer title underline

* Update changelog.rst

* Fix default arguments + add bugbear (#363)

* Fix potential bug + add bug bear

* Remove unused variables

* Minor: version bump

* Add code of conduct + update doc (#373)

* Add code of conduct

* Fix DQN doc example

* Update doc (channel-last/first)

* Apply suggestions from code review

Co-authored-by: Anssi <kaneran21@hotmail.com>

* Apply suggestions from code review

Co-authored-by: Adam Gleave <adam@gleave.me>

Co-authored-by: Anssi <kaneran21@hotmail.com>
Co-authored-by: Adam Gleave <adam@gleave.me>

* Make installation command compatible with ZSH (#376)

* Add quotes

* Add Zsh bracket info

* Add clarify pip installation line

* Make note bold

* Add Zsh pip installation note

* Add handle timeouts param

* Fixes

* Fixes (buffer size, extend test)

* Fix `max_episode_length` redefinition

* Fix potential issue

* Add some docs on dict obs

* Fix performance bug

* Fix slowdown

* Add package to install (#378)

* Add package to install

* Update docs packages installation command

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Fix backward compat + add test

* Fix VecEnv detection

* Update doc

* Fix vec env check

* Support for `VecMonitor` for gym3-style environments (#311)

* add vectorized monitor

* auto format of the code

* add documentation and VecExtractDictObs

* refactor and add test cases

* add test cases and format

* avoid circular import and fix doc

* fix type

* fix type

* oops

* Update stable_baselines3/common/monitor.py

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Update stable_baselines3/common/monitor.py

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* add test cases

* update changelog

* fix mutable argument

* quick fix

* Apply suggestions from code review

* fix terminal observation for gym3 envs

* delete comment

* Update doc and bump version

* Add warning when already using `Monitor` wrapper

* Update vecmonitor tests

* Fixes

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Reformat

* Fixed loading of ``ent_coef`` for ``SAC`` and ``TQC``, it was not optimized anymore (#392)

* Fix ent coef loading bug

* Add test

* Add comment

* Reuse save path

* Add test for GAE + rename `RolloutBuffer.dones` for clarification (#375)

* Fix return computation + add test for GAE

* Rename `last_dones` to `episode_starts` for clarification

* Revert advantage

* Cleanup test

* Rename variable

* Clarify return computation

* Clarify docs

* Add multi-episode rollout test

* Reformat

Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com>

* Fixed saving of `A2C` and `PPO` policy when using gSDE (#401)

* Improve doc and replay buffer loading

* Add support for images

* Fix doc

* Update Procgen doc

* Update changelog

* Update docstrings

Co-authored-by: Adam Gleave <adam@gleave.me>
Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca>
Co-authored-by: Justin Terry <justinkterry@gmail.com>
Co-authored-by: Anssi <kaneran21@hotmail.com>
Co-authored-by: Tom Dörr <tomdoerr96@gmail.com>
Co-authored-by: Tom Dörr <tom.doerr@tum.de>
Co-authored-by: Costa Huang <costa.huang@outlook.com>

* Update doc and minor fixes

* Update doc

* Added note about MultiInputPolicy in error of NatureCNN

* Merge branch 'master' into feat/dict_observations

* Address comments

* Naming clarifications

* Actually saving the file would be nice

* Fix edge case when doing online sampling with HER

* Cleanup

* Add sanity check

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com>
Co-authored-by: Adam Gleave <adam@gleave.me>
Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca>
Co-authored-by: Justin Terry <justinkterry@gmail.com>
Co-authored-by: Tom Dörr <tomdoerr96@gmail.com>
Co-authored-by: Tom Dörr <tom.doerr@tum.de>
Co-authored-by: Costa Huang <costa.huang@outlook.com>
2021-05-11 12:29:30 +02:00

739 lines
30 KiB
Python

import warnings
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional, Union
import numpy as np
import torch as th
from gym import spaces
from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape
from stable_baselines3.common.type_aliases import (
DictReplayBufferSamples,
DictRolloutBufferSamples,
ReplayBufferSamples,
RolloutBufferSamples,
)
from stable_baselines3.common.vec_env import VecNormalize
try:
# Check memory used by replay buffer when possible
import psutil
except ImportError:
psutil = None
class BaseBuffer(ABC):
"""
Base class that represent a buffer (rollout or replay)
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device: PyTorch device
to which the values will be converted
:param n_envs: Number of parallel environments
"""
def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "cpu",
n_envs: int = 1,
):
super(BaseBuffer, self).__init__()
self.buffer_size = buffer_size
self.observation_space = observation_space
self.action_space = action_space
self.obs_shape = get_obs_shape(observation_space)
self.action_dim = get_action_dim(action_space)
self.pos = 0
self.full = False
self.device = device
self.n_envs = n_envs
@staticmethod
def swap_and_flatten(arr: np.ndarray) -> np.ndarray:
"""
Swap and then flatten axes 0 (buffer_size) and 1 (n_envs)
to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features)
to [n_steps * n_envs, ...] (which maintain the order)
:param arr:
:return:
"""
shape = arr.shape
if len(shape) < 3:
shape = shape + (1,)
return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:])
def size(self) -> int:
"""
:return: The current size of the buffer
"""
if self.full:
return self.buffer_size
return self.pos
def add(self, *args, **kwargs) -> None:
"""
Add elements to the buffer.
"""
raise NotImplementedError()
def extend(self, *args, **kwargs) -> None:
"""
Add a new batch of transitions to the buffer
"""
# Do a for loop along the batch axis
for data in zip(*args):
self.add(*data)
def reset(self) -> None:
"""
Reset the buffer.
"""
self.pos = 0
self.full = False
def sample(self, batch_size: int, env: Optional[VecNormalize] = None):
"""
:param batch_size: Number of element to sample
:param env: associated gym VecEnv
to normalize the observations/rewards when sampling
:return:
"""
upper_bound = self.buffer_size if self.full else self.pos
batch_inds = np.random.randint(0, upper_bound, size=batch_size)
return self._get_samples(batch_inds, env=env)
@abstractmethod
def _get_samples(
self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None
) -> Union[ReplayBufferSamples, RolloutBufferSamples]:
"""
:param batch_inds:
:param env:
:return:
"""
raise NotImplementedError()
def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:
"""
Convert a numpy array to a PyTorch tensor.
Note: it copies the data by default
:param array:
:param copy: Whether to copy or not the data
(may be useful to avoid changing things be reference)
:return:
"""
if copy:
return th.tensor(array).to(self.device)
return th.as_tensor(array).to(self.device)
@staticmethod
def _normalize_obs(
obs: Union[np.ndarray, Dict[str, np.ndarray]],
env: Optional[VecNormalize] = None,
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
if env is not None:
return env.normalize_obs(obs)
return obs
@staticmethod
def _normalize_reward(reward: np.ndarray, env: Optional[VecNormalize] = None) -> np.ndarray:
if env is not None:
return env.normalize_reward(reward).astype(np.float32)
return reward
class ReplayBuffer(BaseBuffer):
"""
Replay buffer used in off-policy algorithms like SAC/TD3.
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device:
:param n_envs: Number of parallel environments
:param optimize_memory_usage: Enable a memory efficient variant
of the replay buffer which reduces by almost a factor two the memory used,
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
and https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
:param handle_timeout_termination: Handle timeout termination (due to timelimit)
separately and treat the task as infinite horizon task.
https://github.com/DLR-RM/stable-baselines3/issues/284
"""
def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "cpu",
n_envs: int = 1,
optimize_memory_usage: bool = False,
handle_timeout_termination: bool = True,
):
super(ReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
assert n_envs == 1, "Replay buffer only support single environment for now"
# Check that the replay buffer can fit into the memory
if psutil is not None:
mem_available = psutil.virtual_memory().available
self.optimize_memory_usage = optimize_memory_usage
self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype)
if optimize_memory_usage:
# `observations` contains also the next observation
self.next_observations = None
else:
self.next_observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype)
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
# Handle timeouts termination properly if needed
# see https://github.com/DLR-RM/stable-baselines3/issues/284
self.handle_timeout_termination = handle_timeout_termination
self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
if psutil is not None:
total_memory_usage = self.observations.nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
if self.next_observations is not None:
total_memory_usage += self.next_observations.nbytes
if total_memory_usage > mem_available:
# Convert to GB
total_memory_usage /= 1e9
mem_available /= 1e9
warnings.warn(
"This system does not have apparently enough memory to store the complete "
f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
)
def add(
self,
obs: np.ndarray,
next_obs: np.ndarray,
action: np.ndarray,
reward: np.ndarray,
done: np.ndarray,
infos: List[Dict[str, Any]],
) -> None:
# Copy to avoid modification by reference
self.observations[self.pos] = np.array(obs).copy()
if self.optimize_memory_usage:
self.observations[(self.pos + 1) % self.buffer_size] = np.array(next_obs).copy()
else:
self.next_observations[self.pos] = np.array(next_obs).copy()
self.actions[self.pos] = np.array(action).copy()
self.rewards[self.pos] = np.array(reward).copy()
self.dones[self.pos] = np.array(done).copy()
if self.handle_timeout_termination:
self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
self.pos += 1
if self.pos == self.buffer_size:
self.full = True
self.pos = 0
def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
"""
Sample elements from the replay buffer.
Custom sampling when using memory efficient variant,
as we should not sample the element with index `self.pos`
See https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
:param batch_size: Number of element to sample
:param env: associated gym VecEnv
to normalize the observations/rewards when sampling
:return:
"""
if not self.optimize_memory_usage:
return super().sample(batch_size=batch_size, env=env)
# Do not sample the element with index `self.pos` as the transitions is invalid
# (we use only one array to store `obs` and `next_obs`)
if self.full:
batch_inds = (np.random.randint(1, self.buffer_size, size=batch_size) + self.pos) % self.buffer_size
else:
batch_inds = np.random.randint(0, self.pos, size=batch_size)
return self._get_samples(batch_inds, env=env)
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
if self.optimize_memory_usage:
next_obs = self._normalize_obs(self.observations[(batch_inds + 1) % self.buffer_size, 0, :], env)
else:
next_obs = self._normalize_obs(self.next_observations[batch_inds, 0, :], env)
data = (
self._normalize_obs(self.observations[batch_inds, 0, :], env),
self.actions[batch_inds, 0, :],
next_obs,
# Only use dones that are not due to timeouts
# deactivated by default (timeouts is initialized as an array of False)
self.dones[batch_inds] * (1 - self.timeouts[batch_inds]),
self._normalize_reward(self.rewards[batch_inds], env),
)
return ReplayBufferSamples(*tuple(map(self.to_torch, data)))
class RolloutBuffer(BaseBuffer):
"""
Rollout buffer used in on-policy algorithms like A2C/PPO.
It corresponds to ``buffer_size`` transitions collected
using the current policy.
This experience will be discarded after the policy update.
In order to use PPO objective, we also store the current value of each state
and the log probability of each taken action.
The term rollout here refers to the model-free notion and should not
be used with the concept of rollout used in model-based RL or planning.
Hence, it is only involved in policy and value function training but not action selection.
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device:
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to classic advantage when set to 1.
:param gamma: Discount factor
:param n_envs: Number of parallel environments
"""
def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "cpu",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
):
super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
self.gae_lambda = gae_lambda
self.gamma = gamma
self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None
self.generator_ready = False
self.reset()
def reset(self) -> None:
self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=np.float32)
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.generator_ready = False
super(RolloutBuffer, self).reset()
def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None:
"""
Post-processing step: compute the lambda-return (TD(lambda) estimate)
and GAE(lambda) advantage.
Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
to compute the advantage. To obtain vanilla advantage (A(s) = R - V(S))
where R is the discounted reward with value bootstrap,
set ``gae_lambda=1.0`` during initialization.
The TD(lambda) estimator has also two special cases:
- TD(1) is Monte-Carlo estimate (sum of discounted rewards)
- TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1}))
For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375.
:param last_values: state value estimation for the last step (one for each env)
:param dones: if the last step was a terminal step (one bool for each env).
"""
# Convert to numpy
last_values = last_values.clone().cpu().numpy().flatten()
last_gae_lam = 0
for step in reversed(range(self.buffer_size)):
if step == self.buffer_size - 1:
next_non_terminal = 1.0 - dones
next_values = last_values
else:
next_non_terminal = 1.0 - self.episode_starts[step + 1]
next_values = self.values[step + 1]
delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
self.advantages[step] = last_gae_lam
# TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
# in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
self.returns = self.advantages + self.values
def add(
self,
obs: np.ndarray,
action: np.ndarray,
reward: np.ndarray,
episode_start: np.ndarray,
value: th.Tensor,
log_prob: th.Tensor,
) -> None:
"""
:param obs: Observation
:param action: Action
:param reward:
:param episode_start: Start of episode signal.
:param value: estimated value of the current state
following the current policy.
:param log_prob: log probability of the action
following the current policy.
"""
if len(log_prob.shape) == 0:
# Reshape 0-d tensor to avoid error
log_prob = log_prob.reshape(-1, 1)
# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space, spaces.Discrete):
obs = obs.reshape((self.n_envs,) + self.obs_shape)
self.observations[self.pos] = np.array(obs).copy()
self.actions[self.pos] = np.array(action).copy()
self.rewards[self.pos] = np.array(reward).copy()
self.episode_starts[self.pos] = np.array(episode_start).copy()
self.values[self.pos] = value.clone().cpu().numpy().flatten()
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
self.pos += 1
if self.pos == self.buffer_size:
self.full = True
def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:
assert self.full, ""
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
if not self.generator_ready:
_tensor_names = [
"observations",
"actions",
"values",
"log_probs",
"advantages",
"returns",
]
for tensor in _tensor_names:
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
self.generator_ready = True
# Return everything, don't create minibatches
if batch_size is None:
batch_size = self.buffer_size * self.n_envs
start_idx = 0
while start_idx < self.buffer_size * self.n_envs:
yield self._get_samples(indices[start_idx : start_idx + batch_size])
start_idx += batch_size
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> RolloutBufferSamples:
data = (
self.observations[batch_inds],
self.actions[batch_inds],
self.values[batch_inds].flatten(),
self.log_probs[batch_inds].flatten(),
self.advantages[batch_inds].flatten(),
self.returns[batch_inds].flatten(),
)
return RolloutBufferSamples(*tuple(map(self.to_torch, data)))
class DictReplayBuffer(ReplayBuffer):
"""
Dict Replay buffer used in off-policy algorithms like SAC/TD3.
Extends the ReplayBuffer to use dictionary observations
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device:
:param n_envs: Number of parallel environments
:param optimize_memory_usage: Enable a memory efficient variant
Disabled for now (see https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702)
:param handle_timeout_termination: Handle timeout termination (due to timelimit)
separately and treat the task as infinite horizon task.
https://github.com/DLR-RM/stable-baselines3/issues/284
"""
def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "cpu",
n_envs: int = 1,
optimize_memory_usage: bool = False,
handle_timeout_termination: bool = True,
):
super(ReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
assert isinstance(self.obs_shape, dict), "DictReplayBuffer must be used with Dict obs space only"
assert n_envs == 1, "Replay buffer only support single environment for now"
# Check that the replay buffer can fit into the memory
if psutil is not None:
mem_available = psutil.virtual_memory().available
assert optimize_memory_usage is False, "DictReplayBuffer does not support optimize_memory_usage"
# disabling as this adds quite a bit of complexity
# https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702
self.optimize_memory_usage = optimize_memory_usage
self.observations = {
key: np.zeros((self.buffer_size, self.n_envs) + _obs_shape) for key, _obs_shape in self.obs_shape.items()
}
self.next_observations = {
key: np.zeros((self.buffer_size, self.n_envs) + _obs_shape) for key, _obs_shape in self.obs_shape.items()
}
# only 1 env is supported
self.actions = np.zeros((self.buffer_size, self.action_dim), dtype=action_space.dtype)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
# Handle timeouts termination properly if needed
# see https://github.com/DLR-RM/stable-baselines3/issues/284
self.handle_timeout_termination = handle_timeout_termination
self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
if psutil is not None:
obs_nbytes = 0
for _, obs in self.observations.items():
obs_nbytes += obs.nbytes
total_memory_usage = obs_nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
if self.next_observations is not None:
next_obs_nbytes = 0
for _, obs in self.observations.items():
next_obs_nbytes += obs.nbytes
total_memory_usage += next_obs_nbytes
if total_memory_usage > mem_available:
# Convert to GB
total_memory_usage /= 1e9
mem_available /= 1e9
warnings.warn(
"This system does not have apparently enough memory to store the complete "
f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
)
def add(
self,
obs: Dict[str, np.ndarray],
next_obs: Dict[str, np.ndarray],
action: np.ndarray,
reward: np.ndarray,
done: np.ndarray,
infos: List[Dict[str, Any]],
) -> None:
# Copy to avoid modification by reference
for key in self.observations.keys():
self.observations[key][self.pos] = np.array(obs[key]).copy()
for key in self.next_observations.keys():
self.next_observations[key][self.pos] = np.array(next_obs[key]).copy()
self.actions[self.pos] = np.array(action).copy()
self.rewards[self.pos] = np.array(reward).copy()
self.dones[self.pos] = np.array(done).copy()
if self.handle_timeout_termination:
self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
self.pos += 1
if self.pos == self.buffer_size:
self.full = True
self.pos = 0
def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples:
"""
Sample elements from the replay buffer.
:param batch_size: Number of element to sample
:param env: associated gym VecEnv
to normalize the observations/rewards when sampling
:return:
"""
return super(ReplayBuffer, self).sample(batch_size=batch_size, env=env)
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples:
# Normalize if needed and remove extra dimension (we are using only one env for now)
obs_ = self._normalize_obs({key: obs[batch_inds, 0, :] for key, obs in self.observations.items()})
next_obs_ = self._normalize_obs({key: obs[batch_inds, 0, :] for key, obs in self.next_observations.items()})
# Convert to torch tensor
observations = {key: self.to_torch(obs) for key, obs in obs_.items()}
next_observations = {key: self.to_torch(obs) for key, obs in next_obs_.items()}
return DictReplayBufferSamples(
observations=observations,
actions=self.to_torch(self.actions[batch_inds]),
next_observations=next_observations,
# Only use dones that are not due to timeouts
# deactivated by default (timeouts is initialized as an array of False)
dones=self.to_torch(self.dones[batch_inds] * (1 - self.timeouts[batch_inds])),
rewards=self.to_torch(self._normalize_reward(self.rewards[batch_inds], env)),
)
class DictRolloutBuffer(RolloutBuffer):
"""
Dict Rollout buffer used in on-policy algorithms like A2C/PPO.
Extends the RolloutBuffer to use dictionary observations
It corresponds to ``buffer_size`` transitions collected
using the current policy.
This experience will be discarded after the policy update.
In order to use PPO objective, we also store the current value of each state
and the log probability of each taken action.
The term rollout here refers to the model-free notion and should not
be used with the concept of rollout used in model-based RL or planning.
Hence, it is only involved in policy and value function training but not action selection.
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device:
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to classic advantage when set to 1.
:param gamma: Discount factor
:param n_envs: Number of parallel environments
"""
def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "cpu",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
):
super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
self.gae_lambda = gae_lambda
self.gamma = gamma
self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None
self.generator_ready = False
self.reset()
def reset(self) -> None:
assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
self.observations = {}
for key, obs_input_shape in self.obs_shape.items():
self.observations[key] = np.zeros((self.buffer_size, self.n_envs) + obs_input_shape, dtype=np.float32)
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.generator_ready = False
super(RolloutBuffer, self).reset()
def add(
self,
obs: Dict[str, np.ndarray],
action: np.ndarray,
reward: np.ndarray,
episode_start: np.ndarray,
value: th.Tensor,
log_prob: th.Tensor,
) -> None:
"""
:param obs: Observation
:param action: Action
:param reward:
:param episode_start: Start of episode signal.
:param value: estimated value of the current state
following the current policy.
:param log_prob: log probability of the action
following the current policy.
"""
if len(log_prob.shape) == 0:
# Reshape 0-d tensor to avoid error
log_prob = log_prob.reshape(-1, 1)
for key in self.observations.keys():
obs_ = np.array(obs[key]).copy()
# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key])
self.observations[key][self.pos] = obs_
self.actions[self.pos] = np.array(action).copy()
self.rewards[self.pos] = np.array(reward).copy()
self.episode_starts[self.pos] = np.array(episode_start).copy()
self.values[self.pos] = value.clone().cpu().numpy().flatten()
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
self.pos += 1
if self.pos == self.buffer_size:
self.full = True
def get(self, batch_size: Optional[int] = None) -> Generator[DictRolloutBufferSamples, None, None]:
assert self.full, ""
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
if not self.generator_ready:
for key, obs in self.observations.items():
self.observations[key] = self.swap_and_flatten(obs)
_tensor_names = ["actions", "values", "log_probs", "advantages", "returns"]
for tensor in _tensor_names:
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
self.generator_ready = True
# Return everything, don't create minibatches
if batch_size is None:
batch_size = self.buffer_size * self.n_envs
start_idx = 0
while start_idx < self.buffer_size * self.n_envs:
yield self._get_samples(indices[start_idx : start_idx + batch_size])
start_idx += batch_size
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictRolloutBufferSamples:
return DictRolloutBufferSamples(
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
actions=self.to_torch(self.actions[batch_inds]),
old_values=self.to_torch(self.values[batch_inds].flatten()),
old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()),
advantages=self.to_torch(self.advantages[batch_inds].flatten()),
returns=self.to_torch(self.returns[batch_inds].flatten()),
)