mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-05-29 23:06:49 +00:00
Custom seasonalities in R
This commit is contained in:
parent
707c885275
commit
8be35c2f34
9 changed files with 338 additions and 89 deletions
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
S3method(plot,prophet)
|
||||
S3method(predict,prophet)
|
||||
export(add_seasonality)
|
||||
export(fit.prophet)
|
||||
export(make_future_dataframe)
|
||||
export(predictive_samples)
|
||||
|
|
|
|||
169
R/R/prophet.R
169
R/R/prophet.R
|
|
@ -15,9 +15,11 @@ globalVariables(c(
|
|||
|
||||
#' Prophet forecaster.
|
||||
#'
|
||||
#' @param df Dataframe containing the history. Must have columns ds (date type)
|
||||
#' and y, the time series. If growth is logistic, then df must also have a
|
||||
#' column cap that specifies the capacity at each ds.
|
||||
#' @param df (optional) Dataframe containing the history. Must have columns ds
|
||||
#' (date type) and y, the time series. If growth is logistic, then df must
|
||||
#' also have a column cap that specifies the capacity at each ds. If not
|
||||
#' provided, then the model object will be instantiated but not fit; use
|
||||
#' fit.prophet(m, df) to fit the model.
|
||||
#' @param growth String 'linear' or 'logistic' to specify a linear or logistic
|
||||
#' trend.
|
||||
#' @param changepoints Vector of dates at which to include potential
|
||||
|
|
@ -27,8 +29,10 @@ globalVariables(c(
|
|||
#' if input `changepoints` is supplied. If `changepoints` is not supplied,
|
||||
#' then n.changepoints potential changepoints are selected uniformly from the
|
||||
#' first 80 percent of df$ds.
|
||||
#' @param yearly.seasonality Fit yearly seasonality; 'auto', TRUE, or FALSE.
|
||||
#' @param weekly.seasonality Fit weekly seasonality; 'auto', TRUE, or FALSE.
|
||||
#' @param yearly.seasonality Fit yearly seasonality. Can be 'auto', TRUE,
|
||||
#' FALSE, or a number of Fourier terms to generate.
|
||||
#' @param weekly.seasonality Fit weekly seasonality. Can be 'auto', TRUE,
|
||||
#' FALSE, or a number of Fourier terms to generate.
|
||||
#' @param holidays data frame with columns holiday (character) and ds (date
|
||||
#' type)and optionally columns lower_window and upper_window which specify a
|
||||
#' range of days around the date to be included as holidays. lower_window=-2
|
||||
|
|
@ -66,7 +70,7 @@ globalVariables(c(
|
|||
#' @export
|
||||
#' @importFrom dplyr "%>%"
|
||||
#' @import Rcpp
|
||||
prophet <- function(df = df,
|
||||
prophet <- function(df = NULL,
|
||||
growth = 'linear',
|
||||
changepoints = NULL,
|
||||
n.changepoints = 25,
|
||||
|
|
@ -105,6 +109,7 @@ prophet <- function(df = df,
|
|||
y.scale = NULL,
|
||||
t.scale = NULL,
|
||||
changepoints.t = NULL,
|
||||
seasonalities = list(),
|
||||
stan.fit = NULL,
|
||||
params = list(),
|
||||
history = NULL,
|
||||
|
|
@ -112,7 +117,7 @@ prophet <- function(df = df,
|
|||
)
|
||||
validate_inputs(m)
|
||||
class(m) <- append("prophet", class(m))
|
||||
if (fit) {
|
||||
if ((fit) && (!is.null(df))) {
|
||||
m <- fit.prophet(m, df, ...)
|
||||
}
|
||||
|
||||
|
|
@ -372,6 +377,31 @@ make_holiday_features <- function(m, dates) {
|
|||
return(holiday.mat)
|
||||
}
|
||||
|
||||
#' Add a seasonal component with specified period and number of Fourier
|
||||
#' components.
|
||||
#'
|
||||
#' Increasing the number of Fourier components allows the seasonality to change
|
||||
#' more quickly (at risk of overfitting).
|
||||
#'
|
||||
#' @param m Prophet object.
|
||||
#' @param name String name of the seasonality component.
|
||||
#' @param period Float number of days in one period.
|
||||
#' @param fourier.order Int number of Fourier components to use.
|
||||
#'
|
||||
#' @return The prophet model with the seasonality added.
|
||||
#'
|
||||
#' @importFrom dplyr "%>%"
|
||||
#' @export
|
||||
add_seasonality <- function(m, name, period, fourier.order) {
|
||||
if (!is.null(m$holidays)) {
|
||||
if (name %in% (unique(m$holidays$holiday) %>% as.character())) {
|
||||
stop('Name "', name, '" already used for holiday')
|
||||
}
|
||||
}
|
||||
m$seasonalities[[name]] <- c(period, fourier.order)
|
||||
return(m)
|
||||
}
|
||||
|
||||
#' Dataframe with seasonality features.
|
||||
#'
|
||||
#' @param m Prophet object.
|
||||
|
|
@ -381,19 +411,14 @@ make_holiday_features <- function(m, dates) {
|
|||
#'
|
||||
make_all_seasonality_features <- function(m, df) {
|
||||
seasonal.features <- data.frame(zeros = rep(0, nrow(df)))
|
||||
if (m$yearly.seasonality) {
|
||||
for (name in names(m$seasonalities)) {
|
||||
period <- m$seasonalities[[name]][1]
|
||||
series.order <- m$seasonalities[[name]][2]
|
||||
seasonal.features <- cbind(
|
||||
seasonal.features,
|
||||
make_seasonality_features(df$ds, 365.25, 10, 'yearly'))
|
||||
}
|
||||
if (m$weekly.seasonality) {
|
||||
seasonal.features <- cbind(
|
||||
seasonal.features,
|
||||
make_seasonality_features(df$ds, 7, 3, 'weekly'))
|
||||
make_seasonality_features(df$ds, period, series.order, name))
|
||||
}
|
||||
if(!is.null(m$holidays)) {
|
||||
# A smaller prior scale will shrink holiday estimates more than seasonality
|
||||
scale.ratio <- m$holidays.prior.scale / m$seasonality.prior.scale
|
||||
seasonal.features <- cbind(
|
||||
seasonal.features,
|
||||
make_holiday_features(m, df$ds))
|
||||
|
|
@ -401,6 +426,39 @@ make_all_seasonality_features <- function(m, df) {
|
|||
return(seasonal.features)
|
||||
}
|
||||
|
||||
#' Get number of Fourier components for built-in seasonalities.
|
||||
#'
|
||||
#' @param m Prophet object.
|
||||
#' @param name String name of the seasonality component.
|
||||
#' @param arg 'auto', TRUE, FALSE, or number of Fourier components as
|
||||
#' provided.
|
||||
#' @param auto.disable Bool if seasonality should be disabled when 'auto'.
|
||||
#' @param default.order Int default Fourier order.
|
||||
#'
|
||||
#' @return Number of Fourier components, or 0 for disabled.
|
||||
#'
|
||||
parse_seasonality_args <- function(m, name, arg, auto.disable, default.order) {
|
||||
if (arg == 'auto') {
|
||||
fourier.order <- 0
|
||||
if (name %in% names(m$seasonalities)) {
|
||||
warning('Found custom seasonality named "', name,
|
||||
'", disabling built-in ', name, ' seasonality.')
|
||||
} else if (auto.disable) {
|
||||
warning('Disabling ', name, ' seasonality. Run prophet with ', name,
|
||||
'.seasonality=TRUE to override this.')
|
||||
} else {
|
||||
fourier.order <- default.order
|
||||
}
|
||||
} else if (arg == TRUE) {
|
||||
fourier.order <- default.order
|
||||
} else if (arg == FALSE) {
|
||||
fourier.order <- 0
|
||||
} else {
|
||||
fourier.order <- arg
|
||||
}
|
||||
return(fourier.order)
|
||||
}
|
||||
|
||||
#' Set seasonalities that were left on auto.
|
||||
#'
|
||||
#' Turns on yearly seasonality if there is >=2 years of history.
|
||||
|
|
@ -414,25 +472,21 @@ make_all_seasonality_features <- function(m, df) {
|
|||
set_auto_seasonalities <- function(m) {
|
||||
first <- min(m$history$ds)
|
||||
last <- max(m$history$ds)
|
||||
if (m$yearly.seasonality == 'auto') {
|
||||
if (last - first < 730) {
|
||||
warning('Disabling yearly seasonality. ',
|
||||
'Run prophet with `yearly.seasonality=TRUE` to override this.')
|
||||
m$yearly.seasonality <- FALSE
|
||||
} else {
|
||||
m$yearly.seasonality <- TRUE
|
||||
}
|
||||
dt <- diff(m$history$ds)
|
||||
min.dt <- min(dt[dt > 0])
|
||||
|
||||
yearly.disable <- last - first < 730
|
||||
fourier.order <- parse_seasonality_args(
|
||||
m, 'yearly', m$yearly.seasonality, yearly.disable, 10)
|
||||
if (fourier.order > 0) {
|
||||
m$seasonalities[['yearly']] <- c(365.25, fourier.order)
|
||||
}
|
||||
if (m$weekly.seasonality == 'auto') {
|
||||
dt <- diff(m$history$ds)
|
||||
min.dt <- min(dt[dt > 0])
|
||||
if ((last - first < 14) || (min.dt >= 7)) {
|
||||
warning('Disabling weekly seasonality. ',
|
||||
'Run prophet with `weekly.seasonality=TRUE` to override this.')
|
||||
m$weekly.seasonality <- FALSE
|
||||
} else {
|
||||
m$weekly.seasonality <- TRUE
|
||||
}
|
||||
|
||||
weekly.disable <- ((last - first < 14) || (min.dt >= 7))
|
||||
fourier.order <- parse_seasonality_args(
|
||||
m, 'weekly', m$weekly.seasonality, weekly.disable, 3)
|
||||
if (fourier.order > 0) {
|
||||
m$seasonalities[['weekly']] <- c(7, fourier.order)
|
||||
}
|
||||
return(m)
|
||||
}
|
||||
|
|
@ -1053,6 +1107,13 @@ prophet_plot_components <- function(
|
|||
if ("yearly" %in% colnames(df)) {
|
||||
panels[[length(panels) + 1]] <- plot_yearly(m, uncertainty, yearly_start)
|
||||
}
|
||||
# Plot other seasonalities
|
||||
for (name in names(m$seasonalities)) {
|
||||
if (!(name %in% c('weekly', 'yearly')) && (name %in% colnames(df))) {
|
||||
panels[[length(panels) + 1]] <- plot_seasonality(m, name, uncertainty)
|
||||
}
|
||||
}
|
||||
|
||||
# Make the plot.
|
||||
grid::grid.newpage()
|
||||
grid::pushViewport(grid::viewport(layout = grid::grid.layout(length(panels),
|
||||
|
|
@ -1190,4 +1251,44 @@ plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0) {
|
|||
return(gg.yearly)
|
||||
}
|
||||
|
||||
#' Plot a custom seasonal component.
|
||||
#'
|
||||
#' @param m Prophet model object.
|
||||
#' @param name String name of the seasonality.
|
||||
#' @param uncertainty Boolean to plot uncertainty intervals.
|
||||
#'
|
||||
#' @return A ggplot2 plot.
|
||||
plot_seasonality <- function(m, name, uncertainty = TRUE) {
|
||||
# Compute seasonality from Jan 1 through a single period.
|
||||
start <- zoo::as.Date('2017-01-01')
|
||||
period <- m$seasonalities[[name]][1]
|
||||
end <- start + period
|
||||
plot.points <- as.numeric(end - start)
|
||||
df.y <- data.frame(
|
||||
ds=seq.Date(from=start, to=end, length.out=plot.points), cap=1.)
|
||||
df.y <- setup_dataframe(m, df.y)$df
|
||||
seas <- predict_seasonal_components(m, df.y)
|
||||
seas$ds <- df.y$ds
|
||||
gg.s <- ggplot2::ggplot(
|
||||
seas, ggplot2::aes_string(x = 'ds', y = name, group = 1)) +
|
||||
ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
|
||||
if (period < 14) {
|
||||
fmt.str <- '%m/%d %R'
|
||||
} else {
|
||||
fmt.str <- '%m/%d'
|
||||
}
|
||||
gg.s <- gg.s + ggplot2::scale_x_date(labels = scales::date_format(fmt.str))
|
||||
if (uncertainty) {
|
||||
gg.s <- gg.s +
|
||||
ggplot2::geom_ribbon(
|
||||
ggplot2::aes_string(
|
||||
ymin = paste0(name, '_lower'), ymax = paste0(name, '_upper')
|
||||
),
|
||||
alpha = 0.2,
|
||||
fill = "#0072B2",
|
||||
na.rm = TRUE)
|
||||
}
|
||||
return(gg.s)
|
||||
}
|
||||
|
||||
# fb-block 3
|
||||
|
|
|
|||
25
R/man/add_seasonality.Rd
Normal file
25
R/man/add_seasonality.Rd
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/prophet.R
|
||||
\name{add_seasonality}
|
||||
\alias{add_seasonality}
|
||||
\title{Add a seasonal component with specified period and number of Fourier
|
||||
components.}
|
||||
\usage{
|
||||
add_seasonality(m, name, period, fourier.order)
|
||||
}
|
||||
\arguments{
|
||||
\item{m}{Prophet object.}
|
||||
|
||||
\item{name}{String name of the seasonality component.}
|
||||
|
||||
\item{period}{Float number of days in one period.}
|
||||
|
||||
\item{fourier.order}{Int number of Fourier components to use.}
|
||||
}
|
||||
\value{
|
||||
The prophet model with the seasonality added.
|
||||
}
|
||||
\description{
|
||||
Increasing the number of Fourier components allows the seasonality to change
|
||||
more quickly (at risk of overfitting).
|
||||
}
|
||||
26
R/man/parse_seasonality_args.Rd
Normal file
26
R/man/parse_seasonality_args.Rd
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/prophet.R
|
||||
\name{parse_seasonality_args}
|
||||
\alias{parse_seasonality_args}
|
||||
\title{Get number of Fourier components for built-in seasonalities.}
|
||||
\usage{
|
||||
parse_seasonality_args(m, name, arg, auto.disable, default.order)
|
||||
}
|
||||
\arguments{
|
||||
\item{m}{Prophet object.}
|
||||
|
||||
\item{name}{String name of the seasonality component.}
|
||||
|
||||
\item{arg}{'auto', TRUE, FALSE, or number of Fourier components as
|
||||
provided.}
|
||||
|
||||
\item{auto.disable}{Bool if seasonality should be disabled when 'auto'.}
|
||||
|
||||
\item{default.order}{Int default Fourier order.}
|
||||
}
|
||||
\value{
|
||||
Number of Fourier components, or 0 for disabled.
|
||||
}
|
||||
\description{
|
||||
Get number of Fourier components for built-in seasonalities.
|
||||
}
|
||||
21
R/man/plot_seasonality.Rd
Normal file
21
R/man/plot_seasonality.Rd
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/prophet.R
|
||||
\name{plot_seasonality}
|
||||
\alias{plot_seasonality}
|
||||
\title{Plot a custom seasonal component.}
|
||||
\usage{
|
||||
plot_seasonality(m, name, uncertainty = TRUE)
|
||||
}
|
||||
\arguments{
|
||||
\item{m}{Prophet model object.}
|
||||
|
||||
\item{name}{String name of the seasonality.}
|
||||
|
||||
\item{uncertainty}{Boolean to plot uncertainty intervals.}
|
||||
}
|
||||
\value{
|
||||
A ggplot2 plot.
|
||||
}
|
||||
\description{
|
||||
Plot a custom seasonal component.
|
||||
}
|
||||
|
|
@ -4,7 +4,7 @@
|
|||
\alias{prophet}
|
||||
\title{Prophet forecaster.}
|
||||
\usage{
|
||||
prophet(df = df, growth = "linear", changepoints = NULL,
|
||||
prophet(df = NULL, growth = "linear", changepoints = NULL,
|
||||
n.changepoints = 25, yearly.seasonality = "auto",
|
||||
weekly.seasonality = "auto", holidays = NULL,
|
||||
seasonality.prior.scale = 10, holidays.prior.scale = 10,
|
||||
|
|
@ -12,9 +12,11 @@ prophet(df = df, growth = "linear", changepoints = NULL,
|
|||
uncertainty.samples = 1000, fit = TRUE, ...)
|
||||
}
|
||||
\arguments{
|
||||
\item{df}{Dataframe containing the history. Must have columns ds (date type)
|
||||
and y, the time series. If growth is logistic, then df must also have a
|
||||
column cap that specifies the capacity at each ds.}
|
||||
\item{df}{(optional) Dataframe containing the history. Must have columns ds
|
||||
(date type) and y, the time series. If growth is logistic, then df must
|
||||
also have a column cap that specifies the capacity at each ds. If not
|
||||
provided, then the model object will be instantiated but not fit; use
|
||||
fit.prophet(m, df) to fit the model.}
|
||||
|
||||
\item{growth}{String 'linear' or 'logistic' to specify a linear or logistic
|
||||
trend.}
|
||||
|
|
@ -28,9 +30,11 @@ if input `changepoints` is supplied. If `changepoints` is not supplied,
|
|||
then n.changepoints potential changepoints are selected uniformly from the
|
||||
first 80 percent of df$ds.}
|
||||
|
||||
\item{yearly.seasonality}{Fit yearly seasonality; 'auto', TRUE, or FALSE.}
|
||||
\item{yearly.seasonality}{Fit yearly seasonality. Can be 'auto', TRUE,
|
||||
FALSE, or a number of Fourier terms to generate.}
|
||||
|
||||
\item{weekly.seasonality}{Fit weekly seasonality; 'auto', TRUE, or FALSE.}
|
||||
\item{weekly.seasonality}{Fit weekly seasonality. Can be 'auto', TRUE,
|
||||
FALSE, or a number of Fourier terms to generate.}
|
||||
|
||||
\item{holidays}{data frame with columns holiday (character) and ds (date
|
||||
type)and optionally columns lower_window and upper_window which specify a
|
||||
|
|
|
|||
|
|
@ -219,38 +219,51 @@ test_that("make_future_dataframe", {
|
|||
|
||||
test_that("auto_weekly_seasonality", {
|
||||
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
|
||||
# Should be True
|
||||
# Should be enabled
|
||||
N.w <- 15
|
||||
train.w <- DATA[1:N.w, ]
|
||||
m <- prophet(train.w, fit = FALSE)
|
||||
expect_equal(m$weekly.seasonality, 'auto')
|
||||
m <- prophet:::fit.prophet(m, train.w)
|
||||
expect_equal(m$weekly.seasonality, TRUE)
|
||||
# Should be False due to too short history
|
||||
expect_true('weekly' %in% names(m$seasonalities))
|
||||
expect_equal(m$seasonalities[['weekly']], c(7, 3))
|
||||
# Should be disabled due to too short history
|
||||
N.w <- 9
|
||||
train.w <- DATA[1:N.w, ]
|
||||
m <- prophet(train.w)
|
||||
expect_equal(m$weekly.seasonality, FALSE)
|
||||
expect_false('weekly' %in% names(m$seasonalities))
|
||||
m <- prophet(train.w, weekly.seasonality = TRUE)
|
||||
expect_equal(m$weekly.seasonality, TRUE)
|
||||
expect_true('weekly' %in% names(m$seasonalities))
|
||||
# Should be False due to weekly spacing
|
||||
train.w <- DATA[seq(1, nrow(DATA), 7), ]
|
||||
m <- prophet(train.w)
|
||||
expect_equal(m$weekly.seasonality, FALSE)
|
||||
expect_false('weekly' %in% names(m$seasonalities))
|
||||
m <- prophet(DATA, weekly.seasonality=2)
|
||||
expect_equal(m$seasonalities[['weekly']], c(7, 2))
|
||||
})
|
||||
|
||||
test_that("auto_yearly_seasonality", {
|
||||
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
|
||||
# Should be True
|
||||
# Should be enabled
|
||||
m <- prophet(DATA, fit = FALSE)
|
||||
expect_equal(m$yearly.seasonality, 'auto')
|
||||
m <- prophet:::fit.prophet(m, DATA)
|
||||
expect_equal(m$yearly.seasonality, TRUE)
|
||||
# Should be False due to too short history
|
||||
expect_true('yearly' %in% names(m$seasonalities))
|
||||
expect_equal(m$seasonalities[['yearly']], c(365.25, 10))
|
||||
# Should be disabled due to too short history
|
||||
N.w <- 240
|
||||
train.y <- DATA[1:N.w, ]
|
||||
m <- prophet(train.y)
|
||||
expect_equal(m$yearly.seasonality, FALSE)
|
||||
expect_false('yearly' %in% names(m$seasonalities))
|
||||
m <- prophet(train.y, yearly.seasonality = TRUE)
|
||||
expect_equal(m$yearly.seasonality, TRUE)
|
||||
expect_true('yearly' %in% names(m$seasonalities))
|
||||
m <- prophet(DATA, yearly.seasonality=7)
|
||||
expect_equal(m$seasonalities[['yearly']], c(365.25, 7))
|
||||
})
|
||||
|
||||
test_that("custom_seasonality", {
|
||||
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
|
||||
m <- prophet()
|
||||
m <- add_seasonality(m, name='monthly', period=30, fourier.order=5)
|
||||
expect_equal(m$seasonalities[['monthly']], c(30, 5))
|
||||
})
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
- id: quick_start
|
||||
- id: forecasting_growth
|
||||
- id: trend_changepoints
|
||||
- id: holiday_effects
|
||||
- id: seasonality_and_holiday_effects
|
||||
- id: uncertainty_intervals
|
||||
- id: outliers
|
||||
- id: non-daily_data
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
Loading…
Reference in a new issue