stable-baselines3/stable_baselines3/common/vec_env/vec_normalize.py
Antonin RAFFIN a1e055695c
Improve typing coverage (#175)
* Improve typing coverage

* Even more types

* Fixes

* Update changelog

* Unified docstrings

* Improve error messages for unsupported spaces
2020-10-07 10:51:49 +02:00

191 lines
6.3 KiB
Python

import pickle
from typing import Any, Dict
import numpy as np
from stable_baselines3.common.running_mean_std import RunningMeanStd
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
class VecNormalize(VecEnvWrapper):
"""
A moving average, normalizing wrapper for vectorized environment.
has support for saving/loading moving average,
:param venv: the vectorized environment to wrap
:param training: Whether to update or not the moving average
:param norm_obs: Whether to normalize observation or not (default: True)
:param norm_reward: Whether to normalize rewards or not (default: True)
:param clip_obs: Max absolute value for observation
:param clip_reward: Max value absolute for discounted reward
:param gamma: discount factor
:param epsilon: To avoid division by zero
"""
def __init__(
self,
venv: VecEnv,
training: bool = True,
norm_obs: bool = True,
norm_reward: bool = True,
clip_obs: float = 10.0,
clip_reward: float = 10.0,
gamma: float = 0.99,
epsilon: float = 1e-8,
):
VecEnvWrapper.__init__(self, venv)
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
self.ret_rms = RunningMeanStd(shape=())
self.clip_obs = clip_obs
self.clip_reward = clip_reward
# Returns: discounted rewards
self.ret = np.zeros(self.num_envs)
self.gamma = gamma
self.epsilon = epsilon
self.training = training
self.norm_obs = norm_obs
self.norm_reward = norm_reward
self.old_obs = np.array([])
self.old_reward = np.array([])
def __getstate__(self) -> Dict[str, Any]:
"""
Gets state for pickling.
Excludes self.venv, as in general VecEnv's may not be pickleable."""
state = self.__dict__.copy()
# these attributes are not pickleable
del state["venv"]
del state["class_attributes"]
# these attributes depend on the above and so we would prefer not to pickle
del state["ret"]
return state
def __setstate__(self, state: Dict[str, Any]) -> None:
"""
Restores pickled state.
User must call set_venv() after unpickling before using.
:param state:"""
self.__dict__.update(state)
assert "venv" not in state
self.venv = None
def set_venv(self, venv: VecEnv) -> None:
"""
Sets the vector environment to wrap to venv.
Also sets attributes derived from this such as `num_env`.
:param venv:
"""
if self.venv is not None:
raise ValueError("Trying to set venv of already initialized VecNormalize wrapper.")
VecEnvWrapper.__init__(self, venv)
if self.obs_rms.mean.shape != self.observation_space.shape:
raise ValueError("venv is incompatible with current statistics.")
self.ret = np.zeros(self.num_envs)
def step_wait(self) -> VecEnvStepReturn:
"""
Apply sequence of actions to sequence of environments
actions -> (observations, rewards, news)
where 'news' is a boolean vector indicating whether each element is new.
"""
obs, rews, news, infos = self.venv.step_wait()
self.old_obs = obs
self.old_reward = rews
if self.training:
self.obs_rms.update(obs)
obs = self.normalize_obs(obs)
if self.training:
self._update_reward(rews)
rews = self.normalize_reward(rews)
self.ret[news] = 0
return obs, rews, news, infos
def _update_reward(self, reward: np.ndarray) -> None:
"""Update reward normalization statistics."""
self.ret = self.ret * self.gamma + reward
self.ret_rms.update(self.ret)
def normalize_obs(self, obs: np.ndarray) -> np.ndarray:
"""
Normalize observations using this VecNormalize's observations statistics.
Calling this method does not update statistics.
"""
if self.norm_obs:
obs = np.clip((obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon), -self.clip_obs, self.clip_obs)
return obs
def normalize_reward(self, reward: np.ndarray) -> np.ndarray:
"""
Normalize rewards using this VecNormalize's rewards statistics.
Calling this method does not update statistics.
"""
if self.norm_reward:
reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward)
return reward
def unnormalize_obs(self, obs: np.ndarray) -> np.ndarray:
if self.norm_obs:
return (obs * np.sqrt(self.obs_rms.var + self.epsilon)) + self.obs_rms.mean
return obs
def unnormalize_reward(self, reward: np.ndarray) -> np.ndarray:
if self.norm_reward:
return reward * np.sqrt(self.ret_rms.var + self.epsilon)
return reward
def get_original_obs(self) -> np.ndarray:
"""
Returns an unnormalized version of the observations from the most recent
step or reset.
"""
return self.old_obs.copy()
def get_original_reward(self) -> np.ndarray:
"""
Returns an unnormalized version of the rewards from the most recent step.
"""
return self.old_reward.copy()
def reset(self) -> np.ndarray:
"""
Reset all environments
"""
obs = self.venv.reset()
self.old_obs = obs
self.ret = np.zeros(self.num_envs)
if self.training:
self._update_reward(self.ret)
return self.normalize_obs(obs)
@staticmethod
def load(load_path: str, venv: VecEnv) -> "VecNormalize":
"""
Loads a saved VecNormalize object.
:param load_path: the path to load from.
:param venv: the VecEnv to wrap.
:return:
"""
with open(load_path, "rb") as file_handler:
vec_normalize = pickle.load(file_handler)
vec_normalize.set_venv(venv)
return vec_normalize
def save(self, save_path: str) -> None:
"""
Save current VecNormalize object with
all running statistics and settings (e.g. clip_obs)
:param save_path: The path to save to
"""
with open(save_path, "wb") as file_handler:
pickle.dump(self, file_handler)