mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-30 03:38:13 +00:00
Fix tests and change log_path behavior for EvalCallback
This commit is contained in:
parent
5d4e73544c
commit
ec657cc34e
4 changed files with 27 additions and 16 deletions
|
|
@ -1,37 +1,49 @@
|
|||
import pytest
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from torchy_baselines import SAC
|
||||
import pytest
|
||||
import gym
|
||||
|
||||
from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3
|
||||
from torchy_baselines.common.callbacks import (CallbackList, CheckpointCallback, EvalCallback,
|
||||
EveryNTimesteps, StopTrainingOnRewardThreshold)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [SAC])
|
||||
@pytest.mark.parametrize("model_class", [A2C, CEMRL, PPO, SAC, TD3])
|
||||
def test_callbacks(model_class):
|
||||
log_folder = './logs/callbacks/'
|
||||
kwargs = {}
|
||||
if model_class == CEMRL:
|
||||
kwargs['pop_size'] = 2
|
||||
kwargs['n_grad'] = 1
|
||||
|
||||
# Create RL model
|
||||
model = model_class('MlpPolicy', 'Pendulum-v0')
|
||||
# Small network for fast test
|
||||
model = model_class('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[32]), **kwargs)
|
||||
|
||||
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path='./logs/')
|
||||
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_folder)
|
||||
|
||||
# For testing: use the same training env
|
||||
eval_env = model.get_env()
|
||||
eval_env = gym.make('Pendulum-v0')
|
||||
# Stop training if the performance is good enough
|
||||
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1)
|
||||
|
||||
eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best,
|
||||
best_model_save_path='./logs/best_model',
|
||||
log_path='./logs/results', eval_freq=100)
|
||||
best_model_save_path=log_folder,
|
||||
log_path=log_folder, eval_freq=100)
|
||||
|
||||
# Equivalent to the `checkpoint_callback`
|
||||
# but here in an event-driven manner
|
||||
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path='./logs/',
|
||||
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path=log_folder,
|
||||
name_prefix='event')
|
||||
event_callback = EveryNTimesteps(n_steps=1000, callback=checkpoint_on_event)
|
||||
event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)
|
||||
|
||||
callback = CallbackList([checkpoint_callback, eval_callback, event_callback])
|
||||
|
||||
model.learn(1000, callback=callback)
|
||||
model.learn(500, callback=callback)
|
||||
model.learn(500, callback=None)
|
||||
# Transform callback into a callback list automatically
|
||||
model.learn(500, callback=[checkpoint_callback, eval_callback])
|
||||
# Automatic wrapping, old way of doing callbacks
|
||||
model.learn(500, callback=lambda _locals, _globals : True)
|
||||
if os.path.exists(log_folder):
|
||||
shutil.rmtree(log_folder)
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ def test_save_load(model_class):
|
|||
|
||||
# check if model still selects the same actions
|
||||
new_selected_actions = [model.predict(observation, deterministic=True) for observation in observations]
|
||||
assert np.allclose(selected_actions, new_selected_actions)
|
||||
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
|
||||
|
||||
# check if learn still works
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
|
|
|
|||
|
|
@ -160,7 +160,6 @@ class CEMRL(TD3):
|
|||
actor_steps = 0
|
||||
# evaluate all actors
|
||||
for params in self.es_params:
|
||||
|
||||
self.actor.load_from_vector(params)
|
||||
|
||||
rollout = self.collect_rollouts(self.env, n_episodes=self.n_episodes_rollout,
|
||||
|
|
|
|||
|
|
@ -210,7 +210,7 @@ class EvalCallback(EventCallback):
|
|||
when there is a new best model according to the `mean_reward`
|
||||
:param n_eval_episodes: (int) The number of episodes to test the agent
|
||||
:param eval_freq: (int) Evaluate the agent every eval_freq call of the callback.
|
||||
:param log_path: (str) Path to a log file (.npz) where the evaluations
|
||||
:param log_path: (str) Path to a folder where the evaluations (`evaluations.npz`)
|
||||
will be saved. It will be updated at each evaluation.
|
||||
:param best_model_save_path: (str) Path to a folder where the best model
|
||||
according to performance on the eval env will be saved.
|
||||
|
|
@ -240,7 +240,7 @@ class EvalCallback(EventCallback):
|
|||
|
||||
self.eval_env = eval_env
|
||||
self.best_model_save_path = best_model_save_path
|
||||
self.log_path = log_path
|
||||
self.log_path = os.path.join(log_path, 'evaluations')
|
||||
self.evaluations_results = []
|
||||
self.evaluations_timesteps = []
|
||||
self.evaluations_length = []
|
||||
|
|
|
|||
Loading…
Reference in a new issue