From 3ececcd3a933167709e2ba6b7fe18dc84a7780b6 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 21 Sep 2019 17:09:26 +0200 Subject: [PATCH] Add tensorboard example --- torchy_baselines/ppo/ppo.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index 4fb2ee1..ad920f8 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -1,9 +1,14 @@ +import os import time from copy import deepcopy import gym import torch as th import torch.nn.functional as F +try: + from torch.utils.tensorboard import SummaryWriter +except ImportError: + SummaryWriter = None import numpy as np from torchy_baselines.common.base_class import BaseRLModel @@ -29,6 +34,7 @@ class PPO(BaseRLModel): gamma=0.99, lambda_=0.95, clip_range=0.2, ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5, target_kl=None, clip_range_vf=None, create_eval_env=False, + tensorboard_log=None, _init_setup_model=True): super(PPO, self).__init__(policy, env, PPOPolicy, policy_kwargs, @@ -48,6 +54,8 @@ class PPO(BaseRLModel): self.rollout_buffer = None self.target_kl = target_kl self.clip_range_vf = clip_range_vf + self.tensorboard_log = tensorboard_log + self.tb_writer = None if _init_setup_model: self._setup_model() @@ -172,6 +180,9 @@ class PPO(BaseRLModel): obs = self.env.reset() eval_env = self._get_eval_env(eval_env) + if self.tensorboard_log is not None: + self.tb_writer = SummaryWriter(log_dir=os.path.join(self.tensorboard_log, tb_log_name)) + while self.num_timesteps < total_timesteps: if callback is not None: @@ -194,6 +205,9 @@ class PPO(BaseRLModel): if isinstance(self.env, VecNormalize): eval_env.obs_rms = deepcopy(self.env.obs_rms) mean_reward, _ = evaluate_policy(self, eval_env, n_eval_episodes) + if self.tb_writer is not None: + self.tb_writer.add_scalar('Eval/reward', mean_reward, self.num_timesteps) + evaluations.append(mean_reward) if self.verbose > 0: print("Eval num_timesteps={}, mean_reward={:.2f}".format(self.num_timesteps, evaluations[-1]))