diff --git a/tests/test_sde.py b/tests/test_sde.py index 7143e9c..c565269 100644 --- a/tests/test_sde.py +++ b/tests/test_sde.py @@ -49,7 +49,7 @@ def test_state_dependent_exploration(): @pytest.mark.parametrize("model_class", [A2C]) -@pytest.mark.parametrize("sde_net_arch", [None, [64, 64]]) +@pytest.mark.parametrize("sde_net_arch", [None, [32, 16]]) def test_state_dependent_noise(model_class, sde_net_arch): env_id = 'MountainCarContinuous-v0' diff --git a/torchy_baselines/ppo/policies.py b/torchy_baselines/ppo/policies.py index ed5c493..a0abfb4 100644 --- a/torchy_baselines/ppo/policies.py +++ b/torchy_baselines/ppo/policies.py @@ -89,7 +89,7 @@ class PPOPolicy(BasePolicy): activation_fn=self.activation_fn, squash_out=False) self.sde_feature_extractor = nn.Sequential(*latent_sde) - if isinstance(self.action_dist, (DiagGaussianDistribution, StateDependentNoiseDistribution)): + if isinstance(self.action_dist, DiagGaussianDistribution): self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=self.mlp_extractor.latent_dim_pi, log_std_init=self.log_std_init) elif isinstance(self.action_dist, StateDependentNoiseDistribution):