mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-26 22:45:15 +00:00
* Fix bug with full replay buffer * Updated changelog * Update tests/test_her.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
376 lines
16 KiB
Python
376 lines
16 KiB
Python
from collections import deque
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
import torch as th
|
|
from gym import spaces
|
|
|
|
from stable_baselines3.common.buffers import ReplayBuffer
|
|
from stable_baselines3.common.type_aliases import ReplayBufferSamples, RolloutBufferSamples
|
|
from stable_baselines3.common.vec_env import VecNormalize
|
|
from stable_baselines3.common.vec_env.obs_dict_wrapper import ObsDictWrapper
|
|
from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy
|
|
|
|
|
|
class HerReplayBuffer(ReplayBuffer):
|
|
"""
|
|
Replay buffer for sampling HER (Hindsight Experience Replay) transitions.
|
|
In the online sampling case, these new transitions will not be saved in the replay buffer
|
|
and will only be created at sampling time.
|
|
|
|
:param env: The training environment
|
|
:param buffer_size: The size of the buffer measured in transitions.
|
|
:param max_episode_length: The length of an episode. (time horizon)
|
|
:param goal_selection_strategy: Strategy for sampling goals for replay.
|
|
One of ['episode', 'final', 'future']
|
|
:param observation_space: Observation space
|
|
:param action_space: Action space
|
|
:param device: PyTorch device
|
|
:param n_envs: Number of parallel environments
|
|
:her_ratio: The ratio between HER transitions and regular transitions in percent
|
|
(between 0 and 1, for online sampling)
|
|
The default value ``her_ratio=0.8`` corresponds to 4 virtual transitions
|
|
for one real transition (4 / (4 + 1) = 0.8)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
env: ObsDictWrapper,
|
|
buffer_size: int,
|
|
max_episode_length: int,
|
|
goal_selection_strategy: GoalSelectionStrategy,
|
|
observation_space: spaces.Space,
|
|
action_space: spaces.Space,
|
|
device: Union[th.device, str] = "cpu",
|
|
n_envs: int = 1,
|
|
her_ratio: float = 0.8,
|
|
):
|
|
|
|
super(HerReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs)
|
|
|
|
self.env = env
|
|
self.buffer_size = buffer_size
|
|
self.max_episode_length = max_episode_length
|
|
|
|
# buffer with episodes
|
|
# number of episodes which can be stored until buffer size is reached
|
|
self.max_episode_stored = self.buffer_size // self.max_episode_length
|
|
self.current_idx = 0
|
|
|
|
# input dimensions for buffer initialization
|
|
input_shape = {
|
|
"observation": (self.env.num_envs, self.env.obs_dim),
|
|
"achieved_goal": (self.env.num_envs, self.env.goal_dim),
|
|
"desired_goal": (self.env.num_envs, self.env.goal_dim),
|
|
"action": (self.action_dim,),
|
|
"reward": (1,),
|
|
"next_obs": (self.env.num_envs, self.env.obs_dim),
|
|
"next_achieved_goal": (self.env.num_envs, self.env.goal_dim),
|
|
"next_desired_goal": (self.env.num_envs, self.env.goal_dim),
|
|
"done": (1,),
|
|
}
|
|
self.buffer = {
|
|
key: np.zeros((self.max_episode_stored, self.max_episode_length, *dim), dtype=np.float32)
|
|
for key, dim in input_shape.items()
|
|
}
|
|
# Store info dicts are it can be used to compute the reward (e.g. continuity cost)
|
|
self.info_buffer = [deque(maxlen=self.max_episode_length) for _ in range(self.max_episode_stored)]
|
|
# episode length storage, needed for episodes which has less steps than the maximum length
|
|
self.episode_lengths = np.zeros(self.max_episode_stored, dtype=np.int64)
|
|
|
|
self.goal_selection_strategy = goal_selection_strategy
|
|
# percentage of her indices
|
|
self.her_ratio = her_ratio
|
|
|
|
def __getstate__(self) -> Dict[str, Any]:
|
|
"""
|
|
Gets state for pickling.
|
|
|
|
Excludes self.env, as in general Env's may not be pickleable."""
|
|
state = self.__dict__.copy()
|
|
# these attributes are not pickleable
|
|
del state["env"]
|
|
return state
|
|
|
|
def __setstate__(self, state: Dict[str, Any]) -> None:
|
|
"""
|
|
Restores pickled state.
|
|
|
|
User must call ``set_env()`` after unpickling before using.
|
|
|
|
:param state:
|
|
"""
|
|
self.__dict__.update(state)
|
|
assert "env" not in state
|
|
self.env = None
|
|
|
|
def set_env(self, env: ObsDictWrapper) -> None:
|
|
"""
|
|
Sets the environment.
|
|
|
|
:param env:
|
|
"""
|
|
if self.env is not None:
|
|
raise ValueError("Trying to set env of already initialized environment.")
|
|
|
|
self.env = env
|
|
|
|
def _get_samples(
|
|
self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None
|
|
) -> Union[ReplayBufferSamples, RolloutBufferSamples]:
|
|
"""
|
|
Abstract method from base class.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
def sample(
|
|
self,
|
|
batch_size: int,
|
|
env: Optional[VecNormalize],
|
|
) -> Union[ReplayBufferSamples, Tuple[np.ndarray, ...]]:
|
|
"""
|
|
Sample function for online sampling of HER transition,
|
|
this replaces the "regular" replay buffer ``sample()``
|
|
method in the ``train()`` function.
|
|
|
|
:param batch_size: Number of element to sample
|
|
:param env: Associated gym VecEnv
|
|
to normalize the observations/rewards when sampling
|
|
:return: Samples.
|
|
"""
|
|
return self._sample_transitions(batch_size, maybe_vec_env=env, online_sampling=True)
|
|
|
|
def sample_offline(
|
|
self,
|
|
n_sampled_goal: Optional[int] = None,
|
|
) -> Union[ReplayBufferSamples, Tuple[np.ndarray, ...]]:
|
|
"""
|
|
Sample function for offline sampling of HER transition,
|
|
in that case, only one episode is used and transitions
|
|
are added to the regular replay buffer.
|
|
|
|
:param n_sampled_goal: Number of sampled goals for replay
|
|
:return: at most(n_sampled_goal * episode_length) HER transitions.
|
|
"""
|
|
# env=None as we should store unnormalized transitions, they will be normalized at sampling time
|
|
return self._sample_transitions(
|
|
batch_size=None, maybe_vec_env=None, online_sampling=False, n_sampled_goal=n_sampled_goal
|
|
)
|
|
|
|
def sample_goals(
|
|
self,
|
|
episode_indices: np.ndarray,
|
|
her_indices: np.ndarray,
|
|
transitions_indices: np.ndarray,
|
|
) -> np.ndarray:
|
|
"""
|
|
Sample goals based on goal_selection_strategy.
|
|
This is a vectorized (fast) version.
|
|
|
|
:param episode_indices: Episode indices to use.
|
|
:param her_indices: HER indices.
|
|
:param transitions_indices: Transition indices to use.
|
|
:return: Return sampled goals.
|
|
"""
|
|
her_episode_indices = episode_indices[her_indices]
|
|
|
|
if self.goal_selection_strategy == GoalSelectionStrategy.FINAL:
|
|
# replay with final state of current episode
|
|
transitions_indices = self.episode_lengths[her_episode_indices] - 1
|
|
|
|
elif self.goal_selection_strategy == GoalSelectionStrategy.FUTURE:
|
|
# replay with random state which comes from the same episode and was observed after current transition
|
|
transitions_indices = np.random.randint(
|
|
transitions_indices[her_indices] + 1, self.episode_lengths[her_episode_indices]
|
|
)
|
|
|
|
elif self.goal_selection_strategy == GoalSelectionStrategy.EPISODE:
|
|
# replay with random state which comes from the same episode as current transition
|
|
transitions_indices = np.random.randint(self.episode_lengths[her_episode_indices])
|
|
|
|
else:
|
|
raise ValueError(f"Strategy {self.goal_selection_strategy} for sampling goals not supported!")
|
|
|
|
return self.buffer["achieved_goal"][her_episode_indices, transitions_indices]
|
|
|
|
def _sample_transitions(
|
|
self,
|
|
batch_size: Optional[int],
|
|
maybe_vec_env: Optional[VecNormalize],
|
|
online_sampling: bool,
|
|
n_sampled_goal: Optional[int] = None,
|
|
) -> Union[ReplayBufferSamples, Tuple[np.ndarray, ...]]:
|
|
"""
|
|
:param batch_size: Number of element to sample (only used for online sampling)
|
|
:param env: associated gym VecEnv to normalize the observations/rewards
|
|
Only valid when using online sampling
|
|
:param online_sampling: Using online_sampling for HER or not.
|
|
:param n_sampled_goal: Number of sampled goals for replay. (offline sampling)
|
|
:return: Samples.
|
|
"""
|
|
# Select which episodes to use
|
|
if online_sampling:
|
|
assert batch_size is not None, "No batch_size specified for online sampling of HER transitions"
|
|
# Do not sample the episode with index `self.pos` as the episode is invalid
|
|
if self.full:
|
|
episode_indices = (
|
|
np.random.randint(1, self.n_episodes_stored, batch_size) + self.pos
|
|
) % self.n_episodes_stored
|
|
else:
|
|
episode_indices = np.random.randint(0, self.n_episodes_stored, batch_size)
|
|
# A subset of the transitions will be relabeled using HER algorithm
|
|
her_indices = np.arange(batch_size)[: int(self.her_ratio * batch_size)]
|
|
else:
|
|
assert maybe_vec_env is None, "Transitions must be stored unnormalized in the replay buffer"
|
|
assert n_sampled_goal is not None, "No n_sampled_goal specified for offline sampling of HER transitions"
|
|
# Offline sampling: there is only one episode stored
|
|
episode_length = self.episode_lengths[0]
|
|
# we sample n_sampled_goal per timestep in the episode (only one is stored).
|
|
episode_indices = np.tile(0, (episode_length * n_sampled_goal))
|
|
# we only sample virtual transitions
|
|
# as real transitions are already stored in the replay buffer
|
|
her_indices = np.arange(len(episode_indices))
|
|
|
|
ep_lengths = self.episode_lengths[episode_indices]
|
|
|
|
# Special case when using the "future" goal sampling strategy
|
|
# we cannot sample all transitions, we have to remove the last timestep
|
|
if self.goal_selection_strategy == GoalSelectionStrategy.FUTURE:
|
|
# restrict the sampling domain when ep_lengths > 1
|
|
# otherwise filter out the indices
|
|
her_indices = her_indices[ep_lengths[her_indices] > 1]
|
|
ep_lengths[her_indices] -= 1
|
|
|
|
if online_sampling:
|
|
# Select which transitions to use
|
|
transitions_indices = np.random.randint(ep_lengths)
|
|
else:
|
|
if her_indices.size == 0:
|
|
# Episode of one timestep, not enough for using the "future" strategy
|
|
# no virtual transitions are created in that case
|
|
return np.zeros(0), np.zeros(0), np.zeros(0), np.zeros(0)
|
|
else:
|
|
# Repeat every transition index n_sampled_goals times
|
|
# to sample n_sampled_goal per timestep in the episode (only one is stored).
|
|
# Now with the corrected episode length when using "future" strategy
|
|
transitions_indices = np.tile(np.arange(ep_lengths[0]), n_sampled_goal)
|
|
episode_indices = episode_indices[transitions_indices]
|
|
her_indices = np.arange(len(episode_indices))
|
|
|
|
# get selected transitions
|
|
transitions = {key: self.buffer[key][episode_indices, transitions_indices].copy() for key in self.buffer.keys()}
|
|
|
|
# sample new desired goals and relabel the transitions
|
|
new_goals = self.sample_goals(episode_indices, her_indices, transitions_indices)
|
|
transitions["desired_goal"][her_indices] = new_goals
|
|
|
|
# Convert info buffer to numpy array
|
|
transitions["info"] = np.array(
|
|
[
|
|
self.info_buffer[episode_idx][transition_idx]
|
|
for episode_idx, transition_idx in zip(episode_indices, transitions_indices)
|
|
]
|
|
)
|
|
|
|
# Vectorized computation of the new reward
|
|
transitions["reward"][her_indices, 0] = self.env.env_method(
|
|
"compute_reward",
|
|
# the new state depends on the previous state and action
|
|
# s_{t+1} = f(s_t, a_t)
|
|
# so the next_achieved_goal depends also on the previous state and action
|
|
# because we are in a GoalEnv:
|
|
# r_t = reward(s_t, a_t) = reward(next_achieved_goal, desired_goal)
|
|
# therefore we have to use "next_achieved_goal" and not "achieved_goal"
|
|
transitions["next_achieved_goal"][her_indices, 0],
|
|
# here we use the new desired goal
|
|
transitions["desired_goal"][her_indices, 0],
|
|
transitions["info"][her_indices, 0],
|
|
)
|
|
|
|
# concatenate observation with (desired) goal
|
|
observations = ObsDictWrapper.convert_dict(self._normalize_obs(transitions, maybe_vec_env))
|
|
# HACK to make normalize obs work with the next observation
|
|
transitions["observation"] = transitions["next_obs"]
|
|
next_observations = ObsDictWrapper.convert_dict(self._normalize_obs(transitions, maybe_vec_env))
|
|
|
|
if online_sampling:
|
|
data = (
|
|
observations[:, 0],
|
|
transitions["action"],
|
|
next_observations[:, 0],
|
|
transitions["done"],
|
|
self._normalize_reward(transitions["reward"], maybe_vec_env),
|
|
)
|
|
|
|
return ReplayBufferSamples(*tuple(map(self.to_torch, data)))
|
|
else:
|
|
return observations, next_observations, transitions["action"], transitions["reward"]
|
|
|
|
def add(
|
|
self,
|
|
obs: Dict[str, np.ndarray],
|
|
next_obs: Dict[str, np.ndarray],
|
|
action: np.ndarray,
|
|
reward: np.ndarray,
|
|
done: np.ndarray,
|
|
infos: List[dict],
|
|
) -> None:
|
|
|
|
if self.current_idx == 0 and self.full:
|
|
# Clear info buffer
|
|
self.info_buffer[self.pos] = deque(maxlen=self.max_episode_length)
|
|
|
|
self.buffer["observation"][self.pos][self.current_idx] = obs["observation"]
|
|
self.buffer["achieved_goal"][self.pos][self.current_idx] = obs["achieved_goal"]
|
|
self.buffer["desired_goal"][self.pos][self.current_idx] = obs["desired_goal"]
|
|
self.buffer["action"][self.pos][self.current_idx] = action
|
|
self.buffer["done"][self.pos][self.current_idx] = done
|
|
self.buffer["reward"][self.pos][self.current_idx] = reward
|
|
self.buffer["next_obs"][self.pos][self.current_idx] = next_obs["observation"]
|
|
self.buffer["next_achieved_goal"][self.pos][self.current_idx] = next_obs["achieved_goal"]
|
|
self.buffer["next_desired_goal"][self.pos][self.current_idx] = next_obs["desired_goal"]
|
|
|
|
self.info_buffer[self.pos].append(infos)
|
|
|
|
# update current pointer
|
|
self.current_idx += 1
|
|
|
|
def store_episode(self) -> None:
|
|
"""
|
|
Increment episode counter
|
|
and reset transition pointer.
|
|
"""
|
|
# add episode length to length storage
|
|
self.episode_lengths[self.pos] = self.current_idx
|
|
|
|
# update current episode pointer
|
|
# Note: in the OpenAI implementation
|
|
# when the buffer is full, the episode replaced
|
|
# is randomly chosen
|
|
self.pos += 1
|
|
if self.pos == self.max_episode_stored:
|
|
self.full = True
|
|
self.pos = 0
|
|
# reset transition pointer
|
|
self.current_idx = 0
|
|
|
|
@property
|
|
def n_episodes_stored(self) -> int:
|
|
if self.full:
|
|
return self.max_episode_stored
|
|
return self.pos
|
|
|
|
def size(self) -> int:
|
|
"""
|
|
:return: The current number of transitions in the buffer.
|
|
"""
|
|
return int(np.sum(self.episode_lengths))
|
|
|
|
def reset(self) -> None:
|
|
"""
|
|
Reset the buffer.
|
|
"""
|
|
self.pos = 0
|
|
self.current_idx = 0
|
|
self.full = False
|
|
self.episode_lengths = np.zeros(self.max_episode_stored, dtype=np.int64)
|