mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Add test for actor
This commit is contained in:
parent
2c5e41ec47
commit
71ce9ef2f4
1 changed files with 19 additions and 1 deletions
|
|
@ -180,6 +180,9 @@ def test_save_load_policy(model_class):
|
|||
observations = observations.reshape(10, -1)
|
||||
|
||||
policy = model.policy
|
||||
actor = None
|
||||
if model_class in [SAC, TD3]:
|
||||
actor = policy.actor
|
||||
|
||||
# Get dictionary of current parameters
|
||||
params = deepcopy(policy.state_dict())
|
||||
|
|
@ -199,11 +202,19 @@ def test_save_load_policy(model_class):
|
|||
|
||||
# get selected actions
|
||||
selected_actions, _ = policy.predict(observations, deterministic=True)
|
||||
# Should also work with the actor only
|
||||
if actor is not None:
|
||||
selected_actions_actor, _ = actor.predict(observations, deterministic=True)
|
||||
|
||||
# Save and load policy
|
||||
policy.save("./logs/policy_weights.pkl")
|
||||
# del policy
|
||||
# Save and load actor
|
||||
if actor is not None:
|
||||
actor.save("./logs/actor_weights.pkl")
|
||||
|
||||
policy.load("./logs/policy_weights.pkl")
|
||||
if actor is not None:
|
||||
actor.load("./logs/actor_weights.pkl")
|
||||
|
||||
# check if params are still the same after load
|
||||
new_params = policy.state_dict()
|
||||
|
|
@ -216,5 +227,12 @@ def test_save_load_policy(model_class):
|
|||
new_selected_actions, _ = policy.predict(observations, deterministic=True)
|
||||
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
|
||||
|
||||
if actor is not None:
|
||||
new_selected_actions_actor, _ = actor.predict(observations, deterministic=True)
|
||||
assert np.allclose(selected_actions_actor, new_selected_actions_actor, 1e-4)
|
||||
assert np.allclose(selected_actions_actor, new_selected_actions, 1e-4)
|
||||
|
||||
# clear file from os
|
||||
os.remove("./logs/policy_weights.pkl")
|
||||
if actor is not None:
|
||||
os.remove("./logs/actor_weights.pkl")
|
||||
|
|
|
|||
Loading…
Reference in a new issue