diff --git a/python/fbprophet/diagnostics.py b/python/fbprophet/diagnostics.py index 7dab24b..313d086 100644 --- a/python/fbprophet/diagnostics.py +++ b/python/fbprophet/diagnostics.py @@ -59,12 +59,12 @@ def generate_cutoffs(df, horizon, initial, period): return list(reversed(result)) -def cross_validation(model, horizon, period=None, initial=None, multiprocess=False): +def cross_validation(model, horizon, period=None, initial=None, multiprocess=False, cutoffs=None): """Cross-Validation for time series. - Computes forecasts from historical cutoff points. Beginning from - (end - horizon), works backwards making cutoffs with a spacing of period - until initial is reached. + Computes forecasts from historical cutoff points, which user can input. + If not provided beginning from (end - horizon), works backwards making + cutoffs with a spacing of period until initial is reached. When period is equal to the time interval of the data, this is the technique described in https://robjhyndman.com/hyndsight/tscv/ . @@ -78,6 +78,10 @@ def cross_validation(model, horizon, period=None, initial=None, multiprocess=Fal be done at every this period. If not provided, 0.5 * horizon is used. initial: string with pd.Timedelta compatible style. The first training period will begin here. If not provided, 3 * horizon is used. + cutoffs: list of pd.Timestamp representing cutoff to be used during + cross-validtation. If not provided works beginning from + (end - horizon), works backwards making cutoffs with a spacing of period + until initial is reached. multiprocess: True, False, Optional (defaults to False). If `True`, use the `multiprocessing` module to distribute each task to a different processor core. @@ -89,29 +93,39 @@ def cross_validation(model, horizon, period=None, initial=None, multiprocess=Fal df = model.history.copy().reset_index(drop=True) horizon = pd.Timedelta(horizon) - # Set period - period = 0.5 * horizon if period is None else pd.Timedelta(period) - # Identify largest seasonality period - period_max = 0. - for s in model.seasonalities.values(): - period_max = max(period_max, s['period']) - seasonality_dt = pd.Timedelta(str(period_max) + ' days') - # Set initial - if initial is None: - initial = max(3 * horizon, seasonality_dt) - else: - initial = pd.Timedelta(initial) - if initial < seasonality_dt: - msg = 'Seasonality has period of {} days '.format(period_max) - msg += 'which is larger than initial window. ' - msg += 'Consider increasing initial.' - logger.warning(msg) predict_columns = ['ds', 'yhat'] if model.uncertainty_samples: predict_columns.extend(['yhat_lower', 'yhat_upper']) + + # Identify largest seasonality period + period_max = 0. + for s in model.seasonalities.values(): + period_max = max(period_max, s['period']) + seasonality_dt = pd.Timedelta(str(period_max) + ' days') - cutoffs = generate_cutoffs(df, horizon, initial, period) + if cutoffs is None: + # Set period + period = 0.5 * horizon if period is None else pd.Timedelta(period) + + # Set initial + if initial is None: + initial = max(3 * horizon, seasonality_dt) + else: + initial = pd.Timedelta(initial) + # Compute Cutoffs + cutoffs = generate_cutoffs(df, horizon, initial, period) + else: + initial = cutoffs[0] - df['ds'].min() + + # Check if the initial window + # (that is, the amount of time between the start of the history and the first cutoff) + # is less than the maximum seasonality period + if initial < seasonality_dt: + msg = 'Seasonality has period of {} days '.format(period_max) + msg += 'which is larger than initial window. ' + msg += 'Consider increasing initial.' + logger.warning(msg) if multiprocess is True: with Pool() as pool: diff --git a/python/fbprophet/tests/test_diagnostics.py b/python/fbprophet/tests/test_diagnostics.py index a21971a..a635218 100644 --- a/python/fbprophet/tests/test_diagnostics.py +++ b/python/fbprophet/tests/test_diagnostics.py @@ -131,6 +131,18 @@ class TestDiagnostics(TestCase): self.assertAlmostEqual( ((df_cv1['yhat'] - df_cv2['yhat']) ** 2).sum(), 0.0) + def test_cross_validation_custom_cutoffs(self): + m = Prophet() + m.fit(self.__df) + # When specify a list of cutoffs + # the cutoff dates in df_cv are those specified + df_cv1 = diagnostics.cross_validation( + m, + horizon='32 days', + period='10 days', + cutoffs=[pd.Timestamp('2012-07-31'), pd.Timestamp('2012-08-31')]) + self.assertEqual(len(df_cv1['cutoff'].unique()), 2) + def test_cross_validation_uncertainty_disabled(self): df = self.__df.copy() for uncertainty in [0, False]: @@ -313,3 +325,4 @@ class TestDiagnostics(TestCase): self.assertTrue((changepoints == m2.changepoints).all()) self.assertTrue('custom' in m2.seasonalities) self.assertTrue('binary_feature' in m2.extra_regressors) +