mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-18 21:30:19 +00:00
* Fix failing set_env test * Fix test failiing due to deprectation of env.seed * Adjust mean reward threshold in failing test * Fix her test failing due to rng * Change seed and revert reward threshold to 90 * Pin gym version * Make VecEnv compatible with gym seeding change * Revert change to VecEnv reset signature * Change subprocenv seed cmd to call reset instead * Fix type check * Add backward compat * Add `compat_gym_seed` helper * Add goal env checks in env_checker * Add docs on HER requirements for envs * Capture user warning in test with inverted box space * Update ale-py version * Fix randint * Allow noop_max to be zero * Update changelog * Update docker image * Update doc conda env and dockerfile * Custom envs should not have any warnings * Fix test for numpy >= 1.21 * Add check for vectorized compute reward * Bump to gym 0.24 * Fix gym default step docstring * Test downgrading gym * Revert "Test downgrading gym" This reverts commit 0072b77156c006ada8a1d6e26ce347ed85a83eeb. * Fix protobuf error * Fix in dependencies * Fix protobuf dep * Use newest version of cartpole * Update gym * Fix warning * Loosen required scipy version * Scipy no longer needed * Try gym 0.25 * Silence warnings from gym * Filter warnings during tests * Update doc * Update requirements * Add gym 26 compat in vec env * Fixes in envs and tests for gym 0.26+ * Enforce gym 0.26 api * format * Fix formatting * Fix dependencies * Fix syntax * Cleanup doc and warnings * Faster tests * Higher budget for HER perf test (revert prev change) * Fixes and update doc * Fix doc build * Fix breaking change * Fixes for rendering * Rename variables in monitor * update render method for gym 0.26 API backwards compatible (mode argument is allowed) while using the gym 0.26 API (render mode is determined at environment creation) * update tests and docs to new gym render API * undo removal of render modes metatadata check * set rgb_array as default render mode for gym.make * undo changes & raise warning if not 'rgb_array' * Fix type check * Remove recursion and fix type checking * Remove hacks for protobuf and gym 0.24 * Fix type annotations * reuse existing render_mode attribute * return tiled images for 'human' render mode * Allow to use opencv for human render, fix typos * Add warning when using non-zero start with Discrete (fixes #1197) * Fix type checking * Bug fixes and handle more cases * Throw proper warnings * Update test * Fix new metadata name * Ignore numpy warnings * Fixes in vec recorder * Global ignore * Filter local warning too * Monkey patch not needed for gym 26 * Add doc of VecEnv vs Gym API * Add render test * Fix return type * Update VecEnv vs Gym API doc * Fix for custom render mode * Fix return type * Fix type checking * check test env test_buffer * skip render check * check env test_dict_env * test_env test_gae * check envs in remaining tests * Update tests * Add warning for Discrete action space with non-zero (#1295) * Fix atari annotation * ignore get_action_meanings [attr-defined] * Fix mypy issues * Add patch for gym/gymnasium transition * Switch to gymnasium * Rely on signature instead of version * More patches * Type ignore because of https://github.com/Farama-Foundation/Gymnasium/pull/39 * Fix doc build * Fix pytype errors * Fix atari requirement * Update env checker due to change in dtype for Discrete * Fix type hint * Convert spaces for saved models * Ignore pytype * Remove gitlab CI * Disable pytype for convert space * Fix undefined info * Fix undefined info * Upgrade shimmy * Fix wrappers type annotation (need PR from Gymnasium) * Fix gymnasium dependency * Fix dependency declaration * Cap pygame version for python 3.7 * Point to master branch (v0.28.0) * Fix: use main not master branch * Rename done to terminated * Fix pygame dependency for python 3.7 * Rename gym to gymnasium * Update Gymnasium * Fix test * Fix tests * Forks don't have access to private variables * Fix linter warnings * Update read the doc env * Fix env checker for GoalEnv * Fix import * Update env checker (more info) and fix dtype * Use micromamab for Docker * Update dependencies * Clarify VecEnv doc * Fix Gymnasium version * Copy file only after mamba install * [ci skip] Update docker doc * Polish code * Reformat * Remove deprecated features * Ignore warning * Update doc * Update examples and changelog * Fix type annotation bundle (SAC, TD3, A2C, PPO, base class) (#1436) * Fix SAC type hints, improve DQN ones * Fix A2C and TD3 type hints * Fix PPO type hints * Fix on-policy type hints * Fix base class type annotation, do not use defaults * Update version * Disable mypy for python 3.7 * Rename Gym26StepReturn * Update continuous critic type annotation * Fix pytype complain --------- Co-authored-by: Carlos Luis <carlos.luisgonc@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Thomas Lips <37955681+tlpss@users.noreply.github.com> Co-authored-by: tlips <thomas.lips@ugent.be> Co-authored-by: tlpss <thomas17.lips@gmail.com> Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
320 lines
12 KiB
Python
320 lines
12 KiB
Python
import inspect
|
|
import pickle
|
|
from copy import deepcopy
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
import numpy as np
|
|
from gymnasium import spaces
|
|
|
|
from stable_baselines3.common import utils
|
|
from stable_baselines3.common.preprocessing import is_image_space
|
|
from stable_baselines3.common.running_mean_std import RunningMeanStd
|
|
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
|
|
|
|
|
|
class VecNormalize(VecEnvWrapper):
|
|
"""
|
|
A moving average, normalizing wrapper for vectorized environment.
|
|
has support for saving/loading moving average,
|
|
|
|
:param venv: the vectorized environment to wrap
|
|
:param training: Whether to update or not the moving average
|
|
:param norm_obs: Whether to normalize observation or not (default: True)
|
|
:param norm_reward: Whether to normalize rewards or not (default: True)
|
|
:param clip_obs: Max absolute value for observation
|
|
:param clip_reward: Max value absolute for discounted reward
|
|
:param gamma: discount factor
|
|
:param epsilon: To avoid division by zero
|
|
:param norm_obs_keys: Which keys from observation dict to normalize.
|
|
If not specified, all keys will be normalized.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
venv: VecEnv,
|
|
training: bool = True,
|
|
norm_obs: bool = True,
|
|
norm_reward: bool = True,
|
|
clip_obs: float = 10.0,
|
|
clip_reward: float = 10.0,
|
|
gamma: float = 0.99,
|
|
epsilon: float = 1e-8,
|
|
norm_obs_keys: Optional[List[str]] = None,
|
|
):
|
|
VecEnvWrapper.__init__(self, venv)
|
|
|
|
self.norm_obs = norm_obs
|
|
self.norm_obs_keys = norm_obs_keys
|
|
# Check observation spaces
|
|
if self.norm_obs:
|
|
self._sanity_checks()
|
|
|
|
if isinstance(self.observation_space, spaces.Dict):
|
|
self.obs_spaces = self.observation_space.spaces
|
|
self.obs_rms = {key: RunningMeanStd(shape=self.obs_spaces[key].shape) for key in self.norm_obs_keys}
|
|
# Update observation space when using image
|
|
# See explanation below and GH #1214
|
|
for key in self.obs_rms.keys():
|
|
if is_image_space(self.obs_spaces[key]):
|
|
self.observation_space.spaces[key] = spaces.Box(
|
|
low=-clip_obs,
|
|
high=clip_obs,
|
|
shape=self.obs_spaces[key].shape,
|
|
dtype=np.float32,
|
|
)
|
|
|
|
else:
|
|
self.obs_spaces = None
|
|
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
|
|
# Update observation space when using image
|
|
# See GH #1214
|
|
# This is to raise proper error when
|
|
# VecNormalize is used with an image-like input and
|
|
# normalize_images=True.
|
|
# For correctness, we should also update the bounds
|
|
# in other cases but this will cause backward-incompatible change
|
|
# and break already saved policies.
|
|
if is_image_space(self.observation_space):
|
|
self.observation_space = spaces.Box(
|
|
low=-clip_obs,
|
|
high=clip_obs,
|
|
shape=self.observation_space.shape,
|
|
dtype=np.float32,
|
|
)
|
|
|
|
self.ret_rms = RunningMeanStd(shape=())
|
|
self.clip_obs = clip_obs
|
|
self.clip_reward = clip_reward
|
|
# Returns: discounted rewards
|
|
self.returns = np.zeros(self.num_envs)
|
|
self.gamma = gamma
|
|
self.epsilon = epsilon
|
|
self.training = training
|
|
self.norm_obs = norm_obs
|
|
self.norm_reward = norm_reward
|
|
self.old_obs = np.array([])
|
|
self.old_reward = np.array([])
|
|
|
|
def _sanity_checks(self) -> None:
|
|
"""
|
|
Check the observations that are going to be normalized are of the correct type (spaces.Box).
|
|
"""
|
|
if isinstance(self.observation_space, spaces.Dict):
|
|
# By default, we normalize all keys
|
|
if self.norm_obs_keys is None:
|
|
self.norm_obs_keys = list(self.observation_space.spaces.keys())
|
|
# Check that all keys are of type Box
|
|
for obs_key in self.norm_obs_keys:
|
|
if not isinstance(self.observation_space.spaces[obs_key], spaces.Box):
|
|
raise ValueError(
|
|
f"VecNormalize only supports `gym.spaces.Box` observation spaces but {obs_key} "
|
|
f"is of type {self.observation_space.spaces[obs_key]}. "
|
|
"You should probably explicitely pass the observation keys "
|
|
" that should be normalized via the `norm_obs_keys` parameter."
|
|
)
|
|
|
|
elif isinstance(self.observation_space, spaces.Box):
|
|
if self.norm_obs_keys is not None:
|
|
raise ValueError("`norm_obs_keys` param is applicable only with `gym.spaces.Dict` observation spaces")
|
|
|
|
else:
|
|
raise ValueError(
|
|
"VecNormalize only supports `gym.spaces.Box` and `gym.spaces.Dict` observation spaces, "
|
|
f"not {self.observation_space}"
|
|
)
|
|
|
|
def __getstate__(self) -> Dict[str, Any]:
|
|
"""
|
|
Gets state for pickling.
|
|
|
|
Excludes self.venv, as in general VecEnv's may not be pickleable."""
|
|
state = self.__dict__.copy()
|
|
# these attributes are not pickleable
|
|
del state["venv"]
|
|
del state["class_attributes"]
|
|
# these attributes depend on the above and so we would prefer not to pickle
|
|
del state["returns"]
|
|
return state
|
|
|
|
def __setstate__(self, state: Dict[str, Any]) -> None:
|
|
"""
|
|
Restores pickled state.
|
|
|
|
User must call set_venv() after unpickling before using.
|
|
|
|
:param state:"""
|
|
# Backward compatibility
|
|
if "norm_obs_keys" not in state and isinstance(state["observation_space"], spaces.Dict):
|
|
state["norm_obs_keys"] = list(state["observation_space"].spaces.keys())
|
|
self.__dict__.update(state)
|
|
assert "venv" not in state
|
|
self.venv = None
|
|
|
|
def set_venv(self, venv: VecEnv) -> None:
|
|
"""
|
|
Sets the vector environment to wrap to venv.
|
|
|
|
Also sets attributes derived from this such as `num_env`.
|
|
|
|
:param venv:
|
|
"""
|
|
if self.venv is not None:
|
|
raise ValueError("Trying to set venv of already initialized VecNormalize wrapper.")
|
|
self.venv = venv
|
|
self.num_envs = venv.num_envs
|
|
self.class_attributes = dict(inspect.getmembers(self.__class__))
|
|
|
|
# Check that the observation_space shape match
|
|
utils.check_shape_equal(self.observation_space, venv.observation_space)
|
|
self.returns = np.zeros(self.num_envs)
|
|
|
|
def step_wait(self) -> VecEnvStepReturn:
|
|
"""
|
|
Apply sequence of actions to sequence of environments
|
|
actions -> (observations, rewards, dones)
|
|
|
|
where ``dones`` is a boolean vector indicating whether each element is new.
|
|
"""
|
|
obs, rewards, dones, infos = self.venv.step_wait()
|
|
self.old_obs = obs
|
|
self.old_reward = rewards
|
|
|
|
if self.training and self.norm_obs:
|
|
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
|
|
for key in self.obs_rms.keys():
|
|
self.obs_rms[key].update(obs[key])
|
|
else:
|
|
self.obs_rms.update(obs)
|
|
|
|
obs = self.normalize_obs(obs)
|
|
|
|
if self.training:
|
|
self._update_reward(rewards)
|
|
rewards = self.normalize_reward(rewards)
|
|
|
|
# Normalize the terminal observations
|
|
for idx, done in enumerate(dones):
|
|
if not done:
|
|
continue
|
|
if "terminal_observation" in infos[idx]:
|
|
infos[idx]["terminal_observation"] = self.normalize_obs(infos[idx]["terminal_observation"])
|
|
|
|
self.returns[dones] = 0
|
|
return obs, rewards, dones, infos
|
|
|
|
def _update_reward(self, reward: np.ndarray) -> None:
|
|
"""Update reward normalization statistics."""
|
|
self.returns = self.returns * self.gamma + reward
|
|
self.ret_rms.update(self.returns)
|
|
|
|
def _normalize_obs(self, obs: np.ndarray, obs_rms: RunningMeanStd) -> np.ndarray:
|
|
"""
|
|
Helper to normalize observation.
|
|
:param obs:
|
|
:param obs_rms: associated statistics
|
|
:return: normalized observation
|
|
"""
|
|
return np.clip((obs - obs_rms.mean) / np.sqrt(obs_rms.var + self.epsilon), -self.clip_obs, self.clip_obs)
|
|
|
|
def _unnormalize_obs(self, obs: np.ndarray, obs_rms: RunningMeanStd) -> np.ndarray:
|
|
"""
|
|
Helper to unnormalize observation.
|
|
:param obs:
|
|
:param obs_rms: associated statistics
|
|
:return: unnormalized observation
|
|
"""
|
|
return (obs * np.sqrt(obs_rms.var + self.epsilon)) + obs_rms.mean
|
|
|
|
def normalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
|
"""
|
|
Normalize observations using this VecNormalize's observations statistics.
|
|
Calling this method does not update statistics.
|
|
"""
|
|
# Avoid modifying by reference the original object
|
|
obs_ = deepcopy(obs)
|
|
if self.norm_obs:
|
|
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
|
|
# Only normalize the specified keys
|
|
for key in self.norm_obs_keys:
|
|
obs_[key] = self._normalize_obs(obs[key], self.obs_rms[key]).astype(np.float32)
|
|
else:
|
|
obs_ = self._normalize_obs(obs, self.obs_rms).astype(np.float32)
|
|
return obs_
|
|
|
|
def normalize_reward(self, reward: np.ndarray) -> np.ndarray:
|
|
"""
|
|
Normalize rewards using this VecNormalize's rewards statistics.
|
|
Calling this method does not update statistics.
|
|
"""
|
|
if self.norm_reward:
|
|
reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward)
|
|
return reward
|
|
|
|
def unnormalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
|
# Avoid modifying by reference the original object
|
|
obs_ = deepcopy(obs)
|
|
if self.norm_obs:
|
|
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
|
|
for key in self.norm_obs_keys:
|
|
obs_[key] = self._unnormalize_obs(obs[key], self.obs_rms[key])
|
|
else:
|
|
obs_ = self._unnormalize_obs(obs, self.obs_rms)
|
|
return obs_
|
|
|
|
def unnormalize_reward(self, reward: np.ndarray) -> np.ndarray:
|
|
if self.norm_reward:
|
|
return reward * np.sqrt(self.ret_rms.var + self.epsilon)
|
|
return reward
|
|
|
|
def get_original_obs(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
|
"""
|
|
Returns an unnormalized version of the observations from the most recent
|
|
step or reset.
|
|
"""
|
|
return deepcopy(self.old_obs)
|
|
|
|
def get_original_reward(self) -> np.ndarray:
|
|
"""
|
|
Returns an unnormalized version of the rewards from the most recent step.
|
|
"""
|
|
return self.old_reward.copy()
|
|
|
|
def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
|
"""
|
|
Reset all environments
|
|
:return: first observation of the episode
|
|
"""
|
|
obs = self.venv.reset()
|
|
self.old_obs = obs
|
|
self.returns = np.zeros(self.num_envs)
|
|
if self.training and self.norm_obs:
|
|
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
|
|
for key in self.obs_rms.keys():
|
|
self.obs_rms[key].update(obs[key])
|
|
else:
|
|
self.obs_rms.update(obs)
|
|
return self.normalize_obs(obs)
|
|
|
|
@staticmethod
|
|
def load(load_path: str, venv: VecEnv) -> "VecNormalize":
|
|
"""
|
|
Loads a saved VecNormalize object.
|
|
|
|
:param load_path: the path to load from.
|
|
:param venv: the VecEnv to wrap.
|
|
:return:
|
|
"""
|
|
with open(load_path, "rb") as file_handler:
|
|
vec_normalize = pickle.load(file_handler)
|
|
vec_normalize.set_venv(venv)
|
|
return vec_normalize
|
|
|
|
def save(self, save_path: str) -> None:
|
|
"""
|
|
Save current VecNormalize object with
|
|
all running statistics and settings (e.g. clip_obs)
|
|
|
|
:param save_path: The path to save to
|
|
"""
|
|
with open(save_path, "wb") as file_handler:
|
|
pickle.dump(self, file_handler)
|