From d496cd4d950ff37aefede718240c56169cbe1562 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Wed, 22 Dec 2021 11:43:59 +0100 Subject: [PATCH] Consistent use of `device` as keyword argument (#702) * consistent device as keyword arg * Fixed ``device`` arg inconsistency in changelog --- docs/misc/changelog.rst | 1 + stable_baselines3/common/off_policy_algorithm.py | 6 +++--- stable_baselines3/common/on_policy_algorithm.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 50fd9fd..3e4c067 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -38,6 +38,7 @@ Others: - Added a warning in the env checker when not using ``np.float32`` for continuous actions - Improved test coverage and error message when checking shape of observation - Added ``newline="\n"`` when opening CSV monitor files so that each line ends with ``\r\n`` instead of ``\r\r\n`` on Windows while Linux environments are not affected (@hsuehch) +- Fixed ``device`` argument inconsistency (@qgallouedec) Documentation: ^^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index b1528f5..015a32b 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -197,14 +197,14 @@ class OffPolicyAlgorithm(BaseAlgorithm): self.buffer_size, self.observation_space, self.action_space, - self.device, + device=self.device, optimize_memory_usage=self.optimize_memory_usage, ) self.replay_buffer = HerReplayBuffer( self.env, self.buffer_size, - self.device, + device=self.device, replay_buffer=replay_buffer, **self.replay_buffer_kwargs, ) @@ -214,7 +214,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): self.buffer_size, self.observation_space, self.action_space, - self.device, + device=self.device, n_envs=self.n_envs, optimize_memory_usage=self.optimize_memory_usage, **self.replay_buffer_kwargs, diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index a58d331..062db93 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -112,7 +112,7 @@ class OnPolicyAlgorithm(BaseAlgorithm): self.n_steps, self.observation_space, self.action_space, - self.device, + device=self.device, gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs,