stable-baselines3/tests/test_deterministic.py
Carlos Luis 5143cd19f7
Gym fixes - Follow up from #705 (#734)
* fix Atari in CI

* fix dtype and atari extra

* Update setup.py

* remove 3.6

* note about how to install Atari

* pendulum-v1

* atari v5

* black

* fix pendulum capitalization

* add minimum version

* moved things in changelog to breaking changes

* partial v5 fix

* env update to pass tests

* mismatch env version fixed

* Fix tests after merge

* Include autorom in setup.py

* Blacken code

* Fix dtype issue in more robust way

* Fix GitLab CI: switch to Docker container with new black version

* Remove workaround from GitLab. (May need to rebuild Docker for this though.)

* Revert to v4

* Update setup.py

* Apply suggestions from code review

* Remove unnecessary autorom

* Consistent gym versions

Co-authored-by: J K Terry <justinkterry@gmail.com>
Co-authored-by: Anssi <kaneran21@hotmail.com>
Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
Co-authored-by: modanesh <mohamad4danesh@gmail.com>
Co-authored-by: Adam Gleave <adam@gleave.me>
2022-02-04 15:13:57 -08:00

37 lines
1.3 KiB
Python

import pytest
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.noise import NormalActionNoise
N_STEPS_TRAINING = 500
SEED = 0
@pytest.mark.parametrize("algo", [A2C, DQN, PPO, SAC, TD3])
def test_deterministic_training_common(algo):
results = [[], []]
rewards = [[], []]
# Smaller network
kwargs = {"policy_kwargs": dict(net_arch=[64])}
env_id = "Pendulum-v1"
if algo in [TD3, SAC]:
kwargs.update({"action_noise": NormalActionNoise(0.0, 0.1), "learning_starts": 100, "train_freq": 4})
else:
if algo == DQN:
env_id = "CartPole-v1"
kwargs.update({"learning_starts": 100, "target_update_interval": 100})
elif algo == PPO:
kwargs.update({"n_steps": 64, "n_epochs": 4})
for i in range(2):
model = algo("MlpPolicy", env_id, seed=SEED, **kwargs)
model.learn(N_STEPS_TRAINING)
env = model.get_env()
obs = env.reset()
for _ in range(100):
action, _ = model.predict(obs, deterministic=False)
obs, reward, _, _ = env.step(action)
results[i].append(action)
rewards[i].append(reward)
assert sum(results[0]) == sum(results[1]), results
assert sum(rewards[0]) == sum(rewards[1]), rewards