mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-06-12 00:59:25 +00:00
Add support for fitting daily seasonality, make holiday features work when daily seasonality is enabled (#246)
* Add support for fitting daily seasonality, make holiday features work when daily seasonality is enabled * fix wrong comment in make_future_dataframe()
This commit is contained in:
parent
b07d345155
commit
b0938df109
2 changed files with 125 additions and 43 deletions
142
R/R/prophet.R
142
R/R/prophet.R
|
|
@ -33,6 +33,8 @@ globalVariables(c(
|
|||
#' FALSE, or a number of Fourier terms to generate.
|
||||
#' @param weekly.seasonality Fit weekly seasonality. Can be 'auto', TRUE,
|
||||
#' FALSE, or a number of Fourier terms to generate.
|
||||
#' @param daily.seasonality Fit daily seasonality. Can be 'auto', TRUE,
|
||||
#' FALSE, or a number of Fourier terms to generate.
|
||||
#' @param holidays data frame with columns holiday (character) and ds (date
|
||||
#' type)and optionally columns lower_window and upper_window which specify a
|
||||
#' range of days around the date to be included as holidays. lower_window=-2
|
||||
|
|
@ -76,6 +78,7 @@ prophet <- function(df = NULL,
|
|||
n.changepoints = 25,
|
||||
yearly.seasonality = 'auto',
|
||||
weekly.seasonality = 'auto',
|
||||
daily.seasonality = 'auto',
|
||||
holidays = NULL,
|
||||
seasonality.prior.scale = 10,
|
||||
holidays.prior.scale = 10,
|
||||
|
|
@ -98,6 +101,7 @@ prophet <- function(df = NULL,
|
|||
n.changepoints = n.changepoints,
|
||||
yearly.seasonality = yearly.seasonality,
|
||||
weekly.seasonality = weekly.seasonality,
|
||||
daily.seasonality = daily.seasonality,
|
||||
holidays = holidays,
|
||||
seasonality.prior.scale = seasonality.prior.scale,
|
||||
changepoint.prior.scale = changepoint.prior.scale,
|
||||
|
|
@ -206,6 +210,47 @@ compile_stan_model <- function(model) {
|
|||
return(rstan::stan_model(stanc_ret = stanc, model_name = model.name))
|
||||
}
|
||||
|
||||
#' Convert date vector
|
||||
#'
|
||||
#' Convert the date to POSIXct object
|
||||
#'
|
||||
#' @param ds Date vector, can be consisted of characters
|
||||
#'
|
||||
#' @return vector of POSIXct object converted from date
|
||||
#'
|
||||
set_date <- function(ds = NULL, tz = "GMT") {
|
||||
if (length(ds) == 0) {
|
||||
return(NULL)
|
||||
}
|
||||
|
||||
if (is.factor(ds)) {
|
||||
ds <- as.character(ds)
|
||||
}
|
||||
|
||||
if (min(nchar(ds)) < 12) {
|
||||
ds <- as.POSIXct(ds, format = "%Y-%m-%d", tz = tz)
|
||||
} else {
|
||||
ds <- as.POSIXct(ds, format = "%Y-%m-%d %H:%M:%S", tz = tz)
|
||||
}
|
||||
return(ds)
|
||||
}
|
||||
|
||||
#' Extract hour
|
||||
#'
|
||||
#' Extract hour from a POSIXct object
|
||||
#'
|
||||
#' @param ds POSIXct object
|
||||
#'
|
||||
#' @return hour of POSIXct object
|
||||
#'
|
||||
get_hour <- function(ds) {
|
||||
if (!("POSIXct" %in% is(ds))) {
|
||||
stop("ds must be a POSIXct object, use function set_date() to convert first.")
|
||||
}
|
||||
|
||||
return(format(ds , "%H"))
|
||||
}
|
||||
|
||||
#' Prepare dataframe for fitting or predicting.
|
||||
#'
|
||||
#' Adds a time index and scales y. Creates auxillary columns 't', 't_ix',
|
||||
|
|
@ -222,9 +267,9 @@ setup_dataframe <- function(m, df, initialize_scales = FALSE) {
|
|||
if (exists('y', where=df)) {
|
||||
df$y <- as.numeric(df$y)
|
||||
}
|
||||
df$ds <- zoo::as.Date(df$ds)
|
||||
df$ds <- set_date(df$ds)
|
||||
if (anyNA(df$ds)) {
|
||||
stop('Unable to parse date format in column ds. Convert to date format.')
|
||||
stop('Unable to parse date format in column ds. Convert to date format. Either %Y-%m-%d or %Y-%m-%d %H:%M:%S')
|
||||
}
|
||||
|
||||
df <- df %>%
|
||||
|
|
@ -233,10 +278,10 @@ setup_dataframe <- function(m, df, initialize_scales = FALSE) {
|
|||
if (initialize_scales) {
|
||||
m$y.scale <- max(abs(df$y))
|
||||
m$start <- min(df$ds)
|
||||
m$t.scale <- as.numeric(max(df$ds) - m$start)
|
||||
m$t.scale <- as.numeric(difftime(max(df$ds), m$start, units = "secs"))
|
||||
}
|
||||
|
||||
df$t <- as.numeric(df$ds - m$start) / m$t.scale
|
||||
df$t <- as.numeric(difftime(df$ds, m$start, units = "secs")) / m$t.scale
|
||||
if (exists('y', where=df)) {
|
||||
df$y_scaled <- df$y / m$y.scale
|
||||
}
|
||||
|
|
@ -285,8 +330,8 @@ set_changepoints <- function(m) {
|
|||
}
|
||||
}
|
||||
if (length(m$changepoints) > 0) {
|
||||
m$changepoints <- zoo::as.Date(m$changepoints)
|
||||
m$changepoints.t <- sort(as.numeric(m$changepoints - m$start) / m$t.scale)
|
||||
m$changepoints <- set_date(m$changepoints)
|
||||
m$changepoints.t <- sort(as.numeric(difftime(m$changepoints, m$start, units = "secs"))) / m$t.scale
|
||||
} else {
|
||||
m$changepoints.t <- c(0) # dummy changepoint
|
||||
}
|
||||
|
|
@ -316,7 +361,7 @@ get_changepoint_matrix <- function(m) {
|
|||
#' @return Matrix with seasonality features.
|
||||
#'
|
||||
fourier_series <- function(dates, period, series.order) {
|
||||
t <- dates - zoo::as.Date('1970-01-01')
|
||||
t <- as.numeric(difftime(dates, set_date('1970-01-01 00:00:00'), units = 'days'))
|
||||
features <- matrix(0, length(t), 2 * series.order)
|
||||
for (i in 1:series.order) {
|
||||
x <- as.numeric(2 * i * pi * t / period)
|
||||
|
|
@ -352,7 +397,7 @@ make_seasonality_features <- function(dates, period, series.order, prefix) {
|
|||
make_holiday_features <- function(m, dates) {
|
||||
scale.ratio <- m$holidays.prior.scale / m$seasonality.prior.scale
|
||||
wide <- m$holidays %>%
|
||||
dplyr::mutate(ds = zoo::as.Date(ds)) %>%
|
||||
dplyr::mutate(ds = set_date(ds)) %>%
|
||||
dplyr::group_by(holiday, ds) %>%
|
||||
dplyr::filter(row_number() == 1) %>%
|
||||
dplyr::do({
|
||||
|
|
@ -364,7 +409,7 @@ make_holiday_features <- function(m, dates) {
|
|||
}
|
||||
names <- paste(
|
||||
.$holiday, '_delim_', ifelse(offsets < 0, '-', '+'), abs(offsets), sep = '')
|
||||
dplyr::data_frame(ds = .$ds + offsets, holiday = names)
|
||||
dplyr::data_frame(ds = .$ds + offsets * 24 * 3600, holiday = names)
|
||||
}) %>%
|
||||
dplyr::mutate(x = scale.ratio) %>%
|
||||
tidyr::spread(holiday, x, fill = 0)
|
||||
|
|
@ -472,22 +517,29 @@ parse_seasonality_args <- function(m, name, arg, auto.disable, default.order) {
|
|||
set_auto_seasonalities <- function(m) {
|
||||
first <- min(m$history$ds)
|
||||
last <- max(m$history$ds)
|
||||
dt <- diff(m$history$ds)
|
||||
dt <- diff(as.numeric(difftime(m$history$ds, m$start, units = "d")))
|
||||
min.dt <- min(dt[dt > 0])
|
||||
|
||||
yearly.disable <- last - first < 730
|
||||
yearly.disable <- as.numeric(difftime(last, first, unit = "days")) < 730
|
||||
fourier.order <- parse_seasonality_args(
|
||||
m, 'yearly', m$yearly.seasonality, yearly.disable, 10)
|
||||
if (fourier.order > 0) {
|
||||
m$seasonalities[['yearly']] <- c(365.25, fourier.order)
|
||||
}
|
||||
|
||||
weekly.disable <- ((last - first < 14) || (min.dt >= 7))
|
||||
weekly.disable <- ((as.numeric(difftime(last, first, unit = "days")) < 14) || (min.dt >= 7))
|
||||
fourier.order <- parse_seasonality_args(
|
||||
m, 'weekly', m$weekly.seasonality, weekly.disable, 3)
|
||||
if (fourier.order > 0) {
|
||||
m$seasonalities[['weekly']] <- c(7, fourier.order)
|
||||
}
|
||||
|
||||
daily.disable <- ((as.numeric(difftime(last, first, unit = "days")) < 2)) || (min.dt >= 1)
|
||||
fourier.order <- parse_seasonality_args(
|
||||
m, 'daily', m$daily.seasonality, daily.disable, 4)
|
||||
if (fourier.order > 0) {
|
||||
m$seasonalities[['daily']] <- c(1, fourier.order)
|
||||
}
|
||||
return(m)
|
||||
}
|
||||
|
||||
|
|
@ -571,7 +623,7 @@ fit.prophet <- function(m, df, ...) {
|
|||
if (any(is.infinite(history$y))) {
|
||||
stop("Found infinity in column y.")
|
||||
}
|
||||
m$history.dates <- sort(zoo::as.Date(df$ds))
|
||||
m$history.dates <- sort(set_date(df$ds))
|
||||
|
||||
out <- setup_dataframe(m, history, initialize_scales = TRUE)
|
||||
history <- out$df
|
||||
|
|
@ -985,7 +1037,7 @@ sample_predictive_trend <- function(model, df, iteration) {
|
|||
#'
|
||||
#' @param m Prophet model object.
|
||||
#' @param periods Int number of periods to forecast forward.
|
||||
#' @param freq 'day', 'week', 'month', 'quarter', or 'year'.
|
||||
#' @param freq 'day', 'week', 'month', 'quarter', 'year', 1(1 sec), 60(1 minute) or 3600(1 hour).
|
||||
#' @param include_history Boolean to include the historical dates in the data
|
||||
#' frame for predictions.
|
||||
#'
|
||||
|
|
@ -993,7 +1045,7 @@ sample_predictive_trend <- function(model, df, iteration) {
|
|||
#' requested number of periods.
|
||||
#'
|
||||
#' @export
|
||||
make_future_dataframe <- function(m, periods, freq = 'd',
|
||||
make_future_dataframe <- function(m, periods, freq = 'day',
|
||||
include_history = TRUE) {
|
||||
dates <- seq(max(m$history.dates), length.out = periods + 1, by = freq)
|
||||
dates <- dates[2:(periods + 1)] # Drop the first, which is max(history$ds)
|
||||
|
|
@ -1091,7 +1143,7 @@ plot.prophet <- function(x, fcst, uncertainty = TRUE, plot_cap = TRUE,
|
|||
#' @importFrom dplyr "%>%"
|
||||
prophet_plot_components <- function(
|
||||
m, fcst, uncertainty = TRUE, plot_cap = TRUE, weekly_start = 0,
|
||||
yearly_start = 0) {
|
||||
yearly_start = 0, daily_start = 0) {
|
||||
df <- df_for_plotting(m, fcst)
|
||||
# Plot the trend
|
||||
panels <- list(plot_trend(df, uncertainty, plot_cap))
|
||||
|
|
@ -1099,6 +1151,10 @@ prophet_plot_components <- function(
|
|||
if (!is.null(m$holidays)) {
|
||||
panels[[length(panels) + 1]] <- plot_holidays(m, df, uncertainty)
|
||||
}
|
||||
# Plot daily seasonality, if present
|
||||
if ("daily" %in% colnames(df)) {
|
||||
panels[[length(panels) + 1]] <- plot_daily(m, uncertainty, daily_start)
|
||||
}
|
||||
# Plot weekly seasonality, if present
|
||||
if ("weekly" %in% colnames(df)) {
|
||||
panels[[length(panels) + 1]] <- plot_weekly(m, uncertainty, weekly_start)
|
||||
|
|
@ -1109,7 +1165,7 @@ prophet_plot_components <- function(
|
|||
}
|
||||
# Plot other seasonalities
|
||||
for (name in names(m$seasonalities)) {
|
||||
if (!(name %in% c('weekly', 'yearly')) && (name %in% colnames(df))) {
|
||||
if (!(name %in% c('daily', 'weekly', 'yearly')) && (name %in% colnames(df))) {
|
||||
panels[[length(panels) + 1]] <- plot_seasonality(m, name, uncertainty)
|
||||
}
|
||||
}
|
||||
|
|
@ -1184,6 +1240,39 @@ plot_holidays <- function(m, df, uncertainty = TRUE) {
|
|||
return(gg.holidays)
|
||||
}
|
||||
|
||||
#' Plot the daily component of the forecast.
|
||||
#'
|
||||
#' @param m Prophet model object
|
||||
#' @param uncertainty Boolean to plot uncertainty intervals.
|
||||
#' @param daily_start Integer specifying the start day of the daily
|
||||
#' seasonality plot. 0 (default) starts the week on Sunday. 1 shifts by 1 day
|
||||
#' to Monday, and so on.
|
||||
#'
|
||||
#' @return A ggplot2 plot.
|
||||
plot_daily <- function(m, uncertainty = TRUE, daily_start = 0) {
|
||||
# Compute weekly seasonality for a Sun-Sat sequence of dates.
|
||||
df.d <- data.frame(
|
||||
ds=seq(set_date('2017-01-01 00:00:00'), length.out=24, by = "hour") +
|
||||
daily_start, cap=1.)
|
||||
df.d <- setup_dataframe(m, df.d)$df
|
||||
seas <- predict_seasonal_components(m, df.d)
|
||||
seas$hod <- factor(get_hour(df.d$ds), levels=get_hour(df.d$ds))
|
||||
|
||||
gg.daily <- ggplot2::ggplot(seas, ggplot2::aes(x = hod, y = daily,
|
||||
group = 1)) +
|
||||
ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
|
||||
ggplot2::labs(x = "Hour of day")
|
||||
if (uncertainty) {
|
||||
gg.daily <- gg.daily +
|
||||
ggplot2::geom_ribbon(ggplot2::aes(ymin = daily_lower,
|
||||
ymax = daily_upper),
|
||||
alpha = 0.2,
|
||||
fill = "#0072B2",
|
||||
na.rm = TRUE)
|
||||
}
|
||||
return(gg.daily)
|
||||
}
|
||||
|
||||
#' Plot the weekly component of the forecast.
|
||||
#'
|
||||
#' @param m Prophet model object
|
||||
|
|
@ -1196,7 +1285,7 @@ plot_holidays <- function(m, df, uncertainty = TRUE) {
|
|||
plot_weekly <- function(m, uncertainty = TRUE, weekly_start = 0) {
|
||||
# Compute weekly seasonality for a Sun-Sat sequence of dates.
|
||||
df.w <- data.frame(
|
||||
ds=seq.Date(zoo::as.Date('2017-01-01'), by='d', length.out=7) +
|
||||
ds=seq(set_date('2017-01-01'), by='d', length.out=7) +
|
||||
weekly_start, cap=1.)
|
||||
df.w <- setup_dataframe(m, df.w)$df
|
||||
seas <- predict_seasonal_components(m, df.w)
|
||||
|
|
@ -1229,7 +1318,7 @@ plot_weekly <- function(m, uncertainty = TRUE, weekly_start = 0) {
|
|||
plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0) {
|
||||
# Compute yearly seasonality for a Jan 1 - Dec 31 sequence of dates.
|
||||
df.y <- data.frame(
|
||||
ds=seq.Date(zoo::as.Date('2017-01-01'), by='d', length.out=365) +
|
||||
ds=seq(set_date('2017-01-01'), by='d', length.out=365) +
|
||||
yearly_start, cap=1.)
|
||||
df.y <- setup_dataframe(m, df.y)$df
|
||||
seas <- predict_seasonal_components(m, df.y)
|
||||
|
|
@ -1238,7 +1327,6 @@ plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0) {
|
|||
gg.yearly <- ggplot2::ggplot(seas, ggplot2::aes(x = ds, y = yearly,
|
||||
group = 1)) +
|
||||
ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
|
||||
ggplot2::scale_x_date(labels = scales::date_format('%B %d')) +
|
||||
ggplot2::labs(x = "Day of year")
|
||||
if (uncertainty) {
|
||||
gg.yearly <- gg.yearly +
|
||||
|
|
@ -1260,24 +1348,18 @@ plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0) {
|
|||
#' @return A ggplot2 plot.
|
||||
plot_seasonality <- function(m, name, uncertainty = TRUE) {
|
||||
# Compute seasonality from Jan 1 through a single period.
|
||||
start <- zoo::as.Date('2017-01-01')
|
||||
start <- set_date('2017-01-01')
|
||||
period <- m$seasonalities[[name]][1]
|
||||
end <- start + period
|
||||
plot.points <- as.numeric(end - start)
|
||||
end <- start + period * 24 * 3600
|
||||
plot.points <- as.numeric(difftime(end, start))
|
||||
df.y <- data.frame(
|
||||
ds=seq.Date(from=start, to=end, length.out=plot.points), cap=1.)
|
||||
ds=seq(from=start, by='d', length.out=plot.points), cap=1.)
|
||||
df.y <- setup_dataframe(m, df.y)$df
|
||||
seas <- predict_seasonal_components(m, df.y)
|
||||
seas$ds <- df.y$ds
|
||||
gg.s <- ggplot2::ggplot(
|
||||
seas, ggplot2::aes_string(x = 'ds', y = name, group = 1)) +
|
||||
ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
|
||||
if (period < 14) {
|
||||
fmt.str <- '%m/%d %R'
|
||||
} else {
|
||||
fmt.str <- '%m/%d'
|
||||
}
|
||||
gg.s <- gg.s + ggplot2::scale_x_date(labels = scales::date_format(fmt.str))
|
||||
if (uncertainty) {
|
||||
gg.s <- gg.s +
|
||||
ggplot2::geom_ribbon(
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ library(prophet)
|
|||
context("Prophet tests")
|
||||
|
||||
DATA <- read.csv('data.csv')
|
||||
DATA$ds <- as.Date(DATA$ds)
|
||||
DATA$ds <- set_date(DATA$ds)
|
||||
N <- nrow(DATA)
|
||||
train <- DATA[1:floor(N / 2), ]
|
||||
future <- DATA[(ceiling(N/2) + 1):N, ]
|
||||
|
|
@ -27,9 +27,9 @@ test_that("fit_predict_no_changepoints", {
|
|||
|
||||
test_that("fit_predict_changepoint_not_in_history", {
|
||||
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
|
||||
train_t <- dplyr::mutate(DATA, ds=zoo::as.Date(ds))
|
||||
train_t <- dplyr::filter(train_t, (ds < zoo::as.Date('2013-01-01')) |
|
||||
(ds > zoo::as.Date('2014-01-01')))
|
||||
train_t <- dplyr::mutate(DATA, ds=set_date(ds))
|
||||
train_t <- dplyr::filter(train_t, (ds < set_date('2013-01-01')) |
|
||||
(ds > set_date('2014-01-01')))
|
||||
future <- data.frame(ds=DATA$ds)
|
||||
m <- prophet(train_t, changepoints=c('2013-06-06'))
|
||||
expect_error(predict(m, future), NA)
|
||||
|
|
@ -170,19 +170,19 @@ test_that("piecewise_logistic", {
|
|||
})
|
||||
|
||||
test_that("holidays", {
|
||||
holidays = data.frame(ds = zoo::as.Date(c('2016-12-25')),
|
||||
holidays = data.frame(ds = set_date(c('2016-12-25')),
|
||||
holiday = c('xmas'),
|
||||
lower_window = c(-1),
|
||||
upper_window = c(0))
|
||||
df <- data.frame(
|
||||
ds = seq(zoo::as.Date('2016-12-20'), zoo::as.Date('2016-12-31'), by='d'))
|
||||
ds = seq(set_date('2016-12-20'), set_date('2016-12-31'), by='d'))
|
||||
m <- prophet(train, holidays = holidays, fit = FALSE)
|
||||
feats <- prophet:::make_holiday_features(m, df$ds)
|
||||
expect_equal(nrow(feats), nrow(df))
|
||||
expect_equal(ncol(feats), 2)
|
||||
expect_equal(sum(colSums(feats) - c(1, 1)), 0)
|
||||
|
||||
holidays = data.frame(ds = zoo::as.Date(c('2016-12-25')),
|
||||
holidays = data.frame(ds = set_date(c('2016-12-25')),
|
||||
holiday = c('xmas'),
|
||||
lower_window = c(-1),
|
||||
upper_window = c(10))
|
||||
|
|
@ -194,7 +194,7 @@ test_that("holidays", {
|
|||
|
||||
test_that("fit_with_holidays", {
|
||||
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
|
||||
holidays <- data.frame(ds = zoo::as.Date(c('2012-06-06', '2013-06-06')),
|
||||
holidays <- data.frame(ds = set_date(c('2012-06-06', '2013-06-06')),
|
||||
holiday = c('seans-bday', 'seans-bday'),
|
||||
lower_window = c(0, 0),
|
||||
upper_window = c(1, 1))
|
||||
|
|
@ -206,14 +206,14 @@ test_that("make_future_dataframe", {
|
|||
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
|
||||
train.t <- DATA[1:234, ]
|
||||
m <- prophet(train.t)
|
||||
future <- make_future_dataframe(m, periods = 3, freq = 'd',
|
||||
future <- make_future_dataframe(m, periods = 3, freq = 'day',
|
||||
include_history = FALSE)
|
||||
correct <- as.Date(c('2013-04-26', '2013-04-27', '2013-04-28'))
|
||||
correct <- set_date(c('2013-04-26', '2013-04-27', '2013-04-28'))
|
||||
expect_equal(future$ds, correct)
|
||||
|
||||
future <- make_future_dataframe(m, periods = 3, freq = 'm',
|
||||
future <- make_future_dataframe(m, periods = 3, freq = 'month',
|
||||
include_history = FALSE)
|
||||
correct <- as.Date(c('2013-05-25', '2013-06-25', '2013-07-25'))
|
||||
correct <- set_date(c('2013-05-25', '2013-06-25', '2013-07-25'))
|
||||
expect_equal(future$ds, correct)
|
||||
})
|
||||
|
||||
|
|
@ -263,7 +263,7 @@ test_that("auto_yearly_seasonality", {
|
|||
|
||||
test_that("custom_seasonality", {
|
||||
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
|
||||
holidays <- data.frame(ds = zoo::as.Date(c('2017-01-02')),
|
||||
holidays <- data.frame(ds = set_date(c('2017-01-02')),
|
||||
holiday = c('special_day'))
|
||||
m <- prophet(holidays=holidays)
|
||||
m <- add_seasonality(m, name='monthly', period=30, fourier.order=5)
|
||||
|
|
|
|||
Loading…
Reference in a new issue