diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 50fd9fd..3e4c067 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -38,6 +38,7 @@ Others: - Added a warning in the env checker when not using ``np.float32`` for continuous actions - Improved test coverage and error message when checking shape of observation - Added ``newline="\n"`` when opening CSV monitor files so that each line ends with ``\r\n`` instead of ``\r\r\n`` on Windows while Linux environments are not affected (@hsuehch) +- Fixed ``device`` argument inconsistency (@qgallouedec) Documentation: ^^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index b1528f5..015a32b 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -197,14 +197,14 @@ class OffPolicyAlgorithm(BaseAlgorithm): self.buffer_size, self.observation_space, self.action_space, - self.device, + device=self.device, optimize_memory_usage=self.optimize_memory_usage, ) self.replay_buffer = HerReplayBuffer( self.env, self.buffer_size, - self.device, + device=self.device, replay_buffer=replay_buffer, **self.replay_buffer_kwargs, ) @@ -214,7 +214,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): self.buffer_size, self.observation_space, self.action_space, - self.device, + device=self.device, n_envs=self.n_envs, optimize_memory_usage=self.optimize_memory_usage, **self.replay_buffer_kwargs, diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index a58d331..062db93 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -112,7 +112,7 @@ class OnPolicyAlgorithm(BaseAlgorithm): self.n_steps, self.observation_space, self.action_space, - self.device, + device=self.device, gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs,