Add tensorboard example

This commit is contained in:
Antonin RAFFIN 2019-09-21 17:09:26 +02:00
parent e8ddd1f901
commit 3ececcd3a9

View file

@ -1,9 +1,14 @@
import os
import time
from copy import deepcopy
import gym
import torch as th
import torch.nn.functional as F
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
SummaryWriter = None
import numpy as np
from torchy_baselines.common.base_class import BaseRLModel
@ -29,6 +34,7 @@ class PPO(BaseRLModel):
gamma=0.99, lambda_=0.95, clip_range=0.2,
ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5,
target_kl=None, clip_range_vf=None, create_eval_env=False,
tensorboard_log=None,
_init_setup_model=True):
super(PPO, self).__init__(policy, env, PPOPolicy, policy_kwargs,
@ -48,6 +54,8 @@ class PPO(BaseRLModel):
self.rollout_buffer = None
self.target_kl = target_kl
self.clip_range_vf = clip_range_vf
self.tensorboard_log = tensorboard_log
self.tb_writer = None
if _init_setup_model:
self._setup_model()
@ -172,6 +180,9 @@ class PPO(BaseRLModel):
obs = self.env.reset()
eval_env = self._get_eval_env(eval_env)
if self.tensorboard_log is not None:
self.tb_writer = SummaryWriter(log_dir=os.path.join(self.tensorboard_log, tb_log_name))
while self.num_timesteps < total_timesteps:
if callback is not None:
@ -194,6 +205,9 @@ class PPO(BaseRLModel):
if isinstance(self.env, VecNormalize):
eval_env.obs_rms = deepcopy(self.env.obs_rms)
mean_reward, _ = evaluate_policy(self, eval_env, n_eval_episodes)
if self.tb_writer is not None:
self.tb_writer.add_scalar('Eval/reward', mean_reward, self.num_timesteps)
evaluations.append(mean_reward)
if self.verbose > 0:
print("Eval num_timesteps={}, mean_reward={:.2f}".format(self.num_timesteps, evaluations[-1]))