stable-baselines3/stable_baselines3/her/her.py
2021-03-06 15:57:27 +01:00

582 lines
25 KiB
Python

import io
import pathlib
import warnings
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
import numpy as np
import torch as th
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.save_util import load_from_zip_file, recursive_setattr
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn, TrainFreq
from stable_baselines3.common.utils import check_for_correct_spaces, should_collect_more_steps
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.vec_env.obs_dict_wrapper import ObsDictWrapper
from stable_baselines3.her.goal_selection_strategy import KEY_TO_GOAL_STRATEGY, GoalSelectionStrategy
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
def get_time_limit(env: VecEnv, current_max_episode_length: Optional[int]) -> int:
"""
Get time limit from environment.
:param env: Environment from which we want to get the time limit.
:param current_max_episode_length: Current value for max_episode_length.
:return: max episode length
"""
# try to get the attribute from environment
if current_max_episode_length is None:
try:
current_max_episode_length = env.get_attr("spec")[0].max_episode_steps
# Raise the error because the attribute is present but is None
if current_max_episode_length is None:
raise AttributeError
# if not available check if a valid value was passed as an argument
except AttributeError:
raise ValueError(
"The max episode length could not be inferred.\n"
"You must specify a `max_episode_steps` when registering the environment,\n"
"use a `gym.wrappers.TimeLimit` wrapper "
"or pass `max_episode_length` to the model constructor"
)
return current_max_episode_length
# TODO: rewrite HER class as soon as dict obs are supported
class HER(BaseAlgorithm):
"""
Hindsight Experience Replay (HER)
Paper: https://arxiv.org/abs/1707.01495
.. warning::
For performance reasons, the maximum number of steps per episodes must be specified.
In most cases, it will be inferred if you specify ``max_episode_steps`` when registering the environment
or if you use a ``gym.wrappers.TimeLimit`` (and ``env.spec`` is not None).
Otherwise, you can directly pass ``max_episode_length`` to the model constructor
For additional offline algorithm specific arguments please have a look at the corresponding documentation.
:param policy: The policy model to use.
:param env: The environment to learn from (if registered in Gym, can be str)
:param model_class: Off policy model which will be used with hindsight experience replay. (SAC, TD3, DDPG, DQN)
:param n_sampled_goal: Number of sampled goals for replay. (offline sampling)
:param goal_selection_strategy: Strategy for sampling goals for replay.
One of ['episode', 'final', 'future', 'random']
:param online_sampling: Sample HER transitions online.
:param learning_rate: learning rate for the optimizer,
it can be a function of the current progress remaining (from 1 to 0)
:param max_episode_length: The maximum length of an episode. If not specified,
it will be automatically inferred if the environment uses a ``gym.wrappers.TimeLimit`` wrapper.
"""
def __init__(
self,
policy: Union[str, Type[BasePolicy]],
env: Union[GymEnv, str],
model_class: Type[OffPolicyAlgorithm],
n_sampled_goal: int = 4,
goal_selection_strategy: Union[GoalSelectionStrategy, str] = "future",
online_sampling: bool = False,
max_episode_length: Optional[int] = None,
*args,
**kwargs,
):
# we will use the policy and learning rate from the model
super(HER, self).__init__(policy=BasePolicy, env=env, policy_base=BasePolicy, learning_rate=0.0)
del self.policy, self.learning_rate
if self.get_vec_normalize_env() is not None:
assert online_sampling, "You must pass `online_sampling=True` if you want to use `VecNormalize` with `HER`"
_init_setup_model = kwargs.get("_init_setup_model", True)
if "_init_setup_model" in kwargs:
del kwargs["_init_setup_model"]
# model initialization
self.model_class = model_class
self.model = model_class(
policy=policy,
env=self.env,
_init_setup_model=False, # pytype: disable=wrong-keyword-args
*args,
**kwargs, # pytype: disable=wrong-keyword-args
)
# Make HER use self.model.action_noise
del self.action_noise
self.verbose = self.model.verbose
self.tensorboard_log = self.model.tensorboard_log
# convert goal_selection_strategy into GoalSelectionStrategy if string
if isinstance(goal_selection_strategy, str):
self.goal_selection_strategy = KEY_TO_GOAL_STRATEGY[goal_selection_strategy.lower()]
else:
self.goal_selection_strategy = goal_selection_strategy
# check if goal_selection_strategy is valid
assert isinstance(
self.goal_selection_strategy, GoalSelectionStrategy
), f"Invalid goal selection strategy, please use one of {list(GoalSelectionStrategy)}"
self.n_sampled_goal = n_sampled_goal
# if we sample her transitions online use custom replay buffer
self.online_sampling = online_sampling
# compute ratio between HER replays and regular replays in percent for online HER sampling
self.her_ratio = 1 - (1.0 / (self.n_sampled_goal + 1))
# maximum steps in episode
self.max_episode_length = get_time_limit(self.env, max_episode_length)
# storage for transitions of current episode for offline sampling
# for online sampling, it replaces the "classic" replay buffer completely
her_buffer_size = self.buffer_size if online_sampling else self.max_episode_length
assert self.env is not None, "Because it needs access to `env.compute_reward()` HER you must provide the env."
self._episode_storage = HerReplayBuffer(
self.env,
her_buffer_size,
self.max_episode_length,
self.goal_selection_strategy,
self.env.observation_space,
self.env.action_space,
self.device,
self.n_envs,
self.her_ratio, # pytype: disable=wrong-arg-types
)
# counter for steps in episode
self.episode_steps = 0
if _init_setup_model:
self._setup_model()
def _setup_model(self) -> None:
self.model._setup_model()
# assign episode storage to replay buffer when using online HER sampling
if self.online_sampling:
self.model.replay_buffer = self._episode_storage
def predict(
self,
observation: np.ndarray,
state: Optional[np.ndarray] = None,
mask: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
return self.model.predict(observation, state, mask, deterministic)
def learn(
self,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
eval_env: Optional[GymEnv] = None,
eval_freq: int = -1,
n_eval_episodes: int = 5,
tb_log_name: str = "HER",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
) -> BaseAlgorithm:
total_timesteps, callback = self._setup_learn(
total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name
)
self.model.start_time = self.start_time
self.model.ep_info_buffer = self.ep_info_buffer
self.model.ep_success_buffer = self.ep_success_buffer
self.model.num_timesteps = self.num_timesteps
self.model._episode_num = self._episode_num
self.model._last_obs = self._last_obs
self.model._total_timesteps = self._total_timesteps
callback.on_training_start(locals(), globals())
while self.num_timesteps < total_timesteps:
rollout = self.collect_rollouts(
self.env,
train_freq=self.train_freq,
action_noise=self.action_noise,
callback=callback,
learning_starts=self.learning_starts,
log_interval=log_interval,
)
if rollout.continue_training is False:
break
if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts and self.replay_buffer.size() > 0:
# If no `gradient_steps` is specified,
# do as many gradients steps as steps performed during the rollout
gradient_steps = self.gradient_steps if self.gradient_steps > 0 else rollout.episode_timesteps
self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
callback.on_training_end()
return self
def collect_rollouts(
self,
env: VecEnv,
callback: BaseCallback,
train_freq: TrainFreq,
action_noise: Optional[ActionNoise] = None,
learning_starts: int = 0,
log_interval: Optional[int] = None,
) -> RolloutReturn:
"""
Collect experiences and store them into a ReplayBuffer.
:param env: The training environment
:param callback: Callback that will be called at each step
(and at the beginning and end of the rollout)
:param train_freq: How much experience to collect
by doing rollouts of current policy.
Either ``TrainFreq(<n>, TrainFrequencyUnit.STEP)``
or ``TrainFreq(<n>, TrainFrequencyUnit.EPISODE)``
with ``<n>`` being an integer greater than 0.
:param action_noise: Action noise that will be used for exploration
Required for deterministic policy (e.g. TD3). This can also be used
in addition to the stochastic policy for SAC.
:param learning_starts: Number of steps before learning for the warm-up phase.
:param log_interval: Log data every ``log_interval`` episodes
:return:
"""
episode_rewards, total_timesteps = [], []
num_collected_steps, num_collected_episodes = 0, 0
assert isinstance(env, VecEnv), "You must pass a VecEnv"
assert env.num_envs == 1, "OffPolicyAlgorithm only support single environment"
assert train_freq.frequency > 0, "Should at least collect one step or episode."
if self.model.use_sde:
self.actor.reset_noise()
callback.on_rollout_start()
continue_training = True
while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes):
done = False
episode_reward, episode_timesteps = 0.0, 0
while not done:
# concatenate observation and (desired) goal
observation = self._last_obs
self._last_obs = ObsDictWrapper.convert_dict(observation)
if (
self.model.use_sde
and self.model.sde_sample_freq > 0
and num_collected_steps % self.model.sde_sample_freq == 0
):
# Sample a new noise matrix
self.actor.reset_noise()
# Select action randomly or according to policy
self.model._last_obs = self._last_obs
action, buffer_action = self._sample_action(learning_starts, action_noise)
# Perform action
new_obs, reward, done, infos = env.step(action)
self.num_timesteps += 1
self.model.num_timesteps = self.num_timesteps
episode_timesteps += 1
num_collected_steps += 1
# Only stop training if return value is False, not when it is None.
if callback.on_step() is False:
return RolloutReturn(0.0, num_collected_steps, num_collected_episodes, continue_training=False)
episode_reward += reward
# Retrieve reward and episode length if using Monitor wrapper
self._update_info_buffer(infos, done)
self.model.ep_info_buffer = self.ep_info_buffer
self.model.ep_success_buffer = self.ep_success_buffer
# == Store transition in the replay buffer and/or in the episode storage ==
if self._vec_normalize_env is not None:
# Store only the unnormalized version
new_obs_ = self._vec_normalize_env.get_original_obs()
reward_ = self._vec_normalize_env.get_original_reward()
else:
# Avoid changing the original ones
self._last_original_obs, new_obs_, reward_ = observation, new_obs, reward
self.model._last_original_obs = self._last_original_obs
# As the VecEnv resets automatically, new_obs is already the
# first observation of the next episode
if done and infos[0].get("terminal_observation") is not None:
next_obs = infos[0]["terminal_observation"]
# VecNormalize normalizes the terminal observation
if self._vec_normalize_env is not None:
next_obs = self._vec_normalize_env.unnormalize_obs(next_obs)
else:
next_obs = new_obs_
if self.online_sampling:
self.replay_buffer.add(self._last_original_obs, next_obs, buffer_action, reward_, done, infos)
else:
# concatenate observation with (desired) goal
flattened_obs = ObsDictWrapper.convert_dict(self._last_original_obs)
flattened_next_obs = ObsDictWrapper.convert_dict(next_obs)
# add to replay buffer
self.replay_buffer.add(flattened_obs, flattened_next_obs, buffer_action, reward_, done)
# add current transition to episode storage
self._episode_storage.add(self._last_original_obs, next_obs, buffer_action, reward_, done, infos)
self._last_obs = new_obs
self.model._last_obs = self._last_obs
# Save the unnormalized new observation
if self._vec_normalize_env is not None:
self._last_original_obs = new_obs_
self.model._last_original_obs = self._last_original_obs
self.model._update_current_progress_remaining(self.num_timesteps, self._total_timesteps)
# For DQN, check if the target network should be updated
# and update the exploration schedule
# For SAC/TD3, the update is done as the same time as the gradient update
# see https://github.com/hill-a/stable-baselines/issues/900
self.model._on_step()
self.episode_steps += 1
if not should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes):
break
if done or self.episode_steps >= self.max_episode_length:
if self.online_sampling:
self.replay_buffer.store_episode()
else:
self._episode_storage.store_episode()
# sample virtual transitions and store them in replay buffer
self._sample_her_transitions()
# clear storage for current episode
self._episode_storage.reset()
num_collected_episodes += 1
self._episode_num += 1
self.model._episode_num = self._episode_num
episode_rewards.append(episode_reward)
total_timesteps.append(episode_timesteps)
if action_noise is not None:
action_noise.reset()
# Log training infos
if log_interval is not None and self._episode_num % log_interval == 0:
self._dump_logs()
self.episode_steps = 0
mean_reward = np.mean(episode_rewards) if num_collected_episodes > 0 else 0.0
callback.on_rollout_end()
return RolloutReturn(mean_reward, num_collected_steps, num_collected_episodes, continue_training)
def _sample_her_transitions(self) -> None:
"""
Sample additional goals and store new transitions in replay buffer
when using offline sampling.
"""
# Sample goals and get new observations
# maybe_vec_env=None as we should store unnormalized transitions,
# they will be normalized at sampling time
observations, next_observations, actions, rewards = self._episode_storage.sample_offline(
n_sampled_goal=self.n_sampled_goal
)
# store data in replay buffer
dones = np.zeros((len(observations)), dtype=bool)
self.replay_buffer.extend(observations, next_observations, actions, rewards, dones)
def __getattr__(self, item: str) -> Any:
"""
Find attribute from model class if this class does not have it.
"""
if hasattr(self.model, item):
return getattr(self.model, item)
else:
raise AttributeError(f"{self} has no attribute {item}")
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
return self.model._get_torch_save_params()
def save(
self,
path: Union[str, pathlib.Path, io.BufferedIOBase],
exclude: Optional[Iterable[str]] = None,
include: Optional[Iterable[str]] = None,
) -> None:
"""
Save all the attributes of the object and the model parameters in a zip-file.
:param path: path to the file where the rl agent should be saved
:param exclude: name of parameters that should be excluded in addition to the default one
:param include: name of parameters that might be excluded but should be included anyway
"""
# add HER parameters to model
self.model.n_sampled_goal = self.n_sampled_goal
self.model.goal_selection_strategy = self.goal_selection_strategy
self.model.online_sampling = self.online_sampling
self.model.model_class = self.model_class
self.model.max_episode_length = self.max_episode_length
self.model.save(path, exclude, include)
@classmethod
def load(
cls,
path: Union[str, pathlib.Path, io.BufferedIOBase],
env: Optional[GymEnv] = None,
device: Union[th.device, str] = "auto",
custom_objects: Optional[Dict[str, Any]] = None,
**kwargs,
) -> "BaseAlgorithm":
"""
Load the model from a zip-file
:param path: path to the file (or a file-like) where to
load the agent from
:param env: the new environment to run the loaded model on
(can be None if you only need prediction from a trained model) has priority over any saved environment
:param device: Device on which the code should run.
:param custom_objects: Dictionary of objects to replace
upon loading. If a variable is present in this dictionary as a
key, it will not be deserialized and the corresponding item
will be used instead. Similar to custom_objects in
``keras.models.load_model``. Useful when you have an object in
file that can not be deserialized.
:param kwargs: extra arguments to change the model when loading
"""
data, params, pytorch_variables = load_from_zip_file(path, device=device, custom_objects=custom_objects)
# Remove stored device information and replace with ours
if "policy_kwargs" in data:
if "device" in data["policy_kwargs"]:
del data["policy_kwargs"]["device"]
if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]:
raise ValueError(
f"The specified policy kwargs do not equal the stored policy kwargs."
f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}"
)
# check if observation space and action space are part of the saved parameters
if "observation_space" not in data or "action_space" not in data:
raise KeyError("The observation_space and action_space were not given, can't verify new environments")
# check if given env is valid
if env is not None:
# Wrap first if needed
env = cls._wrap_env(env, data["verbose"])
# Check if given env is valid
check_for_correct_spaces(env, data["observation_space"], data["action_space"])
else:
# Use stored env, if one exists. If not, continue as is (can be used for predict)
if "env" in data:
env = data["env"]
if "use_sde" in data and data["use_sde"]:
kwargs["use_sde"] = True
# Keys that cannot be changed
for key in {"model_class", "online_sampling", "max_episode_length"}:
if key in kwargs:
del kwargs[key]
# Keys that can be changed
for key in {"n_sampled_goal", "goal_selection_strategy"}:
if key in kwargs:
data[key] = kwargs[key] # pytype: disable=unsupported-operands
del kwargs[key]
# noinspection PyArgumentList
her_model = cls(
policy=data["policy_class"],
env=env,
model_class=data["model_class"],
n_sampled_goal=data["n_sampled_goal"],
goal_selection_strategy=data["goal_selection_strategy"],
online_sampling=data["online_sampling"],
max_episode_length=data["max_episode_length"],
policy_kwargs=data["policy_kwargs"],
_init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args
**kwargs,
)
# load parameters
her_model.model.__dict__.update(data)
her_model.model.__dict__.update(kwargs)
her_model._setup_model()
her_model._total_timesteps = her_model.model._total_timesteps
her_model.num_timesteps = her_model.model.num_timesteps
her_model._episode_num = her_model.model._episode_num
# put state_dicts back in place
her_model.model.set_parameters(params, exact_match=True, device=device)
# put other pytorch variables back in place
if pytorch_variables is not None:
for name in pytorch_variables:
recursive_setattr(her_model.model, name, pytorch_variables[name])
# Sample gSDE exploration matrix, so it uses the right device
# see issue #44
if her_model.model.use_sde:
her_model.model.policy.reset_noise() # pytype: disable=attribute-error
return her_model
def load_replay_buffer(
self, path: Union[str, pathlib.Path, io.BufferedIOBase], truncate_last_trajectory: bool = True
) -> None:
"""
Load a replay buffer from a pickle file and set environment for replay buffer (only online sampling).
:param path: Path to the pickled replay buffer.
:param truncate_last_trajectory: Only for online sampling.
If set to ``True`` we assume that the last trajectory in the replay buffer was finished.
If it is set to ``False`` we assume that we continue the same trajectory (same episode).
"""
self.model.load_replay_buffer(path=path)
if self.online_sampling:
# set environment
self.replay_buffer.set_env(self.env)
# If we are at the start of an episode, no need to truncate
current_idx = self.replay_buffer.current_idx
# truncate interrupted episode
if truncate_last_trajectory and current_idx > 0:
warnings.warn(
"The last trajectory in the replay buffer will be truncated.\n"
"If you are in the same episode as when the replay buffer was saved,\n"
"you should use `truncate_last_trajectory=False` to avoid that issue."
)
# get current episode and transition index
pos = self.replay_buffer.pos
# set episode length for current episode
self.replay_buffer.episode_lengths[pos] = current_idx
# set done = True for current episode
# current_idx was already incremented
self.replay_buffer.buffer["done"][pos][current_idx - 1] = np.array([True], dtype=np.float32)
# reset current transition index
self.replay_buffer.current_idx = 0
# increment episode counter
self.replay_buffer.pos = (self.replay_buffer.pos + 1) % self.replay_buffer.max_episode_stored
# update "full" indicator
self.replay_buffer.full = self.replay_buffer.full or self.replay_buffer.pos == 0