Custom seasonalities in R

This commit is contained in:
bl 2017-07-05 01:20:22 -07:00
parent 707c885275
commit 8be35c2f34
9 changed files with 338 additions and 89 deletions

View file

@ -2,6 +2,7 @@
S3method(plot,prophet)
S3method(predict,prophet)
export(add_seasonality)
export(fit.prophet)
export(make_future_dataframe)
export(predictive_samples)

View file

@ -15,9 +15,11 @@ globalVariables(c(
#' Prophet forecaster.
#'
#' @param df Dataframe containing the history. Must have columns ds (date type)
#' and y, the time series. If growth is logistic, then df must also have a
#' column cap that specifies the capacity at each ds.
#' @param df (optional) Dataframe containing the history. Must have columns ds
#' (date type) and y, the time series. If growth is logistic, then df must
#' also have a column cap that specifies the capacity at each ds. If not
#' provided, then the model object will be instantiated but not fit; use
#' fit.prophet(m, df) to fit the model.
#' @param growth String 'linear' or 'logistic' to specify a linear or logistic
#' trend.
#' @param changepoints Vector of dates at which to include potential
@ -27,8 +29,10 @@ globalVariables(c(
#' if input `changepoints` is supplied. If `changepoints` is not supplied,
#' then n.changepoints potential changepoints are selected uniformly from the
#' first 80 percent of df$ds.
#' @param yearly.seasonality Fit yearly seasonality; 'auto', TRUE, or FALSE.
#' @param weekly.seasonality Fit weekly seasonality; 'auto', TRUE, or FALSE.
#' @param yearly.seasonality Fit yearly seasonality. Can be 'auto', TRUE,
#' FALSE, or a number of Fourier terms to generate.
#' @param weekly.seasonality Fit weekly seasonality. Can be 'auto', TRUE,
#' FALSE, or a number of Fourier terms to generate.
#' @param holidays data frame with columns holiday (character) and ds (date
#' type)and optionally columns lower_window and upper_window which specify a
#' range of days around the date to be included as holidays. lower_window=-2
@ -66,7 +70,7 @@ globalVariables(c(
#' @export
#' @importFrom dplyr "%>%"
#' @import Rcpp
prophet <- function(df = df,
prophet <- function(df = NULL,
growth = 'linear',
changepoints = NULL,
n.changepoints = 25,
@ -105,6 +109,7 @@ prophet <- function(df = df,
y.scale = NULL,
t.scale = NULL,
changepoints.t = NULL,
seasonalities = list(),
stan.fit = NULL,
params = list(),
history = NULL,
@ -112,7 +117,7 @@ prophet <- function(df = df,
)
validate_inputs(m)
class(m) <- append("prophet", class(m))
if (fit) {
if ((fit) && (!is.null(df))) {
m <- fit.prophet(m, df, ...)
}
@ -372,6 +377,31 @@ make_holiday_features <- function(m, dates) {
return(holiday.mat)
}
#' Add a seasonal component with specified period and number of Fourier
#' components.
#'
#' Increasing the number of Fourier components allows the seasonality to change
#' more quickly (at risk of overfitting).
#'
#' @param m Prophet object.
#' @param name String name of the seasonality component.
#' @param period Float number of days in one period.
#' @param fourier.order Int number of Fourier components to use.
#'
#' @return The prophet model with the seasonality added.
#'
#' @importFrom dplyr "%>%"
#' @export
add_seasonality <- function(m, name, period, fourier.order) {
if (!is.null(m$holidays)) {
if (name %in% (unique(m$holidays$holiday) %>% as.character())) {
stop('Name "', name, '" already used for holiday')
}
}
m$seasonalities[[name]] <- c(period, fourier.order)
return(m)
}
#' Dataframe with seasonality features.
#'
#' @param m Prophet object.
@ -381,19 +411,14 @@ make_holiday_features <- function(m, dates) {
#'
make_all_seasonality_features <- function(m, df) {
seasonal.features <- data.frame(zeros = rep(0, nrow(df)))
if (m$yearly.seasonality) {
for (name in names(m$seasonalities)) {
period <- m$seasonalities[[name]][1]
series.order <- m$seasonalities[[name]][2]
seasonal.features <- cbind(
seasonal.features,
make_seasonality_features(df$ds, 365.25, 10, 'yearly'))
}
if (m$weekly.seasonality) {
seasonal.features <- cbind(
seasonal.features,
make_seasonality_features(df$ds, 7, 3, 'weekly'))
make_seasonality_features(df$ds, period, series.order, name))
}
if(!is.null(m$holidays)) {
# A smaller prior scale will shrink holiday estimates more than seasonality
scale.ratio <- m$holidays.prior.scale / m$seasonality.prior.scale
seasonal.features <- cbind(
seasonal.features,
make_holiday_features(m, df$ds))
@ -401,6 +426,39 @@ make_all_seasonality_features <- function(m, df) {
return(seasonal.features)
}
#' Get number of Fourier components for built-in seasonalities.
#'
#' @param m Prophet object.
#' @param name String name of the seasonality component.
#' @param arg 'auto', TRUE, FALSE, or number of Fourier components as
#' provided.
#' @param auto.disable Bool if seasonality should be disabled when 'auto'.
#' @param default.order Int default Fourier order.
#'
#' @return Number of Fourier components, or 0 for disabled.
#'
parse_seasonality_args <- function(m, name, arg, auto.disable, default.order) {
if (arg == 'auto') {
fourier.order <- 0
if (name %in% names(m$seasonalities)) {
warning('Found custom seasonality named "', name,
'", disabling built-in ', name, ' seasonality.')
} else if (auto.disable) {
warning('Disabling ', name, ' seasonality. Run prophet with ', name,
'.seasonality=TRUE to override this.')
} else {
fourier.order <- default.order
}
} else if (arg == TRUE) {
fourier.order <- default.order
} else if (arg == FALSE) {
fourier.order <- 0
} else {
fourier.order <- arg
}
return(fourier.order)
}
#' Set seasonalities that were left on auto.
#'
#' Turns on yearly seasonality if there is >=2 years of history.
@ -414,25 +472,21 @@ make_all_seasonality_features <- function(m, df) {
set_auto_seasonalities <- function(m) {
first <- min(m$history$ds)
last <- max(m$history$ds)
if (m$yearly.seasonality == 'auto') {
if (last - first < 730) {
warning('Disabling yearly seasonality. ',
'Run prophet with `yearly.seasonality=TRUE` to override this.')
m$yearly.seasonality <- FALSE
} else {
m$yearly.seasonality <- TRUE
}
dt <- diff(m$history$ds)
min.dt <- min(dt[dt > 0])
yearly.disable <- last - first < 730
fourier.order <- parse_seasonality_args(
m, 'yearly', m$yearly.seasonality, yearly.disable, 10)
if (fourier.order > 0) {
m$seasonalities[['yearly']] <- c(365.25, fourier.order)
}
if (m$weekly.seasonality == 'auto') {
dt <- diff(m$history$ds)
min.dt <- min(dt[dt > 0])
if ((last - first < 14) || (min.dt >= 7)) {
warning('Disabling weekly seasonality. ',
'Run prophet with `weekly.seasonality=TRUE` to override this.')
m$weekly.seasonality <- FALSE
} else {
m$weekly.seasonality <- TRUE
}
weekly.disable <- ((last - first < 14) || (min.dt >= 7))
fourier.order <- parse_seasonality_args(
m, 'weekly', m$weekly.seasonality, weekly.disable, 3)
if (fourier.order > 0) {
m$seasonalities[['weekly']] <- c(7, fourier.order)
}
return(m)
}
@ -1053,6 +1107,13 @@ prophet_plot_components <- function(
if ("yearly" %in% colnames(df)) {
panels[[length(panels) + 1]] <- plot_yearly(m, uncertainty, yearly_start)
}
# Plot other seasonalities
for (name in names(m$seasonalities)) {
if (!(name %in% c('weekly', 'yearly')) && (name %in% colnames(df))) {
panels[[length(panels) + 1]] <- plot_seasonality(m, name, uncertainty)
}
}
# Make the plot.
grid::grid.newpage()
grid::pushViewport(grid::viewport(layout = grid::grid.layout(length(panels),
@ -1190,4 +1251,44 @@ plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0) {
return(gg.yearly)
}
#' Plot a custom seasonal component.
#'
#' @param m Prophet model object.
#' @param name String name of the seasonality.
#' @param uncertainty Boolean to plot uncertainty intervals.
#'
#' @return A ggplot2 plot.
plot_seasonality <- function(m, name, uncertainty = TRUE) {
# Compute seasonality from Jan 1 through a single period.
start <- zoo::as.Date('2017-01-01')
period <- m$seasonalities[[name]][1]
end <- start + period
plot.points <- as.numeric(end - start)
df.y <- data.frame(
ds=seq.Date(from=start, to=end, length.out=plot.points), cap=1.)
df.y <- setup_dataframe(m, df.y)$df
seas <- predict_seasonal_components(m, df.y)
seas$ds <- df.y$ds
gg.s <- ggplot2::ggplot(
seas, ggplot2::aes_string(x = 'ds', y = name, group = 1)) +
ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
if (period < 14) {
fmt.str <- '%m/%d %R'
} else {
fmt.str <- '%m/%d'
}
gg.s <- gg.s + ggplot2::scale_x_date(labels = scales::date_format(fmt.str))
if (uncertainty) {
gg.s <- gg.s +
ggplot2::geom_ribbon(
ggplot2::aes_string(
ymin = paste0(name, '_lower'), ymax = paste0(name, '_upper')
),
alpha = 0.2,
fill = "#0072B2",
na.rm = TRUE)
}
return(gg.s)
}
# fb-block 3

25
R/man/add_seasonality.Rd Normal file
View file

@ -0,0 +1,25 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/prophet.R
\name{add_seasonality}
\alias{add_seasonality}
\title{Add a seasonal component with specified period and number of Fourier
components.}
\usage{
add_seasonality(m, name, period, fourier.order)
}
\arguments{
\item{m}{Prophet object.}
\item{name}{String name of the seasonality component.}
\item{period}{Float number of days in one period.}
\item{fourier.order}{Int number of Fourier components to use.}
}
\value{
The prophet model with the seasonality added.
}
\description{
Increasing the number of Fourier components allows the seasonality to change
more quickly (at risk of overfitting).
}

View file

@ -0,0 +1,26 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/prophet.R
\name{parse_seasonality_args}
\alias{parse_seasonality_args}
\title{Get number of Fourier components for built-in seasonalities.}
\usage{
parse_seasonality_args(m, name, arg, auto.disable, default.order)
}
\arguments{
\item{m}{Prophet object.}
\item{name}{String name of the seasonality component.}
\item{arg}{'auto', TRUE, FALSE, or number of Fourier components as
provided.}
\item{auto.disable}{Bool if seasonality should be disabled when 'auto'.}
\item{default.order}{Int default Fourier order.}
}
\value{
Number of Fourier components, or 0 for disabled.
}
\description{
Get number of Fourier components for built-in seasonalities.
}

21
R/man/plot_seasonality.Rd Normal file
View file

@ -0,0 +1,21 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/prophet.R
\name{plot_seasonality}
\alias{plot_seasonality}
\title{Plot a custom seasonal component.}
\usage{
plot_seasonality(m, name, uncertainty = TRUE)
}
\arguments{
\item{m}{Prophet model object.}
\item{name}{String name of the seasonality.}
\item{uncertainty}{Boolean to plot uncertainty intervals.}
}
\value{
A ggplot2 plot.
}
\description{
Plot a custom seasonal component.
}

View file

@ -4,7 +4,7 @@
\alias{prophet}
\title{Prophet forecaster.}
\usage{
prophet(df = df, growth = "linear", changepoints = NULL,
prophet(df = NULL, growth = "linear", changepoints = NULL,
n.changepoints = 25, yearly.seasonality = "auto",
weekly.seasonality = "auto", holidays = NULL,
seasonality.prior.scale = 10, holidays.prior.scale = 10,
@ -12,9 +12,11 @@ prophet(df = df, growth = "linear", changepoints = NULL,
uncertainty.samples = 1000, fit = TRUE, ...)
}
\arguments{
\item{df}{Dataframe containing the history. Must have columns ds (date type)
and y, the time series. If growth is logistic, then df must also have a
column cap that specifies the capacity at each ds.}
\item{df}{(optional) Dataframe containing the history. Must have columns ds
(date type) and y, the time series. If growth is logistic, then df must
also have a column cap that specifies the capacity at each ds. If not
provided, then the model object will be instantiated but not fit; use
fit.prophet(m, df) to fit the model.}
\item{growth}{String 'linear' or 'logistic' to specify a linear or logistic
trend.}
@ -28,9 +30,11 @@ if input `changepoints` is supplied. If `changepoints` is not supplied,
then n.changepoints potential changepoints are selected uniformly from the
first 80 percent of df$ds.}
\item{yearly.seasonality}{Fit yearly seasonality; 'auto', TRUE, or FALSE.}
\item{yearly.seasonality}{Fit yearly seasonality. Can be 'auto', TRUE,
FALSE, or a number of Fourier terms to generate.}
\item{weekly.seasonality}{Fit weekly seasonality; 'auto', TRUE, or FALSE.}
\item{weekly.seasonality}{Fit weekly seasonality. Can be 'auto', TRUE,
FALSE, or a number of Fourier terms to generate.}
\item{holidays}{data frame with columns holiday (character) and ds (date
type)and optionally columns lower_window and upper_window which specify a

View file

@ -219,38 +219,51 @@ test_that("make_future_dataframe", {
test_that("auto_weekly_seasonality", {
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
# Should be True
# Should be enabled
N.w <- 15
train.w <- DATA[1:N.w, ]
m <- prophet(train.w, fit = FALSE)
expect_equal(m$weekly.seasonality, 'auto')
m <- prophet:::fit.prophet(m, train.w)
expect_equal(m$weekly.seasonality, TRUE)
# Should be False due to too short history
expect_true('weekly' %in% names(m$seasonalities))
expect_equal(m$seasonalities[['weekly']], c(7, 3))
# Should be disabled due to too short history
N.w <- 9
train.w <- DATA[1:N.w, ]
m <- prophet(train.w)
expect_equal(m$weekly.seasonality, FALSE)
expect_false('weekly' %in% names(m$seasonalities))
m <- prophet(train.w, weekly.seasonality = TRUE)
expect_equal(m$weekly.seasonality, TRUE)
expect_true('weekly' %in% names(m$seasonalities))
# Should be False due to weekly spacing
train.w <- DATA[seq(1, nrow(DATA), 7), ]
m <- prophet(train.w)
expect_equal(m$weekly.seasonality, FALSE)
expect_false('weekly' %in% names(m$seasonalities))
m <- prophet(DATA, weekly.seasonality=2)
expect_equal(m$seasonalities[['weekly']], c(7, 2))
})
test_that("auto_yearly_seasonality", {
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
# Should be True
# Should be enabled
m <- prophet(DATA, fit = FALSE)
expect_equal(m$yearly.seasonality, 'auto')
m <- prophet:::fit.prophet(m, DATA)
expect_equal(m$yearly.seasonality, TRUE)
# Should be False due to too short history
expect_true('yearly' %in% names(m$seasonalities))
expect_equal(m$seasonalities[['yearly']], c(365.25, 10))
# Should be disabled due to too short history
N.w <- 240
train.y <- DATA[1:N.w, ]
m <- prophet(train.y)
expect_equal(m$yearly.seasonality, FALSE)
expect_false('yearly' %in% names(m$seasonalities))
m <- prophet(train.y, yearly.seasonality = TRUE)
expect_equal(m$yearly.seasonality, TRUE)
expect_true('yearly' %in% names(m$seasonalities))
m <- prophet(DATA, yearly.seasonality=7)
expect_equal(m$seasonalities[['yearly']], c(365.25, 7))
})
test_that("custom_seasonality", {
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
m <- prophet()
m <- add_seasonality(m, name='monthly', period=30, fourier.order=5)
expect_equal(m$seasonalities[['monthly']], c(30, 5))
})

View file

@ -4,7 +4,7 @@
- id: quick_start
- id: forecasting_growth
- id: trend_changepoints
- id: holiday_effects
- id: seasonality_and_holiday_effects
- id: uncertainty_intervals
- id: outliers
- id: non-daily_data

File diff suppressed because one or more lines are too long