diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index d122acd..c8c87e8 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -859,6 +859,8 @@ class ContinuousCritic(BaseModel): normalize_images: bool = True, n_critics: int = 2, share_features_extractor: bool = True, + dropout_rate: float = 0.0, + layer_norm: bool = False, ): super().__init__( observation_space, @@ -873,7 +875,9 @@ class ContinuousCritic(BaseModel): self.n_critics = n_critics self.q_networks = [] for idx in range(n_critics): - q_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn) + q_net = create_mlp( + features_dim + action_dim, 1, net_arch, activation_fn, dropout_rate=dropout_rate, layer_norm=layer_norm + ) q_net = nn.Sequential(*q_net) self.add_module(f"qf{idx}", q_net) self.q_networks.append(q_net) diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index f87337c..525ffd4 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -99,6 +99,8 @@ def create_mlp( net_arch: List[int], activation_fn: Type[nn.Module] = nn.ReLU, squash_output: bool = False, + dropout_rate: float = 0.0, + layer_norm: bool = False, ) -> List[nn.Module]: """ Create a multi layer perceptron (MLP), which is @@ -117,12 +119,22 @@ def create_mlp( """ if len(net_arch) > 0: - modules = [nn.Linear(input_dim, net_arch[0]), activation_fn()] + additional_modules = [] + if dropout_rate > 0.0: + additional_modules.append(nn.Dropout(p=dropout_rate)) + if layer_norm: + additional_modules.append(nn.LayerNorm(net_arch[0])) + modules = [nn.Linear(input_dim, net_arch[0])] + additional_modules + [activation_fn()] + else: modules = [] for idx in range(len(net_arch) - 1): modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1])) + if dropout_rate > 0.0: + modules.append(nn.Dropout(p=dropout_rate)) + if layer_norm: + modules.append(nn.LayerNorm(net_arch[idx + 1])) modules.append(activation_fn()) if output_dim > 0: diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index 255bd75..392f088 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -236,6 +236,9 @@ class SACPolicy(BasePolicy): optimizer_kwargs: Optional[Dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, + # For the critic only + dropout_rate: float = 0.0, + layer_norm: bool = False, ): super().__init__( observation_space, @@ -279,6 +282,8 @@ class SACPolicy(BasePolicy): "n_critics": n_critics, "net_arch": critic_arch, "share_features_extractor": share_features_extractor, + "dropout_rate": dropout_rate, + "layer_norm": layer_norm, } ) diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index b7cbcf6..6c238c7 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -239,10 +239,11 @@ class SAC(OffPolicyAlgorithm): with th.no_grad(): # Select action according to policy next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations) - q_networks_indices = np.random.permutation(len(self.critic_target.q_networks))[:2] - q_networks = [q_net for idx, q_net in enumerate(self.critic_target.q_networks) if idx in q_networks_indices] + # For REDQ, sample q networks to be used + # q_networks_indices = np.random.permutation(len(self.critic_target.q_networks))[:2] + # q_networks = [q_net for idx, q_net in enumerate(self.critic_target.q_networks) if idx in q_networks_indices] # Compute the next Q values: min over all critics targets - next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions, q_networks), dim=1) + next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1) next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True) # add entropy term next_q_values = next_q_values - ent_coef * next_log_prob.reshape(-1, 1) @@ -264,12 +265,13 @@ class SAC(OffPolicyAlgorithm): # Compute actor loss # Alternative: actor_loss = th.mean(log_prob - qf1_pi) - # Mean over all critic networks + # Min over all critic networks if update_actor: q_values_pi = th.cat(self.critic(replay_data.observations, actions_pi), dim=1) - # Note: REDQ does a mean here - min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True) - actor_loss = (ent_coef * log_prob - min_qf_pi).mean() + # Note: REDQ and DropQ does a mean here + # min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True) + mean_qf_pi = th.mean(q_values_pi, dim=1, keepdim=True) + actor_loss = (ent_coef * log_prob - mean_qf_pi).mean() actor_losses.append(actor_loss.item()) # Optimize the actor diff --git a/tests/test_run.py b/tests/test_run.py index b0a9a11..7919d62 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -89,6 +89,17 @@ def test_sac(ent_coef): model.learn(total_timesteps=300, eval_freq=250) +def test_dropq(): + model = SAC( + "MlpPolicy", + "Pendulum-v1", + policy_kwargs=dict(net_arch=[64, 64], layer_norm=True, dropout_rate=0.005), + verbose=1, + buffer_size=250, + ) + model.learn(total_timesteps=300) + + @pytest.mark.parametrize("n_critics", [1, 3]) def test_n_critics(n_critics): # Test SAC with different number of critics, for TD3, n_critics=1 corresponds to DDPG