From b0938df109c2f77f855fce3bd87ca30fc5b64269 Mon Sep 17 00:00:00 2001 From: Qi Wang Date: Mon, 10 Jul 2017 18:46:49 -0700 Subject: [PATCH] Add support for fitting daily seasonality, make holiday features work when daily seasonality is enabled (#246) * Add support for fitting daily seasonality, make holiday features work when daily seasonality is enabled * fix wrong comment in make_future_dataframe() --- R/R/prophet.R | 142 +++++++++++++++++++++++++------- R/tests/testthat/test_prophet.R | 26 +++--- 2 files changed, 125 insertions(+), 43 deletions(-) diff --git a/R/R/prophet.R b/R/R/prophet.R index 9ae7268..f3c4fd4 100644 --- a/R/R/prophet.R +++ b/R/R/prophet.R @@ -33,6 +33,8 @@ globalVariables(c( #' FALSE, or a number of Fourier terms to generate. #' @param weekly.seasonality Fit weekly seasonality. Can be 'auto', TRUE, #' FALSE, or a number of Fourier terms to generate. +#' @param daily.seasonality Fit daily seasonality. Can be 'auto', TRUE, +#' FALSE, or a number of Fourier terms to generate. #' @param holidays data frame with columns holiday (character) and ds (date #' type)and optionally columns lower_window and upper_window which specify a #' range of days around the date to be included as holidays. lower_window=-2 @@ -76,6 +78,7 @@ prophet <- function(df = NULL, n.changepoints = 25, yearly.seasonality = 'auto', weekly.seasonality = 'auto', + daily.seasonality = 'auto', holidays = NULL, seasonality.prior.scale = 10, holidays.prior.scale = 10, @@ -98,6 +101,7 @@ prophet <- function(df = NULL, n.changepoints = n.changepoints, yearly.seasonality = yearly.seasonality, weekly.seasonality = weekly.seasonality, + daily.seasonality = daily.seasonality, holidays = holidays, seasonality.prior.scale = seasonality.prior.scale, changepoint.prior.scale = changepoint.prior.scale, @@ -206,6 +210,47 @@ compile_stan_model <- function(model) { return(rstan::stan_model(stanc_ret = stanc, model_name = model.name)) } +#' Convert date vector +#' +#' Convert the date to POSIXct object +#' +#' @param ds Date vector, can be consisted of characters +#' +#' @return vector of POSIXct object converted from date +#' +set_date <- function(ds = NULL, tz = "GMT") { + if (length(ds) == 0) { + return(NULL) + } + + if (is.factor(ds)) { + ds <- as.character(ds) + } + + if (min(nchar(ds)) < 12) { + ds <- as.POSIXct(ds, format = "%Y-%m-%d", tz = tz) + } else { + ds <- as.POSIXct(ds, format = "%Y-%m-%d %H:%M:%S", tz = tz) + } + return(ds) +} + +#' Extract hour +#' +#' Extract hour from a POSIXct object +#' +#' @param ds POSIXct object +#' +#' @return hour of POSIXct object +#' +get_hour <- function(ds) { + if (!("POSIXct" %in% is(ds))) { + stop("ds must be a POSIXct object, use function set_date() to convert first.") + } + + return(format(ds , "%H")) +} + #' Prepare dataframe for fitting or predicting. #' #' Adds a time index and scales y. Creates auxillary columns 't', 't_ix', @@ -222,9 +267,9 @@ setup_dataframe <- function(m, df, initialize_scales = FALSE) { if (exists('y', where=df)) { df$y <- as.numeric(df$y) } - df$ds <- zoo::as.Date(df$ds) + df$ds <- set_date(df$ds) if (anyNA(df$ds)) { - stop('Unable to parse date format in column ds. Convert to date format.') + stop('Unable to parse date format in column ds. Convert to date format. Either %Y-%m-%d or %Y-%m-%d %H:%M:%S') } df <- df %>% @@ -233,10 +278,10 @@ setup_dataframe <- function(m, df, initialize_scales = FALSE) { if (initialize_scales) { m$y.scale <- max(abs(df$y)) m$start <- min(df$ds) - m$t.scale <- as.numeric(max(df$ds) - m$start) + m$t.scale <- as.numeric(difftime(max(df$ds), m$start, units = "secs")) } - df$t <- as.numeric(df$ds - m$start) / m$t.scale + df$t <- as.numeric(difftime(df$ds, m$start, units = "secs")) / m$t.scale if (exists('y', where=df)) { df$y_scaled <- df$y / m$y.scale } @@ -285,8 +330,8 @@ set_changepoints <- function(m) { } } if (length(m$changepoints) > 0) { - m$changepoints <- zoo::as.Date(m$changepoints) - m$changepoints.t <- sort(as.numeric(m$changepoints - m$start) / m$t.scale) + m$changepoints <- set_date(m$changepoints) + m$changepoints.t <- sort(as.numeric(difftime(m$changepoints, m$start, units = "secs"))) / m$t.scale } else { m$changepoints.t <- c(0) # dummy changepoint } @@ -316,7 +361,7 @@ get_changepoint_matrix <- function(m) { #' @return Matrix with seasonality features. #' fourier_series <- function(dates, period, series.order) { - t <- dates - zoo::as.Date('1970-01-01') + t <- as.numeric(difftime(dates, set_date('1970-01-01 00:00:00'), units = 'days')) features <- matrix(0, length(t), 2 * series.order) for (i in 1:series.order) { x <- as.numeric(2 * i * pi * t / period) @@ -352,7 +397,7 @@ make_seasonality_features <- function(dates, period, series.order, prefix) { make_holiday_features <- function(m, dates) { scale.ratio <- m$holidays.prior.scale / m$seasonality.prior.scale wide <- m$holidays %>% - dplyr::mutate(ds = zoo::as.Date(ds)) %>% + dplyr::mutate(ds = set_date(ds)) %>% dplyr::group_by(holiday, ds) %>% dplyr::filter(row_number() == 1) %>% dplyr::do({ @@ -364,7 +409,7 @@ make_holiday_features <- function(m, dates) { } names <- paste( .$holiday, '_delim_', ifelse(offsets < 0, '-', '+'), abs(offsets), sep = '') - dplyr::data_frame(ds = .$ds + offsets, holiday = names) + dplyr::data_frame(ds = .$ds + offsets * 24 * 3600, holiday = names) }) %>% dplyr::mutate(x = scale.ratio) %>% tidyr::spread(holiday, x, fill = 0) @@ -472,22 +517,29 @@ parse_seasonality_args <- function(m, name, arg, auto.disable, default.order) { set_auto_seasonalities <- function(m) { first <- min(m$history$ds) last <- max(m$history$ds) - dt <- diff(m$history$ds) + dt <- diff(as.numeric(difftime(m$history$ds, m$start, units = "d"))) min.dt <- min(dt[dt > 0]) - yearly.disable <- last - first < 730 + yearly.disable <- as.numeric(difftime(last, first, unit = "days")) < 730 fourier.order <- parse_seasonality_args( m, 'yearly', m$yearly.seasonality, yearly.disable, 10) if (fourier.order > 0) { m$seasonalities[['yearly']] <- c(365.25, fourier.order) } - weekly.disable <- ((last - first < 14) || (min.dt >= 7)) + weekly.disable <- ((as.numeric(difftime(last, first, unit = "days")) < 14) || (min.dt >= 7)) fourier.order <- parse_seasonality_args( m, 'weekly', m$weekly.seasonality, weekly.disable, 3) if (fourier.order > 0) { m$seasonalities[['weekly']] <- c(7, fourier.order) } + + daily.disable <- ((as.numeric(difftime(last, first, unit = "days")) < 2)) || (min.dt >= 1) + fourier.order <- parse_seasonality_args( + m, 'daily', m$daily.seasonality, daily.disable, 4) + if (fourier.order > 0) { + m$seasonalities[['daily']] <- c(1, fourier.order) + } return(m) } @@ -571,7 +623,7 @@ fit.prophet <- function(m, df, ...) { if (any(is.infinite(history$y))) { stop("Found infinity in column y.") } - m$history.dates <- sort(zoo::as.Date(df$ds)) + m$history.dates <- sort(set_date(df$ds)) out <- setup_dataframe(m, history, initialize_scales = TRUE) history <- out$df @@ -985,7 +1037,7 @@ sample_predictive_trend <- function(model, df, iteration) { #' #' @param m Prophet model object. #' @param periods Int number of periods to forecast forward. -#' @param freq 'day', 'week', 'month', 'quarter', or 'year'. +#' @param freq 'day', 'week', 'month', 'quarter', 'year', 1(1 sec), 60(1 minute) or 3600(1 hour). #' @param include_history Boolean to include the historical dates in the data #' frame for predictions. #' @@ -993,7 +1045,7 @@ sample_predictive_trend <- function(model, df, iteration) { #' requested number of periods. #' #' @export -make_future_dataframe <- function(m, periods, freq = 'd', +make_future_dataframe <- function(m, periods, freq = 'day', include_history = TRUE) { dates <- seq(max(m$history.dates), length.out = periods + 1, by = freq) dates <- dates[2:(periods + 1)] # Drop the first, which is max(history$ds) @@ -1091,7 +1143,7 @@ plot.prophet <- function(x, fcst, uncertainty = TRUE, plot_cap = TRUE, #' @importFrom dplyr "%>%" prophet_plot_components <- function( m, fcst, uncertainty = TRUE, plot_cap = TRUE, weekly_start = 0, - yearly_start = 0) { + yearly_start = 0, daily_start = 0) { df <- df_for_plotting(m, fcst) # Plot the trend panels <- list(plot_trend(df, uncertainty, plot_cap)) @@ -1099,6 +1151,10 @@ prophet_plot_components <- function( if (!is.null(m$holidays)) { panels[[length(panels) + 1]] <- plot_holidays(m, df, uncertainty) } + # Plot daily seasonality, if present + if ("daily" %in% colnames(df)) { + panels[[length(panels) + 1]] <- plot_daily(m, uncertainty, daily_start) + } # Plot weekly seasonality, if present if ("weekly" %in% colnames(df)) { panels[[length(panels) + 1]] <- plot_weekly(m, uncertainty, weekly_start) @@ -1109,7 +1165,7 @@ prophet_plot_components <- function( } # Plot other seasonalities for (name in names(m$seasonalities)) { - if (!(name %in% c('weekly', 'yearly')) && (name %in% colnames(df))) { + if (!(name %in% c('daily', 'weekly', 'yearly')) && (name %in% colnames(df))) { panels[[length(panels) + 1]] <- plot_seasonality(m, name, uncertainty) } } @@ -1184,6 +1240,39 @@ plot_holidays <- function(m, df, uncertainty = TRUE) { return(gg.holidays) } +#' Plot the daily component of the forecast. +#' +#' @param m Prophet model object +#' @param uncertainty Boolean to plot uncertainty intervals. +#' @param daily_start Integer specifying the start day of the daily +#' seasonality plot. 0 (default) starts the week on Sunday. 1 shifts by 1 day +#' to Monday, and so on. +#' +#' @return A ggplot2 plot. +plot_daily <- function(m, uncertainty = TRUE, daily_start = 0) { + # Compute weekly seasonality for a Sun-Sat sequence of dates. + df.d <- data.frame( + ds=seq(set_date('2017-01-01 00:00:00'), length.out=24, by = "hour") + + daily_start, cap=1.) + df.d <- setup_dataframe(m, df.d)$df + seas <- predict_seasonal_components(m, df.d) + seas$hod <- factor(get_hour(df.d$ds), levels=get_hour(df.d$ds)) + + gg.daily <- ggplot2::ggplot(seas, ggplot2::aes(x = hod, y = daily, + group = 1)) + + ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) + + ggplot2::labs(x = "Hour of day") + if (uncertainty) { + gg.daily <- gg.daily + + ggplot2::geom_ribbon(ggplot2::aes(ymin = daily_lower, + ymax = daily_upper), + alpha = 0.2, + fill = "#0072B2", + na.rm = TRUE) + } + return(gg.daily) +} + #' Plot the weekly component of the forecast. #' #' @param m Prophet model object @@ -1196,7 +1285,7 @@ plot_holidays <- function(m, df, uncertainty = TRUE) { plot_weekly <- function(m, uncertainty = TRUE, weekly_start = 0) { # Compute weekly seasonality for a Sun-Sat sequence of dates. df.w <- data.frame( - ds=seq.Date(zoo::as.Date('2017-01-01'), by='d', length.out=7) + + ds=seq(set_date('2017-01-01'), by='d', length.out=7) + weekly_start, cap=1.) df.w <- setup_dataframe(m, df.w)$df seas <- predict_seasonal_components(m, df.w) @@ -1229,7 +1318,7 @@ plot_weekly <- function(m, uncertainty = TRUE, weekly_start = 0) { plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0) { # Compute yearly seasonality for a Jan 1 - Dec 31 sequence of dates. df.y <- data.frame( - ds=seq.Date(zoo::as.Date('2017-01-01'), by='d', length.out=365) + + ds=seq(set_date('2017-01-01'), by='d', length.out=365) + yearly_start, cap=1.) df.y <- setup_dataframe(m, df.y)$df seas <- predict_seasonal_components(m, df.y) @@ -1238,7 +1327,6 @@ plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0) { gg.yearly <- ggplot2::ggplot(seas, ggplot2::aes(x = ds, y = yearly, group = 1)) + ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) + - ggplot2::scale_x_date(labels = scales::date_format('%B %d')) + ggplot2::labs(x = "Day of year") if (uncertainty) { gg.yearly <- gg.yearly + @@ -1260,24 +1348,18 @@ plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0) { #' @return A ggplot2 plot. plot_seasonality <- function(m, name, uncertainty = TRUE) { # Compute seasonality from Jan 1 through a single period. - start <- zoo::as.Date('2017-01-01') + start <- set_date('2017-01-01') period <- m$seasonalities[[name]][1] - end <- start + period - plot.points <- as.numeric(end - start) + end <- start + period * 24 * 3600 + plot.points <- as.numeric(difftime(end, start)) df.y <- data.frame( - ds=seq.Date(from=start, to=end, length.out=plot.points), cap=1.) + ds=seq(from=start, by='d', length.out=plot.points), cap=1.) df.y <- setup_dataframe(m, df.y)$df seas <- predict_seasonal_components(m, df.y) seas$ds <- df.y$ds gg.s <- ggplot2::ggplot( seas, ggplot2::aes_string(x = 'ds', y = name, group = 1)) + ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) - if (period < 14) { - fmt.str <- '%m/%d %R' - } else { - fmt.str <- '%m/%d' - } - gg.s <- gg.s + ggplot2::scale_x_date(labels = scales::date_format(fmt.str)) if (uncertainty) { gg.s <- gg.s + ggplot2::geom_ribbon( diff --git a/R/tests/testthat/test_prophet.R b/R/tests/testthat/test_prophet.R index c8d0c91..f65a479 100644 --- a/R/tests/testthat/test_prophet.R +++ b/R/tests/testthat/test_prophet.R @@ -2,7 +2,7 @@ library(prophet) context("Prophet tests") DATA <- read.csv('data.csv') -DATA$ds <- as.Date(DATA$ds) +DATA$ds <- set_date(DATA$ds) N <- nrow(DATA) train <- DATA[1:floor(N / 2), ] future <- DATA[(ceiling(N/2) + 1):N, ] @@ -27,9 +27,9 @@ test_that("fit_predict_no_changepoints", { test_that("fit_predict_changepoint_not_in_history", { skip_if_not(Sys.getenv('R_ARCH') != '/i386') - train_t <- dplyr::mutate(DATA, ds=zoo::as.Date(ds)) - train_t <- dplyr::filter(train_t, (ds < zoo::as.Date('2013-01-01')) | - (ds > zoo::as.Date('2014-01-01'))) + train_t <- dplyr::mutate(DATA, ds=set_date(ds)) + train_t <- dplyr::filter(train_t, (ds < set_date('2013-01-01')) | + (ds > set_date('2014-01-01'))) future <- data.frame(ds=DATA$ds) m <- prophet(train_t, changepoints=c('2013-06-06')) expect_error(predict(m, future), NA) @@ -170,19 +170,19 @@ test_that("piecewise_logistic", { }) test_that("holidays", { - holidays = data.frame(ds = zoo::as.Date(c('2016-12-25')), + holidays = data.frame(ds = set_date(c('2016-12-25')), holiday = c('xmas'), lower_window = c(-1), upper_window = c(0)) df <- data.frame( - ds = seq(zoo::as.Date('2016-12-20'), zoo::as.Date('2016-12-31'), by='d')) + ds = seq(set_date('2016-12-20'), set_date('2016-12-31'), by='d')) m <- prophet(train, holidays = holidays, fit = FALSE) feats <- prophet:::make_holiday_features(m, df$ds) expect_equal(nrow(feats), nrow(df)) expect_equal(ncol(feats), 2) expect_equal(sum(colSums(feats) - c(1, 1)), 0) - holidays = data.frame(ds = zoo::as.Date(c('2016-12-25')), + holidays = data.frame(ds = set_date(c('2016-12-25')), holiday = c('xmas'), lower_window = c(-1), upper_window = c(10)) @@ -194,7 +194,7 @@ test_that("holidays", { test_that("fit_with_holidays", { skip_if_not(Sys.getenv('R_ARCH') != '/i386') - holidays <- data.frame(ds = zoo::as.Date(c('2012-06-06', '2013-06-06')), + holidays <- data.frame(ds = set_date(c('2012-06-06', '2013-06-06')), holiday = c('seans-bday', 'seans-bday'), lower_window = c(0, 0), upper_window = c(1, 1)) @@ -206,14 +206,14 @@ test_that("make_future_dataframe", { skip_if_not(Sys.getenv('R_ARCH') != '/i386') train.t <- DATA[1:234, ] m <- prophet(train.t) - future <- make_future_dataframe(m, periods = 3, freq = 'd', + future <- make_future_dataframe(m, periods = 3, freq = 'day', include_history = FALSE) - correct <- as.Date(c('2013-04-26', '2013-04-27', '2013-04-28')) + correct <- set_date(c('2013-04-26', '2013-04-27', '2013-04-28')) expect_equal(future$ds, correct) - future <- make_future_dataframe(m, periods = 3, freq = 'm', + future <- make_future_dataframe(m, periods = 3, freq = 'month', include_history = FALSE) - correct <- as.Date(c('2013-05-25', '2013-06-25', '2013-07-25')) + correct <- set_date(c('2013-05-25', '2013-06-25', '2013-07-25')) expect_equal(future$ds, correct) }) @@ -263,7 +263,7 @@ test_that("auto_yearly_seasonality", { test_that("custom_seasonality", { skip_if_not(Sys.getenv('R_ARCH') != '/i386') - holidays <- data.frame(ds = zoo::as.Date(c('2017-01-02')), + holidays <- data.frame(ds = set_date(c('2017-01-02')), holiday = c('special_day')) m <- prophet(holidays=holidays) m <- add_seasonality(m, name='monthly', period=30, fourier.order=5)