mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-03 03:59:13 +00:00
Fix logger for discrete actions
This commit is contained in:
parent
c6f90b9c3c
commit
95c741c707
2 changed files with 4 additions and 2 deletions
|
|
@ -123,7 +123,8 @@ class A2C(PPO):
|
|||
logger.logkv("entropy", entropy.mean().item())
|
||||
logger.logkv("policy_loss", policy_loss.item())
|
||||
logger.logkv("value_loss", value_loss.item())
|
||||
logger.logkv("std", th.exp(self.policy.log_std).mean().item())
|
||||
if hasattr(self.policy, 'log_std'):
|
||||
logger.logkv("std", th.exp(self.policy.log_std).mean().item())
|
||||
|
||||
def learn(self, total_timesteps, callback=None, log_interval=100,
|
||||
eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="A2C", reset_num_timesteps=True):
|
||||
|
|
|
|||
|
|
@ -253,7 +253,8 @@ class PPO(BaseRLModel):
|
|||
logger.logkv("entropy", entropy.mean().item())
|
||||
logger.logkv("policy_loss", policy_loss.item())
|
||||
logger.logkv("value_loss", value_loss.item())
|
||||
logger.logkv("std", th.exp(self.policy.log_std).mean().item())
|
||||
if hasattr(self.policy, 'log_std'):
|
||||
logger.logkv("std", th.exp(self.policy.log_std).mean().item())
|
||||
|
||||
def learn(self, total_timesteps, callback=None, log_interval=1,
|
||||
eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="PPO", reset_num_timesteps=True):
|
||||
|
|
|
|||
Loading…
Reference in a new issue