From 8a08078ea2fd6065dd7e75ff89920c8683d58580 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 25 Mar 2021 10:35:21 +0100 Subject: [PATCH] Fix default arguments + add bugbear (#363) * Fix potential bug + add bug bear * Remove unused variables * Minor: version bump --- docs/misc/changelog.rst | 4 +++- setup.py | 6 ++++-- stable_baselines3/common/off_policy_algorithm.py | 2 +- stable_baselines3/common/results_plotter.py | 2 +- stable_baselines3/common/torch_layers.py | 4 ++-- stable_baselines3/ddpg/ddpg.py | 2 +- stable_baselines3/dqn/dqn.py | 2 +- stable_baselines3/sac/sac.py | 2 +- stable_baselines3/td3/td3.py | 4 ++-- stable_baselines3/version.txt | 2 +- 10 files changed, 17 insertions(+), 13 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e1bed34..a0737f2 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.1.0a0 (WIP) +Release 1.1.0a1 (WIP) --------------------------- Breaking Changes: @@ -15,12 +15,14 @@ New Features: Bug Fixes: ^^^^^^^^^^ +- Fixed potential issue when calling off-policy algorithms with default arguments multiple times (the size of the replay buffer would be the same) Deprecations: ^^^^^^^^^^^^^ Others: ^^^^^^^ +- Added ``flake8-bugbear`` to tests dependencies to find likely bugs Documentation: ^^^^^^^^^^^^^^ diff --git a/setup.py b/setup.py index 0ef4e9b..a68aa18 100644 --- a/setup.py +++ b/setup.py @@ -20,8 +20,8 @@ These algorithms will make it easier for the research community and industry to Repository: https://github.com/DLR-RM/stable-baselines3 -Medium article: -https://medium.com/@araffin/df87c4b2fc82 +Blog post: +https://araffin.github.io/post/sb3/ Documentation: https://stable-baselines3.readthedocs.io/en/master/ @@ -94,6 +94,8 @@ setup( "pytype", # Lint code "flake8>=3.8", + # Find likely bugs + "flake8-bugbear", # Sort imports "isort>=5.0", # Reformat diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 7c7412f..f140fa4 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -76,7 +76,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): env: Union[GymEnv, str], policy_base: Type[BasePolicy], learning_rate: Union[float, Schedule], - buffer_size: int = int(1e6), + buffer_size: int = 1000000, learning_starts: int = 100, batch_size: int = 256, tau: float = 0.005, diff --git a/stable_baselines3/common/results_plotter.py b/stable_baselines3/common/results_plotter.py index 7e5b3cd..92f67ac 100644 --- a/stable_baselines3/common/results_plotter.py +++ b/stable_baselines3/common/results_plotter.py @@ -84,7 +84,7 @@ def plot_curves( plt.figure(title, figsize=figsize) max_x = max(xy[0][-1] for xy in xy_list) min_x = 0 - for (i, (x, y)) in enumerate(xy_list): + for (_, (x, y)) in enumerate(xy_list): plt.scatter(x, y, s=2) # Do not plot the smoothed curve at all if the timeseries is shorter than window size. if x.shape[0] >= EPISODES_WINDOW: diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index 165d37d..1f0c28d 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -170,7 +170,7 @@ class MlpExtractor(nn.Module): last_layer_dim_shared = feature_dim # Iterate through the shared layers and build the shared parts of the network - for idx, layer in enumerate(net_arch): + for layer in net_arch: if isinstance(layer, int): # Check that this is a shared layer layer_size = layer # TODO: give layer a meaningful name @@ -192,7 +192,7 @@ class MlpExtractor(nn.Module): last_layer_dim_vf = last_layer_dim_shared # Build the non-shared part of the network - for idx, (pi_layer_size, vf_layer_size) in enumerate(zip_longest(policy_only_layers, value_only_layers)): + for pi_layer_size, vf_layer_size in zip_longest(policy_only_layers, value_only_layers): if pi_layer_size is not None: assert isinstance(pi_layer_size, int), "Error: net_arch[-1]['pi'] must only contain integers." policy_net.append(nn.Linear(last_layer_dim_pi, pi_layer_size)) diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py index e7a1f75..ea08651 100644 --- a/stable_baselines3/ddpg/ddpg.py +++ b/stable_baselines3/ddpg/ddpg.py @@ -54,7 +54,7 @@ class DDPG(TD3): policy: Union[str, Type[TD3Policy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 1e-3, - buffer_size: int = int(1e6), + buffer_size: int = 1000000, learning_starts: int = 100, batch_size: int = 100, tau: float = 0.005, diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 241f9ae..6cea64d 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -147,7 +147,7 @@ class DQN(OffPolicyAlgorithm): self._update_learning_rate(self.policy.optimizer) losses = [] - for gradient_step in range(gradient_steps): + for _ in range(gradient_steps): # Sample replay buffer replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index 63ed10f..73b220c 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -74,7 +74,7 @@ class SAC(OffPolicyAlgorithm): policy: Union[str, Type[SACPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 3e-4, - buffer_size: int = int(1e6), + buffer_size: int = 1000000, learning_starts: int = 100, batch_size: int = 256, tau: float = 0.005, diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index b552e60..6ea0681 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -62,7 +62,7 @@ class TD3(OffPolicyAlgorithm): policy: Union[str, Type[TD3Policy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 1e-3, - buffer_size: int = int(1e6), + buffer_size: int = 1000000, # 1e6 learning_starts: int = 100, batch_size: int = 100, tau: float = 0.005, @@ -131,7 +131,7 @@ class TD3(OffPolicyAlgorithm): actor_losses, critic_losses = [], [] - for gradient_step in range(gradient_steps): + for _ in range(gradient_steps): self._n_updates += 1 # Sample replay buffer diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index ae7fb2d..6b34d9c 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.1.0a0 +1.1.0a1