__all__ = ['Monitor', 'get_monitor_files', 'load_results'] import csv import json import os import time from glob import glob from typing import Tuple, Dict, Any, List, Optional import gym import pandas import numpy as np class Monitor(gym.Wrapper): EXT = "monitor.csv" def __init__(self, env: gym.Env, filename: Optional[str] = None, allow_early_resets: bool = True, reset_keywords: Tuple[str, ...] = (), info_keywords: Tuple[str, ...] = ()): """ A monitor wrapper for Gym environments, it is used to know the episode reward, length, time and other data. :param env: (gym.Env) The environment :param filename: (Optional[str]) the location to save a log file, can be None for no log :param allow_early_resets: (bool) allows the reset of the environment before it is done :param reset_keywords: (Tuple[str, ...]) extra keywords for the reset call, if extra parameters are needed at reset :param info_keywords: (Tuple[str, ...]) extra information to log, from the information return of env.step() """ super(Monitor, self).__init__(env=env) self.t_start = time.time() if filename is None: self.file_handler = None self.logger = None else: if not filename.endswith(Monitor.EXT): if os.path.isdir(filename): filename = os.path.join(filename, Monitor.EXT) else: filename = filename + "." + Monitor.EXT self.file_handler = open(filename, "wt") self.file_handler.write('#%s\n' % json.dumps({"t_start": self.t_start, 'env_id': env.spec and env.spec.id})) self.logger = csv.DictWriter(self.file_handler, fieldnames=('r', 'l', 't') + reset_keywords + info_keywords) self.logger.writeheader() self.file_handler.flush() self.reset_keywords = reset_keywords self.info_keywords = info_keywords self.allow_early_resets = allow_early_resets self.rewards = None self.needs_reset = True self.episode_rewards = [] self.episode_lengths = [] self.episode_times = [] self.total_steps = 0 self.current_reset_info = {} # extra info about the current episode, that was passed in during reset() def reset(self, **kwargs) -> np.ndarray: """ Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True :param kwargs: Extra keywords saved for the next episode. only if defined by reset_keywords :return: (np.ndarray) the first observation of the environment """ if not self.allow_early_resets and not self.needs_reset: raise RuntimeError("Tried to reset an environment before done. If you want to allow early resets, " "wrap your env with Monitor(env, path, allow_early_resets=True)") self.rewards = [] self.needs_reset = False for key in self.reset_keywords: value = kwargs.get(key) if value is None: raise ValueError('Expected you to pass kwarg {} into reset'.format(key)) self.current_reset_info[key] = value return self.env.reset(**kwargs) def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, Dict[Any, Any]]: """ Step the environment with the given action :param action: (np.ndarray) the action :return: (Tuple[np.ndarray, float, bool, Dict[Any, Any]]) observation, reward, done, information """ if self.needs_reset: raise RuntimeError("Tried to step environment that needs reset") observation, reward, done, info = self.env.step(action) self.rewards.append(reward) if done: self.needs_reset = True ep_rew = sum(self.rewards) ep_len = len(self.rewards) ep_info = {"r": round(ep_rew, 6), "l": ep_len, "t": round(time.time() - self.t_start, 6)} for key in self.info_keywords: ep_info[key] = info[key] self.episode_rewards.append(ep_rew) self.episode_lengths.append(ep_len) self.episode_times.append(time.time() - self.t_start) ep_info.update(self.current_reset_info) if self.logger: self.logger.writerow(ep_info) self.file_handler.flush() info['episode'] = ep_info self.total_steps += 1 return observation, reward, done, info def close(self): """ Closes the environment """ super(Monitor, self).close() if self.file_handler is not None: self.file_handler.close() def get_total_steps(self) -> int: """ Returns the total number of timesteps :return: (int) """ return self.total_steps def get_episode_rewards(self) -> List[float]: """ Returns the rewards of all the episodes :return: ([float]) """ return self.episode_rewards def get_episode_lengths(self) -> List[int]: """ Returns the number of timesteps of all the episodes :return: ([int]) """ return self.episode_lengths def get_episode_times(self) -> List[float]: """ Returns the runtime in seconds of all the episodes :return: ([float]) """ return self.episode_times class LoadMonitorResultsError(Exception): """ Raised when loading the monitor log fails. """ pass def get_monitor_files(path: str) -> List[str]: """ get all the monitor files in the given path :param path: (str) the logging folder :return: ([str]) the log files """ return glob(os.path.join(path, "*" + Monitor.EXT)) def load_results(path: str) -> pandas.DataFrame: """ Load all Monitor logs from a given directory path matching ``*monitor.csv`` :param path: (str) the directory path containing the log file(s) :return: (pandas.DataFrame) the logged data """ monitor_files = get_monitor_files(path) if len(monitor_files) == 0: raise LoadMonitorResultsError("no monitor files of the form *%s found in %s" % (Monitor.EXT, path)) data_frames, headers = [], [] for file_name in monitor_files: with open(file_name, 'rt') as file_handler: first_line = file_handler.readline() assert first_line[0] == '#' header = json.loads(first_line[1:]) data_frame = pandas.read_csv(file_handler, index_col=None) headers.append(header) data_frame['t'] += header['t_start'] data_frames.append(data_frame) data_frame = pandas.concat(data_frames) data_frame.sort_values('t', inplace=True) data_frame.reset_index(inplace=True) data_frame['t'] -= min(header['t_start'] for header in headers) return data_frame