diff --git a/torchy_baselines/ppo/policies.py b/torchy_baselines/ppo/policies.py index e5ac7af..7eff3d5 100644 --- a/torchy_baselines/ppo/policies.py +++ b/torchy_baselines/ppo/policies.py @@ -16,8 +16,12 @@ class PPOPolicy(BasePolicy): ortho_init=True, use_sde=False, log_std_init=0.0): super(PPOPolicy, self).__init__(observation_space, action_space, device) self.obs_dim = self.observation_space.shape[0] + + + # Default network architecture, from stable-baselines if net_arch is None: - net_arch = [dict(pi=[64], vf=[64])] + net_arch = [dict(pi=[64, 64], vf=[64, 64])] + self.net_arch = net_arch self.activation_fn = activation_fn self.adam_epsilon = adam_epsilon