From a4b31cd70b7d9af67822375d4b982d5091c6b965 Mon Sep 17 00:00:00 2001 From: Ben Letham Date: Thu, 4 Mar 2021 14:04:48 -0800 Subject: [PATCH] Refactor R cross validation to match Py function structure --- R/R/diagnostics.R | 84 ++++++++++++++++++++++++++++------------------- 1 file changed, 50 insertions(+), 34 deletions(-) diff --git a/R/R/diagnostics.R b/R/R/diagnostics.R index 394c081..1e9aae4 100644 --- a/R/R/diagnostics.R +++ b/R/R/diagnostics.R @@ -131,45 +131,61 @@ cross_validation <- function( predicts <- data.frame() for (i in 1:length(cutoffs)) { - # Copy the model - cutoff <- cutoffs[i] - m <- prophet_copy(model, cutoff) - # Train model - history.c <- dplyr::filter(df, ds <= cutoff) - if (nrow(history.c) < 2) { - stop('Less than two datapoints before cutoff. Increase initial window.') - } - fit.args <- c(list(m=m, df=history.c), model$fit.kwargs) - m <- do.call(fit.prophet, fit.args) - # Calculate yhat - df.predict <- dplyr::filter(df, ds > cutoff, ds <= cutoff + horizon.dt) - # Get the columns for the future dataframe - columns <- 'ds' - if (m$growth == 'logistic') { - columns <- c(columns, 'cap') - if (m$logistic.floor) { - columns <- c(columns, 'floor') - } - } - columns <- c(columns, names(m$extra_regressors)) - for (name in names(m$seasonalities)) { - condition.name = m$seasonalities[[name]]$condition.name - if (!is.null(condition.name)) { - columns <- c(columns, condition.name) - } - } - future <- df.predict[columns] - yhat <- stats::predict(m, future) - # Merge yhat, y, and cutoff. - df.c <- dplyr::inner_join(df.predict, yhat[predict_columns], by = "ds") - df.c <- df.c[c(predict_columns, "y")] - df.c <- dplyr::select(df.c, y, predict_columns) - df.c$cutoff <- cutoff + df.c <- single_cutoff_forecast(df, model, cutoffs[i], horizon.dt, predict_columns) predicts <- rbind(predicts, df.c) } return(predicts) } +#' Forecast for a single cutoff. + +#' Used in cross_validation function when evaluating for multiple cutoffs. +#' +#' @param df Dataframe with history for cutoff. +#' @param model Prophet model object. +#' @param cutoff Datetime of cutoff. +#' @param horizon.dt timediff forecast horizon. +#' @param predict_columns Array of names of columns to be returned in output. +#' +#' @return Dataframe with forecast, actual value, and cutoff. +#' +#' @keywords internal +single_cutoff_forecast <- function(df, model, cutoff, horizon.dt, predict_columns){ + m <- prophet_copy(model, cutoff) + # Train model + history.c <- dplyr::filter(df, ds <= cutoff) + if (nrow(history.c) < 2) { + stop('Less than two datapoints before cutoff. Increase initial window.') + } + fit.args <- c(list(m=m, df=history.c), model$fit.kwargs) + m <- do.call(fit.prophet, fit.args) + # Calculate yhat + df.predict <- dplyr::filter(df, ds > cutoff, ds <= cutoff + horizon.dt) + # Get the columns for the future dataframe + columns <- 'ds' + if (m$growth == 'logistic') { + columns <- c(columns, 'cap') + if (m$logistic.floor) { + columns <- c(columns, 'floor') + } + } + columns <- c(columns, names(m$extra_regressors)) + for (name in names(m$seasonalities)) { + condition.name = m$seasonalities[[name]]$condition.name + if (!is.null(condition.name)) { + columns <- c(columns, condition.name) + } + } + future <- df.predict[columns] + yhat <- stats::predict(m, future) + # Merge yhat, y, and cutoff. + df.c <- dplyr::inner_join(df.predict, yhat[predict_columns], by = "ds") + df.c <- df.c[c(predict_columns, "y")] + df.c <- dplyr::select(df.c, y, predict_columns) + df.c$cutoff <- cutoff + return(df.c) +} + #' Copy Prophet object. #' #' @param m Prophet model object.