mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
* First commit * Fixing missing refs from a quick merge from master * Reformat * Adding DictBuffers * Reformat * Minor reformat * added slow dict test. Added SACMultiInputPolicy for future. Added private static image transpose helper to common policy * Ran black on buffers * Ran isort * Adding StackedObservations classes used within VecStackEnvs wrappers. Made test_dict_env shorter and removed slow * Running isort :facepalm * Fixed typing issues * Adding docstrings and typing. Using util for moving data to device. * Fixed trailing commas * Fix types * Minor edits * Avoid duplicating code * Fix calls to parents * Adding assert to buffers. Updating changelong * Running format on buffers * Adding multi-input policies to dqn,td3,a2c. Fixing warnings. Fixed bug with DictReplayBuffer as Replay buffers use only 1 env * Fixing warnings, splitting is_vectorized_observation into multiple functions based on space type * Created envs folder in common. Updated imports. Moved stacked_obs to vec_env folder * Moved envs to envs directory. Moved stacked obs to vec_envs. Started update on documentation * Fixes * Running code style * Update docstrings on torch_layers * Decapitalize non-constant variables * Using NatureCNN architecture in combined extractor. Increasing img size in multi input env. Adding memory reduction in test * Update doc * Update doc * Fix format * Removing NineRoom env. Using nested preprocess. Removing mutable default args * running code style * Passing channel check through to stacked dict observations. * Running black * Adding channel control to SimpleMultiObsEnv. Passing check_channels to CombinedExtractor * Remove optimize memory for dict buffers * Update doc * Move identity env * Minor edits + bump version * Update doc * Fix doc build * Bug fixes + add support for more type of dict env * Fixes + add multi env test * Add support for vectranspose * Fix stacked obs for dict and add tests * Add check for nested spaces. Fix dict-subprocvecenv test * Fix (single) pytype error * Simplify CombinedExtractor * Fix tests * Fix check * Merge branch 'master' into feat/dict_observations * Fix for net_arch with dict and vector obs * Fixes * Add consistency test * Update env checker * Add some docs on dict obs * Update default CNN feature vector size * Refactor HER (#351) * Start refactoring HER * Fixes * Additional fixes * Faster tests * WIP: HER as a custom replay buffer * New replay only version (working with DQN) * Add support for all off-policy algorithms * Fix saving/loading * Remove ObsDictWrapper and add VecNormalize tests with dict * Stable-Baselines3 v1.0 (#354) * Bump version and update doc * Fix name * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> * Update docs/index.rst Co-authored-by: Adam Gleave <adam@gleave.me> * Update wording for RL zoo Co-authored-by: Adam Gleave <adam@gleave.me> * Add gym-pybullet-drones project (#358) * Update projects.rst Added gym-pybullet-drones * Update projects.rst Longer title underline * Update changelog Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org> * Include SuperSuit in projects (#359) * include supersuit * longer title underline * Update changelog.rst * Fix default arguments + add bugbear (#363) * Fix potential bug + add bug bear * Remove unused variables * Minor: version bump * Add code of conduct + update doc (#373) * Add code of conduct * Fix DQN doc example * Update doc (channel-last/first) * Apply suggestions from code review Co-authored-by: Anssi <kaneran21@hotmail.com> * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Adam Gleave <adam@gleave.me> * Make installation command compatible with ZSH (#376) * Add quotes * Add Zsh bracket info * Add clarify pip installation line * Make note bold * Add Zsh pip installation note * Add handle timeouts param * Fixes * Fixes (buffer size, extend test) * Fix `max_episode_length` redefinition * Fix potential issue * Add some docs on dict obs * Fix performance bug * Fix slowdown * Add package to install (#378) * Add package to install * Update docs packages installation command Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Fix backward compat + add test * Fix VecEnv detection * Update doc * Fix vec env check * Support for `VecMonitor` for gym3-style environments (#311) * add vectorized monitor * auto format of the code * add documentation and VecExtractDictObs * refactor and add test cases * add test cases and format * avoid circular import and fix doc * fix type * fix type * oops * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * add test cases * update changelog * fix mutable argument * quick fix * Apply suggestions from code review * fix terminal observation for gym3 envs * delete comment * Update doc and bump version * Add warning when already using `Monitor` wrapper * Update vecmonitor tests * Fixes Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Reformat * Fixed loading of ``ent_coef`` for ``SAC`` and ``TQC``, it was not optimized anymore (#392) * Fix ent coef loading bug * Add test * Add comment * Reuse save path * Add test for GAE + rename `RolloutBuffer.dones` for clarification (#375) * Fix return computation + add test for GAE * Rename `last_dones` to `episode_starts` for clarification * Revert advantage * Cleanup test * Rename variable * Clarify return computation * Clarify docs * Add multi-episode rollout test * Reformat Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com> * Fixed saving of `A2C` and `PPO` policy when using gSDE (#401) * Improve doc and replay buffer loading * Add support for images * Fix doc * Update Procgen doc * Update changelog * Update docstrings Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca> Co-authored-by: Justin Terry <justinkterry@gmail.com> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Tom Dörr <tomdoerr96@gmail.com> Co-authored-by: Tom Dörr <tom.doerr@tum.de> Co-authored-by: Costa Huang <costa.huang@outlook.com> * Update doc and minor fixes * Update doc * Added note about MultiInputPolicy in error of NatureCNN * Merge branch 'master' into feat/dict_observations * Address comments * Naming clarifications * Actually saving the file would be nice * Fix edge case when doing online sampling with HER * Cleanup * Add sanity check Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com> Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca> Co-authored-by: Justin Terry <justinkterry@gmail.com> Co-authored-by: Tom Dörr <tomdoerr96@gmail.com> Co-authored-by: Tom Dörr <tom.doerr@tum.de> Co-authored-by: Costa Huang <costa.huang@outlook.com>
302 lines
10 KiB
Python
302 lines
10 KiB
Python
import gym
|
|
import numpy as np
|
|
import pytest
|
|
from gym import spaces
|
|
|
|
from stable_baselines3 import SAC, TD3, HerReplayBuffer
|
|
from stable_baselines3.common.monitor import Monitor
|
|
from stable_baselines3.common.running_mean_std import RunningMeanStd
|
|
from stable_baselines3.common.vec_env import (
|
|
DummyVecEnv,
|
|
VecFrameStack,
|
|
VecNormalize,
|
|
sync_envs_normalization,
|
|
unwrap_vec_normalize,
|
|
)
|
|
|
|
ENV_ID = "Pendulum-v0"
|
|
|
|
|
|
class DummyDictEnv(gym.GoalEnv):
|
|
"""
|
|
Dummy gym goal env for testing purposes
|
|
"""
|
|
|
|
def __init__(self):
|
|
super(DummyDictEnv, self).__init__()
|
|
self.observation_space = spaces.Dict(
|
|
{
|
|
"observation": spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32),
|
|
"achieved_goal": spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32),
|
|
"desired_goal": spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32),
|
|
}
|
|
)
|
|
self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32)
|
|
|
|
def reset(self):
|
|
return self.observation_space.sample()
|
|
|
|
def step(self, action):
|
|
obs = self.observation_space.sample()
|
|
reward = self.compute_reward(obs["achieved_goal"], obs["desired_goal"], {})
|
|
done = np.random.rand() > 0.8
|
|
return obs, reward, done, {}
|
|
|
|
def compute_reward(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, _info) -> np.float32:
|
|
distance = np.linalg.norm(achieved_goal - desired_goal, axis=-1)
|
|
return -(distance > 0).astype(np.float32)
|
|
|
|
|
|
def allclose(obs_1, obs_2):
|
|
"""
|
|
Generalized np.allclose() to work with dict spaces.
|
|
"""
|
|
if isinstance(obs_1, dict):
|
|
all_close = True
|
|
for key in obs_1.keys():
|
|
if not np.allclose(obs_1[key], obs_2[key]):
|
|
all_close = False
|
|
break
|
|
return all_close
|
|
return np.allclose(obs_1, obs_2)
|
|
|
|
|
|
def make_env():
|
|
return Monitor(gym.make(ENV_ID))
|
|
|
|
|
|
def make_dict_env():
|
|
return Monitor(DummyDictEnv())
|
|
|
|
|
|
def check_rms_equal(rmsa, rmsb):
|
|
if isinstance(rmsa, dict):
|
|
for key in rmsa.keys():
|
|
assert np.all(rmsa[key].mean == rmsb[key].mean)
|
|
assert np.all(rmsa[key].var == rmsb[key].var)
|
|
assert np.all(rmsa[key].count == rmsb[key].count)
|
|
else:
|
|
assert np.all(rmsa.mean == rmsb.mean)
|
|
assert np.all(rmsa.var == rmsb.var)
|
|
assert np.all(rmsa.count == rmsb.count)
|
|
|
|
|
|
def check_vec_norm_equal(norma, normb):
|
|
assert norma.observation_space == normb.observation_space
|
|
assert norma.action_space == normb.action_space
|
|
assert norma.num_envs == normb.num_envs
|
|
|
|
check_rms_equal(norma.obs_rms, normb.obs_rms)
|
|
check_rms_equal(norma.ret_rms, normb.ret_rms)
|
|
assert norma.clip_obs == normb.clip_obs
|
|
assert norma.clip_reward == normb.clip_reward
|
|
assert norma.norm_obs == normb.norm_obs
|
|
assert norma.norm_reward == normb.norm_reward
|
|
|
|
assert np.all(norma.ret == normb.ret)
|
|
assert norma.gamma == normb.gamma
|
|
assert norma.epsilon == normb.epsilon
|
|
assert norma.training == normb.training
|
|
|
|
|
|
def _make_warmstart_cartpole():
|
|
"""Warm-start VecNormalize by stepping through CartPole"""
|
|
venv = DummyVecEnv([lambda: gym.make("CartPole-v1")])
|
|
venv = VecNormalize(venv)
|
|
venv.reset()
|
|
venv.get_original_obs()
|
|
|
|
for _ in range(100):
|
|
actions = [venv.action_space.sample()]
|
|
venv.step(actions)
|
|
return venv
|
|
|
|
|
|
def _make_warmstart_dict_env():
|
|
"""Warm-start VecNormalize by stepping through BitFlippingEnv"""
|
|
venv = DummyVecEnv([make_dict_env])
|
|
venv = VecNormalize(venv)
|
|
venv.reset()
|
|
venv.get_original_obs()
|
|
|
|
for _ in range(100):
|
|
actions = [venv.action_space.sample()]
|
|
venv.step(actions)
|
|
return venv
|
|
|
|
|
|
def test_runningmeanstd():
|
|
"""Test RunningMeanStd object"""
|
|
for (x_1, x_2, x_3) in [
|
|
(np.random.randn(3), np.random.randn(4), np.random.randn(5)),
|
|
(np.random.randn(3, 2), np.random.randn(4, 2), np.random.randn(5, 2)),
|
|
]:
|
|
rms = RunningMeanStd(epsilon=0.0, shape=x_1.shape[1:])
|
|
|
|
x_cat = np.concatenate([x_1, x_2, x_3], axis=0)
|
|
moments_1 = [x_cat.mean(axis=0), x_cat.var(axis=0)]
|
|
rms.update(x_1)
|
|
rms.update(x_2)
|
|
rms.update(x_3)
|
|
moments_2 = [rms.mean, rms.var]
|
|
|
|
assert np.allclose(moments_1, moments_2)
|
|
|
|
|
|
@pytest.mark.parametrize("make_env", [make_env, make_dict_env])
|
|
def test_vec_env(tmp_path, make_env):
|
|
"""Test VecNormalize Object"""
|
|
clip_obs = 0.5
|
|
clip_reward = 5.0
|
|
|
|
orig_venv = DummyVecEnv([make_env])
|
|
norm_venv = VecNormalize(orig_venv, norm_obs=True, norm_reward=True, clip_obs=clip_obs, clip_reward=clip_reward)
|
|
_, done = norm_venv.reset(), [False]
|
|
while not done[0]:
|
|
actions = [norm_venv.action_space.sample()]
|
|
obs, rew, done, _ = norm_venv.step(actions)
|
|
if isinstance(obs, dict):
|
|
for key in obs.keys():
|
|
assert np.max(np.abs(obs[key])) <= clip_obs
|
|
else:
|
|
assert np.max(np.abs(obs)) <= clip_obs
|
|
assert np.max(np.abs(rew)) <= clip_reward
|
|
|
|
path = tmp_path / "vec_normalize"
|
|
norm_venv.save(path)
|
|
deserialized = VecNormalize.load(path, venv=orig_venv)
|
|
check_vec_norm_equal(norm_venv, deserialized)
|
|
|
|
|
|
def test_get_original():
|
|
venv = _make_warmstart_cartpole()
|
|
for _ in range(3):
|
|
actions = [venv.action_space.sample()]
|
|
obs, rewards, _, _ = venv.step(actions)
|
|
obs = obs[0]
|
|
orig_obs = venv.get_original_obs()[0]
|
|
rewards = rewards[0]
|
|
orig_rewards = venv.get_original_reward()[0]
|
|
|
|
assert np.all(orig_rewards == 1)
|
|
assert orig_obs.shape == obs.shape
|
|
assert orig_rewards.dtype == rewards.dtype
|
|
assert not np.array_equal(orig_obs, obs)
|
|
assert not np.array_equal(orig_rewards, rewards)
|
|
np.testing.assert_allclose(venv.normalize_obs(orig_obs), obs)
|
|
np.testing.assert_allclose(venv.normalize_reward(orig_rewards), rewards)
|
|
|
|
|
|
def test_get_original_dict():
|
|
venv = _make_warmstart_dict_env()
|
|
for _ in range(3):
|
|
actions = [venv.action_space.sample()]
|
|
obs, rewards, _, _ = venv.step(actions)
|
|
# obs = obs[0]
|
|
orig_obs = venv.get_original_obs()
|
|
rewards = rewards[0]
|
|
orig_rewards = venv.get_original_reward()[0]
|
|
|
|
for key in orig_obs.keys():
|
|
assert orig_obs[key].shape == obs[key].shape
|
|
assert orig_rewards.dtype == rewards.dtype
|
|
|
|
assert not allclose(orig_obs, obs)
|
|
assert not np.array_equal(orig_rewards, rewards)
|
|
assert allclose(venv.normalize_obs(orig_obs), obs)
|
|
np.testing.assert_allclose(venv.normalize_reward(orig_rewards), rewards)
|
|
|
|
|
|
def test_normalize_external():
|
|
venv = _make_warmstart_cartpole()
|
|
|
|
rewards = np.array([1, 1])
|
|
norm_rewards = venv.normalize_reward(rewards)
|
|
assert norm_rewards.shape == rewards.shape
|
|
# Episode return is almost always >= 1 in CartPole. So reward should shrink.
|
|
assert np.all(norm_rewards < 1)
|
|
|
|
|
|
@pytest.mark.parametrize("model_class", [SAC, TD3, HerReplayBuffer])
|
|
@pytest.mark.parametrize("online_sampling", [False, True])
|
|
def test_offpolicy_normalization(model_class, online_sampling):
|
|
|
|
if online_sampling and model_class != HerReplayBuffer:
|
|
pytest.skip()
|
|
|
|
make_env_ = make_dict_env if model_class == HerReplayBuffer else make_env
|
|
env = DummyVecEnv([make_env_])
|
|
env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.0, clip_reward=10.0)
|
|
|
|
eval_env = DummyVecEnv([make_env_])
|
|
eval_env = VecNormalize(eval_env, training=False, norm_obs=True, norm_reward=False, clip_obs=10.0, clip_reward=10.0)
|
|
|
|
if model_class == HerReplayBuffer:
|
|
model = SAC(
|
|
"MultiInputPolicy",
|
|
env,
|
|
verbose=1,
|
|
learning_starts=100,
|
|
policy_kwargs=dict(net_arch=[64]),
|
|
replay_buffer_kwargs=dict(
|
|
max_episode_length=100,
|
|
online_sampling=online_sampling,
|
|
n_sampled_goal=2,
|
|
),
|
|
replay_buffer_class=HerReplayBuffer,
|
|
seed=2,
|
|
)
|
|
else:
|
|
model = model_class("MlpPolicy", env, verbose=1, learning_starts=100, policy_kwargs=dict(net_arch=[64]))
|
|
|
|
model.learn(total_timesteps=150, eval_env=eval_env, eval_freq=75)
|
|
# Check getter
|
|
assert isinstance(model.get_vec_normalize_env(), VecNormalize)
|
|
|
|
|
|
@pytest.mark.parametrize("make_env", [make_env, make_dict_env])
|
|
def test_sync_vec_normalize(make_env):
|
|
env = DummyVecEnv([make_env])
|
|
|
|
assert unwrap_vec_normalize(env) is None
|
|
|
|
env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0)
|
|
|
|
assert isinstance(unwrap_vec_normalize(env), VecNormalize)
|
|
|
|
if not isinstance(env.observation_space, spaces.Dict):
|
|
env = VecFrameStack(env, 1)
|
|
assert isinstance(unwrap_vec_normalize(env), VecNormalize)
|
|
|
|
eval_env = DummyVecEnv([make_env])
|
|
eval_env = VecNormalize(eval_env, training=False, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0)
|
|
|
|
if not isinstance(env.observation_space, spaces.Dict):
|
|
eval_env = VecFrameStack(eval_env, 1)
|
|
|
|
env.seed(0)
|
|
env.action_space.seed(0)
|
|
|
|
env.reset()
|
|
# Initialize running mean
|
|
latest_reward = None
|
|
for _ in range(100):
|
|
_, latest_reward, _, _ = env.step([env.action_space.sample()])
|
|
|
|
# Check that unnormalized reward is same as original reward
|
|
original_latest_reward = env.get_original_reward()
|
|
assert np.allclose(original_latest_reward, env.unnormalize_reward(latest_reward))
|
|
|
|
obs = env.reset()
|
|
dummy_rewards = np.random.rand(10)
|
|
original_obs = env.get_original_obs()
|
|
# Check that unnormalization works
|
|
assert allclose(original_obs, env.unnormalize_obs(obs))
|
|
# Normalization must be different (between different environments)
|
|
assert not allclose(obs, eval_env.normalize_obs(original_obs))
|
|
|
|
# Test syncing of parameters
|
|
sync_envs_normalization(env, eval_env)
|
|
# Now they must be synced
|
|
assert allclose(obs, eval_env.normalize_obs(original_obs))
|
|
assert allclose(env.normalize_reward(dummy_rewards), eval_env.normalize_reward(dummy_rewards))
|