diff --git a/torchy_baselines/ppo/policies.py b/torchy_baselines/ppo/policies.py index f542788..0844d78 100644 --- a/torchy_baselines/ppo/policies.py +++ b/torchy_baselines/ppo/policies.py @@ -1,6 +1,9 @@ +from functools import partial + import torch as th import torch.nn as nn from torch.distributions import Normal +import numpy as np from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp @@ -27,21 +30,29 @@ class PPOPolicy(BasePolicy): self._build(learning_rate) @staticmethod - def init_weights(module): + def init_weights(module, gain=1): if type(module) == nn.Linear: - nn.init.orthogonal_(module.weight, gain=1) + nn.init.orthogonal_(module.weight, gain=gain) module.bias.data.fill_(0.0) def _build(self, learning_rate): + # TODO: support non-shared network shared_net = create_mlp(self.state_dim, output_dim=-1, net_arch=self.net_arch, activation_fn=self.activation_fn) self.shared_net = nn.Sequential(*shared_net).to(self.device) self.actor_net = nn.Linear(self.net_arch[-1], self.action_dim) self.value_net = nn.Linear(self.net_arch[-1], 1) self.log_std = nn.Parameter(th.zeros(self.action_dim)) - # Init weights: + # Init weights: use orthogonal initialization for module in [self.shared_net, self.actor_net, self.value_net]: - module.apply(self.init_weights) - + gain = 0.01 if module == self.actor_net else 1.0 + # Values from stable-baselines check why + gain = { + self.shared_net: np.sqrt(2), + self.actor_net: 0.01, + self.value_net: 1 + }[module] + module.apply(partial(self.init_weights, gain=gain)) + # TODO: support linear decay of the learning rate self.optimizer = th.optim.Adam(self.parameters(), lr=learning_rate, eps=self.adam_epsilon) def forward(self, state, deterministic=False): diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index 9367403..4fb2ee1 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -1,4 +1,5 @@ import time +from copy import deepcopy import gym import torch as th @@ -7,9 +8,10 @@ import numpy as np from torchy_baselines.common.base_class import BaseRLModel from torchy_baselines.common.evaluation import evaluate_policy -from torchy_baselines.ppo.policies import PPOPolicy from torchy_baselines.common.buffers import RolloutBuffer from torchy_baselines.common.utils import explained_variance +from torchy_baselines.common.vec_env import VecNormalize +from torchy_baselines.ppo.policies import PPOPolicy class PPO(BaseRLModel): @@ -188,6 +190,9 @@ class PPO(BaseRLModel): # Evaluate agent if 0 < eval_freq <= timesteps_since_eval and eval_env is not None: timesteps_since_eval %= eval_freq + # Sync eval env and train env when using VecNormalize + if isinstance(self.env, VecNormalize): + eval_env.obs_rms = deepcopy(self.env.obs_rms) mean_reward, _ = evaluate_policy(self, eval_env, n_eval_episodes) evaluations.append(mean_reward) if self.verbose > 0: