mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-14 01:13:45 +00:00
Fix default arguments + add bugbear (#363)
* Fix potential bug + add bug bear * Remove unused variables * Minor: version bump
This commit is contained in:
parent
e1ee87fef7
commit
8a08078ea2
10 changed files with 17 additions and 13 deletions
|
|
@ -4,7 +4,7 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.1.0a0 (WIP)
|
||||
Release 1.1.0a1 (WIP)
|
||||
---------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -15,12 +15,14 @@ New Features:
|
|||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed potential issue when calling off-policy algorithms with default arguments multiple times (the size of the replay buffer would be the same)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
- Added ``flake8-bugbear`` to tests dependencies to find likely bugs
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
|
|||
6
setup.py
6
setup.py
|
|
@ -20,8 +20,8 @@ These algorithms will make it easier for the research community and industry to
|
|||
Repository:
|
||||
https://github.com/DLR-RM/stable-baselines3
|
||||
|
||||
Medium article:
|
||||
https://medium.com/@araffin/df87c4b2fc82
|
||||
Blog post:
|
||||
https://araffin.github.io/post/sb3/
|
||||
|
||||
Documentation:
|
||||
https://stable-baselines3.readthedocs.io/en/master/
|
||||
|
|
@ -94,6 +94,8 @@ setup(
|
|||
"pytype",
|
||||
# Lint code
|
||||
"flake8>=3.8",
|
||||
# Find likely bugs
|
||||
"flake8-bugbear",
|
||||
# Sort imports
|
||||
"isort>=5.0",
|
||||
# Reformat
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
env: Union[GymEnv, str],
|
||||
policy_base: Type[BasePolicy],
|
||||
learning_rate: Union[float, Schedule],
|
||||
buffer_size: int = int(1e6),
|
||||
buffer_size: int = 1000000,
|
||||
learning_starts: int = 100,
|
||||
batch_size: int = 256,
|
||||
tau: float = 0.005,
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ def plot_curves(
|
|||
plt.figure(title, figsize=figsize)
|
||||
max_x = max(xy[0][-1] for xy in xy_list)
|
||||
min_x = 0
|
||||
for (i, (x, y)) in enumerate(xy_list):
|
||||
for (_, (x, y)) in enumerate(xy_list):
|
||||
plt.scatter(x, y, s=2)
|
||||
# Do not plot the smoothed curve at all if the timeseries is shorter than window size.
|
||||
if x.shape[0] >= EPISODES_WINDOW:
|
||||
|
|
|
|||
|
|
@ -170,7 +170,7 @@ class MlpExtractor(nn.Module):
|
|||
last_layer_dim_shared = feature_dim
|
||||
|
||||
# Iterate through the shared layers and build the shared parts of the network
|
||||
for idx, layer in enumerate(net_arch):
|
||||
for layer in net_arch:
|
||||
if isinstance(layer, int): # Check that this is a shared layer
|
||||
layer_size = layer
|
||||
# TODO: give layer a meaningful name
|
||||
|
|
@ -192,7 +192,7 @@ class MlpExtractor(nn.Module):
|
|||
last_layer_dim_vf = last_layer_dim_shared
|
||||
|
||||
# Build the non-shared part of the network
|
||||
for idx, (pi_layer_size, vf_layer_size) in enumerate(zip_longest(policy_only_layers, value_only_layers)):
|
||||
for pi_layer_size, vf_layer_size in zip_longest(policy_only_layers, value_only_layers):
|
||||
if pi_layer_size is not None:
|
||||
assert isinstance(pi_layer_size, int), "Error: net_arch[-1]['pi'] must only contain integers."
|
||||
policy_net.append(nn.Linear(last_layer_dim_pi, pi_layer_size))
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ class DDPG(TD3):
|
|||
policy: Union[str, Type[TD3Policy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, Schedule] = 1e-3,
|
||||
buffer_size: int = int(1e6),
|
||||
buffer_size: int = 1000000,
|
||||
learning_starts: int = 100,
|
||||
batch_size: int = 100,
|
||||
tau: float = 0.005,
|
||||
|
|
|
|||
|
|
@ -147,7 +147,7 @@ class DQN(OffPolicyAlgorithm):
|
|||
self._update_learning_rate(self.policy.optimizer)
|
||||
|
||||
losses = []
|
||||
for gradient_step in range(gradient_steps):
|
||||
for _ in range(gradient_steps):
|
||||
# Sample replay buffer
|
||||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
|
||||
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class SAC(OffPolicyAlgorithm):
|
|||
policy: Union[str, Type[SACPolicy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, Schedule] = 3e-4,
|
||||
buffer_size: int = int(1e6),
|
||||
buffer_size: int = 1000000,
|
||||
learning_starts: int = 100,
|
||||
batch_size: int = 256,
|
||||
tau: float = 0.005,
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class TD3(OffPolicyAlgorithm):
|
|||
policy: Union[str, Type[TD3Policy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, Schedule] = 1e-3,
|
||||
buffer_size: int = int(1e6),
|
||||
buffer_size: int = 1000000, # 1e6
|
||||
learning_starts: int = 100,
|
||||
batch_size: int = 100,
|
||||
tau: float = 0.005,
|
||||
|
|
@ -131,7 +131,7 @@ class TD3(OffPolicyAlgorithm):
|
|||
|
||||
actor_losses, critic_losses = [], []
|
||||
|
||||
for gradient_step in range(gradient_steps):
|
||||
for _ in range(gradient_steps):
|
||||
|
||||
self._n_updates += 1
|
||||
# Sample replay buffer
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.1.0a0
|
||||
1.1.0a1
|
||||
|
|
|
|||
Loading…
Reference in a new issue