Add test for actor

This commit is contained in:
Antonin RAFFIN 2020-03-31 18:26:26 +02:00
parent 2c5e41ec47
commit 71ce9ef2f4

View file

@ -180,6 +180,9 @@ def test_save_load_policy(model_class):
observations = observations.reshape(10, -1)
policy = model.policy
actor = None
if model_class in [SAC, TD3]:
actor = policy.actor
# Get dictionary of current parameters
params = deepcopy(policy.state_dict())
@ -199,11 +202,19 @@ def test_save_load_policy(model_class):
# get selected actions
selected_actions, _ = policy.predict(observations, deterministic=True)
# Should also work with the actor only
if actor is not None:
selected_actions_actor, _ = actor.predict(observations, deterministic=True)
# Save and load policy
policy.save("./logs/policy_weights.pkl")
# del policy
# Save and load actor
if actor is not None:
actor.save("./logs/actor_weights.pkl")
policy.load("./logs/policy_weights.pkl")
if actor is not None:
actor.load("./logs/actor_weights.pkl")
# check if params are still the same after load
new_params = policy.state_dict()
@ -216,5 +227,12 @@ def test_save_load_policy(model_class):
new_selected_actions, _ = policy.predict(observations, deterministic=True)
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
if actor is not None:
new_selected_actions_actor, _ = actor.predict(observations, deterministic=True)
assert np.allclose(selected_actions_actor, new_selected_actions_actor, 1e-4)
assert np.allclose(selected_actions_actor, new_selected_actions, 1e-4)
# clear file from os
os.remove("./logs/policy_weights.pkl")
if actor is not None:
os.remove("./logs/actor_weights.pkl")