diff --git a/torchy_baselines/common/results_plotter.py b/torchy_baselines/common/results_plotter.py index 6c2cd6b..d447344 100644 --- a/torchy_baselines/common/results_plotter.py +++ b/torchy_baselines/common/results_plotter.py @@ -45,25 +45,25 @@ def window_func(var_1: np.ndarray, var_2: np.ndarray, return var_1[window - 1:], function_on_var2 -def ts2xy(timesteps: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray]: +def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray]: """ - Decompose a timesteps variable to x ans ys + Decompose a data frame variable to x ans ys - :param timesteps: (pd.DataFrame) the input data + :param data_frame: (pd.DataFrame) the input data :param x_axis: (str) the axis for the x and y output (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs') :return: (Tuple[np.ndarray, np.ndarray]) the x and y output """ if x_axis == X_TIMESTEPS: - x_var = np.cumsum(timesteps.l.values) - y_var = timesteps.r.values + x_var = np.cumsum(data_frame.l.values) + y_var = data_frame.r.values elif x_axis == X_EPISODES: - x_var = np.arange(len(timesteps)) - y_var = timesteps.r.values + x_var = np.arange(len(data_frame)) + y_var = data_frame.r.values elif x_axis == X_WALLTIME: # Convert to hours - x_var = timesteps.t.values / 3600. - y_var = timesteps.r.values + x_var = data_frame.t.values / 3600. + y_var = data_frame.r.values else: raise NotImplementedError return x_var, y_var @@ -81,7 +81,7 @@ def plot_curves(xy_list: List[Tuple[np.ndarray, np.ndarray]], :param figsize: (Tuple[int, int]) Size of the figure (width, height) """ - plt.figure(figsize=figsize) + plt.figure(title, figsize=figsize) max_x = max(xy[0][-1] for xy in xy_list) min_x = 0 for (i, (x, y)) in enumerate(xy_list): @@ -101,7 +101,7 @@ def plot_curves(xy_list: List[Tuple[np.ndarray, np.ndarray]], def plot_results(dirs: List[str], num_timesteps: Optional[int], x_axis: str, task_name: str, figsize: Tuple[int, int] = (8, 2)) -> None: """ - plot the results + Plot the results using csv files from ``Monitor`` wrapper. :param dirs: ([str]) the save location of the results to plot :param num_timesteps: (int or None) only plot the points below this value @@ -111,11 +111,11 @@ def plot_results(dirs: List[str], num_timesteps: Optional[int], :param figsize: (Tuple[int, int]) Size of the figure (width, height) """ - timesteps_list = [] + data_frames = [] for folder in dirs: - timesteps = load_results(folder) + data_frame = load_results(folder) if num_timesteps is not None: - timesteps = timesteps[timesteps.l.cumsum() <= num_timesteps] - timesteps_list.append(timesteps) - xy_list = [ts2xy(timesteps_item, x_axis) for timesteps_item in timesteps_list] + data_frame = data_frame[data_frame.l.cumsum() <= num_timesteps] + data_frames.append(data_frame) + xy_list = [ts2xy(data_frame, x_axis) for data_frame in data_frames] plot_curves(xy_list, x_axis, task_name, figsize)