mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-05-14 20:48:08 +00:00
Add a visualiztion of cross validation prediction performance vs. horizon
This commit is contained in:
parent
7179ae3a38
commit
8198afe17a
3 changed files with 334 additions and 63 deletions
File diff suppressed because one or more lines are too long
|
|
@ -196,7 +196,7 @@ def prophet_copy(m, cutoff=None):
|
|||
return m2
|
||||
|
||||
|
||||
def performance_metrics(df, metrics=None, rolling_window=0.05):
|
||||
def performance_metrics(df, metrics=None, rolling_window=0.1):
|
||||
"""Compute performance metrics from cross-validation results.
|
||||
|
||||
Computes a suite of performance metrics on the output of cross-validation.
|
||||
|
|
@ -216,7 +216,7 @@ def performance_metrics(df, metrics=None, rolling_window=0.05):
|
|||
which specifies a proportion of simulated forecast points to include in
|
||||
each window. rolling_window=0 will compute it separately for each simulated
|
||||
forecast point (i.e., 'mse' will actually be squared error with no mean).
|
||||
The default of rolling_window=0.05 will use 5% of the rows in df in each
|
||||
The default of rolling_window=0.1 will use 10% of the rows in df in each
|
||||
window. rolling_window=1 will compute the metric across all simulated forecast
|
||||
points. The results are set to the right edge of the window.
|
||||
|
||||
|
|
@ -227,9 +227,9 @@ def performance_metrics(df, metrics=None, rolling_window=0.05):
|
|||
----------
|
||||
df: The dataframe returned by cross_validation.
|
||||
metrics: A list of performance metrics to compute. If not provided, will
|
||||
use ['mse', 'mae', 'mape', 'coverage', 'rmse'].
|
||||
use ['mse', 'rmse', 'mae', 'mape', 'coverage'].
|
||||
rolling_window: Proportion of data to use in each rolling window for
|
||||
computing the metrics.
|
||||
computing the metrics. Should be in [0, 1].
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
|
|||
|
|
@ -15,6 +15,9 @@ import logging
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from fbprophet.diagnostics import performance_metrics
|
||||
|
||||
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -367,3 +370,78 @@ def add_changepoints_to_plot(
|
|||
for cp in signif_changepoints:
|
||||
artists.append(ax.axvline(x=cp, c=cp_color, ls=cp_linestyle))
|
||||
return artists
|
||||
|
||||
|
||||
def plot_cross_validation_metric(df_cv, metric, rolling_window=0.1, ax=None):
|
||||
"""Plot a performance metric vs. forecast horizon from cross validation.
|
||||
|
||||
Cross validation produces a collection of out-of-sample model predictions
|
||||
that can be compared to actual values, at a range of different horizons
|
||||
(distance from the cutoff). This computes a specified performance metric
|
||||
for each prediction, and aggregated over a rolling window with horizon.
|
||||
|
||||
This uses fbprophet.diagnostics.performance_metrics to compute the metrics.
|
||||
Valid values of metric are 'mse', 'rmse', 'mae', 'mape', and 'coverage'.
|
||||
|
||||
rolling_window is the proportion of data included in the rolling window of
|
||||
aggregation. The default value of 0.1 means 10% of data are included in the
|
||||
aggregation for computing the metric.
|
||||
|
||||
As a concrete example, if metric='mse', then this plot will show the
|
||||
squared error for each cross validation prediction, along with the MSE
|
||||
averaged over rolling windows of 10% of the data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df_cv: The output from fbprophet.diagnostics.cross_validation.
|
||||
metric: Metric name, one of ['mse', 'rmse', 'mae', 'mape', 'coverage'].
|
||||
rolling_window: Proportion of data to use for rolling average of metric.
|
||||
In [0, 1]. Defaults to 0.1.
|
||||
ax: Optional matplotlib axis on which to plot. If not given, a new figure
|
||||
will be created.
|
||||
|
||||
Returns
|
||||
-------
|
||||
a matplotlib figure.
|
||||
"""
|
||||
if ax is None:
|
||||
fig = plt.figure(facecolor='w', figsize=(10, 6))
|
||||
ax = fig.add_subplot(111)
|
||||
else:
|
||||
fig = ax.get_figure()
|
||||
# Get the metric at the level of individual predictions, and with the rolling window.
|
||||
df_none = performance_metrics(df_cv, metrics=[metric], rolling_window=0)
|
||||
df_h = performance_metrics(df_cv, metrics=[metric], rolling_window=rolling_window)
|
||||
|
||||
# Some work because matplotlib does not handle timedelta
|
||||
# Target ~10 ticks.
|
||||
tick_w = max(df_none['horizon'].astype('timedelta64[ns]')) / 10.
|
||||
# Find the largest time resolution that has <1 unit per bin.
|
||||
dts = ['D', 'h', 'm', 's', 'ms', 'us', 'ns']
|
||||
dt_names = [
|
||||
'days', 'hours', 'minutes', 'seconds', 'milliseconds', 'microseconds',
|
||||
'nanoseconds'
|
||||
]
|
||||
dt_conversions = [
|
||||
24 * 60 * 60 * 10 ** 9,
|
||||
60 * 60 * 10 ** 9,
|
||||
60 * 10 ** 9,
|
||||
10 ** 9,
|
||||
10 ** 6,
|
||||
10 ** 3,
|
||||
1.,
|
||||
]
|
||||
for i, dt in enumerate(dts):
|
||||
if np.timedelta64(1, dt) < np.timedelta64(tick_w, 'ns'):
|
||||
break
|
||||
|
||||
x_plt = df_none['horizon'].astype('timedelta64[ns]').astype(int) / float(dt_conversions[i])
|
||||
x_plt_h = df_h['horizon'].astype('timedelta64[ns]').astype(int) / float(dt_conversions[i])
|
||||
|
||||
ax.plot(x_plt, df_none[metric], '.', alpha=0.5, c='gray')
|
||||
ax.plot(x_plt_h, df_h[metric], '-', c='b')
|
||||
ax.grid(True)
|
||||
|
||||
ax.set_xlabel('Horizon ({})'.format(dt_names[i]))
|
||||
ax.set_ylabel(metric)
|
||||
return fig
|
||||
|
|
|
|||
Loading…
Reference in a new issue