diff --git a/python/fbprophet/diagnostics.py b/python/fbprophet/diagnostics.py index 62c5ce1..6e83a7c 100644 --- a/python/fbprophet/diagnostics.py +++ b/python/fbprophet/diagnostics.py @@ -140,6 +140,13 @@ def cross_validation(model, horizon, period=None, initial=None, parallel=None, c # Compute Cutoffs cutoffs = generate_cutoffs(df, horizon, initial, period) else: + # add validation of the cutoff to make sure that the min cutoff is strictly greater than the min date in the history + if min(cutoffs) <= df['ds'].min(): + raise ValueError("Minimum cutoff value is not strictly greater than min date in history") + # max value of cutoffs is <= (end date minus horizon) + end_date_minus_horizon = df['ds'].max() - horizon + if max(cutoffs) > end_date_minus_horizon: + raise ValueError("Maximum cutoff value is greater than end date minus horizon, no value for cross-validation remaining") initial = cutoffs[0] - df['ds'].min() # Check if the initial window