stable-baselines3/stable_baselines3/her/her_replay_buffer.py
Antonin RAFFIN 40e0b9d2c8
Add Gymnasium support (#1327)
* Fix failing set_env test

* Fix test failiing due to deprectation of env.seed

* Adjust mean reward threshold in failing test

* Fix her test failing due to rng

* Change seed and revert reward threshold to 90

* Pin gym version

* Make VecEnv compatible with gym seeding change

* Revert change to VecEnv reset signature

* Change subprocenv seed cmd to call reset instead

* Fix type check

* Add backward compat

* Add `compat_gym_seed` helper

* Add goal env checks in env_checker

* Add docs on  HER requirements for envs

* Capture user warning in test with inverted box space

* Update ale-py version

* Fix randint

* Allow noop_max to be zero

* Update changelog

* Update docker image

* Update doc conda env and dockerfile

* Custom envs should not have any warnings

* Fix test for numpy >= 1.21

* Add check for vectorized compute reward

* Bump to gym 0.24

* Fix gym default step docstring

* Test downgrading gym

* Revert "Test downgrading gym"

This reverts commit 0072b77156c006ada8a1d6e26ce347ed85a83eeb.

* Fix protobuf error

* Fix in dependencies

* Fix protobuf dep

* Use newest version of cartpole

* Update gym

* Fix warning

* Loosen required scipy version

* Scipy no longer needed

* Try gym 0.25

* Silence warnings from gym

* Filter warnings during tests

* Update doc

* Update requirements

* Add gym 26 compat in vec env

* Fixes in envs and tests for gym 0.26+

* Enforce gym 0.26 api

* format

* Fix formatting

* Fix dependencies

* Fix syntax

* Cleanup doc and warnings

* Faster tests

* Higher budget for HER perf test (revert prev change)

* Fixes and update doc

* Fix doc build

* Fix breaking change

* Fixes for rendering

* Rename variables in monitor

* update render method for gym 0.26 API

backwards compatible (mode argument is allowed) while using the gym 0.26 API (render mode is determined at environment creation)

* update tests and docs to new gym render API

* undo removal of render modes metatadata check

* set rgb_array as default render mode for gym.make

* undo changes & raise warning if not 'rgb_array'

* Fix type check

* Remove recursion and fix type checking

* Remove hacks for protobuf and gym 0.24

* Fix type annotations

* reuse existing render_mode attribute

* return tiled images for 'human' render mode

* Allow to use opencv for human render, fix typos

* Add warning when using non-zero start with Discrete (fixes #1197)

* Fix type checking

* Bug fixes and handle more cases

* Throw proper warnings

* Update test

* Fix new metadata name

* Ignore numpy warnings

* Fixes in vec recorder

* Global ignore

* Filter local warning too

* Monkey patch not needed for gym 26

* Add doc of VecEnv vs Gym API

* Add render test

* Fix return type

* Update VecEnv vs Gym API doc

* Fix for custom render mode

* Fix return type

* Fix type checking

* check test env test_buffer

* skip render check

* check env test_dict_env

* test_env test_gae

* check envs in remaining tests

* Update tests

* Add warning for Discrete action space with non-zero (#1295)

* Fix atari annotation

* ignore get_action_meanings [attr-defined]

* Fix mypy issues

* Add patch for gym/gymnasium transition

* Switch to gymnasium

* Rely on signature instead of version

* More patches

* Type ignore because of https://github.com/Farama-Foundation/Gymnasium/pull/39

* Fix doc build

* Fix pytype errors

* Fix atari requirement

* Update env checker due to change in dtype for Discrete

* Fix type hint

* Convert spaces for saved models

* Ignore pytype

* Remove gitlab CI

* Disable pytype for convert space

* Fix undefined info

* Fix undefined info

* Upgrade shimmy

* Fix wrappers type annotation (need PR from Gymnasium)

* Fix gymnasium dependency

* Fix dependency declaration

* Cap pygame version for python 3.7

* Point to master branch (v0.28.0)

* Fix: use main not master branch

* Rename done to terminated

* Fix pygame dependency for python 3.7

* Rename gym to gymnasium

* Update Gymnasium

* Fix test

* Fix tests

* Forks don't have access to private variables

* Fix linter warnings

* Update read the doc env

* Fix env checker for GoalEnv

* Fix import

* Update env checker (more info) and fix dtype

* Use micromamab for Docker

* Update dependencies

* Clarify VecEnv doc

* Fix Gymnasium version

* Copy file only after mamba install

* [ci skip] Update docker doc

* Polish code

* Reformat

* Remove deprecated features

* Ignore warning

* Update doc

* Update examples and changelog

* Fix type annotation bundle (SAC, TD3, A2C, PPO, base class) (#1436)

* Fix SAC type hints, improve DQN ones

* Fix A2C and TD3 type hints

* Fix PPO type hints

* Fix on-policy type hints

* Fix base class type annotation, do not use defaults

* Update version

* Disable mypy for python 3.7

* Rename Gym26StepReturn

* Update continuous critic type annotation

* Fix pytype complain

---------

Co-authored-by: Carlos Luis <carlos.luisgonc@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Thomas Lips <37955681+tlpss@users.noreply.github.com>
Co-authored-by: tlips <thomas.lips@ugent.be>
Co-authored-by: tlpss <thomas17.lips@gmail.com>
Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
2023-04-14 13:13:59 +02:00

386 lines
17 KiB
Python

import copy
import warnings
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common.buffers import DictReplayBuffer
from stable_baselines3.common.type_aliases import DictReplayBufferSamples, TensorDict
from stable_baselines3.common.vec_env import VecEnv, VecNormalize
from stable_baselines3.her.goal_selection_strategy import KEY_TO_GOAL_STRATEGY, GoalSelectionStrategy
class HerReplayBuffer(DictReplayBuffer):
"""
Hindsight Experience Replay (HER) buffer.
Paper: https://arxiv.org/abs/1707.01495
Replay buffer for sampling HER (Hindsight Experience Replay) transitions.
.. note::
Compared to other implementations, the ``future`` goal sampling strategy is inclusive:
the current transition can be used when re-sampling.
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param env: The training environment
:param device: PyTorch device
:param n_envs: Number of parallel environments
:param optimize_memory_usage: Enable a memory efficient variant
Disabled for now (see https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702)
:param handle_timeout_termination: Handle timeout termination (due to timelimit)
separately and treat the task as infinite horizon task.
https://github.com/DLR-RM/stable-baselines3/issues/284
:param n_sampled_goal: Number of virtual transitions to create per real transition,
by sampling new goals.
:param goal_selection_strategy: Strategy for sampling goals for replay.
One of ['episode', 'final', 'future']
:param copy_info_dict: Whether to copy the info dictionary and pass it to
``compute_reward()`` method.
Please note that the copy may cause a slowdown.
False by default.
"""
def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
env: VecEnv,
device: Union[th.device, str] = "auto",
n_envs: int = 1,
optimize_memory_usage: bool = False,
handle_timeout_termination: bool = True,
n_sampled_goal: int = 4,
goal_selection_strategy: Union[GoalSelectionStrategy, str] = "future",
copy_info_dict: bool = False,
):
super().__init__(
buffer_size,
observation_space,
action_space,
device=device,
n_envs=n_envs,
optimize_memory_usage=optimize_memory_usage,
handle_timeout_termination=handle_timeout_termination,
)
self.env = env
self.copy_info_dict = copy_info_dict
# convert goal_selection_strategy into GoalSelectionStrategy if string
if isinstance(goal_selection_strategy, str):
self.goal_selection_strategy = KEY_TO_GOAL_STRATEGY[goal_selection_strategy.lower()]
else:
self.goal_selection_strategy = goal_selection_strategy
# check if goal_selection_strategy is valid
assert isinstance(
self.goal_selection_strategy, GoalSelectionStrategy
), f"Invalid goal selection strategy, please use one of {list(GoalSelectionStrategy)}"
self.n_sampled_goal = n_sampled_goal
# Compute ratio between HER replays and regular replays in percent
self.her_ratio = 1 - (1.0 / (self.n_sampled_goal + 1))
# In some environments, the info dict is used to compute the reward. Then, we need to store it.
self.infos = np.array([[{} for _ in range(self.n_envs)] for _ in range(self.buffer_size)])
# To create virtual transitions, we need to know for each transition
# when an episode starts and ends.
# We use the following arrays to store the indices,
# and update them when an episode ends.
self.ep_start = np.zeros((self.buffer_size, self.n_envs), dtype=np.int64)
self.ep_length = np.zeros((self.buffer_size, self.n_envs), dtype=np.int64)
self._current_ep_start = np.zeros(self.n_envs, dtype=np.int64)
def __getstate__(self) -> Dict[str, Any]:
"""
Gets state for pickling.
Excludes self.env, as in general Env's may not be pickleable.
"""
state = self.__dict__.copy()
# these attributes are not pickleable
del state["env"]
return state
def __setstate__(self, state: Dict[str, Any]) -> None:
"""
Restores pickled state.
User must call ``set_env()`` after unpickling before using.
:param state:
"""
self.__dict__.update(state)
assert "env" not in state
self.env = None
def set_env(self, env: VecEnv) -> None:
"""
Sets the environment.
:param env:
"""
if self.env is not None:
raise ValueError("Trying to set env of already initialized environment.")
self.env = env
def add(
self,
obs: TensorDict,
next_obs: TensorDict,
action: np.ndarray,
reward: np.ndarray,
done: np.ndarray,
infos: List[Dict[str, Any]],
) -> None:
# When the buffer is full, we rewrite on old episodes. When we start to
# rewrite on an old episodes, we want the whole old episode to be deleted
# (and not only the transition on which we rewrite). To do this, we set
# the length of the old episode to 0, so it can't be sampled anymore.
for env_idx in range(self.n_envs):
episode_start = self.ep_start[self.pos, env_idx]
episode_length = self.ep_length[self.pos, env_idx]
if episode_length > 0:
episode_end = episode_start + episode_length
episode_indices = np.arange(self.pos, episode_end) % self.buffer_size
self.ep_length[episode_indices, env_idx] = 0
# Update episode start
self.ep_start[self.pos] = self._current_ep_start.copy()
if self.copy_info_dict:
self.infos[self.pos] = infos
# Store the transition
super().add(obs, next_obs, action, reward, done, infos)
# When episode ends, compute and store the episode length
for env_idx in range(self.n_envs):
if done[env_idx]:
episode_start = self._current_ep_start[env_idx]
episode_end = self.pos
if episode_end < episode_start:
# Occurs when the buffer becomes full, the storage resumes at the
# beginning of the buffer. This can happen in the middle of an episode.
episode_end += self.buffer_size
episode_indices = np.arange(episode_start, episode_end) % self.buffer_size
self.ep_length[episode_indices, env_idx] = episode_end - episode_start
# Update the current episode start
self._current_ep_start[env_idx] = self.pos
def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples:
"""
Sample elements from the replay buffer.
:param batch_size: Number of element to sample
:param env: Associated VecEnv to normalize the observations/rewards when sampling
:return: Samples
"""
# When the buffer is full, we rewrite on old episodes. We don't want to
# sample incomplete episode transitions, so we have to eliminate some indexes.
is_valid = self.ep_length > 0
if not np.any(is_valid):
raise RuntimeError(
"Unable to sample before the end of the first episode. We recommend choosing a value "
"for learning_starts that is greater than the maximum number of timesteps in the environment."
)
# Get the indices of valid transitions
# Example:
# if is_valid = [[True, False, False], [True, False, True]],
# is_valid has shape (buffer_size=2, n_envs=3)
# then valid_indices = [0, 3, 5]
# they correspond to is_valid[0, 0], is_valid[1, 0] and is_valid[1, 2]
# or in numpy format ([rows], [columns]): (array([0, 1, 1]), array([0, 0, 2]))
# Those indices are obtained back using np.unravel_index(valid_indices, is_valid.shape)
valid_indices = np.flatnonzero(is_valid)
# Sample valid transitions that will constitute the minibatch of size batch_size
sampled_indices = np.random.choice(valid_indices, size=batch_size, replace=True)
# Unravel the indexes, i.e. recover the batch and env indices.
# Example: if sampled_indices = [0, 3, 5], then batch_indices = [0, 1, 1] and env_indices = [0, 0, 2]
batch_indices, env_indices = np.unravel_index(sampled_indices, is_valid.shape)
# Split the indexes between real and virtual transitions.
nb_virtual = int(self.her_ratio * batch_size)
virtual_batch_indices, real_batch_indices = np.split(batch_indices, [nb_virtual])
virtual_env_indices, real_env_indices = np.split(env_indices, [nb_virtual])
# Get real and virtual data
real_data = self._get_real_samples(real_batch_indices, real_env_indices, env)
# Create virtual transitions by sampling new desired goals and computing new rewards
virtual_data = self._get_virtual_samples(virtual_batch_indices, virtual_env_indices, env)
# Concatenate real and virtual data
observations = {
key: th.cat((real_data.observations[key], virtual_data.observations[key]))
for key in virtual_data.observations.keys()
}
actions = th.cat((real_data.actions, virtual_data.actions))
next_observations = {
key: th.cat((real_data.next_observations[key], virtual_data.next_observations[key]))
for key in virtual_data.next_observations.keys()
}
dones = th.cat((real_data.dones, virtual_data.dones))
rewards = th.cat((real_data.rewards, virtual_data.rewards))
return DictReplayBufferSamples(
observations=observations,
actions=actions,
next_observations=next_observations,
dones=dones,
rewards=rewards,
)
def _get_real_samples(
self,
batch_indices: np.ndarray,
env_indices: np.ndarray,
env: Optional[VecNormalize] = None,
) -> DictReplayBufferSamples:
"""
Get the samples corresponding to the batch and environment indices.
:param batch_indices: Indices of the transitions
:param env_indices: Indices of the envrionments
:param env: associated gym VecEnv to normalize the
observations/rewards when sampling, defaults to None
:return: Samples
"""
# Normalize if needed and remove extra dimension (we are using only one env for now)
obs_ = self._normalize_obs({key: obs[batch_indices, env_indices, :] for key, obs in self.observations.items()}, env)
next_obs_ = self._normalize_obs(
{key: obs[batch_indices, env_indices, :] for key, obs in self.next_observations.items()}, env
)
# Convert to torch tensor
observations = {key: self.to_torch(obs) for key, obs in obs_.items()}
next_observations = {key: self.to_torch(obs) for key, obs in next_obs_.items()}
return DictReplayBufferSamples(
observations=observations,
actions=self.to_torch(self.actions[batch_indices, env_indices]),
next_observations=next_observations,
# Only use dones that are not due to timeouts
# deactivated by default (timeouts is initialized as an array of False)
dones=self.to_torch(
self.dones[batch_indices, env_indices] * (1 - self.timeouts[batch_indices, env_indices])
).reshape(-1, 1),
rewards=self.to_torch(self._normalize_reward(self.rewards[batch_indices, env_indices].reshape(-1, 1), env)),
)
def _get_virtual_samples(
self,
batch_indices: np.ndarray,
env_indices: np.ndarray,
env: Optional[VecNormalize] = None,
) -> DictReplayBufferSamples:
"""
Get the samples, sample new desired goals and compute new rewards.
:param batch_indices: Indices of the transitions
:param env_indices: Indices of the envrionments
:param env: associated gym VecEnv to normalize the
observations/rewards when sampling, defaults to None
:return: Samples, with new desired goals and new rewards
"""
# Get infos and obs
obs = {key: obs[batch_indices, env_indices, :] for key, obs in self.observations.items()}
next_obs = {key: obs[batch_indices, env_indices, :] for key, obs in self.next_observations.items()}
if self.copy_info_dict:
# The copy may cause a slow down
infos = copy.deepcopy(self.infos[batch_indices, env_indices])
else:
infos = [{} for _ in range(len(batch_indices))]
# Sample and set new goals
new_goals = self._sample_goals(batch_indices, env_indices)
obs["desired_goal"] = new_goals
# The desired goal for the next observation must be the same as the previous one
next_obs["desired_goal"] = new_goals
# Compute new reward
rewards = self.env.env_method(
"compute_reward",
# here we use the new desired goal
obs["desired_goal"],
# the new state depends on the previous state and action
# s_{t+1} = f(s_t, a_t)
# so the next achieved_goal depends also on the previous state and action
# because we are in a GoalEnv:
# r_t = reward(s_t, a_t) = reward(next_achieved_goal, desired_goal)
# therefore we have to use next_obs["achieved_goal"] and not obs["achieved_goal"]
next_obs["achieved_goal"],
infos,
# we use the method of the first environment assuming that all environments are identical.
indices=[0],
)
rewards = rewards[0].astype(np.float32) # env_method returns a list containing one element
obs = self._normalize_obs(obs, env)
next_obs = self._normalize_obs(next_obs, env)
# Convert to torch tensor
observations = {key: self.to_torch(obs) for key, obs in obs.items()}
next_observations = {key: self.to_torch(obs) for key, obs in next_obs.items()}
return DictReplayBufferSamples(
observations=observations,
actions=self.to_torch(self.actions[batch_indices, env_indices]),
next_observations=next_observations,
# Only use dones that are not due to timeouts
# deactivated by default (timeouts is initialized as an array of False)
dones=self.to_torch(
self.dones[batch_indices, env_indices] * (1 - self.timeouts[batch_indices, env_indices])
).reshape(-1, 1),
rewards=self.to_torch(self._normalize_reward(rewards.reshape(-1, 1), env)),
)
def _sample_goals(self, batch_indices: np.ndarray, env_indices: np.ndarray) -> np.ndarray:
"""
Sample goals based on goal_selection_strategy.
:param batch_indices: Indices of the transitions
:param env_indices: Indices of the envrionments
:return: Sampled goals
"""
batch_ep_start = self.ep_start[batch_indices, env_indices]
batch_ep_length = self.ep_length[batch_indices, env_indices]
if self.goal_selection_strategy == GoalSelectionStrategy.FINAL:
# Replay with final state of current episode
transition_indices_in_episode = batch_ep_length - 1
elif self.goal_selection_strategy == GoalSelectionStrategy.FUTURE:
# Replay with random state which comes from the same episode and was observed after current transition
# Note: our implementation is inclusive: current transition can be sampled
current_indices_in_episode = (batch_indices - batch_ep_start) % self.buffer_size
transition_indices_in_episode = np.random.randint(current_indices_in_episode, batch_ep_length)
elif self.goal_selection_strategy == GoalSelectionStrategy.EPISODE:
# Replay with random state which comes from the same episode as current transition
transition_indices_in_episode = np.random.randint(0, batch_ep_length)
else:
raise ValueError(f"Strategy {self.goal_selection_strategy} for sampling goals not supported!")
transition_indices = (transition_indices_in_episode + batch_ep_start) % self.buffer_size
return self.next_observations["achieved_goal"][transition_indices, env_indices]
def truncate_last_trajectory(self) -> None:
"""
If called, we assume that the last trajectory in the replay buffer was finished
(and truncate it).
If not called, we assume that we continue the same trajectory (same episode).
"""
# If we are at the start of an episode, no need to truncate
if (self.ep_start[self.pos] != self.pos).any():
warnings.warn(
"The last trajectory in the replay buffer will be truncated.\n"
"If you are in the same episode as when the replay buffer was saved,\n"
"you should use `truncate_last_trajectory=False` to avoid that issue."
)
self.ep_start[-1] = self.pos
# set done = True for current episodes
self.dones[self.pos - 1] = True