diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index ebf2065..8feec01 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -56,6 +56,7 @@ class BaseRLModel(object): self.eval_env = None self.replay_buffer = None self.seed = seed + self.action_noise = None if env is not None: if isinstance(env, str): @@ -270,6 +271,16 @@ class BaseRLModel(object): obs = self.env.reset() return timesteps_since_eval, episode_num, evaluations, obs, eval_env + def _update_info_buffer(self, infos): + """ + Retrieve reward and episode length if using Monitor wrapper. + :param infos: ([dict]) + """ + for info in infos: + maybe_ep_info = info.get('episode') + if maybe_ep_info is not None: + self.ep_info_buffer.extend([maybe_ep_info]) + def collect_rollouts(self, env, n_episodes=1, n_steps=-1, action_noise=None, deterministic=False, callback=None, learning_starts=0, num_timesteps=0, @@ -310,10 +321,7 @@ class BaseRLModel(object): episode_reward += reward # Retrieve reward and episode length if using Monitor wrapper - for info in infos: - maybe_ep_info = info.get('episode') - if maybe_ep_info is not None: - self.ep_info_buffer.extend([maybe_ep_info]) + self._update_info_buffer(infos) # Store data in replay buffer if replay_buffer is not None: diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index 8933dda..3d748fb 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -18,6 +18,7 @@ from torchy_baselines.common.evaluation import evaluate_policy from torchy_baselines.common.buffers import RolloutBuffer from torchy_baselines.common.utils import explained_variance from torchy_baselines.common.vec_env import VecNormalize +from torchy_baselines.common import logger from torchy_baselines.ppo.policies import PPOPolicy @@ -153,8 +154,9 @@ class PPO(BaseRLModel): # Clip the actions to avoid out of bound error if isinstance(self.action_space, gym.spaces.Box): clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) - new_obs, rewards, dones, _ = env.step(clipped_actions) + new_obs, rewards, dones, infos = env.step(clipped_actions) + self._update_info_buffer(infos) n_steps += 1 rollout_buffer.add(obs, actions, rewards, dones, values, log_probs) obs = new_obs @@ -219,15 +221,10 @@ class PPO(BaseRLModel): # print(explained_variance(self.rollout_buffer.returns.flatten().cpu().numpy(), # self.rollout_buffer.values.flatten().cpu().numpy())) - def learn(self, total_timesteps, callback=None, log_interval=100, + def learn(self, total_timesteps, callback=None, log_interval=1, eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="PPO", reset_num_timesteps=True): - timesteps_since_eval = 0 - episode_num = 0 - evaluations = [] - start_time = time.time() - obs = self.env.reset() - eval_env = self._get_eval_env(eval_env) + timesteps_since_eval, iteration, evaluations, obs, eval_env = self._setup_learn(eval_env) if self.tensorboard_log is not None and SummaryWriter is not None: self.tb_writer = SummaryWriter(log_dir=os.path.join(self.tensorboard_log, tb_log_name)) @@ -241,10 +238,22 @@ class PPO(BaseRLModel): obs = self.collect_rollouts(self.env, self.rollout_buffer, n_rollout_steps=self.n_steps, obs=obs) - episode_num += 1 + iteration += 1 self.num_timesteps += self.n_steps * self.n_envs timesteps_since_eval += self.n_steps * self.n_envs + # Display training infos + if self.verbose >= 1 and log_interval is not None and iteration % log_interval == 0: + fps = int(self.num_timesteps / (time.time() - self.start_time)) + logger.logkv("iterations", iteration) + if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: + logger.logkv('ep_rew_mean', self.safe_mean([ep_info['r'] for ep_info in self.ep_info_buffer])) + logger.logkv('ep_len_mean', self.safe_mean([ep_info['l'] for ep_info in self.ep_info_buffer])) + logger.logkv("fps", fps) + logger.logkv('time_elapsed', int(time.time() - self.start_time)) + logger.logkv("total timesteps", self.num_timesteps) + logger.dumpkvs() + self.train(self.n_epochs, batch_size=self.batch_size) # Evaluate agent @@ -260,7 +269,7 @@ class PPO(BaseRLModel): evaluations.append(mean_reward) if self.verbose > 0: print("Eval num_timesteps={}, mean_reward={:.2f}".format(self.num_timesteps, evaluations[-1])) - print("FPS: {:.2f}".format(self.num_timesteps / (time.time() - start_time))) + print("FPS: {:.2f}".format(self.num_timesteps / (time.time() - self.start_time))) return self