mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-17 21:20:11 +00:00
Bug fix in choosing the distribution
This commit is contained in:
parent
5d6649d92b
commit
0885dbe74b
2 changed files with 2 additions and 2 deletions
|
|
@ -49,7 +49,7 @@ def test_state_dependent_exploration():
|
|||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C])
|
||||
@pytest.mark.parametrize("sde_net_arch", [None, [64, 64]])
|
||||
@pytest.mark.parametrize("sde_net_arch", [None, [32, 16]])
|
||||
def test_state_dependent_noise(model_class, sde_net_arch):
|
||||
env_id = 'MountainCarContinuous-v0'
|
||||
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ class PPOPolicy(BasePolicy):
|
|||
activation_fn=self.activation_fn, squash_out=False)
|
||||
self.sde_feature_extractor = nn.Sequential(*latent_sde)
|
||||
|
||||
if isinstance(self.action_dist, (DiagGaussianDistribution, StateDependentNoiseDistribution)):
|
||||
if isinstance(self.action_dist, DiagGaussianDistribution):
|
||||
self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=self.mlp_extractor.latent_dim_pi,
|
||||
log_std_init=self.log_std_init)
|
||||
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
|
||||
|
|
|
|||
Loading…
Reference in a new issue