stable-baselines3/stable_baselines3/common/base_class.py

575 lines
24 KiB
Python
Raw Normal View History

2020-07-03 01:49:59 +00:00
"""Abstract base classes for RL algorithms."""
import io
import pathlib
2019-10-10 11:47:13 +00:00
import time
from abc import ABC, abstractmethod
from collections import deque
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
2019-09-05 15:29:41 +00:00
import gym
2019-09-12 09:19:06 +00:00
import numpy as np
import torch as th
2019-09-05 15:29:41 +00:00
from stable_baselines3.common import logger, utils
2020-05-05 13:02:35 +00:00
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.policies import BasePolicy, get_policy_from_name
from stable_baselines3.common.preprocessing import is_image_space
from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.common.utils import (
check_for_correct_spaces,
get_device,
get_schedule_fn,
set_random_seed,
update_learning_rate,
)
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecNormalize, VecTransposeImage, unwrap_vec_normalize
2019-09-05 15:29:41 +00:00
2020-07-03 01:49:59 +00:00
def maybe_make_env(env: Union[GymEnv, str, None], monitor_wrapper: bool, verbose: int) -> Optional[GymEnv]:
"""If env is a string, make the environment; otherwise, return env.
:param env: (Union[GymEnv, str, None]) The environment to learn from.
:param monitor_wrapper: (bool) Whether to wrap env in a Monitor when creating env.
:param verbose: (int) logging verbosity
:return A Gym (vector) environment.
"""
if isinstance(env, str):
if verbose >= 1:
print(f"Creating environment from the given name '{env}'")
env = gym.make(env)
if monitor_wrapper:
env = Monitor(env, filename=None)
return env
class BaseAlgorithm(ABC):
2019-09-05 15:29:41 +00:00
"""
The base of RL algorithms
2019-09-05 15:29:41 +00:00
:param policy: (Type[BasePolicy]) Policy object
2020-07-03 01:49:59 +00:00
:param env: (Union[GymEnv, str, None]) The environment to learn from
2019-09-05 15:29:41 +00:00
(if registered in Gym, can be str. Can be None for loading trained models)
:param policy_base: (Type[BasePolicy]) The base policy used by this method
:param learning_rate: (float or callable) learning rate for the optimizer,
it can be a function of the current progress remaining (from 1 to 0)
:param policy_kwargs: (Dict[str, Any]) Additional arguments to be passed to the policy on creation
:param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
:param verbose: (int) The verbosity level: 0 none, 1 training information, 2 debug
:param device: (Union[th.device, str]) Device on which the code should run.
2019-09-12 09:19:06 +00:00
By default, it will try to use a Cuda compatible device and fallback to cpu
if it is not possible.
:param support_multi_env: (bool) Whether the algorithm supports training
2019-11-22 12:33:12 +00:00
with multiple environments (as in A2C)
:param create_eval_env: (bool) Whether to create a second environment that will be
2019-11-22 12:33:12 +00:00
used for evaluating the agent periodically. (Only available when passing string for the environment)
:param monitor_wrapper: (bool) When creating an environment, whether to wrap it
2019-10-10 11:47:13 +00:00
or not in a Monitor wrapper.
:param seed: (Optional[int]) Seed for the pseudo random generators
:param use_sde: (bool) Whether to use generalized State Dependent Exploration (gSDE)
2019-11-26 14:26:12 +00:00
instead of action noise exploration (default: False)
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using gSDE
2019-12-17 10:47:21 +00:00
Default: -1 (only sample at the beginning of the rollout)
2019-09-05 15:29:41 +00:00
"""
2020-01-27 14:53:27 +00:00
def __init__(
self,
policy: Type[BasePolicy],
env: Union[GymEnv, str, None],
policy_base: Type[BasePolicy],
learning_rate: Union[float, Callable],
policy_kwargs: Dict[str, Any] = None,
tensorboard_log: Optional[str] = None,
verbose: int = 0,
device: Union[th.device, str] = "auto",
support_multi_env: bool = False,
create_eval_env: bool = False,
monitor_wrapper: bool = True,
seed: Optional[int] = None,
use_sde: bool = False,
sde_sample_freq: int = -1,
):
2020-01-22 16:51:27 +00:00
2019-09-06 08:44:55 +00:00
if isinstance(policy, str) and policy_base is not None:
self.policy_class = get_policy_from_name(policy_base, policy)
2019-09-06 08:44:55 +00:00
else:
self.policy_class = policy
2019-09-12 09:19:06 +00:00
self.device = get_device(device)
2019-09-12 09:19:06 +00:00
if verbose > 0:
2020-01-22 15:39:25 +00:00
print(f"Using {self.device} device")
2019-09-12 09:19:06 +00:00
2020-03-23 16:15:30 +00:00
self.env = None # type: Optional[GymEnv]
2019-11-14 13:35:00 +00:00
# get VecNormalize object if needed
self._vec_normalize_env = unwrap_vec_normalize(env)
2019-09-05 15:29:41 +00:00
self.verbose = verbose
self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs
self.observation_space = None # type: Optional[gym.spaces.Space]
self.action_space = None # type: Optional[gym.spaces.Space]
2019-09-05 15:29:41 +00:00
self.n_envs = None
self.num_timesteps = 0
Implement DQN (#28) * Created DQN template according to the paper. Next steps: - Create Policy - Complete Training - Debug * Changed Base Class * refactor save, to be consistence with overriding the excluded_save_params function. Do not try to exclude the parameters twice. * Added simple DQN policy * Finished learn and train function - missing correct loss computation * changed collect_rollouts to work with discrete space * moved discrete space collect_rollouts to dqn * basic dqn working * deleted SDE related code * added gradient clipping and moved greedy policy to policy * changed policy to implement target network and added soft update(in fact standart tau is 1 so hard update) * fixed policy setup * rebase target_update_intervall on _n_updates * adapted all tests all tests passing * Move to stable-baseline3 * Fixes for DQN * Fix tests + add CNNPolicy * Allow any optimizer for DQN * added some util functions to create a arbitrary linear schedule, fixed pickle problem with old exploration schedule * more documentation * changed buffer dtype * refactor and document * Added Sphinx Documentation Updated changelog.rst * removed custom collect_rollouts as it is no longer necessary * Implemented suggestions to clean code and documentation. * extracted some functions on tests to reduce duplicated code * added support for exploration_fraction * Fixed exploration_fraction * Added documentation * Fixed get_linear_fn -> proper progress scaling * Merged master * Added nature reference * Changed default parameters to https://www.nature.com/articles/nature14236/tables/1 * Fixed n_updates to be incremented correctly * Correct train_freq * Doc update * added special parameter for DQN in tests * different fix for test_discrete * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Added RMSProp in optimizer_kwargs, as described in nature paper * Exploration fraction is inverse of 50.000.000 (total frames) / 1.000.000 (frames with linear schedule) according to nature paper * Changelog update for buffer dtype * standard exlude parameters should be always excluded to assure proper saving only if intentionally included by ``include`` parameter * slightly more iterations on test_discrete to pass the test * added param use_rms_prop instead of mutable default argument * forgot alpha * using huber loss, adam and learning rate 1e-4 * account for train_freq in update_target_network * Added memory check for both buffers * Doc updated for buffer allocation * Added psutil Requirement * Adapted test_identity.py * Fixes with new SB3 version * Fix for tensorboard name * Convert assert to warning and fix tests * Refactor off-policy algorithms * Fixes * test: remove next_obs in replay buffer * Update changelog * Fix tests and use tmp_path where possible * Fix sampling bug in buffer * Do not store next obs on episode termination * Fix replay buffer sampling * Update comment * moved epsilon from policy to model * Update predict method * Update atari wrappers to match SB2 * Minor edit in the buffers * Update changelog * Merge branch 'master' into dqn * Update DQN to new structure * Fix tests and remove hardcoded path * Fix for DQN * Disable memory efficient replay buffer by default * Fix docstring * Add tests for memory efficient buffer * Update changelog * Split collect rollout * Move target update outside `train()` for DQN * Update changelog * Update linear schedule doc * Cleanup DQN code * Minor edit * Update version and docker images Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-06-29 09:16:54 +00:00
# Used for updating schedules
self._total_timesteps = 0
self.eval_env = None
2019-10-10 11:47:13 +00:00
self.seed = seed
self.action_noise = None # type: Optional[ActionNoise]
2020-01-22 16:17:12 +00:00
self.start_time = None
2020-02-03 17:18:41 +00:00
self.policy = None
self.learning_rate = learning_rate
self.tensorboard_log = tensorboard_log
2020-04-17 10:36:27 +00:00
self.lr_schedule = None # type: Optional[Callable]
self._last_obs = None # type: Optional[np.ndarray]
self._last_dones = None # type: Optional[np.ndarray]
2020-04-17 10:36:27 +00:00
# When using VecNormalize:
self._last_original_obs = None # type: Optional[np.ndarray]
self._episode_num = 0
# Used for gSDE only
2019-11-26 14:26:12 +00:00
self.use_sde = use_sde
2019-12-17 10:47:21 +00:00
self.sde_sample_freq = sde_sample_freq
# Track the training progress remaining (from 1 to 0)
2019-10-28 15:47:13 +00:00
# this is used to update the learning rate
self._current_progress_remaining = 1
2020-02-04 12:24:09 +00:00
# Buffers for logging
self.ep_info_buffer = None # type: Optional[deque]
self.ep_success_buffer = None # type: Optional[deque]
2020-03-13 10:43:12 +00:00
# For logging
self._n_updates = 0 # type: int
2019-09-05 15:29:41 +00:00
2019-11-22 12:33:12 +00:00
# Create and wrap the env if needed
2019-09-05 15:29:41 +00:00
if env is not None:
2019-09-20 13:19:04 +00:00
if isinstance(env, str):
if create_eval_env:
2020-07-03 01:49:59 +00:00
self.eval_env = maybe_make_env(env, monitor_wrapper, self.verbose)
2019-09-20 13:19:04 +00:00
2020-07-03 01:49:59 +00:00
env = maybe_make_env(env, monitor_wrapper, self.verbose)
env = self._wrap_env(env)
2019-09-05 15:29:41 +00:00
self.observation_space = env.observation_space
self.action_space = env.action_space
2019-09-20 13:19:04 +00:00
self.n_envs = env.num_envs
self.env = env
if not support_multi_env and self.n_envs > 1:
raise ValueError(
"Error: the model does not support multiple envs; it requires " "a single vectorized environment."
)
2019-09-20 13:19:04 +00:00
if self.use_sde and not isinstance(self.observation_space, gym.spaces.Box):
raise ValueError("generalized State-Dependent Exploration (gSDE) can only be used with continuous actions.")
def _wrap_env(self, env: GymEnv) -> VecEnv:
if not isinstance(env, VecEnv):
if self.verbose >= 1:
print("Wrapping the env in a DummyVecEnv.")
env = DummyVecEnv([lambda: env])
2020-04-22 16:28:00 +00:00
if is_image_space(env.observation_space) and not isinstance(env, VecTransposeImage):
if self.verbose >= 1:
print("Wrapping the env in a VecTransposeImage.")
env = VecTransposeImage(env)
return env
@abstractmethod
def _setup_model(self) -> None:
2020-07-03 01:49:59 +00:00
"""Create networks, buffer and optimizers."""
2020-01-27 13:32:31 +00:00
def _get_eval_env(self, eval_env: Optional[GymEnv]) -> Optional[GymEnv]:
2019-11-22 12:33:12 +00:00
"""
Return the environment that will be used for evaluation.
:param eval_env: (Optional[GymEnv]))
:return: (Optional[GymEnv])
2019-11-22 12:33:12 +00:00
"""
if eval_env is None:
eval_env = self.eval_env
if eval_env is not None:
eval_env = self._wrap_env(eval_env)
assert eval_env.num_envs == 1
return eval_env
2019-09-05 15:29:41 +00:00
2020-03-16 13:05:21 +00:00
def _setup_lr_schedule(self) -> None:
2019-10-28 15:47:13 +00:00
"""Transform to callable if needed."""
self.lr_schedule = get_schedule_fn(self.learning_rate)
2019-10-28 15:47:13 +00:00
def _update_current_progress_remaining(self, num_timesteps: int, total_timesteps: int) -> None:
2019-10-28 15:47:13 +00:00
"""
Compute current progress remaining (starts from 1 and ends to 0)
2019-10-28 15:47:13 +00:00
2020-01-22 16:51:27 +00:00
:param num_timesteps: current number of timesteps
:param total_timesteps:
2019-10-28 15:47:13 +00:00
"""
self._current_progress_remaining = 1.0 - float(num_timesteps) / float(total_timesteps)
2019-10-28 15:47:13 +00:00
2020-01-22 16:51:27 +00:00
def _update_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.optim.Optimizer]) -> None:
2019-10-28 16:42:39 +00:00
"""
Update the optimizers learning rate using the current learning rate schedule
and the current progress remaining (from 1 to 0).
2019-10-28 16:42:39 +00:00
:param optimizers: (Union[List[th.optim.Optimizer], th.optim.Optimizer])
An optimizer or a list of optimizers.
2019-10-28 16:42:39 +00:00
"""
# Log the current learning rate
logger.record("train/learning_rate", self.lr_schedule(self._current_progress_remaining))
2019-10-28 15:47:13 +00:00
if not isinstance(optimizers, list):
optimizers = [optimizers]
for optimizer in optimizers:
update_learning_rate(optimizer, self.lr_schedule(self._current_progress_remaining))
2019-10-10 11:47:13 +00:00
2020-01-27 13:32:31 +00:00
def get_env(self) -> Optional[VecEnv]:
2019-09-05 15:29:41 +00:00
"""
2020-01-20 15:19:35 +00:00
Returns the current environment (can be None if not defined).
2019-09-05 15:29:41 +00:00
2020-02-12 10:34:29 +00:00
:return: (Optional[VecEnv]) The current environment
2019-09-05 15:29:41 +00:00
"""
return self.env
2020-02-12 10:34:29 +00:00
def get_vec_normalize_env(self) -> Optional[VecNormalize]:
"""
2020-05-05 15:19:21 +00:00
Return the ``VecNormalize`` wrapper of the training env
2020-02-12 10:34:29 +00:00
if it exists.
2020-05-05 15:19:21 +00:00
:return: Optional[VecNormalize] The ``VecNormalize`` env.
2020-02-12 10:34:29 +00:00
"""
return self._vec_normalize_env
2020-01-27 13:32:31 +00:00
def set_env(self, env: GymEnv) -> None:
2019-09-05 15:29:41 +00:00
"""
Checks the validity of the environment, and if it is coherent, set it as the current environment.
Furthermore wrap any non vectorized env into a vectorized
2019-12-05 08:11:30 +00:00
checked parameters:
2020-01-20 15:19:35 +00:00
- observation_space
- action_space
2019-09-05 15:29:41 +00:00
2020-01-22 16:51:27 +00:00
:param env: The environment for learning a policy
2019-09-05 15:29:41 +00:00
"""
check_for_correct_spaces(env, self.observation_space, self.action_space)
# it must be coherent now
# if it is not a VecEnv, make it a VecEnv
env = self._wrap_env(env)
self.n_envs = env.num_envs
2019-12-05 08:11:30 +00:00
self.env = env
2019-09-05 15:29:41 +00:00
def get_torch_variables(self) -> Tuple[List[str], List[str]]:
2019-12-05 07:50:11 +00:00
"""
2020-07-03 01:49:59 +00:00
Get the name of the torch variables that will be saved.
2020-05-05 15:19:21 +00:00
``th.save`` and ``th.load`` will be used with the right device
instead of the default pickling strategy.
2019-09-05 15:29:41 +00:00
:return: (Tuple[List[str], List[str]])
name of the variables with state dicts to save, name of additional torch tensors,
2019-09-05 15:29:41 +00:00
"""
state_dicts = ["policy"]
2019-09-05 15:29:41 +00:00
return state_dicts, []
2019-09-05 15:29:41 +00:00
@abstractmethod
def learn(
self,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 100,
tb_log_name: str = "run",
eval_env: Optional[GymEnv] = None,
eval_freq: int = -1,
n_eval_episodes: int = 5,
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
) -> "BaseAlgorithm":
2019-09-05 15:29:41 +00:00
"""
Return a trained model.
:param total_timesteps: (int) The total number of samples (env steps) to train on
2020-07-03 01:49:59 +00:00
:param callback: (MaybeCallback) callback(s) called at every step with state of the algorithm.
2019-09-05 15:29:41 +00:00
:param log_interval: (int) The number of timesteps before logging.
2020-07-03 01:49:59 +00:00
:param tb_log_name: (str) the name of the run for TensorBoard logging
2020-01-07 13:00:03 +00:00
:param eval_env: (gym.Env) Environment that will be used to evaluate the agent
2020-05-05 15:19:21 +00:00
:param eval_freq: (int) Evaluate the agent every ``eval_freq`` timesteps (this may vary a little)
2020-01-07 13:00:03 +00:00
:param n_eval_episodes: (int) Number of episode to evaluate the agent
2020-01-27 14:53:27 +00:00
:param eval_log_path: (Optional[str]) Path to a folder where the evaluations will be saved
:param reset_num_timesteps: (bool) whether or not to reset the current timestep number (used in logging)
:return: (BaseAlgorithm) the trained model
2019-09-05 15:29:41 +00:00
"""
def predict(
self,
observation: np.ndarray,
state: Optional[np.ndarray] = None,
mask: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
2019-09-05 15:29:41 +00:00
"""
2020-02-12 14:25:05 +00:00
Get the model's action(s) from an observation
:param observation: (np.ndarray) the input observation
:param state: (Optional[np.ndarray]) The last states (can be None, used in recurrent policies)
:param mask: (Optional[np.ndarray]) The last masks (can be None, used in recurrent policies)
:param deterministic: (bool) Whether or not to return deterministic actions.
2020-03-23 16:15:30 +00:00
:return: (Tuple[np.ndarray, Optional[np.ndarray]]) the model's action and the next state
(used in recurrent policies)
2020-02-12 14:25:05 +00:00
"""
return self.policy.predict(observation, state, mask, deterministic)
2019-09-05 15:29:41 +00:00
@classmethod
def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> "BaseAlgorithm":
2019-09-05 15:29:41 +00:00
"""
2019-11-28 15:30:13 +00:00
Load the model from a zip-file
2019-09-05 15:29:41 +00:00
2020-01-22 16:51:27 +00:00
:param load_path: the location of the saved data
:param env: the new environment to run the loaded model on
2019-11-28 15:30:13 +00:00
(can be None if you only need prediction from a trained model) has priority over any saved environment
2019-09-05 15:29:41 +00:00
:param kwargs: extra arguments to change the model when loading
"""
data, params, tensors = load_from_zip_file(load_path)
2019-11-12 16:03:57 +00:00
if "policy_kwargs" in data:
for arg_to_remove in ["device"]:
if arg_to_remove in data["policy_kwargs"]:
del data["policy_kwargs"][arg_to_remove]
2020-05-05 15:19:21 +00:00
if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]:
raise ValueError(
f"The specified policy kwargs do not equal the stored policy kwargs."
f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}"
)
2019-11-12 16:03:57 +00:00
2019-12-05 13:46:02 +00:00
# check if observation space and action space are part of the saved parameters
if "observation_space" not in data or "action_space" not in data:
raise KeyError("The observation_space and action_space were not given, can't verify new environments")
# check if given env is valid
2020-04-27 09:12:19 +00:00
if env is not None:
check_for_correct_spaces(env, data["observation_space"], data["action_space"])
# if no new env was given use stored env if possible
2019-11-28 15:30:13 +00:00
if env is None and "env" in data:
env = data["env"]
2020-01-20 10:17:55 +00:00
# noinspection PyArgumentList
model = cls(
policy=data["policy_class"],
env=env,
device="auto",
_init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args
)
# load parameters
2019-11-12 16:03:57 +00:00
model.__dict__.update(data)
model.__dict__.update(kwargs)
model._setup_model()
# put state_dicts back in place
for name in params:
attr = recursive_getattr(model, name)
attr.load_state_dict(params[name])
# put tensors back in place
if tensors is not None:
for name in tensors:
recursive_setattr(model, name, tensors[name])
# Sample gSDE exploration matrix, so it uses the right device
# see issue #44
if model.use_sde:
model.policy.reset_noise() # pytype: disable=attribute-error
2019-11-12 16:03:57 +00:00
return model
2020-01-22 16:51:27 +00:00
def set_random_seed(self, seed: Optional[int] = None) -> None:
2019-11-18 14:11:19 +00:00
"""
Set the seed of the pseudo-random generators
(python, numpy, pytorch, gym, action_space)
:param seed: (int)
"""
2019-10-31 13:14:30 +00:00
if seed is None:
return
set_random_seed(seed, using_cuda=self.device == th.device("cuda"))
2019-09-21 13:53:28 +00:00
self.action_space.seed(seed)
2019-09-18 11:10:27 +00:00
if self.env is not None:
self.env.seed(seed)
2019-09-21 13:53:28 +00:00
if self.eval_env is not None:
self.eval_env.seed(seed)
def _init_callback(
self,
callback: MaybeCallback,
eval_env: Optional[VecEnv] = None,
eval_freq: int = 10000,
n_eval_episodes: int = 5,
log_path: Optional[str] = None,
) -> BaseCallback:
2020-01-27 14:53:27 +00:00
"""
2020-07-03 01:49:59 +00:00
:param callback: (MaybeCallback) Callback(s) called at every step with state of the algorithm.
:param eval_freq: (Optional[int]) How many steps between evaluations; if None, do not evaluate.
:param n_eval_episodes: (int) How many episodes to play per evaluation
:param n_eval_episodes: (int) Number of episodes to rollout during evaluation.
:param log_path: (Optional[str]) Path to a folder where the evaluations will be saved
:return: (BaseCallback) A hybrid callback calling `callback` and performing evaluation.
2020-01-27 14:53:27 +00:00
"""
# Convert a list of callbacks into a callback
if isinstance(callback, list):
callback = CallbackList(callback)
# Convert functional callback to object
if not isinstance(callback, BaseCallback):
callback = ConvertCallback(callback)
# Create eval callback in charge of the evaluation
if eval_env is not None:
eval_callback = EvalCallback(
eval_env,
best_model_save_path=log_path,
log_path=log_path,
eval_freq=eval_freq,
n_eval_episodes=n_eval_episodes,
)
2020-01-27 14:53:27 +00:00
callback = CallbackList([callback, eval_callback])
callback.init_callback(self)
return callback
def _setup_learn(
self,
total_timesteps: int,
eval_env: Optional[GymEnv],
callback: MaybeCallback = None,
eval_freq: int = 10000,
n_eval_episodes: int = 5,
log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
tb_log_name: str = "run",
) -> Tuple[int, BaseCallback]:
2019-11-22 12:33:12 +00:00
"""
Initialize different variables needed for training.
:param total_timesteps: (int) The total number of samples (env steps) to train on
2020-07-03 01:49:59 +00:00
:param eval_env: (Optional[VecEnv]) Environment to use for evaluation.
:param callback: (MaybeCallback) Callback(s) called at every step with state of the algorithm.
:param eval_freq: (int) How many steps between evaluations
:param n_eval_episodes: (int) How many episodes to play per evaluation
2020-07-03 01:49:59 +00:00
:param log_path: (Optional[str]) Path to a folder where the evaluations will be saved
2020-05-05 15:19:21 +00:00
:param reset_num_timesteps: (bool) Whether to reset or not the ``num_timesteps`` attribute
:param tb_log_name: (str) the name of the run for tensorboard log
Implement DQN (#28) * Created DQN template according to the paper. Next steps: - Create Policy - Complete Training - Debug * Changed Base Class * refactor save, to be consistence with overriding the excluded_save_params function. Do not try to exclude the parameters twice. * Added simple DQN policy * Finished learn and train function - missing correct loss computation * changed collect_rollouts to work with discrete space * moved discrete space collect_rollouts to dqn * basic dqn working * deleted SDE related code * added gradient clipping and moved greedy policy to policy * changed policy to implement target network and added soft update(in fact standart tau is 1 so hard update) * fixed policy setup * rebase target_update_intervall on _n_updates * adapted all tests all tests passing * Move to stable-baseline3 * Fixes for DQN * Fix tests + add CNNPolicy * Allow any optimizer for DQN * added some util functions to create a arbitrary linear schedule, fixed pickle problem with old exploration schedule * more documentation * changed buffer dtype * refactor and document * Added Sphinx Documentation Updated changelog.rst * removed custom collect_rollouts as it is no longer necessary * Implemented suggestions to clean code and documentation. * extracted some functions on tests to reduce duplicated code * added support for exploration_fraction * Fixed exploration_fraction * Added documentation * Fixed get_linear_fn -> proper progress scaling * Merged master * Added nature reference * Changed default parameters to https://www.nature.com/articles/nature14236/tables/1 * Fixed n_updates to be incremented correctly * Correct train_freq * Doc update * added special parameter for DQN in tests * different fix for test_discrete * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Added RMSProp in optimizer_kwargs, as described in nature paper * Exploration fraction is inverse of 50.000.000 (total frames) / 1.000.000 (frames with linear schedule) according to nature paper * Changelog update for buffer dtype * standard exlude parameters should be always excluded to assure proper saving only if intentionally included by ``include`` parameter * slightly more iterations on test_discrete to pass the test * added param use_rms_prop instead of mutable default argument * forgot alpha * using huber loss, adam and learning rate 1e-4 * account for train_freq in update_target_network * Added memory check for both buffers * Doc updated for buffer allocation * Added psutil Requirement * Adapted test_identity.py * Fixes with new SB3 version * Fix for tensorboard name * Convert assert to warning and fix tests * Refactor off-policy algorithms * Fixes * test: remove next_obs in replay buffer * Update changelog * Fix tests and use tmp_path where possible * Fix sampling bug in buffer * Do not store next obs on episode termination * Fix replay buffer sampling * Update comment * moved epsilon from policy to model * Update predict method * Update atari wrappers to match SB2 * Minor edit in the buffers * Update changelog * Merge branch 'master' into dqn * Update DQN to new structure * Fix tests and remove hardcoded path * Fix for DQN * Disable memory efficient replay buffer by default * Fix docstring * Add tests for memory efficient buffer * Update changelog * Split collect rollout * Move target update outside `train()` for DQN * Update changelog * Update linear schedule doc * Cleanup DQN code * Minor edit * Update version and docker images Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-06-29 09:16:54 +00:00
:return: (Tuple[int, BaseCallback])
2019-11-22 12:33:12 +00:00
"""
2019-10-10 11:47:13 +00:00
self.start_time = time.time()
if self.ep_info_buffer is None or reset_num_timesteps:
# Initialize buffers if they don't exist, or reinitialize if resetting counters
self.ep_info_buffer = deque(maxlen=100)
self.ep_success_buffer = deque(maxlen=100)
2019-11-22 12:33:12 +00:00
2019-10-10 11:47:13 +00:00
if self.action_noise is not None:
self.action_noise.reset()
2019-11-22 12:33:12 +00:00
2020-01-31 12:16:28 +00:00
if reset_num_timesteps:
self.num_timesteps = 0
2020-04-17 10:36:27 +00:00
self._episode_num = 0
else:
# Make sure training timesteps are ahead of the internal counter
total_timesteps += self.num_timesteps
Implement DQN (#28) * Created DQN template according to the paper. Next steps: - Create Policy - Complete Training - Debug * Changed Base Class * refactor save, to be consistence with overriding the excluded_save_params function. Do not try to exclude the parameters twice. * Added simple DQN policy * Finished learn and train function - missing correct loss computation * changed collect_rollouts to work with discrete space * moved discrete space collect_rollouts to dqn * basic dqn working * deleted SDE related code * added gradient clipping and moved greedy policy to policy * changed policy to implement target network and added soft update(in fact standart tau is 1 so hard update) * fixed policy setup * rebase target_update_intervall on _n_updates * adapted all tests all tests passing * Move to stable-baseline3 * Fixes for DQN * Fix tests + add CNNPolicy * Allow any optimizer for DQN * added some util functions to create a arbitrary linear schedule, fixed pickle problem with old exploration schedule * more documentation * changed buffer dtype * refactor and document * Added Sphinx Documentation Updated changelog.rst * removed custom collect_rollouts as it is no longer necessary * Implemented suggestions to clean code and documentation. * extracted some functions on tests to reduce duplicated code * added support for exploration_fraction * Fixed exploration_fraction * Added documentation * Fixed get_linear_fn -> proper progress scaling * Merged master * Added nature reference * Changed default parameters to https://www.nature.com/articles/nature14236/tables/1 * Fixed n_updates to be incremented correctly * Correct train_freq * Doc update * added special parameter for DQN in tests * different fix for test_discrete * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Added RMSProp in optimizer_kwargs, as described in nature paper * Exploration fraction is inverse of 50.000.000 (total frames) / 1.000.000 (frames with linear schedule) according to nature paper * Changelog update for buffer dtype * standard exlude parameters should be always excluded to assure proper saving only if intentionally included by ``include`` parameter * slightly more iterations on test_discrete to pass the test * added param use_rms_prop instead of mutable default argument * forgot alpha * using huber loss, adam and learning rate 1e-4 * account for train_freq in update_target_network * Added memory check for both buffers * Doc updated for buffer allocation * Added psutil Requirement * Adapted test_identity.py * Fixes with new SB3 version * Fix for tensorboard name * Convert assert to warning and fix tests * Refactor off-policy algorithms * Fixes * test: remove next_obs in replay buffer * Update changelog * Fix tests and use tmp_path where possible * Fix sampling bug in buffer * Do not store next obs on episode termination * Fix replay buffer sampling * Update comment * moved epsilon from policy to model * Update predict method * Update atari wrappers to match SB2 * Minor edit in the buffers * Update changelog * Merge branch 'master' into dqn * Update DQN to new structure * Fix tests and remove hardcoded path * Fix for DQN * Disable memory efficient replay buffer by default * Fix docstring * Add tests for memory efficient buffer * Update changelog * Split collect rollout * Move target update outside `train()` for DQN * Update changelog * Update linear schedule doc * Cleanup DQN code * Minor edit * Update version and docker images Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-06-29 09:16:54 +00:00
self._total_timesteps = total_timesteps
2020-04-17 10:36:27 +00:00
2020-05-05 15:19:21 +00:00
# Avoid resetting the environment when calling ``.learn()`` consecutive times
2020-04-17 10:36:27 +00:00
if reset_num_timesteps or self._last_obs is None:
self._last_obs = self.env.reset()
self._last_dones = np.zeros((self.env.num_envs,), dtype=np.bool)
2020-04-17 10:36:27 +00:00
# Retrieve unnormalized observation for saving into the buffer
if self._vec_normalize_env is not None:
self._last_original_obs = self._vec_normalize_env.get_original_obs()
2020-01-31 12:16:28 +00:00
2019-10-10 11:47:13 +00:00
if eval_env is not None and self.seed is not None:
eval_env.seed(self.seed)
2019-11-22 12:33:12 +00:00
2019-10-10 11:47:13 +00:00
eval_env = self._get_eval_env(eval_env)
2020-01-27 13:32:31 +00:00
# Configure logger's outputs
utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
2020-01-27 14:53:27 +00:00
# Create eval callback if needed
callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path)
return total_timesteps, callback
2019-10-10 11:47:13 +00:00
2020-02-04 12:24:09 +00:00
def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.ndarray] = None) -> None:
2019-10-17 11:44:48 +00:00
"""
2019-11-22 12:33:12 +00:00
Retrieve reward and episode length and update the buffer
if using Monitor wrapper.
2019-10-17 11:44:48 +00:00
:param infos: ([dict])
"""
2020-02-04 12:24:09 +00:00
if dones is None:
dones = np.array([False] * len(infos))
for idx, info in enumerate(infos):
maybe_ep_info = info.get("episode")
maybe_is_success = info.get("is_success")
2019-10-17 11:44:48 +00:00
if maybe_ep_info is not None:
self.ep_info_buffer.extend([maybe_ep_info])
2020-02-04 12:24:09 +00:00
if maybe_is_success is not None and dones[idx]:
self.ep_success_buffer.append(maybe_is_success)
2019-10-17 11:44:48 +00:00
2020-02-03 17:18:41 +00:00
def excluded_save_params(self) -> List[str]:
"""
Returns the names of the parameters that should be excluded by default
when saving the model.
:return: ([str]) List of parameters that should be excluded from save
"""
return ["policy", "device", "env", "eval_env", "replay_buffer", "rollout_buffer", "_vec_normalize_env"]
def save(
self,
path: Union[str, pathlib.Path, io.BufferedIOBase],
2020-07-03 01:49:59 +00:00
exclude: Optional[Iterable[str]] = None,
include: Optional[Iterable[str]] = None,
) -> None:
2020-02-03 17:18:41 +00:00
"""
Save all the attributes of the object and the model parameters in a zip-file.
:param (Union[str, pathlib.Path, io.BufferedIOBase]): path to the file where the rl agent should be saved
2020-02-03 17:18:41 +00:00
:param exclude: name of parameters that should be excluded in addition to the default one
:param include: name of parameters that might be excluded but should be included anyway
"""
# copy parameter list so we don't mutate the original dict
data = self.__dict__.copy()
2020-07-03 01:49:59 +00:00
# Exclude is union of specified parameters (if any) and standard exclusions
2020-02-03 17:18:41 +00:00
if exclude is None:
2020-07-03 01:49:59 +00:00
exclude = []
exclude = set(exclude).union(self.excluded_save_params())
2020-02-03 17:18:41 +00:00
2020-07-03 01:49:59 +00:00
# Do not exclude params if they are specifically included
2020-02-03 17:18:41 +00:00
if include is not None:
2020-07-03 01:49:59 +00:00
exclude = exclude.difference(include)
2020-02-03 17:18:41 +00:00
state_dicts_names, tensors_names = self.get_torch_variables()
# any params that are in the save vars must not be saved by data
torch_variables = state_dicts_names + tensors_names
for torch_var in torch_variables:
# we need to get only the name of the top most module as we'll remove that
var_name = torch_var.split(".")[0]
2020-07-03 01:49:59 +00:00
exclude.add(var_name)
2020-02-03 17:18:41 +00:00
# Remove parameter entries of parameters which are to be excluded
for param_name in exclude:
2020-07-03 01:49:59 +00:00
data.pop(param_name, None)
2020-02-03 17:18:41 +00:00
# Build dict of tensor variables
tensors = None
if tensors_names is not None:
tensors = {}
for name in tensors_names:
attr = recursive_getattr(self, name)
tensors[name] = attr
# Build dict of state_dicts
params_to_save = {}
for name in state_dicts_names:
attr = recursive_getattr(self, name)
# Retrieve state dict
params_to_save[name] = attr.state_dict()
2020-02-03 17:18:41 +00:00
save_to_zip_file(path, data=data, params=params_to_save, tensors=tensors)