From 9bc7fb77b5f6ff238f6441255ca73065c5058ac2 Mon Sep 17 00:00:00 2001 From: Luke Scales <40272781+LukeScales1@users.noreply.github.com> Date: Thu, 7 Jan 2021 10:50:18 -0800 Subject: [PATCH] Added optional colour to plot_cross_validation_metric (#1758) * Added optional colour to plot_cross_validation_metric to facilitate the comparison of different model's performance on the one plot * use color Change from colour to color for consistency. Co-authored-by: Ben Letham --- python/fbprophet/plot.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/fbprophet/plot.py b/python/fbprophet/plot.py index 9c1d4a9..1a23968 100644 --- a/python/fbprophet/plot.py +++ b/python/fbprophet/plot.py @@ -469,7 +469,7 @@ def add_changepoints_to_plot( def plot_cross_validation_metric( - df_cv, metric, rolling_window=0.1, ax=None, figsize=(10, 6) + df_cv, metric, rolling_window=0.1, ax=None, figsize=(10, 6), color='b' ): """Plot a performance metric vs. forecast horizon from cross validation. @@ -498,6 +498,8 @@ def plot_cross_validation_metric( ax: Optional matplotlib axis on which to plot. If not given, a new figure will be created. figsize: Optional tuple width, height in inches. + color: Optional color for plot and error points, useful when plotting + multiple model performances on one axis for comparison. Returns ------- @@ -537,8 +539,8 @@ def plot_cross_validation_metric( x_plt = df_none['horizon'].astype('timedelta64[ns]').astype(np.int64) / float(dt_conversions[i]) x_plt_h = df_h['horizon'].astype('timedelta64[ns]').astype(np.int64) / float(dt_conversions[i]) - ax.plot(x_plt, df_none[metric], '.', alpha=0.5, c='gray') - ax.plot(x_plt_h, df_h[metric], '-', c='b') + ax.plot(x_plt, df_none[metric], '.', alpha=0.1, c=color) + ax.plot(x_plt_h, df_h[metric], '-', c=color) ax.grid(True) ax.set_xlabel('Horizon ({})'.format(dt_names[i]))