stable-baselines3/tests/test_callbacks.py
Antonin RAFFIN 15ff6d47ee
Documentation update and style fixes (#21)
* Update doc: add gSDE

* Fix codestyle

* Remove travis script

* Add lint check to gitlab
2020-05-15 13:54:06 +02:00

44 lines
1.8 KiB
Python

import os
import shutil
import pytest
import gym
from stable_baselines3 import A2C, PPO, SAC, TD3
from stable_baselines3.common.callbacks import (CallbackList, CheckpointCallback, EvalCallback,
EveryNTimesteps, StopTrainingOnRewardThreshold)
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3])
def test_callbacks(model_class):
log_folder = './logs/callbacks/'
# Create RL model
# Small network for fast test
model = model_class('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[32]))
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_folder)
eval_env = gym.make('Pendulum-v0')
# Stop training if the performance is good enough
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1)
eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best,
best_model_save_path=log_folder,
log_path=log_folder, eval_freq=100)
# Equivalent to the `checkpoint_callback`
# but here in an event-driven manner
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path=log_folder,
name_prefix='event')
event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)
callback = CallbackList([checkpoint_callback, eval_callback, event_callback])
model.learn(500, callback=callback)
model.learn(500, callback=None)
# Transform callback into a callback list automatically
model.learn(500, callback=[checkpoint_callback, eval_callback])
# Automatic wrapping, old way of doing callbacks
model.learn(500, callback=lambda _locals, _globals: True)
if os.path.exists(log_folder):
shutil.rmtree(log_folder)