diff --git a/README.md b/README.md index b5624ec..74b55c6 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,9 @@ TODO: - save/load - better predict - complete logger +- SDE: reduce the number of parameters (only n_features instead of n_features x n_actions) for A2C +(done for TD3) +- SDE: learn the feature extractor? Later: - get_parameters / set_parameters diff --git a/torchy_baselines/td3/policies.py b/torchy_baselines/td3/policies.py index 9c93f10..89dc019 100644 --- a/torchy_baselines/td3/policies.py +++ b/torchy_baselines/td3/policies.py @@ -7,18 +7,27 @@ from torchy_baselines.common.policies import BasePolicy, register_policy, create class Actor(BaseNetwork): def __init__(self, obs_dim, action_dim, net_arch, activation_fn=nn.ReLU, - use_sde=False, log_std_init=-2, clip_noise=None, lr_sde=3e-4): + use_sde=False, log_std_init=-2, clip_noise=None, + lr_sde=3e-4, full_std=False): super(Actor, self).__init__() self.latent_pi, self.log_std = None, None self.weights_dist, self.exploration_mat = None, None self.use_sde, self.sde_optimizer = use_sde, None + self.action_dim = action_dim + self.full_std = full_std if use_sde: latent_dim = net_arch[-1] latent_pi = create_mlp(obs_dim, -1, net_arch, activation_fn, squash_out=False) self.latent_pi = nn.Sequential(*latent_pi) - self.log_std = nn.Parameter(th.ones(latent_dim, action_dim) * log_std_init) + if full_std: + self.log_std = nn.Parameter(th.ones(latent_dim, action_dim) * log_std_init) + else: + # Reduce the number of parameters: + self.log_std = nn.Parameter(th.ones(latent_dim, 1) * log_std_init) + + self.latent_dim = latent_dim self.actor_net = nn.Sequential(nn.Linear(net_arch[-1], action_dim), nn.Tanh()) self.clip_noise = clip_noise self.sde_optimizer = th.optim.Adam([self.log_std], lr=lr_sde) @@ -27,12 +36,19 @@ class Actor(BaseNetwork): actor_net = create_mlp(obs_dim, action_dim, net_arch, activation_fn, squash_out=True) self.actor_net = nn.Sequential(*actor_net) + def get_log_std(self): + if self.full_std: + return self.log_std + # Reduce the number of parameters: + return th.ones((self.latent_dim, self.action_dim)) * self.log_std + def get_distribution_stats(self, obs, action): with th.no_grad(): latent_pi = self.latent_pi(obs) mean_actions = self.actor_net(latent_pi) - variance = th.mm(latent_pi ** 2, th.exp(self.log_std) ** 2) - distribution = Normal(mean_actions, th.sqrt(variance)) + + variance = th.mm(latent_pi ** 2, th.exp(self.get_log_std()) ** 2) + distribution = Normal(mean_actions, th.sqrt(variance + 1e-5)) log_prob = distribution.log_prob(action) if len(log_prob.shape) > 1: log_prob = log_prob.sum(axis=1) @@ -41,7 +57,7 @@ class Actor(BaseNetwork): return log_prob, distribution.entropy() def reset_noise(self): - self.weights_dist = Normal(th.zeros_like(self.log_std), th.exp(self.log_std)) + self.weights_dist = Normal(th.zeros_like(self.get_log_std()), th.exp(self.get_log_std())) self.exploration_mat = self.weights_dist.rsample() def forward(self, obs, deterministic=True): diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 2dbdd5f..ce6b9a4 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -206,7 +206,11 @@ class TD3(BaseRLModel): # Normalize returns # returns = (returns - returns.mean()) / (returns.std() + 1e-8) - returns = (returns - returns.mean()) + # returns = (returns - returns.mean()) + with th.no_grad(): + current_q1, current_q2 = self.critic(obs, action) + # Alternatively use the q value + returns = (returns - th.min(current_q1, current_q2)) policy_loss = -(returns * log_prob).mean() @@ -218,6 +222,11 @@ class TD3(BaseRLModel): # Optimization step self.actor.sde_optimizer.zero_grad() loss.backward() + + assert not th.isnan(log_prob).any(), log_prob + assert not th.isnan(entropy).any() + assert not th.isnan(self.actor.log_std.grad).any() + assert not th.isnan(self.actor.log_std).any() # print(self.actor.log_std.grad.mean().item(), self.actor.log_std.grad.max().item(), self.actor.log_std.grad.min().item()) # print(self.actor.log_std.mean().item(), self.actor.log_std.max().item(), self.actor.log_std.min().item()) # Clip grad norm @@ -259,11 +268,13 @@ class TD3(BaseRLModel): print("Total T: {} Episode Num: {} Episode T: {} Reward: {}".format( self.num_timesteps, episode_num, episode_timesteps, episode_reward)) - gradient_steps = self.gradient_steps if self.gradient_steps > 0 else episode_timesteps - self.train(gradient_steps, batch_size=self.batch_size, policy_delay=self.policy_delay) if self.use_sde: self.train_sde() + gradient_steps = self.gradient_steps if self.gradient_steps > 0 else episode_timesteps + self.train(gradient_steps, batch_size=self.batch_size, policy_delay=self.policy_delay) + + # Evaluate episode if 0 < eval_freq <= timesteps_since_eval and eval_env is not None: timesteps_since_eval %= eval_freq