stable-baselines3/tests/test_deterministic.py
Sidney Tio 489b1fdaf2
Add the argument dtype (default to float32) to the noise (#1301)
* Fixed noise to return float32

* Updated changelog

* Fixed test to use numpy arrays instead of python floats

* Sorted imports for tests

* Added dtype to constructor

* Removed dtype parameter for VectorizedActionNoise

* __init__ -> None; Capitalize and period in docstring when needed; fix dtype type hint; dtype in docstring

* fix dtype type hint

* Update version

* Clarify changelog [skip ci]

* empty commit to run ci

* Update docs/misc/changelog.rst

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2023-02-07 13:42:14 +01:00

40 lines
1.3 KiB
Python

import numpy as np
import pytest
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.noise import NormalActionNoise
N_STEPS_TRAINING = 500
SEED = 0
@pytest.mark.parametrize("algo", [A2C, DQN, PPO, SAC, TD3])
def test_deterministic_training_common(algo):
results = [[], []]
rewards = [[], []]
# Smaller network
kwargs = {"policy_kwargs": dict(net_arch=[64])}
env_id = "Pendulum-v1"
if algo in [TD3, SAC]:
kwargs.update(
{"action_noise": NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)), "learning_starts": 100, "train_freq": 4}
)
else:
if algo == DQN:
env_id = "CartPole-v1"
kwargs.update({"learning_starts": 100, "target_update_interval": 100})
elif algo == PPO:
kwargs.update({"n_steps": 64, "n_epochs": 4})
for i in range(2):
model = algo("MlpPolicy", env_id, seed=SEED, **kwargs)
model.learn(N_STEPS_TRAINING)
env = model.get_env()
obs = env.reset()
for _ in range(100):
action, _ = model.predict(obs, deterministic=False)
obs, reward, _, _ = env.step(action)
results[i].append(action)
rewards[i].append(reward)
assert sum(results[0]) == sum(results[1]), results
assert sum(rewards[0]) == sum(rewards[1]), rewards