stable-baselines3/tests/test_custom_policy.py
2019-11-21 13:01:03 +01:00

18 lines
403 B
Python

import os
import gym
import pytest
from torchy_baselines import PPO
@pytest.mark.parametrize('net_arch', [
[12, dict(vf=[16], pi=[8])],
[4],
[4, 4],
[12, dict(vf=[8, 4], pi=[8])],
[12, dict(vf=[8], pi=[8, 4])],
[12, dict(pi=[8])],
])
def test_flexible_mlp(net_arch):
model = PPO('MlpPolicy', 'CartPole-v1', policy_kwargs=dict(net_arch=net_arch), n_steps=100).learn(1000)