From 9555dd9859025561db12f331fc33be279b023a0a Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 7 Jan 2025 14:19:05 +0100 Subject: [PATCH] Fix tests and warnings when running locally with a GPU (#2069) * Fix test when GPU is available * Sort file list for consistent results * Ignore A2C warnings too --- tests/test_save_load.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 9620882..61c9fea 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -758,16 +758,16 @@ def test_no_resource_warning(tmp_path): # check that files are properly closed # Create a PPO agent and save it - PPO("MlpPolicy", "CartPole-v1").save(tmp_path / "dqn_cartpole") - PPO.load(tmp_path / "dqn_cartpole") + PPO("MlpPolicy", "CartPole-v1", device="cpu").save(tmp_path / "dqn_cartpole") + PPO.load(tmp_path / "dqn_cartpole", device="cpu") - PPO("MlpPolicy", "CartPole-v1").save(str(tmp_path / "dqn_cartpole")) - PPO.load(str(tmp_path / "dqn_cartpole")) + PPO("MlpPolicy", "CartPole-v1", device="cpu").save(str(tmp_path / "dqn_cartpole")) + PPO.load(str(tmp_path / "dqn_cartpole"), device="cpu") # Do the same but in memory, should not close the file with tempfile.TemporaryFile() as fp: - PPO("MlpPolicy", "CartPole-v1").save(fp) - PPO.load(fp) + PPO("MlpPolicy", "CartPole-v1", device="cpu").save(fp) + PPO.load(fp, device="cpu") assert not fp.closed # Same but with replay buffer