Solve NaN issue and reduce number of parameters

This commit is contained in:
Antonin Raffin 2019-11-13 13:02:37 +01:00
parent d725d01186
commit da325a0ba7
3 changed files with 38 additions and 8 deletions

View file

@ -21,6 +21,9 @@ TODO:
- save/load
- better predict
- complete logger
- SDE: reduce the number of parameters (only n_features instead of n_features x n_actions) for A2C
(done for TD3)
- SDE: learn the feature extractor?
Later:
- get_parameters / set_parameters

View file

@ -7,18 +7,27 @@ from torchy_baselines.common.policies import BasePolicy, register_policy, create
class Actor(BaseNetwork):
def __init__(self, obs_dim, action_dim, net_arch, activation_fn=nn.ReLU,
use_sde=False, log_std_init=-2, clip_noise=None, lr_sde=3e-4):
use_sde=False, log_std_init=-2, clip_noise=None,
lr_sde=3e-4, full_std=False):
super(Actor, self).__init__()
self.latent_pi, self.log_std = None, None
self.weights_dist, self.exploration_mat = None, None
self.use_sde, self.sde_optimizer = use_sde, None
self.action_dim = action_dim
self.full_std = full_std
if use_sde:
latent_dim = net_arch[-1]
latent_pi = create_mlp(obs_dim, -1, net_arch, activation_fn, squash_out=False)
self.latent_pi = nn.Sequential(*latent_pi)
self.log_std = nn.Parameter(th.ones(latent_dim, action_dim) * log_std_init)
if full_std:
self.log_std = nn.Parameter(th.ones(latent_dim, action_dim) * log_std_init)
else:
# Reduce the number of parameters:
self.log_std = nn.Parameter(th.ones(latent_dim, 1) * log_std_init)
self.latent_dim = latent_dim
self.actor_net = nn.Sequential(nn.Linear(net_arch[-1], action_dim), nn.Tanh())
self.clip_noise = clip_noise
self.sde_optimizer = th.optim.Adam([self.log_std], lr=lr_sde)
@ -27,12 +36,19 @@ class Actor(BaseNetwork):
actor_net = create_mlp(obs_dim, action_dim, net_arch, activation_fn, squash_out=True)
self.actor_net = nn.Sequential(*actor_net)
def get_log_std(self):
if self.full_std:
return self.log_std
# Reduce the number of parameters:
return th.ones((self.latent_dim, self.action_dim)) * self.log_std
def get_distribution_stats(self, obs, action):
with th.no_grad():
latent_pi = self.latent_pi(obs)
mean_actions = self.actor_net(latent_pi)
variance = th.mm(latent_pi ** 2, th.exp(self.log_std) ** 2)
distribution = Normal(mean_actions, th.sqrt(variance))
variance = th.mm(latent_pi ** 2, th.exp(self.get_log_std()) ** 2)
distribution = Normal(mean_actions, th.sqrt(variance + 1e-5))
log_prob = distribution.log_prob(action)
if len(log_prob.shape) > 1:
log_prob = log_prob.sum(axis=1)
@ -41,7 +57,7 @@ class Actor(BaseNetwork):
return log_prob, distribution.entropy()
def reset_noise(self):
self.weights_dist = Normal(th.zeros_like(self.log_std), th.exp(self.log_std))
self.weights_dist = Normal(th.zeros_like(self.get_log_std()), th.exp(self.get_log_std()))
self.exploration_mat = self.weights_dist.rsample()
def forward(self, obs, deterministic=True):

View file

@ -206,7 +206,11 @@ class TD3(BaseRLModel):
# Normalize returns
# returns = (returns - returns.mean()) / (returns.std() + 1e-8)
returns = (returns - returns.mean())
# returns = (returns - returns.mean())
with th.no_grad():
current_q1, current_q2 = self.critic(obs, action)
# Alternatively use the q value
returns = (returns - th.min(current_q1, current_q2))
policy_loss = -(returns * log_prob).mean()
@ -218,6 +222,11 @@ class TD3(BaseRLModel):
# Optimization step
self.actor.sde_optimizer.zero_grad()
loss.backward()
assert not th.isnan(log_prob).any(), log_prob
assert not th.isnan(entropy).any()
assert not th.isnan(self.actor.log_std.grad).any()
assert not th.isnan(self.actor.log_std).any()
# print(self.actor.log_std.grad.mean().item(), self.actor.log_std.grad.max().item(), self.actor.log_std.grad.min().item())
# print(self.actor.log_std.mean().item(), self.actor.log_std.max().item(), self.actor.log_std.min().item())
# Clip grad norm
@ -259,11 +268,13 @@ class TD3(BaseRLModel):
print("Total T: {} Episode Num: {} Episode T: {} Reward: {}".format(
self.num_timesteps, episode_num, episode_timesteps, episode_reward))
gradient_steps = self.gradient_steps if self.gradient_steps > 0 else episode_timesteps
self.train(gradient_steps, batch_size=self.batch_size, policy_delay=self.policy_delay)
if self.use_sde:
self.train_sde()
gradient_steps = self.gradient_steps if self.gradient_steps > 0 else episode_timesteps
self.train(gradient_steps, batch_size=self.batch_size, policy_delay=self.policy_delay)
# Evaluate episode
if 0 < eval_freq <= timesteps_since_eval and eval_env is not None:
timesteps_since_eval %= eval_freq