mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-05 00:00:04 +00:00
* Add DDPG + TD3 with any number of critics * Allow any number of critics for SAC * Update doc * [ci skip] Update DDPG example * Remove unused parameter * Add DDPG to identity test * Fix computation with n_critics=1,3 * Update doc * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> * Update docstrings for off-policy algos * Add check for sde Co-authored-by: Adam Gleave <adam@gleave.me>
60 lines
1.9 KiB
Python
60 lines
1.9 KiB
Python
import numpy as np
|
|
import pytest
|
|
|
|
from stable_baselines3 import A2C, PPO, SAC, TD3, DQN, DDPG
|
|
from stable_baselines3.common.identity_env import (IdentityEnvBox, IdentityEnv,
|
|
IdentityEnvMultiBinary, IdentityEnvMultiDiscrete)
|
|
|
|
from stable_baselines3.common.vec_env import DummyVecEnv
|
|
from stable_baselines3.common.evaluation import evaluate_policy
|
|
from stable_baselines3.common.noise import NormalActionNoise
|
|
|
|
DIM = 4
|
|
|
|
|
|
@pytest.mark.parametrize("model_class", [A2C, PPO, DQN])
|
|
@pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)])
|
|
def test_discrete(model_class, env):
|
|
env_ = DummyVecEnv([lambda: env])
|
|
kwargs = {}
|
|
n_steps = 3000
|
|
if model_class == DQN:
|
|
kwargs = dict(learning_starts=0)
|
|
n_steps = 4000
|
|
# DQN only support discrete actions
|
|
if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)):
|
|
return
|
|
|
|
model = model_class('MlpPolicy', env_, gamma=0.5, seed=1, **kwargs).learn(n_steps)
|
|
|
|
evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90)
|
|
obs = env.reset()
|
|
|
|
assert np.shape(model.predict(obs)[0]) == np.shape(obs)
|
|
|
|
|
|
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, DDPG, TD3])
|
|
def test_continuous(model_class):
|
|
env = IdentityEnvBox(eps=0.5)
|
|
|
|
n_steps = {
|
|
A2C: 3500,
|
|
PPO: 3000,
|
|
SAC: 700,
|
|
TD3: 500,
|
|
DDPG: 500
|
|
}[model_class]
|
|
|
|
kwargs = dict(
|
|
policy_kwargs=dict(net_arch=[64, 64]),
|
|
seed=0,
|
|
gamma=0.95
|
|
)
|
|
if model_class in [TD3]:
|
|
n_actions = 1
|
|
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
|
|
kwargs['action_noise'] = action_noise
|
|
|
|
model = model_class('MlpPolicy', env, **kwargs).learn(n_steps)
|
|
|
|
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90)
|