mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-16 21:10:08 +00:00
18 lines
403 B
Python
18 lines
403 B
Python
import os
|
|
|
|
import gym
|
|
import pytest
|
|
|
|
from torchy_baselines import PPO
|
|
|
|
|
|
@pytest.mark.parametrize('net_arch', [
|
|
[12, dict(vf=[16], pi=[8])],
|
|
[4],
|
|
[4, 4],
|
|
[12, dict(vf=[8, 4], pi=[8])],
|
|
[12, dict(vf=[8], pi=[8, 4])],
|
|
[12, dict(pi=[8])],
|
|
])
|
|
def test_flexible_mlp(net_arch):
|
|
model = PPO('MlpPolicy', 'CartPole-v1', policy_kwargs=dict(net_arch=net_arch), n_steps=100).learn(1000)
|