2020-01-31 12:42:04 +00:00
|
|
|
import os
|
|
|
|
|
import shutil
|
|
|
|
|
|
2020-01-27 13:32:31 +00:00
|
|
|
import pytest
|
2020-01-31 12:42:04 +00:00
|
|
|
import gym
|
2020-01-27 13:32:31 +00:00
|
|
|
|
2020-06-29 09:16:54 +00:00
|
|
|
from stable_baselines3 import A2C, PPO, SAC, TD3, DQN
|
2020-05-05 13:02:35 +00:00
|
|
|
from stable_baselines3.common.callbacks import (CallbackList, CheckpointCallback, EvalCallback,
|
2020-05-15 11:54:06 +00:00
|
|
|
EveryNTimesteps, StopTrainingOnRewardThreshold)
|
2020-01-27 13:32:31 +00:00
|
|
|
|
|
|
|
|
|
2020-06-29 09:16:54 +00:00
|
|
|
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN])
|
|
|
|
|
def test_callbacks(tmp_path, model_class):
|
|
|
|
|
log_folder = tmp_path / 'logs/callbacks/'
|
|
|
|
|
|
|
|
|
|
# Dyn only support discrete actions
|
|
|
|
|
env_name = select_env(model_class)
|
2020-01-27 13:32:31 +00:00
|
|
|
# Create RL model
|
2020-01-31 12:42:04 +00:00
|
|
|
# Small network for fast test
|
2020-06-29 09:16:54 +00:00
|
|
|
model = model_class('MlpPolicy', env_name, policy_kwargs=dict(net_arch=[32]))
|
2020-01-27 13:32:31 +00:00
|
|
|
|
2020-01-31 12:42:04 +00:00
|
|
|
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_folder)
|
2020-01-27 13:32:31 +00:00
|
|
|
|
2020-06-29 09:16:54 +00:00
|
|
|
eval_env = gym.make(env_name)
|
2020-01-27 13:32:31 +00:00
|
|
|
# Stop training if the performance is good enough
|
|
|
|
|
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1)
|
|
|
|
|
|
|
|
|
|
eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best,
|
2020-01-31 12:42:04 +00:00
|
|
|
best_model_save_path=log_folder,
|
|
|
|
|
log_path=log_folder, eval_freq=100)
|
2020-01-27 13:32:31 +00:00
|
|
|
|
|
|
|
|
# Equivalent to the `checkpoint_callback`
|
|
|
|
|
# but here in an event-driven manner
|
2020-01-31 12:42:04 +00:00
|
|
|
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path=log_folder,
|
2020-01-27 13:32:31 +00:00
|
|
|
name_prefix='event')
|
2020-01-31 12:42:04 +00:00
|
|
|
event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)
|
2020-01-27 13:32:31 +00:00
|
|
|
|
|
|
|
|
callback = CallbackList([checkpoint_callback, eval_callback, event_callback])
|
|
|
|
|
|
2020-01-31 12:42:04 +00:00
|
|
|
model.learn(500, callback=callback)
|
2020-01-27 13:32:31 +00:00
|
|
|
model.learn(500, callback=None)
|
|
|
|
|
# Transform callback into a callback list automatically
|
|
|
|
|
model.learn(500, callback=[checkpoint_callback, eval_callback])
|
|
|
|
|
# Automatic wrapping, old way of doing callbacks
|
2020-03-12 10:12:10 +00:00
|
|
|
model.learn(500, callback=lambda _locals, _globals: True)
|
2020-01-31 12:42:04 +00:00
|
|
|
if os.path.exists(log_folder):
|
|
|
|
|
shutil.rmtree(log_folder)
|
2020-06-29 09:16:54 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def select_env(model_class) -> str:
|
|
|
|
|
if model_class is DQN:
|
|
|
|
|
return 'CartPole-v0'
|
|
|
|
|
else:
|
|
|
|
|
return 'Pendulum-v0'
|