diff --git a/python/fbprophet/diagnostics.py b/python/fbprophet/diagnostics.py index 6e83a7c..7b4a5b3 100644 --- a/python/fbprophet/diagnostics.py +++ b/python/fbprophet/diagnostics.py @@ -131,12 +131,13 @@ def cross_validation(model, horizon, period=None, initial=None, parallel=None, c if cutoffs is None: # Set period period = 0.5 * horizon if period is None else pd.Timedelta(period) - + # Set initial - if initial is None: - initial = max(3 * horizon, seasonality_dt) - else: - initial = pd.Timedelta(initial) + initial = ( + max(3 * horizon, seasonality_dt) if initial is None + else pd.Timedelta(initial) + ) + # Compute Cutoffs cutoffs = generate_cutoffs(df, horizon, initial, period) else: