mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-24 02:48:02 +00:00
Solve NaN issue and reduce number of parameters
This commit is contained in:
parent
d725d01186
commit
da325a0ba7
3 changed files with 38 additions and 8 deletions
|
|
@ -21,6 +21,9 @@ TODO:
|
|||
- save/load
|
||||
- better predict
|
||||
- complete logger
|
||||
- SDE: reduce the number of parameters (only n_features instead of n_features x n_actions) for A2C
|
||||
(done for TD3)
|
||||
- SDE: learn the feature extractor?
|
||||
|
||||
Later:
|
||||
- get_parameters / set_parameters
|
||||
|
|
|
|||
|
|
@ -7,18 +7,27 @@ from torchy_baselines.common.policies import BasePolicy, register_policy, create
|
|||
|
||||
class Actor(BaseNetwork):
|
||||
def __init__(self, obs_dim, action_dim, net_arch, activation_fn=nn.ReLU,
|
||||
use_sde=False, log_std_init=-2, clip_noise=None, lr_sde=3e-4):
|
||||
use_sde=False, log_std_init=-2, clip_noise=None,
|
||||
lr_sde=3e-4, full_std=False):
|
||||
super(Actor, self).__init__()
|
||||
|
||||
self.latent_pi, self.log_std = None, None
|
||||
self.weights_dist, self.exploration_mat = None, None
|
||||
self.use_sde, self.sde_optimizer = use_sde, None
|
||||
self.action_dim = action_dim
|
||||
self.full_std = full_std
|
||||
|
||||
if use_sde:
|
||||
latent_dim = net_arch[-1]
|
||||
latent_pi = create_mlp(obs_dim, -1, net_arch, activation_fn, squash_out=False)
|
||||
self.latent_pi = nn.Sequential(*latent_pi)
|
||||
self.log_std = nn.Parameter(th.ones(latent_dim, action_dim) * log_std_init)
|
||||
if full_std:
|
||||
self.log_std = nn.Parameter(th.ones(latent_dim, action_dim) * log_std_init)
|
||||
else:
|
||||
# Reduce the number of parameters:
|
||||
self.log_std = nn.Parameter(th.ones(latent_dim, 1) * log_std_init)
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
self.actor_net = nn.Sequential(nn.Linear(net_arch[-1], action_dim), nn.Tanh())
|
||||
self.clip_noise = clip_noise
|
||||
self.sde_optimizer = th.optim.Adam([self.log_std], lr=lr_sde)
|
||||
|
|
@ -27,12 +36,19 @@ class Actor(BaseNetwork):
|
|||
actor_net = create_mlp(obs_dim, action_dim, net_arch, activation_fn, squash_out=True)
|
||||
self.actor_net = nn.Sequential(*actor_net)
|
||||
|
||||
def get_log_std(self):
|
||||
if self.full_std:
|
||||
return self.log_std
|
||||
# Reduce the number of parameters:
|
||||
return th.ones((self.latent_dim, self.action_dim)) * self.log_std
|
||||
|
||||
def get_distribution_stats(self, obs, action):
|
||||
with th.no_grad():
|
||||
latent_pi = self.latent_pi(obs)
|
||||
mean_actions = self.actor_net(latent_pi)
|
||||
variance = th.mm(latent_pi ** 2, th.exp(self.log_std) ** 2)
|
||||
distribution = Normal(mean_actions, th.sqrt(variance))
|
||||
|
||||
variance = th.mm(latent_pi ** 2, th.exp(self.get_log_std()) ** 2)
|
||||
distribution = Normal(mean_actions, th.sqrt(variance + 1e-5))
|
||||
log_prob = distribution.log_prob(action)
|
||||
if len(log_prob.shape) > 1:
|
||||
log_prob = log_prob.sum(axis=1)
|
||||
|
|
@ -41,7 +57,7 @@ class Actor(BaseNetwork):
|
|||
return log_prob, distribution.entropy()
|
||||
|
||||
def reset_noise(self):
|
||||
self.weights_dist = Normal(th.zeros_like(self.log_std), th.exp(self.log_std))
|
||||
self.weights_dist = Normal(th.zeros_like(self.get_log_std()), th.exp(self.get_log_std()))
|
||||
self.exploration_mat = self.weights_dist.rsample()
|
||||
|
||||
def forward(self, obs, deterministic=True):
|
||||
|
|
|
|||
|
|
@ -206,7 +206,11 @@ class TD3(BaseRLModel):
|
|||
|
||||
# Normalize returns
|
||||
# returns = (returns - returns.mean()) / (returns.std() + 1e-8)
|
||||
returns = (returns - returns.mean())
|
||||
# returns = (returns - returns.mean())
|
||||
with th.no_grad():
|
||||
current_q1, current_q2 = self.critic(obs, action)
|
||||
# Alternatively use the q value
|
||||
returns = (returns - th.min(current_q1, current_q2))
|
||||
|
||||
policy_loss = -(returns * log_prob).mean()
|
||||
|
||||
|
|
@ -218,6 +222,11 @@ class TD3(BaseRLModel):
|
|||
# Optimization step
|
||||
self.actor.sde_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
assert not th.isnan(log_prob).any(), log_prob
|
||||
assert not th.isnan(entropy).any()
|
||||
assert not th.isnan(self.actor.log_std.grad).any()
|
||||
assert not th.isnan(self.actor.log_std).any()
|
||||
# print(self.actor.log_std.grad.mean().item(), self.actor.log_std.grad.max().item(), self.actor.log_std.grad.min().item())
|
||||
# print(self.actor.log_std.mean().item(), self.actor.log_std.max().item(), self.actor.log_std.min().item())
|
||||
# Clip grad norm
|
||||
|
|
@ -259,11 +268,13 @@ class TD3(BaseRLModel):
|
|||
print("Total T: {} Episode Num: {} Episode T: {} Reward: {}".format(
|
||||
self.num_timesteps, episode_num, episode_timesteps, episode_reward))
|
||||
|
||||
gradient_steps = self.gradient_steps if self.gradient_steps > 0 else episode_timesteps
|
||||
self.train(gradient_steps, batch_size=self.batch_size, policy_delay=self.policy_delay)
|
||||
if self.use_sde:
|
||||
self.train_sde()
|
||||
|
||||
gradient_steps = self.gradient_steps if self.gradient_steps > 0 else episode_timesteps
|
||||
self.train(gradient_steps, batch_size=self.batch_size, policy_delay=self.policy_delay)
|
||||
|
||||
|
||||
# Evaluate episode
|
||||
if 0 < eval_freq <= timesteps_since_eval and eval_env is not None:
|
||||
timesteps_since_eval %= eval_freq
|
||||
|
|
|
|||
Loading…
Reference in a new issue