Update black version + update docker image (#151)

* Update docker image

* Update black and reformat
This commit is contained in:
Antonin RAFFIN 2020-08-27 23:02:59 +02:00 committed by GitHub
parent a1afc5e42f
commit 15d32c6a4a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 22 additions and 10 deletions

View file

@ -1,4 +1,4 @@
image: stablebaselines/stable-baselines3-cpu:0.8.0a4
image: stablebaselines/stable-baselines3-cpu:0.9.0a1
type-check:
script:

View file

@ -171,12 +171,12 @@ class ReplayBuffer(BaseBuffer):
mem_available = psutil.virtual_memory().available
self.optimize_memory_usage = optimize_memory_usage
self.observations = np.zeros((self.buffer_size, self.n_envs,) + self.obs_shape, dtype=observation_space.dtype)
self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype)
if optimize_memory_usage:
# `observations` contains also the next observation
self.next_observations = None
else:
self.next_observations = np.zeros((self.buffer_size, self.n_envs,) + self.obs_shape, dtype=observation_space.dtype)
self.next_observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype)
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
@ -284,7 +284,7 @@ class RolloutBuffer(BaseBuffer):
self.reset()
def reset(self) -> None:
self.observations = np.zeros((self.buffer_size, self.n_envs,) + self.obs_shape, dtype=np.float32)
self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=np.float32)
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)

View file

@ -699,7 +699,10 @@ class ContinuousCritic(BaseModel):
n_critics: int = 2,
):
super().__init__(
observation_space, action_space, features_extractor=features_extractor, normalize_images=normalize_images,
observation_space,
action_space,
features_extractor=features_extractor,
normalize_images=normalize_images,
)
action_dim = get_action_dim(self.action_space)

View file

@ -350,7 +350,9 @@ def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose=0)
def load_from_zip_file(
load_path: Union[str, pathlib.Path, io.BufferedIOBase], load_data: bool = True, verbose=0,
load_path: Union[str, pathlib.Path, io.BufferedIOBase],
load_data: bool = True,
verbose=0,
) -> (Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]):
"""
Load model data from a .zip archive

View file

@ -31,7 +31,10 @@ class QNetwork(BasePolicy):
normalize_images: bool = True,
):
super(QNetwork, self).__init__(
observation_space, action_space, features_extractor=features_extractor, normalize_images=normalize_images,
observation_space,
action_space,
features_extractor=features_extractor,
normalize_images=normalize_images,
)
if net_arch is None:

View file

@ -22,7 +22,7 @@ def test_flexible_mlp(model_class, net_arch):
_ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=100).learn(1000)
@pytest.mark.parametrize("net_arch", [[4], [4, 4],])
@pytest.mark.parametrize("net_arch", [[4], [4, 4]])
@pytest.mark.parametrize("model_class", [SAC, TD3])
def test_custom_offpolicy(model_class, net_arch):
_ = model_class("MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=net_arch)).learn(1000)

View file

@ -67,7 +67,11 @@ def test_sde_distribution():
# TODO: analytical form for squashed Gaussian?
@pytest.mark.parametrize(
"dist", [DiagGaussianDistribution(N_ACTIONS), StateDependentNoiseDistribution(N_ACTIONS, squash_output=False),]
"dist",
[
DiagGaussianDistribution(N_ACTIONS),
StateDependentNoiseDistribution(N_ACTIONS, squash_output=False),
],
)
def test_entropy(dist):
# The entropy can be approximated by averaging the negative log likelihood

View file

@ -225,7 +225,7 @@ def check_vecenv_spaces(vec_env_class, space, obs_assert):
def check_vecenv_obs(obs, space):
"""Helper method to check observations from multiple environments each belong to
the appropriate observation space."""
the appropriate observation space."""
assert obs.shape[0] == N_ENVS
for value in obs:
assert space.contains(value)