Add a visualiztion of cross validation prediction performance vs. horizon

This commit is contained in:
Ben Letham 2018-05-04 11:21:40 -07:00
parent 7179ae3a38
commit 8198afe17a
3 changed files with 334 additions and 63 deletions

File diff suppressed because one or more lines are too long

View file

@ -196,7 +196,7 @@ def prophet_copy(m, cutoff=None):
return m2
def performance_metrics(df, metrics=None, rolling_window=0.05):
def performance_metrics(df, metrics=None, rolling_window=0.1):
"""Compute performance metrics from cross-validation results.
Computes a suite of performance metrics on the output of cross-validation.
@ -216,7 +216,7 @@ def performance_metrics(df, metrics=None, rolling_window=0.05):
which specifies a proportion of simulated forecast points to include in
each window. rolling_window=0 will compute it separately for each simulated
forecast point (i.e., 'mse' will actually be squared error with no mean).
The default of rolling_window=0.05 will use 5% of the rows in df in each
The default of rolling_window=0.1 will use 10% of the rows in df in each
window. rolling_window=1 will compute the metric across all simulated forecast
points. The results are set to the right edge of the window.
@ -227,9 +227,9 @@ def performance_metrics(df, metrics=None, rolling_window=0.05):
----------
df: The dataframe returned by cross_validation.
metrics: A list of performance metrics to compute. If not provided, will
use ['mse', 'mae', 'mape', 'coverage', 'rmse'].
use ['mse', 'rmse', 'mae', 'mape', 'coverage'].
rolling_window: Proportion of data to use in each rolling window for
computing the metrics.
computing the metrics. Should be in [0, 1].
Returns
-------

View file

@ -15,6 +15,9 @@ import logging
import numpy as np
import pandas as pd
from fbprophet.diagnostics import performance_metrics
logging.basicConfig()
logger = logging.getLogger(__name__)
@ -367,3 +370,78 @@ def add_changepoints_to_plot(
for cp in signif_changepoints:
artists.append(ax.axvline(x=cp, c=cp_color, ls=cp_linestyle))
return artists
def plot_cross_validation_metric(df_cv, metric, rolling_window=0.1, ax=None):
"""Plot a performance metric vs. forecast horizon from cross validation.
Cross validation produces a collection of out-of-sample model predictions
that can be compared to actual values, at a range of different horizons
(distance from the cutoff). This computes a specified performance metric
for each prediction, and aggregated over a rolling window with horizon.
This uses fbprophet.diagnostics.performance_metrics to compute the metrics.
Valid values of metric are 'mse', 'rmse', 'mae', 'mape', and 'coverage'.
rolling_window is the proportion of data included in the rolling window of
aggregation. The default value of 0.1 means 10% of data are included in the
aggregation for computing the metric.
As a concrete example, if metric='mse', then this plot will show the
squared error for each cross validation prediction, along with the MSE
averaged over rolling windows of 10% of the data.
Parameters
----------
df_cv: The output from fbprophet.diagnostics.cross_validation.
metric: Metric name, one of ['mse', 'rmse', 'mae', 'mape', 'coverage'].
rolling_window: Proportion of data to use for rolling average of metric.
In [0, 1]. Defaults to 0.1.
ax: Optional matplotlib axis on which to plot. If not given, a new figure
will be created.
Returns
-------
a matplotlib figure.
"""
if ax is None:
fig = plt.figure(facecolor='w', figsize=(10, 6))
ax = fig.add_subplot(111)
else:
fig = ax.get_figure()
# Get the metric at the level of individual predictions, and with the rolling window.
df_none = performance_metrics(df_cv, metrics=[metric], rolling_window=0)
df_h = performance_metrics(df_cv, metrics=[metric], rolling_window=rolling_window)
# Some work because matplotlib does not handle timedelta
# Target ~10 ticks.
tick_w = max(df_none['horizon'].astype('timedelta64[ns]')) / 10.
# Find the largest time resolution that has <1 unit per bin.
dts = ['D', 'h', 'm', 's', 'ms', 'us', 'ns']
dt_names = [
'days', 'hours', 'minutes', 'seconds', 'milliseconds', 'microseconds',
'nanoseconds'
]
dt_conversions = [
24 * 60 * 60 * 10 ** 9,
60 * 60 * 10 ** 9,
60 * 10 ** 9,
10 ** 9,
10 ** 6,
10 ** 3,
1.,
]
for i, dt in enumerate(dts):
if np.timedelta64(1, dt) < np.timedelta64(tick_w, 'ns'):
break
x_plt = df_none['horizon'].astype('timedelta64[ns]').astype(int) / float(dt_conversions[i])
x_plt_h = df_h['horizon'].astype('timedelta64[ns]').astype(int) / float(dt_conversions[i])
ax.plot(x_plt, df_none[metric], '.', alpha=0.5, c='gray')
ax.plot(x_plt_h, df_h[metric], '-', c='b')
ax.grid(True)
ax.set_xlabel('Horizon ({})'.format(dt_names[i]))
ax.set_ylabel(metric)
return fig