stable-baselines3/tests/test_vec_normalize.py
Jaden Travnik 75b6f3b3b0
Dictionary Observations (#243)
* First commit

* Fixing missing refs from a quick merge from master

* Reformat

* Adding DictBuffers

* Reformat

* Minor reformat

* added slow dict test. Added SACMultiInputPolicy for future. Added private static image transpose helper to common policy

* Ran black on buffers

* Ran isort

* Adding StackedObservations classes used within VecStackEnvs wrappers. Made test_dict_env shorter and removed slow

* Running isort :facepalm

* Fixed typing issues

* Adding docstrings and typing. Using util for moving data to device.

* Fixed trailing commas

* Fix types

* Minor edits

* Avoid duplicating code

* Fix calls to parents

* Adding assert to buffers. Updating changelong

* Running format on buffers

* Adding multi-input policies to dqn,td3,a2c. Fixing warnings. Fixed bug with DictReplayBuffer as Replay buffers use only 1 env

* Fixing warnings, splitting is_vectorized_observation into multiple functions based on space type

* Created envs folder in common. Updated imports. Moved stacked_obs to vec_env folder

* Moved envs to envs directory. Moved stacked obs to vec_envs. Started update on documentation

* Fixes

* Running code style

* Update docstrings on torch_layers

* Decapitalize non-constant variables

* Using NatureCNN architecture in combined extractor. Increasing img size in multi input env. Adding memory reduction in test

* Update doc

* Update doc

* Fix format

* Removing NineRoom env. Using nested preprocess. Removing mutable default args

* running code style

* Passing channel check through to stacked dict observations.

* Running black

* Adding channel control to SimpleMultiObsEnv. Passing check_channels to CombinedExtractor

* Remove optimize memory for dict buffers

* Update doc

* Move identity env

* Minor edits + bump version

* Update doc

* Fix doc build

* Bug fixes + add support for more type of dict env

* Fixes + add multi env test

* Add support for vectranspose

* Fix stacked obs for dict and add tests

* Add check for nested spaces. Fix dict-subprocvecenv test

* Fix (single) pytype error

* Simplify CombinedExtractor

* Fix tests

* Fix check

* Merge branch 'master' into feat/dict_observations

* Fix for net_arch with dict and vector obs

* Fixes

* Add consistency test

* Update env checker

* Add some docs on dict obs

* Update default CNN feature vector size

* Refactor HER (#351)

* Start refactoring HER

* Fixes

* Additional fixes

* Faster tests

* WIP: HER as a custom replay buffer

* New replay only version (working with DQN)

* Add support for all off-policy algorithms

* Fix saving/loading

* Remove ObsDictWrapper and add VecNormalize tests with dict

* Stable-Baselines3 v1.0 (#354)

* Bump version and update doc

* Fix name

* Apply suggestions from code review

Co-authored-by: Adam Gleave <adam@gleave.me>

* Update docs/index.rst

Co-authored-by: Adam Gleave <adam@gleave.me>

* Update wording for RL zoo

Co-authored-by: Adam Gleave <adam@gleave.me>

* Add gym-pybullet-drones project (#358)

* Update projects.rst

Added gym-pybullet-drones

* Update projects.rst

Longer title underline

* Update changelog

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>

* Include SuperSuit in projects (#359)

* include supersuit

* longer title underline

* Update changelog.rst

* Fix default arguments + add bugbear (#363)

* Fix potential bug + add bug bear

* Remove unused variables

* Minor: version bump

* Add code of conduct + update doc (#373)

* Add code of conduct

* Fix DQN doc example

* Update doc (channel-last/first)

* Apply suggestions from code review

Co-authored-by: Anssi <kaneran21@hotmail.com>

* Apply suggestions from code review

Co-authored-by: Adam Gleave <adam@gleave.me>

Co-authored-by: Anssi <kaneran21@hotmail.com>
Co-authored-by: Adam Gleave <adam@gleave.me>

* Make installation command compatible with ZSH (#376)

* Add quotes

* Add Zsh bracket info

* Add clarify pip installation line

* Make note bold

* Add Zsh pip installation note

* Add handle timeouts param

* Fixes

* Fixes (buffer size, extend test)

* Fix `max_episode_length` redefinition

* Fix potential issue

* Add some docs on dict obs

* Fix performance bug

* Fix slowdown

* Add package to install (#378)

* Add package to install

* Update docs packages installation command

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Fix backward compat + add test

* Fix VecEnv detection

* Update doc

* Fix vec env check

* Support for `VecMonitor` for gym3-style environments (#311)

* add vectorized monitor

* auto format of the code

* add documentation and VecExtractDictObs

* refactor and add test cases

* add test cases and format

* avoid circular import and fix doc

* fix type

* fix type

* oops

* Update stable_baselines3/common/monitor.py

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Update stable_baselines3/common/monitor.py

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* add test cases

* update changelog

* fix mutable argument

* quick fix

* Apply suggestions from code review

* fix terminal observation for gym3 envs

* delete comment

* Update doc and bump version

* Add warning when already using `Monitor` wrapper

* Update vecmonitor tests

* Fixes

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Reformat

* Fixed loading of ``ent_coef`` for ``SAC`` and ``TQC``, it was not optimized anymore (#392)

* Fix ent coef loading bug

* Add test

* Add comment

* Reuse save path

* Add test for GAE + rename `RolloutBuffer.dones` for clarification (#375)

* Fix return computation + add test for GAE

* Rename `last_dones` to `episode_starts` for clarification

* Revert advantage

* Cleanup test

* Rename variable

* Clarify return computation

* Clarify docs

* Add multi-episode rollout test

* Reformat

Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com>

* Fixed saving of `A2C` and `PPO` policy when using gSDE (#401)

* Improve doc and replay buffer loading

* Add support for images

* Fix doc

* Update Procgen doc

* Update changelog

* Update docstrings

Co-authored-by: Adam Gleave <adam@gleave.me>
Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca>
Co-authored-by: Justin Terry <justinkterry@gmail.com>
Co-authored-by: Anssi <kaneran21@hotmail.com>
Co-authored-by: Tom Dörr <tomdoerr96@gmail.com>
Co-authored-by: Tom Dörr <tom.doerr@tum.de>
Co-authored-by: Costa Huang <costa.huang@outlook.com>

* Update doc and minor fixes

* Update doc

* Added note about MultiInputPolicy in error of NatureCNN

* Merge branch 'master' into feat/dict_observations

* Address comments

* Naming clarifications

* Actually saving the file would be nice

* Fix edge case when doing online sampling with HER

* Cleanup

* Add sanity check

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com>
Co-authored-by: Adam Gleave <adam@gleave.me>
Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca>
Co-authored-by: Justin Terry <justinkterry@gmail.com>
Co-authored-by: Tom Dörr <tomdoerr96@gmail.com>
Co-authored-by: Tom Dörr <tom.doerr@tum.de>
Co-authored-by: Costa Huang <costa.huang@outlook.com>
2021-05-11 12:29:30 +02:00

302 lines
10 KiB
Python

import gym
import numpy as np
import pytest
from gym import spaces
from stable_baselines3 import SAC, TD3, HerReplayBuffer
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.running_mean_std import RunningMeanStd
from stable_baselines3.common.vec_env import (
DummyVecEnv,
VecFrameStack,
VecNormalize,
sync_envs_normalization,
unwrap_vec_normalize,
)
ENV_ID = "Pendulum-v0"
class DummyDictEnv(gym.GoalEnv):
"""
Dummy gym goal env for testing purposes
"""
def __init__(self):
super(DummyDictEnv, self).__init__()
self.observation_space = spaces.Dict(
{
"observation": spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32),
"achieved_goal": spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32),
"desired_goal": spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32),
}
)
self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32)
def reset(self):
return self.observation_space.sample()
def step(self, action):
obs = self.observation_space.sample()
reward = self.compute_reward(obs["achieved_goal"], obs["desired_goal"], {})
done = np.random.rand() > 0.8
return obs, reward, done, {}
def compute_reward(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, _info) -> np.float32:
distance = np.linalg.norm(achieved_goal - desired_goal, axis=-1)
return -(distance > 0).astype(np.float32)
def allclose(obs_1, obs_2):
"""
Generalized np.allclose() to work with dict spaces.
"""
if isinstance(obs_1, dict):
all_close = True
for key in obs_1.keys():
if not np.allclose(obs_1[key], obs_2[key]):
all_close = False
break
return all_close
return np.allclose(obs_1, obs_2)
def make_env():
return Monitor(gym.make(ENV_ID))
def make_dict_env():
return Monitor(DummyDictEnv())
def check_rms_equal(rmsa, rmsb):
if isinstance(rmsa, dict):
for key in rmsa.keys():
assert np.all(rmsa[key].mean == rmsb[key].mean)
assert np.all(rmsa[key].var == rmsb[key].var)
assert np.all(rmsa[key].count == rmsb[key].count)
else:
assert np.all(rmsa.mean == rmsb.mean)
assert np.all(rmsa.var == rmsb.var)
assert np.all(rmsa.count == rmsb.count)
def check_vec_norm_equal(norma, normb):
assert norma.observation_space == normb.observation_space
assert norma.action_space == normb.action_space
assert norma.num_envs == normb.num_envs
check_rms_equal(norma.obs_rms, normb.obs_rms)
check_rms_equal(norma.ret_rms, normb.ret_rms)
assert norma.clip_obs == normb.clip_obs
assert norma.clip_reward == normb.clip_reward
assert norma.norm_obs == normb.norm_obs
assert norma.norm_reward == normb.norm_reward
assert np.all(norma.ret == normb.ret)
assert norma.gamma == normb.gamma
assert norma.epsilon == normb.epsilon
assert norma.training == normb.training
def _make_warmstart_cartpole():
"""Warm-start VecNormalize by stepping through CartPole"""
venv = DummyVecEnv([lambda: gym.make("CartPole-v1")])
venv = VecNormalize(venv)
venv.reset()
venv.get_original_obs()
for _ in range(100):
actions = [venv.action_space.sample()]
venv.step(actions)
return venv
def _make_warmstart_dict_env():
"""Warm-start VecNormalize by stepping through BitFlippingEnv"""
venv = DummyVecEnv([make_dict_env])
venv = VecNormalize(venv)
venv.reset()
venv.get_original_obs()
for _ in range(100):
actions = [venv.action_space.sample()]
venv.step(actions)
return venv
def test_runningmeanstd():
"""Test RunningMeanStd object"""
for (x_1, x_2, x_3) in [
(np.random.randn(3), np.random.randn(4), np.random.randn(5)),
(np.random.randn(3, 2), np.random.randn(4, 2), np.random.randn(5, 2)),
]:
rms = RunningMeanStd(epsilon=0.0, shape=x_1.shape[1:])
x_cat = np.concatenate([x_1, x_2, x_3], axis=0)
moments_1 = [x_cat.mean(axis=0), x_cat.var(axis=0)]
rms.update(x_1)
rms.update(x_2)
rms.update(x_3)
moments_2 = [rms.mean, rms.var]
assert np.allclose(moments_1, moments_2)
@pytest.mark.parametrize("make_env", [make_env, make_dict_env])
def test_vec_env(tmp_path, make_env):
"""Test VecNormalize Object"""
clip_obs = 0.5
clip_reward = 5.0
orig_venv = DummyVecEnv([make_env])
norm_venv = VecNormalize(orig_venv, norm_obs=True, norm_reward=True, clip_obs=clip_obs, clip_reward=clip_reward)
_, done = norm_venv.reset(), [False]
while not done[0]:
actions = [norm_venv.action_space.sample()]
obs, rew, done, _ = norm_venv.step(actions)
if isinstance(obs, dict):
for key in obs.keys():
assert np.max(np.abs(obs[key])) <= clip_obs
else:
assert np.max(np.abs(obs)) <= clip_obs
assert np.max(np.abs(rew)) <= clip_reward
path = tmp_path / "vec_normalize"
norm_venv.save(path)
deserialized = VecNormalize.load(path, venv=orig_venv)
check_vec_norm_equal(norm_venv, deserialized)
def test_get_original():
venv = _make_warmstart_cartpole()
for _ in range(3):
actions = [venv.action_space.sample()]
obs, rewards, _, _ = venv.step(actions)
obs = obs[0]
orig_obs = venv.get_original_obs()[0]
rewards = rewards[0]
orig_rewards = venv.get_original_reward()[0]
assert np.all(orig_rewards == 1)
assert orig_obs.shape == obs.shape
assert orig_rewards.dtype == rewards.dtype
assert not np.array_equal(orig_obs, obs)
assert not np.array_equal(orig_rewards, rewards)
np.testing.assert_allclose(venv.normalize_obs(orig_obs), obs)
np.testing.assert_allclose(venv.normalize_reward(orig_rewards), rewards)
def test_get_original_dict():
venv = _make_warmstart_dict_env()
for _ in range(3):
actions = [venv.action_space.sample()]
obs, rewards, _, _ = venv.step(actions)
# obs = obs[0]
orig_obs = venv.get_original_obs()
rewards = rewards[0]
orig_rewards = venv.get_original_reward()[0]
for key in orig_obs.keys():
assert orig_obs[key].shape == obs[key].shape
assert orig_rewards.dtype == rewards.dtype
assert not allclose(orig_obs, obs)
assert not np.array_equal(orig_rewards, rewards)
assert allclose(venv.normalize_obs(orig_obs), obs)
np.testing.assert_allclose(venv.normalize_reward(orig_rewards), rewards)
def test_normalize_external():
venv = _make_warmstart_cartpole()
rewards = np.array([1, 1])
norm_rewards = venv.normalize_reward(rewards)
assert norm_rewards.shape == rewards.shape
# Episode return is almost always >= 1 in CartPole. So reward should shrink.
assert np.all(norm_rewards < 1)
@pytest.mark.parametrize("model_class", [SAC, TD3, HerReplayBuffer])
@pytest.mark.parametrize("online_sampling", [False, True])
def test_offpolicy_normalization(model_class, online_sampling):
if online_sampling and model_class != HerReplayBuffer:
pytest.skip()
make_env_ = make_dict_env if model_class == HerReplayBuffer else make_env
env = DummyVecEnv([make_env_])
env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.0, clip_reward=10.0)
eval_env = DummyVecEnv([make_env_])
eval_env = VecNormalize(eval_env, training=False, norm_obs=True, norm_reward=False, clip_obs=10.0, clip_reward=10.0)
if model_class == HerReplayBuffer:
model = SAC(
"MultiInputPolicy",
env,
verbose=1,
learning_starts=100,
policy_kwargs=dict(net_arch=[64]),
replay_buffer_kwargs=dict(
max_episode_length=100,
online_sampling=online_sampling,
n_sampled_goal=2,
),
replay_buffer_class=HerReplayBuffer,
seed=2,
)
else:
model = model_class("MlpPolicy", env, verbose=1, learning_starts=100, policy_kwargs=dict(net_arch=[64]))
model.learn(total_timesteps=150, eval_env=eval_env, eval_freq=75)
# Check getter
assert isinstance(model.get_vec_normalize_env(), VecNormalize)
@pytest.mark.parametrize("make_env", [make_env, make_dict_env])
def test_sync_vec_normalize(make_env):
env = DummyVecEnv([make_env])
assert unwrap_vec_normalize(env) is None
env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0)
assert isinstance(unwrap_vec_normalize(env), VecNormalize)
if not isinstance(env.observation_space, spaces.Dict):
env = VecFrameStack(env, 1)
assert isinstance(unwrap_vec_normalize(env), VecNormalize)
eval_env = DummyVecEnv([make_env])
eval_env = VecNormalize(eval_env, training=False, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0)
if not isinstance(env.observation_space, spaces.Dict):
eval_env = VecFrameStack(eval_env, 1)
env.seed(0)
env.action_space.seed(0)
env.reset()
# Initialize running mean
latest_reward = None
for _ in range(100):
_, latest_reward, _, _ = env.step([env.action_space.sample()])
# Check that unnormalized reward is same as original reward
original_latest_reward = env.get_original_reward()
assert np.allclose(original_latest_reward, env.unnormalize_reward(latest_reward))
obs = env.reset()
dummy_rewards = np.random.rand(10)
original_obs = env.get_original_obs()
# Check that unnormalization works
assert allclose(original_obs, env.unnormalize_obs(obs))
# Normalization must be different (between different environments)
assert not allclose(obs, eval_env.normalize_obs(original_obs))
# Test syncing of parameters
sync_envs_normalization(env, eval_env)
# Now they must be synced
assert allclose(obs, eval_env.normalize_obs(original_obs))
assert allclose(env.normalize_reward(dummy_rewards), eval_env.normalize_reward(dummy_rewards))