mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-07-05 04:17:56 +00:00
Refactor R cross validation to match Py function structure
This commit is contained in:
parent
eb6b8f60ec
commit
a4b31cd70b
1 changed files with 50 additions and 34 deletions
|
|
@ -131,45 +131,61 @@ cross_validation <- function(
|
|||
|
||||
predicts <- data.frame()
|
||||
for (i in 1:length(cutoffs)) {
|
||||
# Copy the model
|
||||
cutoff <- cutoffs[i]
|
||||
m <- prophet_copy(model, cutoff)
|
||||
# Train model
|
||||
history.c <- dplyr::filter(df, ds <= cutoff)
|
||||
if (nrow(history.c) < 2) {
|
||||
stop('Less than two datapoints before cutoff. Increase initial window.')
|
||||
}
|
||||
fit.args <- c(list(m=m, df=history.c), model$fit.kwargs)
|
||||
m <- do.call(fit.prophet, fit.args)
|
||||
# Calculate yhat
|
||||
df.predict <- dplyr::filter(df, ds > cutoff, ds <= cutoff + horizon.dt)
|
||||
# Get the columns for the future dataframe
|
||||
columns <- 'ds'
|
||||
if (m$growth == 'logistic') {
|
||||
columns <- c(columns, 'cap')
|
||||
if (m$logistic.floor) {
|
||||
columns <- c(columns, 'floor')
|
||||
}
|
||||
}
|
||||
columns <- c(columns, names(m$extra_regressors))
|
||||
for (name in names(m$seasonalities)) {
|
||||
condition.name = m$seasonalities[[name]]$condition.name
|
||||
if (!is.null(condition.name)) {
|
||||
columns <- c(columns, condition.name)
|
||||
}
|
||||
}
|
||||
future <- df.predict[columns]
|
||||
yhat <- stats::predict(m, future)
|
||||
# Merge yhat, y, and cutoff.
|
||||
df.c <- dplyr::inner_join(df.predict, yhat[predict_columns], by = "ds")
|
||||
df.c <- df.c[c(predict_columns, "y")]
|
||||
df.c <- dplyr::select(df.c, y, predict_columns)
|
||||
df.c$cutoff <- cutoff
|
||||
df.c <- single_cutoff_forecast(df, model, cutoffs[i], horizon.dt, predict_columns)
|
||||
predicts <- rbind(predicts, df.c)
|
||||
}
|
||||
return(predicts)
|
||||
}
|
||||
|
||||
#' Forecast for a single cutoff.
|
||||
|
||||
#' Used in cross_validation function when evaluating for multiple cutoffs.
|
||||
#'
|
||||
#' @param df Dataframe with history for cutoff.
|
||||
#' @param model Prophet model object.
|
||||
#' @param cutoff Datetime of cutoff.
|
||||
#' @param horizon.dt timediff forecast horizon.
|
||||
#' @param predict_columns Array of names of columns to be returned in output.
|
||||
#'
|
||||
#' @return Dataframe with forecast, actual value, and cutoff.
|
||||
#'
|
||||
#' @keywords internal
|
||||
single_cutoff_forecast <- function(df, model, cutoff, horizon.dt, predict_columns){
|
||||
m <- prophet_copy(model, cutoff)
|
||||
# Train model
|
||||
history.c <- dplyr::filter(df, ds <= cutoff)
|
||||
if (nrow(history.c) < 2) {
|
||||
stop('Less than two datapoints before cutoff. Increase initial window.')
|
||||
}
|
||||
fit.args <- c(list(m=m, df=history.c), model$fit.kwargs)
|
||||
m <- do.call(fit.prophet, fit.args)
|
||||
# Calculate yhat
|
||||
df.predict <- dplyr::filter(df, ds > cutoff, ds <= cutoff + horizon.dt)
|
||||
# Get the columns for the future dataframe
|
||||
columns <- 'ds'
|
||||
if (m$growth == 'logistic') {
|
||||
columns <- c(columns, 'cap')
|
||||
if (m$logistic.floor) {
|
||||
columns <- c(columns, 'floor')
|
||||
}
|
||||
}
|
||||
columns <- c(columns, names(m$extra_regressors))
|
||||
for (name in names(m$seasonalities)) {
|
||||
condition.name = m$seasonalities[[name]]$condition.name
|
||||
if (!is.null(condition.name)) {
|
||||
columns <- c(columns, condition.name)
|
||||
}
|
||||
}
|
||||
future <- df.predict[columns]
|
||||
yhat <- stats::predict(m, future)
|
||||
# Merge yhat, y, and cutoff.
|
||||
df.c <- dplyr::inner_join(df.predict, yhat[predict_columns], by = "ds")
|
||||
df.c <- df.c[c(predict_columns, "y")]
|
||||
df.c <- dplyr::select(df.c, y, predict_columns)
|
||||
df.c$cutoff <- cutoff
|
||||
return(df.c)
|
||||
}
|
||||
|
||||
#' Copy Prophet object.
|
||||
#'
|
||||
#' @param m Prophet model object.
|
||||
|
|
|
|||
Loading…
Reference in a new issue