stable-baselines3/tests/test_callbacks.py
Antonin Raffin 18f38f8cf5 Reformat
2020-03-12 11:12:10 +01:00

49 lines
1.9 KiB
Python

import os
import shutil
import pytest
import gym
from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3
from torchy_baselines.common.callbacks import (CallbackList, CheckpointCallback, EvalCallback,
EveryNTimesteps, StopTrainingOnRewardThreshold)
@pytest.mark.parametrize("model_class", [A2C, CEMRL, PPO, SAC, TD3])
def test_callbacks(model_class):
log_folder = './logs/callbacks/'
kwargs = {}
if model_class == CEMRL:
kwargs['pop_size'] = 2
kwargs['n_grad'] = 1
# Create RL model
# Small network for fast test
model = model_class('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[32]), **kwargs)
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_folder)
eval_env = gym.make('Pendulum-v0')
# Stop training if the performance is good enough
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1)
eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best,
best_model_save_path=log_folder,
log_path=log_folder, eval_freq=100)
# Equivalent to the `checkpoint_callback`
# but here in an event-driven manner
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path=log_folder,
name_prefix='event')
event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)
callback = CallbackList([checkpoint_callback, eval_callback, event_callback])
model.learn(500, callback=callback)
model.learn(500, callback=None)
# Transform callback into a callback list automatically
model.learn(500, callback=[checkpoint_callback, eval_callback])
# Automatic wrapping, old way of doing callbacks
model.learn(500, callback=lambda _locals, _globals: True)
if os.path.exists(log_folder):
shutil.rmtree(log_folder)