Bug fix in choosing the distribution

This commit is contained in:
Antonin Raffin 2019-11-25 15:02:10 +01:00
parent 5d6649d92b
commit 0885dbe74b
2 changed files with 2 additions and 2 deletions

View file

@ -49,7 +49,7 @@ def test_state_dependent_exploration():
@pytest.mark.parametrize("model_class", [A2C])
@pytest.mark.parametrize("sde_net_arch", [None, [64, 64]])
@pytest.mark.parametrize("sde_net_arch", [None, [32, 16]])
def test_state_dependent_noise(model_class, sde_net_arch):
env_id = 'MountainCarContinuous-v0'

View file

@ -89,7 +89,7 @@ class PPOPolicy(BasePolicy):
activation_fn=self.activation_fn, squash_out=False)
self.sde_feature_extractor = nn.Sequential(*latent_sde)
if isinstance(self.action_dist, (DiagGaussianDistribution, StateDependentNoiseDistribution)):
if isinstance(self.action_dist, DiagGaussianDistribution):
self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=self.mlp_extractor.latent_dim_pi,
log_std_init=self.log_std_init)
elif isinstance(self.action_dist, StateDependentNoiseDistribution):