2020-07-02 22:14:21 +00:00
|
|
|
import io
|
2020-07-16 14:12:16 +00:00
|
|
|
import os
|
|
|
|
|
import pathlib
|
2020-06-29 09:16:54 +00:00
|
|
|
import warnings
|
2020-09-24 12:28:27 +00:00
|
|
|
from collections import OrderedDict
|
2020-03-25 15:42:05 +00:00
|
|
|
from copy import deepcopy
|
|
|
|
|
|
2020-06-29 09:16:54 +00:00
|
|
|
import gym
|
2020-03-25 15:42:05 +00:00
|
|
|
import numpy as np
|
2020-07-16 14:12:16 +00:00
|
|
|
import pytest
|
2019-11-21 10:44:37 +00:00
|
|
|
import torch as th
|
|
|
|
|
|
2020-07-16 14:12:16 +00:00
|
|
|
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
|
2020-06-29 09:16:54 +00:00
|
|
|
from stable_baselines3.common.base_class import BaseAlgorithm
|
2020-07-16 14:12:16 +00:00
|
|
|
from stable_baselines3.common.identity_env import FakeImageEnv, IdentityEnv, IdentityEnvBox
|
|
|
|
|
from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl
|
2020-09-20 17:13:18 +00:00
|
|
|
from stable_baselines3.common.utils import get_device
|
2020-05-05 13:02:35 +00:00
|
|
|
from stable_baselines3.common.vec_env import DummyVecEnv
|
2020-07-16 14:12:16 +00:00
|
|
|
|
|
|
|
|
MODEL_LIST = [PPO, A2C, TD3, SAC, DQN, DDPG]
|
2019-11-12 16:03:57 +00:00
|
|
|
|
2020-03-31 14:40:53 +00:00
|
|
|
|
2020-06-29 09:16:54 +00:00
|
|
|
def select_env(model_class: BaseAlgorithm) -> gym.Env:
|
|
|
|
|
"""
|
|
|
|
|
Selects an environment with the correct action space as DQN only supports discrete action space
|
|
|
|
|
"""
|
|
|
|
|
if model_class == DQN:
|
|
|
|
|
return IdentityEnv(10)
|
|
|
|
|
else:
|
|
|
|
|
return IdentityEnvBox(10)
|
|
|
|
|
|
|
|
|
|
|
2019-11-12 16:03:57 +00:00
|
|
|
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
2020-06-29 09:16:54 +00:00
|
|
|
def test_save_load(tmp_path, model_class):
|
2019-11-12 16:03:57 +00:00
|
|
|
"""
|
|
|
|
|
Test if 'save' and 'load' saves and loads model correctly
|
2020-09-24 12:28:27 +00:00
|
|
|
and if 'get_parameters' and 'set_parameters' and work correctly.
|
2020-01-31 12:06:55 +00:00
|
|
|
|
2019-11-21 13:39:44 +00:00
|
|
|
''warning does not test function of optimizer parameter load
|
2019-11-12 16:03:57 +00:00
|
|
|
|
2020-06-09 11:54:18 +00:00
|
|
|
:param model_class: (BaseAlgorithm) A RL model
|
2019-11-12 16:03:57 +00:00
|
|
|
"""
|
2020-06-29 09:16:54 +00:00
|
|
|
|
|
|
|
|
env = DummyVecEnv([lambda: select_env(model_class)])
|
2019-11-12 16:03:57 +00:00
|
|
|
|
|
|
|
|
# create model
|
2020-07-02 22:14:21 +00:00
|
|
|
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1)
|
2020-10-13 10:01:33 +00:00
|
|
|
model.learn(total_timesteps=500)
|
2019-11-12 16:03:57 +00:00
|
|
|
|
2019-11-28 14:25:01 +00:00
|
|
|
env.reset()
|
2020-07-16 14:12:16 +00:00
|
|
|
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
|
2019-11-28 14:25:01 +00:00
|
|
|
|
2020-09-24 12:28:27 +00:00
|
|
|
# Get parameters of different objects
|
|
|
|
|
# deepcopy to avoid referencing to tensors we are about to modify
|
|
|
|
|
original_params = deepcopy(model.get_parameters())
|
|
|
|
|
|
|
|
|
|
# Test different error cases of set_parameters.
|
|
|
|
|
# Test that invalid object names throw errors
|
|
|
|
|
invalid_object_params = deepcopy(original_params)
|
|
|
|
|
invalid_object_params["I_should_not_be_a_valid_object"] = "and_I_am_an_invalid_tensor"
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
|
model.set_parameters(invalid_object_params, exact_match=True)
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
|
model.set_parameters(invalid_object_params, exact_match=False)
|
|
|
|
|
|
|
|
|
|
# Test that exact_match catches when something was missed.
|
|
|
|
|
missing_object_params = dict((k, v) for k, v in list(original_params.items())[:-1])
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
|
model.set_parameters(missing_object_params, exact_match=True)
|
|
|
|
|
|
|
|
|
|
# Test that exact_match catches when something inside state-dict
|
|
|
|
|
# is missing but we have exact_match.
|
|
|
|
|
missing_state_dict_tensor_params = {}
|
|
|
|
|
for object_name in original_params:
|
|
|
|
|
object_params = {}
|
|
|
|
|
missing_state_dict_tensor_params[object_name] = object_params
|
|
|
|
|
# Skip last item in state-dict
|
|
|
|
|
for k, v in list(original_params[object_name].items())[:-1]:
|
|
|
|
|
object_params[k] = v
|
|
|
|
|
with pytest.raises(RuntimeError):
|
|
|
|
|
# PyTorch load_state_dict throws RuntimeError if strict but
|
|
|
|
|
# invalid state-dict.
|
|
|
|
|
model.set_parameters(missing_state_dict_tensor_params, exact_match=True)
|
|
|
|
|
|
|
|
|
|
# Test that parameters do indeed change.
|
|
|
|
|
random_params = {}
|
|
|
|
|
for object_name, params in original_params.items():
|
|
|
|
|
# Do not randomize optimizer parameters (custom layout)
|
|
|
|
|
if "optim" in object_name:
|
|
|
|
|
random_params[object_name] = params
|
|
|
|
|
else:
|
|
|
|
|
# Again, skip the last item in state-dict
|
|
|
|
|
random_params[object_name] = OrderedDict(
|
|
|
|
|
(param_name, th.rand_like(param)) for param_name, param in list(params.items())[:-1]
|
|
|
|
|
)
|
2019-11-21 10:39:47 +00:00
|
|
|
|
2019-11-21 13:39:44 +00:00
|
|
|
# Update model parameters with the new random values
|
2020-09-24 12:28:27 +00:00
|
|
|
model.set_parameters(random_params, exact_match=False)
|
|
|
|
|
|
|
|
|
|
new_params = model.get_parameters()
|
|
|
|
|
# Check that all params except the final item in each state-dict are different.
|
|
|
|
|
for object_name in original_params:
|
|
|
|
|
# Skip optimizers (no valid comparison with just th.allclose)
|
|
|
|
|
if "optim" in object_name:
|
|
|
|
|
continue
|
|
|
|
|
# state-dicts use ordered dictionaries, so key order
|
|
|
|
|
# is guaranteed.
|
|
|
|
|
last_key = list(original_params[object_name].keys())[-1]
|
|
|
|
|
for k in original_params[object_name]:
|
|
|
|
|
if k == last_key:
|
|
|
|
|
# Should be same as before
|
|
|
|
|
assert th.allclose(
|
|
|
|
|
original_params[object_name][k], new_params[object_name][k]
|
|
|
|
|
), "Parameter changed despite not included in the loaded parameters."
|
|
|
|
|
else:
|
|
|
|
|
# Should be different
|
|
|
|
|
assert not th.allclose(
|
|
|
|
|
original_params[object_name][k], new_params[object_name][k]
|
|
|
|
|
), "Parameters did not change as expected."
|
2019-11-21 10:39:47 +00:00
|
|
|
|
|
|
|
|
params = new_params
|
2019-11-12 16:03:57 +00:00
|
|
|
|
2019-11-28 14:38:04 +00:00
|
|
|
# get selected actions
|
2020-03-18 14:11:19 +00:00
|
|
|
selected_actions, _ = model.predict(observations, deterministic=True)
|
2019-11-28 14:25:01 +00:00
|
|
|
|
2019-11-12 16:03:57 +00:00
|
|
|
# Check
|
2020-06-29 09:16:54 +00:00
|
|
|
model.save(tmp_path / "test_save.zip")
|
2019-11-21 13:39:44 +00:00
|
|
|
del model
|
2019-11-21 10:39:47 +00:00
|
|
|
|
2020-09-20 17:13:18 +00:00
|
|
|
# Check if the model loads as expected for every possible choice of device:
|
|
|
|
|
for device in ["auto", "cpu", "cuda"]:
|
|
|
|
|
model = model_class.load(str(tmp_path / "test_save.zip"), env=env, device=device)
|
|
|
|
|
|
|
|
|
|
# check if the model was loaded to the correct device
|
|
|
|
|
assert model.device.type == get_device(device).type
|
|
|
|
|
assert model.policy.device.type == get_device(device).type
|
|
|
|
|
|
|
|
|
|
# check if params are still the same after load
|
2020-09-24 12:28:27 +00:00
|
|
|
new_params = model.get_parameters()
|
2020-09-20 17:13:18 +00:00
|
|
|
|
|
|
|
|
# Check that all params are the same as before save load procedure now
|
2020-09-24 12:28:27 +00:00
|
|
|
for object_name in new_params:
|
|
|
|
|
# Skip optimizers (no valid comparison with just th.allclose)
|
|
|
|
|
if "optim" in object_name:
|
|
|
|
|
continue
|
|
|
|
|
for key in params[object_name]:
|
|
|
|
|
assert new_params[object_name][key].device.type == get_device(device).type
|
|
|
|
|
assert th.allclose(
|
|
|
|
|
params[object_name][key].to("cpu"), new_params[object_name][key].to("cpu")
|
|
|
|
|
), "Model parameters not the same after save and load."
|
2019-11-21 14:44:57 +00:00
|
|
|
|
2020-09-20 17:13:18 +00:00
|
|
|
# check if model still selects the same actions
|
|
|
|
|
new_selected_actions, _ = model.predict(observations, deterministic=True)
|
|
|
|
|
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
|
2019-11-21 15:46:53 +00:00
|
|
|
|
2020-09-20 17:13:18 +00:00
|
|
|
# check if learn still works
|
2020-10-13 10:01:33 +00:00
|
|
|
model.learn(total_timesteps=500)
|
2019-11-28 14:25:01 +00:00
|
|
|
|
2020-09-20 17:13:18 +00:00
|
|
|
del model
|
2019-11-21 15:46:53 +00:00
|
|
|
|
|
|
|
|
# clear file from os
|
2020-06-29 09:16:54 +00:00
|
|
|
os.remove(tmp_path / "test_save.zip")
|
2019-11-28 15:07:15 +00:00
|
|
|
|
|
|
|
|
|
2019-12-05 12:36:19 +00:00
|
|
|
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
|
|
|
|
def test_set_env(model_class):
|
2019-12-05 12:59:07 +00:00
|
|
|
"""
|
|
|
|
|
Test if set_env function does work correct
|
2020-06-09 11:54:18 +00:00
|
|
|
:param model_class: (BaseAlgorithm) A RL model
|
2019-12-05 12:59:07 +00:00
|
|
|
"""
|
2020-06-29 09:16:54 +00:00
|
|
|
|
|
|
|
|
# use discrete for DQN
|
|
|
|
|
env = DummyVecEnv([lambda: select_env(model_class)])
|
|
|
|
|
env2 = DummyVecEnv([lambda: select_env(model_class)])
|
|
|
|
|
env3 = select_env(model_class)
|
2019-12-05 12:59:07 +00:00
|
|
|
|
2020-10-13 10:01:33 +00:00
|
|
|
kwargs = {}
|
|
|
|
|
if model_class in {DQN, DDPG, SAC, TD3}:
|
2021-02-27 18:53:13 +00:00
|
|
|
kwargs = dict(learning_starts=100, train_freq=4)
|
2020-10-13 10:01:33 +00:00
|
|
|
elif model_class in {A2C, PPO}:
|
2021-01-21 00:42:33 +00:00
|
|
|
kwargs = dict(n_steps=64)
|
2020-10-13 10:01:33 +00:00
|
|
|
|
2019-12-05 12:59:07 +00:00
|
|
|
# create model
|
2020-10-13 10:01:33 +00:00
|
|
|
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), **kwargs)
|
2019-12-05 12:59:07 +00:00
|
|
|
# learn
|
2021-03-06 13:17:43 +00:00
|
|
|
model.learn(total_timesteps=128)
|
2019-12-05 12:59:07 +00:00
|
|
|
|
|
|
|
|
# change env
|
|
|
|
|
model.set_env(env2)
|
|
|
|
|
# learn again
|
2021-03-06 13:17:43 +00:00
|
|
|
model.learn(total_timesteps=128)
|
2019-12-05 12:59:07 +00:00
|
|
|
|
|
|
|
|
# change env test wrapping
|
|
|
|
|
model.set_env(env3)
|
|
|
|
|
# learn again
|
2021-03-06 13:17:43 +00:00
|
|
|
model.learn(total_timesteps=128)
|
2019-12-05 12:36:19 +00:00
|
|
|
|
|
|
|
|
|
2019-11-28 15:07:15 +00:00
|
|
|
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
2020-06-29 09:16:54 +00:00
|
|
|
def test_exclude_include_saved_params(tmp_path, model_class):
|
2019-11-28 15:07:15 +00:00
|
|
|
"""
|
|
|
|
|
Test if exclude and include parameters of save() work
|
|
|
|
|
|
2020-06-09 11:54:18 +00:00
|
|
|
:param model_class: (BaseAlgorithm) A RL model
|
2019-11-28 15:07:15 +00:00
|
|
|
"""
|
2020-06-29 09:16:54 +00:00
|
|
|
env = DummyVecEnv([lambda: select_env(model_class)])
|
2019-11-28 15:07:15 +00:00
|
|
|
|
2019-12-05 07:40:28 +00:00
|
|
|
# create model, set verbose as 2, which is not standard
|
2020-07-02 22:14:21 +00:00
|
|
|
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=2)
|
2019-11-28 15:07:15 +00:00
|
|
|
|
|
|
|
|
# Check if exclude works
|
2020-07-02 22:14:21 +00:00
|
|
|
model.save(tmp_path / "test_save", exclude=["verbose"])
|
2019-11-28 15:07:15 +00:00
|
|
|
del model
|
2020-07-02 22:14:21 +00:00
|
|
|
model = model_class.load(str(tmp_path / "test_save.zip"))
|
2019-11-28 15:07:15 +00:00
|
|
|
# check if verbose was not saved
|
2019-12-05 07:07:43 +00:00
|
|
|
assert model.verbose != 2
|
2019-11-28 15:07:15 +00:00
|
|
|
|
|
|
|
|
# set verbose as something different then standard settings
|
|
|
|
|
model.verbose = 2
|
|
|
|
|
# Check if include works
|
2020-07-02 22:14:21 +00:00
|
|
|
model.save(tmp_path / "test_save", exclude=["verbose"], include=["verbose"])
|
2019-11-28 15:07:15 +00:00
|
|
|
del model
|
2021-03-06 13:17:43 +00:00
|
|
|
# Load with custom objects
|
|
|
|
|
custom_objects = dict(learning_rate=2e-5, dummy=1.0)
|
|
|
|
|
model = model_class.load(str(tmp_path / "test_save.zip"), custom_objects=custom_objects)
|
2019-11-28 15:07:15 +00:00
|
|
|
assert model.verbose == 2
|
2021-03-06 13:17:43 +00:00
|
|
|
# Check that the custom object was taken into account
|
|
|
|
|
assert model.learning_rate == custom_objects["learning_rate"]
|
|
|
|
|
# Check that only parameters that are here already are replaced
|
|
|
|
|
assert not hasattr(model, "dummy")
|
2019-11-28 15:07:15 +00:00
|
|
|
|
|
|
|
|
# clear file from os
|
2020-06-29 09:16:54 +00:00
|
|
|
os.remove(tmp_path / "test_save.zip")
|
2020-02-05 12:10:02 +00:00
|
|
|
|
2020-02-14 13:03:41 +00:00
|
|
|
|
2021-04-15 12:50:43 +00:00
|
|
|
def test_save_load_pytorch_var(tmp_path):
|
|
|
|
|
model = SAC("MlpPolicy", "Pendulum-v0", seed=3, policy_kwargs=dict(net_arch=[64], n_critics=1))
|
|
|
|
|
model.learn(200)
|
|
|
|
|
save_path = str(tmp_path / "sac_pendulum")
|
|
|
|
|
model.save(save_path)
|
|
|
|
|
env = model.get_env()
|
|
|
|
|
ent_coef_before = model.log_ent_coef
|
|
|
|
|
|
|
|
|
|
del model
|
|
|
|
|
|
|
|
|
|
model = SAC.load(save_path, env=env)
|
|
|
|
|
assert th.allclose(ent_coef_before, model.log_ent_coef)
|
|
|
|
|
model.learn(200)
|
|
|
|
|
ent_coef_after = model.log_ent_coef
|
|
|
|
|
# Check that the entropy coefficient is still optimized
|
|
|
|
|
assert not th.allclose(ent_coef_before, ent_coef_after)
|
|
|
|
|
|
|
|
|
|
|
2020-10-27 21:12:52 +00:00
|
|
|
@pytest.mark.parametrize("model_class", [A2C, TD3])
|
|
|
|
|
def test_save_load_env_cnn(tmp_path, model_class):
|
|
|
|
|
"""
|
|
|
|
|
Test loading with an env that requires a ``CnnPolicy``.
|
|
|
|
|
This is to test wrapping and observation space check.
|
|
|
|
|
We test one on-policy and one off-policy
|
|
|
|
|
algorithm as the rest share the loading part.
|
|
|
|
|
"""
|
|
|
|
|
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=False)
|
|
|
|
|
kwargs = dict(policy_kwargs=dict(net_arch=[32]))
|
|
|
|
|
if model_class == TD3:
|
2021-02-27 18:53:13 +00:00
|
|
|
kwargs.update(dict(buffer_size=100, learning_starts=50, train_freq=4))
|
2020-10-27 21:12:52 +00:00
|
|
|
|
|
|
|
|
model = model_class("CnnPolicy", env, **kwargs).learn(100)
|
|
|
|
|
model.save(tmp_path / "test_save")
|
|
|
|
|
# Test loading with env and continuing training
|
2021-02-27 18:53:13 +00:00
|
|
|
model = model_class.load(str(tmp_path / "test_save.zip"), env=env, **kwargs).learn(100)
|
2020-10-27 21:12:52 +00:00
|
|
|
# clear file from os
|
|
|
|
|
os.remove(tmp_path / "test_save.zip")
|
|
|
|
|
|
|
|
|
|
|
2020-06-29 09:16:54 +00:00
|
|
|
@pytest.mark.parametrize("model_class", [SAC, TD3, DQN])
|
|
|
|
|
def test_save_load_replay_buffer(tmp_path, model_class):
|
2020-07-02 22:14:21 +00:00
|
|
|
path = pathlib.Path(tmp_path / "logs/replay_buffer.pkl")
|
|
|
|
|
path.parent.mkdir(exist_ok=True, parents=True) # to not raise a warning
|
2020-10-22 09:56:43 +00:00
|
|
|
model = model_class(
|
|
|
|
|
"MlpPolicy", select_env(model_class), buffer_size=1000, policy_kwargs=dict(net_arch=[64]), learning_starts=200
|
|
|
|
|
)
|
|
|
|
|
model.learn(300)
|
2020-02-05 12:10:02 +00:00
|
|
|
old_replay_buffer = deepcopy(model.replay_buffer)
|
2020-07-02 22:14:21 +00:00
|
|
|
model.save_replay_buffer(path)
|
2020-02-05 12:10:02 +00:00
|
|
|
model.replay_buffer = None
|
2020-07-02 22:14:21 +00:00
|
|
|
model.load_replay_buffer(path)
|
2020-02-05 12:10:02 +00:00
|
|
|
|
|
|
|
|
assert np.allclose(old_replay_buffer.observations, model.replay_buffer.observations)
|
|
|
|
|
assert np.allclose(old_replay_buffer.actions, model.replay_buffer.actions)
|
|
|
|
|
assert np.allclose(old_replay_buffer.rewards, model.replay_buffer.rewards)
|
|
|
|
|
assert np.allclose(old_replay_buffer.dones, model.replay_buffer.dones)
|
|
|
|
|
|
2020-02-11 15:40:44 +00:00
|
|
|
# test extending replay buffer
|
2020-07-02 22:14:21 +00:00
|
|
|
model.replay_buffer.extend(
|
|
|
|
|
old_replay_buffer.observations,
|
|
|
|
|
old_replay_buffer.observations,
|
|
|
|
|
old_replay_buffer.actions,
|
|
|
|
|
old_replay_buffer.rewards,
|
|
|
|
|
old_replay_buffer.dones,
|
|
|
|
|
)
|
2020-03-31 14:29:13 +00:00
|
|
|
|
|
|
|
|
|
2020-06-29 09:16:54 +00:00
|
|
|
@pytest.mark.parametrize("model_class", [DQN, SAC, TD3])
|
|
|
|
|
@pytest.mark.parametrize("optimize_memory_usage", [False, True])
|
|
|
|
|
def test_warn_buffer(recwarn, model_class, optimize_memory_usage):
|
|
|
|
|
"""
|
|
|
|
|
When using memory efficient replay buffer,
|
|
|
|
|
a warning must be emitted when calling `.learn()`
|
|
|
|
|
multiple times.
|
|
|
|
|
See https://github.com/DLR-RM/stable-baselines3/issues/46
|
|
|
|
|
"""
|
|
|
|
|
# remove gym warnings
|
2020-07-02 22:14:21 +00:00
|
|
|
warnings.filterwarnings(action="ignore", category=DeprecationWarning)
|
|
|
|
|
warnings.filterwarnings(action="ignore", category=UserWarning, module="gym")
|
|
|
|
|
|
|
|
|
|
model = model_class(
|
|
|
|
|
"MlpPolicy",
|
|
|
|
|
select_env(model_class),
|
|
|
|
|
buffer_size=100,
|
|
|
|
|
optimize_memory_usage=optimize_memory_usage,
|
|
|
|
|
policy_kwargs=dict(net_arch=[64]),
|
|
|
|
|
learning_starts=10,
|
|
|
|
|
)
|
2020-06-29 09:16:54 +00:00
|
|
|
|
|
|
|
|
model.learn(150)
|
|
|
|
|
|
|
|
|
|
model.learn(150, reset_num_timesteps=False)
|
|
|
|
|
|
|
|
|
|
# Check that there is no warning
|
|
|
|
|
assert len(recwarn) == 0
|
|
|
|
|
|
|
|
|
|
model.learn(150)
|
|
|
|
|
|
|
|
|
|
if optimize_memory_usage:
|
|
|
|
|
assert len(recwarn) == 1
|
|
|
|
|
warning = recwarn.pop(UserWarning)
|
2020-07-16 14:12:16 +00:00
|
|
|
assert "The last trajectory in the replay buffer will be truncated" in str(warning.message)
|
2020-06-29 09:16:54 +00:00
|
|
|
else:
|
|
|
|
|
assert len(recwarn) == 0
|
|
|
|
|
|
|
|
|
|
|
2020-03-31 14:29:13 +00:00
|
|
|
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
2020-07-02 22:14:21 +00:00
|
|
|
@pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"])
|
2021-04-19 10:23:02 +00:00
|
|
|
@pytest.mark.parametrize("use_sde", [False, True])
|
|
|
|
|
def test_save_load_policy(tmp_path, model_class, policy_str, use_sde):
|
2020-03-31 14:29:13 +00:00
|
|
|
"""
|
|
|
|
|
Test saving and loading policy only.
|
|
|
|
|
|
2020-06-09 11:54:18 +00:00
|
|
|
:param model_class: (BaseAlgorithm) A RL model
|
2020-04-22 11:14:22 +00:00
|
|
|
:param policy_str: (str) Name of the policy.
|
2020-03-31 14:29:13 +00:00
|
|
|
"""
|
2020-10-27 13:24:59 +00:00
|
|
|
kwargs = dict(policy_kwargs=dict(net_arch=[16]))
|
2021-04-19 10:23:02 +00:00
|
|
|
|
|
|
|
|
# gSDE is only applicable for A2C, PPO and SAC
|
|
|
|
|
if use_sde and model_class not in [A2C, PPO, SAC]:
|
|
|
|
|
pytest.skip()
|
|
|
|
|
|
2020-07-02 22:14:21 +00:00
|
|
|
if policy_str == "MlpPolicy":
|
2020-06-29 09:16:54 +00:00
|
|
|
env = select_env(model_class)
|
2020-04-22 11:14:22 +00:00
|
|
|
else:
|
2020-10-22 09:56:43 +00:00
|
|
|
if model_class in [SAC, TD3, DQN, DDPG]:
|
2020-04-22 11:14:22 +00:00
|
|
|
# Avoid memory error when using replay buffer
|
|
|
|
|
# Reduce the size of the features
|
2020-10-27 13:24:59 +00:00
|
|
|
kwargs = dict(
|
|
|
|
|
buffer_size=250, learning_starts=100, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32))
|
|
|
|
|
)
|
2020-07-16 14:12:16 +00:00
|
|
|
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN)
|
2020-04-22 11:14:22 +00:00
|
|
|
|
2021-04-19 10:23:02 +00:00
|
|
|
if use_sde:
|
|
|
|
|
kwargs["use_sde"] = True
|
|
|
|
|
|
2020-04-22 11:14:22 +00:00
|
|
|
env = DummyVecEnv([lambda: env])
|
2020-03-31 14:29:13 +00:00
|
|
|
|
|
|
|
|
# create model
|
2020-10-27 13:24:59 +00:00
|
|
|
model = model_class(policy_str, env, verbose=1, **kwargs)
|
|
|
|
|
model.learn(total_timesteps=300)
|
2020-03-31 14:29:13 +00:00
|
|
|
|
|
|
|
|
env.reset()
|
2020-07-16 14:12:16 +00:00
|
|
|
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
|
2020-03-31 14:29:13 +00:00
|
|
|
|
|
|
|
|
policy = model.policy
|
2020-04-20 13:59:44 +00:00
|
|
|
policy_class = policy.__class__
|
|
|
|
|
actor, actor_class = None, None
|
2020-03-31 16:26:26 +00:00
|
|
|
if model_class in [SAC, TD3]:
|
|
|
|
|
actor = policy.actor
|
2020-04-20 13:59:44 +00:00
|
|
|
actor_class = actor.__class__
|
2020-03-31 14:29:13 +00:00
|
|
|
|
|
|
|
|
# Get dictionary of current parameters
|
|
|
|
|
params = deepcopy(policy.state_dict())
|
|
|
|
|
|
|
|
|
|
# Modify all parameters to be random values
|
2020-07-16 14:12:16 +00:00
|
|
|
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
|
2020-03-31 14:29:13 +00:00
|
|
|
|
|
|
|
|
# Update model parameters with the new random values
|
|
|
|
|
policy.load_state_dict(random_params)
|
|
|
|
|
|
|
|
|
|
new_params = policy.state_dict()
|
|
|
|
|
# Check that all params are different now
|
|
|
|
|
for k in params:
|
2020-07-16 14:12:16 +00:00
|
|
|
assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."
|
2020-03-31 14:29:13 +00:00
|
|
|
|
|
|
|
|
params = new_params
|
|
|
|
|
|
|
|
|
|
# get selected actions
|
|
|
|
|
selected_actions, _ = policy.predict(observations, deterministic=True)
|
2020-03-31 16:26:26 +00:00
|
|
|
# Should also work with the actor only
|
|
|
|
|
if actor is not None:
|
|
|
|
|
selected_actions_actor, _ = actor.predict(observations, deterministic=True)
|
2020-03-31 14:29:13 +00:00
|
|
|
|
|
|
|
|
# Save and load policy
|
2020-06-29 09:16:54 +00:00
|
|
|
policy.save(tmp_path / "policy.pkl")
|
2020-03-31 16:26:26 +00:00
|
|
|
# Save and load actor
|
|
|
|
|
if actor is not None:
|
2020-06-29 09:16:54 +00:00
|
|
|
actor.save(tmp_path / "actor.pkl")
|
2020-03-31 16:26:26 +00:00
|
|
|
|
2020-04-20 13:59:44 +00:00
|
|
|
del policy, actor
|
|
|
|
|
|
2020-06-29 09:16:54 +00:00
|
|
|
policy = policy_class.load(tmp_path / "policy.pkl")
|
2020-04-20 13:59:44 +00:00
|
|
|
if actor_class is not None:
|
2020-06-29 09:16:54 +00:00
|
|
|
actor = actor_class.load(tmp_path / "actor.pkl")
|
2020-03-31 14:29:13 +00:00
|
|
|
|
|
|
|
|
# check if params are still the same after load
|
|
|
|
|
new_params = policy.state_dict()
|
|
|
|
|
|
|
|
|
|
# Check that all params are the same as before save load procedure now
|
|
|
|
|
for key in params:
|
2020-07-16 14:12:16 +00:00
|
|
|
assert th.allclose(params[key], new_params[key]), "Policy parameters not the same after save and load."
|
2020-03-31 14:29:13 +00:00
|
|
|
|
|
|
|
|
# check if model still selects the same actions
|
|
|
|
|
new_selected_actions, _ = policy.predict(observations, deterministic=True)
|
|
|
|
|
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
|
|
|
|
|
|
2020-04-20 13:59:44 +00:00
|
|
|
if actor_class is not None:
|
2020-03-31 16:26:26 +00:00
|
|
|
new_selected_actions_actor, _ = actor.predict(observations, deterministic=True)
|
|
|
|
|
assert np.allclose(selected_actions_actor, new_selected_actions_actor, 1e-4)
|
|
|
|
|
assert np.allclose(selected_actions_actor, new_selected_actions, 1e-4)
|
|
|
|
|
|
2020-03-31 14:29:13 +00:00
|
|
|
# clear file from os
|
2020-06-29 09:16:54 +00:00
|
|
|
os.remove(tmp_path / "policy.pkl")
|
2020-04-20 13:59:44 +00:00
|
|
|
if actor_class is not None:
|
2020-06-29 09:16:54 +00:00
|
|
|
os.remove(tmp_path / "actor.pkl")
|
2020-07-02 22:14:21 +00:00
|
|
|
|
|
|
|
|
|
2020-12-21 15:17:24 +00:00
|
|
|
@pytest.mark.parametrize("model_class", [DQN])
|
|
|
|
|
@pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"])
|
|
|
|
|
def test_save_load_q_net(tmp_path, model_class, policy_str):
|
|
|
|
|
"""
|
|
|
|
|
Test saving and loading q-network/quantile net only.
|
|
|
|
|
|
|
|
|
|
:param model_class: (BaseAlgorithm) A RL model
|
|
|
|
|
:param policy_str: (str) Name of the policy.
|
|
|
|
|
"""
|
|
|
|
|
kwargs = dict(policy_kwargs=dict(net_arch=[16]))
|
|
|
|
|
if policy_str == "MlpPolicy":
|
|
|
|
|
env = select_env(model_class)
|
|
|
|
|
else:
|
|
|
|
|
if model_class in [DQN]:
|
|
|
|
|
# Avoid memory error when using replay buffer
|
|
|
|
|
# Reduce the size of the features
|
|
|
|
|
kwargs = dict(
|
|
|
|
|
buffer_size=250,
|
|
|
|
|
learning_starts=100,
|
|
|
|
|
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
|
|
|
|
|
)
|
|
|
|
|
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN)
|
|
|
|
|
|
|
|
|
|
env = DummyVecEnv([lambda: env])
|
|
|
|
|
|
|
|
|
|
# create model
|
|
|
|
|
model = model_class(policy_str, env, verbose=1, **kwargs)
|
|
|
|
|
model.learn(total_timesteps=300)
|
|
|
|
|
|
|
|
|
|
env.reset()
|
|
|
|
|
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
|
|
|
|
|
|
|
|
|
|
q_net = model.q_net
|
|
|
|
|
q_net_class = q_net.__class__
|
|
|
|
|
|
|
|
|
|
# Get dictionary of current parameters
|
|
|
|
|
params = deepcopy(q_net.state_dict())
|
|
|
|
|
|
|
|
|
|
# Modify all parameters to be random values
|
|
|
|
|
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
|
|
|
|
|
|
|
|
|
|
# Update model parameters with the new random values
|
|
|
|
|
q_net.load_state_dict(random_params)
|
|
|
|
|
|
|
|
|
|
new_params = q_net.state_dict()
|
|
|
|
|
# Check that all params are different now
|
|
|
|
|
for k in params:
|
|
|
|
|
assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."
|
|
|
|
|
|
|
|
|
|
params = new_params
|
|
|
|
|
|
|
|
|
|
# get selected actions
|
|
|
|
|
selected_actions, _ = q_net.predict(observations, deterministic=True)
|
|
|
|
|
|
|
|
|
|
# Save and load q_net
|
|
|
|
|
q_net.save(tmp_path / "q_net.pkl")
|
|
|
|
|
|
|
|
|
|
del q_net
|
|
|
|
|
|
|
|
|
|
q_net = q_net_class.load(tmp_path / "q_net.pkl")
|
|
|
|
|
|
|
|
|
|
# check if params are still the same after load
|
|
|
|
|
new_params = q_net.state_dict()
|
|
|
|
|
|
|
|
|
|
# Check that all params are the same as before save load procedure now
|
|
|
|
|
for key in params:
|
|
|
|
|
assert th.allclose(params[key], new_params[key]), "Policy parameters not the same after save and load."
|
|
|
|
|
|
|
|
|
|
# check if model still selects the same actions
|
|
|
|
|
new_selected_actions, _ = q_net.predict(observations, deterministic=True)
|
|
|
|
|
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
|
|
|
|
|
|
|
|
|
|
# clear file from os
|
|
|
|
|
os.remove(tmp_path / "q_net.pkl")
|
|
|
|
|
|
|
|
|
|
|
2020-07-02 22:14:21 +00:00
|
|
|
@pytest.mark.parametrize("pathtype", [str, pathlib.Path])
|
|
|
|
|
def test_open_file_str_pathlib(tmp_path, pathtype):
|
|
|
|
|
# check that suffix isn't added because we used open_path first
|
|
|
|
|
with open_path(pathtype(f"{tmp_path}/t1"), "w") as fp1:
|
|
|
|
|
save_to_pkl(fp1, "foo")
|
|
|
|
|
assert fp1.closed
|
|
|
|
|
with pytest.warns(None) as record:
|
|
|
|
|
assert load_from_pkl(pathtype(f"{tmp_path}/t1")) == "foo"
|
|
|
|
|
assert not record
|
|
|
|
|
|
|
|
|
|
# test custom suffix
|
|
|
|
|
with open_path(pathtype(f"{tmp_path}/t1.custom_ext"), "w") as fp1:
|
|
|
|
|
save_to_pkl(fp1, "foo")
|
|
|
|
|
assert fp1.closed
|
|
|
|
|
with pytest.warns(None) as record:
|
|
|
|
|
assert load_from_pkl(pathtype(f"{tmp_path}/t1.custom_ext")) == "foo"
|
|
|
|
|
assert not record
|
|
|
|
|
|
|
|
|
|
# test without suffix
|
|
|
|
|
with open_path(pathtype(f"{tmp_path}/t1"), "w", suffix="pkl") as fp1:
|
|
|
|
|
save_to_pkl(fp1, "foo")
|
|
|
|
|
assert fp1.closed
|
|
|
|
|
with pytest.warns(None) as record:
|
|
|
|
|
assert load_from_pkl(pathtype(f"{tmp_path}/t1.pkl")) == "foo"
|
|
|
|
|
assert not record
|
|
|
|
|
|
|
|
|
|
# test that a warning is raised when the path doesn't exist
|
|
|
|
|
with open_path(pathtype(f"{tmp_path}/t2.pkl"), "w") as fp1:
|
|
|
|
|
save_to_pkl(fp1, "foo")
|
|
|
|
|
assert fp1.closed
|
|
|
|
|
with pytest.warns(None) as record:
|
2020-07-16 14:12:16 +00:00
|
|
|
assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl")) == "foo"
|
2020-07-02 22:14:21 +00:00
|
|
|
assert len(record) == 0
|
|
|
|
|
|
|
|
|
|
with pytest.warns(None) as record:
|
2020-07-16 14:12:16 +00:00
|
|
|
assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl", verbose=2)) == "foo"
|
2020-07-02 22:14:21 +00:00
|
|
|
assert len(record) == 1
|
|
|
|
|
|
|
|
|
|
fp = pathlib.Path(f"{tmp_path}/t2").open("w")
|
|
|
|
|
fp.write("rubbish")
|
|
|
|
|
fp.close()
|
|
|
|
|
# test that a warning is only raised when verbose = 0
|
|
|
|
|
with pytest.warns(None) as record:
|
|
|
|
|
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=0).close()
|
|
|
|
|
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=1).close()
|
|
|
|
|
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=2).close()
|
|
|
|
|
assert len(record) == 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_open_file(tmp_path):
|
|
|
|
|
|
|
|
|
|
# path must much the type
|
|
|
|
|
with pytest.raises(TypeError):
|
|
|
|
|
open_path(123, None, None, None)
|
|
|
|
|
|
|
|
|
|
p1 = tmp_path / "test1"
|
|
|
|
|
fp = p1.open("wb")
|
|
|
|
|
|
|
|
|
|
# provided path must match the mode
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
|
open_path(fp, "r")
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
|
open_path(fp, "randomstuff")
|
|
|
|
|
|
|
|
|
|
# test identity
|
|
|
|
|
_ = open_path(fp, "w")
|
|
|
|
|
assert _ is not None
|
|
|
|
|
assert fp is _
|
|
|
|
|
|
|
|
|
|
# Can't use a closed path
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
|
fp.close()
|
|
|
|
|
open_path(fp, "w")
|
|
|
|
|
|
|
|
|
|
buff = io.BytesIO()
|
|
|
|
|
assert buff.writable()
|
|
|
|
|
assert buff.readable() is ("w" == "w")
|
|
|
|
|
_ = open_path(buff, "w")
|
|
|
|
|
assert _ is buff
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
|
buff.close()
|
|
|
|
|
open_path(buff, "w")
|