diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 9073296..6200bc1 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -511,7 +511,7 @@ class BaseRLModel(object): self.actor.reset_noise() # Reset rollout data if self.on_policy_exploration: - self.rollout_data = {key: [] for key in ['observations', 'actions', 'rewards', 'dones']} + self.rollout_data = {key: [] for key in ['observations', 'actions', 'rewards', 'dones', 'values']} while total_steps < n_steps or total_episodes < n_episodes: done = False @@ -574,6 +574,7 @@ class BaseRLModel(object): self.rollout_data['actions'].append(scaled_action[0].copy()) self.rollout_data['rewards'].append(reward[0].copy()) self.rollout_data['dones'].append(np.array(done_bool[0]).copy()) + self.rollout_data['values'].append(self.vf_net(th.FloatTensor(obs).to(self.device))[0].cpu().detach().numpy()) obs = new_obs # Save the true unnormalized observation @@ -615,19 +616,24 @@ class BaseRLModel(object): # Post processing if self.rollout_data is not None: - for key in ['observations', 'actions', 'rewards', 'dones']: + for key in ['observations', 'actions', 'rewards', 'dones', 'values']: self.rollout_data[key] = th.FloatTensor(np.array(self.rollout_data[key])).to(self.device) self.rollout_data['returns'] = self.rollout_data['rewards'].clone() - # Compute return + self.rollout_data['advantage'] = self.rollout_data['rewards'].clone() + + # Compute return and advantage last_return = 0.0 for step in reversed(range(len(self.rollout_data['rewards']))): if step == len(self.rollout_data['rewards']) - 1: - last_return = self.rollout_data['rewards'][step] + next_non_terminal = 1.0 - done[0] + next_value = self.vf_net(th.FloatTensor(obs).to(self.device))[0].detach() + last_return = self.rollout_data['rewards'][step] + next_non_terminal * next_value else: next_non_terminal = 1.0 - self.rollout_data['dones'][step + 1] last_return = self.rollout_data['rewards'][step] + self.gamma * last_return * next_non_terminal self.rollout_data['returns'][step] = last_return + self.rollout_data['advantage'] = self.rollout_data['returns'] - self.rollout_data['values'] return mean_reward, total_steps, total_episodes, obs diff --git a/torchy_baselines/td3/policies.py b/torchy_baselines/td3/policies.py index a472b58..1fd6a20 100644 --- a/torchy_baselines/td3/policies.py +++ b/torchy_baselines/td3/policies.py @@ -158,6 +158,27 @@ class Critic(BaseNetwork): return self.q_networks[0](th.cat([obs, action], dim=1)) +class ValueFunction(BaseNetwork): + """ + Value function for TD3 when doing on-policy exploration with SDE. + + :param obs_dim: (int) Dimension of the observation + :param net_arch: ([int]) Network architecture + :param activation_fn: (nn.Module) Activation function + """ + def __init__(self, obs_dim, net_arch=None, activation_fn=nn.Tanh): + super(ValueFunction, self).__init__() + + if net_arch is None: + net_arch = [64, 64] + + vf_net = create_mlp(obs_dim, 1, net_arch, activation_fn) + self.vf_net = nn.Sequential(*vf_net) + + def forward(self, obs): + return self.vf_net(obs) + + class TD3Policy(BasePolicy): """ Policy class (with both actor and critic) for TD3. @@ -206,7 +227,9 @@ class TD3Policy(BasePolicy): self.actor, self.actor_target = None, None self.critic, self.critic_target = None, None + # For SDE only self.use_sde = use_sde + self.vf_net = None self.log_std_init = log_std_init self._build(learning_rate) @@ -221,6 +244,10 @@ class TD3Policy(BasePolicy): self.critic_target.load_state_dict(self.critic.state_dict()) self.critic.optimizer = th.optim.Adam(self.critic.parameters(), lr=learning_rate(1)) + if self.use_sde: + self.vf_net = ValueFunction(self.obs_dim) + self.actor.sde_optimizer.add_param_group({'params': self.vf_net.parameters()}) + def reset_noise(self): return self.actor.reset_noise() diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 3e1dfb7..9b5cafd 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -9,6 +9,7 @@ from torchy_baselines.common.buffers import ReplayBuffer from torchy_baselines.common.evaluation import evaluate_policy from torchy_baselines.td3.policies import TD3Policy from torchy_baselines.common.vec_env import sync_envs_normalization +from torchy_baselines.ppo.policies import MlpPolicy class TD3(BaseRLModel): @@ -85,6 +86,7 @@ class TD3(BaseRLModel): self.sde_ent_coef = sde_ent_coef self.sde_log_std_scheduler = sde_log_std_scheduler self.on_policy_exploration = True + self.sde_vf = None if _init_setup_model: self._setup_model() @@ -104,6 +106,7 @@ class TD3(BaseRLModel): self.actor_target = self.policy.actor_target self.critic = self.policy.critic self.critic_target = self.policy.critic_target + self.vf_net = self.policy.vf_net def select_action(self, observation, deterministic=True): # Normally not needed @@ -209,25 +212,26 @@ class TD3(BaseRLModel): # self._update_learning_rate(self.policy.optimizer) # Unpack - obs, action, returns = [self.rollout_data[key] for key in ['observations', 'actions', 'returns']] + obs, action, advantage, returns = [self.rollout_data[key] for key in ['observations', 'actions', 'advantage', 'returns']] - # TODO: avoid second computation of everything because of the gradient log_prob, entropy = self.actor.evaluate_actions(obs, action) + values = self.vf_net(obs).flatten() - # Normalize returns - # returns = (returns - returns.mean()) / (returns.std() + 1e-8) - # returns = (returns - returns.mean()) - with th.no_grad(): - current_q1, current_q2 = self.critic(obs, action) - # Alternatively use the q value - returns = (returns - th.min(current_q1, current_q2)) + # Normalize advantage + # if self.normalize_advantage: + # advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8) - policy_loss = -(returns * log_prob).mean() + # Value loss using the TD(gae_lambda) target + value_loss = F.mse_loss(returns, values) + + # A2C loss + policy_loss = -(advantage * log_prob).mean() # Entropy loss favor exploration entropy_loss = -th.mean(entropy) - loss = policy_loss + self.sde_ent_coef * entropy_loss + vf_coef = 0.5 + loss = policy_loss + self.sde_ent_coef * entropy_loss + vf_coef * value_loss # Optimization step self.actor.sde_optimizer.zero_grad()