modified cross_validation to allow custom cutoffs (#1402)

* modified cross_validation to allow custom cutoffs

* moved set period, initials and identify larg. seas

* modified the diagnostics and added the test

* reverted cv default value tests and added a new custom cutoff test

* reorganized to raise the seasonality period warning message even if cutoffs are manually specified

* moved the initials vs. seasonality check

* changed assertCountEqual to assertItemsEqual in cv

* modified to test lengths instread of cutoff values

Co-authored-by: Fusi Marco <Marco.Fusi@valuelab.it>
This commit is contained in:
Marco Fusi 2020-03-27 00:36:02 +01:00 committed by GitHub
parent d22922d08c
commit 3c69ce3312
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 49 additions and 22 deletions

View file

@ -59,12 +59,12 @@ def generate_cutoffs(df, horizon, initial, period):
return list(reversed(result))
def cross_validation(model, horizon, period=None, initial=None, multiprocess=False):
def cross_validation(model, horizon, period=None, initial=None, multiprocess=False, cutoffs=None):
"""Cross-Validation for time series.
Computes forecasts from historical cutoff points. Beginning from
(end - horizon), works backwards making cutoffs with a spacing of period
until initial is reached.
Computes forecasts from historical cutoff points, which user can input.
If not provided beginning from (end - horizon), works backwards making
cutoffs with a spacing of period until initial is reached.
When period is equal to the time interval of the data, this is the
technique described in https://robjhyndman.com/hyndsight/tscv/ .
@ -78,6 +78,10 @@ def cross_validation(model, horizon, period=None, initial=None, multiprocess=Fal
be done at every this period. If not provided, 0.5 * horizon is used.
initial: string with pd.Timedelta compatible style. The first training
period will begin here. If not provided, 3 * horizon is used.
cutoffs: list of pd.Timestamp representing cutoff to be used during
cross-validtation. If not provided works beginning from
(end - horizon), works backwards making cutoffs with a spacing of period
until initial is reached.
multiprocess: True, False, Optional (defaults to False). If `True`, use the
`multiprocessing` module to distribute each task to a different processor
core.
@ -89,29 +93,39 @@ def cross_validation(model, horizon, period=None, initial=None, multiprocess=Fal
df = model.history.copy().reset_index(drop=True)
horizon = pd.Timedelta(horizon)
# Set period
period = 0.5 * horizon if period is None else pd.Timedelta(period)
# Identify largest seasonality period
period_max = 0.
for s in model.seasonalities.values():
period_max = max(period_max, s['period'])
seasonality_dt = pd.Timedelta(str(period_max) + ' days')
# Set initial
if initial is None:
initial = max(3 * horizon, seasonality_dt)
else:
initial = pd.Timedelta(initial)
if initial < seasonality_dt:
msg = 'Seasonality has period of {} days '.format(period_max)
msg += 'which is larger than initial window. '
msg += 'Consider increasing initial.'
logger.warning(msg)
predict_columns = ['ds', 'yhat']
if model.uncertainty_samples:
predict_columns.extend(['yhat_lower', 'yhat_upper'])
# Identify largest seasonality period
period_max = 0.
for s in model.seasonalities.values():
period_max = max(period_max, s['period'])
seasonality_dt = pd.Timedelta(str(period_max) + ' days')
cutoffs = generate_cutoffs(df, horizon, initial, period)
if cutoffs is None:
# Set period
period = 0.5 * horizon if period is None else pd.Timedelta(period)
# Set initial
if initial is None:
initial = max(3 * horizon, seasonality_dt)
else:
initial = pd.Timedelta(initial)
# Compute Cutoffs
cutoffs = generate_cutoffs(df, horizon, initial, period)
else:
initial = cutoffs[0] - df['ds'].min()
# Check if the initial window
# (that is, the amount of time between the start of the history and the first cutoff)
# is less than the maximum seasonality period
if initial < seasonality_dt:
msg = 'Seasonality has period of {} days '.format(period_max)
msg += 'which is larger than initial window. '
msg += 'Consider increasing initial.'
logger.warning(msg)
if multiprocess is True:
with Pool() as pool:

View file

@ -131,6 +131,18 @@ class TestDiagnostics(TestCase):
self.assertAlmostEqual(
((df_cv1['yhat'] - df_cv2['yhat']) ** 2).sum(), 0.0)
def test_cross_validation_custom_cutoffs(self):
m = Prophet()
m.fit(self.__df)
# When specify a list of cutoffs
# the cutoff dates in df_cv are those specified
df_cv1 = diagnostics.cross_validation(
m,
horizon='32 days',
period='10 days',
cutoffs=[pd.Timestamp('2012-07-31'), pd.Timestamp('2012-08-31')])
self.assertEqual(len(df_cv1['cutoff'].unique()), 2)
def test_cross_validation_uncertainty_disabled(self):
df = self.__df.copy()
for uncertainty in [0, False]:
@ -313,3 +325,4 @@ class TestDiagnostics(TestCase):
self.assertTrue((changepoints == m2.changepoints).all())
self.assertTrue('custom' in m2.seasonalities)
self.assertTrue('binary_feature' in m2.extra_regressors)