mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-01 23:30:53 +00:00
Add tensorboard example
This commit is contained in:
parent
e8ddd1f901
commit
3ececcd3a9
1 changed files with 14 additions and 0 deletions
|
|
@ -1,9 +1,14 @@
|
|||
import os
|
||||
import time
|
||||
from copy import deepcopy
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
import torch.nn.functional as F
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except ImportError:
|
||||
SummaryWriter = None
|
||||
import numpy as np
|
||||
|
||||
from torchy_baselines.common.base_class import BaseRLModel
|
||||
|
|
@ -29,6 +34,7 @@ class PPO(BaseRLModel):
|
|||
gamma=0.99, lambda_=0.95, clip_range=0.2,
|
||||
ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5,
|
||||
target_kl=None, clip_range_vf=None, create_eval_env=False,
|
||||
tensorboard_log=None,
|
||||
_init_setup_model=True):
|
||||
|
||||
super(PPO, self).__init__(policy, env, PPOPolicy, policy_kwargs,
|
||||
|
|
@ -48,6 +54,8 @@ class PPO(BaseRLModel):
|
|||
self.rollout_buffer = None
|
||||
self.target_kl = target_kl
|
||||
self.clip_range_vf = clip_range_vf
|
||||
self.tensorboard_log = tensorboard_log
|
||||
self.tb_writer = None
|
||||
|
||||
if _init_setup_model:
|
||||
self._setup_model()
|
||||
|
|
@ -172,6 +180,9 @@ class PPO(BaseRLModel):
|
|||
obs = self.env.reset()
|
||||
eval_env = self._get_eval_env(eval_env)
|
||||
|
||||
if self.tensorboard_log is not None:
|
||||
self.tb_writer = SummaryWriter(log_dir=os.path.join(self.tensorboard_log, tb_log_name))
|
||||
|
||||
while self.num_timesteps < total_timesteps:
|
||||
|
||||
if callback is not None:
|
||||
|
|
@ -194,6 +205,9 @@ class PPO(BaseRLModel):
|
|||
if isinstance(self.env, VecNormalize):
|
||||
eval_env.obs_rms = deepcopy(self.env.obs_rms)
|
||||
mean_reward, _ = evaluate_policy(self, eval_env, n_eval_episodes)
|
||||
if self.tb_writer is not None:
|
||||
self.tb_writer.add_scalar('Eval/reward', mean_reward, self.num_timesteps)
|
||||
|
||||
evaluations.append(mean_reward)
|
||||
if self.verbose > 0:
|
||||
print("Eval num_timesteps={}, mean_reward={:.2f}".format(self.num_timesteps, evaluations[-1]))
|
||||
|
|
|
|||
Loading…
Reference in a new issue