mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-05-18 21:21:22 +00:00
Warn in cross validation if initial period is less than seasonality (#557), and fix bug that can produce error if period > initial
This commit is contained in:
parent
35d470cbff
commit
c8f2e8f847
2 changed files with 47 additions and 14 deletions
|
|
@ -29,8 +29,11 @@ generate_cutoffs <- function(df, horizon, initial, period) {
|
|||
# If data does not exist in data range (cutoff, cutoff + horizon]
|
||||
if (!any((df$ds > cutoff) & (df$ds <= cutoff + horizon))) {
|
||||
# Next cutoff point is 'closest date before cutoff in data - horizon'
|
||||
closest.date <- max(df$ds[df$ds <= cutoff])
|
||||
cutoff <- closest.date - horizon
|
||||
if (cutoff > min(df$ds)) {
|
||||
closest.date <- max(df$ds[df$ds <= cutoff])
|
||||
cutoff <- closest.date - horizon
|
||||
}
|
||||
# else no data left, leave cutoff as is, it will be dropped.
|
||||
}
|
||||
result <- c(result, cutoff)
|
||||
}
|
||||
|
|
@ -73,19 +76,34 @@ generate_cutoffs <- function(df, horizon, initial, period) {
|
|||
cross_validation <- function(
|
||||
model, horizon, units, period = NULL, initial = NULL) {
|
||||
df <- model$history
|
||||
te <- max(df$ds)
|
||||
ts <- min(df$ds)
|
||||
horizon.dt <- as.difftime(horizon, units = units)
|
||||
# Set period
|
||||
if (is.null(period)) {
|
||||
period <- 0.5 * horizon
|
||||
}
|
||||
if (is.null(initial)) {
|
||||
initial <- 3 * horizon
|
||||
}
|
||||
horizon.dt <- as.difftime(horizon, units = units)
|
||||
initial.dt <- as.difftime(initial, units = units)
|
||||
period.dt <- as.difftime(period, units = units)
|
||||
# Identify largest seasonality period
|
||||
period.max <- 0
|
||||
for (s in model$seasonalities) {
|
||||
period.max <- max(period.max, s$period)
|
||||
}
|
||||
seasonality.dt <- as.difftime(period.max, units = 'days')
|
||||
# Set initial
|
||||
if (is.null(initial)) {
|
||||
initial.dt <- max(
|
||||
as.difftime(3 * horizon, units = units),
|
||||
seasonality.dt
|
||||
)
|
||||
} else {
|
||||
initial.dt <- as.difftime(initial, units = units)
|
||||
if (initial.dt < seasonality.dt) {
|
||||
warning(paste0('Seasonality has period of ', period.max, ' days which ',
|
||||
'is larger than initial window. Consider increasing initial.'))
|
||||
}
|
||||
}
|
||||
|
||||
cutoffs <- generate_cutoffs(df, horizon.dt, initial.dt, period.dt)
|
||||
|
||||
predicts <- data.frame()
|
||||
for (i in 1:length(cutoffs)) {
|
||||
cutoff <- cutoffs[i]
|
||||
|
|
|
|||
|
|
@ -42,8 +42,10 @@ def generate_cutoffs(df, horizon, initial, period):
|
|||
# If data does not exist in data range (cutoff, cutoff + horizon]
|
||||
if not (((df['ds'] > cutoff) & (df['ds'] <= cutoff + horizon)).any()):
|
||||
# Next cutoff point is 'last date before cutoff in data - horizon'
|
||||
closest_date = df[df['ds'] <= cutoff].max()['ds']
|
||||
cutoff = closest_date - horizon
|
||||
if cutoff > df['ds'].min():
|
||||
closest_date = df[df['ds'] <= cutoff].max()['ds']
|
||||
cutoff = closest_date - horizon
|
||||
# else no data left, leave cutoff as is, it will be dropped.
|
||||
result.append(cutoff)
|
||||
result = result[:-1]
|
||||
if len(result) == 0:
|
||||
|
|
@ -82,11 +84,24 @@ def cross_validation(model, horizon, period=None, initial=None):
|
|||
A pd.DataFrame with the forecast, actual value and cutoff.
|
||||
"""
|
||||
df = model.history.copy().reset_index(drop=True)
|
||||
te = df['ds'].max()
|
||||
ts = df['ds'].min()
|
||||
horizon = pd.Timedelta(horizon)
|
||||
# Set period
|
||||
period = 0.5 * horizon if period is None else pd.Timedelta(period)
|
||||
initial = 3 * horizon if initial is None else pd.Timedelta(initial)
|
||||
# Identify largest seasonality period
|
||||
period_max = 0.
|
||||
for s in model.seasonalities.values():
|
||||
period_max = max(period_max, s['period'])
|
||||
seasonality_dt = pd.Timedelta(str(period_max) + ' days')
|
||||
# Set initial
|
||||
if initial is None:
|
||||
initial = max(3 * horizon, seasonality_dt)
|
||||
else:
|
||||
initial = pd.Timedelta(initial)
|
||||
if initial < seasonality_dt:
|
||||
msg = 'Seasonality has period of {} days '.format(period_max)
|
||||
msg += 'which is larger than initial window. '
|
||||
msg += 'Consider increasing initial.'
|
||||
logger.warning(msg)
|
||||
|
||||
cutoffs = generate_cutoffs(df, horizon, initial, period)
|
||||
predicts = []
|
||||
|
|
|
|||
Loading…
Reference in a new issue