mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-16 21:10:08 +00:00
15 lines
377 B
Python
15 lines
377 B
Python
import pytest
|
|
|
|
from torchy_baselines import PPO
|
|
|
|
|
|
@pytest.mark.parametrize('net_arch', [
|
|
[12, dict(vf=[16], pi=[8])],
|
|
[4],
|
|
[4, 4],
|
|
[12, dict(vf=[8, 4], pi=[8])],
|
|
[12, dict(vf=[8], pi=[8, 4])],
|
|
[12, dict(pi=[8])],
|
|
])
|
|
def test_flexible_mlp(net_arch):
|
|
_ = PPO('MlpPolicy', 'CartPole-v1', policy_kwargs=dict(net_arch=net_arch), n_steps=100).learn(1000)
|