stable-baselines3/tests/test_callbacks.py

45 lines
1.8 KiB
Python
Raw Normal View History

import os
import shutil
2020-01-27 13:32:31 +00:00
import pytest
import gym
2020-01-27 13:32:31 +00:00
2020-05-05 13:02:35 +00:00
from stable_baselines3 import A2C, PPO, SAC, TD3
from stable_baselines3.common.callbacks import (CallbackList, CheckpointCallback, EvalCallback,
2020-03-12 10:12:10 +00:00
EveryNTimesteps, StopTrainingOnRewardThreshold)
2020-01-27 13:32:31 +00:00
2020-03-23 13:48:38 +00:00
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3])
2020-01-27 13:32:31 +00:00
def test_callbacks(model_class):
log_folder = './logs/callbacks/'
2020-01-27 13:32:31 +00:00
# Create RL model
# Small network for fast test
2020-03-23 13:48:38 +00:00
model = model_class('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[32]))
2020-01-27 13:32:31 +00:00
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_folder)
2020-01-27 13:32:31 +00:00
eval_env = gym.make('Pendulum-v0')
2020-01-27 13:32:31 +00:00
# Stop training if the performance is good enough
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1)
eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best,
best_model_save_path=log_folder,
log_path=log_folder, eval_freq=100)
2020-01-27 13:32:31 +00:00
# Equivalent to the `checkpoint_callback`
# but here in an event-driven manner
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path=log_folder,
2020-01-27 13:32:31 +00:00
name_prefix='event')
event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)
2020-01-27 13:32:31 +00:00
callback = CallbackList([checkpoint_callback, eval_callback, event_callback])
model.learn(500, callback=callback)
2020-01-27 13:32:31 +00:00
model.learn(500, callback=None)
# Transform callback into a callback list automatically
model.learn(500, callback=[checkpoint_callback, eval_callback])
# Automatic wrapping, old way of doing callbacks
2020-03-12 10:12:10 +00:00
model.learn(500, callback=lambda _locals, _globals: True)
if os.path.exists(log_folder):
shutil.rmtree(log_folder)