Add support for fitting daily seasonality, make holiday features work when daily seasonality is enabled (#246)

* Add support for fitting daily seasonality, make holiday features work
when daily seasonality is enabled

* fix wrong comment in make_future_dataframe()
This commit is contained in:
Qi Wang 2017-07-10 18:46:49 -07:00 committed by Ben Letham
parent b07d345155
commit b0938df109
2 changed files with 125 additions and 43 deletions

View file

@ -33,6 +33,8 @@ globalVariables(c(
#' FALSE, or a number of Fourier terms to generate.
#' @param weekly.seasonality Fit weekly seasonality. Can be 'auto', TRUE,
#' FALSE, or a number of Fourier terms to generate.
#' @param daily.seasonality Fit daily seasonality. Can be 'auto', TRUE,
#' FALSE, or a number of Fourier terms to generate.
#' @param holidays data frame with columns holiday (character) and ds (date
#' type)and optionally columns lower_window and upper_window which specify a
#' range of days around the date to be included as holidays. lower_window=-2
@ -76,6 +78,7 @@ prophet <- function(df = NULL,
n.changepoints = 25,
yearly.seasonality = 'auto',
weekly.seasonality = 'auto',
daily.seasonality = 'auto',
holidays = NULL,
seasonality.prior.scale = 10,
holidays.prior.scale = 10,
@ -98,6 +101,7 @@ prophet <- function(df = NULL,
n.changepoints = n.changepoints,
yearly.seasonality = yearly.seasonality,
weekly.seasonality = weekly.seasonality,
daily.seasonality = daily.seasonality,
holidays = holidays,
seasonality.prior.scale = seasonality.prior.scale,
changepoint.prior.scale = changepoint.prior.scale,
@ -206,6 +210,47 @@ compile_stan_model <- function(model) {
return(rstan::stan_model(stanc_ret = stanc, model_name = model.name))
}
#' Convert date vector
#'
#' Convert the date to POSIXct object
#'
#' @param ds Date vector, can be consisted of characters
#'
#' @return vector of POSIXct object converted from date
#'
set_date <- function(ds = NULL, tz = "GMT") {
if (length(ds) == 0) {
return(NULL)
}
if (is.factor(ds)) {
ds <- as.character(ds)
}
if (min(nchar(ds)) < 12) {
ds <- as.POSIXct(ds, format = "%Y-%m-%d", tz = tz)
} else {
ds <- as.POSIXct(ds, format = "%Y-%m-%d %H:%M:%S", tz = tz)
}
return(ds)
}
#' Extract hour
#'
#' Extract hour from a POSIXct object
#'
#' @param ds POSIXct object
#'
#' @return hour of POSIXct object
#'
get_hour <- function(ds) {
if (!("POSIXct" %in% is(ds))) {
stop("ds must be a POSIXct object, use function set_date() to convert first.")
}
return(format(ds , "%H"))
}
#' Prepare dataframe for fitting or predicting.
#'
#' Adds a time index and scales y. Creates auxillary columns 't', 't_ix',
@ -222,9 +267,9 @@ setup_dataframe <- function(m, df, initialize_scales = FALSE) {
if (exists('y', where=df)) {
df$y <- as.numeric(df$y)
}
df$ds <- zoo::as.Date(df$ds)
df$ds <- set_date(df$ds)
if (anyNA(df$ds)) {
stop('Unable to parse date format in column ds. Convert to date format.')
stop('Unable to parse date format in column ds. Convert to date format. Either %Y-%m-%d or %Y-%m-%d %H:%M:%S')
}
df <- df %>%
@ -233,10 +278,10 @@ setup_dataframe <- function(m, df, initialize_scales = FALSE) {
if (initialize_scales) {
m$y.scale <- max(abs(df$y))
m$start <- min(df$ds)
m$t.scale <- as.numeric(max(df$ds) - m$start)
m$t.scale <- as.numeric(difftime(max(df$ds), m$start, units = "secs"))
}
df$t <- as.numeric(df$ds - m$start) / m$t.scale
df$t <- as.numeric(difftime(df$ds, m$start, units = "secs")) / m$t.scale
if (exists('y', where=df)) {
df$y_scaled <- df$y / m$y.scale
}
@ -285,8 +330,8 @@ set_changepoints <- function(m) {
}
}
if (length(m$changepoints) > 0) {
m$changepoints <- zoo::as.Date(m$changepoints)
m$changepoints.t <- sort(as.numeric(m$changepoints - m$start) / m$t.scale)
m$changepoints <- set_date(m$changepoints)
m$changepoints.t <- sort(as.numeric(difftime(m$changepoints, m$start, units = "secs"))) / m$t.scale
} else {
m$changepoints.t <- c(0) # dummy changepoint
}
@ -316,7 +361,7 @@ get_changepoint_matrix <- function(m) {
#' @return Matrix with seasonality features.
#'
fourier_series <- function(dates, period, series.order) {
t <- dates - zoo::as.Date('1970-01-01')
t <- as.numeric(difftime(dates, set_date('1970-01-01 00:00:00'), units = 'days'))
features <- matrix(0, length(t), 2 * series.order)
for (i in 1:series.order) {
x <- as.numeric(2 * i * pi * t / period)
@ -352,7 +397,7 @@ make_seasonality_features <- function(dates, period, series.order, prefix) {
make_holiday_features <- function(m, dates) {
scale.ratio <- m$holidays.prior.scale / m$seasonality.prior.scale
wide <- m$holidays %>%
dplyr::mutate(ds = zoo::as.Date(ds)) %>%
dplyr::mutate(ds = set_date(ds)) %>%
dplyr::group_by(holiday, ds) %>%
dplyr::filter(row_number() == 1) %>%
dplyr::do({
@ -364,7 +409,7 @@ make_holiday_features <- function(m, dates) {
}
names <- paste(
.$holiday, '_delim_', ifelse(offsets < 0, '-', '+'), abs(offsets), sep = '')
dplyr::data_frame(ds = .$ds + offsets, holiday = names)
dplyr::data_frame(ds = .$ds + offsets * 24 * 3600, holiday = names)
}) %>%
dplyr::mutate(x = scale.ratio) %>%
tidyr::spread(holiday, x, fill = 0)
@ -472,22 +517,29 @@ parse_seasonality_args <- function(m, name, arg, auto.disable, default.order) {
set_auto_seasonalities <- function(m) {
first <- min(m$history$ds)
last <- max(m$history$ds)
dt <- diff(m$history$ds)
dt <- diff(as.numeric(difftime(m$history$ds, m$start, units = "d")))
min.dt <- min(dt[dt > 0])
yearly.disable <- last - first < 730
yearly.disable <- as.numeric(difftime(last, first, unit = "days")) < 730
fourier.order <- parse_seasonality_args(
m, 'yearly', m$yearly.seasonality, yearly.disable, 10)
if (fourier.order > 0) {
m$seasonalities[['yearly']] <- c(365.25, fourier.order)
}
weekly.disable <- ((last - first < 14) || (min.dt >= 7))
weekly.disable <- ((as.numeric(difftime(last, first, unit = "days")) < 14) || (min.dt >= 7))
fourier.order <- parse_seasonality_args(
m, 'weekly', m$weekly.seasonality, weekly.disable, 3)
if (fourier.order > 0) {
m$seasonalities[['weekly']] <- c(7, fourier.order)
}
daily.disable <- ((as.numeric(difftime(last, first, unit = "days")) < 2)) || (min.dt >= 1)
fourier.order <- parse_seasonality_args(
m, 'daily', m$daily.seasonality, daily.disable, 4)
if (fourier.order > 0) {
m$seasonalities[['daily']] <- c(1, fourier.order)
}
return(m)
}
@ -571,7 +623,7 @@ fit.prophet <- function(m, df, ...) {
if (any(is.infinite(history$y))) {
stop("Found infinity in column y.")
}
m$history.dates <- sort(zoo::as.Date(df$ds))
m$history.dates <- sort(set_date(df$ds))
out <- setup_dataframe(m, history, initialize_scales = TRUE)
history <- out$df
@ -985,7 +1037,7 @@ sample_predictive_trend <- function(model, df, iteration) {
#'
#' @param m Prophet model object.
#' @param periods Int number of periods to forecast forward.
#' @param freq 'day', 'week', 'month', 'quarter', or 'year'.
#' @param freq 'day', 'week', 'month', 'quarter', 'year', 1(1 sec), 60(1 minute) or 3600(1 hour).
#' @param include_history Boolean to include the historical dates in the data
#' frame for predictions.
#'
@ -993,7 +1045,7 @@ sample_predictive_trend <- function(model, df, iteration) {
#' requested number of periods.
#'
#' @export
make_future_dataframe <- function(m, periods, freq = 'd',
make_future_dataframe <- function(m, periods, freq = 'day',
include_history = TRUE) {
dates <- seq(max(m$history.dates), length.out = periods + 1, by = freq)
dates <- dates[2:(periods + 1)] # Drop the first, which is max(history$ds)
@ -1091,7 +1143,7 @@ plot.prophet <- function(x, fcst, uncertainty = TRUE, plot_cap = TRUE,
#' @importFrom dplyr "%>%"
prophet_plot_components <- function(
m, fcst, uncertainty = TRUE, plot_cap = TRUE, weekly_start = 0,
yearly_start = 0) {
yearly_start = 0, daily_start = 0) {
df <- df_for_plotting(m, fcst)
# Plot the trend
panels <- list(plot_trend(df, uncertainty, plot_cap))
@ -1099,6 +1151,10 @@ prophet_plot_components <- function(
if (!is.null(m$holidays)) {
panels[[length(panels) + 1]] <- plot_holidays(m, df, uncertainty)
}
# Plot daily seasonality, if present
if ("daily" %in% colnames(df)) {
panels[[length(panels) + 1]] <- plot_daily(m, uncertainty, daily_start)
}
# Plot weekly seasonality, if present
if ("weekly" %in% colnames(df)) {
panels[[length(panels) + 1]] <- plot_weekly(m, uncertainty, weekly_start)
@ -1109,7 +1165,7 @@ prophet_plot_components <- function(
}
# Plot other seasonalities
for (name in names(m$seasonalities)) {
if (!(name %in% c('weekly', 'yearly')) && (name %in% colnames(df))) {
if (!(name %in% c('daily', 'weekly', 'yearly')) && (name %in% colnames(df))) {
panels[[length(panels) + 1]] <- plot_seasonality(m, name, uncertainty)
}
}
@ -1184,6 +1240,39 @@ plot_holidays <- function(m, df, uncertainty = TRUE) {
return(gg.holidays)
}
#' Plot the daily component of the forecast.
#'
#' @param m Prophet model object
#' @param uncertainty Boolean to plot uncertainty intervals.
#' @param daily_start Integer specifying the start day of the daily
#' seasonality plot. 0 (default) starts the week on Sunday. 1 shifts by 1 day
#' to Monday, and so on.
#'
#' @return A ggplot2 plot.
plot_daily <- function(m, uncertainty = TRUE, daily_start = 0) {
# Compute weekly seasonality for a Sun-Sat sequence of dates.
df.d <- data.frame(
ds=seq(set_date('2017-01-01 00:00:00'), length.out=24, by = "hour") +
daily_start, cap=1.)
df.d <- setup_dataframe(m, df.d)$df
seas <- predict_seasonal_components(m, df.d)
seas$hod <- factor(get_hour(df.d$ds), levels=get_hour(df.d$ds))
gg.daily <- ggplot2::ggplot(seas, ggplot2::aes(x = hod, y = daily,
group = 1)) +
ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
ggplot2::labs(x = "Hour of day")
if (uncertainty) {
gg.daily <- gg.daily +
ggplot2::geom_ribbon(ggplot2::aes(ymin = daily_lower,
ymax = daily_upper),
alpha = 0.2,
fill = "#0072B2",
na.rm = TRUE)
}
return(gg.daily)
}
#' Plot the weekly component of the forecast.
#'
#' @param m Prophet model object
@ -1196,7 +1285,7 @@ plot_holidays <- function(m, df, uncertainty = TRUE) {
plot_weekly <- function(m, uncertainty = TRUE, weekly_start = 0) {
# Compute weekly seasonality for a Sun-Sat sequence of dates.
df.w <- data.frame(
ds=seq.Date(zoo::as.Date('2017-01-01'), by='d', length.out=7) +
ds=seq(set_date('2017-01-01'), by='d', length.out=7) +
weekly_start, cap=1.)
df.w <- setup_dataframe(m, df.w)$df
seas <- predict_seasonal_components(m, df.w)
@ -1229,7 +1318,7 @@ plot_weekly <- function(m, uncertainty = TRUE, weekly_start = 0) {
plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0) {
# Compute yearly seasonality for a Jan 1 - Dec 31 sequence of dates.
df.y <- data.frame(
ds=seq.Date(zoo::as.Date('2017-01-01'), by='d', length.out=365) +
ds=seq(set_date('2017-01-01'), by='d', length.out=365) +
yearly_start, cap=1.)
df.y <- setup_dataframe(m, df.y)$df
seas <- predict_seasonal_components(m, df.y)
@ -1238,7 +1327,6 @@ plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0) {
gg.yearly <- ggplot2::ggplot(seas, ggplot2::aes(x = ds, y = yearly,
group = 1)) +
ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
ggplot2::scale_x_date(labels = scales::date_format('%B %d')) +
ggplot2::labs(x = "Day of year")
if (uncertainty) {
gg.yearly <- gg.yearly +
@ -1260,24 +1348,18 @@ plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0) {
#' @return A ggplot2 plot.
plot_seasonality <- function(m, name, uncertainty = TRUE) {
# Compute seasonality from Jan 1 through a single period.
start <- zoo::as.Date('2017-01-01')
start <- set_date('2017-01-01')
period <- m$seasonalities[[name]][1]
end <- start + period
plot.points <- as.numeric(end - start)
end <- start + period * 24 * 3600
plot.points <- as.numeric(difftime(end, start))
df.y <- data.frame(
ds=seq.Date(from=start, to=end, length.out=plot.points), cap=1.)
ds=seq(from=start, by='d', length.out=plot.points), cap=1.)
df.y <- setup_dataframe(m, df.y)$df
seas <- predict_seasonal_components(m, df.y)
seas$ds <- df.y$ds
gg.s <- ggplot2::ggplot(
seas, ggplot2::aes_string(x = 'ds', y = name, group = 1)) +
ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
if (period < 14) {
fmt.str <- '%m/%d %R'
} else {
fmt.str <- '%m/%d'
}
gg.s <- gg.s + ggplot2::scale_x_date(labels = scales::date_format(fmt.str))
if (uncertainty) {
gg.s <- gg.s +
ggplot2::geom_ribbon(

View file

@ -2,7 +2,7 @@ library(prophet)
context("Prophet tests")
DATA <- read.csv('data.csv')
DATA$ds <- as.Date(DATA$ds)
DATA$ds <- set_date(DATA$ds)
N <- nrow(DATA)
train <- DATA[1:floor(N / 2), ]
future <- DATA[(ceiling(N/2) + 1):N, ]
@ -27,9 +27,9 @@ test_that("fit_predict_no_changepoints", {
test_that("fit_predict_changepoint_not_in_history", {
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
train_t <- dplyr::mutate(DATA, ds=zoo::as.Date(ds))
train_t <- dplyr::filter(train_t, (ds < zoo::as.Date('2013-01-01')) |
(ds > zoo::as.Date('2014-01-01')))
train_t <- dplyr::mutate(DATA, ds=set_date(ds))
train_t <- dplyr::filter(train_t, (ds < set_date('2013-01-01')) |
(ds > set_date('2014-01-01')))
future <- data.frame(ds=DATA$ds)
m <- prophet(train_t, changepoints=c('2013-06-06'))
expect_error(predict(m, future), NA)
@ -170,19 +170,19 @@ test_that("piecewise_logistic", {
})
test_that("holidays", {
holidays = data.frame(ds = zoo::as.Date(c('2016-12-25')),
holidays = data.frame(ds = set_date(c('2016-12-25')),
holiday = c('xmas'),
lower_window = c(-1),
upper_window = c(0))
df <- data.frame(
ds = seq(zoo::as.Date('2016-12-20'), zoo::as.Date('2016-12-31'), by='d'))
ds = seq(set_date('2016-12-20'), set_date('2016-12-31'), by='d'))
m <- prophet(train, holidays = holidays, fit = FALSE)
feats <- prophet:::make_holiday_features(m, df$ds)
expect_equal(nrow(feats), nrow(df))
expect_equal(ncol(feats), 2)
expect_equal(sum(colSums(feats) - c(1, 1)), 0)
holidays = data.frame(ds = zoo::as.Date(c('2016-12-25')),
holidays = data.frame(ds = set_date(c('2016-12-25')),
holiday = c('xmas'),
lower_window = c(-1),
upper_window = c(10))
@ -194,7 +194,7 @@ test_that("holidays", {
test_that("fit_with_holidays", {
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
holidays <- data.frame(ds = zoo::as.Date(c('2012-06-06', '2013-06-06')),
holidays <- data.frame(ds = set_date(c('2012-06-06', '2013-06-06')),
holiday = c('seans-bday', 'seans-bday'),
lower_window = c(0, 0),
upper_window = c(1, 1))
@ -206,14 +206,14 @@ test_that("make_future_dataframe", {
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
train.t <- DATA[1:234, ]
m <- prophet(train.t)
future <- make_future_dataframe(m, periods = 3, freq = 'd',
future <- make_future_dataframe(m, periods = 3, freq = 'day',
include_history = FALSE)
correct <- as.Date(c('2013-04-26', '2013-04-27', '2013-04-28'))
correct <- set_date(c('2013-04-26', '2013-04-27', '2013-04-28'))
expect_equal(future$ds, correct)
future <- make_future_dataframe(m, periods = 3, freq = 'm',
future <- make_future_dataframe(m, periods = 3, freq = 'month',
include_history = FALSE)
correct <- as.Date(c('2013-05-25', '2013-06-25', '2013-07-25'))
correct <- set_date(c('2013-05-25', '2013-06-25', '2013-07-25'))
expect_equal(future$ds, correct)
})
@ -263,7 +263,7 @@ test_that("auto_yearly_seasonality", {
test_that("custom_seasonality", {
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
holidays <- data.frame(ds = zoo::as.Date(c('2017-01-02')),
holidays <- data.frame(ds = set_date(c('2017-01-02')),
holiday = c('special_day'))
m <- prophet(holidays=holidays)
m <- add_seasonality(m, name='monthly', period=30, fourier.order=5)