mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Upgrade black formatting (#1310)
* apply black * Reformat tests --------- Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
bea3c44ba5
commit
82bc63fca4
22 changed files with 4 additions and 44 deletions
|
|
@ -81,7 +81,6 @@ class A2C(OnPolicyAlgorithm):
|
|||
device: Union[th.device, str] = "auto",
|
||||
_init_setup_model: bool = True,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
policy,
|
||||
env,
|
||||
|
|
@ -132,7 +131,6 @@ class A2C(OnPolicyAlgorithm):
|
|||
|
||||
# This will only loop once (get all data in one go)
|
||||
for rollout_data in self.rollout_buffer.get(batch_size=None):
|
||||
|
||||
actions = rollout_data.actions
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
# Convert discrete action from float to long
|
||||
|
|
@ -189,7 +187,6 @@ class A2C(OnPolicyAlgorithm):
|
|||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> SelfA2C:
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
|
|
|
|||
|
|
@ -240,7 +240,6 @@ class ReplayBuffer(BaseBuffer):
|
|||
done: np.ndarray,
|
||||
infos: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
|
||||
# Reshape needed when using multiple envs with discrete observations
|
||||
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
|
||||
if isinstance(self.observation_space, spaces.Discrete):
|
||||
|
|
@ -346,7 +345,6 @@ class RolloutBuffer(BaseBuffer):
|
|||
gamma: float = 0.99,
|
||||
n_envs: int = 1,
|
||||
):
|
||||
|
||||
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
|
||||
self.gae_lambda = gae_lambda
|
||||
self.gamma = gamma
|
||||
|
|
@ -356,7 +354,6 @@ class RolloutBuffer(BaseBuffer):
|
|||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
|
||||
self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=np.float32)
|
||||
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
|
||||
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
|
|
@ -451,7 +448,6 @@ class RolloutBuffer(BaseBuffer):
|
|||
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
||||
# Prepare the data
|
||||
if not self.generator_ready:
|
||||
|
||||
_tensor_names = [
|
||||
"observations",
|
||||
"actions",
|
||||
|
|
@ -688,7 +684,6 @@ class DictRolloutBuffer(RolloutBuffer):
|
|||
gamma: float = 0.99,
|
||||
n_envs: int = 1,
|
||||
):
|
||||
|
||||
super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
|
||||
|
||||
assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
|
||||
|
|
@ -763,7 +758,6 @@ class DictRolloutBuffer(RolloutBuffer):
|
|||
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
||||
# Prepare the data
|
||||
if not self.generator_ready:
|
||||
|
||||
for key, obs in self.observations.items():
|
||||
self.observations[key] = self.swap_and_flatten(obs)
|
||||
|
||||
|
|
@ -787,7 +781,6 @@ class DictRolloutBuffer(RolloutBuffer):
|
|||
batch_inds: np.ndarray,
|
||||
env: Optional[VecNormalize] = None,
|
||||
) -> DictRolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME
|
||||
|
||||
return DictRolloutBufferSamples(
|
||||
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
|
||||
actions=self.to_torch(self.actions[batch_inds]),
|
||||
|
|
|
|||
|
|
@ -429,11 +429,9 @@ class EvalCallback(EventCallback):
|
|||
self._is_success_buffer.append(maybe_is_success)
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
|
||||
continue_training = True
|
||||
|
||||
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
|
||||
|
||||
# Sync training and eval env if there is VecNormalize
|
||||
if self.model.get_vec_normalize_env() is not None:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -91,7 +91,6 @@ def evaluate_policy(
|
|||
current_lengths += 1
|
||||
for i in range(n_envs):
|
||||
if episode_counts[i] < episode_count_targets[i]:
|
||||
|
||||
# unpack values so that the callback can access the local variables
|
||||
reward = rewards[i]
|
||||
done = dones[i]
|
||||
|
|
|
|||
|
|
@ -173,7 +173,6 @@ class HumanOutputFormat(KVWriter, SeqWriter):
|
|||
key2str = {}
|
||||
tag = None
|
||||
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
|
||||
|
||||
if excluded is not None and ("stdout" in excluded or "log" in excluded):
|
||||
continue
|
||||
|
||||
|
|
@ -342,7 +341,7 @@ class CSVOutputFormat(KVWriter):
|
|||
self.file.seek(0)
|
||||
lines = self.file.readlines()
|
||||
self.file.seek(0)
|
||||
for (i, key) in enumerate(self.keys):
|
||||
for i, key in enumerate(self.keys):
|
||||
if i > 0:
|
||||
self.file.write(",")
|
||||
self.file.write(key)
|
||||
|
|
@ -399,9 +398,7 @@ class TensorBoardOutputFormat(KVWriter):
|
|||
self.writer = SummaryWriter(log_dir=folder)
|
||||
|
||||
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
|
||||
|
||||
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
|
||||
|
||||
if excluded is not None and "tensorboard" in excluded:
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -102,7 +102,6 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
sde_support: bool = True,
|
||||
supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
policy=policy,
|
||||
env=env,
|
||||
|
|
@ -319,7 +318,6 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> SelfOffPolicyAlgorithm:
|
||||
|
||||
total_timesteps, callback = self._setup_learn(
|
||||
total_timesteps,
|
||||
callback,
|
||||
|
|
|
|||
|
|
@ -72,7 +72,6 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
_init_setup_model: bool = True,
|
||||
supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
policy=policy,
|
||||
env=env,
|
||||
|
|
@ -244,7 +243,6 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
callback.on_training_start(locals(), globals())
|
||||
|
||||
while self.num_timesteps < total_timesteps:
|
||||
|
||||
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
|
||||
|
||||
if continue_training is False:
|
||||
|
|
|
|||
|
|
@ -433,7 +433,6 @@ class ActorCriticPolicy(BasePolicy):
|
|||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
|
||||
if optimizer_kwargs is None:
|
||||
optimizer_kwargs = {}
|
||||
# Small values to avoid NaN in Adam optimizer
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ def plot_curves(
|
|||
plt.figure(title, figsize=figsize)
|
||||
max_x = max(xy[0][-1] for xy in xy_list)
|
||||
min_x = 0
|
||||
for (_, (x, y)) in enumerate(xy_list):
|
||||
for _, (x, y) in enumerate(xy_list):
|
||||
plt.scatter(x, y, s=2)
|
||||
# Do not plot the smoothed curve at all if the timeseries is shorter than window size.
|
||||
if x.shape[0] >= EPISODES_WINDOW:
|
||||
|
|
|
|||
|
|
@ -367,7 +367,7 @@ def load_from_zip_file(
|
|||
device: Union[th.device, str] = "auto",
|
||||
verbose: int = 0,
|
||||
print_system_info: bool = False,
|
||||
) -> (Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]):
|
||||
) -> Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]:
|
||||
"""
|
||||
Load model data from a .zip archive
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,6 @@ class StackedObservations:
|
|||
observation_space: spaces.Space,
|
||||
channels_order: Optional[str] = None,
|
||||
):
|
||||
|
||||
self.n_stack = n_stack
|
||||
(
|
||||
self.channels_first,
|
||||
|
|
|
|||
|
|
@ -44,7 +44,6 @@ class VecFrameStack(VecEnvWrapper):
|
|||
def step_wait(
|
||||
self,
|
||||
) -> Tuple[Union[np.ndarray, Dict[str, np.ndarray]], np.ndarray, np.ndarray, List[Dict[str, Any]],]:
|
||||
|
||||
observations, rewards, dones, infos = self.venv.step_wait()
|
||||
|
||||
observations, infos = self.stackedobs.update(observations, dones, infos)
|
||||
|
|
|
|||
|
|
@ -30,7 +30,6 @@ class VecVideoRecorder(VecEnvWrapper):
|
|||
video_length: int = 200,
|
||||
name_prefix: str = "rl-video",
|
||||
):
|
||||
|
||||
VecEnvWrapper.__init__(self, venv)
|
||||
|
||||
self.env = venv
|
||||
|
|
|
|||
|
|
@ -76,7 +76,6 @@ class DDPG(TD3):
|
|||
device: Union[th.device, str] = "auto",
|
||||
_init_setup_model: bool = True,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
policy=policy,
|
||||
env=env,
|
||||
|
|
@ -121,7 +120,6 @@ class DDPG(TD3):
|
|||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> SelfDDPG:
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
|
|
|
|||
|
|
@ -94,7 +94,6 @@ class DQN(OffPolicyAlgorithm):
|
|||
device: Union[th.device, str] = "auto",
|
||||
_init_setup_model: bool = True,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
policy,
|
||||
env,
|
||||
|
|
@ -261,7 +260,6 @@ class DQN(OffPolicyAlgorithm):
|
|||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> SelfDQN:
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
|
|
|
|||
|
|
@ -81,7 +81,6 @@ class HerReplayBuffer(DictReplayBuffer):
|
|||
online_sampling: bool = True,
|
||||
handle_timeout_termination: bool = True,
|
||||
):
|
||||
|
||||
super().__init__(buffer_size, env.observation_space, env.action_space, device, env.num_envs)
|
||||
|
||||
# convert goal_selection_strategy into GoalSelectionStrategy if string
|
||||
|
|
@ -389,7 +388,6 @@ class HerReplayBuffer(DictReplayBuffer):
|
|||
done: np.ndarray,
|
||||
infos: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
|
||||
if self.current_idx == 0 and self.full:
|
||||
# Clear info buffer
|
||||
self.info_buffer[self.pos] = deque(maxlen=self.max_episode_length)
|
||||
|
|
|
|||
|
|
@ -98,7 +98,6 @@ class PPO(OnPolicyAlgorithm):
|
|||
device: Union[th.device, str] = "auto",
|
||||
_init_setup_model: bool = True,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
policy,
|
||||
env,
|
||||
|
|
@ -303,7 +302,6 @@ class PPO(OnPolicyAlgorithm):
|
|||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> SelfPPO:
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
|
|
|
|||
|
|
@ -109,7 +109,6 @@ class SAC(OffPolicyAlgorithm):
|
|||
device: Union[th.device, str] = "auto",
|
||||
_init_setup_model: bool = True,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
policy,
|
||||
env,
|
||||
|
|
@ -295,7 +294,6 @@ class SAC(OffPolicyAlgorithm):
|
|||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> SelfSAC:
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
|
|
|
|||
|
|
@ -94,7 +94,6 @@ class TD3(OffPolicyAlgorithm):
|
|||
device: Union[th.device, str] = "auto",
|
||||
_init_setup_model: bool = True,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
policy,
|
||||
env,
|
||||
|
|
@ -151,7 +150,6 @@ class TD3(OffPolicyAlgorithm):
|
|||
|
||||
actor_losses, critic_losses = [], []
|
||||
for _ in range(gradient_steps):
|
||||
|
||||
self._n_updates += 1
|
||||
# Sample replay buffer
|
||||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
|
||||
|
|
@ -210,7 +208,6 @@ class TD3(OffPolicyAlgorithm):
|
|||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> SelfTD3:
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
|
|
|
|||
|
|
@ -115,7 +115,6 @@ def test_dqn():
|
|||
|
||||
@pytest.mark.parametrize("train_freq", [4, (4, "step"), (1, "episode")])
|
||||
def test_train_freq(tmp_path, train_freq):
|
||||
|
||||
model = SAC(
|
||||
"MlpPolicy",
|
||||
"Pendulum-v1",
|
||||
|
|
|
|||
|
|
@ -648,7 +648,6 @@ def test_open_file_str_pathlib(tmp_path, pathtype):
|
|||
|
||||
|
||||
def test_open_file(tmp_path):
|
||||
|
||||
# path must much the type
|
||||
with pytest.raises(TypeError):
|
||||
open_path(123, None, None, None)
|
||||
|
|
|
|||
|
|
@ -178,7 +178,7 @@ def _make_warmstart_dict_env(**kwargs):
|
|||
|
||||
def test_runningmeanstd():
|
||||
"""Test RunningMeanStd object"""
|
||||
for (x_1, x_2, x_3) in [
|
||||
for x_1, x_2, x_3 in [
|
||||
(np.random.randn(3), np.random.randn(4), np.random.randn(5)),
|
||||
(np.random.randn(3, 2), np.random.randn(4, 2), np.random.randn(5, 2)),
|
||||
]:
|
||||
|
|
@ -336,7 +336,6 @@ def test_normalize_dict_selected_keys():
|
|||
@pytest.mark.parametrize("model_class", [SAC, TD3, HerReplayBuffer])
|
||||
@pytest.mark.parametrize("online_sampling", [False, True])
|
||||
def test_offpolicy_normalization(model_class, online_sampling):
|
||||
|
||||
if online_sampling and model_class != HerReplayBuffer:
|
||||
pytest.skip()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue