mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-05-18 21:21:22 +00:00
Move built-in country holidays to a function (R)
This commit is contained in:
parent
92f955d25a
commit
287fb2f6de
9 changed files with 183 additions and 83 deletions
|
|
@ -32,7 +32,7 @@ Suggests:
|
|||
readr
|
||||
License: BSD_3_clause + file LICENSE
|
||||
LazyData: true
|
||||
RoxygenNote: 6.1.0
|
||||
RoxygenNote: 6.1.1
|
||||
VignetteBuilder: knitr
|
||||
SystemRequirements: C++11
|
||||
Encoding: UTF-8
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
S3method(plot,prophet)
|
||||
S3method(predict,prophet)
|
||||
export(add_changepoints_to_plot)
|
||||
export(add_country_holidays)
|
||||
export(add_regressor)
|
||||
export(add_seasonality)
|
||||
export(cross_validation)
|
||||
|
|
|
|||
146
R/R/prophet.R
146
R/R/prophet.R
|
|
@ -8,7 +8,7 @@
|
|||
## Makes R CMD CHECK happy due to dplyr syntax below
|
||||
globalVariables(c(
|
||||
"ds", "y", "cap", ".",
|
||||
"component", "dow", "doy", "holiday", "holidays", "append.holidays", "holidays_lower",
|
||||
"component", "dow", "doy", "holiday", "holidays", "holidays_lower",
|
||||
"holidays_upper", "ix", "lower", "n", "stat", "trend", "row_number", "extra_regressors", "col",
|
||||
"trend_lower", "trend_upper", "upper", "value", "weekly", "weekly_lower", "weekly_upper",
|
||||
"x", "yearly", "yearly_lower", "yearly_upper", "yhat", "yhat_lower", "yhat_upper"))
|
||||
|
|
@ -43,7 +43,6 @@ globalVariables(c(
|
|||
#' range of days around the date to be included as holidays. lower_window=-2
|
||||
#' will include 2 days prior to the date as holidays. Also optionally can have
|
||||
#' a column prior_scale specifying the prior scale for each holiday.
|
||||
#' @param append.holidays country name or abbreviation (character).
|
||||
#' @param seasonality.mode 'additive' (default) or 'multiplicative'.
|
||||
#' @param seasonality.prior.scale Parameter modulating the strength of the
|
||||
#' seasonality model. Larger values allow the model to fit larger seasonal
|
||||
|
|
@ -88,7 +87,6 @@ prophet <- function(df = NULL,
|
|||
weekly.seasonality = 'auto',
|
||||
daily.seasonality = 'auto',
|
||||
holidays = NULL,
|
||||
append.holidays = NULL,
|
||||
seasonality.mode = 'additive',
|
||||
seasonality.prior.scale = 10,
|
||||
holidays.prior.scale = 10,
|
||||
|
|
@ -112,7 +110,6 @@ prophet <- function(df = NULL,
|
|||
weekly.seasonality = weekly.seasonality,
|
||||
daily.seasonality = daily.seasonality,
|
||||
holidays = holidays,
|
||||
append.holidays = append.holidays,
|
||||
seasonality.mode = seasonality.mode,
|
||||
seasonality.prior.scale = seasonality.prior.scale,
|
||||
changepoint.prior.scale = changepoint.prior.scale,
|
||||
|
|
@ -128,6 +125,7 @@ prophet <- function(df = NULL,
|
|||
changepoints.t = NULL,
|
||||
seasonalities = list(),
|
||||
extra_regressors = list(),
|
||||
country_holidays = NULL,
|
||||
stan.fit = NULL,
|
||||
params = list(),
|
||||
history = NULL,
|
||||
|
|
@ -181,11 +179,6 @@ validate_inputs <- function(m) {
|
|||
validate_column_name(m, h, check_holidays = FALSE)
|
||||
}
|
||||
}
|
||||
if (!is.null(m$append.holidays)) {
|
||||
if (!(m$append.holidays %in% generated_holidays$country)){
|
||||
stop("Holidays in ", m$append.holidays," are not currently supported!")
|
||||
}
|
||||
}
|
||||
if (!(m$seasonality.mode %in% c('additive', 'multiplicative'))) {
|
||||
stop("seasonality.mode must be 'additive' or 'multiplicative'")
|
||||
}
|
||||
|
|
@ -223,9 +216,9 @@ validate_column_name <- function(
|
|||
(name %in% unique(m$holidays$holiday))){
|
||||
stop("Name ", name, " already used for a holiday.")
|
||||
}
|
||||
if(check_holidays & !is.null(m$append.holidays)){
|
||||
if(name %in% get_holiday_names(m$append.holidays)){
|
||||
stop("Name ", name, " is a holiday name in ", m$append.holidays, ".")
|
||||
if(check_holidays & !is.null(m$country_holidays)){
|
||||
if(name %in% get_holiday_names(m$country_holidays)){
|
||||
stop("Name ", name, " is a holiday name in ", m$country_holidays, ".")
|
||||
}
|
||||
}
|
||||
if(check_seasonalities & (!is.null(m$seasonalities[[name]]))){
|
||||
|
|
@ -533,10 +526,46 @@ make_seasonality_features <- function(dates, period, series.order, prefix) {
|
|||
return(data.frame(features))
|
||||
}
|
||||
|
||||
#' Construct a dataframe of holiday dates.
|
||||
#'
|
||||
#' @param m Prophet object.
|
||||
#' @param dates Vector with dates used for computing seasonality.
|
||||
#'
|
||||
#' @return A dataframe of holiday dates, in holiday dataframe format used in
|
||||
#' initialization.
|
||||
#'
|
||||
#' @importFrom dplyr "%>%"
|
||||
#' @keywords internal
|
||||
construct_holiday_dataframe <- function(m, dates) {
|
||||
all.holidays <- data.frame()
|
||||
if (!is.null(m$holidays)){
|
||||
all.holidays <- m$holidays
|
||||
}
|
||||
if (!is.null(m$country_holidays)) {
|
||||
year.list <- as.numeric(unique(format(dates, "%Y")))
|
||||
country.holidays.df <- make_holidays_df(year.list, m$country_holidays) %>%
|
||||
dplyr::mutate(ds=as.character(ds), holiday=as.character(holiday))
|
||||
all.holidays <- suppressWarnings(dplyr::bind_rows(all.holidays, country.holidays.df))
|
||||
}
|
||||
# If the model has already been fit with a certain set of holidays,
|
||||
# make sure we are using those same ones.
|
||||
if (!is.null(m$train.holiday.names)) {
|
||||
row.to.keep <- which(all.holidays$holiday %in% m$train.holiday.names)
|
||||
all.holidays <- all.holidays[row.to.keep,]
|
||||
holidays.to.add <- data.frame(
|
||||
holiday=setdiff(m$train.holiday.names, all.holidays$holiday)
|
||||
)
|
||||
all.holidays <- suppressWarnings(dplyr::bind_rows(all.holidays, holidays.to.add))
|
||||
}
|
||||
return(all.holidays)
|
||||
}
|
||||
|
||||
#' Construct a matrix of holiday features.
|
||||
#'
|
||||
#' @param m Prophet object.
|
||||
#' @param dates Vector with dates used for computing seasonality.
|
||||
#' @param holidays Dataframe containing holidays, as returned by
|
||||
#' construct_holiday_dataframe.
|
||||
#'
|
||||
#' @return A list with entries
|
||||
#' holiday.features: dataframe with a column for each holiday.
|
||||
|
|
@ -545,28 +574,10 @@ make_seasonality_features <- function(dates, period, series.order, prefix) {
|
|||
#'
|
||||
#' @importFrom dplyr "%>%"
|
||||
#' @keywords internal
|
||||
make_holiday_features <- function(m, dates) {
|
||||
make_holiday_features <- function(m, dates, holidays) {
|
||||
# Strip dates to be just days, for joining on holidays
|
||||
dates <- set_date(format(dates, "%Y-%m-%d"))
|
||||
all.holidays <- m$holidays
|
||||
if (!is.null(m$append.holidays)){
|
||||
years <- as.numeric(unique(format(dates, "%Y")))
|
||||
append.holidays.df <- make_holidays_df(years, m$append.holidays) %>%
|
||||
dplyr::mutate(ds=as.character(ds), holiday=as.character(holiday))
|
||||
all.holidays <- suppressWarnings(dplyr::bind_rows(all.holidays, append.holidays.df))
|
||||
}
|
||||
# Make fit.prophet and predict.prophet holidays components match
|
||||
if (!is.null(m$append.holidays) && !is.null(m$train.holiday.names)){
|
||||
row.to.keep <- which(all.holidays$holiday %in% m$train.holiday.names)
|
||||
all.holidays <- all.holidays[row.to.keep,]
|
||||
holidays.to.add <- data.frame(holiday=setdiff(m$train.holiday.names,
|
||||
all.holidays$holiday))
|
||||
all.holidays <- suppressWarnings(dplyr::bind_rows(all.holidays, holidays.to.add))
|
||||
}
|
||||
if (nrow(all.holidays)==0){
|
||||
return(NULL)
|
||||
}
|
||||
wide <- all.holidays %>%
|
||||
wide <- holidays %>%
|
||||
dplyr::mutate(ds = set_date(ds)) %>%
|
||||
dplyr::group_by(holiday, ds) %>%
|
||||
dplyr::filter(dplyr::row_number() == 1) %>%
|
||||
|
|
@ -587,17 +598,17 @@ make_holiday_features <- function(m, dates) {
|
|||
holiday.features <- data.frame(ds = set_date(dates)) %>%
|
||||
dplyr::left_join(wide, by = 'ds') %>%
|
||||
dplyr::select(-ds)
|
||||
# Make sure fit.prophet and predict.prophet component.cols perfectly equal
|
||||
# Make sure column order is consistent
|
||||
holiday.features <- holiday.features %>% dplyr::select(sort(names(.)))
|
||||
holiday.features[is.na(holiday.features)] <- 0
|
||||
|
||||
|
||||
# Prior scales
|
||||
if (!('prior_scale' %in% colnames(all.holidays))) {
|
||||
all.holidays$prior_scale <- m$holidays.prior.scale
|
||||
if (!('prior_scale' %in% colnames(holidays))) {
|
||||
holidays$prior_scale <- m$holidays.prior.scale
|
||||
}
|
||||
prior.scales.list <- list()
|
||||
for (name in unique(all.holidays$holiday)) {
|
||||
df.h <- all.holidays[all.holidays$holiday == name, ]
|
||||
for (name in unique(holidays$holiday)) {
|
||||
df.h <- holidays[holidays$holiday == name, ]
|
||||
ps <- unique(df.h$prior_scale)
|
||||
if (length(ps) > 1) {
|
||||
stop('Holiday ', name, ' does not have a consistent prior scale ',
|
||||
|
|
@ -707,7 +718,6 @@ add_regressor <- function(
|
|||
#'
|
||||
#' @return The prophet model with the seasonality added.
|
||||
#'
|
||||
#' @importFrom dplyr "%>%"
|
||||
#' @export
|
||||
add_seasonality <- function(
|
||||
m, name, period, fourier.order, prior.scale = NULL, mode = NULL
|
||||
|
|
@ -742,6 +752,46 @@ add_seasonality <- function(
|
|||
return(m)
|
||||
}
|
||||
|
||||
#' Add in built-in holidays for the specified country.
|
||||
#'
|
||||
#' These holidays will be included in addition to any specified on model
|
||||
#' initialization.
|
||||
#'
|
||||
#' Holidays will be calculated for arbitrary date ranges in the history
|
||||
#' and future. See the online documentation for the list of countries with
|
||||
#' built-in holidays.
|
||||
#'
|
||||
#' Built-in country holidays can only be set for a single country.
|
||||
#'
|
||||
#' @param m Prophet object.
|
||||
#' @param country_name Name of the country, like 'UnitedStates' or 'US'
|
||||
#'
|
||||
#' @return The prophet model with the holidays country set.
|
||||
#'
|
||||
#' @export
|
||||
add_country_holidays <- function(m, country_name) {
|
||||
if (!is.null(m$history)) {
|
||||
stop("Country holidays must be added prior to model fitting.")
|
||||
}
|
||||
if (!(country_name %in% generated_holidays$country)){
|
||||
stop("Holidays in ", country_name," are not currently supported!")
|
||||
}
|
||||
# Validate names.
|
||||
for (name in get_holiday_names(country_name)) {
|
||||
# Allow merging with existing holidays
|
||||
validate_column_name(m, name, check_holidays = FALSE)
|
||||
}
|
||||
# Set the holidays.
|
||||
if (!is.null(m$country_holidays)) {
|
||||
message(
|
||||
'Changing country holidays from ', m$country_holidays, ' to ',
|
||||
country_name
|
||||
)
|
||||
}
|
||||
m$country_holidays = country_name
|
||||
return(m)
|
||||
}
|
||||
|
||||
#' Dataframe with seasonality features.
|
||||
#' Includes seasonality features, holiday features, and added regressors.
|
||||
#'
|
||||
|
|
@ -776,15 +826,15 @@ make_all_seasonality_features <- function(m, df) {
|
|||
}
|
||||
|
||||
# Holiday features
|
||||
if (!is.null(m$holidays) || !is.null(m$append.holidays)) {
|
||||
out <- make_holiday_features(m, df$ds)
|
||||
if (!is.null(out)){
|
||||
m <- out$m
|
||||
seasonal.features <- cbind(seasonal.features, out$holiday.features)
|
||||
prior.scales <- c(prior.scales, out$prior.scales)
|
||||
modes[[m$seasonality.mode]] <- c(
|
||||
modes[[m$seasonality.mode]], out$holiday.names)
|
||||
}
|
||||
holidays <- construct_holiday_dataframe(m, df$ds)
|
||||
if (nrow(holidays) > 0) {
|
||||
out <- make_holiday_features(m, df$ds, holidays)
|
||||
m <- out$m
|
||||
seasonal.features <- cbind(seasonal.features, out$holiday.features)
|
||||
prior.scales <- c(prior.scales, out$prior.scales)
|
||||
modes[[m$seasonality.mode]] <- c(
|
||||
modes[[m$seasonality.mode]], out$holiday.names
|
||||
)
|
||||
}
|
||||
|
||||
# Additional regressors
|
||||
|
|
|
|||
27
R/man/add_country_holidays.Rd
Normal file
27
R/man/add_country_holidays.Rd
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/prophet.R
|
||||
\name{add_country_holidays}
|
||||
\alias{add_country_holidays}
|
||||
\title{Add in built-in holidays for the specified country.}
|
||||
\usage{
|
||||
add_country_holidays(m, country_name)
|
||||
}
|
||||
\arguments{
|
||||
\item{m}{Prophet object.}
|
||||
|
||||
\item{country_name}{Name of the country, like 'UnitedStates' or 'US'}
|
||||
}
|
||||
\value{
|
||||
The prophet model with the holidays country set.
|
||||
}
|
||||
\description{
|
||||
These holidays will be included in addition to any specified on model
|
||||
initialization.
|
||||
}
|
||||
\details{
|
||||
Holidays will be calculated for arbitrary date ranges in the history
|
||||
and future. See the online documentation for the list of countries with
|
||||
built-in holidays.
|
||||
|
||||
Built-in country holidays can only be set for a single country.
|
||||
}
|
||||
21
R/man/construct_holiday_dataframe.Rd
Normal file
21
R/man/construct_holiday_dataframe.Rd
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/prophet.R
|
||||
\name{construct_holiday_dataframe}
|
||||
\alias{construct_holiday_dataframe}
|
||||
\title{Construct a dataframe of holiday dates.}
|
||||
\usage{
|
||||
construct_holiday_dataframe(m, dates)
|
||||
}
|
||||
\arguments{
|
||||
\item{m}{Prophet object.}
|
||||
|
||||
\item{dates}{Vector with dates used for computing seasonality.}
|
||||
}
|
||||
\value{
|
||||
A dataframe of holiday dates, in holiday dataframe format used in
|
||||
initialization.
|
||||
}
|
||||
\description{
|
||||
Construct a dataframe of holiday dates.
|
||||
}
|
||||
\keyword{internal}
|
||||
|
|
@ -4,12 +4,15 @@
|
|||
\alias{make_holiday_features}
|
||||
\title{Construct a matrix of holiday features.}
|
||||
\usage{
|
||||
make_holiday_features(m, dates)
|
||||
make_holiday_features(m, dates, holidays)
|
||||
}
|
||||
\arguments{
|
||||
\item{m}{Prophet object.}
|
||||
|
||||
\item{dates}{Vector with dates used for computing seasonality.}
|
||||
|
||||
\item{holidays}{Dataframe containing holidays, as returned by
|
||||
construct_holiday_dataframe.}
|
||||
}
|
||||
\value{
|
||||
A list with entries
|
||||
|
|
|
|||
|
|
@ -8,10 +8,10 @@ prophet(df = NULL, growth = "linear", changepoints = NULL,
|
|||
n.changepoints = 25, changepoint.range = 0.8,
|
||||
yearly.seasonality = "auto", weekly.seasonality = "auto",
|
||||
daily.seasonality = "auto", holidays = NULL,
|
||||
append.holidays = NULL, seasonality.mode = "additive",
|
||||
seasonality.prior.scale = 10, holidays.prior.scale = 10,
|
||||
changepoint.prior.scale = 0.05, mcmc.samples = 0,
|
||||
interval.width = 0.8, uncertainty.samples = 1000, fit = TRUE, ...)
|
||||
seasonality.mode = "additive", seasonality.prior.scale = 10,
|
||||
holidays.prior.scale = 10, changepoint.prior.scale = 0.05,
|
||||
mcmc.samples = 0, interval.width = 0.8, uncertainty.samples = 1000,
|
||||
fit = TRUE, ...)
|
||||
}
|
||||
\arguments{
|
||||
\item{df}{(optional) Dataframe containing the history. Must have columns ds
|
||||
|
|
@ -51,8 +51,6 @@ range of days around the date to be included as holidays. lower_window=-2
|
|||
will include 2 days prior to the date as holidays. Also optionally can have
|
||||
a column prior_scale specifying the prior scale for each holiday.}
|
||||
|
||||
\item{append.holidays}{country name or abbreviation (character).}
|
||||
|
||||
\item{seasonality.mode}{'additive' (default) or 'multiplicative'.}
|
||||
|
||||
\item{seasonality.prior.scale}{Parameter modulating the strength of the
|
||||
|
|
|
|||
|
|
@ -259,7 +259,7 @@ test_that("holidays", {
|
|||
ds = seq(prophet:::set_date('2016-12-20'),
|
||||
prophet:::set_date('2016-12-31'), by='d'))
|
||||
m <- prophet(train, holidays = holidays, fit = FALSE)
|
||||
out <- prophet:::make_holiday_features(m, df$ds)
|
||||
out <- prophet:::make_holiday_features(m, df$ds, m$holidays)
|
||||
feats <- out$holiday.features
|
||||
priors <- out$prior.scales
|
||||
names <- out$holiday.names
|
||||
|
|
@ -274,7 +274,7 @@ test_that("holidays", {
|
|||
lower_window = c(-1),
|
||||
upper_window = c(10))
|
||||
m <- prophet(train, holidays = holidays, fit = FALSE)
|
||||
out <- prophet:::make_holiday_features(m, df$ds)
|
||||
out <- prophet:::make_holiday_features(m, df$ds, m$holidays)
|
||||
feats <- out$holiday.features
|
||||
priors <- out$prior.scales
|
||||
names <- out$holiday.names
|
||||
|
|
@ -291,7 +291,7 @@ test_that("holidays", {
|
|||
prior_scale = c(5., 5.)
|
||||
)
|
||||
m <- prophet(holidays = holidays, fit = FALSE)
|
||||
out <- prophet:::make_holiday_features(m, df$ds)
|
||||
out <- prophet:::make_holiday_features(m, df$ds, m$holidays)
|
||||
priors <- out$prior.scales
|
||||
names <- out$holiday.names
|
||||
expect_true(all(priors == c(5., 5.)))
|
||||
|
|
@ -306,7 +306,7 @@ test_that("holidays", {
|
|||
)
|
||||
holidays2 <- rbind(holidays, holidays2)
|
||||
m <- prophet(holidays = holidays2, fit = FALSE)
|
||||
out <- prophet:::make_holiday_features(m, df$ds)
|
||||
out <- prophet:::make_holiday_features(m, df$ds, m$holidays)
|
||||
priors <- out$prior.scales
|
||||
names <- out$holiday.names
|
||||
expect_true(all(priors == c(8, 8, 5, 5)))
|
||||
|
|
@ -324,7 +324,7 @@ test_that("holidays", {
|
|||
# manual factorizing to avoid above bind_rows() warning
|
||||
holidays2$holiday <- factor(holidays2$holiday)
|
||||
m <- prophet(holidays = holidays2, fit = FALSE, holidays.prior.scale = 4)
|
||||
out <- prophet:::make_holiday_features(m, df$ds)
|
||||
out <- prophet:::make_holiday_features(m, df$ds, m$holidays)
|
||||
priors <- out$prior.scales
|
||||
expect_true(all(priors == c(4, 4, 5, 5)))
|
||||
# Check incompatible priors
|
||||
|
|
@ -336,7 +336,7 @@ test_that("holidays", {
|
|||
prior_scale = c(5., 6.)
|
||||
)
|
||||
m <- prophet(holidays = holidays, fit = FALSE)
|
||||
expect_error(prophet:::make_holiday_features(m, df$ds))
|
||||
expect_error(prophet:::make_holiday_features(m, df$ds, m$holidays))
|
||||
})
|
||||
|
||||
test_that("fit_with_holidays", {
|
||||
|
|
@ -349,47 +349,45 @@ test_that("fit_with_holidays", {
|
|||
expect_error(predict(m), NA)
|
||||
})
|
||||
|
||||
test_that("fit_with_append_holidays", {
|
||||
test_that("fit_with_country_holidays", {
|
||||
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
|
||||
holidays <- data.frame(ds = c('2012-06-06', '2013-06-06'),
|
||||
holiday = c('seans-bday', 'seans-bday'),
|
||||
lower_window = c(0, 0),
|
||||
upper_window = c(1, 1))
|
||||
append.holidays = 'US'
|
||||
# Test with holidays and append_holidays
|
||||
m <- prophet(DATA,
|
||||
holidays = holidays,
|
||||
append.holidays = append.holidays,
|
||||
uncertainty.samples = 0)
|
||||
m <- prophet(holidays = holidays, uncertainty.samples = 0)
|
||||
m <- add_country_holidays(m, 'US')
|
||||
m <- fit.prophet(m, DATA)
|
||||
expect_error(predict(m), NA)
|
||||
# There are training holidays missing in the test set
|
||||
train2 <- DATA %>% head(155)
|
||||
future2 <- DATA %>% tail(355)
|
||||
model <- prophet(train2,
|
||||
append.holidays = append.holidays,
|
||||
uncertainty.samples = 0)
|
||||
m <- prophet(uncertainty.samples = 0)
|
||||
m <- add_country_holidays(m, 'US')
|
||||
m <- fit.prophet(m, train2)
|
||||
expect_error(predict(m, future2), NA)
|
||||
# There are test holidays missing in the training set
|
||||
train2 <- DATA %>% tail(355)
|
||||
future2 <- DATA2
|
||||
model <- prophet(train2,
|
||||
append.holidays = append.holidays,
|
||||
uncertainty.samples = 0)
|
||||
m <- prophet(uncertainty.samples = 0)
|
||||
m <- add_country_holidays(m, 'US')
|
||||
m <- fit.prophet(m, train2)
|
||||
expect_error(predict(m, future2), NA)
|
||||
# Append_holidays with non-existing year
|
||||
max.year <- generated_holidays %>%
|
||||
dplyr::filter(country==append.holidays) %>%
|
||||
dplyr::filter(country=='US') %>%
|
||||
dplyr::select(year) %>%
|
||||
max()
|
||||
train2 <- data.frame('ds'=c(paste(max.year+1, "-01-01", sep=''),
|
||||
paste(max.year+1, "-01-02", sep='')),
|
||||
'y'=1)
|
||||
expect_warning(prophet(train2,
|
||||
append.holidays = append.holidays))
|
||||
m <- prophet()
|
||||
m <- add_country_holidays(m, 'US')
|
||||
expect_warning(m <- fit.prophet(m, train2))
|
||||
# Append_holidays with non-existing country
|
||||
append.holidays = 'Utopia'
|
||||
expect_error(prophet(DATA,
|
||||
append.holidays = append.holidays))
|
||||
m <- prophet()
|
||||
expect_error(add_country_holidays(m, 'Utopia'))
|
||||
})
|
||||
|
||||
test_that("make_future_dataframe", {
|
||||
|
|
|
|||
|
|
@ -434,10 +434,12 @@ class Prophet(object):
|
|||
|
||||
Returns
|
||||
-------
|
||||
dataframe of holiday dates, in holiday dataframe format used in
|
||||
initialization.
|
||||
"""
|
||||
all_holidays = pd.DataFrame()
|
||||
if self.holidays is not None:
|
||||
all_holidays = pd.concat((all_holidays, self.holidays))
|
||||
all_holidays = self.holidays.copy()
|
||||
if self.country_holidays is not None:
|
||||
year_list = list({x.year for x in dates})
|
||||
country_holidays_df = make_holidays_df(
|
||||
|
|
@ -464,7 +466,7 @@ class Prophet(object):
|
|||
all_holidays = pd.concat((all_holidays, holidays_to_add), sort=False)
|
||||
all_holidays.reset_index(drop=True, inplace=True)
|
||||
return all_holidays
|
||||
|
||||
|
||||
def make_holiday_features(self, dates, holidays):
|
||||
"""Construct a dataframe of holiday features.
|
||||
|
||||
|
|
@ -526,7 +528,7 @@ class Prophet(object):
|
|||
# Access key to generate value
|
||||
expanded_holidays[key]
|
||||
holiday_features = pd.DataFrame(expanded_holidays)
|
||||
# Make sure fit and predict component_cols perfectly equal
|
||||
# Make sure column order is consistent
|
||||
holiday_features = holiday_features[sorted(holiday_features.columns.tolist())]
|
||||
prior_scale_list = [
|
||||
prior_scales[h.split('_delim_')[0]]
|
||||
|
|
|
|||
Loading…
Reference in a new issue