mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-22 22:10:16 +00:00
Improve initialization
This commit is contained in:
parent
dfe1ab9690
commit
e8ddd1f901
2 changed files with 22 additions and 6 deletions
|
|
@ -1,6 +1,9 @@
|
|||
from functools import partial
|
||||
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
from torch.distributions import Normal
|
||||
import numpy as np
|
||||
|
||||
from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp
|
||||
|
||||
|
|
@ -27,21 +30,29 @@ class PPOPolicy(BasePolicy):
|
|||
self._build(learning_rate)
|
||||
|
||||
@staticmethod
|
||||
def init_weights(module):
|
||||
def init_weights(module, gain=1):
|
||||
if type(module) == nn.Linear:
|
||||
nn.init.orthogonal_(module.weight, gain=1)
|
||||
nn.init.orthogonal_(module.weight, gain=gain)
|
||||
module.bias.data.fill_(0.0)
|
||||
|
||||
def _build(self, learning_rate):
|
||||
# TODO: support non-shared network
|
||||
shared_net = create_mlp(self.state_dim, output_dim=-1, net_arch=self.net_arch, activation_fn=self.activation_fn)
|
||||
self.shared_net = nn.Sequential(*shared_net).to(self.device)
|
||||
self.actor_net = nn.Linear(self.net_arch[-1], self.action_dim)
|
||||
self.value_net = nn.Linear(self.net_arch[-1], 1)
|
||||
self.log_std = nn.Parameter(th.zeros(self.action_dim))
|
||||
# Init weights:
|
||||
# Init weights: use orthogonal initialization
|
||||
for module in [self.shared_net, self.actor_net, self.value_net]:
|
||||
module.apply(self.init_weights)
|
||||
|
||||
gain = 0.01 if module == self.actor_net else 1.0
|
||||
# Values from stable-baselines check why
|
||||
gain = {
|
||||
self.shared_net: np.sqrt(2),
|
||||
self.actor_net: 0.01,
|
||||
self.value_net: 1
|
||||
}[module]
|
||||
module.apply(partial(self.init_weights, gain=gain))
|
||||
# TODO: support linear decay of the learning rate
|
||||
self.optimizer = th.optim.Adam(self.parameters(), lr=learning_rate, eps=self.adam_epsilon)
|
||||
|
||||
def forward(self, state, deterministic=False):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import time
|
||||
from copy import deepcopy
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
|
|
@ -7,9 +8,10 @@ import numpy as np
|
|||
|
||||
from torchy_baselines.common.base_class import BaseRLModel
|
||||
from torchy_baselines.common.evaluation import evaluate_policy
|
||||
from torchy_baselines.ppo.policies import PPOPolicy
|
||||
from torchy_baselines.common.buffers import RolloutBuffer
|
||||
from torchy_baselines.common.utils import explained_variance
|
||||
from torchy_baselines.common.vec_env import VecNormalize
|
||||
from torchy_baselines.ppo.policies import PPOPolicy
|
||||
|
||||
|
||||
class PPO(BaseRLModel):
|
||||
|
|
@ -188,6 +190,9 @@ class PPO(BaseRLModel):
|
|||
# Evaluate agent
|
||||
if 0 < eval_freq <= timesteps_since_eval and eval_env is not None:
|
||||
timesteps_since_eval %= eval_freq
|
||||
# Sync eval env and train env when using VecNormalize
|
||||
if isinstance(self.env, VecNormalize):
|
||||
eval_env.obs_rms = deepcopy(self.env.obs_rms)
|
||||
mean_reward, _ = evaluate_policy(self, eval_env, n_eval_episodes)
|
||||
evaluations.append(mean_reward)
|
||||
if self.verbose > 0:
|
||||
|
|
|
|||
Loading…
Reference in a new issue