Fix default arguments + add bugbear (#363)

* Fix potential bug + add bug bear

* Remove unused variables

* Minor: version bump
This commit is contained in:
Antonin RAFFIN 2021-03-25 10:35:21 +01:00 committed by GitHub
parent e1ee87fef7
commit 8a08078ea2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 17 additions and 13 deletions

View file

@ -4,7 +4,7 @@ Changelog
==========
Release 1.1.0a0 (WIP)
Release 1.1.0a1 (WIP)
---------------------------
Breaking Changes:
@ -15,12 +15,14 @@ New Features:
Bug Fixes:
^^^^^^^^^^
- Fixed potential issue when calling off-policy algorithms with default arguments multiple times (the size of the replay buffer would be the same)
Deprecations:
^^^^^^^^^^^^^
Others:
^^^^^^^
- Added ``flake8-bugbear`` to tests dependencies to find likely bugs
Documentation:
^^^^^^^^^^^^^^

View file

@ -20,8 +20,8 @@ These algorithms will make it easier for the research community and industry to
Repository:
https://github.com/DLR-RM/stable-baselines3
Medium article:
https://medium.com/@araffin/df87c4b2fc82
Blog post:
https://araffin.github.io/post/sb3/
Documentation:
https://stable-baselines3.readthedocs.io/en/master/
@ -94,6 +94,8 @@ setup(
"pytype",
# Lint code
"flake8>=3.8",
# Find likely bugs
"flake8-bugbear",
# Sort imports
"isort>=5.0",
# Reformat

View file

@ -76,7 +76,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
env: Union[GymEnv, str],
policy_base: Type[BasePolicy],
learning_rate: Union[float, Schedule],
buffer_size: int = int(1e6),
buffer_size: int = 1000000,
learning_starts: int = 100,
batch_size: int = 256,
tau: float = 0.005,

View file

@ -84,7 +84,7 @@ def plot_curves(
plt.figure(title, figsize=figsize)
max_x = max(xy[0][-1] for xy in xy_list)
min_x = 0
for (i, (x, y)) in enumerate(xy_list):
for (_, (x, y)) in enumerate(xy_list):
plt.scatter(x, y, s=2)
# Do not plot the smoothed curve at all if the timeseries is shorter than window size.
if x.shape[0] >= EPISODES_WINDOW:

View file

@ -170,7 +170,7 @@ class MlpExtractor(nn.Module):
last_layer_dim_shared = feature_dim
# Iterate through the shared layers and build the shared parts of the network
for idx, layer in enumerate(net_arch):
for layer in net_arch:
if isinstance(layer, int): # Check that this is a shared layer
layer_size = layer
# TODO: give layer a meaningful name
@ -192,7 +192,7 @@ class MlpExtractor(nn.Module):
last_layer_dim_vf = last_layer_dim_shared
# Build the non-shared part of the network
for idx, (pi_layer_size, vf_layer_size) in enumerate(zip_longest(policy_only_layers, value_only_layers)):
for pi_layer_size, vf_layer_size in zip_longest(policy_only_layers, value_only_layers):
if pi_layer_size is not None:
assert isinstance(pi_layer_size, int), "Error: net_arch[-1]['pi'] must only contain integers."
policy_net.append(nn.Linear(last_layer_dim_pi, pi_layer_size))

View file

@ -54,7 +54,7 @@ class DDPG(TD3):
policy: Union[str, Type[TD3Policy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 1e-3,
buffer_size: int = int(1e6),
buffer_size: int = 1000000,
learning_starts: int = 100,
batch_size: int = 100,
tau: float = 0.005,

View file

@ -147,7 +147,7 @@ class DQN(OffPolicyAlgorithm):
self._update_learning_rate(self.policy.optimizer)
losses = []
for gradient_step in range(gradient_steps):
for _ in range(gradient_steps):
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)

View file

@ -74,7 +74,7 @@ class SAC(OffPolicyAlgorithm):
policy: Union[str, Type[SACPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 3e-4,
buffer_size: int = int(1e6),
buffer_size: int = 1000000,
learning_starts: int = 100,
batch_size: int = 256,
tau: float = 0.005,

View file

@ -62,7 +62,7 @@ class TD3(OffPolicyAlgorithm):
policy: Union[str, Type[TD3Policy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 1e-3,
buffer_size: int = int(1e6),
buffer_size: int = 1000000, # 1e6
learning_starts: int = 100,
batch_size: int = 100,
tau: float = 0.005,
@ -131,7 +131,7 @@ class TD3(OffPolicyAlgorithm):
actor_losses, critic_losses = [], []
for gradient_step in range(gradient_steps):
for _ in range(gradient_steps):
self._n_updates += 1
# Sample replay buffer

View file

@ -1 +1 @@
1.1.0a0
1.1.0a1