Merge branch 'master' into fix-common-vec_env-__init__-type-hint

This commit is contained in:
Quentin Gallouédec 2022-12-12 13:32:21 +01:00 committed by GitHub
commit 55faed59ca
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 154 additions and 35 deletions

View file

@ -22,9 +22,9 @@ jobs:
python-version: [3.7, 3.8, 3.9]
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies

View file

@ -333,11 +333,11 @@ If your task requires even more granular control over the policy/value architect
:return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
If all layers are shared, then ``latent_policy == latent_value``
"""
return self.policy_net(features), self.value_net(features)
return self.forward_actor(features), self.forward_critic(features)
def forward_actor(self, features: th.Tensor) -> th.Tensor:
return self.policy_net(features)
def forward_critic(self, features: th.Tensor) -> th.Tensor:
return self.value_net(features)

View file

@ -4,7 +4,7 @@ Changelog
==========
Release 1.7.0a4 (WIP)
Release 1.7.0a5 (WIP)
--------------------------
Breaking Changes:
@ -18,14 +18,15 @@ New Features:
^^^^^^^^^^^^^
- Introduced mypy type checking
- Added ``with_bias`` argument to ``create_mlp``
- Added support for multidimensional ``spaces.MultiBinary`` observations
SB3-Contrib
^^^^^^^^^^^
Bug Fixes:
^^^^^^^^^^
- Fix return type of ``evaluate_actions`` in ``ActorCritcPolicy`` to reflect that entropy is an optional tensor (@Rocamonde)
- Fix type annotation of ``policy`` in ``BaseAlgorithm`` and ``OffPolicyAlgorithm``
- Fixed return type of ``evaluate_actions`` in ``ActorCritcPolicy`` to reflect that entropy is an optional tensor (@Rocamonde)
- Fixed type annotation of ``policy`` in ``BaseAlgorithm`` and ``OffPolicyAlgorithm``
- Allowed model trained with Python 3.7 to be loaded with Python 3.8+ without the ``custom_objects`` workaround
- Raise an error when the same gym environment instance is passed as separate environments when creating a vectorized environment with more than one environment. (@Rocamonde)
- Fix type annotation of ``model`` in ``evaluate_policy``
@ -42,13 +43,17 @@ Others:
- Replaced ``CartPole-v0`` by ``CartPole-v1`` is tests
- Fixed ``tests/test_distributions.py`` type hint
- Fixed ``stable_baselines3/common/type_aliases.py`` type hint
- Fixed ``stable_baselines3/common/torch_layers.py`` type hint
- Fixed ``stable_baselines3/common/env_util.py`` type hint
- Exposed modules in ``__init__.py`` with the ``__all__`` attribute (@ZikangXiong)
- Upgraded GitHub CI/setup-python to v4 and checkout to v3
- Fixed ``stable_baselines3/common/vec_env/__init__.py`` type hint
Documentation:
^^^^^^^^^^^^^^
- Updated Hugging Face Integration page (@simoninithomas)
- Changed ``env`` to ``vec_env`` when environment is vectorized
- Update custom policy documentation (@athatheo)
Release 1.6.2 (2022-10-10)
--------------------------
@ -1127,4 +1132,4 @@ And all the contributors:
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong

View file

@ -47,7 +47,6 @@ exclude = (?x)(
| stable_baselines3/common/preprocessing.py$
| stable_baselines3/common/save_util.py$
| stable_baselines3/common/sb2_compat/rmsprop_tf_like.py$
| stable_baselines3/common/torch_layers.py$
| stable_baselines3/common/utils.py$
| stable_baselines3/common/vec_env/base_vec_env.py$
| stable_baselines3/common/vec_env/dummy_vec_env.py$
@ -80,17 +79,6 @@ exclude = (?x)(
ignore = W503,W504,E203,E231
# Ignore import not used when aliases are defined
per-file-ignores =
./stable_baselines3/__init__.py:F401
./stable_baselines3/common/__init__.py:F401
./stable_baselines3/common/envs/__init__.py:F401
./stable_baselines3/a2c/__init__.py:F401
./stable_baselines3/ddpg/__init__.py:F401
./stable_baselines3/dqn/__init__.py:F401
./stable_baselines3/her/__init__.py:F401
./stable_baselines3/ppo/__init__.py:F401
./stable_baselines3/sac/__init__.py:F401
./stable_baselines3/td3/__init__.py:F401
./stable_baselines3/common/vec_env/__init__.py:F401
# Default implementation in abstract methods
./stable_baselines3/common/callbacks.py:B027
./stable_baselines3/common/noise.py:B027

View file

@ -20,3 +20,15 @@ def HER(*args, **kwargs):
"Since Stable Baselines 2.1.0, `HER` is now a replay buffer class `HerReplayBuffer`.\n "
"Please check the documentation for more information: https://stable-baselines3.readthedocs.io/"
)
__all__ = [
"A2C",
"DDPG",
"DQN",
"PPO",
"SAC",
"TD3",
"HerReplayBuffer",
"get_system_info",
]

View file

@ -1,2 +1,4 @@
from stable_baselines3.a2c.a2c import A2C
from stable_baselines3.a2c.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "A2C"]

View file

@ -7,3 +7,14 @@ from stable_baselines3.common.envs.identity_env import (
IdentityEnvMultiDiscrete,
)
from stable_baselines3.common.envs.multi_input_envs import SimpleMultiObsEnv
__all__ = [
"BitFlippingEnv",
"FakeImageEnv",
"IdentityEnv",
"IdentityEnvBox",
"IdentityEnvMultiBinary",
"IdentityEnvMultiDiscrete",
"SimpleMultiObsEnv",
"SimpleMultiObsEnv",
]

View file

@ -150,7 +150,10 @@ def get_obs_shape(
return (int(len(observation_space.nvec)),)
elif isinstance(observation_space, spaces.MultiBinary):
# Number of binary features
return (int(observation_space.n),)
if type(observation_space.n) in [tuple, list, np.ndarray]:
return tuple(observation_space.n)
else:
return (int(observation_space.n),)
elif isinstance(observation_space, spaces.Dict):
return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()}

View file

@ -28,9 +28,6 @@ class BaseFeaturesExtractor(nn.Module):
def features_dim(self) -> int:
return self._features_dim
def forward(self, observations: th.Tensor) -> th.Tensor:
raise NotImplementedError()
class FlattenExtractor(BaseFeaturesExtractor):
"""
@ -173,9 +170,11 @@ class MlpExtractor(nn.Module):
):
super().__init__()
device = get_device(device)
shared_net, policy_net, value_net = [], [], []
policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network
value_only_layers = [] # Layer sizes of the network that only belongs to the value network
shared_net: List[nn.Module] = []
policy_net: List[nn.Module] = []
value_net: List[nn.Module] = []
policy_only_layers: List[int] = [] # Layer sizes of the network that only belongs to the policy network
value_only_layers: List[int] = [] # Layer sizes of the network that only belongs to the value network
last_layer_dim_shared = feature_dim
# Iterate through the shared layers and build the shared parts of the network
@ -254,7 +253,7 @@ class CombinedExtractor(BaseFeaturesExtractor):
# TODO we do not know features-dim here before going over all the items, so put something there. This is dirty!
super().__init__(observation_space, features_dim=1)
extractors = {}
extractors: Dict[str, nn.Module] = {}
total_concat_size = 0
for key, subspace in observation_space.spaces.items():

