From 15d32c6a4a0ddc19e64baaa16e4afd5a045e23ce Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 27 Aug 2020 23:02:59 +0200 Subject: [PATCH] Update black version + update docker image (#151) * Update docker image * Update black and reformat --- .gitlab-ci.yml | 2 +- stable_baselines3/common/buffers.py | 6 +++--- stable_baselines3/common/policies.py | 5 ++++- stable_baselines3/common/save_util.py | 4 +++- stable_baselines3/dqn/policies.py | 5 ++++- tests/test_custom_policy.py | 2 +- tests/test_distributions.py | 6 +++++- tests/test_vec_envs.py | 2 +- 8 files changed, 22 insertions(+), 10 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 695c145..71826e9 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,4 +1,4 @@ -image: stablebaselines/stable-baselines3-cpu:0.8.0a4 +image: stablebaselines/stable-baselines3-cpu:0.9.0a1 type-check: script: diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 4534063..6c58953 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -171,12 +171,12 @@ class ReplayBuffer(BaseBuffer): mem_available = psutil.virtual_memory().available self.optimize_memory_usage = optimize_memory_usage - self.observations = np.zeros((self.buffer_size, self.n_envs,) + self.obs_shape, dtype=observation_space.dtype) + self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype) if optimize_memory_usage: # `observations` contains also the next observation self.next_observations = None else: - self.next_observations = np.zeros((self.buffer_size, self.n_envs,) + self.obs_shape, dtype=observation_space.dtype) + self.next_observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype) self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype) self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) @@ -284,7 +284,7 @@ class RolloutBuffer(BaseBuffer): self.reset() def reset(self) -> None: - self.observations = np.zeros((self.buffer_size, self.n_envs,) + self.obs_shape, dtype=np.float32) + self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=np.float32) self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 41280a8..0478342 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -699,7 +699,10 @@ class ContinuousCritic(BaseModel): n_critics: int = 2, ): super().__init__( - observation_space, action_space, features_extractor=features_extractor, normalize_images=normalize_images, + observation_space, + action_space, + features_extractor=features_extractor, + normalize_images=normalize_images, ) action_dim = get_action_dim(self.action_space) diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index 51fa8cd..326db1e 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -350,7 +350,9 @@ def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose=0) def load_from_zip_file( - load_path: Union[str, pathlib.Path, io.BufferedIOBase], load_data: bool = True, verbose=0, + load_path: Union[str, pathlib.Path, io.BufferedIOBase], + load_data: bool = True, + verbose=0, ) -> (Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]): """ Load model data from a .zip archive diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index f5001c7..ebbcd34 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -31,7 +31,10 @@ class QNetwork(BasePolicy): normalize_images: bool = True, ): super(QNetwork, self).__init__( - observation_space, action_space, features_extractor=features_extractor, normalize_images=normalize_images, + observation_space, + action_space, + features_extractor=features_extractor, + normalize_images=normalize_images, ) if net_arch is None: diff --git a/tests/test_custom_policy.py b/tests/test_custom_policy.py index c1e08df..95f4a7c 100644 --- a/tests/test_custom_policy.py +++ b/tests/test_custom_policy.py @@ -22,7 +22,7 @@ def test_flexible_mlp(model_class, net_arch): _ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=100).learn(1000) -@pytest.mark.parametrize("net_arch", [[4], [4, 4],]) +@pytest.mark.parametrize("net_arch", [[4], [4, 4]]) @pytest.mark.parametrize("model_class", [SAC, TD3]) def test_custom_offpolicy(model_class, net_arch): _ = model_class("MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=net_arch)).learn(1000) diff --git a/tests/test_distributions.py b/tests/test_distributions.py index a73b81e..490f80e 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -67,7 +67,11 @@ def test_sde_distribution(): # TODO: analytical form for squashed Gaussian? @pytest.mark.parametrize( - "dist", [DiagGaussianDistribution(N_ACTIONS), StateDependentNoiseDistribution(N_ACTIONS, squash_output=False),] + "dist", + [ + DiagGaussianDistribution(N_ACTIONS), + StateDependentNoiseDistribution(N_ACTIONS, squash_output=False), + ], ) def test_entropy(dist): # The entropy can be approximated by averaging the negative log likelihood diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index 8c33341..141ca6a 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -225,7 +225,7 @@ def check_vecenv_spaces(vec_env_class, space, obs_assert): def check_vecenv_obs(obs, space): """Helper method to check observations from multiple environments each belong to - the appropriate observation space.""" + the appropriate observation space.""" assert obs.shape[0] == N_ENVS for value in obs: assert space.contains(value)