stable-baselines3/tests/test_callbacks.py

50 lines
1.9 KiB
Python
Raw Normal View History

import os
import shutil
2020-01-27 13:32:31 +00:00
import pytest
import gym
2020-01-27 13:32:31 +00:00
from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3
2020-01-27 13:32:31 +00:00
from torchy_baselines.common.callbacks import (CallbackList, CheckpointCallback, EvalCallback,
EveryNTimesteps, StopTrainingOnRewardThreshold)
@pytest.mark.parametrize("model_class", [A2C, CEMRL, PPO, SAC, TD3])
2020-01-27 13:32:31 +00:00
def test_callbacks(model_class):
log_folder = './logs/callbacks/'
kwargs = {}
if model_class == CEMRL:
kwargs['pop_size'] = 2
kwargs['n_grad'] = 1
2020-01-27 13:32:31 +00:00
# Create RL model
# Small network for fast test
model = model_class('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[32]), **kwargs)
2020-01-27 13:32:31 +00:00
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_folder)
2020-01-27 13:32:31 +00:00
eval_env = gym.make('Pendulum-v0')
2020-01-27 13:32:31 +00:00
# Stop training if the performance is good enough
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1)
eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best,
best_model_save_path=log_folder,
log_path=log_folder, eval_freq=100)
2020-01-27 13:32:31 +00:00
# Equivalent to the `checkpoint_callback`
# but here in an event-driven manner
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path=log_folder,
2020-01-27 13:32:31 +00:00
name_prefix='event')
event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)
2020-01-27 13:32:31 +00:00
callback = CallbackList([checkpoint_callback, eval_callback, event_callback])
model.learn(500, callback=callback)
2020-01-27 13:32:31 +00:00
model.learn(500, callback=None)
# Transform callback into a callback list automatically
model.learn(500, callback=[checkpoint_callback, eval_callback])
# Automatic wrapping, old way of doing callbacks
model.learn(500, callback=lambda _locals, _globals : True)
if os.path.exists(log_folder):
shutil.rmtree(log_folder)