diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 533c548..96db45a 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -1,37 +1,49 @@ -import pytest +import os +import shutil -from torchy_baselines import SAC +import pytest +import gym + +from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3 from torchy_baselines.common.callbacks import (CallbackList, CheckpointCallback, EvalCallback, EveryNTimesteps, StopTrainingOnRewardThreshold) -@pytest.mark.parametrize("model_class", [SAC]) +@pytest.mark.parametrize("model_class", [A2C, CEMRL, PPO, SAC, TD3]) def test_callbacks(model_class): + log_folder = './logs/callbacks/' + kwargs = {} + if model_class == CEMRL: + kwargs['pop_size'] = 2 + kwargs['n_grad'] = 1 + # Create RL model - model = model_class('MlpPolicy', 'Pendulum-v0') + # Small network for fast test + model = model_class('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[32]), **kwargs) - checkpoint_callback = CheckpointCallback(save_freq=1000, save_path='./logs/') + checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_folder) - # For testing: use the same training env - eval_env = model.get_env() + eval_env = gym.make('Pendulum-v0') # Stop training if the performance is good enough callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1) eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best, - best_model_save_path='./logs/best_model', - log_path='./logs/results', eval_freq=100) + best_model_save_path=log_folder, + log_path=log_folder, eval_freq=100) # Equivalent to the `checkpoint_callback` # but here in an event-driven manner - checkpoint_on_event = CheckpointCallback(save_freq=1, save_path='./logs/', + checkpoint_on_event = CheckpointCallback(save_freq=1, save_path=log_folder, name_prefix='event') - event_callback = EveryNTimesteps(n_steps=1000, callback=checkpoint_on_event) + event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event) callback = CallbackList([checkpoint_callback, eval_callback, event_callback]) - model.learn(1000, callback=callback) + model.learn(500, callback=callback) model.learn(500, callback=None) # Transform callback into a callback list automatically model.learn(500, callback=[checkpoint_callback, eval_callback]) # Automatic wrapping, old way of doing callbacks model.learn(500, callback=lambda _locals, _globals : True) + if os.path.exists(log_folder): + shutil.rmtree(log_folder) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index b4dd7b7..5171042 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -69,7 +69,7 @@ def test_save_load(model_class): # check if model still selects the same actions new_selected_actions = [model.predict(observation, deterministic=True) for observation in observations] - assert np.allclose(selected_actions, new_selected_actions) + assert np.allclose(selected_actions, new_selected_actions, 1e-4) # check if learn still works model.learn(total_timesteps=1000, eval_freq=500) diff --git a/torchy_baselines/cem_rl/cem_rl.py b/torchy_baselines/cem_rl/cem_rl.py index 35c6e3c..653b3e2 100644 --- a/torchy_baselines/cem_rl/cem_rl.py +++ b/torchy_baselines/cem_rl/cem_rl.py @@ -160,7 +160,6 @@ class CEMRL(TD3): actor_steps = 0 # evaluate all actors for params in self.es_params: - self.actor.load_from_vector(params) rollout = self.collect_rollouts(self.env, n_episodes=self.n_episodes_rollout, diff --git a/torchy_baselines/common/callbacks.py b/torchy_baselines/common/callbacks.py index 89b1e22..346f115 100644 --- a/torchy_baselines/common/callbacks.py +++ b/torchy_baselines/common/callbacks.py @@ -210,7 +210,7 @@ class EvalCallback(EventCallback): when there is a new best model according to the `mean_reward` :param n_eval_episodes: (int) The number of episodes to test the agent :param eval_freq: (int) Evaluate the agent every eval_freq call of the callback. - :param log_path: (str) Path to a log file (.npz) where the evaluations + :param log_path: (str) Path to a folder where the evaluations (`evaluations.npz`) will be saved. It will be updated at each evaluation. :param best_model_save_path: (str) Path to a folder where the best model according to performance on the eval env will be saved. @@ -240,7 +240,7 @@ class EvalCallback(EventCallback): self.eval_env = eval_env self.best_model_save_path = best_model_save_path - self.log_path = log_path + self.log_path = os.path.join(log_path, 'evaluations') self.evaluations_results = [] self.evaluations_timesteps = [] self.evaluations_length = []