mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-24 22:25:13 +00:00
Implement DropQ
This commit is contained in:
parent
ebf6ed1d0a
commit
e7be8dc052
5 changed files with 43 additions and 9 deletions
|
|
@ -859,6 +859,8 @@ class ContinuousCritic(BaseModel):
|
|||
normalize_images: bool = True,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = True,
|
||||
dropout_rate: float = 0.0,
|
||||
layer_norm: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
observation_space,
|
||||
|
|
@ -873,7 +875,9 @@ class ContinuousCritic(BaseModel):
|
|||
self.n_critics = n_critics
|
||||
self.q_networks = []
|
||||
for idx in range(n_critics):
|
||||
q_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
|
||||
q_net = create_mlp(
|
||||
features_dim + action_dim, 1, net_arch, activation_fn, dropout_rate=dropout_rate, layer_norm=layer_norm
|
||||
)
|
||||
q_net = nn.Sequential(*q_net)
|
||||
self.add_module(f"qf{idx}", q_net)
|
||||
self.q_networks.append(q_net)
|
||||
|
|
|
|||
|
|
@ -99,6 +99,8 @@ def create_mlp(
|
|||
net_arch: List[int],
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
squash_output: bool = False,
|
||||
dropout_rate: float = 0.0,
|
||||
layer_norm: bool = False,
|
||||
) -> List[nn.Module]:
|
||||
"""
|
||||
Create a multi layer perceptron (MLP), which is
|
||||
|
|
@ -117,12 +119,22 @@ def create_mlp(
|
|||
"""
|
||||
|
||||
if len(net_arch) > 0:
|
||||
modules = [nn.Linear(input_dim, net_arch[0]), activation_fn()]
|
||||
additional_modules = []
|
||||
if dropout_rate > 0.0:
|
||||
additional_modules.append(nn.Dropout(p=dropout_rate))
|
||||
if layer_norm:
|
||||
additional_modules.append(nn.LayerNorm(net_arch[0]))
|
||||
modules = [nn.Linear(input_dim, net_arch[0])] + additional_modules + [activation_fn()]
|
||||
|
||||
else:
|
||||
modules = []
|
||||
|
||||
for idx in range(len(net_arch) - 1):
|
||||
modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1]))
|
||||
if dropout_rate > 0.0:
|
||||
modules.append(nn.Dropout(p=dropout_rate))
|
||||
if layer_norm:
|
||||
modules.append(nn.LayerNorm(net_arch[idx + 1]))
|
||||
modules.append(activation_fn())
|
||||
|
||||
if output_dim > 0:
|
||||
|
|
|
|||
|
|
@ -236,6 +236,9 @@ class SACPolicy(BasePolicy):
|
|||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = False,
|
||||
# For the critic only
|
||||
dropout_rate: float = 0.0,
|
||||
layer_norm: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
observation_space,
|
||||
|
|
@ -279,6 +282,8 @@ class SACPolicy(BasePolicy):
|
|||
"n_critics": n_critics,
|
||||
"net_arch": critic_arch,
|
||||
"share_features_extractor": share_features_extractor,
|
||||
"dropout_rate": dropout_rate,
|
||||
"layer_norm": layer_norm,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -239,10 +239,11 @@ class SAC(OffPolicyAlgorithm):
|
|||
with th.no_grad():
|
||||
# Select action according to policy
|
||||
next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
|
||||
q_networks_indices = np.random.permutation(len(self.critic_target.q_networks))[:2]
|
||||
q_networks = [q_net for idx, q_net in enumerate(self.critic_target.q_networks) if idx in q_networks_indices]
|
||||
# For REDQ, sample q networks to be used
|
||||
# q_networks_indices = np.random.permutation(len(self.critic_target.q_networks))[:2]
|
||||
# q_networks = [q_net for idx, q_net in enumerate(self.critic_target.q_networks) if idx in q_networks_indices]
|
||||
# Compute the next Q values: min over all critics targets
|
||||
next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions, q_networks), dim=1)
|
||||
next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
|
||||
next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
|
||||
# add entropy term
|
||||
next_q_values = next_q_values - ent_coef * next_log_prob.reshape(-1, 1)
|
||||
|
|
@ -264,12 +265,13 @@ class SAC(OffPolicyAlgorithm):
|
|||
|
||||
# Compute actor loss
|
||||
# Alternative: actor_loss = th.mean(log_prob - qf1_pi)
|
||||
# Mean over all critic networks
|
||||
# Min over all critic networks
|
||||
if update_actor:
|
||||
q_values_pi = th.cat(self.critic(replay_data.observations, actions_pi), dim=1)
|
||||
# Note: REDQ does a mean here
|
||||
min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
|
||||
actor_loss = (ent_coef * log_prob - min_qf_pi).mean()
|
||||
# Note: REDQ and DropQ does a mean here
|
||||
# min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
|
||||
mean_qf_pi = th.mean(q_values_pi, dim=1, keepdim=True)
|
||||
actor_loss = (ent_coef * log_prob - mean_qf_pi).mean()
|
||||
actor_losses.append(actor_loss.item())
|
||||
|
||||
# Optimize the actor
|
||||
|
|
|
|||
|
|
@ -89,6 +89,17 @@ def test_sac(ent_coef):
|
|||
model.learn(total_timesteps=300, eval_freq=250)
|
||||
|
||||
|
||||
def test_dropq():
|
||||
model = SAC(
|
||||
"MlpPolicy",
|
||||
"Pendulum-v1",
|
||||
policy_kwargs=dict(net_arch=[64, 64], layer_norm=True, dropout_rate=0.005),
|
||||
verbose=1,
|
||||
buffer_size=250,
|
||||
)
|
||||
model.learn(total_timesteps=300)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_critics", [1, 3])
|
||||
def test_n_critics(n_critics):
|
||||
# Test SAC with different number of critics, for TD3, n_critics=1 corresponds to DDPG
|
||||
|
|
|
|||
Loading…
Reference in a new issue