mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-16 21:10:08 +00:00
Update black version + update docker image (#151)
* Update docker image * Update black and reformat
This commit is contained in:
parent
a1afc5e42f
commit
15d32c6a4a
8 changed files with 22 additions and 10 deletions
|
|
@ -1,4 +1,4 @@
|
|||
image: stablebaselines/stable-baselines3-cpu:0.8.0a4
|
||||
image: stablebaselines/stable-baselines3-cpu:0.9.0a1
|
||||
|
||||
type-check:
|
||||
script:
|
||||
|
|
|
|||
|
|
@ -171,12 +171,12 @@ class ReplayBuffer(BaseBuffer):
|
|||
mem_available = psutil.virtual_memory().available
|
||||
|
||||
self.optimize_memory_usage = optimize_memory_usage
|
||||
self.observations = np.zeros((self.buffer_size, self.n_envs,) + self.obs_shape, dtype=observation_space.dtype)
|
||||
self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype)
|
||||
if optimize_memory_usage:
|
||||
# `observations` contains also the next observation
|
||||
self.next_observations = None
|
||||
else:
|
||||
self.next_observations = np.zeros((self.buffer_size, self.n_envs,) + self.obs_shape, dtype=observation_space.dtype)
|
||||
self.next_observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype)
|
||||
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype)
|
||||
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
|
|
@ -284,7 +284,7 @@ class RolloutBuffer(BaseBuffer):
|
|||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
self.observations = np.zeros((self.buffer_size, self.n_envs,) + self.obs_shape, dtype=np.float32)
|
||||
self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=np.float32)
|
||||
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
|
||||
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
|
|
|
|||
|
|
@ -699,7 +699,10 @@ class ContinuousCritic(BaseModel):
|
|||
n_critics: int = 2,
|
||||
):
|
||||
super().__init__(
|
||||
observation_space, action_space, features_extractor=features_extractor, normalize_images=normalize_images,
|
||||
observation_space,
|
||||
action_space,
|
||||
features_extractor=features_extractor,
|
||||
normalize_images=normalize_images,
|
||||
)
|
||||
|
||||
action_dim = get_action_dim(self.action_space)
|
||||
|
|
|
|||
|
|
@ -350,7 +350,9 @@ def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose=0)
|
|||
|
||||
|
||||
def load_from_zip_file(
|
||||
load_path: Union[str, pathlib.Path, io.BufferedIOBase], load_data: bool = True, verbose=0,
|
||||
load_path: Union[str, pathlib.Path, io.BufferedIOBase],
|
||||
load_data: bool = True,
|
||||
verbose=0,
|
||||
) -> (Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]):
|
||||
"""
|
||||
Load model data from a .zip archive
|
||||
|
|
|
|||
|
|
@ -31,7 +31,10 @@ class QNetwork(BasePolicy):
|
|||
normalize_images: bool = True,
|
||||
):
|
||||
super(QNetwork, self).__init__(
|
||||
observation_space, action_space, features_extractor=features_extractor, normalize_images=normalize_images,
|
||||
observation_space,
|
||||
action_space,
|
||||
features_extractor=features_extractor,
|
||||
normalize_images=normalize_images,
|
||||
)
|
||||
|
||||
if net_arch is None:
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ def test_flexible_mlp(model_class, net_arch):
|
|||
_ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=100).learn(1000)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("net_arch", [[4], [4, 4],])
|
||||
@pytest.mark.parametrize("net_arch", [[4], [4, 4]])
|
||||
@pytest.mark.parametrize("model_class", [SAC, TD3])
|
||||
def test_custom_offpolicy(model_class, net_arch):
|
||||
_ = model_class("MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=net_arch)).learn(1000)
|
||||
|
|
|
|||
|
|
@ -67,7 +67,11 @@ def test_sde_distribution():
|
|||
|
||||
# TODO: analytical form for squashed Gaussian?
|
||||
@pytest.mark.parametrize(
|
||||
"dist", [DiagGaussianDistribution(N_ACTIONS), StateDependentNoiseDistribution(N_ACTIONS, squash_output=False),]
|
||||
"dist",
|
||||
[
|
||||
DiagGaussianDistribution(N_ACTIONS),
|
||||
StateDependentNoiseDistribution(N_ACTIONS, squash_output=False),
|
||||
],
|
||||
)
|
||||
def test_entropy(dist):
|
||||
# The entropy can be approximated by averaging the negative log likelihood
|
||||
|
|
|
|||
|
|
@ -225,7 +225,7 @@ def check_vecenv_spaces(vec_env_class, space, obs_assert):
|
|||
|
||||
def check_vecenv_obs(obs, space):
|
||||
"""Helper method to check observations from multiple environments each belong to
|
||||
the appropriate observation space."""
|
||||
the appropriate observation space."""
|
||||
assert obs.shape[0] == N_ENVS
|
||||
for value in obs:
|
||||
assert space.contains(value)
|
||||
|
|
|
|||
Loading…
Reference in a new issue