View file

@ -299,14 +299,14 @@ def is_vectorized_multibinary_observation(observation: np.ndarray, observation_s
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
if observation.shape == (observation_space.n,):
if observation.shape == observation_space.shape:
return False
elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n:
elif len(observation.shape) == len(observation_space.shape) + 1 and observation.shape[1:] == observation_space.shape:
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for MultiBinary "
+ f"environment, please use ({observation_space.n},) or "
+ f"environment, please use {observation_space.shape} or "
+ f"(n_env, {observation_space.n}) for the observation shape."
)

View file

@ -1,4 +1,3 @@
# flake8: noqa F401
import typing
from copy import deepcopy
from typing import Optional, Type, TypeVar, Union
@ -76,3 +75,25 @@ def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None:
eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms)
env_tmp = env_tmp.venv
eval_env_tmp = eval_env_tmp.venv
__all__ = [
"CloudpickleWrapper",
"VecEnv",
"VecEnvWrapper",
"DummyVecEnv",
"StackedDictObservations",
"StackedObservations",
"SubprocVecEnv",
"VecCheckNan",
"VecExtractDictObs",
"VecFrameStack",
"VecMonitor",
"VecNormalize",
"VecTransposeImage",
"VecVideoRecorder",
"unwrap_vec_wrapper",
"unwrap_vec_normalize",
"is_vecenv_wrapped",
"sync_envs_normalization",
]

