Update plotter

This commit is contained in:
Antonin RAFFIN 2020-03-16 12:04:57 +01:00
parent cf89cac3e9
commit d4ddb3d021

View file

@ -45,25 +45,25 @@ def window_func(var_1: np.ndarray, var_2: np.ndarray,
return var_1[window - 1:], function_on_var2
def ts2xy(timesteps: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray]:
def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray]:
"""
Decompose a timesteps variable to x ans ys
Decompose a data frame variable to x ans ys
:param timesteps: (pd.DataFrame) the input data
:param data_frame: (pd.DataFrame) the input data
:param x_axis: (str) the axis for the x and y output
(can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs')
:return: (Tuple[np.ndarray, np.ndarray]) the x and y output
"""
if x_axis == X_TIMESTEPS:
x_var = np.cumsum(timesteps.l.values)
y_var = timesteps.r.values
x_var = np.cumsum(data_frame.l.values)
y_var = data_frame.r.values
elif x_axis == X_EPISODES:
x_var = np.arange(len(timesteps))
y_var = timesteps.r.values
x_var = np.arange(len(data_frame))
y_var = data_frame.r.values
elif x_axis == X_WALLTIME:
# Convert to hours
x_var = timesteps.t.values / 3600.
y_var = timesteps.r.values
x_var = data_frame.t.values / 3600.
y_var = data_frame.r.values
else:
raise NotImplementedError
return x_var, y_var
@ -81,7 +81,7 @@ def plot_curves(xy_list: List[Tuple[np.ndarray, np.ndarray]],
:param figsize: (Tuple[int, int]) Size of the figure (width, height)
"""
plt.figure(figsize=figsize)
plt.figure(title, figsize=figsize)
max_x = max(xy[0][-1] for xy in xy_list)
min_x = 0
for (i, (x, y)) in enumerate(xy_list):
@ -101,7 +101,7 @@ def plot_curves(xy_list: List[Tuple[np.ndarray, np.ndarray]],
def plot_results(dirs: List[str], num_timesteps: Optional[int],
x_axis: str, task_name: str, figsize: Tuple[int, int] = (8, 2)) -> None:
"""
plot the results
Plot the results using csv files from ``Monitor`` wrapper.
:param dirs: ([str]) the save location of the results to plot
:param num_timesteps: (int or None) only plot the points below this value
@ -111,11 +111,11 @@ def plot_results(dirs: List[str], num_timesteps: Optional[int],
:param figsize: (Tuple[int, int]) Size of the figure (width, height)
"""
timesteps_list = []
data_frames = []
for folder in dirs:
timesteps = load_results(folder)
data_frame = load_results(folder)
if num_timesteps is not None:
timesteps = timesteps[timesteps.l.cumsum() <= num_timesteps]
timesteps_list.append(timesteps)
xy_list = [ts2xy(timesteps_item, x_axis) for timesteps_item in timesteps_list]
data_frame = data_frame[data_frame.l.cumsum() <= num_timesteps]
data_frames.append(data_frame)
xy_list = [ts2xy(data_frame, x_axis) for data_frame in data_frames]
plot_curves(xy_list, x_axis, task_name, figsize)