mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-25 22:35:14 +00:00
582 lines
25 KiB
Python
582 lines
25 KiB
Python
import io
|
|
import pathlib
|
|
import warnings
|
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
|
|
|
|
import numpy as np
|
|
import torch as th
|
|
|
|
from stable_baselines3.common.base_class import BaseAlgorithm
|
|
from stable_baselines3.common.callbacks import BaseCallback
|
|
from stable_baselines3.common.noise import ActionNoise
|
|
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
|
from stable_baselines3.common.policies import BasePolicy
|
|
from stable_baselines3.common.save_util import load_from_zip_file, recursive_setattr
|
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn, TrainFreq
|
|
from stable_baselines3.common.utils import check_for_correct_spaces, should_collect_more_steps
|
|
from stable_baselines3.common.vec_env import VecEnv
|
|
from stable_baselines3.common.vec_env.obs_dict_wrapper import ObsDictWrapper
|
|
from stable_baselines3.her.goal_selection_strategy import KEY_TO_GOAL_STRATEGY, GoalSelectionStrategy
|
|
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
|
|
|
|
|
|
def get_time_limit(env: VecEnv, current_max_episode_length: Optional[int]) -> int:
|
|
"""
|
|
Get time limit from environment.
|
|
|
|
:param env: Environment from which we want to get the time limit.
|
|
:param current_max_episode_length: Current value for max_episode_length.
|
|
:return: max episode length
|
|
"""
|
|
# try to get the attribute from environment
|
|
if current_max_episode_length is None:
|
|
try:
|
|
current_max_episode_length = env.get_attr("spec")[0].max_episode_steps
|
|
# Raise the error because the attribute is present but is None
|
|
if current_max_episode_length is None:
|
|
raise AttributeError
|
|
# if not available check if a valid value was passed as an argument
|
|
except AttributeError:
|
|
raise ValueError(
|
|
"The max episode length could not be inferred.\n"
|
|
"You must specify a `max_episode_steps` when registering the environment,\n"
|
|
"use a `gym.wrappers.TimeLimit` wrapper "
|
|
"or pass `max_episode_length` to the model constructor"
|
|
)
|
|
return current_max_episode_length
|
|
|
|
|
|
# TODO: rewrite HER class as soon as dict obs are supported
|
|
class HER(BaseAlgorithm):
|
|
"""
|
|
Hindsight Experience Replay (HER)
|
|
Paper: https://arxiv.org/abs/1707.01495
|
|
|
|
.. warning::
|
|
|
|
For performance reasons, the maximum number of steps per episodes must be specified.
|
|
In most cases, it will be inferred if you specify ``max_episode_steps`` when registering the environment
|
|
or if you use a ``gym.wrappers.TimeLimit`` (and ``env.spec`` is not None).
|
|
Otherwise, you can directly pass ``max_episode_length`` to the model constructor
|
|
|
|
|
|
For additional offline algorithm specific arguments please have a look at the corresponding documentation.
|
|
|
|
:param policy: The policy model to use.
|
|
:param env: The environment to learn from (if registered in Gym, can be str)
|
|
:param model_class: Off policy model which will be used with hindsight experience replay. (SAC, TD3, DDPG, DQN)
|
|
:param n_sampled_goal: Number of sampled goals for replay. (offline sampling)
|
|
:param goal_selection_strategy: Strategy for sampling goals for replay.
|
|
One of ['episode', 'final', 'future', 'random']
|
|
:param online_sampling: Sample HER transitions online.
|
|
:param learning_rate: learning rate for the optimizer,
|
|
it can be a function of the current progress remaining (from 1 to 0)
|
|
:param max_episode_length: The maximum length of an episode. If not specified,
|
|
it will be automatically inferred if the environment uses a ``gym.wrappers.TimeLimit`` wrapper.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
policy: Union[str, Type[BasePolicy]],
|
|
env: Union[GymEnv, str],
|
|
model_class: Type[OffPolicyAlgorithm],
|
|
n_sampled_goal: int = 4,
|
|
goal_selection_strategy: Union[GoalSelectionStrategy, str] = "future",
|
|
online_sampling: bool = False,
|
|
max_episode_length: Optional[int] = None,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
|
|
# we will use the policy and learning rate from the model
|
|
super(HER, self).__init__(policy=BasePolicy, env=env, policy_base=BasePolicy, learning_rate=0.0)
|
|
del self.policy, self.learning_rate
|
|
|
|
if self.get_vec_normalize_env() is not None:
|
|
assert online_sampling, "You must pass `online_sampling=True` if you want to use `VecNormalize` with `HER`"
|
|
|
|
_init_setup_model = kwargs.get("_init_setup_model", True)
|
|
if "_init_setup_model" in kwargs:
|
|
del kwargs["_init_setup_model"]
|
|
# model initialization
|
|
self.model_class = model_class
|
|
self.model = model_class(
|
|
policy=policy,
|
|
env=self.env,
|
|
_init_setup_model=False, # pytype: disable=wrong-keyword-args
|
|
*args,
|
|
**kwargs, # pytype: disable=wrong-keyword-args
|
|
)
|
|
|
|
# Make HER use self.model.action_noise
|
|
del self.action_noise
|
|
self.verbose = self.model.verbose
|
|
self.tensorboard_log = self.model.tensorboard_log
|
|
|
|
# convert goal_selection_strategy into GoalSelectionStrategy if string
|
|
if isinstance(goal_selection_strategy, str):
|
|
self.goal_selection_strategy = KEY_TO_GOAL_STRATEGY[goal_selection_strategy.lower()]
|
|
else:
|
|
self.goal_selection_strategy = goal_selection_strategy
|
|
|
|
# check if goal_selection_strategy is valid
|
|
assert isinstance(
|
|
self.goal_selection_strategy, GoalSelectionStrategy
|
|
), f"Invalid goal selection strategy, please use one of {list(GoalSelectionStrategy)}"
|
|
|
|
self.n_sampled_goal = n_sampled_goal
|
|
# if we sample her transitions online use custom replay buffer
|
|
self.online_sampling = online_sampling
|
|
# compute ratio between HER replays and regular replays in percent for online HER sampling
|
|
self.her_ratio = 1 - (1.0 / (self.n_sampled_goal + 1))
|
|
# maximum steps in episode
|
|
self.max_episode_length = get_time_limit(self.env, max_episode_length)
|
|
# storage for transitions of current episode for offline sampling
|
|
# for online sampling, it replaces the "classic" replay buffer completely
|
|
her_buffer_size = self.buffer_size if online_sampling else self.max_episode_length
|
|
|
|
assert self.env is not None, "Because it needs access to `env.compute_reward()` HER you must provide the env."
|
|
|
|
self._episode_storage = HerReplayBuffer(
|
|
self.env,
|
|
her_buffer_size,
|
|
self.max_episode_length,
|
|
self.goal_selection_strategy,
|
|
self.env.observation_space,
|
|
self.env.action_space,
|
|
self.device,
|
|
self.n_envs,
|
|
self.her_ratio, # pytype: disable=wrong-arg-types
|
|
)
|
|
|
|
# counter for steps in episode
|
|
self.episode_steps = 0
|
|
|
|
if _init_setup_model:
|
|
self._setup_model()
|
|
|
|
def _setup_model(self) -> None:
|
|
self.model._setup_model()
|
|
# assign episode storage to replay buffer when using online HER sampling
|
|
if self.online_sampling:
|
|
self.model.replay_buffer = self._episode_storage
|
|
|
|
def predict(
|
|
self,
|
|
observation: np.ndarray,
|
|
state: Optional[np.ndarray] = None,
|
|
mask: Optional[np.ndarray] = None,
|
|
deterministic: bool = False,
|
|
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
|
|
|
return self.model.predict(observation, state, mask, deterministic)
|
|
|
|
def learn(
|
|
self,
|
|
total_timesteps: int,
|
|
callback: MaybeCallback = None,
|
|
log_interval: int = 4,
|
|
eval_env: Optional[GymEnv] = None,
|
|
eval_freq: int = -1,
|
|
n_eval_episodes: int = 5,
|
|
tb_log_name: str = "HER",
|
|
eval_log_path: Optional[str] = None,
|
|
reset_num_timesteps: bool = True,
|
|
) -> BaseAlgorithm:
|
|
|
|
total_timesteps, callback = self._setup_learn(
|
|
total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name
|
|
)
|
|
self.model.start_time = self.start_time
|
|
self.model.ep_info_buffer = self.ep_info_buffer
|
|
self.model.ep_success_buffer = self.ep_success_buffer
|
|
self.model.num_timesteps = self.num_timesteps
|
|
self.model._episode_num = self._episode_num
|
|
self.model._last_obs = self._last_obs
|
|
self.model._total_timesteps = self._total_timesteps
|
|
|
|
callback.on_training_start(locals(), globals())
|
|
|
|
while self.num_timesteps < total_timesteps:
|
|
rollout = self.collect_rollouts(
|
|
self.env,
|
|
train_freq=self.train_freq,
|
|
action_noise=self.action_noise,
|
|
callback=callback,
|
|
learning_starts=self.learning_starts,
|
|
log_interval=log_interval,
|
|
)
|
|
|
|
if rollout.continue_training is False:
|
|
break
|
|
|
|
if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts and self.replay_buffer.size() > 0:
|
|
# If no `gradient_steps` is specified,
|
|
# do as many gradients steps as steps performed during the rollout
|
|
gradient_steps = self.gradient_steps if self.gradient_steps > 0 else rollout.episode_timesteps
|
|
self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
|
|
|
|
callback.on_training_end()
|
|
|
|
return self
|
|
|
|
def collect_rollouts(
|
|
self,
|
|
env: VecEnv,
|
|
callback: BaseCallback,
|
|
train_freq: TrainFreq,
|
|
action_noise: Optional[ActionNoise] = None,
|
|
learning_starts: int = 0,
|
|
log_interval: Optional[int] = None,
|
|
) -> RolloutReturn:
|
|
"""
|
|
Collect experiences and store them into a ReplayBuffer.
|
|
|
|
:param env: The training environment
|
|
:param callback: Callback that will be called at each step
|
|
(and at the beginning and end of the rollout)
|
|
:param train_freq: How much experience to collect
|
|
by doing rollouts of current policy.
|
|
Either ``TrainFreq(<n>, TrainFrequencyUnit.STEP)``
|
|
or ``TrainFreq(<n>, TrainFrequencyUnit.EPISODE)``
|
|
with ``<n>`` being an integer greater than 0.
|
|
:param action_noise: Action noise that will be used for exploration
|
|
Required for deterministic policy (e.g. TD3). This can also be used
|
|
in addition to the stochastic policy for SAC.
|
|
:param learning_starts: Number of steps before learning for the warm-up phase.
|
|
:param log_interval: Log data every ``log_interval`` episodes
|
|
:return:
|
|
"""
|
|
|
|
episode_rewards, total_timesteps = [], []
|
|
num_collected_steps, num_collected_episodes = 0, 0
|
|
|
|
assert isinstance(env, VecEnv), "You must pass a VecEnv"
|
|
assert env.num_envs == 1, "OffPolicyAlgorithm only support single environment"
|
|
assert train_freq.frequency > 0, "Should at least collect one step or episode."
|
|
|
|
if self.model.use_sde:
|
|
self.actor.reset_noise()
|
|
|
|
callback.on_rollout_start()
|
|
continue_training = True
|
|
|
|
while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes):
|
|
done = False
|
|
episode_reward, episode_timesteps = 0.0, 0
|
|
|
|
while not done:
|
|
# concatenate observation and (desired) goal
|
|
observation = self._last_obs
|
|
self._last_obs = ObsDictWrapper.convert_dict(observation)
|
|
|
|
if (
|
|
self.model.use_sde
|
|
and self.model.sde_sample_freq > 0
|
|
and num_collected_steps % self.model.sde_sample_freq == 0
|
|
):
|
|
# Sample a new noise matrix
|
|
self.actor.reset_noise()
|
|
|
|
# Select action randomly or according to policy
|
|
self.model._last_obs = self._last_obs
|
|
action, buffer_action = self._sample_action(learning_starts, action_noise)
|
|
|
|
# Perform action
|
|
new_obs, reward, done, infos = env.step(action)
|
|
|
|
self.num_timesteps += 1
|
|
self.model.num_timesteps = self.num_timesteps
|
|
episode_timesteps += 1
|
|
num_collected_steps += 1
|
|
|
|
# Only stop training if return value is False, not when it is None.
|
|
if callback.on_step() is False:
|
|
return RolloutReturn(0.0, num_collected_steps, num_collected_episodes, continue_training=False)
|
|
|
|
episode_reward += reward
|
|
|
|
# Retrieve reward and episode length if using Monitor wrapper
|
|
self._update_info_buffer(infos, done)
|
|
self.model.ep_info_buffer = self.ep_info_buffer
|
|
self.model.ep_success_buffer = self.ep_success_buffer
|
|
|
|
# == Store transition in the replay buffer and/or in the episode storage ==
|
|
|
|
if self._vec_normalize_env is not None:
|
|
# Store only the unnormalized version
|
|
new_obs_ = self._vec_normalize_env.get_original_obs()
|
|
reward_ = self._vec_normalize_env.get_original_reward()
|
|
else:
|
|
# Avoid changing the original ones
|
|
self._last_original_obs, new_obs_, reward_ = observation, new_obs, reward
|
|
self.model._last_original_obs = self._last_original_obs
|
|
|
|
# As the VecEnv resets automatically, new_obs is already the
|
|
# first observation of the next episode
|
|
if done and infos[0].get("terminal_observation") is not None:
|
|
next_obs = infos[0]["terminal_observation"]
|
|
# VecNormalize normalizes the terminal observation
|
|
if self._vec_normalize_env is not None:
|
|
next_obs = self._vec_normalize_env.unnormalize_obs(next_obs)
|
|
else:
|
|
next_obs = new_obs_
|
|
|
|
if self.online_sampling:
|
|
self.replay_buffer.add(self._last_original_obs, next_obs, buffer_action, reward_, done, infos)
|
|
else:
|
|
# concatenate observation with (desired) goal
|
|
flattened_obs = ObsDictWrapper.convert_dict(self._last_original_obs)
|
|
flattened_next_obs = ObsDictWrapper.convert_dict(next_obs)
|
|
# add to replay buffer
|
|
self.replay_buffer.add(flattened_obs, flattened_next_obs, buffer_action, reward_, done)
|
|
# add current transition to episode storage
|
|
self._episode_storage.add(self._last_original_obs, next_obs, buffer_action, reward_, done, infos)
|
|
|
|
self._last_obs = new_obs
|
|
self.model._last_obs = self._last_obs
|
|
|
|
# Save the unnormalized new observation
|
|
if self._vec_normalize_env is not None:
|
|
self._last_original_obs = new_obs_
|
|
self.model._last_original_obs = self._last_original_obs
|
|
|
|
self.model._update_current_progress_remaining(self.num_timesteps, self._total_timesteps)
|
|
|
|
# For DQN, check if the target network should be updated
|
|
# and update the exploration schedule
|
|
# For SAC/TD3, the update is done as the same time as the gradient update
|
|
# see https://github.com/hill-a/stable-baselines/issues/900
|
|
self.model._on_step()
|
|
|
|
self.episode_steps += 1
|
|
|
|
if not should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes):
|
|
break
|
|
|
|
if done or self.episode_steps >= self.max_episode_length:
|
|
if self.online_sampling:
|
|
self.replay_buffer.store_episode()
|
|
else:
|
|
self._episode_storage.store_episode()
|
|
# sample virtual transitions and store them in replay buffer
|
|
self._sample_her_transitions()
|
|
# clear storage for current episode
|
|
self._episode_storage.reset()
|
|
|
|
num_collected_episodes += 1
|
|
self._episode_num += 1
|
|
self.model._episode_num = self._episode_num
|
|
episode_rewards.append(episode_reward)
|
|
total_timesteps.append(episode_timesteps)
|
|
|
|
if action_noise is not None:
|
|
action_noise.reset()
|
|
|
|
# Log training infos
|
|
if log_interval is not None and self._episode_num % log_interval == 0:
|
|
self._dump_logs()
|
|
|
|
self.episode_steps = 0
|
|
|
|
mean_reward = np.mean(episode_rewards) if num_collected_episodes > 0 else 0.0
|
|
|
|
callback.on_rollout_end()
|
|
|
|
return RolloutReturn(mean_reward, num_collected_steps, num_collected_episodes, continue_training)
|
|
|
|
def _sample_her_transitions(self) -> None:
|
|
"""
|
|
Sample additional goals and store new transitions in replay buffer
|
|
when using offline sampling.
|
|
"""
|
|
|
|
# Sample goals and get new observations
|
|
# maybe_vec_env=None as we should store unnormalized transitions,
|
|
# they will be normalized at sampling time
|
|
observations, next_observations, actions, rewards = self._episode_storage.sample_offline(
|
|
n_sampled_goal=self.n_sampled_goal
|
|
)
|
|
|
|
# store data in replay buffer
|
|
dones = np.zeros((len(observations)), dtype=bool)
|
|
self.replay_buffer.extend(observations, next_observations, actions, rewards, dones)
|
|
|
|
def __getattr__(self, item: str) -> Any:
|
|
"""
|
|
Find attribute from model class if this class does not have it.
|
|
"""
|
|
if hasattr(self.model, item):
|
|
return getattr(self.model, item)
|
|
else:
|
|
raise AttributeError(f"{self} has no attribute {item}")
|
|
|
|
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
|
return self.model._get_torch_save_params()
|
|
|
|
def save(
|
|
self,
|
|
path: Union[str, pathlib.Path, io.BufferedIOBase],
|
|
exclude: Optional[Iterable[str]] = None,
|
|
include: Optional[Iterable[str]] = None,
|
|
) -> None:
|
|
"""
|
|
Save all the attributes of the object and the model parameters in a zip-file.
|
|
|
|
:param path: path to the file where the rl agent should be saved
|
|
:param exclude: name of parameters that should be excluded in addition to the default one
|
|
:param include: name of parameters that might be excluded but should be included anyway
|
|
"""
|
|
|
|
# add HER parameters to model
|
|
self.model.n_sampled_goal = self.n_sampled_goal
|
|
self.model.goal_selection_strategy = self.goal_selection_strategy
|
|
self.model.online_sampling = self.online_sampling
|
|
self.model.model_class = self.model_class
|
|
self.model.max_episode_length = self.max_episode_length
|
|
|
|
self.model.save(path, exclude, include)
|
|
|
|
@classmethod
|
|
def load(
|
|
cls,
|
|
path: Union[str, pathlib.Path, io.BufferedIOBase],
|
|
env: Optional[GymEnv] = None,
|
|
device: Union[th.device, str] = "auto",
|
|
custom_objects: Optional[Dict[str, Any]] = None,
|
|
**kwargs,
|
|
) -> "BaseAlgorithm":
|
|
"""
|
|
Load the model from a zip-file
|
|
|
|
:param path: path to the file (or a file-like) where to
|
|
load the agent from
|
|
:param env: the new environment to run the loaded model on
|
|
(can be None if you only need prediction from a trained model) has priority over any saved environment
|
|
:param device: Device on which the code should run.
|
|
:param custom_objects: Dictionary of objects to replace
|
|
upon loading. If a variable is present in this dictionary as a
|
|
key, it will not be deserialized and the corresponding item
|
|
will be used instead. Similar to custom_objects in
|
|
``keras.models.load_model``. Useful when you have an object in
|
|
file that can not be deserialized.
|
|
:param kwargs: extra arguments to change the model when loading
|
|
"""
|
|
data, params, pytorch_variables = load_from_zip_file(path, device=device, custom_objects=custom_objects)
|
|
|
|
# Remove stored device information and replace with ours
|
|
if "policy_kwargs" in data:
|
|
if "device" in data["policy_kwargs"]:
|
|
del data["policy_kwargs"]["device"]
|
|
|
|
if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]:
|
|
raise ValueError(
|
|
f"The specified policy kwargs do not equal the stored policy kwargs."
|
|
f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}"
|
|
)
|
|
|
|
# check if observation space and action space are part of the saved parameters
|
|
if "observation_space" not in data or "action_space" not in data:
|
|
raise KeyError("The observation_space and action_space were not given, can't verify new environments")
|
|
|
|
# check if given env is valid
|
|
if env is not None:
|
|
# Wrap first if needed
|
|
env = cls._wrap_env(env, data["verbose"])
|
|
# Check if given env is valid
|
|
check_for_correct_spaces(env, data["observation_space"], data["action_space"])
|
|
else:
|
|
# Use stored env, if one exists. If not, continue as is (can be used for predict)
|
|
if "env" in data:
|
|
env = data["env"]
|
|
|
|
if "use_sde" in data and data["use_sde"]:
|
|
kwargs["use_sde"] = True
|
|
|
|
# Keys that cannot be changed
|
|
for key in {"model_class", "online_sampling", "max_episode_length"}:
|
|
if key in kwargs:
|
|
del kwargs[key]
|
|
|
|
# Keys that can be changed
|
|
for key in {"n_sampled_goal", "goal_selection_strategy"}:
|
|
if key in kwargs:
|
|
data[key] = kwargs[key] # pytype: disable=unsupported-operands
|
|
del kwargs[key]
|
|
|
|
# noinspection PyArgumentList
|
|
her_model = cls(
|
|
policy=data["policy_class"],
|
|
env=env,
|
|
model_class=data["model_class"],
|
|
n_sampled_goal=data["n_sampled_goal"],
|
|
goal_selection_strategy=data["goal_selection_strategy"],
|
|
online_sampling=data["online_sampling"],
|
|
max_episode_length=data["max_episode_length"],
|
|
policy_kwargs=data["policy_kwargs"],
|
|
_init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args
|
|
**kwargs,
|
|
)
|
|
|
|
# load parameters
|
|
her_model.model.__dict__.update(data)
|
|
her_model.model.__dict__.update(kwargs)
|
|
her_model._setup_model()
|
|
|
|
her_model._total_timesteps = her_model.model._total_timesteps
|
|
her_model.num_timesteps = her_model.model.num_timesteps
|
|
her_model._episode_num = her_model.model._episode_num
|
|
|
|
# put state_dicts back in place
|
|
her_model.model.set_parameters(params, exact_match=True, device=device)
|
|
|
|
# put other pytorch variables back in place
|
|
if pytorch_variables is not None:
|
|
for name in pytorch_variables:
|
|
recursive_setattr(her_model.model, name, pytorch_variables[name])
|
|
|
|
# Sample gSDE exploration matrix, so it uses the right device
|
|
# see issue #44
|
|
if her_model.model.use_sde:
|
|
her_model.model.policy.reset_noise() # pytype: disable=attribute-error
|
|
return her_model
|
|
|
|
def load_replay_buffer(
|
|
self, path: Union[str, pathlib.Path, io.BufferedIOBase], truncate_last_trajectory: bool = True
|
|
) -> None:
|
|
"""
|
|
Load a replay buffer from a pickle file and set environment for replay buffer (only online sampling).
|
|
|
|
:param path: Path to the pickled replay buffer.
|
|
:param truncate_last_trajectory: Only for online sampling.
|
|
If set to ``True`` we assume that the last trajectory in the replay buffer was finished.
|
|
If it is set to ``False`` we assume that we continue the same trajectory (same episode).
|
|
"""
|
|
self.model.load_replay_buffer(path=path)
|
|
|
|
if self.online_sampling:
|
|
# set environment
|
|
self.replay_buffer.set_env(self.env)
|
|
# If we are at the start of an episode, no need to truncate
|
|
current_idx = self.replay_buffer.current_idx
|
|
|
|
# truncate interrupted episode
|
|
if truncate_last_trajectory and current_idx > 0:
|
|
warnings.warn(
|
|
"The last trajectory in the replay buffer will be truncated.\n"
|
|
"If you are in the same episode as when the replay buffer was saved,\n"
|
|
"you should use `truncate_last_trajectory=False` to avoid that issue."
|
|
)
|
|
# get current episode and transition index
|
|
pos = self.replay_buffer.pos
|
|
# set episode length for current episode
|
|
self.replay_buffer.episode_lengths[pos] = current_idx
|
|
# set done = True for current episode
|
|
# current_idx was already incremented
|
|
self.replay_buffer.buffer["done"][pos][current_idx - 1] = np.array([True], dtype=np.float32)
|
|
# reset current transition index
|
|
self.replay_buffer.current_idx = 0
|
|
# increment episode counter
|
|
self.replay_buffer.pos = (self.replay_buffer.pos + 1) % self.replay_buffer.max_episode_stored
|
|
# update "full" indicator
|
|
self.replay_buffer.full = self.replay_buffer.full or self.replay_buffer.pos == 0
|