mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-18 21:30:19 +00:00
* fix Atari in CI * fix dtype and atari extra * Update setup.py * remove 3.6 * note about how to install Atari * pendulum-v1 * atari v5 * black * fix pendulum capitalization * add minimum version * moved things in changelog to breaking changes * partial v5 fix * env update to pass tests * mismatch env version fixed * Fix tests after merge * Include autorom in setup.py * Blacken code * Fix dtype issue in more robust way * Fix GitLab CI: switch to Docker container with new black version * Remove workaround from GitLab. (May need to rebuild Docker for this though.) * Revert to v4 * Update setup.py * Apply suggestions from code review * Remove unnecessary autorom * Consistent gym versions Co-authored-by: J K Terry <justinkterry@gmail.com> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: modanesh <mohamad4danesh@gmail.com> Co-authored-by: Adam Gleave <adam@gleave.me>
37 lines
1.1 KiB
Python
37 lines
1.1 KiB
Python
import os
|
|
|
|
import pytest
|
|
|
|
from stable_baselines3 import A2C, PPO, SAC, TD3
|
|
|
|
MODEL_DICT = {
|
|
"a2c": (A2C, "CartPole-v1"),
|
|
"ppo": (PPO, "CartPole-v1"),
|
|
"sac": (SAC, "Pendulum-v1"),
|
|
"td3": (TD3, "Pendulum-v1"),
|
|
}
|
|
|
|
N_STEPS = 100
|
|
|
|
|
|
@pytest.mark.parametrize("model_name", MODEL_DICT.keys())
|
|
def test_tensorboard(tmp_path, model_name):
|
|
# Skip if no tensorboard installed
|
|
pytest.importorskip("tensorboard")
|
|
|
|
logname = model_name.upper()
|
|
algo, env_id = MODEL_DICT[model_name]
|
|
model = algo("MlpPolicy", env_id, verbose=1, tensorboard_log=tmp_path)
|
|
model.learn(N_STEPS)
|
|
model.learn(N_STEPS, reset_num_timesteps=False)
|
|
|
|
assert os.path.isdir(tmp_path / str(logname + "_1"))
|
|
assert not os.path.isdir(tmp_path / str(logname + "_2"))
|
|
|
|
logname = "tb_multiple_runs_" + model_name
|
|
model.learn(N_STEPS, tb_log_name=logname)
|
|
model.learn(N_STEPS, tb_log_name=logname)
|
|
|
|
assert os.path.isdir(tmp_path / str(logname + "_1"))
|
|
# Check that the log dir name increments correctly
|
|
assert os.path.isdir(tmp_path / str(logname + "_2"))
|