Add logger for PPO

This commit is contained in:
Antonin RAFFIN 2019-10-17 13:44:48 +02:00
parent 53898f3d1a
commit 3bc746c6ee
2 changed files with 31 additions and 14 deletions

View file

@ -56,6 +56,7 @@ class BaseRLModel(object):
self.eval_env = None
self.replay_buffer = None
self.seed = seed
self.action_noise = None
if env is not None:
if isinstance(env, str):
@ -270,6 +271,16 @@ class BaseRLModel(object):
obs = self.env.reset()
return timesteps_since_eval, episode_num, evaluations, obs, eval_env
def _update_info_buffer(self, infos):
"""
Retrieve reward and episode length if using Monitor wrapper.
:param infos: ([dict])
"""
for info in infos:
maybe_ep_info = info.get('episode')
if maybe_ep_info is not None:
self.ep_info_buffer.extend([maybe_ep_info])
def collect_rollouts(self, env, n_episodes=1, n_steps=-1, action_noise=None,
deterministic=False, callback=None,
learning_starts=0, num_timesteps=0,
@ -310,10 +321,7 @@ class BaseRLModel(object):
episode_reward += reward
# Retrieve reward and episode length if using Monitor wrapper
for info in infos:
maybe_ep_info = info.get('episode')
if maybe_ep_info is not None:
self.ep_info_buffer.extend([maybe_ep_info])
self._update_info_buffer(infos)
# Store data in replay buffer
if replay_buffer is not None:

View file

@ -18,6 +18,7 @@ from torchy_baselines.common.evaluation import evaluate_policy
from torchy_baselines.common.buffers import RolloutBuffer
from torchy_baselines.common.utils import explained_variance
from torchy_baselines.common.vec_env import VecNormalize
from torchy_baselines.common import logger
from torchy_baselines.ppo.policies import PPOPolicy
@ -153,8 +154,9 @@ class PPO(BaseRLModel):
# Clip the actions to avoid out of bound error
if isinstance(self.action_space, gym.spaces.Box):
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
new_obs, rewards, dones, _ = env.step(clipped_actions)
new_obs, rewards, dones, infos = env.step(clipped_actions)
self._update_info_buffer(infos)
n_steps += 1
rollout_buffer.add(obs, actions, rewards, dones, values, log_probs)
obs = new_obs
@ -219,15 +221,10 @@ class PPO(BaseRLModel):
# print(explained_variance(self.rollout_buffer.returns.flatten().cpu().numpy(),
# self.rollout_buffer.values.flatten().cpu().numpy()))
def learn(self, total_timesteps, callback=None, log_interval=100,
def learn(self, total_timesteps, callback=None, log_interval=1,
eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="PPO", reset_num_timesteps=True):
timesteps_since_eval = 0
episode_num = 0
evaluations = []
start_time = time.time()
obs = self.env.reset()
eval_env = self._get_eval_env(eval_env)
timesteps_since_eval, iteration, evaluations, obs, eval_env = self._setup_learn(eval_env)
if self.tensorboard_log is not None and SummaryWriter is not None:
self.tb_writer = SummaryWriter(log_dir=os.path.join(self.tensorboard_log, tb_log_name))
@ -241,10 +238,22 @@ class PPO(BaseRLModel):
obs = self.collect_rollouts(self.env, self.rollout_buffer, n_rollout_steps=self.n_steps,
obs=obs)
episode_num += 1
iteration += 1
self.num_timesteps += self.n_steps * self.n_envs
timesteps_since_eval += self.n_steps * self.n_envs
# Display training infos
if self.verbose >= 1 and log_interval is not None and iteration % log_interval == 0:
fps = int(self.num_timesteps / (time.time() - self.start_time))
logger.logkv("iterations", iteration)
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
logger.logkv('ep_rew_mean', self.safe_mean([ep_info['r'] for ep_info in self.ep_info_buffer]))
logger.logkv('ep_len_mean', self.safe_mean([ep_info['l'] for ep_info in self.ep_info_buffer]))
logger.logkv("fps", fps)
logger.logkv('time_elapsed', int(time.time() - self.start_time))
logger.logkv("total timesteps", self.num_timesteps)
logger.dumpkvs()
self.train(self.n_epochs, batch_size=self.batch_size)
# Evaluate agent
@ -260,7 +269,7 @@ class PPO(BaseRLModel):
evaluations.append(mean_reward)
if self.verbose > 0:
print("Eval num_timesteps={}, mean_reward={:.2f}".format(self.num_timesteps, evaluations[-1]))
print("FPS: {:.2f}".format(self.num_timesteps / (time.time() - start_time)))
print("FPS: {:.2f}".format(self.num_timesteps / (time.time() - self.start_time)))
return self