diff --git a/R/R/diagnostics.R b/R/R/diagnostics.R index 13f6b2e..2c72a08 100644 --- a/R/R/diagnostics.R +++ b/R/R/diagnostics.R @@ -29,8 +29,11 @@ generate_cutoffs <- function(df, horizon, initial, period) { # If data does not exist in data range (cutoff, cutoff + horizon] if (!any((df$ds > cutoff) & (df$ds <= cutoff + horizon))) { # Next cutoff point is 'closest date before cutoff in data - horizon' - closest.date <- max(df$ds[df$ds <= cutoff]) - cutoff <- closest.date - horizon + if (cutoff > min(df$ds)) { + closest.date <- max(df$ds[df$ds <= cutoff]) + cutoff <- closest.date - horizon + } + # else no data left, leave cutoff as is, it will be dropped. } result <- c(result, cutoff) } @@ -73,19 +76,34 @@ generate_cutoffs <- function(df, horizon, initial, period) { cross_validation <- function( model, horizon, units, period = NULL, initial = NULL) { df <- model$history - te <- max(df$ds) - ts <- min(df$ds) + horizon.dt <- as.difftime(horizon, units = units) + # Set period if (is.null(period)) { period <- 0.5 * horizon } - if (is.null(initial)) { - initial <- 3 * horizon - } - horizon.dt <- as.difftime(horizon, units = units) - initial.dt <- as.difftime(initial, units = units) period.dt <- as.difftime(period, units = units) + # Identify largest seasonality period + period.max <- 0 + for (s in model$seasonalities) { + period.max <- max(period.max, s$period) + } + seasonality.dt <- as.difftime(period.max, units = 'days') + # Set initial + if (is.null(initial)) { + initial.dt <- max( + as.difftime(3 * horizon, units = units), + seasonality.dt + ) + } else { + initial.dt <- as.difftime(initial, units = units) + if (initial.dt < seasonality.dt) { + warning(paste0('Seasonality has period of ', period.max, ' days which ', + 'is larger than initial window. Consider increasing initial.')) + } + } cutoffs <- generate_cutoffs(df, horizon.dt, initial.dt, period.dt) + predicts <- data.frame() for (i in 1:length(cutoffs)) { cutoff <- cutoffs[i] diff --git a/python/fbprophet/diagnostics.py b/python/fbprophet/diagnostics.py index f174b7d..88a0d9d 100644 --- a/python/fbprophet/diagnostics.py +++ b/python/fbprophet/diagnostics.py @@ -42,8 +42,10 @@ def generate_cutoffs(df, horizon, initial, period): # If data does not exist in data range (cutoff, cutoff + horizon] if not (((df['ds'] > cutoff) & (df['ds'] <= cutoff + horizon)).any()): # Next cutoff point is 'last date before cutoff in data - horizon' - closest_date = df[df['ds'] <= cutoff].max()['ds'] - cutoff = closest_date - horizon + if cutoff > df['ds'].min(): + closest_date = df[df['ds'] <= cutoff].max()['ds'] + cutoff = closest_date - horizon + # else no data left, leave cutoff as is, it will be dropped. result.append(cutoff) result = result[:-1] if len(result) == 0: @@ -82,11 +84,24 @@ def cross_validation(model, horizon, period=None, initial=None): A pd.DataFrame with the forecast, actual value and cutoff. """ df = model.history.copy().reset_index(drop=True) - te = df['ds'].max() - ts = df['ds'].min() horizon = pd.Timedelta(horizon) + # Set period period = 0.5 * horizon if period is None else pd.Timedelta(period) - initial = 3 * horizon if initial is None else pd.Timedelta(initial) + # Identify largest seasonality period + period_max = 0. + for s in model.seasonalities.values(): + period_max = max(period_max, s['period']) + seasonality_dt = pd.Timedelta(str(period_max) + ' days') + # Set initial + if initial is None: + initial = max(3 * horizon, seasonality_dt) + else: + initial = pd.Timedelta(initial) + if initial < seasonality_dt: + msg = 'Seasonality has period of {} days '.format(period_max) + msg += 'which is larger than initial window. ' + msg += 'Consider increasing initial.' + logger.warning(msg) cutoffs = generate_cutoffs(df, horizon, initial, period) predicts = []