mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Merge branch 'master' into fix-common-vec_env-__init__-type-hint
This commit is contained in:
commit
55faed59ca
20 changed files with 154 additions and 35 deletions
4
.github/workflows/ci.yml
vendored
4
.github/workflows/ci.yml
vendored
|
|
@ -22,9 +22,9 @@ jobs:
|
|||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
|
|
|
|||
|
|
@ -333,11 +333,11 @@ If your task requires even more granular control over the policy/value architect
|
|||
:return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
|
||||
If all layers are shared, then ``latent_policy == latent_value``
|
||||
"""
|
||||
return self.policy_net(features), self.value_net(features)
|
||||
|
||||
return self.forward_actor(features), self.forward_critic(features)
|
||||
|
||||
def forward_actor(self, features: th.Tensor) -> th.Tensor:
|
||||
return self.policy_net(features)
|
||||
|
||||
|
||||
def forward_critic(self, features: th.Tensor) -> th.Tensor:
|
||||
return self.value_net(features)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.7.0a4 (WIP)
|
||||
Release 1.7.0a5 (WIP)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -18,14 +18,15 @@ New Features:
|
|||
^^^^^^^^^^^^^
|
||||
- Introduced mypy type checking
|
||||
- Added ``with_bias`` argument to ``create_mlp``
|
||||
- Added support for multidimensional ``spaces.MultiBinary`` observations
|
||||
|
||||
SB3-Contrib
|
||||
^^^^^^^^^^^
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fix return type of ``evaluate_actions`` in ``ActorCritcPolicy`` to reflect that entropy is an optional tensor (@Rocamonde)
|
||||
- Fix type annotation of ``policy`` in ``BaseAlgorithm`` and ``OffPolicyAlgorithm``
|
||||
- Fixed return type of ``evaluate_actions`` in ``ActorCritcPolicy`` to reflect that entropy is an optional tensor (@Rocamonde)
|
||||
- Fixed type annotation of ``policy`` in ``BaseAlgorithm`` and ``OffPolicyAlgorithm``
|
||||
- Allowed model trained with Python 3.7 to be loaded with Python 3.8+ without the ``custom_objects`` workaround
|
||||
- Raise an error when the same gym environment instance is passed as separate environments when creating a vectorized environment with more than one environment. (@Rocamonde)
|
||||
- Fix type annotation of ``model`` in ``evaluate_policy``
|
||||
|
|
@ -42,13 +43,17 @@ Others:
|
|||
- Replaced ``CartPole-v0`` by ``CartPole-v1`` is tests
|
||||
- Fixed ``tests/test_distributions.py`` type hint
|
||||
- Fixed ``stable_baselines3/common/type_aliases.py`` type hint
|
||||
- Fixed ``stable_baselines3/common/torch_layers.py`` type hint
|
||||
- Fixed ``stable_baselines3/common/env_util.py`` type hint
|
||||
- Exposed modules in ``__init__.py`` with the ``__all__`` attribute (@ZikangXiong)
|
||||
- Upgraded GitHub CI/setup-python to v4 and checkout to v3
|
||||
- Fixed ``stable_baselines3/common/vec_env/__init__.py`` type hint
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
- Updated Hugging Face Integration page (@simoninithomas)
|
||||
- Changed ``env`` to ``vec_env`` when environment is vectorized
|
||||
- Update custom policy documentation (@athatheo)
|
||||
|
||||
Release 1.6.2 (2022-10-10)
|
||||
--------------------------
|
||||
|
|
@ -1127,4 +1132,4 @@ And all the contributors:
|
|||
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
|
||||
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
|
||||
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
|
||||
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer
|
||||
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong
|
||||
|
|
|
|||
12
setup.cfg
12
setup.cfg
|
|
@ -47,7 +47,6 @@ exclude = (?x)(
|
|||
| stable_baselines3/common/preprocessing.py$
|
||||
| stable_baselines3/common/save_util.py$
|
||||
| stable_baselines3/common/sb2_compat/rmsprop_tf_like.py$
|
||||
| stable_baselines3/common/torch_layers.py$
|
||||
| stable_baselines3/common/utils.py$
|
||||
| stable_baselines3/common/vec_env/base_vec_env.py$
|
||||
| stable_baselines3/common/vec_env/dummy_vec_env.py$
|
||||
|
|
@ -80,17 +79,6 @@ exclude = (?x)(
|
|||
ignore = W503,W504,E203,E231
|
||||
# Ignore import not used when aliases are defined
|
||||
per-file-ignores =
|
||||
./stable_baselines3/__init__.py:F401
|
||||
./stable_baselines3/common/__init__.py:F401
|
||||
./stable_baselines3/common/envs/__init__.py:F401
|
||||
./stable_baselines3/a2c/__init__.py:F401
|
||||
./stable_baselines3/ddpg/__init__.py:F401
|
||||
./stable_baselines3/dqn/__init__.py:F401
|
||||
./stable_baselines3/her/__init__.py:F401
|
||||
./stable_baselines3/ppo/__init__.py:F401
|
||||
./stable_baselines3/sac/__init__.py:F401
|
||||
./stable_baselines3/td3/__init__.py:F401
|
||||
./stable_baselines3/common/vec_env/__init__.py:F401
|
||||
# Default implementation in abstract methods
|
||||
./stable_baselines3/common/callbacks.py:B027
|
||||
./stable_baselines3/common/noise.py:B027
|
||||
|
|
|
|||
|
|
@ -20,3 +20,15 @@ def HER(*args, **kwargs):
|
|||
"Since Stable Baselines 2.1.0, `HER` is now a replay buffer class `HerReplayBuffer`.\n "
|
||||
"Please check the documentation for more information: https://stable-baselines3.readthedocs.io/"
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"A2C",
|
||||
"DDPG",
|
||||
"DQN",
|
||||
"PPO",
|
||||
"SAC",
|
||||
"TD3",
|
||||
"HerReplayBuffer",
|
||||
"get_system_info",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,2 +1,4 @@
|
|||
from stable_baselines3.a2c.a2c import A2C
|
||||
from stable_baselines3.a2c.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||
|
||||
__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "A2C"]
|
||||
|
|
|
|||
|
|
@ -7,3 +7,14 @@ from stable_baselines3.common.envs.identity_env import (
|
|||
IdentityEnvMultiDiscrete,
|
||||
)
|
||||
from stable_baselines3.common.envs.multi_input_envs import SimpleMultiObsEnv
|
||||
|
||||
__all__ = [
|
||||
"BitFlippingEnv",
|
||||
"FakeImageEnv",
|
||||
"IdentityEnv",
|
||||
"IdentityEnvBox",
|
||||
"IdentityEnvMultiBinary",
|
||||
"IdentityEnvMultiDiscrete",
|
||||
"SimpleMultiObsEnv",
|
||||
"SimpleMultiObsEnv",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -150,7 +150,10 @@ def get_obs_shape(
|
|||
return (int(len(observation_space.nvec)),)
|
||||
elif isinstance(observation_space, spaces.MultiBinary):
|
||||
# Number of binary features
|
||||
return (int(observation_space.n),)
|
||||
if type(observation_space.n) in [tuple, list, np.ndarray]:
|
||||
return tuple(observation_space.n)
|
||||
else:
|
||||
return (int(observation_space.n),)
|
||||
elif isinstance(observation_space, spaces.Dict):
|
||||
return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()}
|
||||
|
||||
|
|
|
|||
|
|
@ -28,9 +28,6 @@ class BaseFeaturesExtractor(nn.Module):
|
|||
def features_dim(self) -> int:
|
||||
return self._features_dim
|
||||
|
||||
def forward(self, observations: th.Tensor) -> th.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class FlattenExtractor(BaseFeaturesExtractor):
|
||||
"""
|
||||
|
|
@ -173,9 +170,11 @@ class MlpExtractor(nn.Module):
|
|||
):
|
||||
super().__init__()
|
||||
device = get_device(device)
|
||||
shared_net, policy_net, value_net = [], [], []
|
||||
policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network
|
||||
value_only_layers = [] # Layer sizes of the network that only belongs to the value network
|
||||
shared_net: List[nn.Module] = []
|
||||
policy_net: List[nn.Module] = []
|
||||
value_net: List[nn.Module] = []
|
||||
policy_only_layers: List[int] = [] # Layer sizes of the network that only belongs to the policy network
|
||||
value_only_layers: List[int] = [] # Layer sizes of the network that only belongs to the value network
|
||||
last_layer_dim_shared = feature_dim
|
||||
|
||||
# Iterate through the shared layers and build the shared parts of the network
|
||||
|
|
@ -254,7 +253,7 @@ class CombinedExtractor(BaseFeaturesExtractor):
|
|||
# TODO we do not know features-dim here before going over all the items, so put something there. This is dirty!
|
||||
super().__init__(observation_space, features_dim=1)
|
||||
|
||||
extractors = {}
|
||||
extractors: Dict[str, nn.Module] = {}
|
||||
|
||||
total_concat_size = 0
|
||||
for key, subspace in observation_space.spaces.items():
|
||||
|
|
|
|||
|
|
@ -299,14 +299,14 @@ def is_vectorized_multibinary_observation(observation: np.ndarray, observation_s
|
|||
:param observation_space: the observation space
|
||||
:return: whether the given observation is vectorized or not
|
||||
"""
|
||||
if observation.shape == (observation_space.n,):
|
||||
if observation.shape == observation_space.shape:
|
||||
return False
|
||||
elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n:
|
||||
elif len(observation.shape) == len(observation_space.shape) + 1 and observation.shape[1:] == observation_space.shape:
|
||||
return True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Error: Unexpected observation shape {observation.shape} for MultiBinary "
|
||||
+ f"environment, please use ({observation_space.n},) or "
|
||||
+ f"environment, please use {observation_space.shape} or "
|
||||
+ f"(n_env, {observation_space.n}) for the observation shape."
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# flake8: noqa F401
|
||||
import typing
|
||||
from copy import deepcopy
|
||||
from typing import Optional, Type, TypeVar, Union
|
||||
|
|
@ -76,3 +75,25 @@ def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None:
|
|||
eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms)
|
||||
env_tmp = env_tmp.venv
|
||||
eval_env_tmp = eval_env_tmp.venv
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CloudpickleWrapper",
|
||||
"VecEnv",
|
||||
"VecEnvWrapper",
|
||||
"DummyVecEnv",
|
||||
"StackedDictObservations",
|
||||
"StackedObservations",
|
||||
"SubprocVecEnv",
|
||||
"VecCheckNan",
|
||||
"VecExtractDictObs",
|
||||
"VecFrameStack",
|
||||
"VecMonitor",
|
||||
"VecNormalize",
|
||||
"VecTransposeImage",
|
||||
"VecVideoRecorder",
|
||||
"unwrap_vec_wrapper",
|
||||
"unwrap_vec_normalize",
|
||||
"is_vecenv_wrapped",
|
||||
"sync_envs_normalization",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,2 +1,4 @@
|
|||
from stable_baselines3.ddpg.ddpg import DDPG
|
||||
from stable_baselines3.ddpg.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||
|
||||
__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "DDPG"]
|
||||
|
|
|
|||
|
|
@ -1,2 +1,4 @@
|
|||
from stable_baselines3.dqn.dqn import DQN
|
||||
from stable_baselines3.dqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||
|
||||
__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "DQN"]
|
||||
|
|
|
|||
|
|
@ -1,2 +1,4 @@
|
|||
from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy
|
||||
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
|
||||
|
||||
__all__ = ["GoalSelectionStrategy", "HerReplayBuffer"]
|
||||
|
|
|
|||
|
|
@ -1,2 +1,4 @@
|
|||
from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||
from stable_baselines3.ppo.ppo import PPO
|
||||
|
||||
__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "PPO"]
|
||||
|
|
|
|||
|
|
@ -1,2 +1,4 @@
|
|||
from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||
from stable_baselines3.sac.sac import SAC
|
||||
|
||||
__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "SAC"]
|
||||
|
|
|
|||
|
|
@ -1,2 +1,4 @@
|
|||
from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||
from stable_baselines3.td3.td3 import TD3
|
||||
|
||||
__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "TD3"]
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.7.0a4
|
||||
1.7.0a5
|
||||
|
|
|
|||
66
tests/test_preprocessing.py
Normal file
66
tests/test_preprocessing.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
import torch
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3.common.preprocessing import get_obs_shape, preprocess_obs
|
||||
|
||||
|
||||
def test_get_obs_shape_discrete():
|
||||
assert get_obs_shape(spaces.Discrete(3)) == (1,)
|
||||
|
||||
|
||||
def test_get_obs_shape_multidiscrete():
|
||||
assert get_obs_shape(spaces.MultiDiscrete([3, 2])) == (2,)
|
||||
|
||||
|
||||
def test_get_obs_shape_multibinary():
|
||||
assert get_obs_shape(spaces.MultiBinary(3)) == (3,)
|
||||
|
||||
|
||||
def test_get_obs_shape_multidimensional_multibinary():
|
||||
assert get_obs_shape(spaces.MultiBinary([3, 2])) == (3, 2)
|
||||
|
||||
|
||||
def test_get_obs_shape_box():
|
||||
assert get_obs_shape(spaces.Box(-2, 2, shape=(3,))) == (3,)
|
||||
|
||||
|
||||
def test_get_obs_shape_multidimensional_box():
|
||||
assert get_obs_shape(spaces.Box(-2, 2, shape=(3, 2))) == (3, 2)
|
||||
|
||||
|
||||
def test_preprocess_obs_discrete():
|
||||
actual = preprocess_obs(torch.tensor([2], dtype=torch.long), spaces.Discrete(3))
|
||||
expected = torch.tensor([[0.0, 0.0, 1.0]], dtype=torch.float32)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
def test_preprocess_obs_multidiscrete():
|
||||
actual = preprocess_obs(torch.tensor([[2, 0]], dtype=torch.long), spaces.MultiDiscrete([3, 2]))
|
||||
expected = torch.tensor([[0.0, 0.0, 1.0, 1.0, 0.0]], dtype=torch.float32)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
def test_preprocess_obs_multibinary():
|
||||
actual = preprocess_obs(torch.tensor([[1, 0, 1]], dtype=torch.long), spaces.MultiBinary(3))
|
||||
expected = torch.tensor([[1.0, 0.0, 1.0]], dtype=torch.float32)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
def test_preprocess_obs_multidimensional_multibinary():
|
||||
actual = preprocess_obs(torch.tensor([[[1, 0], [1, 1], [0, 1]]], dtype=torch.long), spaces.MultiBinary([3, 2]))
|
||||
expected = torch.tensor([[[1.0, 0.0], [1.0, 1.0], [0.0, 1.0]]], dtype=torch.float32)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
def test_preprocess_obs_box():
|
||||
actual = preprocess_obs(torch.tensor([[1.5, 0.3, -1.8]], dtype=torch.float32), spaces.Box(-2, 2, shape=(3,)))
|
||||
expected = torch.tensor([[1.5, 0.3, -1.8]], dtype=torch.float32)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
def test_preprocess_obs_multidimensional_box():
|
||||
actual = preprocess_obs(
|
||||
torch.tensor([[[1.5, 0.3, -1.8], [0.1, -0.6, -1.4]]], dtype=torch.float32), spaces.Box(-2, 2, shape=(3, 2))
|
||||
)
|
||||
expected = torch.tensor([[[1.5, 0.3, -1.8], [0.1, -0.6, -1.4]]], dtype=torch.float32)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
|
@ -47,7 +47,7 @@ class DummyMultidimensionalAction(gym.Env):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [SAC, TD3, DQN])
|
||||
@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8)])
|
||||
@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8), DummyMultiBinary((3, 2))])
|
||||
def test_identity_spaces(model_class, env):
|
||||
"""
|
||||
Additional tests for DQ/SAC/TD3 to check observation space support
|
||||
|
|
|
|||
Loading…
Reference in a new issue