stable-baselines3/tests/test_save_load.py
2019-11-28 11:20:40 +01:00

82 lines
2.6 KiB
Python

import os
import pytest
from copy import deepcopy
import torch as th
from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3
from torchy_baselines.common.vec_env import DummyVecEnv
from torchy_baselines.common.identity_env import IdentityEnvBox
MODEL_LIST = [
PPO,
#A2C,
#TD3,
#SAC,
]
@pytest.mark.parametrize("model_class", MODEL_LIST)
def test_save_load(model_class):
"""
Test if 'save' and 'load' saves and loads model correctly
and if 'load_parameters' and 'get_policy_parameters' work correctly
''warning does not test function of optimizer parameter load
:param model_class: (BaseRLModel) A RL model
"""
env = DummyVecEnv([lambda: IdentityEnvBox(10)])
# create model
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
model.learn(total_timesteps=1000, eval_freq=500)
# Get dictionary of current parameters
params = deepcopy(model.get_policy_parameters())
opt_params = deepcopy(model.get_opt_parameters())
# Modify all parameters to be random values
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
# Update model parameters with the new random values
model.load_parameters(random_params, opt_params)
new_params = model.get_policy_parameters()
# Check that all params are different now
for k in params:
assert not th.allclose(params[k], new_params[k]), "Selected actions did not change " \
"after changing model parameters."
params = new_params
# Check
model.save("test_save.zip")
del model
model = model_class.load("test_save")
# check if params are still the same after load
new_params = model.get_policy_parameters()
# Check that all params are the same as before save load procedure now
for k in params:
assert th.allclose(params[k], new_params[k]), "Model parameters not the same after save and load."
# check if optimizer params are still the same after load
new_opt_params = model.get_opt_parameters()
# check if keys are the same
assert opt_params.keys() == new_opt_params.keys()
# check if values are the same: only tested for Adam and RMSProp so far
for optimizer,opt_state in opt_params.items():
for step_entry, entry_dict in opt_state['state'].items():
for value_key,value in entry_dict.items():
print(value == new_opt_params[optimizer][step_entry][value_key])
# check if learn still works
model.learn(total_timesteps=1000, eval_freq=500)
# clear file from os
os.remove("test_save.zip")