Implement DropQ

This commit is contained in:
Antonin Raffin 2022-08-17 15:15:56 +02:00
parent ebf6ed1d0a
commit e7be8dc052
No known key found for this signature in database
GPG key ID: B8B48F65CAD6232C
5 changed files with 43 additions and 9 deletions

View file

@ -859,6 +859,8 @@ class ContinuousCritic(BaseModel):
normalize_images: bool = True,
n_critics: int = 2,
share_features_extractor: bool = True,
dropout_rate: float = 0.0,
layer_norm: bool = False,
):
super().__init__(
observation_space,
@ -873,7 +875,9 @@ class ContinuousCritic(BaseModel):
self.n_critics = n_critics
self.q_networks = []
for idx in range(n_critics):
q_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
q_net = create_mlp(
features_dim + action_dim, 1, net_arch, activation_fn, dropout_rate=dropout_rate, layer_norm=layer_norm
)
q_net = nn.Sequential(*q_net)
self.add_module(f"qf{idx}", q_net)
self.q_networks.append(q_net)

View file

@ -99,6 +99,8 @@ def create_mlp(
net_arch: List[int],
activation_fn: Type[nn.Module] = nn.ReLU,
squash_output: bool = False,
dropout_rate: float = 0.0,
layer_norm: bool = False,
) -> List[nn.Module]:
"""
Create a multi layer perceptron (MLP), which is
@ -117,12 +119,22 @@ def create_mlp(
"""
if len(net_arch) > 0:
modules = [nn.Linear(input_dim, net_arch[0]), activation_fn()]
additional_modules = []
if dropout_rate > 0.0:
additional_modules.append(nn.Dropout(p=dropout_rate))
if layer_norm:
additional_modules.append(nn.LayerNorm(net_arch[0]))
modules = [nn.Linear(input_dim, net_arch[0])] + additional_modules + [activation_fn()]
else:
modules = []
for idx in range(len(net_arch) - 1):
modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1]))
if dropout_rate > 0.0:
modules.append(nn.Dropout(p=dropout_rate))
if layer_norm:
modules.append(nn.LayerNorm(net_arch[idx + 1]))
modules.append(activation_fn())
if output_dim > 0:

View file

@ -236,6 +236,9 @@ class SACPolicy(BasePolicy):
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = False,
# For the critic only
dropout_rate: float = 0.0,
layer_norm: bool = False,
):
super().__init__(
observation_space,
@ -279,6 +282,8 @@ class SACPolicy(BasePolicy):
"n_critics": n_critics,
"net_arch": critic_arch,
"share_features_extractor": share_features_extractor,
"dropout_rate": dropout_rate,
"layer_norm": layer_norm,
}
)

View file

@ -239,10 +239,11 @@ class SAC(OffPolicyAlgorithm):
with th.no_grad():
# Select action according to policy
next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
q_networks_indices = np.random.permutation(len(self.critic_target.q_networks))[:2]
q_networks = [q_net for idx, q_net in enumerate(self.critic_target.q_networks) if idx in q_networks_indices]
# For REDQ, sample q networks to be used
# q_networks_indices = np.random.permutation(len(self.critic_target.q_networks))[:2]
# q_networks = [q_net for idx, q_net in enumerate(self.critic_target.q_networks) if idx in q_networks_indices]
# Compute the next Q values: min over all critics targets
next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions, q_networks), dim=1)
next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
# add entropy term
next_q_values = next_q_values - ent_coef * next_log_prob.reshape(-1, 1)
@ -264,12 +265,13 @@ class SAC(OffPolicyAlgorithm):
# Compute actor loss
# Alternative: actor_loss = th.mean(log_prob - qf1_pi)
# Mean over all critic networks
# Min over all critic networks
if update_actor:
q_values_pi = th.cat(self.critic(replay_data.observations, actions_pi), dim=1)
# Note: REDQ does a mean here
min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
actor_loss = (ent_coef * log_prob - min_qf_pi).mean()
# Note: REDQ and DropQ does a mean here
# min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
mean_qf_pi = th.mean(q_values_pi, dim=1, keepdim=True)
actor_loss = (ent_coef * log_prob - mean_qf_pi).mean()
actor_losses.append(actor_loss.item())
# Optimize the actor

View file

@ -89,6 +89,17 @@ def test_sac(ent_coef):
model.learn(total_timesteps=300, eval_freq=250)
def test_dropq():
model = SAC(
"MlpPolicy",
"Pendulum-v1",
policy_kwargs=dict(net_arch=[64, 64], layer_norm=True, dropout_rate=0.005),
verbose=1,
buffer_size=250,
)
model.learn(total_timesteps=300)
@pytest.mark.parametrize("n_critics", [1, 3])
def test_n_critics(n_critics):
# Test SAC with different number of critics, for TD3, n_critics=1 corresponds to DDPG