stable-baselines3/tests/test_identity.py

41 lines
1.2 KiB
Python

import numpy as np
import pytest
from torchy_baselines import A2C, PPO, SAC, TD3
from torchy_baselines.common.identity_env import IdentityEnvBox, IdentityEnv
from torchy_baselines.common.evaluation import evaluate_policy
from torchy_baselines.common.noise import NormalActionNoise
@pytest.mark.parametrize("model_class", [A2C, PPO])
def test_discrete(model_class):
env = IdentityEnv(10)
model = model_class('MlpPolicy', env, gamma=0.5, seed=0).learn(3000)
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90)
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3])
def test_continuous(model_class):
env = IdentityEnvBox(eps=0.5)
n_steps = {
A2C: 3500,
PPO: 3000,
SAC: 700,
TD3: 500
}[model_class]
kwargs = dict(
policy_kwargs=dict(net_arch=[64, 64]),
seed=0,
gamma=0.95
)
if model_class in [TD3]:
n_actions = 1
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
kwargs['action_noise'] = action_noise
model = model_class('MlpPolicy', env, **kwargs).learn(n_steps)
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90)