Add documentation for country holidays

This commit is contained in:
Ben Letham 2018-12-03 11:54:55 -08:00
parent 287fb2f6de
commit 536fe931c6
6 changed files with 206 additions and 29 deletions

View file

@ -41,6 +41,7 @@ make_holidays_df <- function(years, country.name){
holidays.df <- country.holidays %>%
dplyr::filter(year %in% years) %>%
dplyr::select(ds, holiday) %>%
dplyr::mutate(ds = as.Date(ds)) %>%
data.frame
return(holidays.df)
}
}

View file

@ -100,7 +100,7 @@ prophet_plot_components <- function(
panels <- list(
plot_forecast_component(m, fcst, 'trend', uncertainty, plot_cap))
# Plot holiday components, if present.
if (!is.null(m$holidays) && ('holidays' %in% colnames(fcst))) {
if (!is.null(m$train.holiday.names) && ('holidays' %in% colnames(fcst))) {
panels[[length(panels) + 1]] <- plot_forecast_component(
m, fcst, 'holidays', uncertainty, FALSE)
}

View file

@ -134,7 +134,7 @@ prophet <- function(df = NULL,
train.component.cols = NULL,
component.modes = NULL
)
validate_inputs(m)
m <- validate_inputs(m)
class(m) <- append("prophet", class(m))
if ((fit) && (!is.null(df))) {
m <- fit.prophet(m, df, ...)
@ -146,6 +146,8 @@ prophet <- function(df = NULL,
#'
#' @param m Prophet object.
#'
#' @return The Prophet object.
#'
#' @keywords internal
validate_inputs <- function(m) {
if (!(m$growth %in% c('linear', 'logistic'))) {
@ -161,6 +163,7 @@ validate_inputs <- function(m) {
if (!(exists('ds', where = m$holidays))) {
stop('Holidays dataframe must have ds field.')
}
m$holidays$ds <- as.Date(m$holidays$ds)
has.lower <- exists('lower_window', where = m$holidays)
has.upper <- exists('upper_window', where = m$holidays)
if (has.lower + has.upper == 1) {
@ -182,6 +185,7 @@ validate_inputs <- function(m) {
if (!(m$seasonality.mode %in% c('additive', 'multiplicative'))) {
stop("seasonality.mode must be 'additive' or 'multiplicative'")
}
return(m)
}
#' Validates the name of a seasonality, holiday, or regressor.
@ -543,8 +547,7 @@ construct_holiday_dataframe <- function(m, dates) {
}
if (!is.null(m$country_holidays)) {
year.list <- as.numeric(unique(format(dates, "%Y")))
country.holidays.df <- make_holidays_df(year.list, m$country_holidays) %>%
dplyr::mutate(ds=as.character(ds), holiday=as.character(holiday))
country.holidays.df <- make_holidays_df(year.list, m$country_holidays)
all.holidays <- suppressWarnings(dplyr::bind_rows(all.holidays, country.holidays.df))
}
# If the model has already been fit with a certain set of holidays,

File diff suppressed because one or more lines are too long

View file

@ -122,16 +122,6 @@ class Prophet(object):
self.yearly_seasonality = yearly_seasonality
self.weekly_seasonality = weekly_seasonality
self.daily_seasonality = daily_seasonality
if holidays is not None:
if not (
isinstance(holidays, pd.DataFrame)
and 'ds' in holidays # noqa W503
and 'holiday' in holidays # noqa W503
):
raise ValueError("holidays must be a DataFrame with 'ds' and "
"'holiday' columns.")
holidays['ds'] = pd.to_datetime(holidays['ds'])
self.holidays = holidays
self.seasonality_mode = seasonality_mode
@ -169,6 +159,14 @@ class Prophet(object):
if ((self.changepoint_range < 0) or (self.changepoint_range > 1)):
raise ValueError("Parameter 'changepoint_range' must be in [0, 1]")
if self.holidays is not None:
if not (
isinstance(holidays, pd.DataFrame)
and 'ds' in holidays # noqa W503
and 'holiday' in holidays # noqa W503
):
raise ValueError("holidays must be a DataFrame with 'ds' and "
"'holiday' columns.")
holidays['ds'] = pd.to_datetime(holidays['ds'])
has_lower = 'lower_window' in self.holidays
has_upper = 'upper_window' in self.holidays
if has_lower + has_upper == 1:

View file

@ -104,7 +104,7 @@ def plot_components(
"""
# Identify components to be plotted
components = ['trend']
if m.holidays is not None and 'holidays' in fcst:
if m.train_holiday_names is not None and 'holidays' in fcst:
components.append('holidays')
components.extend([name for name in m.seasonalities
if name in fcst])