stable-baselines3/tests/test_save_load.py
Antonin RAFFIN 40e0b9d2c8
Add Gymnasium support (#1327)
* Fix failing set_env test

* Fix test failiing due to deprectation of env.seed

* Adjust mean reward threshold in failing test

* Fix her test failing due to rng

* Change seed and revert reward threshold to 90

* Pin gym version

* Make VecEnv compatible with gym seeding change

* Revert change to VecEnv reset signature

* Change subprocenv seed cmd to call reset instead

* Fix type check

* Add backward compat

* Add `compat_gym_seed` helper

* Add goal env checks in env_checker

* Add docs on  HER requirements for envs

* Capture user warning in test with inverted box space

* Update ale-py version

* Fix randint

* Allow noop_max to be zero

* Update changelog

* Update docker image

* Update doc conda env and dockerfile

* Custom envs should not have any warnings

* Fix test for numpy >= 1.21

* Add check for vectorized compute reward

* Bump to gym 0.24

* Fix gym default step docstring

* Test downgrading gym

* Revert "Test downgrading gym"

This reverts commit 0072b77156c006ada8a1d6e26ce347ed85a83eeb.

* Fix protobuf error

* Fix in dependencies

* Fix protobuf dep

* Use newest version of cartpole

* Update gym

* Fix warning

* Loosen required scipy version

* Scipy no longer needed

* Try gym 0.25

* Silence warnings from gym

* Filter warnings during tests

* Update doc

* Update requirements

* Add gym 26 compat in vec env

* Fixes in envs and tests for gym 0.26+

* Enforce gym 0.26 api

* format

* Fix formatting

* Fix dependencies

* Fix syntax

* Cleanup doc and warnings

* Faster tests

* Higher budget for HER perf test (revert prev change)

* Fixes and update doc

* Fix doc build

* Fix breaking change

* Fixes for rendering

* Rename variables in monitor

* update render method for gym 0.26 API

backwards compatible (mode argument is allowed) while using the gym 0.26 API (render mode is determined at environment creation)

* update tests and docs to new gym render API

* undo removal of render modes metatadata check

* set rgb_array as default render mode for gym.make

* undo changes & raise warning if not 'rgb_array'

* Fix type check

* Remove recursion and fix type checking

* Remove hacks for protobuf and gym 0.24

* Fix type annotations

* reuse existing render_mode attribute

* return tiled images for 'human' render mode

* Allow to use opencv for human render, fix typos

* Add warning when using non-zero start with Discrete (fixes #1197)

* Fix type checking

* Bug fixes and handle more cases

* Throw proper warnings

* Update test

* Fix new metadata name

* Ignore numpy warnings

* Fixes in vec recorder

* Global ignore

* Filter local warning too

* Monkey patch not needed for gym 26

* Add doc of VecEnv vs Gym API

* Add render test

* Fix return type

* Update VecEnv vs Gym API doc

* Fix for custom render mode

* Fix return type

* Fix type checking

* check test env test_buffer

* skip render check

* check env test_dict_env

* test_env test_gae

* check envs in remaining tests

* Update tests

* Add warning for Discrete action space with non-zero (#1295)

* Fix atari annotation

* ignore get_action_meanings [attr-defined]

* Fix mypy issues

* Add patch for gym/gymnasium transition

* Switch to gymnasium

* Rely on signature instead of version

* More patches

* Type ignore because of https://github.com/Farama-Foundation/Gymnasium/pull/39

* Fix doc build

* Fix pytype errors

* Fix atari requirement

* Update env checker due to change in dtype for Discrete

* Fix type hint

* Convert spaces for saved models

* Ignore pytype

* Remove gitlab CI

* Disable pytype for convert space

* Fix undefined info

* Fix undefined info

* Upgrade shimmy

* Fix wrappers type annotation (need PR from Gymnasium)

* Fix gymnasium dependency

* Fix dependency declaration

* Cap pygame version for python 3.7

* Point to master branch (v0.28.0)

* Fix: use main not master branch

* Rename done to terminated

* Fix pygame dependency for python 3.7

* Rename gym to gymnasium

* Update Gymnasium

* Fix test

* Fix tests

* Forks don't have access to private variables

* Fix linter warnings

* Update read the doc env

* Fix env checker for GoalEnv

* Fix import

* Update env checker (more info) and fix dtype

* Use micromamab for Docker

* Update dependencies

* Clarify VecEnv doc

* Fix Gymnasium version

* Copy file only after mamba install

* [ci skip] Update docker doc

* Polish code

* Reformat

* Remove deprecated features

* Ignore warning

* Update doc

* Update examples and changelog

* Fix type annotation bundle (SAC, TD3, A2C, PPO, base class) (#1436)

* Fix SAC type hints, improve DQN ones

* Fix A2C and TD3 type hints

* Fix PPO type hints

* Fix on-policy type hints

* Fix base class type annotation, do not use defaults

* Update version

* Disable mypy for python 3.7

* Rename Gym26StepReturn

* Update continuous critic type annotation

* Fix pytype complain

---------

Co-authored-by: Carlos Luis <carlos.luisgonc@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Thomas Lips <37955681+tlpss@users.noreply.github.com>
Co-authored-by: tlips <thomas.lips@ugent.be>
Co-authored-by: tlpss <thomas17.lips@gmail.com>
Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
2023-04-14 13:13:59 +02:00

732 lines
26 KiB
Python

import base64
import io
import json
import os
import pathlib
import warnings
import zipfile
from collections import OrderedDict
from copy import deepcopy
import gymnasium as gym
import numpy as np
import pytest
import torch as th
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import DummyVecEnv
MODEL_LIST = [PPO, A2C, TD3, SAC, DQN, DDPG]
def select_env(model_class: BaseAlgorithm) -> gym.Env:
"""
Selects an environment with the correct action space as DQN only supports discrete action space
"""
if model_class == DQN:
return IdentityEnv(10)
else:
return IdentityEnvBox(-10, 10)
@pytest.mark.parametrize("model_class", MODEL_LIST)
def test_save_load(tmp_path, model_class):
"""
Test if 'save' and 'load' saves and loads model correctly
and if 'get_parameters' and 'set_parameters' and work correctly.
''warning does not test function of optimizer parameter load
:param model_class: (BaseAlgorithm) A RL model
"""
env = DummyVecEnv([lambda: select_env(model_class)])
# create model
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1)
model.learn(total_timesteps=500)
env.reset()
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
# Get parameters of different objects
# deepcopy to avoid referencing to tensors we are about to modify
original_params = deepcopy(model.get_parameters())
# Test different error cases of set_parameters.
# Test that invalid object names throw errors
invalid_object_params = deepcopy(original_params)
invalid_object_params["I_should_not_be_a_valid_object"] = "and_I_am_an_invalid_tensor"
with pytest.raises(ValueError):
model.set_parameters(invalid_object_params, exact_match=True)
with pytest.raises(ValueError):
model.set_parameters(invalid_object_params, exact_match=False)
# Test that exact_match catches when something was missed.
missing_object_params = {k: v for k, v in list(original_params.items())[:-1]}
with pytest.raises(ValueError):
model.set_parameters(missing_object_params, exact_match=True)
# Test that exact_match catches when something inside state-dict
# is missing but we have exact_match.
missing_state_dict_tensor_params = {}
for object_name in original_params:
object_params = {}
missing_state_dict_tensor_params[object_name] = object_params
# Skip last item in state-dict
for k, v in list(original_params[object_name].items())[:-1]:
object_params[k] = v
with pytest.raises(RuntimeError):
# PyTorch load_state_dict throws RuntimeError if strict but
# invalid state-dict.
model.set_parameters(missing_state_dict_tensor_params, exact_match=True)
# Test that parameters do indeed change.
random_params = {}
for object_name, params in original_params.items():
# Do not randomize optimizer parameters (custom layout)
if "optim" in object_name:
random_params[object_name] = params
else:
# Again, skip the last item in state-dict
random_params[object_name] = OrderedDict(
(param_name, th.rand_like(param)) for param_name, param in list(params.items())[:-1]
)
# Update model parameters with the new random values
model.set_parameters(random_params, exact_match=False)
new_params = model.get_parameters()
# Check that all params except the final item in each state-dict are different.
for object_name in original_params:
# Skip optimizers (no valid comparison with just th.allclose)
if "optim" in object_name:
continue
# state-dicts use ordered dictionaries, so key order
# is guaranteed.
last_key = list(original_params[object_name].keys())[-1]
for k in original_params[object_name]:
if k == last_key:
# Should be same as before
assert th.allclose(
original_params[object_name][k], new_params[object_name][k]
), "Parameter changed despite not included in the loaded parameters."
else:
# Should be different
assert not th.allclose(
original_params[object_name][k], new_params[object_name][k]
), "Parameters did not change as expected."
params = new_params
# get selected actions
selected_actions, _ = model.predict(observations, deterministic=True)
# Check
model.save(tmp_path / "test_save.zip")
del model
# Check if the model loads as expected for every possible choice of device:
for device in ["auto", "cpu", "cuda"]:
model = model_class.load(str(tmp_path / "test_save.zip"), env=env, device=device)
# check if the model was loaded to the correct device
assert model.device.type == get_device(device).type
assert model.policy.device.type == get_device(device).type
# check if params are still the same after load
new_params = model.get_parameters()
# Check that all params are the same as before save load procedure now
for object_name in new_params:
# Skip optimizers (no valid comparison with just th.allclose)
if "optim" in object_name:
continue
for key in params[object_name]:
assert new_params[object_name][key].device.type == get_device(device).type
assert th.allclose(
params[object_name][key].to("cpu"), new_params[object_name][key].to("cpu")
), "Model parameters not the same after save and load."
# check if model still selects the same actions
new_selected_actions, _ = model.predict(observations, deterministic=True)
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
# check if learn still works
model.learn(total_timesteps=500)
del model
# clear file from os
os.remove(tmp_path / "test_save.zip")
@pytest.mark.parametrize("model_class", MODEL_LIST)
def test_set_env(tmp_path, model_class):
"""
Test if set_env function does work correct
:param model_class: (BaseAlgorithm) A RL model
"""
# use discrete for DQN
env = DummyVecEnv([lambda: select_env(model_class)])
env2 = DummyVecEnv([lambda: select_env(model_class)])
env3 = select_env(model_class)
env4 = DummyVecEnv([lambda: select_env(model_class) for _ in range(2)])
kwargs = {}
if model_class in {DQN, DDPG, SAC, TD3}:
kwargs = dict(learning_starts=50, train_freq=4)
elif model_class in {A2C, PPO}:
kwargs = dict(n_steps=64)
# create model
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), **kwargs)
# learn
model.learn(total_timesteps=64)
# change env
model.set_env(env2, force_reset=True)
# Check that last obs was discarded
assert model._last_obs is None
# learn again
model.learn(total_timesteps=64, reset_num_timesteps=True)
assert model.num_timesteps == 64
# change env test wrapping
model.set_env(env3)
# learn again
model.learn(total_timesteps=64)
# num_env must be the same
with pytest.raises(AssertionError):
model.set_env(env4)
# Keep the same env, disable reset
model.set_env(model.get_env(), force_reset=False)
assert model._last_obs is not None
# learn again
model.learn(total_timesteps=64, reset_num_timesteps=False)
assert model.num_timesteps == 2 * 64
current_env = model.get_env()
model.save(tmp_path / "test_save.zip")
del model
# Check that we can keep the number of timesteps after loading
# Here the env kept its state so we don't have to reset
model = model_class.load(tmp_path / "test_save.zip", env=current_env, force_reset=False)
assert model._last_obs is not None
model.learn(total_timesteps=64, reset_num_timesteps=False)
assert model.num_timesteps == 3 * 64
del model
# We are changing the env, the env must reset but we should keep the number of timesteps
model = model_class.load(tmp_path / "test_save.zip", env=env3, force_reset=True)
assert model._last_obs is None
model.learn(total_timesteps=64, reset_num_timesteps=False)
assert model.num_timesteps == 3 * 64
del model
# Load the model with a different number of environments
model = model_class.load(tmp_path / "test_save.zip", env=env4)
model.learn(total_timesteps=64)
# Clear saved file
os.remove(tmp_path / "test_save.zip")
@pytest.mark.parametrize("model_class", MODEL_LIST)
def test_exclude_include_saved_params(tmp_path, model_class):
"""
Test if exclude and include parameters of save() work
:param model_class: (BaseAlgorithm) A RL model
"""
env = DummyVecEnv([lambda: select_env(model_class)])
# create model, set verbose as 2, which is not standard
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=2)
# Check if exclude works
model.save(tmp_path / "test_save", exclude=["verbose"])
del model
model = model_class.load(str(tmp_path / "test_save.zip"))
# check if verbose was not saved
assert model.verbose != 2
# set verbose as something different then standard settings
model.verbose = 2
# Check if include works
model.save(tmp_path / "test_save", exclude=["verbose"], include=["verbose"])
del model
# Load with custom objects
custom_objects = dict(learning_rate=2e-5, dummy=1.0)
model = model_class.load(
str(tmp_path / "test_save.zip"),
custom_objects=custom_objects,
print_system_info=True,
)
assert model.verbose == 2
# Check that the custom object was taken into account
assert model.learning_rate == custom_objects["learning_rate"]
# Check that only parameters that are here already are replaced
assert not hasattr(model, "dummy")
# clear file from os
os.remove(tmp_path / "test_save.zip")
def test_save_load_pytorch_var(tmp_path):
model = SAC("MlpPolicy", "Pendulum-v1", seed=3, policy_kwargs=dict(net_arch=[64], n_critics=1))
model.learn(200)
save_path = str(tmp_path / "sac_pendulum")
model.save(save_path)
env = model.get_env()
log_ent_coef_before = model.log_ent_coef
del model
model = SAC.load(save_path, env=env)
assert th.allclose(log_ent_coef_before, model.log_ent_coef)
model.learn(200)
log_ent_coef_after = model.log_ent_coef
# Check that the entropy coefficient is still optimized
assert not th.allclose(log_ent_coef_before, log_ent_coef_after)
# With a fixed entropy coef
model = SAC("MlpPolicy", "Pendulum-v1", seed=3, ent_coef=0.01, policy_kwargs=dict(net_arch=[64], n_critics=1))
model.learn(200)
save_path = str(tmp_path / "sac_pendulum")
model.save(save_path)
env = model.get_env()
assert model.log_ent_coef is None
ent_coef_before = model.ent_coef_tensor
del model
model = SAC.load(save_path, env=env)
assert th.allclose(ent_coef_before, model.ent_coef_tensor)
model.learn(200)
ent_coef_after = model.ent_coef_tensor
assert model.log_ent_coef is None
# Check that the entropy coefficient is still the same
assert th.allclose(ent_coef_before, ent_coef_after)
@pytest.mark.parametrize("model_class", [A2C, TD3])
def test_save_load_env_cnn(tmp_path, model_class):
"""
Test loading with an env that requires a ``CnnPolicy``.
This is to test wrapping and observation space check.
We test one on-policy and one off-policy
algorithm as the rest share the loading part.
"""
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=False)
kwargs = dict(policy_kwargs=dict(net_arch=[32]))
if model_class == TD3:
kwargs.update(dict(buffer_size=100, learning_starts=50, train_freq=4))
model = model_class("CnnPolicy", env, **kwargs).learn(100)
model.save(tmp_path / "test_save")
# Test loading with env and continuing training
model = model_class.load(str(tmp_path / "test_save.zip"), env=env, **kwargs).learn(100)
# clear file from os
os.remove(tmp_path / "test_save.zip")
# Check we can load models saved with SB3 < 1.7.0
if model_class == A2C:
del model.policy.pi_features_extractor
model.save(tmp_path / "test_save")
with pytest.warns(UserWarning):
model_class.load(str(tmp_path / "test_save.zip"), env=env, **kwargs).learn(100)
os.remove(tmp_path / "test_save.zip")
@pytest.mark.parametrize("model_class", [SAC, TD3, DQN])
def test_save_load_replay_buffer(tmp_path, model_class):
path = pathlib.Path(tmp_path / "logs/replay_buffer.pkl")
path.parent.mkdir(exist_ok=True, parents=True) # to not raise a warning
model = model_class(
"MlpPolicy", select_env(model_class), buffer_size=1000, policy_kwargs=dict(net_arch=[64]), learning_starts=200
)
model.learn(300)
old_replay_buffer = deepcopy(model.replay_buffer)
model.save_replay_buffer(path)
model.replay_buffer = None
model.load_replay_buffer(path)
assert np.allclose(old_replay_buffer.observations, model.replay_buffer.observations)
assert np.allclose(old_replay_buffer.actions, model.replay_buffer.actions)
assert np.allclose(old_replay_buffer.rewards, model.replay_buffer.rewards)
assert np.allclose(old_replay_buffer.dones, model.replay_buffer.dones)
assert np.allclose(old_replay_buffer.timeouts, model.replay_buffer.timeouts)
infos = [[{"TimeLimit.truncated": truncated}] for truncated in old_replay_buffer.timeouts]
# test extending replay buffer
model.replay_buffer.extend(
old_replay_buffer.observations,
old_replay_buffer.observations,
old_replay_buffer.actions,
old_replay_buffer.rewards,
old_replay_buffer.dones,
infos,
)
@pytest.mark.parametrize("model_class", [DQN, SAC, TD3])
@pytest.mark.parametrize("optimize_memory_usage", [False, True])
def test_warn_buffer(recwarn, model_class, optimize_memory_usage):
"""
When using memory efficient replay buffer,
a warning must be emitted when calling `.learn()`
multiple times.
See https://github.com/DLR-RM/stable-baselines3/issues/46
"""
# remove gym warnings
warnings.filterwarnings(action="ignore", category=DeprecationWarning)
warnings.filterwarnings(action="ignore", category=UserWarning, module="gym")
model = model_class(
"MlpPolicy",
select_env(model_class),
buffer_size=100,
optimize_memory_usage=optimize_memory_usage,
# we cannot use optimize_memory_usage and handle_timeout_termination
# at the same time
replay_buffer_kwargs={"handle_timeout_termination": not optimize_memory_usage},
policy_kwargs=dict(net_arch=[64]),
learning_starts=10,
)
model.learn(150)
model.learn(150, reset_num_timesteps=False)
# Check that there is no warning
assert len(recwarn) == 0
model.learn(150)
if optimize_memory_usage:
assert len(recwarn) == 1
warning = recwarn.pop(UserWarning)
assert "The last trajectory in the replay buffer will be truncated" in str(warning.message)
else:
assert len(recwarn) == 0
@pytest.mark.parametrize("model_class", MODEL_LIST)
@pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"])
@pytest.mark.parametrize("use_sde", [False, True])
def test_save_load_policy(tmp_path, model_class, policy_str, use_sde):
"""
Test saving and loading policy only.
:param model_class: (BaseAlgorithm) A RL model
:param policy_str: (str) Name of the policy.
"""
kwargs = dict(policy_kwargs=dict(net_arch=[16]))
# gSDE is only applicable for A2C, PPO and SAC
if use_sde and model_class not in [A2C, PPO, SAC]:
pytest.skip()
if policy_str == "MlpPolicy":
env = select_env(model_class)
else:
if model_class in [SAC, TD3, DQN, DDPG]:
# Avoid memory error when using replay buffer
# Reduce the size of the features
kwargs = dict(
buffer_size=250, learning_starts=100, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32))
)
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN)
if use_sde:
kwargs["use_sde"] = True
env = DummyVecEnv([lambda: env])
# create model
model = model_class(policy_str, env, verbose=1, **kwargs)
model.learn(total_timesteps=300)
env.reset()
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
policy = model.policy
policy_class = policy.__class__
actor, actor_class = None, None
if model_class in [SAC, TD3]:
actor = policy.actor
actor_class = actor.__class__
# Get dictionary of current parameters
params = deepcopy(policy.state_dict())
# Modify all parameters to be random values
random_params = {param_name: th.rand_like(param) for param_name, param in params.items()}
# Update model parameters with the new random values
policy.load_state_dict(random_params)
new_params = policy.state_dict()
# Check that all params are different now
for k in params:
assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."
params = new_params
# get selected actions
selected_actions, _ = policy.predict(observations, deterministic=True)
# Should also work with the actor only
if actor is not None:
selected_actions_actor, _ = actor.predict(observations, deterministic=True)
# Save and load policy
policy.save(tmp_path / "policy.pkl")
# Save and load actor
if actor is not None:
actor.save(tmp_path / "actor.pkl")
del policy, actor
policy = policy_class.load(tmp_path / "policy.pkl")
if actor_class is not None:
actor = actor_class.load(tmp_path / "actor.pkl")
# check if params are still the same after load
new_params = policy.state_dict()
# Check that all params are the same as before save load procedure now
for key in params:
assert th.allclose(params[key], new_params[key]), "Policy parameters not the same after save and load."
# check if model still selects the same actions
new_selected_actions, _ = policy.predict(observations, deterministic=True)
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
if actor_class is not None:
new_selected_actions_actor, _ = actor.predict(observations, deterministic=True)
assert np.allclose(selected_actions_actor, new_selected_actions_actor, 1e-4)
assert np.allclose(selected_actions_actor, new_selected_actions, 1e-4)
# clear file from os
os.remove(tmp_path / "policy.pkl")
if actor_class is not None:
os.remove(tmp_path / "actor.pkl")
@pytest.mark.parametrize("model_class", [DQN])
@pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"])
def test_save_load_q_net(tmp_path, model_class, policy_str):
"""
Test saving and loading q-network/quantile net only.
:param model_class: (BaseAlgorithm) A RL model
:param policy_str: (str) Name of the policy.
"""
kwargs = dict(policy_kwargs=dict(net_arch=[16]))
if policy_str == "MlpPolicy":
env = select_env(model_class)
else:
if model_class in [DQN]:
# Avoid memory error when using replay buffer
# Reduce the size of the features
kwargs = dict(
buffer_size=250,
learning_starts=100,
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
)
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN)
env = DummyVecEnv([lambda: env])
# create model
model = model_class(policy_str, env, verbose=1, **kwargs)
model.learn(total_timesteps=300)
env.reset()
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
q_net = model.q_net
q_net_class = q_net.__class__
# Get dictionary of current parameters
params = deepcopy(q_net.state_dict())
# Modify all parameters to be random values
random_params = {param_name: th.rand_like(param) for param_name, param in params.items()}
# Update model parameters with the new random values
q_net.load_state_dict(random_params)
new_params = q_net.state_dict()
# Check that all params are different now
for k in params:
assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."
params = new_params
# get selected actions
selected_actions, _ = q_net.predict(observations, deterministic=True)
# Save and load q_net
q_net.save(tmp_path / "q_net.pkl")
del q_net
q_net = q_net_class.load(tmp_path / "q_net.pkl")
# check if params are still the same after load
new_params = q_net.state_dict()
# Check that all params are the same as before save load procedure now
for key in params:
assert th.allclose(params[key], new_params[key]), "Policy parameters not the same after save and load."
# check if model still selects the same actions
new_selected_actions, _ = q_net.predict(observations, deterministic=True)
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
# clear file from os
os.remove(tmp_path / "q_net.pkl")
@pytest.mark.parametrize("pathtype", [str, pathlib.Path])
def test_open_file_str_pathlib(tmp_path, pathtype):
# check that suffix isn't added because we used open_path first
with open_path(pathtype(f"{tmp_path}/t1"), "w") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with warnings.catch_warnings(record=True) as record:
assert load_from_pkl(pathtype(f"{tmp_path}/t1")) == "foo"
assert not record
# test custom suffix
with open_path(pathtype(f"{tmp_path}/t1.custom_ext"), "w") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with warnings.catch_warnings(record=True) as record:
assert load_from_pkl(pathtype(f"{tmp_path}/t1.custom_ext")) == "foo"
assert not record
# test without suffix
with open_path(pathtype(f"{tmp_path}/t1"), "w", suffix="pkl") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with warnings.catch_warnings(record=True) as record:
assert load_from_pkl(pathtype(f"{tmp_path}/t1.pkl")) == "foo"
assert not record
# test that a warning is raised when the path doesn't exist
with open_path(pathtype(f"{tmp_path}/t2.pkl"), "w") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with warnings.catch_warnings(record=True) as record:
assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl")) == "foo"
assert len(record) == 0
with warnings.catch_warnings(record=True) as record:
assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl", verbose=2)) == "foo"
assert len(record) == 1
fp = pathlib.Path(f"{tmp_path}/t2").open("w")
fp.write("rubbish")
fp.close()
# test that a warning is only raised when verbose = 0
with warnings.catch_warnings(record=True) as record:
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=0).close()
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=1).close()
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=2).close()
assert len(record) == 1
def test_open_file(tmp_path):
# path must much the type
with pytest.raises(TypeError):
open_path(123, None, None, None)
p1 = tmp_path / "test1"
fp = p1.open("wb")
# provided path must match the mode
with pytest.raises(ValueError):
open_path(fp, "r")
with pytest.raises(ValueError):
open_path(fp, "randomstuff")
# test identity
_ = open_path(fp, "w")
assert _ is not None
assert fp is _
# Can't use a closed path
with pytest.raises(ValueError):
fp.close()
open_path(fp, "w")
buff = io.BytesIO()
assert buff.writable()
assert buff.readable() is ("w" == "w")
_ = open_path(buff, "w")
assert _ is buff
with pytest.raises(ValueError):
buff.close()
open_path(buff, "w")
@pytest.mark.expensive
def test_save_load_large_model(tmp_path):
"""
Test saving and loading a model with a large policy that is greater than 2GB. We
test only one algorithm since all algorithms share the same code for loading and
saving the model.
"""
env = select_env(TD3)
kwargs = dict(policy_kwargs=dict(net_arch=[8192, 8192, 8192]), device="cpu")
model = TD3("MlpPolicy", env, **kwargs)
# test saving
model.save(tmp_path / "test_save")
# test loading
model = TD3.load(str(tmp_path / "test_save.zip"), env=env, **kwargs)
# clear file from os
os.remove(tmp_path / "test_save.zip")
def test_load_invalid_object(tmp_path):
# See GH Issue #1122 for an example
# of invalid object loading
path = str(tmp_path / "ppo_pendulum.zip")
PPO("MlpPolicy", "Pendulum-v1", learning_rate=lambda _: 1.0).save(path)
with zipfile.ZipFile(path, mode="r") as archive:
json_data = json.loads(archive.read("data").decode())
# Intentionally corrupt the data
serialization = json_data["learning_rate"][":serialized:"]
base64_object = base64.b64decode(serialization.encode())
new_bytes = base64_object.replace(b"CodeType", b"CodeTyps")
base64_encoded = base64.b64encode(new_bytes).decode()
json_data["learning_rate"][":serialized:"] = base64_encoded
serialized_data = json.dumps(json_data, indent=4)
with open(tmp_path / "data", "w") as f:
f.write(serialized_data)
# Replace with the corrupted file
# probably doesn't work on windows
os.system(f"cd {tmp_path}; zip ppo_pendulum.zip data")
with pytest.warns(UserWarning, match=r"custom_objects"):
PPO.load(path)
# Load with custom object, no warnings
with warnings.catch_warnings(record=True) as record:
PPO.load(path, custom_objects=dict(learning_rate=lambda _: 1.0))
assert len(record) == 0