mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-07-03 03:59:00 +00:00
modified cross_validation to allow custom cutoffs (#1402)
* modified cross_validation to allow custom cutoffs * moved set period, initials and identify larg. seas * modified the diagnostics and added the test * reverted cv default value tests and added a new custom cutoff test * reorganized to raise the seasonality period warning message even if cutoffs are manually specified * moved the initials vs. seasonality check * changed assertCountEqual to assertItemsEqual in cv * modified to test lengths instread of cutoff values Co-authored-by: Fusi Marco <Marco.Fusi@valuelab.it>
This commit is contained in:
parent
d22922d08c
commit
3c69ce3312
2 changed files with 49 additions and 22 deletions
|
|
@ -59,12 +59,12 @@ def generate_cutoffs(df, horizon, initial, period):
|
|||
return list(reversed(result))
|
||||
|
||||
|
||||
def cross_validation(model, horizon, period=None, initial=None, multiprocess=False):
|
||||
def cross_validation(model, horizon, period=None, initial=None, multiprocess=False, cutoffs=None):
|
||||
"""Cross-Validation for time series.
|
||||
|
||||
Computes forecasts from historical cutoff points. Beginning from
|
||||
(end - horizon), works backwards making cutoffs with a spacing of period
|
||||
until initial is reached.
|
||||
Computes forecasts from historical cutoff points, which user can input.
|
||||
If not provided beginning from (end - horizon), works backwards making
|
||||
cutoffs with a spacing of period until initial is reached.
|
||||
|
||||
When period is equal to the time interval of the data, this is the
|
||||
technique described in https://robjhyndman.com/hyndsight/tscv/ .
|
||||
|
|
@ -78,6 +78,10 @@ def cross_validation(model, horizon, period=None, initial=None, multiprocess=Fal
|
|||
be done at every this period. If not provided, 0.5 * horizon is used.
|
||||
initial: string with pd.Timedelta compatible style. The first training
|
||||
period will begin here. If not provided, 3 * horizon is used.
|
||||
cutoffs: list of pd.Timestamp representing cutoff to be used during
|
||||
cross-validtation. If not provided works beginning from
|
||||
(end - horizon), works backwards making cutoffs with a spacing of period
|
||||
until initial is reached.
|
||||
multiprocess: True, False, Optional (defaults to False). If `True`, use the
|
||||
`multiprocessing` module to distribute each task to a different processor
|
||||
core.
|
||||
|
|
@ -89,29 +93,39 @@ def cross_validation(model, horizon, period=None, initial=None, multiprocess=Fal
|
|||
|
||||
df = model.history.copy().reset_index(drop=True)
|
||||
horizon = pd.Timedelta(horizon)
|
||||
# Set period
|
||||
period = 0.5 * horizon if period is None else pd.Timedelta(period)
|
||||
# Identify largest seasonality period
|
||||
period_max = 0.
|
||||
for s in model.seasonalities.values():
|
||||
period_max = max(period_max, s['period'])
|
||||
seasonality_dt = pd.Timedelta(str(period_max) + ' days')
|
||||
# Set initial
|
||||
if initial is None:
|
||||
initial = max(3 * horizon, seasonality_dt)
|
||||
else:
|
||||
initial = pd.Timedelta(initial)
|
||||
if initial < seasonality_dt:
|
||||
msg = 'Seasonality has period of {} days '.format(period_max)
|
||||
msg += 'which is larger than initial window. '
|
||||
msg += 'Consider increasing initial.'
|
||||
logger.warning(msg)
|
||||
|
||||
predict_columns = ['ds', 'yhat']
|
||||
if model.uncertainty_samples:
|
||||
predict_columns.extend(['yhat_lower', 'yhat_upper'])
|
||||
|
||||
# Identify largest seasonality period
|
||||
period_max = 0.
|
||||
for s in model.seasonalities.values():
|
||||
period_max = max(period_max, s['period'])
|
||||
seasonality_dt = pd.Timedelta(str(period_max) + ' days')
|
||||
|
||||
cutoffs = generate_cutoffs(df, horizon, initial, period)
|
||||
if cutoffs is None:
|
||||
# Set period
|
||||
period = 0.5 * horizon if period is None else pd.Timedelta(period)
|
||||
|
||||
# Set initial
|
||||
if initial is None:
|
||||
initial = max(3 * horizon, seasonality_dt)
|
||||
else:
|
||||
initial = pd.Timedelta(initial)
|
||||
# Compute Cutoffs
|
||||
cutoffs = generate_cutoffs(df, horizon, initial, period)
|
||||
else:
|
||||
initial = cutoffs[0] - df['ds'].min()
|
||||
|
||||
# Check if the initial window
|
||||
# (that is, the amount of time between the start of the history and the first cutoff)
|
||||
# is less than the maximum seasonality period
|
||||
if initial < seasonality_dt:
|
||||
msg = 'Seasonality has period of {} days '.format(period_max)
|
||||
msg += 'which is larger than initial window. '
|
||||
msg += 'Consider increasing initial.'
|
||||
logger.warning(msg)
|
||||
|
||||
if multiprocess is True:
|
||||
with Pool() as pool:
|
||||
|
|
|
|||
|
|
@ -131,6 +131,18 @@ class TestDiagnostics(TestCase):
|
|||
self.assertAlmostEqual(
|
||||
((df_cv1['yhat'] - df_cv2['yhat']) ** 2).sum(), 0.0)
|
||||
|
||||
def test_cross_validation_custom_cutoffs(self):
|
||||
m = Prophet()
|
||||
m.fit(self.__df)
|
||||
# When specify a list of cutoffs
|
||||
# the cutoff dates in df_cv are those specified
|
||||
df_cv1 = diagnostics.cross_validation(
|
||||
m,
|
||||
horizon='32 days',
|
||||
period='10 days',
|
||||
cutoffs=[pd.Timestamp('2012-07-31'), pd.Timestamp('2012-08-31')])
|
||||
self.assertEqual(len(df_cv1['cutoff'].unique()), 2)
|
||||
|
||||
def test_cross_validation_uncertainty_disabled(self):
|
||||
df = self.__df.copy()
|
||||
for uncertainty in [0, False]:
|
||||
|
|
@ -313,3 +325,4 @@ class TestDiagnostics(TestCase):
|
|||
self.assertTrue((changepoints == m2.changepoints).all())
|
||||
self.assertTrue('custom' in m2.seasonalities)
|
||||
self.assertTrue('binary_feature' in m2.extra_regressors)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue