mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-18 21:30:19 +00:00
* Added working her version, Online sampling is missing. * Updated test_her. * Added first version of online her sampling. Still problems with tensor dimensions. * Reformat * Fixed tests * Added some comments. * Updated changelog. * Add missing init file * Fixed some small bugs. * Reduced arguments for HER, small changes. * Added getattr. Fixed bug for online sampling. * Updated save/load funtions. Small changes. * Added her to init. * Updated save method. * Updated her ratio. * Move obs_wrapper * Added DQN test. * Fix potential bug * Offline and online her share same sample_goal function. * Changed lists into arrays. * Updated her test. * Fix online sampling * Fixed action bug. Updated time limit for episodes. * Updated convert_dict method to take keys as arguments. * Renamed obs dict wrapper. * Seed bit flipping env * Remove get_episode_dict * Add fast online sampling version * Added documentation. * Vectorized reward computation * Vectorized goal sampling * Update time limit for episodes in online her sampling. * Fix max episode length inference * Bug fix for Fetch envs * Fix for HER + gSDE * Reformat (new black version) * Added info dict to compute new reward. Check her_replay_buffer again. * Fix info buffer * Updated done flag. * Fixes for gSDE * Offline her version uses now HerReplayBuffer as episode storage. * Fix num_timesteps computation * Fix get torch params * Vectorized version for offline sampling. * Modified offline her sampling to use sample method of her_replay_buffer * Updated HER tests. * Updated documentation * Cleanup docstrings * Updated to review comments * Fix pytype * Update according to review comments. * Removed random goal strategy. Updated sample transitions. * Updated migration. Removed time signal removal. * Update doc * Fix potential load issue * Add VecNormalize support for dict obs * Updated saving/loading replay buffer for HER. * Fix test memory usage * Fixed save/load replay buffer. * Fixed save/load replay buffer * Fixed transition index after loading replay buffer in online sampling * Better error handling * Add tests for get_time_limit * More tests for VecNormalize with dict obs * Update doc * Improve HER description * Add test for sde support * Add comments * Add comments * Remove check that was always valid * Fix for terminal observation * Updated buffer size in offline version and reset of HER buffer * Reformat * Update doc * Remove np.empty + add doc * Fix loading * Updated loading replay buffer * Separate online and offline sampling + bug fixes * Update tensorboard log name * Version bump * Bug fix for special case Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
243 lines
8.7 KiB
Python
243 lines
8.7 KiB
Python
import pickle
|
|
from copy import deepcopy
|
|
from typing import Any, Dict, Union
|
|
|
|
import gym
|
|
import numpy as np
|
|
|
|
from stable_baselines3.common import utils
|
|
from stable_baselines3.common.running_mean_std import RunningMeanStd
|
|
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
|
|
|
|
|
|
class VecNormalize(VecEnvWrapper):
|
|
"""
|
|
A moving average, normalizing wrapper for vectorized environment.
|
|
has support for saving/loading moving average,
|
|
|
|
:param venv: the vectorized environment to wrap
|
|
:param training: Whether to update or not the moving average
|
|
:param norm_obs: Whether to normalize observation or not (default: True)
|
|
:param norm_reward: Whether to normalize rewards or not (default: True)
|
|
:param clip_obs: Max absolute value for observation
|
|
:param clip_reward: Max value absolute for discounted reward
|
|
:param gamma: discount factor
|
|
:param epsilon: To avoid division by zero
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
venv: VecEnv,
|
|
training: bool = True,
|
|
norm_obs: bool = True,
|
|
norm_reward: bool = True,
|
|
clip_obs: float = 10.0,
|
|
clip_reward: float = 10.0,
|
|
gamma: float = 0.99,
|
|
epsilon: float = 1e-8,
|
|
):
|
|
VecEnvWrapper.__init__(self, venv)
|
|
|
|
assert isinstance(
|
|
self.observation_space, (gym.spaces.Box, gym.spaces.Dict)
|
|
), "VecNormalize only support `gym.spaces.Box` and `gym.spaces.Dict` observation spaces"
|
|
|
|
if isinstance(self.observation_space, gym.spaces.Dict):
|
|
self.obs_keys = set(self.observation_space.spaces.keys())
|
|
self.obs_spaces = self.observation_space.spaces
|
|
self.obs_rms = {key: RunningMeanStd(shape=space.shape) for key, space in self.obs_spaces.items()}
|
|
else:
|
|
self.obs_keys, self.obs_spaces = None, None
|
|
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
|
|
|
|
self.ret_rms = RunningMeanStd(shape=())
|
|
self.clip_obs = clip_obs
|
|
self.clip_reward = clip_reward
|
|
# Returns: discounted rewards
|
|
self.ret = np.zeros(self.num_envs)
|
|
self.gamma = gamma
|
|
self.epsilon = epsilon
|
|
self.training = training
|
|
self.norm_obs = norm_obs
|
|
self.norm_reward = norm_reward
|
|
self.old_obs = np.array([])
|
|
self.old_reward = np.array([])
|
|
|
|
def __getstate__(self) -> Dict[str, Any]:
|
|
"""
|
|
Gets state for pickling.
|
|
|
|
Excludes self.venv, as in general VecEnv's may not be pickleable."""
|
|
state = self.__dict__.copy()
|
|
# these attributes are not pickleable
|
|
del state["venv"]
|
|
del state["class_attributes"]
|
|
# these attributes depend on the above and so we would prefer not to pickle
|
|
del state["ret"]
|
|
return state
|
|
|
|
def __setstate__(self, state: Dict[str, Any]) -> None:
|
|
"""
|
|
Restores pickled state.
|
|
|
|
User must call set_venv() after unpickling before using.
|
|
|
|
:param state:"""
|
|
self.__dict__.update(state)
|
|
assert "venv" not in state
|
|
self.venv = None
|
|
|
|
def set_venv(self, venv: VecEnv) -> None:
|
|
"""
|
|
Sets the vector environment to wrap to venv.
|
|
|
|
Also sets attributes derived from this such as `num_env`.
|
|
|
|
:param venv:
|
|
"""
|
|
if self.venv is not None:
|
|
raise ValueError("Trying to set venv of already initialized VecNormalize wrapper.")
|
|
VecEnvWrapper.__init__(self, venv)
|
|
|
|
# Check only that the observation_space match
|
|
utils.check_for_correct_spaces(venv, self.observation_space, venv.action_space)
|
|
self.ret = np.zeros(self.num_envs)
|
|
|
|
def step_wait(self) -> VecEnvStepReturn:
|
|
"""
|
|
Apply sequence of actions to sequence of environments
|
|
actions -> (observations, rewards, news)
|
|
|
|
where 'news' is a boolean vector indicating whether each element is new.
|
|
"""
|
|
obs, rews, news, infos = self.venv.step_wait()
|
|
self.old_obs = obs
|
|
self.old_reward = rews
|
|
|
|
if self.training:
|
|
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
|
|
for key in self.obs_rms.keys():
|
|
self.obs_rms[key].update(obs[key])
|
|
else:
|
|
self.obs_rms.update(obs)
|
|
|
|
obs = self.normalize_obs(obs)
|
|
|
|
if self.training:
|
|
self._update_reward(rews)
|
|
rews = self.normalize_reward(rews)
|
|
|
|
self.ret[news] = 0
|
|
return obs, rews, news, infos
|
|
|
|
def _update_reward(self, reward: np.ndarray) -> None:
|
|
"""Update reward normalization statistics."""
|
|
self.ret = self.ret * self.gamma + reward
|
|
self.ret_rms.update(self.ret)
|
|
|
|
def _normalize_obs(self, obs: np.ndarray, obs_rms: RunningMeanStd) -> np.ndarray:
|
|
"""
|
|
Helper to normalize observation.
|
|
:param obs:
|
|
:param obs_rms: associated statistics
|
|
:return: normalized observation
|
|
"""
|
|
return np.clip((obs - obs_rms.mean) / np.sqrt(obs_rms.var + self.epsilon), -self.clip_obs, self.clip_obs)
|
|
|
|
def _unnormalize_obs(self, obs: np.ndarray, obs_rms: RunningMeanStd) -> np.ndarray:
|
|
"""
|
|
Helper to unnormalize observation.
|
|
:param obs:
|
|
:param obs_rms: associated statistics
|
|
:return: unnormalized observation
|
|
"""
|
|
return (obs * np.sqrt(obs_rms.var + self.epsilon)) + obs_rms.mean
|
|
|
|
def normalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
|
"""
|
|
Normalize observations using this VecNormalize's observations statistics.
|
|
Calling this method does not update statistics.
|
|
"""
|
|
# Avoid modifying by reference the original object
|
|
obs_ = deepcopy(obs)
|
|
if self.norm_obs:
|
|
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
|
|
for key in self.obs_rms.keys():
|
|
obs_[key] = self._normalize_obs(obs[key], self.obs_rms[key]).astype(np.float32)
|
|
else:
|
|
obs_ = self._normalize_obs(obs, self.obs_rms).astype(np.float32)
|
|
return obs_
|
|
|
|
def normalize_reward(self, reward: np.ndarray) -> np.ndarray:
|
|
"""
|
|
Normalize rewards using this VecNormalize's rewards statistics.
|
|
Calling this method does not update statistics.
|
|
"""
|
|
if self.norm_reward:
|
|
reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward)
|
|
return reward
|
|
|
|
def unnormalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
|
# Avoid modifying by reference the original object
|
|
obs_ = deepcopy(obs)
|
|
if self.norm_obs:
|
|
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
|
|
for key in self.obs_rms.keys():
|
|
obs_[key] = self._unnormalize_obs(obs[key], self.obs_rms[key])
|
|
else:
|
|
obs_ = self._unnormalize_obs(obs, self.obs_rms)
|
|
return obs_
|
|
|
|
def unnormalize_reward(self, reward: np.ndarray) -> np.ndarray:
|
|
if self.norm_reward:
|
|
return reward * np.sqrt(self.ret_rms.var + self.epsilon)
|
|
return reward
|
|
|
|
def get_original_obs(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
|
"""
|
|
Returns an unnormalized version of the observations from the most recent
|
|
step or reset.
|
|
"""
|
|
return deepcopy(self.old_obs)
|
|
|
|
def get_original_reward(self) -> np.ndarray:
|
|
"""
|
|
Returns an unnormalized version of the rewards from the most recent step.
|
|
"""
|
|
return self.old_reward.copy()
|
|
|
|
def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
|
"""
|
|
Reset all environments
|
|
:return: first observation of the episode
|
|
"""
|
|
obs = self.venv.reset()
|
|
self.old_obs = obs
|
|
self.ret = np.zeros(self.num_envs)
|
|
if self.training:
|
|
self._update_reward(self.ret)
|
|
return self.normalize_obs(obs)
|
|
|
|
@staticmethod
|
|
def load(load_path: str, venv: VecEnv) -> "VecNormalize":
|
|
"""
|
|
Loads a saved VecNormalize object.
|
|
|
|
:param load_path: the path to load from.
|
|
:param venv: the VecEnv to wrap.
|
|
:return:
|
|
"""
|
|
with open(load_path, "rb") as file_handler:
|
|
vec_normalize = pickle.load(file_handler)
|
|
vec_normalize.set_venv(venv)
|
|
return vec_normalize
|
|
|
|
def save(self, save_path: str) -> None:
|
|
"""
|
|
Save current VecNormalize object with
|
|
all running statistics and settings (e.g. clip_obs)
|
|
|
|
:param save_path: The path to save to
|
|
"""
|
|
with open(save_path, "wb") as file_handler:
|
|
pickle.dump(self, file_handler)
|