mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-17 21:20:11 +00:00
542 lines
23 KiB
Python
542 lines
23 KiB
Python
"""Abstract base classes for RL algorithms."""
|
|
|
|
import time
|
|
from typing import Union, Type, Optional, Dict, Any, Iterable, List, Tuple, Callable
|
|
from abc import ABC, abstractmethod
|
|
from collections import deque
|
|
import pathlib
|
|
import io
|
|
|
|
import gym
|
|
import torch as th
|
|
import numpy as np
|
|
|
|
from stable_baselines3.common import logger, utils
|
|
from stable_baselines3.common.policies import BasePolicy, get_policy_from_name
|
|
from stable_baselines3.common.utils import (set_random_seed, get_schedule_fn, update_learning_rate, get_device,
|
|
check_for_correct_spaces)
|
|
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, unwrap_vec_normalize, VecNormalize, VecTransposeImage
|
|
from stable_baselines3.common.preprocessing import is_image_space
|
|
from stable_baselines3.common.save_util import (recursive_getattr, recursive_setattr, save_to_zip_file,
|
|
load_from_zip_file)
|
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
|
|
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback
|
|
from stable_baselines3.common.monitor import Monitor
|
|
from stable_baselines3.common.noise import ActionNoise
|
|
|
|
|
|
def maybe_make_env(env: Union[GymEnv, str, None], monitor_wrapper: bool, verbose: int) -> Optional[GymEnv]:
|
|
"""If env is a string, make the environment; otherwise, return env.
|
|
|
|
:param env: (Union[GymEnv, str, None]) The environment to learn from.
|
|
:param monitor_wrapper: (bool) Whether to wrap env in a Monitor when creating env.
|
|
:param verbose: (int) logging verbosity
|
|
:return A Gym (vector) environment.
|
|
"""
|
|
if isinstance(env, str):
|
|
if verbose >= 1:
|
|
print(f"Creating environment from the given name '{env}'")
|
|
env = gym.make(env)
|
|
if monitor_wrapper:
|
|
env = Monitor(env, filename=None)
|
|
|
|
return env
|
|
|
|
|
|
class BaseAlgorithm(ABC):
|
|
"""
|
|
The base of RL algorithms
|
|
|
|
:param policy: (Type[BasePolicy]) Policy object
|
|
:param env: (Union[GymEnv, str, None]) The environment to learn from
|
|
(if registered in Gym, can be str. Can be None for loading trained models)
|
|
:param policy_base: (Type[BasePolicy]) The base policy used by this method
|
|
:param learning_rate: (float or callable) learning rate for the optimizer,
|
|
it can be a function of the current progress remaining (from 1 to 0)
|
|
:param policy_kwargs: (Dict[str, Any]) Additional arguments to be passed to the policy on creation
|
|
:param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
|
|
:param verbose: (int) The verbosity level: 0 none, 1 training information, 2 debug
|
|
:param device: (Union[th.device, str]) Device on which the code should run.
|
|
By default, it will try to use a Cuda compatible device and fallback to cpu
|
|
if it is not possible.
|
|
:param support_multi_env: (bool) Whether the algorithm supports training
|
|
with multiple environments (as in A2C)
|
|
:param create_eval_env: (bool) Whether to create a second environment that will be
|
|
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
|
:param monitor_wrapper: (bool) When creating an environment, whether to wrap it
|
|
or not in a Monitor wrapper.
|
|
:param seed: (Optional[int]) Seed for the pseudo random generators
|
|
:param use_sde: (bool) Whether to use generalized State Dependent Exploration (gSDE)
|
|
instead of action noise exploration (default: False)
|
|
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using gSDE
|
|
Default: -1 (only sample at the beginning of the rollout)
|
|
"""
|
|
|
|
def __init__(self,
|
|
policy: Type[BasePolicy],
|
|
env: Union[GymEnv, str, None],
|
|
policy_base: Type[BasePolicy],
|
|
learning_rate: Union[float, Callable],
|
|
policy_kwargs: Dict[str, Any] = None,
|
|
tensorboard_log: Optional[str] = None,
|
|
verbose: int = 0,
|
|
device: Union[th.device, str] = 'auto',
|
|
support_multi_env: bool = False,
|
|
create_eval_env: bool = False,
|
|
monitor_wrapper: bool = True,
|
|
seed: Optional[int] = None,
|
|
use_sde: bool = False,
|
|
sde_sample_freq: int = -1):
|
|
|
|
if isinstance(policy, str) and policy_base is not None:
|
|
self.policy_class = get_policy_from_name(policy_base, policy)
|
|
else:
|
|
self.policy_class = policy
|
|
|
|
self.device = get_device(device)
|
|
if verbose > 0:
|
|
print(f"Using {self.device} device")
|
|
|
|
self.env = None # type: Optional[GymEnv]
|
|
# get VecNormalize object if needed
|
|
self._vec_normalize_env = unwrap_vec_normalize(env)
|
|
self.verbose = verbose
|
|
self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs
|
|
self.observation_space = None # type: Optional[gym.spaces.Space]
|
|
self.action_space = None # type: Optional[gym.spaces.Space]
|
|
self.n_envs = None
|
|
self.num_timesteps = 0
|
|
# Used for updating schedules
|
|
self._total_timesteps = 0
|
|
self.eval_env = None
|
|
self.seed = seed
|
|
self.action_noise = None # type: Optional[ActionNoise]
|
|
self.start_time = None
|
|
self.policy = None
|
|
self.learning_rate = learning_rate
|
|
self.tensorboard_log = tensorboard_log
|
|
self.lr_schedule = None # type: Optional[Callable]
|
|
self._last_obs = None # type: Optional[np.ndarray]
|
|
# When using VecNormalize:
|
|
self._last_original_obs = None # type: Optional[np.ndarray]
|
|
self._episode_num = 0
|
|
# Used for gSDE only
|
|
self.use_sde = use_sde
|
|
self.sde_sample_freq = sde_sample_freq
|
|
# Track the training progress remaining (from 1 to 0)
|
|
# this is used to update the learning rate
|
|
self._current_progress_remaining = 1
|
|
# Buffers for logging
|
|
self.ep_info_buffer = None # type: Optional[deque]
|
|
self.ep_success_buffer = None # type: Optional[deque]
|
|
# For logging
|
|
self._n_updates = 0 # type: int
|
|
|
|
# Create and wrap the env if needed
|
|
if env is not None:
|
|
if isinstance(env, str):
|
|
if create_eval_env:
|
|
self.eval_env = maybe_make_env(env, monitor_wrapper, self.verbose)
|
|
|
|
env = maybe_make_env(env, monitor_wrapper, self.verbose)
|
|
env = self._wrap_env(env)
|
|
|
|
self.observation_space = env.observation_space
|
|
self.action_space = env.action_space
|
|
self.n_envs = env.num_envs
|
|
self.env = env
|
|
|
|
if not support_multi_env and self.n_envs > 1:
|
|
raise ValueError("Error: the model does not support multiple envs; it requires "
|
|
"a single vectorized environment.")
|
|
|
|
def _wrap_env(self, env: GymEnv) -> VecEnv:
|
|
if not isinstance(env, VecEnv):
|
|
if self.verbose >= 1:
|
|
print("Wrapping the env in a DummyVecEnv.")
|
|
env = DummyVecEnv([lambda: env])
|
|
|
|
if is_image_space(env.observation_space) and not isinstance(env, VecTransposeImage):
|
|
if self.verbose >= 1:
|
|
print("Wrapping the env in a VecTransposeImage.")
|
|
env = VecTransposeImage(env)
|
|
return env
|
|
|
|
@abstractmethod
|
|
def _setup_model(self) -> None:
|
|
"""Create networks, buffer and optimizers."""
|
|
|
|
def _get_eval_env(self, eval_env: Optional[GymEnv]) -> Optional[GymEnv]:
|
|
"""
|
|
Return the environment that will be used for evaluation.
|
|
|
|
:param eval_env: (Optional[GymEnv]))
|
|
:return: (Optional[GymEnv])
|
|
"""
|
|
if eval_env is None:
|
|
eval_env = self.eval_env
|
|
|
|
if eval_env is not None:
|
|
eval_env = self._wrap_env(eval_env)
|
|
assert eval_env.num_envs == 1
|
|
return eval_env
|
|
|
|
def _setup_lr_schedule(self) -> None:
|
|
"""Transform to callable if needed."""
|
|
self.lr_schedule = get_schedule_fn(self.learning_rate)
|
|
|
|
def _update_current_progress_remaining(self, num_timesteps: int, total_timesteps: int) -> None:
|
|
"""
|
|
Compute current progress remaining (starts from 1 and ends to 0)
|
|
|
|
:param num_timesteps: current number of timesteps
|
|
:param total_timesteps:
|
|
"""
|
|
self._current_progress_remaining = 1.0 - float(num_timesteps) / float(total_timesteps)
|
|
|
|
def _update_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.optim.Optimizer]) -> None:
|
|
"""
|
|
Update the optimizers learning rate using the current learning rate schedule
|
|
and the current progress remaining (from 1 to 0).
|
|
|
|
:param optimizers: (Union[List[th.optim.Optimizer], th.optim.Optimizer])
|
|
An optimizer or a list of optimizers.
|
|
"""
|
|
# Log the current learning rate
|
|
logger.record("train/learning_rate", self.lr_schedule(self._current_progress_remaining))
|
|
|
|
if not isinstance(optimizers, list):
|
|
optimizers = [optimizers]
|
|
for optimizer in optimizers:
|
|
update_learning_rate(optimizer, self.lr_schedule(self._current_progress_remaining))
|
|
|
|
def get_env(self) -> Optional[VecEnv]:
|
|
"""
|
|
Returns the current environment (can be None if not defined).
|
|
|
|
:return: (Optional[VecEnv]) The current environment
|
|
"""
|
|
return self.env
|
|
|
|
def get_vec_normalize_env(self) -> Optional[VecNormalize]:
|
|
"""
|
|
Return the ``VecNormalize`` wrapper of the training env
|
|
if it exists.
|
|
:return: Optional[VecNormalize] The ``VecNormalize`` env.
|
|
"""
|
|
return self._vec_normalize_env
|
|
|
|
def set_env(self, env: GymEnv) -> None:
|
|
"""
|
|
Checks the validity of the environment, and if it is coherent, set it as the current environment.
|
|
Furthermore wrap any non vectorized env into a vectorized
|
|
checked parameters:
|
|
- observation_space
|
|
- action_space
|
|
|
|
:param env: The environment for learning a policy
|
|
"""
|
|
check_for_correct_spaces(env, self.observation_space, self.action_space)
|
|
# it must be coherent now
|
|
# if it is not a VecEnv, make it a VecEnv
|
|
env = self._wrap_env(env)
|
|
|
|
self.n_envs = env.num_envs
|
|
self.env = env
|
|
|
|
def get_torch_variables(self) -> Tuple[List[str], List[str]]:
|
|
"""
|
|
Get the name of the torch variables that will be saved.
|
|
``th.save`` and ``th.load`` will be used with the right device
|
|
instead of the default pickling strategy.
|
|
|
|
:return: (Tuple[List[str], List[str]])
|
|
name of the variables with state dicts to save, name of additional torch tensors,
|
|
"""
|
|
state_dicts = ["policy"]
|
|
|
|
return state_dicts, []
|
|
|
|
@abstractmethod
|
|
def learn(self, total_timesteps: int,
|
|
callback: MaybeCallback = None,
|
|
log_interval: int = 100,
|
|
tb_log_name: str = "run",
|
|
eval_env: Optional[GymEnv] = None,
|
|
eval_freq: int = -1,
|
|
n_eval_episodes: int = 5,
|
|
eval_log_path: Optional[str] = None,
|
|
reset_num_timesteps: bool = True) -> 'BaseAlgorithm':
|
|
"""
|
|
Return a trained model.
|
|
|
|
:param total_timesteps: (int) The total number of samples (env steps) to train on
|
|
:param callback: (MaybeCallback) callback(s) called at every step with state of the algorithm.
|
|
:param log_interval: (int) The number of timesteps before logging.
|
|
:param tb_log_name: (str) the name of the run for TensorBoard logging
|
|
:param eval_env: (gym.Env) Environment that will be used to evaluate the agent
|
|
:param eval_freq: (int) Evaluate the agent every ``eval_freq`` timesteps (this may vary a little)
|
|
:param n_eval_episodes: (int) Number of episode to evaluate the agent
|
|
:param eval_log_path: (Optional[str]) Path to a folder where the evaluations will be saved
|
|
:param reset_num_timesteps: (bool) whether or not to reset the current timestep number (used in logging)
|
|
:return: (BaseAlgorithm) the trained model
|
|
"""
|
|
|
|
def predict(self, observation: np.ndarray,
|
|
state: Optional[np.ndarray] = None,
|
|
mask: Optional[np.ndarray] = None,
|
|
deterministic: bool = False) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
|
"""
|
|
Get the model's action(s) from an observation
|
|
|
|
:param observation: (np.ndarray) the input observation
|
|
:param state: (Optional[np.ndarray]) The last states (can be None, used in recurrent policies)
|
|
:param mask: (Optional[np.ndarray]) The last masks (can be None, used in recurrent policies)
|
|
:param deterministic: (bool) Whether or not to return deterministic actions.
|
|
:return: (Tuple[np.ndarray, Optional[np.ndarray]]) the model's action and the next state
|
|
(used in recurrent policies)
|
|
"""
|
|
return self.policy.predict(observation, state, mask, deterministic)
|
|
|
|
@classmethod
|
|
def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> 'BaseAlgorithm':
|
|
"""
|
|
Load the model from a zip-file
|
|
|
|
:param load_path: the location of the saved data
|
|
:param env: the new environment to run the loaded model on
|
|
(can be None if you only need prediction from a trained model) has priority over any saved environment
|
|
:param kwargs: extra arguments to change the model when loading
|
|
"""
|
|
data, params, tensors = load_from_zip_file(load_path)
|
|
|
|
if 'policy_kwargs' in data:
|
|
for arg_to_remove in ['device']:
|
|
if arg_to_remove in data['policy_kwargs']:
|
|
del data['policy_kwargs'][arg_to_remove]
|
|
|
|
if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']:
|
|
raise ValueError(f"The specified policy kwargs do not equal the stored policy kwargs."
|
|
f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}")
|
|
|
|
# check if observation space and action space are part of the saved parameters
|
|
if "observation_space" not in data or "action_space" not in data:
|
|
raise KeyError("The observation_space and action_space were not given, can't verify new environments")
|
|
# check if given env is valid
|
|
if env is not None:
|
|
check_for_correct_spaces(env, data["observation_space"], data["action_space"])
|
|
# if no new env was given use stored env if possible
|
|
if env is None and "env" in data:
|
|
env = data["env"]
|
|
|
|
# noinspection PyArgumentList
|
|
model = cls(policy=data["policy_class"], env=env, device='auto', _init_setup_model=False)
|
|
|
|
# load parameters
|
|
model.__dict__.update(data)
|
|
model.__dict__.update(kwargs)
|
|
model._setup_model()
|
|
|
|
# put state_dicts back in place
|
|
for name in params:
|
|
attr = recursive_getattr(model, name)
|
|
attr.load_state_dict(params[name])
|
|
|
|
# put tensors back in place
|
|
if tensors is not None:
|
|
for name in tensors:
|
|
recursive_setattr(model, name, tensors[name])
|
|
|
|
# Sample gSDE exploration matrix, so it uses the right device
|
|
# see issue #44
|
|
if model.use_sde:
|
|
model.policy.reset_noise()
|
|
return model
|
|
|
|
def set_random_seed(self, seed: Optional[int] = None) -> None:
|
|
"""
|
|
Set the seed of the pseudo-random generators
|
|
(python, numpy, pytorch, gym, action_space)
|
|
|
|
:param seed: (int)
|
|
"""
|
|
if seed is None:
|
|
return
|
|
set_random_seed(seed, using_cuda=self.device == th.device('cuda'))
|
|
self.action_space.seed(seed)
|
|
if self.env is not None:
|
|
self.env.seed(seed)
|
|
if self.eval_env is not None:
|
|
self.eval_env.seed(seed)
|
|
|
|
def _init_callback(self,
|
|
callback: MaybeCallback,
|
|
eval_env: Optional[VecEnv] = None,
|
|
eval_freq: int = 10000,
|
|
n_eval_episodes: int = 5,
|
|
log_path: Optional[str] = None) -> BaseCallback:
|
|
"""
|
|
:param callback: (MaybeCallback) Callback(s) called at every step with state of the algorithm.
|
|
:param eval_freq: (Optional[int]) How many steps between evaluations; if None, do not evaluate.
|
|
:param n_eval_episodes: (int) How many episodes to play per evaluation
|
|
:param n_eval_episodes: (int) Number of episodes to rollout during evaluation.
|
|
:param log_path: (Optional[str]) Path to a folder where the evaluations will be saved
|
|
:return: (BaseCallback) A hybrid callback calling `callback` and performing evaluation.
|
|
"""
|
|
# Convert a list of callbacks into a callback
|
|
if isinstance(callback, list):
|
|
callback = CallbackList(callback)
|
|
|
|
# Convert functional callback to object
|
|
if not isinstance(callback, BaseCallback):
|
|
callback = ConvertCallback(callback)
|
|
|
|
# Create eval callback in charge of the evaluation
|
|
if eval_env is not None:
|
|
eval_callback = EvalCallback(eval_env,
|
|
best_model_save_path=log_path,
|
|
log_path=log_path, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes)
|
|
callback = CallbackList([callback, eval_callback])
|
|
|
|
callback.init_callback(self)
|
|
return callback
|
|
|
|
def _setup_learn(self,
|
|
total_timesteps: int,
|
|
eval_env: Optional[GymEnv],
|
|
callback: MaybeCallback = None,
|
|
eval_freq: int = 10000,
|
|
n_eval_episodes: int = 5,
|
|
log_path: Optional[str] = None,
|
|
reset_num_timesteps: bool = True,
|
|
tb_log_name: str = 'run',
|
|
) -> Tuple[int, BaseCallback]:
|
|
"""
|
|
Initialize different variables needed for training.
|
|
|
|
:param total_timesteps: (int) The total number of samples (env steps) to train on
|
|
:param eval_env: (Optional[VecEnv]) Environment to use for evaluation.
|
|
:param callback: (MaybeCallback) Callback(s) called at every step with state of the algorithm.
|
|
:param eval_freq: (int) How many steps between evaluations
|
|
:param n_eval_episodes: (int) How many episodes to play per evaluation
|
|
:param log_path: (Optional[str]) Path to a folder where the evaluations will be saved
|
|
:param reset_num_timesteps: (bool) Whether to reset or not the ``num_timesteps`` attribute
|
|
:param tb_log_name: (str) the name of the run for tensorboard log
|
|
:return: (Tuple[int, BaseCallback])
|
|
"""
|
|
self.start_time = time.time()
|
|
if self.ep_info_buffer is None or reset_num_timesteps:
|
|
# Initialize buffers if they don't exist, or reinitialize if resetting counters
|
|
self.ep_info_buffer = deque(maxlen=100)
|
|
self.ep_success_buffer = deque(maxlen=100)
|
|
|
|
if self.action_noise is not None:
|
|
self.action_noise.reset()
|
|
|
|
if reset_num_timesteps:
|
|
self.num_timesteps = 0
|
|
self._episode_num = 0
|
|
else:
|
|
# Make sure training timesteps are ahead of the internal counter
|
|
total_timesteps += self.num_timesteps
|
|
self._total_timesteps = total_timesteps
|
|
|
|
# Avoid resetting the environment when calling ``.learn()`` consecutive times
|
|
if reset_num_timesteps or self._last_obs is None:
|
|
self._last_obs = self.env.reset()
|
|
# Retrieve unnormalized observation for saving into the buffer
|
|
if self._vec_normalize_env is not None:
|
|
self._last_original_obs = self._vec_normalize_env.get_original_obs()
|
|
|
|
if eval_env is not None and self.seed is not None:
|
|
eval_env.seed(self.seed)
|
|
|
|
eval_env = self._get_eval_env(eval_env)
|
|
|
|
# Configure logger's outputs
|
|
utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
|
|
|
|
# Create eval callback if needed
|
|
callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path)
|
|
|
|
return total_timesteps, callback
|
|
|
|
def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.ndarray] = None) -> None:
|
|
"""
|
|
Retrieve reward and episode length and update the buffer
|
|
if using Monitor wrapper.
|
|
|
|
:param infos: ([dict])
|
|
"""
|
|
if dones is None:
|
|
dones = np.array([False] * len(infos))
|
|
for idx, info in enumerate(infos):
|
|
maybe_ep_info = info.get('episode')
|
|
maybe_is_success = info.get('is_success')
|
|
if maybe_ep_info is not None:
|
|
self.ep_info_buffer.extend([maybe_ep_info])
|
|
if maybe_is_success is not None and dones[idx]:
|
|
self.ep_success_buffer.append(maybe_is_success)
|
|
|
|
def excluded_save_params(self) -> List[str]:
|
|
"""
|
|
Returns the names of the parameters that should be excluded by default
|
|
when saving the model.
|
|
|
|
:return: ([str]) List of parameters that should be excluded from save
|
|
"""
|
|
return ["policy", "device", "env", "eval_env", "replay_buffer", "rollout_buffer", "_vec_normalize_env"]
|
|
|
|
def save(
|
|
self,
|
|
path: Union[str, pathlib.Path, io.BufferedIOBase],
|
|
exclude: Optional[Iterable[str]] = None,
|
|
include: Optional[Iterable[str]] = None,
|
|
) -> None:
|
|
"""
|
|
Save all the attributes of the object and the model parameters in a zip-file.
|
|
|
|
:param (Union[str, pathlib.Path, io.BufferedIOBase]): path to the file where the rl agent should be saved
|
|
:param exclude: name of parameters that should be excluded in addition to the default one
|
|
:param include: name of parameters that might be excluded but should be included anyway
|
|
"""
|
|
# copy parameter list so we don't mutate the original dict
|
|
data = self.__dict__.copy()
|
|
|
|
# Exclude is union of specified parameters (if any) and standard exclusions
|
|
if exclude is None:
|
|
exclude = []
|
|
exclude = set(exclude).union(self.excluded_save_params())
|
|
|
|
# Do not exclude params if they are specifically included
|
|
if include is not None:
|
|
exclude = exclude.difference(include)
|
|
|
|
state_dicts_names, tensors_names = self.get_torch_variables()
|
|
# any params that are in the save vars must not be saved by data
|
|
torch_variables = state_dicts_names + tensors_names
|
|
for torch_var in torch_variables:
|
|
# we need to get only the name of the top most module as we'll remove that
|
|
var_name = torch_var.split('.')[0]
|
|
exclude.add(var_name)
|
|
|
|
# Remove parameter entries of parameters which are to be excluded
|
|
for param_name in exclude:
|
|
data.pop(param_name, None)
|
|
|
|
# Build dict of tensor variables
|
|
tensors = None
|
|
if tensors_names is not None:
|
|
tensors = {}
|
|
for name in tensors_names:
|
|
attr = recursive_getattr(self, name)
|
|
tensors[name] = attr
|
|
|
|
# Build dict of state_dicts
|
|
params_to_save = {}
|
|
for name in state_dicts_names:
|
|
attr = recursive_getattr(self, name)
|
|
# Retrieve state dict
|
|
params_to_save[name] = attr.state_dict()
|
|
|
|
save_to_zip_file(path, data=data, params=params_to_save, tensors=tensors)
|