View file

@ -1,2 +1,4 @@
from stable_baselines3.ddpg.ddpg import DDPG
from stable_baselines3.ddpg.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "DDPG"]

View file

@ -1,2 +1,4 @@
from stable_baselines3.dqn.dqn import DQN
from stable_baselines3.dqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "DQN"]

View file

@ -1,2 +1,4 @@
from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
__all__ = ["GoalSelectionStrategy", "HerReplayBuffer"]

View file

@ -1,2 +1,4 @@
from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from stable_baselines3.ppo.ppo import PPO
__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "PPO"]

View file

@ -1,2 +1,4 @@
from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from stable_baselines3.sac.sac import SAC
__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "SAC"]

View file

@ -1,2 +1,4 @@
from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from stable_baselines3.td3.td3 import TD3
__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "TD3"]

View file

@ -1 +1 @@
1.7.0a4
1.7.0a5

View file

@ -0,0 +1,66 @@
import torch
from gym import spaces
from stable_baselines3.common.preprocessing import get_obs_shape, preprocess_obs
def test_get_obs_shape_discrete():
assert get_obs_shape(spaces.Discrete(3)) == (1,)
def test_get_obs_shape_multidiscrete():
assert get_obs_shape(spaces.MultiDiscrete([3, 2])) == (2,)
def test_get_obs_shape_multibinary():
assert get_obs_shape(spaces.MultiBinary(3)) == (3,)
def test_get_obs_shape_multidimensional_multibinary():
assert get_obs_shape(spaces.MultiBinary([3, 2])) == (3, 2)
def test_get_obs_shape_box():
assert get_obs_shape(spaces.Box(-2, 2, shape=(3,))) == (3,)
def test_get_obs_shape_multidimensional_box():
assert get_obs_shape(spaces.Box(-2, 2, shape=(3, 2))) == (3, 2)
def test_preprocess_obs_discrete():
actual = preprocess_obs(torch.tensor([2], dtype=torch.long), spaces.Discrete(3))
expected = torch.tensor([[0.0, 0.0, 1.0]], dtype=torch.float32)
torch.testing.assert_close(actual, expected)
def test_preprocess_obs_multidiscrete():
actual = preprocess_obs(torch.tensor([[2, 0]], dtype=torch.long), spaces.MultiDiscrete([3, 2]))
expected = torch.tensor([[0.0, 0.0, 1.0, 1.0, 0.0]], dtype=torch.float32)
torch.testing.assert_close(actual, expected)
def test_preprocess_obs_multibinary():
actual = preprocess_obs(torch.tensor([[1, 0, 1]], dtype=torch.long), spaces.MultiBinary(3))
expected = torch.tensor([[1.0, 0.0, 1.0]], dtype=torch.float32)
torch.testing.assert_close(actual, expected)
def test_preprocess_obs_multidimensional_multibinary():
actual = preprocess_obs(torch.tensor([[[1, 0], [1, 1], [0, 1]]], dtype=torch.long), spaces.MultiBinary([3, 2]))
expected = torch.tensor([[[1.0, 0.0], [1.0, 1.0], [0.0, 1.0]]], dtype=torch.float32)
torch.testing.assert_close(actual, expected)
def test_preprocess_obs_box():
actual = preprocess_obs(torch.tensor([[1.5, 0.3, -1.8]], dtype=torch.float32), spaces.Box(-2, 2, shape=(3,)))
expected = torch.tensor([[1.5, 0.3, -1.8]], dtype=torch.float32)
torch.testing.assert_close(actual, expected)
def test_preprocess_obs_multidimensional_box():
actual = preprocess_obs(
torch.tensor([[[1.5, 0.3, -1.8], [0.1, -0.6, -1.4]]], dtype=torch.float32), spaces.Box(-2, 2, shape=(3, 2))
)
expected = torch.tensor([[[1.5, 0.3, -1.8], [0.1, -0.6, -1.4]]], dtype=torch.float32)
torch.testing.assert_close(actual, expected)

View file

@ -47,7 +47,7 @@ class DummyMultidimensionalAction(gym.Env):
@pytest.mark.parametrize("model_class", [SAC, TD3, DQN])
@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8)])
@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8), DummyMultiBinary((3, 2))])
def test_identity_spaces(model_class, env):
"""
Additional tests for DQ/SAC/TD3 to check observation space support