From 82bc63fca4273c72deea247faa655e17d47bdd85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Thu, 2 Feb 2023 11:58:41 +0100 Subject: [PATCH] Upgrade black formatting (#1310) * apply black * Reformat tests --------- Co-authored-by: Antonin Raffin --- stable_baselines3/a2c/a2c.py | 3 --- stable_baselines3/common/buffers.py | 7 ------- stable_baselines3/common/callbacks.py | 2 -- stable_baselines3/common/evaluation.py | 1 - stable_baselines3/common/logger.py | 5 +---- stable_baselines3/common/off_policy_algorithm.py | 2 -- stable_baselines3/common/on_policy_algorithm.py | 2 -- stable_baselines3/common/policies.py | 1 - stable_baselines3/common/results_plotter.py | 2 +- stable_baselines3/common/save_util.py | 2 +- stable_baselines3/common/vec_env/stacked_observations.py | 1 - stable_baselines3/common/vec_env/vec_frame_stack.py | 1 - stable_baselines3/common/vec_env/vec_video_recorder.py | 1 - stable_baselines3/ddpg/ddpg.py | 2 -- stable_baselines3/dqn/dqn.py | 2 -- stable_baselines3/her/her_replay_buffer.py | 2 -- stable_baselines3/ppo/ppo.py | 2 -- stable_baselines3/sac/sac.py | 2 -- stable_baselines3/td3/td3.py | 3 --- tests/test_run.py | 1 - tests/test_save_load.py | 1 - tests/test_vec_normalize.py | 3 +-- 22 files changed, 4 insertions(+), 44 deletions(-) diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index ec4ae2e..9e8b40c 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -81,7 +81,6 @@ class A2C(OnPolicyAlgorithm): device: Union[th.device, str] = "auto", _init_setup_model: bool = True, ): - super().__init__( policy, env, @@ -132,7 +131,6 @@ class A2C(OnPolicyAlgorithm): # This will only loop once (get all data in one go) for rollout_data in self.rollout_buffer.get(batch_size=None): - actions = rollout_data.actions if isinstance(self.action_space, spaces.Discrete): # Convert discrete action from float to long @@ -189,7 +187,6 @@ class A2C(OnPolicyAlgorithm): reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> SelfA2C: - return super().learn( total_timesteps=total_timesteps, callback=callback, diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index f71dd29..273dba9 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -240,7 +240,6 @@ class ReplayBuffer(BaseBuffer): done: np.ndarray, infos: List[Dict[str, Any]], ) -> None: - # Reshape needed when using multiple envs with discrete observations # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) if isinstance(self.observation_space, spaces.Discrete): @@ -346,7 +345,6 @@ class RolloutBuffer(BaseBuffer): gamma: float = 0.99, n_envs: int = 1, ): - super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) self.gae_lambda = gae_lambda self.gamma = gamma @@ -356,7 +354,6 @@ class RolloutBuffer(BaseBuffer): self.reset() def reset(self) -> None: - self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=np.float32) self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) @@ -451,7 +448,6 @@ class RolloutBuffer(BaseBuffer): indices = np.random.permutation(self.buffer_size * self.n_envs) # Prepare the data if not self.generator_ready: - _tensor_names = [ "observations", "actions", @@ -688,7 +684,6 @@ class DictRolloutBuffer(RolloutBuffer): gamma: float = 0.99, n_envs: int = 1, ): - super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only" @@ -763,7 +758,6 @@ class DictRolloutBuffer(RolloutBuffer): indices = np.random.permutation(self.buffer_size * self.n_envs) # Prepare the data if not self.generator_ready: - for key, obs in self.observations.items(): self.observations[key] = self.swap_and_flatten(obs) @@ -787,7 +781,6 @@ class DictRolloutBuffer(RolloutBuffer): batch_inds: np.ndarray, env: Optional[VecNormalize] = None, ) -> DictRolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME - return DictRolloutBufferSamples( observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()}, actions=self.to_torch(self.actions[batch_inds]), diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index a96c52c..69a21ab 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -429,11 +429,9 @@ class EvalCallback(EventCallback): self._is_success_buffer.append(maybe_is_success) def _on_step(self) -> bool: - continue_training = True if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0: - # Sync training and eval env if there is VecNormalize if self.model.get_vec_normalize_env() is not None: try: diff --git a/stable_baselines3/common/evaluation.py b/stable_baselines3/common/evaluation.py index ff18137..b65edf8 100644 --- a/stable_baselines3/common/evaluation.py +++ b/stable_baselines3/common/evaluation.py @@ -91,7 +91,6 @@ def evaluate_policy( current_lengths += 1 for i in range(n_envs): if episode_counts[i] < episode_count_targets[i]: - # unpack values so that the callback can access the local variables reward = rewards[i] done = dones[i] diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index 939d924..a8aa766 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -173,7 +173,6 @@ class HumanOutputFormat(KVWriter, SeqWriter): key2str = {} tag = None for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())): - if excluded is not None and ("stdout" in excluded or "log" in excluded): continue @@ -342,7 +341,7 @@ class CSVOutputFormat(KVWriter): self.file.seek(0) lines = self.file.readlines() self.file.seek(0) - for (i, key) in enumerate(self.keys): + for i, key in enumerate(self.keys): if i > 0: self.file.write(",") self.file.write(key) @@ -399,9 +398,7 @@ class TensorBoardOutputFormat(KVWriter): self.writer = SummaryWriter(log_dir=folder) def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None: - for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())): - if excluded is not None and "tensorboard" in excluded: continue diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 48779ef..c1ab215 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -102,7 +102,6 @@ class OffPolicyAlgorithm(BaseAlgorithm): sde_support: bool = True, supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None, ): - super().__init__( policy=policy, env=env, @@ -319,7 +318,6 @@ class OffPolicyAlgorithm(BaseAlgorithm): reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> SelfOffPolicyAlgorithm: - total_timesteps, callback = self._setup_learn( total_timesteps, callback, diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index bc0dda4..44d8b26 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -72,7 +72,6 @@ class OnPolicyAlgorithm(BaseAlgorithm): _init_setup_model: bool = True, supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None, ): - super().__init__( policy=policy, env=env, @@ -244,7 +243,6 @@ class OnPolicyAlgorithm(BaseAlgorithm): callback.on_training_start(locals(), globals()) while self.num_timesteps < total_timesteps: - continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps) if continue_training is False: diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 793cfc5..457274a 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -433,7 +433,6 @@ class ActorCriticPolicy(BasePolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): - if optimizer_kwargs is None: optimizer_kwargs = {} # Small values to avoid NaN in Adam optimizer diff --git a/stable_baselines3/common/results_plotter.py b/stable_baselines3/common/results_plotter.py index 92f67ac..dac2b6c 100644 --- a/stable_baselines3/common/results_plotter.py +++ b/stable_baselines3/common/results_plotter.py @@ -84,7 +84,7 @@ def plot_curves( plt.figure(title, figsize=figsize) max_x = max(xy[0][-1] for xy in xy_list) min_x = 0 - for (_, (x, y)) in enumerate(xy_list): + for _, (x, y) in enumerate(xy_list): plt.scatter(x, y, s=2) # Do not plot the smoothed curve at all if the timeseries is shorter than window size. if x.shape[0] >= EPISODES_WINDOW: diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index facc55a..7ae1e22 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -367,7 +367,7 @@ def load_from_zip_file( device: Union[th.device, str] = "auto", verbose: int = 0, print_system_info: bool = False, -) -> (Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]): +) -> Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]: """ Load model data from a .zip archive diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py index 8583518..d373b87 100644 --- a/stable_baselines3/common/vec_env/stacked_observations.py +++ b/stable_baselines3/common/vec_env/stacked_observations.py @@ -30,7 +30,6 @@ class StackedObservations: observation_space: spaces.Space, channels_order: Optional[str] = None, ): - self.n_stack = n_stack ( self.channels_first, diff --git a/stable_baselines3/common/vec_env/vec_frame_stack.py b/stable_baselines3/common/vec_env/vec_frame_stack.py index e06d512..d933104 100644 --- a/stable_baselines3/common/vec_env/vec_frame_stack.py +++ b/stable_baselines3/common/vec_env/vec_frame_stack.py @@ -44,7 +44,6 @@ class VecFrameStack(VecEnvWrapper): def step_wait( self, ) -> Tuple[Union[np.ndarray, Dict[str, np.ndarray]], np.ndarray, np.ndarray, List[Dict[str, Any]],]: - observations, rewards, dones, infos = self.venv.step_wait() observations, infos = self.stackedobs.update(observations, dones, infos) diff --git a/stable_baselines3/common/vec_env/vec_video_recorder.py b/stable_baselines3/common/vec_env/vec_video_recorder.py index 70d74eb..83d058a 100644 --- a/stable_baselines3/common/vec_env/vec_video_recorder.py +++ b/stable_baselines3/common/vec_env/vec_video_recorder.py @@ -30,7 +30,6 @@ class VecVideoRecorder(VecEnvWrapper): video_length: int = 200, name_prefix: str = "rl-video", ): - VecEnvWrapper.__init__(self, venv) self.env = venv diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py index 40d67b5..c311b23 100644 --- a/stable_baselines3/ddpg/ddpg.py +++ b/stable_baselines3/ddpg/ddpg.py @@ -76,7 +76,6 @@ class DDPG(TD3): device: Union[th.device, str] = "auto", _init_setup_model: bool = True, ): - super().__init__( policy=policy, env=env, @@ -121,7 +120,6 @@ class DDPG(TD3): reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> SelfDDPG: - return super().learn( total_timesteps=total_timesteps, callback=callback, diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index dd8794e..ea1946a 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -94,7 +94,6 @@ class DQN(OffPolicyAlgorithm): device: Union[th.device, str] = "auto", _init_setup_model: bool = True, ): - super().__init__( policy, env, @@ -261,7 +260,6 @@ class DQN(OffPolicyAlgorithm): reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> SelfDQN: - return super().learn( total_timesteps=total_timesteps, callback=callback, diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py index 1518436..0c3da25 100644 --- a/stable_baselines3/her/her_replay_buffer.py +++ b/stable_baselines3/her/her_replay_buffer.py @@ -81,7 +81,6 @@ class HerReplayBuffer(DictReplayBuffer): online_sampling: bool = True, handle_timeout_termination: bool = True, ): - super().__init__(buffer_size, env.observation_space, env.action_space, device, env.num_envs) # convert goal_selection_strategy into GoalSelectionStrategy if string @@ -389,7 +388,6 @@ class HerReplayBuffer(DictReplayBuffer): done: np.ndarray, infos: List[Dict[str, Any]], ) -> None: - if self.current_idx == 0 and self.full: # Clear info buffer self.info_buffer[self.pos] = deque(maxlen=self.max_episode_length) diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index bd80736..c934527 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -98,7 +98,6 @@ class PPO(OnPolicyAlgorithm): device: Union[th.device, str] = "auto", _init_setup_model: bool = True, ): - super().__init__( policy, env, @@ -303,7 +302,6 @@ class PPO(OnPolicyAlgorithm): reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> SelfPPO: - return super().learn( total_timesteps=total_timesteps, callback=callback, diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index 74285b6..d1a6610 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -109,7 +109,6 @@ class SAC(OffPolicyAlgorithm): device: Union[th.device, str] = "auto", _init_setup_model: bool = True, ): - super().__init__( policy, env, @@ -295,7 +294,6 @@ class SAC(OffPolicyAlgorithm): reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> SelfSAC: - return super().learn( total_timesteps=total_timesteps, callback=callback, diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index ae442e1..c844a99 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -94,7 +94,6 @@ class TD3(OffPolicyAlgorithm): device: Union[th.device, str] = "auto", _init_setup_model: bool = True, ): - super().__init__( policy, env, @@ -151,7 +150,6 @@ class TD3(OffPolicyAlgorithm): actor_losses, critic_losses = [], [] for _ in range(gradient_steps): - self._n_updates += 1 # Sample replay buffer replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) @@ -210,7 +208,6 @@ class TD3(OffPolicyAlgorithm): reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> SelfTD3: - return super().learn( total_timesteps=total_timesteps, callback=callback, diff --git a/tests/test_run.py b/tests/test_run.py index 71236a3..ca7548f 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -115,7 +115,6 @@ def test_dqn(): @pytest.mark.parametrize("train_freq", [4, (4, "step"), (1, "episode")]) def test_train_freq(tmp_path, train_freq): - model = SAC( "MlpPolicy", "Pendulum-v1", diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 2c35e43..9d3d537 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -648,7 +648,6 @@ def test_open_file_str_pathlib(tmp_path, pathtype): def test_open_file(tmp_path): - # path must much the type with pytest.raises(TypeError): open_path(123, None, None, None) diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index b17d28c..fb37fd3 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -178,7 +178,7 @@ def _make_warmstart_dict_env(**kwargs): def test_runningmeanstd(): """Test RunningMeanStd object""" - for (x_1, x_2, x_3) in [ + for x_1, x_2, x_3 in [ (np.random.randn(3), np.random.randn(4), np.random.randn(5)), (np.random.randn(3, 2), np.random.randn(4, 2), np.random.randn(5, 2)), ]: @@ -336,7 +336,6 @@ def test_normalize_dict_selected_keys(): @pytest.mark.parametrize("model_class", [SAC, TD3, HerReplayBuffer]) @pytest.mark.parametrize("online_sampling", [False, True]) def test_offpolicy_normalization(model_class, online_sampling): - if online_sampling and model_class != HerReplayBuffer: pytest.skip()