Improve initialization

This commit is contained in:
Antonin RAFFIN 2019-09-21 16:48:51 +02:00
parent dfe1ab9690
commit e8ddd1f901
2 changed files with 22 additions and 6 deletions

View file

@ -1,6 +1,9 @@
from functools import partial
import torch as th
import torch.nn as nn
from torch.distributions import Normal
import numpy as np
from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp
@ -27,21 +30,29 @@ class PPOPolicy(BasePolicy):
self._build(learning_rate)
@staticmethod
def init_weights(module):
def init_weights(module, gain=1):
if type(module) == nn.Linear:
nn.init.orthogonal_(module.weight, gain=1)
nn.init.orthogonal_(module.weight, gain=gain)
module.bias.data.fill_(0.0)
def _build(self, learning_rate):
# TODO: support non-shared network
shared_net = create_mlp(self.state_dim, output_dim=-1, net_arch=self.net_arch, activation_fn=self.activation_fn)
self.shared_net = nn.Sequential(*shared_net).to(self.device)
self.actor_net = nn.Linear(self.net_arch[-1], self.action_dim)
self.value_net = nn.Linear(self.net_arch[-1], 1)
self.log_std = nn.Parameter(th.zeros(self.action_dim))
# Init weights:
# Init weights: use orthogonal initialization
for module in [self.shared_net, self.actor_net, self.value_net]:
module.apply(self.init_weights)
gain = 0.01 if module == self.actor_net else 1.0
# Values from stable-baselines check why
gain = {
self.shared_net: np.sqrt(2),
self.actor_net: 0.01,
self.value_net: 1
}[module]
module.apply(partial(self.init_weights, gain=gain))
# TODO: support linear decay of the learning rate
self.optimizer = th.optim.Adam(self.parameters(), lr=learning_rate, eps=self.adam_epsilon)
def forward(self, state, deterministic=False):

View file

@ -1,4 +1,5 @@
import time
from copy import deepcopy
import gym
import torch as th
@ -7,9 +8,10 @@ import numpy as np
from torchy_baselines.common.base_class import BaseRLModel
from torchy_baselines.common.evaluation import evaluate_policy
from torchy_baselines.ppo.policies import PPOPolicy
from torchy_baselines.common.buffers import RolloutBuffer
from torchy_baselines.common.utils import explained_variance
from torchy_baselines.common.vec_env import VecNormalize
from torchy_baselines.ppo.policies import PPOPolicy
class PPO(BaseRLModel):
@ -188,6 +190,9 @@ class PPO(BaseRLModel):
# Evaluate agent
if 0 < eval_freq <= timesteps_since_eval and eval_env is not None:
timesteps_since_eval %= eval_freq
# Sync eval env and train env when using VecNormalize
if isinstance(self.env, VecNormalize):
eval_env.obs_rms = deepcopy(self.env.obs_rms)
mean_reward, _ = evaluate_policy(self, eval_env, n_eval_episodes)
evaluations.append(mean_reward)
if self.verbose > 0: