stable-baselines3/stable_baselines3/common/utils.py
Antonin RAFFIN 40e0b9d2c8
Add Gymnasium support (#1327)
* Fix failing set_env test

* Fix test failiing due to deprectation of env.seed

* Adjust mean reward threshold in failing test

* Fix her test failing due to rng

* Change seed and revert reward threshold to 90

* Pin gym version

* Make VecEnv compatible with gym seeding change

* Revert change to VecEnv reset signature

* Change subprocenv seed cmd to call reset instead

* Fix type check

* Add backward compat

* Add `compat_gym_seed` helper

* Add goal env checks in env_checker

* Add docs on  HER requirements for envs

* Capture user warning in test with inverted box space

* Update ale-py version

* Fix randint

* Allow noop_max to be zero

* Update changelog

* Update docker image

* Update doc conda env and dockerfile

* Custom envs should not have any warnings

* Fix test for numpy >= 1.21

* Add check for vectorized compute reward

* Bump to gym 0.24

* Fix gym default step docstring

* Test downgrading gym

* Revert "Test downgrading gym"

This reverts commit 0072b77156c006ada8a1d6e26ce347ed85a83eeb.

* Fix protobuf error

* Fix in dependencies

* Fix protobuf dep

* Use newest version of cartpole

* Update gym

* Fix warning

* Loosen required scipy version

* Scipy no longer needed

* Try gym 0.25

* Silence warnings from gym

* Filter warnings during tests

* Update doc

* Update requirements

* Add gym 26 compat in vec env

* Fixes in envs and tests for gym 0.26+

* Enforce gym 0.26 api

* format

* Fix formatting

* Fix dependencies

* Fix syntax

* Cleanup doc and warnings

* Faster tests

* Higher budget for HER perf test (revert prev change)

* Fixes and update doc

* Fix doc build

* Fix breaking change

* Fixes for rendering

* Rename variables in monitor

* update render method for gym 0.26 API

backwards compatible (mode argument is allowed) while using the gym 0.26 API (render mode is determined at environment creation)

* update tests and docs to new gym render API

* undo removal of render modes metatadata check

* set rgb_array as default render mode for gym.make

* undo changes & raise warning if not 'rgb_array'

* Fix type check

* Remove recursion and fix type checking

* Remove hacks for protobuf and gym 0.24

* Fix type annotations

* reuse existing render_mode attribute

* return tiled images for 'human' render mode

* Allow to use opencv for human render, fix typos

* Add warning when using non-zero start with Discrete (fixes #1197)

* Fix type checking

* Bug fixes and handle more cases

* Throw proper warnings

* Update test

* Fix new metadata name

* Ignore numpy warnings

* Fixes in vec recorder

* Global ignore

* Filter local warning too

* Monkey patch not needed for gym 26

* Add doc of VecEnv vs Gym API

* Add render test

* Fix return type

* Update VecEnv vs Gym API doc

* Fix for custom render mode

* Fix return type

* Fix type checking

* check test env test_buffer

* skip render check

* check env test_dict_env

* test_env test_gae

* check envs in remaining tests

* Update tests

* Add warning for Discrete action space with non-zero (#1295)

* Fix atari annotation

* ignore get_action_meanings [attr-defined]

* Fix mypy issues

* Add patch for gym/gymnasium transition

* Switch to gymnasium

* Rely on signature instead of version

* More patches

* Type ignore because of https://github.com/Farama-Foundation/Gymnasium/pull/39

* Fix doc build

* Fix pytype errors

* Fix atari requirement

* Update env checker due to change in dtype for Discrete

* Fix type hint

* Convert spaces for saved models

* Ignore pytype

* Remove gitlab CI

* Disable pytype for convert space

* Fix undefined info

* Fix undefined info

* Upgrade shimmy

* Fix wrappers type annotation (need PR from Gymnasium)

* Fix gymnasium dependency

* Fix dependency declaration

* Cap pygame version for python 3.7

* Point to master branch (v0.28.0)

* Fix: use main not master branch

* Rename done to terminated

* Fix pygame dependency for python 3.7

* Rename gym to gymnasium

* Update Gymnasium

* Fix test

* Fix tests

* Forks don't have access to private variables

* Fix linter warnings

* Update read the doc env

* Fix env checker for GoalEnv

* Fix import

* Update env checker (more info) and fix dtype

* Use micromamab for Docker

* Update dependencies

* Clarify VecEnv doc

* Fix Gymnasium version

* Copy file only after mamba install

* [ci skip] Update docker doc

* Polish code

* Reformat

* Remove deprecated features

* Ignore warning

* Update doc

* Update examples and changelog

* Fix type annotation bundle (SAC, TD3, A2C, PPO, base class) (#1436)

* Fix SAC type hints, improve DQN ones

* Fix A2C and TD3 type hints

* Fix PPO type hints

* Fix on-policy type hints

* Fix base class type annotation, do not use defaults

* Update version

* Disable mypy for python 3.7

* Rename Gym26StepReturn

* Update continuous critic type annotation

* Fix pytype complain

---------

Co-authored-by: Carlos Luis <carlos.luisgonc@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Thomas Lips <37955681+tlpss@users.noreply.github.com>
Co-authored-by: tlips <thomas.lips@ugent.be>
Co-authored-by: tlpss <thomas17.lips@gmail.com>
Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
2023-04-14 13:13:59 +02:00

557 lines
20 KiB
Python

import glob
import os
import platform
import random
import re
from collections import deque
from inspect import signature
from itertools import zip_longest
from typing import Dict, Iterable, List, Optional, Tuple, Union
import gymnasium as gym
import numpy as np
import torch as th
from gymnasium import spaces
import stable_baselines3 as sb3
# Check if tensorboard is available for pytorch
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
SummaryWriter = None
from stable_baselines3.common.logger import Logger, configure
from stable_baselines3.common.type_aliases import GymEnv, Schedule, TensorDict, TrainFreq, TrainFrequencyUnit
def set_random_seed(seed: int, using_cuda: bool = False) -> None:
"""
Seed the different random generators.
:param seed:
:param using_cuda:
"""
# Seed python RNG
random.seed(seed)
# Seed numpy RNG
np.random.seed(seed)
# seed the RNG for all devices (both CPU and CUDA)
th.manual_seed(seed)
if using_cuda:
# Deterministic operations for CuDNN, it may impact performances
th.backends.cudnn.deterministic = True
th.backends.cudnn.benchmark = False
# From stable baselines
def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray:
"""
Computes fraction of variance that ypred explains about y.
Returns 1 - Var[y-ypred] / Var[y]
interpretation:
ev=0 => might as well have predicted zero
ev=1 => perfect prediction
ev<0 => worse than just predicting zero
:param y_pred: the prediction
:param y_true: the expected value
:return: explained variance of ypred and y
"""
assert y_true.ndim == 1 and y_pred.ndim == 1
var_y = np.var(y_true)
return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) -> None:
"""
Update the learning rate for a given optimizer.
Useful when doing linear schedule.
:param optimizer: Pytorch optimizer
:param learning_rate: New learning rate value
"""
for param_group in optimizer.param_groups:
param_group["lr"] = learning_rate
def get_schedule_fn(value_schedule: Union[Schedule, float]) -> Schedule:
"""
Transform (if needed) learning rate and clip range (for PPO)
to callable.
:param value_schedule: Constant value of schedule function
:return: Schedule function (can return constant value)
"""
# If the passed schedule is a float
# create a constant function
if isinstance(value_schedule, (float, int)):
# Cast to float to avoid errors
value_schedule = constant_fn(float(value_schedule))
else:
assert callable(value_schedule)
return value_schedule
def get_linear_fn(start: float, end: float, end_fraction: float) -> Schedule:
"""
Create a function that interpolates linearly between start and end
between ``progress_remaining`` = 1 and ``progress_remaining`` = ``end_fraction``.
This is used in DQN for linearly annealing the exploration fraction
(epsilon for the epsilon-greedy strategy).
:params start: value to start with if ``progress_remaining`` = 1
:params end: value to end with if ``progress_remaining`` = 0
:params end_fraction: fraction of ``progress_remaining``
where end is reached e.g 0.1 then end is reached after 10%
of the complete training process.
:return: Linear schedule function.
"""
def func(progress_remaining: float) -> float:
if (1 - progress_remaining) > end_fraction:
return end
else:
return start + (1 - progress_remaining) * (end - start) / end_fraction
return func
def constant_fn(val: float) -> Schedule:
"""
Create a function that returns a constant
It is useful for learning rate schedule (to avoid code duplication)
:param val: constant value
:return: Constant schedule function.
"""
def func(_):
return val
return func
def get_device(device: Union[th.device, str] = "auto") -> th.device:
"""
Retrieve PyTorch device.
It checks that the requested device is available first.
For now, it supports only cpu and cuda.
By default, it tries to use the gpu.
:param device: One for 'auto', 'cuda', 'cpu'
:return: Supported Pytorch device
"""
# Cuda by default
if device == "auto":
device = "cuda"
# Force conversion to th.device
device = th.device(device)
# Cuda not available
if device.type == th.device("cuda").type and not th.cuda.is_available():
return th.device("cpu")
return device
def get_latest_run_id(log_path: str = "", log_name: str = "") -> int:
"""
Returns the latest run number for the given log name and log path,
by finding the greatest number in the directories.
:param log_path: Path to the log folder containing several runs.
:param log_name: Name of the experiment. Each run is stored
in a folder named ``log_name_1``, ``log_name_2``, ...
:return: latest run number
"""
max_run_id = 0
for path in glob.glob(os.path.join(log_path, f"{glob.escape(log_name)}_[0-9]*")):
file_name = path.split(os.sep)[-1]
ext = file_name.split("_")[-1]
if log_name == "_".join(file_name.split("_")[:-1]) and ext.isdigit() and int(ext) > max_run_id:
max_run_id = int(ext)
return max_run_id
def configure_logger(
verbose: int = 0,
tensorboard_log: Optional[str] = None,
tb_log_name: str = "",
reset_num_timesteps: bool = True,
) -> Logger:
"""
Configure the logger's outputs.
:param verbose: Verbosity level: 0 for no output, 1 for the standard output to be part of the logger outputs
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param tb_log_name: tensorboard log
:param reset_num_timesteps: Whether the ``num_timesteps`` attribute is reset or not.
It allows to continue a previous learning curve (``reset_num_timesteps=False``)
or start from t=0 (``reset_num_timesteps=True``, the default).
:return: The logger object
"""
save_path, format_strings = None, ["stdout"]
if tensorboard_log is not None and SummaryWriter is None:
raise ImportError("Trying to log data to tensorboard but tensorboard is not installed.")
if tensorboard_log is not None and SummaryWriter is not None:
latest_run_id = get_latest_run_id(tensorboard_log, tb_log_name)
if not reset_num_timesteps:
# Continue training in the same directory
latest_run_id -= 1
save_path = os.path.join(tensorboard_log, f"{tb_log_name}_{latest_run_id + 1}")
if verbose >= 1:
format_strings = ["stdout", "tensorboard"]
else:
format_strings = ["tensorboard"]
elif verbose == 0:
format_strings = [""]
return configure(save_path, format_strings=format_strings)
def check_for_correct_spaces(env: GymEnv, observation_space: spaces.Space, action_space: spaces.Space) -> None:
"""
Checks that the environment has same spaces as provided ones. Used by BaseAlgorithm to check if
spaces match after loading the model with given env.
Checked parameters:
- observation_space
- action_space
:param env: Environment to check for valid spaces
:param observation_space: Observation space to check against
:param action_space: Action space to check against
"""
if observation_space != env.observation_space:
raise ValueError(f"Observation spaces do not match: {observation_space} != {env.observation_space}")
if action_space != env.action_space:
raise ValueError(f"Action spaces do not match: {action_space} != {env.action_space}")
def check_shape_equal(space1: spaces.Space, space2: spaces.Space) -> None:
"""
If the spaces are Box, check that they have the same shape.
If the spaces are Dict, it recursively checks the subspaces.
:param space1: Space
:param space2: Other space
"""
if isinstance(space1, spaces.Dict):
assert isinstance(space2, spaces.Dict), "spaces must be of the same type"
assert space1.spaces.keys() == space2.spaces.keys(), "spaces must have the same keys"
for key in space1.spaces.keys():
check_shape_equal(space1.spaces[key], space2.spaces[key])
elif isinstance(space1, spaces.Box):
assert space1.shape == space2.shape, "spaces must have the same shape"
def is_vectorized_box_observation(observation: np.ndarray, observation_space: spaces.Box) -> bool:
"""
For box observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
if observation.shape == observation_space.shape:
return False
elif observation.shape[1:] == observation_space.shape:
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for "
+ f"Box environment, please use {observation_space.shape} "
+ "or (n_env, {}) for the observation shape.".format(", ".join(map(str, observation_space.shape)))
)
def is_vectorized_discrete_observation(observation: Union[int, np.ndarray], observation_space: spaces.Discrete) -> bool:
"""
For discrete observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
if isinstance(observation, int) or observation.shape == (): # A numpy array of a number, has shape empty tuple '()'
return False
elif len(observation.shape) == 1:
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for "
+ "Discrete environment, please use () or (n_env,) for the observation shape."
)
def is_vectorized_multidiscrete_observation(observation: np.ndarray, observation_space: spaces.MultiDiscrete) -> bool:
"""
For multidiscrete observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
if observation.shape == (len(observation_space.nvec),):
return False
elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec):
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for MultiDiscrete "
+ f"environment, please use ({len(observation_space.nvec)},) or "
+ f"(n_env, {len(observation_space.nvec)}) for the observation shape."
)
def is_vectorized_multibinary_observation(observation: np.ndarray, observation_space: spaces.MultiBinary) -> bool:
"""
For multibinary observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
if observation.shape == observation_space.shape:
return False
elif len(observation.shape) == len(observation_space.shape) + 1 and observation.shape[1:] == observation_space.shape:
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for MultiBinary "
+ f"environment, please use {observation_space.shape} or "
+ f"(n_env, {observation_space.n}) for the observation shape."
)
def is_vectorized_dict_observation(observation: np.ndarray, observation_space: spaces.Dict) -> bool:
"""
For dict observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
# We first assume that all observations are not vectorized
all_non_vectorized = True
for key, subspace in observation_space.spaces.items():
# This fails when the observation is not vectorized
# or when it has the wrong shape
if observation[key].shape != subspace.shape:
all_non_vectorized = False
break
if all_non_vectorized:
return False
all_vectorized = True
# Now we check that all observation are vectorized and have the correct shape
for key, subspace in observation_space.spaces.items():
if observation[key].shape[1:] != subspace.shape:
all_vectorized = False
break
if all_vectorized:
return True
else:
# Retrieve error message
error_msg = ""
try:
is_vectorized_observation(observation[key], observation_space.spaces[key])
except ValueError as e:
error_msg = f"{e}"
raise ValueError(
f"There seems to be a mix of vectorized and non-vectorized observations. "
f"Unexpected observation shape {observation[key].shape} for key {key} "
f"of type {observation_space.spaces[key]}. {error_msg}"
)
def is_vectorized_observation(observation: Union[int, np.ndarray], observation_space: spaces.Space) -> bool:
"""
For every observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
is_vec_obs_func_dict = {
spaces.Box: is_vectorized_box_observation,
spaces.Discrete: is_vectorized_discrete_observation,
spaces.MultiDiscrete: is_vectorized_multidiscrete_observation,
spaces.MultiBinary: is_vectorized_multibinary_observation,
spaces.Dict: is_vectorized_dict_observation,
}
for space_type, is_vec_obs_func in is_vec_obs_func_dict.items():
if isinstance(observation_space, space_type):
return is_vec_obs_func(observation, observation_space)
else:
# for-else happens if no break is called
raise ValueError(f"Error: Cannot determine if the observation is vectorized with the space type {observation_space}.")
def safe_mean(arr: Union[np.ndarray, list, deque]) -> np.ndarray:
"""
Compute the mean of an array if there is at least one element.
For empty array, return NaN. It is used for logging only.
:param arr: Numpy array or list of values
:return:
"""
return np.nan if len(arr) == 0 else np.mean(arr)
def get_parameters_by_name(model: th.nn.Module, included_names: Iterable[str]) -> List[th.Tensor]:
"""
Extract parameters from the state dict of ``model``
if the name contains one of the strings in ``included_names``.
:param model: the model where the parameters come from.
:param included_names: substrings of names to include.
:return: List of parameters values (Pytorch tensors)
that matches the queried names.
"""
return [param for name, param in model.state_dict().items() if any([key in name for key in included_names])]
def zip_strict(*iterables: Iterable) -> Iterable:
r"""
``zip()`` function but enforces that iterables are of equal length.
Raises ``ValueError`` if iterables not of equal length.
Code inspired by Stackoverflow answer for question #32954486.
:param \*iterables: iterables to ``zip()``
"""
# As in Stackoverflow #32954486, use
# new object for "empty" in case we have
# Nones in iterable.
sentinel = object()
for combo in zip_longest(*iterables, fillvalue=sentinel):
if sentinel in combo:
raise ValueError("Iterables have different lengths")
yield combo
def polyak_update(
params: Iterable[th.Tensor],
target_params: Iterable[th.Tensor],
tau: float,
) -> None:
"""
Perform a Polyak average update on ``target_params`` using ``params``:
target parameters are slowly updated towards the main parameters.
``tau``, the soft update coefficient controls the interpolation:
``tau=1`` corresponds to copying the parameters to the target ones whereas nothing happens when ``tau=0``.
The Polyak update is done in place, with ``no_grad``, and therefore does not create intermediate tensors,
or a computation graph, reducing memory cost and improving performance. We scale the target params
by ``1-tau`` (in-place), add the new weights, scaled by ``tau`` and store the result of the sum in the target
params (in place).
See https://github.com/DLR-RM/stable-baselines3/issues/93
:param params: parameters to use to update the target params
:param target_params: parameters to update
:param tau: the soft update coefficient ("Polyak update", between 0 and 1)
"""
with th.no_grad():
# zip does not raise an exception if length of parameters does not match.
for param, target_param in zip_strict(params, target_params):
target_param.data.mul_(1 - tau)
th.add(target_param.data, param.data, alpha=tau, out=target_param.data)
def obs_as_tensor(obs: Union[np.ndarray, Dict[str, np.ndarray]], device: th.device) -> Union[th.Tensor, TensorDict]:
"""
Moves the observation to the given device.
:param obs:
:param device: PyTorch device
:return: PyTorch tensor of the observation on a desired device.
"""
if isinstance(obs, np.ndarray):
return th.as_tensor(obs, device=device)
elif isinstance(obs, dict):
return {key: th.as_tensor(_obs, device=device) for (key, _obs) in obs.items()}
else:
raise Exception(f"Unrecognized type of observation {type(obs)}")
def should_collect_more_steps(
train_freq: TrainFreq,
num_collected_steps: int,
num_collected_episodes: int,
) -> bool:
"""
Helper used in ``collect_rollouts()`` of off-policy algorithms
to determine the termination condition.
:param train_freq: How much experience should be collected before updating the policy.
:param num_collected_steps: The number of already collected steps.
:param num_collected_episodes: The number of already collected episodes.
:return: Whether to continue or not collecting experience
by doing rollouts of the current policy.
"""
if train_freq.unit == TrainFrequencyUnit.STEP:
return num_collected_steps < train_freq.frequency
elif train_freq.unit == TrainFrequencyUnit.EPISODE:
return num_collected_episodes < train_freq.frequency
else:
raise ValueError(
"The unit of the `train_freq` must be either TrainFrequencyUnit.STEP "
f"or TrainFrequencyUnit.EPISODE not '{train_freq.unit}'!"
)
def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]:
"""
Retrieve system and python env info for the current system.
:param print_info: Whether to print or not those infos
:return: Dictionary summing up the version for each relevant package
and a formatted string.
"""
env_info = {
# In OS, a regex is used to add a space between a "#" and a number to avoid
# wrongly linking to another issue on GitHub. Example: turn "#42" to "# 42".
"OS": re.sub(r"#(\d)", r"# \1", f"{platform.platform()} {platform.version()}"),
"Python": platform.python_version(),
"Stable-Baselines3": sb3.__version__,
"PyTorch": th.__version__,
"GPU Enabled": str(th.cuda.is_available()),
"Numpy": np.__version__,
"Gym": gym.__version__,
}
env_info_str = ""
for key, value in env_info.items():
env_info_str += f"- {key}: {value}\n"
if print_info:
print(env_info_str)
return env_info, env_info_str
def compat_gym_seed(env: GymEnv, seed: int) -> None:
"""
Compatibility helper to seed Gym envs.
:param env: The Gym environment.
:param seed: The seed for the pseudo random generator
"""
if "seed" in signature(env.unwrapped.reset).parameters:
# gym >= 0.23.1
env.reset(seed=seed)
else:
# VecEnv and backward compatibility
env.seed(seed)