mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-16 21:10:08 +00:00
64 lines
2.3 KiB
Python
64 lines
2.3 KiB
Python
import os
|
|
import pytest
|
|
from copy import deepcopy
|
|
|
|
import torch as th
|
|
|
|
from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3
|
|
from torchy_baselines.common.vec_env import DummyVecEnv
|
|
from torchy_baselines.common.identity_env import IdentityEnvBox
|
|
|
|
MODEL_LIST = [
|
|
PPO
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
|
def test_save_load(model_class):
|
|
"""
|
|
Test if 'save' and 'load' saves and loads model correctly
|
|
and if 'load_parameters' and 'get_policy_parameters' work correctly
|
|
|
|
''warning does not test function of optimizer parameter load
|
|
|
|
:param model_class: (BaseRLModel) A RL model
|
|
"""
|
|
env = DummyVecEnv([lambda: IdentityEnvBox(10)])
|
|
|
|
# create model
|
|
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
|
|
|
|
# Get dictionary of current parameters
|
|
params = deepcopy(model.get_policy_parameters())
|
|
opt_params = deepcopy(model.get_opt_parameters())
|
|
|
|
# Modify all parameters to be random values
|
|
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
|
|
|
|
# Update model parameters with the new random values
|
|
model.load_parameters(random_params, opt_params)
|
|
|
|
# Get items that are the same in params and new_params
|
|
new_params = model.get_policy_parameters()
|
|
shared_items = {k: params[k] for k in params if k in new_params and th.all(th.eq(params[k], new_params[k]))}
|
|
|
|
# Check that the there are at least some parameters new random parameters
|
|
#for k in params.key():
|
|
# assert not th.allclose(params[k], new_params[k])
|
|
assert not len(shared_items) == len(new_params), "Selected actions did not change " \
|
|
"after changing model parameters."
|
|
|
|
params = new_params
|
|
|
|
# Check
|
|
model.learn(total_timesteps=1000, eval_freq=500)
|
|
model.save("test_save.zip")
|
|
del model
|
|
model = model_class.load("test_save")
|
|
|
|
#check if params are still the same after load
|
|
new_params = model.get_policy_parameters()
|
|
shared_items = {k: params[k] for k in params if k in new_params and th.all(th.eq(params[k], new_params[k]))}
|
|
# Check that at least some actions are chosen different now
|
|
assert len(shared_items) == len(new_params), "Parameters not the same after save and load."
|
|
os.remove("test_save.zip")
|