Move built-in country holidays to a function (R)

This commit is contained in:
Ben Letham 2018-11-30 23:12:19 -08:00
parent 92f955d25a
commit 287fb2f6de
9 changed files with 183 additions and 83 deletions

View file

@ -32,7 +32,7 @@ Suggests:
readr
License: BSD_3_clause + file LICENSE
LazyData: true
RoxygenNote: 6.1.0
RoxygenNote: 6.1.1
VignetteBuilder: knitr
SystemRequirements: C++11
Encoding: UTF-8

View file

@ -3,6 +3,7 @@
S3method(plot,prophet)
S3method(predict,prophet)
export(add_changepoints_to_plot)
export(add_country_holidays)
export(add_regressor)
export(add_seasonality)
export(cross_validation)

View file

@ -8,7 +8,7 @@
## Makes R CMD CHECK happy due to dplyr syntax below
globalVariables(c(
"ds", "y", "cap", ".",
"component", "dow", "doy", "holiday", "holidays", "append.holidays", "holidays_lower",
"component", "dow", "doy", "holiday", "holidays", "holidays_lower",
"holidays_upper", "ix", "lower", "n", "stat", "trend", "row_number", "extra_regressors", "col",
"trend_lower", "trend_upper", "upper", "value", "weekly", "weekly_lower", "weekly_upper",
"x", "yearly", "yearly_lower", "yearly_upper", "yhat", "yhat_lower", "yhat_upper"))
@ -43,7 +43,6 @@ globalVariables(c(
#' range of days around the date to be included as holidays. lower_window=-2
#' will include 2 days prior to the date as holidays. Also optionally can have
#' a column prior_scale specifying the prior scale for each holiday.
#' @param append.holidays country name or abbreviation (character).
#' @param seasonality.mode 'additive' (default) or 'multiplicative'.
#' @param seasonality.prior.scale Parameter modulating the strength of the
#' seasonality model. Larger values allow the model to fit larger seasonal
@ -88,7 +87,6 @@ prophet <- function(df = NULL,
weekly.seasonality = 'auto',
daily.seasonality = 'auto',
holidays = NULL,
append.holidays = NULL,
seasonality.mode = 'additive',
seasonality.prior.scale = 10,
holidays.prior.scale = 10,
@ -112,7 +110,6 @@ prophet <- function(df = NULL,
weekly.seasonality = weekly.seasonality,
daily.seasonality = daily.seasonality,
holidays = holidays,
append.holidays = append.holidays,
seasonality.mode = seasonality.mode,
seasonality.prior.scale = seasonality.prior.scale,
changepoint.prior.scale = changepoint.prior.scale,
@ -128,6 +125,7 @@ prophet <- function(df = NULL,
changepoints.t = NULL,
seasonalities = list(),
extra_regressors = list(),
country_holidays = NULL,
stan.fit = NULL,
params = list(),
history = NULL,
@ -181,11 +179,6 @@ validate_inputs <- function(m) {
validate_column_name(m, h, check_holidays = FALSE)
}
}
if (!is.null(m$append.holidays)) {
if (!(m$append.holidays %in% generated_holidays$country)){
stop("Holidays in ", m$append.holidays," are not currently supported!")
}
}
if (!(m$seasonality.mode %in% c('additive', 'multiplicative'))) {
stop("seasonality.mode must be 'additive' or 'multiplicative'")
}
@ -223,9 +216,9 @@ validate_column_name <- function(
(name %in% unique(m$holidays$holiday))){
stop("Name ", name, " already used for a holiday.")
}
if(check_holidays & !is.null(m$append.holidays)){
if(name %in% get_holiday_names(m$append.holidays)){
stop("Name ", name, " is a holiday name in ", m$append.holidays, ".")
if(check_holidays & !is.null(m$country_holidays)){
if(name %in% get_holiday_names(m$country_holidays)){
stop("Name ", name, " is a holiday name in ", m$country_holidays, ".")
}
}
if(check_seasonalities & (!is.null(m$seasonalities[[name]]))){
@ -533,10 +526,46 @@ make_seasonality_features <- function(dates, period, series.order, prefix) {
return(data.frame(features))
}
#' Construct a dataframe of holiday dates.
#'
#' @param m Prophet object.
#' @param dates Vector with dates used for computing seasonality.
#'
#' @return A dataframe of holiday dates, in holiday dataframe format used in
#' initialization.
#'
#' @importFrom dplyr "%>%"
#' @keywords internal
construct_holiday_dataframe <- function(m, dates) {
all.holidays <- data.frame()
if (!is.null(m$holidays)){
all.holidays <- m$holidays
}
if (!is.null(m$country_holidays)) {
year.list <- as.numeric(unique(format(dates, "%Y")))
country.holidays.df <- make_holidays_df(year.list, m$country_holidays) %>%
dplyr::mutate(ds=as.character(ds), holiday=as.character(holiday))
all.holidays <- suppressWarnings(dplyr::bind_rows(all.holidays, country.holidays.df))
}
# If the model has already been fit with a certain set of holidays,
# make sure we are using those same ones.
if (!is.null(m$train.holiday.names)) {
row.to.keep <- which(all.holidays$holiday %in% m$train.holiday.names)
all.holidays <- all.holidays[row.to.keep,]
holidays.to.add <- data.frame(
holiday=setdiff(m$train.holiday.names, all.holidays$holiday)
)
all.holidays <- suppressWarnings(dplyr::bind_rows(all.holidays, holidays.to.add))
}
return(all.holidays)
}
#' Construct a matrix of holiday features.
#'
#' @param m Prophet object.
#' @param dates Vector with dates used for computing seasonality.
#' @param holidays Dataframe containing holidays, as returned by
#' construct_holiday_dataframe.
#'
#' @return A list with entries
#' holiday.features: dataframe with a column for each holiday.
@ -545,28 +574,10 @@ make_seasonality_features <- function(dates, period, series.order, prefix) {
#'
#' @importFrom dplyr "%>%"
#' @keywords internal
make_holiday_features <- function(m, dates) {
make_holiday_features <- function(m, dates, holidays) {
# Strip dates to be just days, for joining on holidays
dates <- set_date(format(dates, "%Y-%m-%d"))
all.holidays <- m$holidays
if (!is.null(m$append.holidays)){
years <- as.numeric(unique(format(dates, "%Y")))
append.holidays.df <- make_holidays_df(years, m$append.holidays) %>%
dplyr::mutate(ds=as.character(ds), holiday=as.character(holiday))
all.holidays <- suppressWarnings(dplyr::bind_rows(all.holidays, append.holidays.df))
}
# Make fit.prophet and predict.prophet holidays components match
if (!is.null(m$append.holidays) && !is.null(m$train.holiday.names)){
row.to.keep <- which(all.holidays$holiday %in% m$train.holiday.names)
all.holidays <- all.holidays[row.to.keep,]
holidays.to.add <- data.frame(holiday=setdiff(m$train.holiday.names,
all.holidays$holiday))
all.holidays <- suppressWarnings(dplyr::bind_rows(all.holidays, holidays.to.add))
}
if (nrow(all.holidays)==0){
return(NULL)
}
wide <- all.holidays %>%
wide <- holidays %>%
dplyr::mutate(ds = set_date(ds)) %>%
dplyr::group_by(holiday, ds) %>%
dplyr::filter(dplyr::row_number() == 1) %>%
@ -587,17 +598,17 @@ make_holiday_features <- function(m, dates) {
holiday.features <- data.frame(ds = set_date(dates)) %>%
dplyr::left_join(wide, by = 'ds') %>%
dplyr::select(-ds)
# Make sure fit.prophet and predict.prophet component.cols perfectly equal
# Make sure column order is consistent
holiday.features <- holiday.features %>% dplyr::select(sort(names(.)))
holiday.features[is.na(holiday.features)] <- 0
# Prior scales
if (!('prior_scale' %in% colnames(all.holidays))) {
all.holidays$prior_scale <- m$holidays.prior.scale
if (!('prior_scale' %in% colnames(holidays))) {
holidays$prior_scale <- m$holidays.prior.scale
}
prior.scales.list <- list()
for (name in unique(all.holidays$holiday)) {
df.h <- all.holidays[all.holidays$holiday == name, ]
for (name in unique(holidays$holiday)) {
df.h <- holidays[holidays$holiday == name, ]
ps <- unique(df.h$prior_scale)
if (length(ps) > 1) {
stop('Holiday ', name, ' does not have a consistent prior scale ',
@ -707,7 +718,6 @@ add_regressor <- function(
#'
#' @return The prophet model with the seasonality added.
#'
#' @importFrom dplyr "%>%"
#' @export
add_seasonality <- function(
m, name, period, fourier.order, prior.scale = NULL, mode = NULL
@ -742,6 +752,46 @@ add_seasonality <- function(
return(m)
}
#' Add in built-in holidays for the specified country.
#'
#' These holidays will be included in addition to any specified on model
#' initialization.
#'
#' Holidays will be calculated for arbitrary date ranges in the history
#' and future. See the online documentation for the list of countries with
#' built-in holidays.
#'
#' Built-in country holidays can only be set for a single country.
#'
#' @param m Prophet object.
#' @param country_name Name of the country, like 'UnitedStates' or 'US'
#'
#' @return The prophet model with the holidays country set.
#'
#' @export
add_country_holidays <- function(m, country_name) {
if (!is.null(m$history)) {
stop("Country holidays must be added prior to model fitting.")
}
if (!(country_name %in% generated_holidays$country)){
stop("Holidays in ", country_name," are not currently supported!")
}
# Validate names.
for (name in get_holiday_names(country_name)) {
# Allow merging with existing holidays
validate_column_name(m, name, check_holidays = FALSE)
}
# Set the holidays.
if (!is.null(m$country_holidays)) {
message(
'Changing country holidays from ', m$country_holidays, ' to ',
country_name
)
}
m$country_holidays = country_name
return(m)
}
#' Dataframe with seasonality features.
#' Includes seasonality features, holiday features, and added regressors.
#'
@ -776,15 +826,15 @@ make_all_seasonality_features <- function(m, df) {
}
# Holiday features
if (!is.null(m$holidays) || !is.null(m$append.holidays)) {
out <- make_holiday_features(m, df$ds)
if (!is.null(out)){
m <- out$m
seasonal.features <- cbind(seasonal.features, out$holiday.features)
prior.scales <- c(prior.scales, out$prior.scales)
modes[[m$seasonality.mode]] <- c(
modes[[m$seasonality.mode]], out$holiday.names)
}
holidays <- construct_holiday_dataframe(m, df$ds)
if (nrow(holidays) > 0) {
out <- make_holiday_features(m, df$ds, holidays)
m <- out$m
seasonal.features <- cbind(seasonal.features, out$holiday.features)
prior.scales <- c(prior.scales, out$prior.scales)
modes[[m$seasonality.mode]] <- c(
modes[[m$seasonality.mode]], out$holiday.names
)
}
# Additional regressors

View file

@ -0,0 +1,27 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/prophet.R
\name{add_country_holidays}
\alias{add_country_holidays}
\title{Add in built-in holidays for the specified country.}
\usage{
add_country_holidays(m, country_name)
}
\arguments{
\item{m}{Prophet object.}
\item{country_name}{Name of the country, like 'UnitedStates' or 'US'}
}
\value{
The prophet model with the holidays country set.
}
\description{
These holidays will be included in addition to any specified on model
initialization.
}
\details{
Holidays will be calculated for arbitrary date ranges in the history
and future. See the online documentation for the list of countries with
built-in holidays.
Built-in country holidays can only be set for a single country.
}

View file

@ -0,0 +1,21 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/prophet.R
\name{construct_holiday_dataframe}
\alias{construct_holiday_dataframe}
\title{Construct a dataframe of holiday dates.}
\usage{
construct_holiday_dataframe(m, dates)
}
\arguments{
\item{m}{Prophet object.}
\item{dates}{Vector with dates used for computing seasonality.}
}
\value{
A dataframe of holiday dates, in holiday dataframe format used in
initialization.
}
\description{
Construct a dataframe of holiday dates.
}
\keyword{internal}

View file

@ -4,12 +4,15 @@
\alias{make_holiday_features}
\title{Construct a matrix of holiday features.}
\usage{
make_holiday_features(m, dates)
make_holiday_features(m, dates, holidays)
}
\arguments{
\item{m}{Prophet object.}
\item{dates}{Vector with dates used for computing seasonality.}
\item{holidays}{Dataframe containing holidays, as returned by
construct_holiday_dataframe.}
}
\value{
A list with entries

View file

@ -8,10 +8,10 @@ prophet(df = NULL, growth = "linear", changepoints = NULL,
n.changepoints = 25, changepoint.range = 0.8,
yearly.seasonality = "auto", weekly.seasonality = "auto",
daily.seasonality = "auto", holidays = NULL,
append.holidays = NULL, seasonality.mode = "additive",
seasonality.prior.scale = 10, holidays.prior.scale = 10,
changepoint.prior.scale = 0.05, mcmc.samples = 0,
interval.width = 0.8, uncertainty.samples = 1000, fit = TRUE, ...)
seasonality.mode = "additive", seasonality.prior.scale = 10,
holidays.prior.scale = 10, changepoint.prior.scale = 0.05,
mcmc.samples = 0, interval.width = 0.8, uncertainty.samples = 1000,
fit = TRUE, ...)
}
\arguments{
\item{df}{(optional) Dataframe containing the history. Must have columns ds
@ -51,8 +51,6 @@ range of days around the date to be included as holidays. lower_window=-2
will include 2 days prior to the date as holidays. Also optionally can have
a column prior_scale specifying the prior scale for each holiday.}
\item{append.holidays}{country name or abbreviation (character).}
\item{seasonality.mode}{'additive' (default) or 'multiplicative'.}
\item{seasonality.prior.scale}{Parameter modulating the strength of the

View file

@ -259,7 +259,7 @@ test_that("holidays", {
ds = seq(prophet:::set_date('2016-12-20'),
prophet:::set_date('2016-12-31'), by='d'))
m <- prophet(train, holidays = holidays, fit = FALSE)
out <- prophet:::make_holiday_features(m, df$ds)
out <- prophet:::make_holiday_features(m, df$ds, m$holidays)
feats <- out$holiday.features
priors <- out$prior.scales
names <- out$holiday.names
@ -274,7 +274,7 @@ test_that("holidays", {
lower_window = c(-1),
upper_window = c(10))
m <- prophet(train, holidays = holidays, fit = FALSE)
out <- prophet:::make_holiday_features(m, df$ds)
out <- prophet:::make_holiday_features(m, df$ds, m$holidays)
feats <- out$holiday.features
priors <- out$prior.scales
names <- out$holiday.names
@ -291,7 +291,7 @@ test_that("holidays", {
prior_scale = c(5., 5.)
)
m <- prophet(holidays = holidays, fit = FALSE)
out <- prophet:::make_holiday_features(m, df$ds)
out <- prophet:::make_holiday_features(m, df$ds, m$holidays)
priors <- out$prior.scales
names <- out$holiday.names
expect_true(all(priors == c(5., 5.)))
@ -306,7 +306,7 @@ test_that("holidays", {
)
holidays2 <- rbind(holidays, holidays2)
m <- prophet(holidays = holidays2, fit = FALSE)
out <- prophet:::make_holiday_features(m, df$ds)
out <- prophet:::make_holiday_features(m, df$ds, m$holidays)
priors <- out$prior.scales
names <- out$holiday.names
expect_true(all(priors == c(8, 8, 5, 5)))
@ -324,7 +324,7 @@ test_that("holidays", {
# manual factorizing to avoid above bind_rows() warning
holidays2$holiday <- factor(holidays2$holiday)
m <- prophet(holidays = holidays2, fit = FALSE, holidays.prior.scale = 4)
out <- prophet:::make_holiday_features(m, df$ds)
out <- prophet:::make_holiday_features(m, df$ds, m$holidays)
priors <- out$prior.scales
expect_true(all(priors == c(4, 4, 5, 5)))
# Check incompatible priors
@ -336,7 +336,7 @@ test_that("holidays", {
prior_scale = c(5., 6.)
)
m <- prophet(holidays = holidays, fit = FALSE)
expect_error(prophet:::make_holiday_features(m, df$ds))
expect_error(prophet:::make_holiday_features(m, df$ds, m$holidays))
})
test_that("fit_with_holidays", {
@ -349,47 +349,45 @@ test_that("fit_with_holidays", {
expect_error(predict(m), NA)
})
test_that("fit_with_append_holidays", {
test_that("fit_with_country_holidays", {
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
holidays <- data.frame(ds = c('2012-06-06', '2013-06-06'),
holiday = c('seans-bday', 'seans-bday'),
lower_window = c(0, 0),
upper_window = c(1, 1))
append.holidays = 'US'
# Test with holidays and append_holidays
m <- prophet(DATA,
holidays = holidays,
append.holidays = append.holidays,
uncertainty.samples = 0)
m <- prophet(holidays = holidays, uncertainty.samples = 0)
m <- add_country_holidays(m, 'US')
m <- fit.prophet(m, DATA)
expect_error(predict(m), NA)
# There are training holidays missing in the test set
train2 <- DATA %>% head(155)
future2 <- DATA %>% tail(355)
model <- prophet(train2,
append.holidays = append.holidays,
uncertainty.samples = 0)
m <- prophet(uncertainty.samples = 0)
m <- add_country_holidays(m, 'US')
m <- fit.prophet(m, train2)
expect_error(predict(m, future2), NA)
# There are test holidays missing in the training set
train2 <- DATA %>% tail(355)
future2 <- DATA2
model <- prophet(train2,
append.holidays = append.holidays,
uncertainty.samples = 0)
m <- prophet(uncertainty.samples = 0)
m <- add_country_holidays(m, 'US')
m <- fit.prophet(m, train2)
expect_error(predict(m, future2), NA)
# Append_holidays with non-existing year
max.year <- generated_holidays %>%
dplyr::filter(country==append.holidays) %>%
dplyr::filter(country=='US') %>%
dplyr::select(year) %>%
max()
train2 <- data.frame('ds'=c(paste(max.year+1, "-01-01", sep=''),
paste(max.year+1, "-01-02", sep='')),
'y'=1)
expect_warning(prophet(train2,
append.holidays = append.holidays))
m <- prophet()
m <- add_country_holidays(m, 'US')
expect_warning(m <- fit.prophet(m, train2))
# Append_holidays with non-existing country
append.holidays = 'Utopia'
expect_error(prophet(DATA,
append.holidays = append.holidays))
m <- prophet()
expect_error(add_country_holidays(m, 'Utopia'))
})
test_that("make_future_dataframe", {

View file

@ -434,10 +434,12 @@ class Prophet(object):
Returns
-------
dataframe of holiday dates, in holiday dataframe format used in
initialization.
"""
all_holidays = pd.DataFrame()
if self.holidays is not None:
all_holidays = pd.concat((all_holidays, self.holidays))
all_holidays = self.holidays.copy()
if self.country_holidays is not None:
year_list = list({x.year for x in dates})
country_holidays_df = make_holidays_df(
@ -464,7 +466,7 @@ class Prophet(object):
all_holidays = pd.concat((all_holidays, holidays_to_add), sort=False)
all_holidays.reset_index(drop=True, inplace=True)
return all_holidays
def make_holiday_features(self, dates, holidays):
"""Construct a dataframe of holiday features.
@ -526,7 +528,7 @@ class Prophet(object):
# Access key to generate value
expanded_holidays[key]
holiday_features = pd.DataFrame(expanded_holidays)
# Make sure fit and predict component_cols perfectly equal
# Make sure column order is consistent
holiday_features = holiday_features[sorted(holiday_features.columns.tolist())]
prior_scale_list = [
prior_scales[h.split('_delim_')[0]]