mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-01 23:30:53 +00:00
Update plotter
This commit is contained in:
parent
cf89cac3e9
commit
d4ddb3d021
1 changed files with 16 additions and 16 deletions
|
|
@ -45,25 +45,25 @@ def window_func(var_1: np.ndarray, var_2: np.ndarray,
|
|||
return var_1[window - 1:], function_on_var2
|
||||
|
||||
|
||||
def ts2xy(timesteps: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray]:
|
||||
def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Decompose a timesteps variable to x ans ys
|
||||
Decompose a data frame variable to x ans ys
|
||||
|
||||
:param timesteps: (pd.DataFrame) the input data
|
||||
:param data_frame: (pd.DataFrame) the input data
|
||||
:param x_axis: (str) the axis for the x and y output
|
||||
(can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs')
|
||||
:return: (Tuple[np.ndarray, np.ndarray]) the x and y output
|
||||
"""
|
||||
if x_axis == X_TIMESTEPS:
|
||||
x_var = np.cumsum(timesteps.l.values)
|
||||
y_var = timesteps.r.values
|
||||
x_var = np.cumsum(data_frame.l.values)
|
||||
y_var = data_frame.r.values
|
||||
elif x_axis == X_EPISODES:
|
||||
x_var = np.arange(len(timesteps))
|
||||
y_var = timesteps.r.values
|
||||
x_var = np.arange(len(data_frame))
|
||||
y_var = data_frame.r.values
|
||||
elif x_axis == X_WALLTIME:
|
||||
# Convert to hours
|
||||
x_var = timesteps.t.values / 3600.
|
||||
y_var = timesteps.r.values
|
||||
x_var = data_frame.t.values / 3600.
|
||||
y_var = data_frame.r.values
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return x_var, y_var
|
||||
|
|
@ -81,7 +81,7 @@ def plot_curves(xy_list: List[Tuple[np.ndarray, np.ndarray]],
|
|||
:param figsize: (Tuple[int, int]) Size of the figure (width, height)
|
||||
"""
|
||||
|
||||
plt.figure(figsize=figsize)
|
||||
plt.figure(title, figsize=figsize)
|
||||
max_x = max(xy[0][-1] for xy in xy_list)
|
||||
min_x = 0
|
||||
for (i, (x, y)) in enumerate(xy_list):
|
||||
|
|
@ -101,7 +101,7 @@ def plot_curves(xy_list: List[Tuple[np.ndarray, np.ndarray]],
|
|||
def plot_results(dirs: List[str], num_timesteps: Optional[int],
|
||||
x_axis: str, task_name: str, figsize: Tuple[int, int] = (8, 2)) -> None:
|
||||
"""
|
||||
plot the results
|
||||
Plot the results using csv files from ``Monitor`` wrapper.
|
||||
|
||||
:param dirs: ([str]) the save location of the results to plot
|
||||
:param num_timesteps: (int or None) only plot the points below this value
|
||||
|
|
@ -111,11 +111,11 @@ def plot_results(dirs: List[str], num_timesteps: Optional[int],
|
|||
:param figsize: (Tuple[int, int]) Size of the figure (width, height)
|
||||
"""
|
||||
|
||||
timesteps_list = []
|
||||
data_frames = []
|
||||
for folder in dirs:
|
||||
timesteps = load_results(folder)
|
||||
data_frame = load_results(folder)
|
||||
if num_timesteps is not None:
|
||||
timesteps = timesteps[timesteps.l.cumsum() <= num_timesteps]
|
||||
timesteps_list.append(timesteps)
|
||||
xy_list = [ts2xy(timesteps_item, x_axis) for timesteps_item in timesteps_list]
|
||||
data_frame = data_frame[data_frame.l.cumsum() <= num_timesteps]
|
||||
data_frames.append(data_frame)
|
||||
xy_list = [ts2xy(data_frame, x_axis) for data_frame in data_frames]
|
||||
plot_curves(xy_list, x_axis, task_name, figsize)
|
||||
|
|
|
|||
Loading…
Reference in a new issue