diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 9620882..61c9fea 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -758,16 +758,16 @@ def test_no_resource_warning(tmp_path): # check that files are properly closed # Create a PPO agent and save it - PPO("MlpPolicy", "CartPole-v1").save(tmp_path / "dqn_cartpole") - PPO.load(tmp_path / "dqn_cartpole") + PPO("MlpPolicy", "CartPole-v1", device="cpu").save(tmp_path / "dqn_cartpole") + PPO.load(tmp_path / "dqn_cartpole", device="cpu") - PPO("MlpPolicy", "CartPole-v1").save(str(tmp_path / "dqn_cartpole")) - PPO.load(str(tmp_path / "dqn_cartpole")) + PPO("MlpPolicy", "CartPole-v1", device="cpu").save(str(tmp_path / "dqn_cartpole")) + PPO.load(str(tmp_path / "dqn_cartpole"), device="cpu") # Do the same but in memory, should not close the file with tempfile.TemporaryFile() as fp: - PPO("MlpPolicy", "CartPole-v1").save(fp) - PPO.load(fp) + PPO("MlpPolicy", "CartPole-v1", device="cpu").save(fp) + PPO.load(fp, device="cpu") assert not fp.closed # Same but with replay buffer