mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-06-01 23:30:43 +00:00
Added optional colour to plot_cross_validation_metric (#1758)
* Added optional colour to plot_cross_validation_metric to facilitate the comparison of different model's performance on the one plot * use color Change from colour to color for consistency. Co-authored-by: Ben Letham <bletham@gmail.com>
This commit is contained in:
parent
20f590b726
commit
9bc7fb77b5
1 changed files with 5 additions and 3 deletions
|
|
@ -469,7 +469,7 @@ def add_changepoints_to_plot(
|
|||
|
||||
|
||||
def plot_cross_validation_metric(
|
||||
df_cv, metric, rolling_window=0.1, ax=None, figsize=(10, 6)
|
||||
df_cv, metric, rolling_window=0.1, ax=None, figsize=(10, 6), color='b'
|
||||
):
|
||||
"""Plot a performance metric vs. forecast horizon from cross validation.
|
||||
|
||||
|
|
@ -498,6 +498,8 @@ def plot_cross_validation_metric(
|
|||
ax: Optional matplotlib axis on which to plot. If not given, a new figure
|
||||
will be created.
|
||||
figsize: Optional tuple width, height in inches.
|
||||
color: Optional color for plot and error points, useful when plotting
|
||||
multiple model performances on one axis for comparison.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
@ -537,8 +539,8 @@ def plot_cross_validation_metric(
|
|||
x_plt = df_none['horizon'].astype('timedelta64[ns]').astype(np.int64) / float(dt_conversions[i])
|
||||
x_plt_h = df_h['horizon'].astype('timedelta64[ns]').astype(np.int64) / float(dt_conversions[i])
|
||||
|
||||
ax.plot(x_plt, df_none[metric], '.', alpha=0.5, c='gray')
|
||||
ax.plot(x_plt_h, df_h[metric], '-', c='b')
|
||||
ax.plot(x_plt, df_none[metric], '.', alpha=0.1, c=color)
|
||||
ax.plot(x_plt_h, df_h[metric], '-', c=color)
|
||||
ax.grid(True)
|
||||
|
||||
ax.set_xlabel('Horizon ({})'.format(dt_names[i]))
|
||||
|
|
|
|||
Loading…
Reference in a new issue