diff --git a/torchy_baselines/a2c/a2c.py b/torchy_baselines/a2c/a2c.py index 9e6af14..41ca60c 100644 --- a/torchy_baselines/a2c/a2c.py +++ b/torchy_baselines/a2c/a2c.py @@ -123,7 +123,8 @@ class A2C(PPO): logger.logkv("entropy", entropy.mean().item()) logger.logkv("policy_loss", policy_loss.item()) logger.logkv("value_loss", value_loss.item()) - logger.logkv("std", th.exp(self.policy.log_std).mean().item()) + if hasattr(self.policy, 'log_std'): + logger.logkv("std", th.exp(self.policy.log_std).mean().item()) def learn(self, total_timesteps, callback=None, log_interval=100, eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="A2C", reset_num_timesteps=True): diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index 32e5f95..2b2a5e6 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -253,7 +253,8 @@ class PPO(BaseRLModel): logger.logkv("entropy", entropy.mean().item()) logger.logkv("policy_loss", policy_loss.item()) logger.logkv("value_loss", value_loss.item()) - logger.logkv("std", th.exp(self.policy.log_std).mean().item()) + if hasattr(self.policy, 'log_std'): + logger.logkv("std", th.exp(self.policy.log_std).mean().item()) def learn(self, total_timesteps, callback=None, log_interval=1, eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="PPO", reset_num_timesteps=True):