mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-18 21:30:19 +00:00
41 lines
1.2 KiB
Python
41 lines
1.2 KiB
Python
import numpy as np
|
|
import pytest
|
|
|
|
from torchy_baselines import A2C, PPO, SAC, TD3
|
|
from torchy_baselines.common.identity_env import IdentityEnvBox, IdentityEnv
|
|
from torchy_baselines.common.evaluation import evaluate_policy
|
|
from torchy_baselines.common.noise import NormalActionNoise
|
|
|
|
|
|
@pytest.mark.parametrize("model_class", [A2C, PPO])
|
|
def test_discrete(model_class):
|
|
env = IdentityEnv(10)
|
|
model = model_class('MlpPolicy', env, gamma=0.5, seed=0).learn(3000)
|
|
|
|
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90)
|
|
|
|
|
|
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3])
|
|
def test_continuous(model_class):
|
|
env = IdentityEnvBox(eps=0.5)
|
|
|
|
n_steps = {
|
|
A2C: 3500,
|
|
PPO: 3000,
|
|
SAC: 700,
|
|
TD3: 500
|
|
}[model_class]
|
|
|
|
kwargs = dict(
|
|
policy_kwargs=dict(net_arch=[64, 64]),
|
|
seed=0,
|
|
gamma=0.95
|
|
)
|
|
if model_class in [TD3]:
|
|
n_actions = 1
|
|
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
|
|
kwargs['action_noise'] = action_noise
|
|
|
|
model = model_class('MlpPolicy', env, **kwargs).learn(n_steps)
|
|
|
|
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90)
|