stable-baselines3/tests/test_identity.py
Roland Gavrilescu 91adefdb4b
Support for MultiBinary / MultiDiscrete spaces (#13)
* multicategorical dist and test

* fixed List annotation

* bernoulli dist and test

* added distributions to preprocessing (needs testing)

* fixed and tested distributions

* added changelog and fixed ppo policy

* minor fix

* dist fixes, added test_spaces

* clean up

* modified changelog

* additional fixes

* minor changelog mod

* hot encoding fix, flake8 clean up

* lint tests

* preprocessing fix

* fixed bernoulli bug

* removed commented prints

* Update changelog.rst

* included suggested modifications

* linting fix

* increased space dim

* Update doc and tests

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-05-18 14:42:13 +02:00

51 lines
1.6 KiB
Python

import numpy as np
import pytest
from stable_baselines3 import A2C, PPO, SAC, TD3
from stable_baselines3.common.identity_env import (IdentityEnvBox, IdentityEnv,
IdentityEnvMultiBinary, IdentityEnvMultiDiscrete)
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.noise import NormalActionNoise
DIM = 4
@pytest.mark.parametrize("model_class", [A2C, PPO])
@pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)])
def test_discrete(model_class, env):
env = DummyVecEnv([lambda: env])
model = model_class('MlpPolicy', env, gamma=0.5, seed=1).learn(3000)
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90)
obs = env.reset()
assert np.shape(model.predict(obs)[0]) == np.shape(obs)
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3])
def test_continuous(model_class):
env = IdentityEnvBox(eps=0.5)
n_steps = {
A2C: 3500,
PPO: 3000,
SAC: 700,
TD3: 500
}[model_class]
kwargs = dict(
policy_kwargs=dict(net_arch=[64, 64]),
seed=0,
gamma=0.95
)
if model_class in [TD3]:
n_actions = 1
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
kwargs['action_noise'] = action_noise
model = model_class('MlpPolicy', env, **kwargs).learn(n_steps)
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90)