diff --git a/tests/test_save_load.py b/tests/test_save_load.py index a10f721..3bc7d23 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -180,6 +180,9 @@ def test_save_load_policy(model_class): observations = observations.reshape(10, -1) policy = model.policy + actor = None + if model_class in [SAC, TD3]: + actor = policy.actor # Get dictionary of current parameters params = deepcopy(policy.state_dict()) @@ -199,11 +202,19 @@ def test_save_load_policy(model_class): # get selected actions selected_actions, _ = policy.predict(observations, deterministic=True) + # Should also work with the actor only + if actor is not None: + selected_actions_actor, _ = actor.predict(observations, deterministic=True) # Save and load policy policy.save("./logs/policy_weights.pkl") - # del policy + # Save and load actor + if actor is not None: + actor.save("./logs/actor_weights.pkl") + policy.load("./logs/policy_weights.pkl") + if actor is not None: + actor.load("./logs/actor_weights.pkl") # check if params are still the same after load new_params = policy.state_dict() @@ -216,5 +227,12 @@ def test_save_load_policy(model_class): new_selected_actions, _ = policy.predict(observations, deterministic=True) assert np.allclose(selected_actions, new_selected_actions, 1e-4) + if actor is not None: + new_selected_actions_actor, _ = actor.predict(observations, deterministic=True) + assert np.allclose(selected_actions_actor, new_selected_actions_actor, 1e-4) + assert np.allclose(selected_actions_actor, new_selected_actions, 1e-4) + # clear file from os os.remove("./logs/policy_weights.pkl") + if actor is not None: + os.remove("./logs/actor_weights.pkl")