Refactor R cross validation to match Py function structure

This commit is contained in:
Ben Letham 2021-03-04 14:04:48 -08:00
parent eb6b8f60ec
commit a4b31cd70b

View file

@ -131,45 +131,61 @@ cross_validation <- function(
predicts <- data.frame()
for (i in 1:length(cutoffs)) {
# Copy the model
cutoff <- cutoffs[i]
m <- prophet_copy(model, cutoff)
# Train model
history.c <- dplyr::filter(df, ds <= cutoff)
if (nrow(history.c) < 2) {
stop('Less than two datapoints before cutoff. Increase initial window.')
}
fit.args <- c(list(m=m, df=history.c), model$fit.kwargs)
m <- do.call(fit.prophet, fit.args)
# Calculate yhat
df.predict <- dplyr::filter(df, ds > cutoff, ds <= cutoff + horizon.dt)
# Get the columns for the future dataframe
columns <- 'ds'
if (m$growth == 'logistic') {
columns <- c(columns, 'cap')
if (m$logistic.floor) {
columns <- c(columns, 'floor')
}
}
columns <- c(columns, names(m$extra_regressors))
for (name in names(m$seasonalities)) {
condition.name = m$seasonalities[[name]]$condition.name
if (!is.null(condition.name)) {
columns <- c(columns, condition.name)
}
}
future <- df.predict[columns]
yhat <- stats::predict(m, future)
# Merge yhat, y, and cutoff.
df.c <- dplyr::inner_join(df.predict, yhat[predict_columns], by = "ds")
df.c <- df.c[c(predict_columns, "y")]
df.c <- dplyr::select(df.c, y, predict_columns)
df.c$cutoff <- cutoff
df.c <- single_cutoff_forecast(df, model, cutoffs[i], horizon.dt, predict_columns)
predicts <- rbind(predicts, df.c)
}
return(predicts)
}
#' Forecast for a single cutoff.
#' Used in cross_validation function when evaluating for multiple cutoffs.
#'
#' @param df Dataframe with history for cutoff.
#' @param model Prophet model object.
#' @param cutoff Datetime of cutoff.
#' @param horizon.dt timediff forecast horizon.
#' @param predict_columns Array of names of columns to be returned in output.
#'
#' @return Dataframe with forecast, actual value, and cutoff.
#'
#' @keywords internal
single_cutoff_forecast <- function(df, model, cutoff, horizon.dt, predict_columns){
m <- prophet_copy(model, cutoff)
# Train model
history.c <- dplyr::filter(df, ds <= cutoff)
if (nrow(history.c) < 2) {
stop('Less than two datapoints before cutoff. Increase initial window.')
}
fit.args <- c(list(m=m, df=history.c), model$fit.kwargs)
m <- do.call(fit.prophet, fit.args)
# Calculate yhat
df.predict <- dplyr::filter(df, ds > cutoff, ds <= cutoff + horizon.dt)
# Get the columns for the future dataframe
columns <- 'ds'
if (m$growth == 'logistic') {
columns <- c(columns, 'cap')
if (m$logistic.floor) {
columns <- c(columns, 'floor')
}
}
columns <- c(columns, names(m$extra_regressors))
for (name in names(m$seasonalities)) {
condition.name = m$seasonalities[[name]]$condition.name
if (!is.null(condition.name)) {
columns <- c(columns, condition.name)
}
}
future <- df.predict[columns]
yhat <- stats::predict(m, future)
# Merge yhat, y, and cutoff.
df.c <- dplyr::inner_join(df.predict, yhat[predict_columns], by = "ds")
df.c <- df.c[c(predict_columns, "y")]
df.c <- dplyr::select(df.c, y, predict_columns)
df.c$cutoff <- cutoff
return(df.c)
}
#' Copy Prophet object.
#'
#' @param m Prophet model object.