Upgrade black formatting (#1310)

* apply black

* Reformat tests

---------

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
Quentin Gallouédec 2023-02-02 11:58:41 +01:00 committed by GitHub
parent bea3c44ba5
commit 82bc63fca4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 4 additions and 44 deletions

View file

@ -81,7 +81,6 @@ class A2C(OnPolicyAlgorithm):
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):
super().__init__(
policy,
env,
@ -132,7 +131,6 @@ class A2C(OnPolicyAlgorithm):
# This will only loop once (get all data in one go)
for rollout_data in self.rollout_buffer.get(batch_size=None):
actions = rollout_data.actions
if isinstance(self.action_space, spaces.Discrete):
# Convert discrete action from float to long
@ -189,7 +187,6 @@ class A2C(OnPolicyAlgorithm):
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfA2C:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,

View file

@ -240,7 +240,6 @@ class ReplayBuffer(BaseBuffer):
done: np.ndarray,
infos: List[Dict[str, Any]],
) -> None:
# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space, spaces.Discrete):
@ -346,7 +345,6 @@ class RolloutBuffer(BaseBuffer):
gamma: float = 0.99,
n_envs: int = 1,
):
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
self.gae_lambda = gae_lambda
self.gamma = gamma
@ -356,7 +354,6 @@ class RolloutBuffer(BaseBuffer):
self.reset()
def reset(self) -> None:
self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=np.float32)
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
@ -451,7 +448,6 @@ class RolloutBuffer(BaseBuffer):
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
if not self.generator_ready:
_tensor_names = [
"observations",
"actions",
@ -688,7 +684,6 @@ class DictRolloutBuffer(RolloutBuffer):
gamma: float = 0.99,
n_envs: int = 1,
):
super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
@ -763,7 +758,6 @@ class DictRolloutBuffer(RolloutBuffer):
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
if not self.generator_ready:
for key, obs in self.observations.items():
self.observations[key] = self.swap_and_flatten(obs)
@ -787,7 +781,6 @@ class DictRolloutBuffer(RolloutBuffer):
batch_inds: np.ndarray,
env: Optional[VecNormalize] = None,
) -> DictRolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME
return DictRolloutBufferSamples(
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
actions=self.to_torch(self.actions[batch_inds]),

View file

@ -429,11 +429,9 @@ class EvalCallback(EventCallback):
self._is_success_buffer.append(maybe_is_success)
def _on_step(self) -> bool:
continue_training = True
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
# Sync training and eval env if there is VecNormalize
if self.model.get_vec_normalize_env() is not None:
try:

View file

@ -91,7 +91,6 @@ def evaluate_policy(
current_lengths += 1
for i in range(n_envs):
if episode_counts[i] < episode_count_targets[i]:
# unpack values so that the callback can access the local variables
reward = rewards[i]
done = dones[i]

View file

@ -173,7 +173,6 @@ class HumanOutputFormat(KVWriter, SeqWriter):
key2str = {}
tag = None
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
if excluded is not None and ("stdout" in excluded or "log" in excluded):
continue
@ -342,7 +341,7 @@ class CSVOutputFormat(KVWriter):
self.file.seek(0)
lines = self.file.readlines()
self.file.seek(0)
for (i, key) in enumerate(self.keys):
for i, key in enumerate(self.keys):
if i > 0:
self.file.write(",")
self.file.write(key)
@ -399,9 +398,7 @@ class TensorBoardOutputFormat(KVWriter):
self.writer = SummaryWriter(log_dir=folder)
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
if excluded is not None and "tensorboard" in excluded:
continue

View file

@ -102,7 +102,6 @@ class OffPolicyAlgorithm(BaseAlgorithm):
sde_support: bool = True,
supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None,
):
super().__init__(
policy=policy,
env=env,
@ -319,7 +318,6 @@ class OffPolicyAlgorithm(BaseAlgorithm):
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfOffPolicyAlgorithm:
total_timesteps, callback = self._setup_learn(
total_timesteps,
callback,

View file

@ -72,7 +72,6 @@ class OnPolicyAlgorithm(BaseAlgorithm):
_init_setup_model: bool = True,
supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None,
):
super().__init__(
policy=policy,
env=env,
@ -244,7 +243,6 @@ class OnPolicyAlgorithm(BaseAlgorithm):
callback.on_training_start(locals(), globals())
while self.num_timesteps < total_timesteps:
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
if continue_training is False:

View file

@ -433,7 +433,6 @@ class ActorCriticPolicy(BasePolicy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
if optimizer_kwargs is None:
optimizer_kwargs = {}
# Small values to avoid NaN in Adam optimizer

View file

@ -84,7 +84,7 @@ def plot_curves(
plt.figure(title, figsize=figsize)
max_x = max(xy[0][-1] for xy in xy_list)
min_x = 0
for (_, (x, y)) in enumerate(xy_list):
for _, (x, y) in enumerate(xy_list):
plt.scatter(x, y, s=2)
# Do not plot the smoothed curve at all if the timeseries is shorter than window size.
if x.shape[0] >= EPISODES_WINDOW:

View file

@ -367,7 +367,7 @@ def load_from_zip_file(
device: Union[th.device, str] = "auto",
verbose: int = 0,
print_system_info: bool = False,
) -> (Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]):
) -> Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]:
"""
Load model data from a .zip archive

View file

@ -30,7 +30,6 @@ class StackedObservations:
observation_space: spaces.Space,
channels_order: Optional[str] = None,
):
self.n_stack = n_stack
(
self.channels_first,

View file

@ -44,7 +44,6 @@ class VecFrameStack(VecEnvWrapper):
def step_wait(
self,
) -> Tuple[Union[np.ndarray, Dict[str, np.ndarray]], np.ndarray, np.ndarray, List[Dict[str, Any]],]:
observations, rewards, dones, infos = self.venv.step_wait()
observations, infos = self.stackedobs.update(observations, dones, infos)

View file

@ -30,7 +30,6 @@ class VecVideoRecorder(VecEnvWrapper):
video_length: int = 200,
name_prefix: str = "rl-video",
):
VecEnvWrapper.__init__(self, venv)
self.env = venv

View file

@ -76,7 +76,6 @@ class DDPG(TD3):
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):
super().__init__(
policy=policy,
env=env,
@ -121,7 +120,6 @@ class DDPG(TD3):
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfDDPG:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,

View file

@ -94,7 +94,6 @@ class DQN(OffPolicyAlgorithm):
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):
super().__init__(
policy,
env,
@ -261,7 +260,6 @@ class DQN(OffPolicyAlgorithm):
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfDQN:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,

View file

@ -81,7 +81,6 @@ class HerReplayBuffer(DictReplayBuffer):
online_sampling: bool = True,
handle_timeout_termination: bool = True,
):
super().__init__(buffer_size, env.observation_space, env.action_space, device, env.num_envs)
# convert goal_selection_strategy into GoalSelectionStrategy if string
@ -389,7 +388,6 @@ class HerReplayBuffer(DictReplayBuffer):
done: np.ndarray,
infos: List[Dict[str, Any]],
) -> None:
if self.current_idx == 0 and self.full:
# Clear info buffer
self.info_buffer[self.pos] = deque(maxlen=self.max_episode_length)

View file

@ -98,7 +98,6 @@ class PPO(OnPolicyAlgorithm):
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):
super().__init__(
policy,
env,
@ -303,7 +302,6 @@ class PPO(OnPolicyAlgorithm):
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfPPO:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,

View file

@ -109,7 +109,6 @@ class SAC(OffPolicyAlgorithm):
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):
super().__init__(
policy,
env,
@ -295,7 +294,6 @@ class SAC(OffPolicyAlgorithm):
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfSAC:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,

View file

@ -94,7 +94,6 @@ class TD3(OffPolicyAlgorithm):
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):
super().__init__(
policy,
env,
@ -151,7 +150,6 @@ class TD3(OffPolicyAlgorithm):
actor_losses, critic_losses = [], []
for _ in range(gradient_steps):
self._n_updates += 1
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
@ -210,7 +208,6 @@ class TD3(OffPolicyAlgorithm):
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfTD3:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,

View file

@ -115,7 +115,6 @@ def test_dqn():
@pytest.mark.parametrize("train_freq", [4, (4, "step"), (1, "episode")])
def test_train_freq(tmp_path, train_freq):
model = SAC(
"MlpPolicy",
"Pendulum-v1",

View file

@ -648,7 +648,6 @@ def test_open_file_str_pathlib(tmp_path, pathtype):
def test_open_file(tmp_path):
# path must much the type
with pytest.raises(TypeError):
open_path(123, None, None, None)

View file

@ -178,7 +178,7 @@ def _make_warmstart_dict_env(**kwargs):
def test_runningmeanstd():
"""Test RunningMeanStd object"""
for (x_1, x_2, x_3) in [
for x_1, x_2, x_3 in [
(np.random.randn(3), np.random.randn(4), np.random.randn(5)),
(np.random.randn(3, 2), np.random.randn(4, 2), np.random.randn(5, 2)),
]:
@ -336,7 +336,6 @@ def test_normalize_dict_selected_keys():
@pytest.mark.parametrize("model_class", [SAC, TD3, HerReplayBuffer])
@pytest.mark.parametrize("online_sampling", [False, True])
def test_offpolicy_normalization(model_class, online_sampling):
if online_sampling and model_class != HerReplayBuffer:
pytest.skip()