mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-01 23:30:53 +00:00
Add logger for PPO
This commit is contained in:
parent
53898f3d1a
commit
3bc746c6ee
2 changed files with 31 additions and 14 deletions
|
|
@ -56,6 +56,7 @@ class BaseRLModel(object):
|
|||
self.eval_env = None
|
||||
self.replay_buffer = None
|
||||
self.seed = seed
|
||||
self.action_noise = None
|
||||
|
||||
if env is not None:
|
||||
if isinstance(env, str):
|
||||
|
|
@ -270,6 +271,16 @@ class BaseRLModel(object):
|
|||
obs = self.env.reset()
|
||||
return timesteps_since_eval, episode_num, evaluations, obs, eval_env
|
||||
|
||||
def _update_info_buffer(self, infos):
|
||||
"""
|
||||
Retrieve reward and episode length if using Monitor wrapper.
|
||||
:param infos: ([dict])
|
||||
"""
|
||||
for info in infos:
|
||||
maybe_ep_info = info.get('episode')
|
||||
if maybe_ep_info is not None:
|
||||
self.ep_info_buffer.extend([maybe_ep_info])
|
||||
|
||||
def collect_rollouts(self, env, n_episodes=1, n_steps=-1, action_noise=None,
|
||||
deterministic=False, callback=None,
|
||||
learning_starts=0, num_timesteps=0,
|
||||
|
|
@ -310,10 +321,7 @@ class BaseRLModel(object):
|
|||
episode_reward += reward
|
||||
|
||||
# Retrieve reward and episode length if using Monitor wrapper
|
||||
for info in infos:
|
||||
maybe_ep_info = info.get('episode')
|
||||
if maybe_ep_info is not None:
|
||||
self.ep_info_buffer.extend([maybe_ep_info])
|
||||
self._update_info_buffer(infos)
|
||||
|
||||
# Store data in replay buffer
|
||||
if replay_buffer is not None:
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from torchy_baselines.common.evaluation import evaluate_policy
|
|||
from torchy_baselines.common.buffers import RolloutBuffer
|
||||
from torchy_baselines.common.utils import explained_variance
|
||||
from torchy_baselines.common.vec_env import VecNormalize
|
||||
from torchy_baselines.common import logger
|
||||
from torchy_baselines.ppo.policies import PPOPolicy
|
||||
|
||||
|
||||
|
|
@ -153,8 +154,9 @@ class PPO(BaseRLModel):
|
|||
# Clip the actions to avoid out of bound error
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
||||
new_obs, rewards, dones, _ = env.step(clipped_actions)
|
||||
new_obs, rewards, dones, infos = env.step(clipped_actions)
|
||||
|
||||
self._update_info_buffer(infos)
|
||||
n_steps += 1
|
||||
rollout_buffer.add(obs, actions, rewards, dones, values, log_probs)
|
||||
obs = new_obs
|
||||
|
|
@ -219,15 +221,10 @@ class PPO(BaseRLModel):
|
|||
# print(explained_variance(self.rollout_buffer.returns.flatten().cpu().numpy(),
|
||||
# self.rollout_buffer.values.flatten().cpu().numpy()))
|
||||
|
||||
def learn(self, total_timesteps, callback=None, log_interval=100,
|
||||
def learn(self, total_timesteps, callback=None, log_interval=1,
|
||||
eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="PPO", reset_num_timesteps=True):
|
||||
|
||||
timesteps_since_eval = 0
|
||||
episode_num = 0
|
||||
evaluations = []
|
||||
start_time = time.time()
|
||||
obs = self.env.reset()
|
||||
eval_env = self._get_eval_env(eval_env)
|
||||
timesteps_since_eval, iteration, evaluations, obs, eval_env = self._setup_learn(eval_env)
|
||||
|
||||
if self.tensorboard_log is not None and SummaryWriter is not None:
|
||||
self.tb_writer = SummaryWriter(log_dir=os.path.join(self.tensorboard_log, tb_log_name))
|
||||
|
|
@ -241,10 +238,22 @@ class PPO(BaseRLModel):
|
|||
|
||||
obs = self.collect_rollouts(self.env, self.rollout_buffer, n_rollout_steps=self.n_steps,
|
||||
obs=obs)
|
||||
episode_num += 1
|
||||
iteration += 1
|
||||
self.num_timesteps += self.n_steps * self.n_envs
|
||||
timesteps_since_eval += self.n_steps * self.n_envs
|
||||
|
||||
# Display training infos
|
||||
if self.verbose >= 1 and log_interval is not None and iteration % log_interval == 0:
|
||||
fps = int(self.num_timesteps / (time.time() - self.start_time))
|
||||
logger.logkv("iterations", iteration)
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
logger.logkv('ep_rew_mean', self.safe_mean([ep_info['r'] for ep_info in self.ep_info_buffer]))
|
||||
logger.logkv('ep_len_mean', self.safe_mean([ep_info['l'] for ep_info in self.ep_info_buffer]))
|
||||
logger.logkv("fps", fps)
|
||||
logger.logkv('time_elapsed', int(time.time() - self.start_time))
|
||||
logger.logkv("total timesteps", self.num_timesteps)
|
||||
logger.dumpkvs()
|
||||
|
||||
self.train(self.n_epochs, batch_size=self.batch_size)
|
||||
|
||||
# Evaluate agent
|
||||
|
|
@ -260,7 +269,7 @@ class PPO(BaseRLModel):
|
|||
evaluations.append(mean_reward)
|
||||
if self.verbose > 0:
|
||||
print("Eval num_timesteps={}, mean_reward={:.2f}".format(self.num_timesteps, evaluations[-1]))
|
||||
print("FPS: {:.2f}".format(self.num_timesteps / (time.time() - start_time)))
|
||||
print("FPS: {:.2f}".format(self.num_timesteps / (time.time() - self.start_time)))
|
||||
|
||||
return self
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue