mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-06-30 03:37:53 +00:00
Add documentation for country holidays
This commit is contained in:
parent
287fb2f6de
commit
536fe931c6
6 changed files with 206 additions and 29 deletions
|
|
@ -41,6 +41,7 @@ make_holidays_df <- function(years, country.name){
|
|||
holidays.df <- country.holidays %>%
|
||||
dplyr::filter(year %in% years) %>%
|
||||
dplyr::select(ds, holiday) %>%
|
||||
dplyr::mutate(ds = as.Date(ds)) %>%
|
||||
data.frame
|
||||
return(holidays.df)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ prophet_plot_components <- function(
|
|||
panels <- list(
|
||||
plot_forecast_component(m, fcst, 'trend', uncertainty, plot_cap))
|
||||
# Plot holiday components, if present.
|
||||
if (!is.null(m$holidays) && ('holidays' %in% colnames(fcst))) {
|
||||
if (!is.null(m$train.holiday.names) && ('holidays' %in% colnames(fcst))) {
|
||||
panels[[length(panels) + 1]] <- plot_forecast_component(
|
||||
m, fcst, 'holidays', uncertainty, FALSE)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -134,7 +134,7 @@ prophet <- function(df = NULL,
|
|||
train.component.cols = NULL,
|
||||
component.modes = NULL
|
||||
)
|
||||
validate_inputs(m)
|
||||
m <- validate_inputs(m)
|
||||
class(m) <- append("prophet", class(m))
|
||||
if ((fit) && (!is.null(df))) {
|
||||
m <- fit.prophet(m, df, ...)
|
||||
|
|
@ -146,6 +146,8 @@ prophet <- function(df = NULL,
|
|||
#'
|
||||
#' @param m Prophet object.
|
||||
#'
|
||||
#' @return The Prophet object.
|
||||
#'
|
||||
#' @keywords internal
|
||||
validate_inputs <- function(m) {
|
||||
if (!(m$growth %in% c('linear', 'logistic'))) {
|
||||
|
|
@ -161,6 +163,7 @@ validate_inputs <- function(m) {
|
|||
if (!(exists('ds', where = m$holidays))) {
|
||||
stop('Holidays dataframe must have ds field.')
|
||||
}
|
||||
m$holidays$ds <- as.Date(m$holidays$ds)
|
||||
has.lower <- exists('lower_window', where = m$holidays)
|
||||
has.upper <- exists('upper_window', where = m$holidays)
|
||||
if (has.lower + has.upper == 1) {
|
||||
|
|
@ -182,6 +185,7 @@ validate_inputs <- function(m) {
|
|||
if (!(m$seasonality.mode %in% c('additive', 'multiplicative'))) {
|
||||
stop("seasonality.mode must be 'additive' or 'multiplicative'")
|
||||
}
|
||||
return(m)
|
||||
}
|
||||
|
||||
#' Validates the name of a seasonality, holiday, or regressor.
|
||||
|
|
@ -543,8 +547,7 @@ construct_holiday_dataframe <- function(m, dates) {
|
|||
}
|
||||
if (!is.null(m$country_holidays)) {
|
||||
year.list <- as.numeric(unique(format(dates, "%Y")))
|
||||
country.holidays.df <- make_holidays_df(year.list, m$country_holidays) %>%
|
||||
dplyr::mutate(ds=as.character(ds), holiday=as.character(holiday))
|
||||
country.holidays.df <- make_holidays_df(year.list, m$country_holidays)
|
||||
all.holidays <- suppressWarnings(dplyr::bind_rows(all.holidays, country.holidays.df))
|
||||
}
|
||||
# If the model has already been fit with a certain set of holidays,
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -122,16 +122,6 @@ class Prophet(object):
|
|||
self.yearly_seasonality = yearly_seasonality
|
||||
self.weekly_seasonality = weekly_seasonality
|
||||
self.daily_seasonality = daily_seasonality
|
||||
|
||||
if holidays is not None:
|
||||
if not (
|
||||
isinstance(holidays, pd.DataFrame)
|
||||
and 'ds' in holidays # noqa W503
|
||||
and 'holiday' in holidays # noqa W503
|
||||
):
|
||||
raise ValueError("holidays must be a DataFrame with 'ds' and "
|
||||
"'holiday' columns.")
|
||||
holidays['ds'] = pd.to_datetime(holidays['ds'])
|
||||
self.holidays = holidays
|
||||
|
||||
self.seasonality_mode = seasonality_mode
|
||||
|
|
@ -169,6 +159,14 @@ class Prophet(object):
|
|||
if ((self.changepoint_range < 0) or (self.changepoint_range > 1)):
|
||||
raise ValueError("Parameter 'changepoint_range' must be in [0, 1]")
|
||||
if self.holidays is not None:
|
||||
if not (
|
||||
isinstance(holidays, pd.DataFrame)
|
||||
and 'ds' in holidays # noqa W503
|
||||
and 'holiday' in holidays # noqa W503
|
||||
):
|
||||
raise ValueError("holidays must be a DataFrame with 'ds' and "
|
||||
"'holiday' columns.")
|
||||
holidays['ds'] = pd.to_datetime(holidays['ds'])
|
||||
has_lower = 'lower_window' in self.holidays
|
||||
has_upper = 'upper_window' in self.holidays
|
||||
if has_lower + has_upper == 1:
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ def plot_components(
|
|||
"""
|
||||
# Identify components to be plotted
|
||||
components = ['trend']
|
||||
if m.holidays is not None and 'holidays' in fcst:
|
||||
if m.train_holiday_names is not None and 'holidays' in fcst:
|
||||
components.append('holidays')
|
||||
components.extend([name for name in m.seasonalities
|
||||
if name in fcst])
|
||||
|
|
|
|||
Loading…
Reference in a new issue