diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b50cc62..9d22a0a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,9 +22,9 @@ jobs: python-version: [3.7, 3.8, 3.9] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install dependencies diff --git a/docs/guide/custom_policy.rst b/docs/guide/custom_policy.rst index 1a3ae34..4ba3203 100644 --- a/docs/guide/custom_policy.rst +++ b/docs/guide/custom_policy.rst @@ -333,11 +333,11 @@ If your task requires even more granular control over the policy/value architect :return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network. If all layers are shared, then ``latent_policy == latent_value`` """ - return self.policy_net(features), self.value_net(features) - + return self.forward_actor(features), self.forward_critic(features) + def forward_actor(self, features: th.Tensor) -> th.Tensor: return self.policy_net(features) - + def forward_critic(self, features: th.Tensor) -> th.Tensor: return self.value_net(features) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index f113ea5..935bd49 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.7.0a4 (WIP) +Release 1.7.0a5 (WIP) -------------------------- Breaking Changes: @@ -18,14 +18,15 @@ New Features: ^^^^^^^^^^^^^ - Introduced mypy type checking - Added ``with_bias`` argument to ``create_mlp`` +- Added support for multidimensional ``spaces.MultiBinary`` observations SB3-Contrib ^^^^^^^^^^^ Bug Fixes: ^^^^^^^^^^ -- Fix return type of ``evaluate_actions`` in ``ActorCritcPolicy`` to reflect that entropy is an optional tensor (@Rocamonde) -- Fix type annotation of ``policy`` in ``BaseAlgorithm`` and ``OffPolicyAlgorithm`` +- Fixed return type of ``evaluate_actions`` in ``ActorCritcPolicy`` to reflect that entropy is an optional tensor (@Rocamonde) +- Fixed type annotation of ``policy`` in ``BaseAlgorithm`` and ``OffPolicyAlgorithm`` - Allowed model trained with Python 3.7 to be loaded with Python 3.8+ without the ``custom_objects`` workaround - Raise an error when the same gym environment instance is passed as separate environments when creating a vectorized environment with more than one environment. (@Rocamonde) - Fix type annotation of ``model`` in ``evaluate_policy`` @@ -42,13 +43,17 @@ Others: - Replaced ``CartPole-v0`` by ``CartPole-v1`` is tests - Fixed ``tests/test_distributions.py`` type hint - Fixed ``stable_baselines3/common/type_aliases.py`` type hint +- Fixed ``stable_baselines3/common/torch_layers.py`` type hint - Fixed ``stable_baselines3/common/env_util.py`` type hint +- Exposed modules in ``__init__.py`` with the ``__all__`` attribute (@ZikangXiong) +- Upgraded GitHub CI/setup-python to v4 and checkout to v3 - Fixed ``stable_baselines3/common/vec_env/__init__.py`` type hint Documentation: ^^^^^^^^^^^^^^ - Updated Hugging Face Integration page (@simoninithomas) - Changed ``env`` to ``vec_env`` when environment is vectorized +- Update custom policy documentation (@athatheo) Release 1.6.2 (2022-10-10) -------------------------- @@ -1127,4 +1132,4 @@ And all the contributors: @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 @Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 -@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer +@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong diff --git a/setup.cfg b/setup.cfg index 36c451a..505a449 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,7 +47,6 @@ exclude = (?x)( | stable_baselines3/common/preprocessing.py$ | stable_baselines3/common/save_util.py$ | stable_baselines3/common/sb2_compat/rmsprop_tf_like.py$ - | stable_baselines3/common/torch_layers.py$ | stable_baselines3/common/utils.py$ | stable_baselines3/common/vec_env/base_vec_env.py$ | stable_baselines3/common/vec_env/dummy_vec_env.py$ @@ -80,17 +79,6 @@ exclude = (?x)( ignore = W503,W504,E203,E231 # Ignore import not used when aliases are defined per-file-ignores = - ./stable_baselines3/__init__.py:F401 - ./stable_baselines3/common/__init__.py:F401 - ./stable_baselines3/common/envs/__init__.py:F401 - ./stable_baselines3/a2c/__init__.py:F401 - ./stable_baselines3/ddpg/__init__.py:F401 - ./stable_baselines3/dqn/__init__.py:F401 - ./stable_baselines3/her/__init__.py:F401 - ./stable_baselines3/ppo/__init__.py:F401 - ./stable_baselines3/sac/__init__.py:F401 - ./stable_baselines3/td3/__init__.py:F401 - ./stable_baselines3/common/vec_env/__init__.py:F401 # Default implementation in abstract methods ./stable_baselines3/common/callbacks.py:B027 ./stable_baselines3/common/noise.py:B027 diff --git a/stable_baselines3/__init__.py b/stable_baselines3/__init__.py index d73f5f0..0775a8e 100644 --- a/stable_baselines3/__init__.py +++ b/stable_baselines3/__init__.py @@ -20,3 +20,15 @@ def HER(*args, **kwargs): "Since Stable Baselines 2.1.0, `HER` is now a replay buffer class `HerReplayBuffer`.\n " "Please check the documentation for more information: https://stable-baselines3.readthedocs.io/" ) + + +__all__ = [ + "A2C", + "DDPG", + "DQN", + "PPO", + "SAC", + "TD3", + "HerReplayBuffer", + "get_system_info", +] diff --git a/stable_baselines3/a2c/__init__.py b/stable_baselines3/a2c/__init__.py index 7e99964..78fc54f 100644 --- a/stable_baselines3/a2c/__init__.py +++ b/stable_baselines3/a2c/__init__.py @@ -1,2 +1,4 @@ from stable_baselines3.a2c.a2c import A2C from stable_baselines3.a2c.policies import CnnPolicy, MlpPolicy, MultiInputPolicy + +__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "A2C"] diff --git a/stable_baselines3/common/envs/__init__.py b/stable_baselines3/common/envs/__init__.py index 23bd575..3ff0221 100644 --- a/stable_baselines3/common/envs/__init__.py +++ b/stable_baselines3/common/envs/__init__.py @@ -7,3 +7,14 @@ from stable_baselines3.common.envs.identity_env import ( IdentityEnvMultiDiscrete, ) from stable_baselines3.common.envs.multi_input_envs import SimpleMultiObsEnv + +__all__ = [ + "BitFlippingEnv", + "FakeImageEnv", + "IdentityEnv", + "IdentityEnvBox", + "IdentityEnvMultiBinary", + "IdentityEnvMultiDiscrete", + "SimpleMultiObsEnv", + "SimpleMultiObsEnv", +] diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index 01422aa..a406a7d 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -150,7 +150,10 @@ def get_obs_shape( return (int(len(observation_space.nvec)),) elif isinstance(observation_space, spaces.MultiBinary): # Number of binary features - return (int(observation_space.n),) + if type(observation_space.n) in [tuple, list, np.ndarray]: + return tuple(observation_space.n) + else: + return (int(observation_space.n),) elif isinstance(observation_space, spaces.Dict): return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()} diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index 2ce0cc1..5de2af1 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -28,9 +28,6 @@ class BaseFeaturesExtractor(nn.Module): def features_dim(self) -> int: return self._features_dim - def forward(self, observations: th.Tensor) -> th.Tensor: - raise NotImplementedError() - class FlattenExtractor(BaseFeaturesExtractor): """ @@ -173,9 +170,11 @@ class MlpExtractor(nn.Module): ): super().__init__() device = get_device(device) - shared_net, policy_net, value_net = [], [], [] - policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network - value_only_layers = [] # Layer sizes of the network that only belongs to the value network + shared_net: List[nn.Module] = [] + policy_net: List[nn.Module] = [] + value_net: List[nn.Module] = [] + policy_only_layers: List[int] = [] # Layer sizes of the network that only belongs to the policy network + value_only_layers: List[int] = [] # Layer sizes of the network that only belongs to the value network last_layer_dim_shared = feature_dim # Iterate through the shared layers and build the shared parts of the network @@ -254,7 +253,7 @@ class CombinedExtractor(BaseFeaturesExtractor): # TODO we do not know features-dim here before going over all the items, so put something there. This is dirty! super().__init__(observation_space, features_dim=1) - extractors = {} + extractors: Dict[str, nn.Module] = {} total_concat_size = 0 for key, subspace in observation_space.spaces.items(): diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 53c642c..1a3f871 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -299,14 +299,14 @@ def is_vectorized_multibinary_observation(observation: np.ndarray, observation_s :param observation_space: the observation space :return: whether the given observation is vectorized or not """ - if observation.shape == (observation_space.n,): + if observation.shape == observation_space.shape: return False - elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n: + elif len(observation.shape) == len(observation_space.shape) + 1 and observation.shape[1:] == observation_space.shape: return True else: raise ValueError( f"Error: Unexpected observation shape {observation.shape} for MultiBinary " - + f"environment, please use ({observation_space.n},) or " + + f"environment, please use {observation_space.shape} or " + f"(n_env, {observation_space.n}) for the observation shape." ) diff --git a/stable_baselines3/common/vec_env/__init__.py b/stable_baselines3/common/vec_env/__init__.py index 6e704d9..1211c01 100644 --- a/stable_baselines3/common/vec_env/__init__.py +++ b/stable_baselines3/common/vec_env/__init__.py @@ -1,4 +1,3 @@ -# flake8: noqa F401 import typing from copy import deepcopy from typing import Optional, Type, TypeVar, Union @@ -76,3 +75,25 @@ def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None: eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms) env_tmp = env_tmp.venv eval_env_tmp = eval_env_tmp.venv + + +__all__ = [ + "CloudpickleWrapper", + "VecEnv", + "VecEnvWrapper", + "DummyVecEnv", + "StackedDictObservations", + "StackedObservations", + "SubprocVecEnv", + "VecCheckNan", + "VecExtractDictObs", + "VecFrameStack", + "VecMonitor", + "VecNormalize", + "VecTransposeImage", + "VecVideoRecorder", + "unwrap_vec_wrapper", + "unwrap_vec_normalize", + "is_vecenv_wrapped", + "sync_envs_normalization", +] diff --git a/stable_baselines3/ddpg/__init__.py b/stable_baselines3/ddpg/__init__.py index 262e7f1..257a3e3 100644 --- a/stable_baselines3/ddpg/__init__.py +++ b/stable_baselines3/ddpg/__init__.py @@ -1,2 +1,4 @@ from stable_baselines3.ddpg.ddpg import DDPG from stable_baselines3.ddpg.policies import CnnPolicy, MlpPolicy, MultiInputPolicy + +__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "DDPG"] diff --git a/stable_baselines3/dqn/__init__.py b/stable_baselines3/dqn/__init__.py index f36f96e..2e5e2db 100644 --- a/stable_baselines3/dqn/__init__.py +++ b/stable_baselines3/dqn/__init__.py @@ -1,2 +1,4 @@ from stable_baselines3.dqn.dqn import DQN from stable_baselines3.dqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy + +__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "DQN"] diff --git a/stable_baselines3/her/__init__.py b/stable_baselines3/her/__init__.py index 1f58921..dc4c8c2 100644 --- a/stable_baselines3/her/__init__.py +++ b/stable_baselines3/her/__init__.py @@ -1,2 +1,4 @@ from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy from stable_baselines3.her.her_replay_buffer import HerReplayBuffer + +__all__ = ["GoalSelectionStrategy", "HerReplayBuffer"] diff --git a/stable_baselines3/ppo/__init__.py b/stable_baselines3/ppo/__init__.py index e5c23fc..cd91257 100644 --- a/stable_baselines3/ppo/__init__.py +++ b/stable_baselines3/ppo/__init__.py @@ -1,2 +1,4 @@ from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from stable_baselines3.ppo.ppo import PPO + +__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "PPO"] diff --git a/stable_baselines3/sac/__init__.py b/stable_baselines3/sac/__init__.py index 5a84dde..bdf780d 100644 --- a/stable_baselines3/sac/__init__.py +++ b/stable_baselines3/sac/__init__.py @@ -1,2 +1,4 @@ from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from stable_baselines3.sac.sac import SAC + +__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "SAC"] diff --git a/stable_baselines3/td3/__init__.py b/stable_baselines3/td3/__init__.py index 0b903cd..428141e 100644 --- a/stable_baselines3/td3/__init__.py +++ b/stable_baselines3/td3/__init__.py @@ -1,2 +1,4 @@ from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from stable_baselines3.td3.td3 import TD3 + +__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "TD3"] diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 0952a4b..5d819d4 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.7.0a4 +1.7.0a5 diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py new file mode 100644 index 0000000..89f869b --- /dev/null +++ b/tests/test_preprocessing.py @@ -0,0 +1,66 @@ +import torch +from gym import spaces + +from stable_baselines3.common.preprocessing import get_obs_shape, preprocess_obs + + +def test_get_obs_shape_discrete(): + assert get_obs_shape(spaces.Discrete(3)) == (1,) + + +def test_get_obs_shape_multidiscrete(): + assert get_obs_shape(spaces.MultiDiscrete([3, 2])) == (2,) + + +def test_get_obs_shape_multibinary(): + assert get_obs_shape(spaces.MultiBinary(3)) == (3,) + + +def test_get_obs_shape_multidimensional_multibinary(): + assert get_obs_shape(spaces.MultiBinary([3, 2])) == (3, 2) + + +def test_get_obs_shape_box(): + assert get_obs_shape(spaces.Box(-2, 2, shape=(3,))) == (3,) + + +def test_get_obs_shape_multidimensional_box(): + assert get_obs_shape(spaces.Box(-2, 2, shape=(3, 2))) == (3, 2) + + +def test_preprocess_obs_discrete(): + actual = preprocess_obs(torch.tensor([2], dtype=torch.long), spaces.Discrete(3)) + expected = torch.tensor([[0.0, 0.0, 1.0]], dtype=torch.float32) + torch.testing.assert_close(actual, expected) + + +def test_preprocess_obs_multidiscrete(): + actual = preprocess_obs(torch.tensor([[2, 0]], dtype=torch.long), spaces.MultiDiscrete([3, 2])) + expected = torch.tensor([[0.0, 0.0, 1.0, 1.0, 0.0]], dtype=torch.float32) + torch.testing.assert_close(actual, expected) + + +def test_preprocess_obs_multibinary(): + actual = preprocess_obs(torch.tensor([[1, 0, 1]], dtype=torch.long), spaces.MultiBinary(3)) + expected = torch.tensor([[1.0, 0.0, 1.0]], dtype=torch.float32) + torch.testing.assert_close(actual, expected) + + +def test_preprocess_obs_multidimensional_multibinary(): + actual = preprocess_obs(torch.tensor([[[1, 0], [1, 1], [0, 1]]], dtype=torch.long), spaces.MultiBinary([3, 2])) + expected = torch.tensor([[[1.0, 0.0], [1.0, 1.0], [0.0, 1.0]]], dtype=torch.float32) + torch.testing.assert_close(actual, expected) + + +def test_preprocess_obs_box(): + actual = preprocess_obs(torch.tensor([[1.5, 0.3, -1.8]], dtype=torch.float32), spaces.Box(-2, 2, shape=(3,))) + expected = torch.tensor([[1.5, 0.3, -1.8]], dtype=torch.float32) + torch.testing.assert_close(actual, expected) + + +def test_preprocess_obs_multidimensional_box(): + actual = preprocess_obs( + torch.tensor([[[1.5, 0.3, -1.8], [0.1, -0.6, -1.4]]], dtype=torch.float32), spaces.Box(-2, 2, shape=(3, 2)) + ) + expected = torch.tensor([[[1.5, 0.3, -1.8], [0.1, -0.6, -1.4]]], dtype=torch.float32) + torch.testing.assert_close(actual, expected) diff --git a/tests/test_spaces.py b/tests/test_spaces.py index 0696492..6f530b7 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -47,7 +47,7 @@ class DummyMultidimensionalAction(gym.Env): @pytest.mark.parametrize("model_class", [SAC, TD3, DQN]) -@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8)]) +@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8), DummyMultiBinary((3, 2))]) def test_identity_spaces(model_class, env): """ Additional tests for DQ/SAC/TD3 to check observation space support