mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-17 21:20:11 +00:00
* Created DQN template according to the paper. Next steps: - Create Policy - Complete Training - Debug * Changed Base Class * refactor save, to be consistence with overriding the excluded_save_params function. Do not try to exclude the parameters twice. * Added simple DQN policy * Finished learn and train function - missing correct loss computation * changed collect_rollouts to work with discrete space * moved discrete space collect_rollouts to dqn * basic dqn working * deleted SDE related code * added gradient clipping and moved greedy policy to policy * changed policy to implement target network and added soft update(in fact standart tau is 1 so hard update) * fixed policy setup * rebase target_update_intervall on _n_updates * adapted all tests all tests passing * Move to stable-baseline3 * Fixes for DQN * Fix tests + add CNNPolicy * Allow any optimizer for DQN * added some util functions to create a arbitrary linear schedule, fixed pickle problem with old exploration schedule * more documentation * changed buffer dtype * refactor and document * Added Sphinx Documentation Updated changelog.rst * removed custom collect_rollouts as it is no longer necessary * Implemented suggestions to clean code and documentation. * extracted some functions on tests to reduce duplicated code * added support for exploration_fraction * Fixed exploration_fraction * Added documentation * Fixed get_linear_fn -> proper progress scaling * Merged master * Added nature reference * Changed default parameters to https://www.nature.com/articles/nature14236/tables/1 * Fixed n_updates to be incremented correctly * Correct train_freq * Doc update * added special parameter for DQN in tests * different fix for test_discrete * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Added RMSProp in optimizer_kwargs, as described in nature paper * Exploration fraction is inverse of 50.000.000 (total frames) / 1.000.000 (frames with linear schedule) according to nature paper * Changelog update for buffer dtype * standard exlude parameters should be always excluded to assure proper saving only if intentionally included by ``include`` parameter * slightly more iterations on test_discrete to pass the test * added param use_rms_prop instead of mutable default argument * forgot alpha * using huber loss, adam and learning rate 1e-4 * account for train_freq in update_target_network * Added memory check for both buffers * Doc updated for buffer allocation * Added psutil Requirement * Adapted test_identity.py * Fixes with new SB3 version * Fix for tensorboard name * Convert assert to warning and fix tests * Refactor off-policy algorithms * Fixes * test: remove next_obs in replay buffer * Update changelog * Fix tests and use tmp_path where possible * Fix sampling bug in buffer * Do not store next obs on episode termination * Fix replay buffer sampling * Update comment * moved epsilon from policy to model * Update predict method * Update atari wrappers to match SB2 * Minor edit in the buffers * Update changelog * Merge branch 'master' into dqn * Update DQN to new structure * Fix tests and remove hardcoded path * Fix for DQN * Disable memory efficient replay buffer by default * Fix docstring * Add tests for memory efficient buffer * Update changelog * Split collect rollout * Move target update outside `train()` for DQN * Update changelog * Update linear schedule doc * Cleanup DQN code * Minor edit * Update version and docker images Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
54 lines
2 KiB
Python
54 lines
2 KiB
Python
import os
|
|
import shutil
|
|
|
|
import pytest
|
|
import gym
|
|
|
|
from stable_baselines3 import A2C, PPO, SAC, TD3, DQN
|
|
from stable_baselines3.common.callbacks import (CallbackList, CheckpointCallback, EvalCallback,
|
|
EveryNTimesteps, StopTrainingOnRewardThreshold)
|
|
|
|
|
|
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN])
|
|
def test_callbacks(tmp_path, model_class):
|
|
log_folder = tmp_path / 'logs/callbacks/'
|
|
|
|
# Dyn only support discrete actions
|
|
env_name = select_env(model_class)
|
|
# Create RL model
|
|
# Small network for fast test
|
|
model = model_class('MlpPolicy', env_name, policy_kwargs=dict(net_arch=[32]))
|
|
|
|
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_folder)
|
|
|
|
eval_env = gym.make(env_name)
|
|
# Stop training if the performance is good enough
|
|
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1)
|
|
|
|
eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best,
|
|
best_model_save_path=log_folder,
|
|
log_path=log_folder, eval_freq=100)
|
|
|
|
# Equivalent to the `checkpoint_callback`
|
|
# but here in an event-driven manner
|
|
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path=log_folder,
|
|
name_prefix='event')
|
|
event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)
|
|
|
|
callback = CallbackList([checkpoint_callback, eval_callback, event_callback])
|
|
|
|
model.learn(500, callback=callback)
|
|
model.learn(500, callback=None)
|
|
# Transform callback into a callback list automatically
|
|
model.learn(500, callback=[checkpoint_callback, eval_callback])
|
|
# Automatic wrapping, old way of doing callbacks
|
|
model.learn(500, callback=lambda _locals, _globals: True)
|
|
if os.path.exists(log_folder):
|
|
shutil.rmtree(log_folder)
|
|
|
|
|
|
def select_env(model_class) -> str:
|
|
if model_class is DQN:
|
|
return 'CartPole-v0'
|
|
else:
|
|
return 'Pendulum-v0'
|