diff --git a/R/R/prophet.R b/R/R/prophet.R index 96f7fce..e37567a 100644 --- a/R/R/prophet.R +++ b/R/R/prophet.R @@ -42,7 +42,8 @@ globalVariables(c( #' a column prior_scale specifying the prior scale for each holiday. #' @param seasonality.prior.scale Parameter modulating the strength of the #' seasonality model. Larger values allow the model to fit larger seasonal -#' fluctuations, smaller values dampen the seasonality. +#' fluctuations, smaller values dampen the seasonality. Can be specified for +#' individual seasonalities using add_seasonality. #' @param holidays.prior.scale Parameter modulating the strength of the holiday #' components model, unless overridden in the holidays input. #' @param changepoint.prior.scale Parameter modulating the flexibility of the @@ -508,37 +509,44 @@ make_holiday_features <- function(m, dates) { } else { offsets <- c(0) } - if (exists('prior_scale', where = .) && !is.na(.$prior_scale)) { - ps <- .$prior_scale - } else { - ps <- m$holidays.prior.scale - } names <- paste(.$holiday, '_delim_', ifelse(offsets < 0, '-', '+'), abs(offsets), sep = '') - dplyr::data_frame(ds = .$ds + offsets * 24 * 3600, holiday = names, - prior_scale = ps) + dplyr::data_frame(ds = .$ds + offsets * 24 * 3600, holiday = names) }) %>% dplyr::mutate(x = 1.) %>% tidyr::spread(holiday, x, fill = 0) holiday.features <- data.frame(ds = set_date(dates)) %>% - dplyr::left_join(wide, by = 'ds') + dplyr::left_join(wide, by = 'ds') %>% + dplyr::select(-ds) - prior.scales.all <- holiday.features$prior_scale - prior.scales <- c() - - holiday.features <- dplyr::select(holiday.features, -ds, -prior_scale) holiday.features[is.na(holiday.features)] <- 0 - for (name in colnames(holiday.features)) { - rows <- !is.na(holiday.features[[name]]) & (holiday.features[[name]] == 1) - ps <- unique(prior.scales.all[rows]) + # Prior scales + if (!('prior_scale' %in% colnames(m$holidays))) { + m$holidays$prior_scale <- m$holidays.prior.scale + } + prior.scales.list <- list() + for (name in unique(m$holidays$holiday)) { + df.h <- m$holidays[m$holidays$holiday == name, ] + ps <- unique(df.h$prior_scale) if (length(ps) > 1) { - sn <- strsplit(name, '_delim_', fixed = TRUE)[[1]][1] - stop('Holiday ', sn, ' does not have a consistent prior scale ', + stop('Holiday ', name, ' does not have a consistent prior scale ', 'specification') } - prior.scales <- c(prior.scales, ps) + if (is.na(ps)) { + ps <- m$holidays.prior.scale + } + if (ps <= 0) { + stop('Prior scale must be > 0.') + } + prior.scales.list[[name]] <- ps + } + + prior.scales <- c() + for (name in colnames(holiday.features)) { + sn <- strsplit(name, '_delim_', fixed = TRUE)[[1]][1] + prior.scales <- c(prior.scales, prior.scales.list[[sn]]) } return(list(holiday.features = holiday.features, prior.scales = prior.scales)) @@ -584,23 +592,28 @@ add_regressor <- function(m, name, prior.scale = NULL, standardize = 'auto'){ return(m) } -#' Add a seasonal component with specified period and number of Fourier -#' components. +#' Add a seasonal component with specified period, number of Fourier +#' components, and prior scale. #' #' Increasing the number of Fourier components allows the seasonality to change #' more quickly (at risk of overfitting). Default values for yearly and weekly #' seasonalities are 10 and 3 respectively. #' +#' Increasing prior scale will allow this seasonality component more +#' flexibility, decreasing will dampen it. If not provided, will use the +#' seasonality.prior.scale provided on Prophet initialization (defaults to 10). +#' #' @param m Prophet object. #' @param name String name of the seasonality component. #' @param period Float number of days in one period. #' @param fourier.order Int number of Fourier components to use. +#' @param prior.scale Float prior scale for this component. #' #' @return The prophet model with the seasonality added. #' #' @importFrom dplyr "%>%" #' @export -add_seasonality <- function(m, name, period, fourier.order) { +add_seasonality <- function(m, name, period, fourier.order, prior.scale = NULL) { if (!is.null(m$history)) { stop("Seasonality must be added prior to model fitting.") } @@ -608,7 +621,19 @@ add_seasonality <- function(m, name, period, fourier.order) { # Allow overriding built-in seasonalities validate_column_name(m, name, check_seasonalities = FALSE) } - m$seasonalities[[name]] <- c(period, fourier.order) + if (is.null(prior.scale)) { + ps <- m$seasonality.prior.scale + } else { + ps <- prior.scale + } + if (ps <= 0) { + stop('Prior scale must be > 0') + } + m$seasonalities[[name]] <- list( + period = period, + fourier.order = fourier.order, + prior.scale = ps + ) return(m) } @@ -631,12 +656,12 @@ make_all_seasonality_features <- function(m, df) { # Seasonality features for (name in names(m$seasonalities)) { - period <- m$seasonalities[[name]][1] - series.order <- m$seasonalities[[name]][2] - features <- make_seasonality_features(df$ds, period, series.order, name) + props <- m$seasonalities[[name]] + features <- make_seasonality_features( + df$ds, props$period, props$fourier.order, name) seasonal.features <- cbind(seasonal.features, features) prior.scales <- c(prior.scales, - m$seasonality.prior.scale * rep(1, ncol(features))) + props$prior.scale * rep(1, ncol(features))) } # Holiday features @@ -751,21 +776,33 @@ set_auto_seasonalities <- function(m) { fourier.order <- parse_seasonality_args( m, 'yearly', m$yearly.seasonality, yearly.disable, 10) if (fourier.order > 0) { - m$seasonalities[['yearly']] <- c(365.25, fourier.order) + m$seasonalities[['yearly']] <- list( + period = 365.25, + fourier.order = fourier.order, + prior.scale = m$seasonality.prior.scale + ) } weekly.disable <- ((time_diff(last, first) < 14) || (min.dt >= 7)) fourier.order <- parse_seasonality_args( m, 'weekly', m$weekly.seasonality, weekly.disable, 3) if (fourier.order > 0) { - m$seasonalities[['weekly']] <- c(7, fourier.order) + m$seasonalities[['weekly']] <- list( + period = 7, + fourier.order = fourier.order, + prior.scale = m$seasonality.prior.scale + ) } daily.disable <- ((time_diff(last, first) < 2) || (min.dt >= 1)) fourier.order <- parse_seasonality_args( m, 'daily', m$daily.seasonality, daily.disable, 4) if (fourier.order > 0) { - m$seasonalities[['daily']] <- c(1, fourier.order) + m$seasonalities[['daily']] <- list( + period = 1, + fourier.order = fourier.order, + prior.scale = m$seasonality.prior.scale + ) } return(m) } @@ -1598,7 +1635,7 @@ plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0) { plot_seasonality <- function(m, name, uncertainty = TRUE) { # Compute seasonality from Jan 1 through a single period. start <- set_date('2017-01-01') - period <- m$seasonalities[[name]][1] + period <- m$seasonalities[[name]]$period end <- start + period * 24 * 3600 plot.points <- 200 days <- seq(from=start, to=end, length.out=plot.points) diff --git a/R/man/add_seasonality.Rd b/R/man/add_seasonality.Rd index 5b3955c..6df3f19 100644 --- a/R/man/add_seasonality.Rd +++ b/R/man/add_seasonality.Rd @@ -2,10 +2,10 @@ % Please edit documentation in R/prophet.R \name{add_seasonality} \alias{add_seasonality} -\title{Add a seasonal component with specified period and number of Fourier -components.} +\title{Add a seasonal component with specified period, number of Fourier +components, and prior scale.} \usage{ -add_seasonality(m, name, period, fourier.order) +add_seasonality(m, name, period, fourier.order, prior.scale = NULL) } \arguments{ \item{m}{Prophet object.} @@ -15,6 +15,8 @@ add_seasonality(m, name, period, fourier.order) \item{period}{Float number of days in one period.} \item{fourier.order}{Int number of Fourier components to use.} + +\item{prior.scale}{Float prior scale for this component.} } \value{ The prophet model with the seasonality added. @@ -24,3 +26,8 @@ Increasing the number of Fourier components allows the seasonality to change more quickly (at risk of overfitting). Default values for yearly and weekly seasonalities are 10 and 3 respectively. } +\details{ +Increasing prior scale will allow this seasonality component more +flexibility, decreasing will dampen it. If not provided, will use the +seasonality.prior.scale provided on Prophet initialization (defaults to 10). +} diff --git a/R/man/make_holiday_features.Rd b/R/man/make_holiday_features.Rd index 7a46b6b..c502972 100644 --- a/R/man/make_holiday_features.Rd +++ b/R/man/make_holiday_features.Rd @@ -12,7 +12,9 @@ make_holiday_features(m, dates) \item{dates}{Vector with dates used for computing seasonality.} } \value{ -A dataframe with a column for each holiday. +A list with entries + holiday.features: dataframe with a column for each holiday. + prior.scales: array of prior scales for each holiday column. } \description{ Construct a matrix of holiday features. diff --git a/R/man/prophet.Rd b/R/man/prophet.Rd index 3ce933c..dc73bbf 100644 --- a/R/man/prophet.Rd +++ b/R/man/prophet.Rd @@ -43,14 +43,16 @@ FALSE, or a number of Fourier terms to generate.} \item{holidays}{data frame with columns holiday (character) and ds (date type)and optionally columns lower_window and upper_window which specify a range of days around the date to be included as holidays. lower_window=-2 -will include 2 days prior to the date as holidays.} +will include 2 days prior to the date as holidays. Also optionally can have +a column prior_scale specifying the prior scale for each holiday.} \item{seasonality.prior.scale}{Parameter modulating the strength of the seasonality model. Larger values allow the model to fit larger seasonal -fluctuations, smaller values dampen the seasonality.} +fluctuations, smaller values dampen the seasonality. Can be specified for +individual seasonalities using add_seasonality.} \item{holidays.prior.scale}{Parameter modulating the strength of the holiday -components model.} +components model, unless overridden in the holidays input.} \item{changepoint.prior.scale}{Parameter modulating the flexibility of the automatic changepoint selection. Large values will allow many changepoints, diff --git a/R/tests/testthat/test_prophet.R b/R/tests/testthat/test_prophet.R index fedcee5..74bcb31 100644 --- a/R/tests/testthat/test_prophet.R +++ b/R/tests/testthat/test_prophet.R @@ -247,11 +247,22 @@ test_that("holidays", { upper_window = c(1, 1), prior_scale = c(8, 8) ) - holiday2 <- rbind(holidays, holidays2) + holidays2 <- rbind(holidays, holidays2) m <- prophet(holidays = holidays2, fit = FALSE) out <- prophet:::make_holiday_features(m, df$ds) priors <- out$prior.scales - expect_true(all(priors == c(8,8, 5, 5))) + expect_true(all(priors == c(8, 8, 5, 5))) + holidays2 <- data.frame( + ds = prophet:::set_date(c('2012-06-06', '2013-06-06')), + holiday = c('seans-bday', 'seans-bday'), + lower_window = c(0, 0), + upper_window = c(1, 1) + ) + holidays2 <- dplyr::bind_rows(holidays, holidays2) + m <- prophet(holidays = holidays2, fit = FALSE, holidays.prior.scale = 4) + out <- prophet:::make_holiday_features(m, df$ds) + priors <- out$prior.scales + expect_true(all(priors == c(4, 4, 5, 5))) # Check incompatible priors holidays <- data.frame( ds = prophet:::set_date(c('2016-12-25', '2016-12-27')), @@ -296,9 +307,12 @@ test_that("auto_weekly_seasonality", { train.w <- DATA[1:N.w, ] m <- prophet(train.w, fit = FALSE) expect_equal(m$weekly.seasonality, 'auto') - m <- prophet:::fit.prophet(m, train.w) + m <- fit.prophet(m, train.w) expect_true('weekly' %in% names(m$seasonalities)) - expect_equal(m$seasonalities[['weekly']], c(7, 3)) + true <- list(period = 7, fourier.order = 3, prior.scale = 10) + for (name in names(true)) { + expect_equal(m$seasonalities$weekly[[name]], true[[name]]) + } # Should be disabled due to too short history N.w <- 9 train.w <- DATA[1:N.w, ] @@ -310,8 +324,11 @@ test_that("auto_weekly_seasonality", { train.w <- DATA[seq(1, nrow(DATA), 7), ] m <- prophet(train.w) expect_false('weekly' %in% names(m$seasonalities)) - m <- prophet(DATA, weekly.seasonality=2) - expect_equal(m$seasonalities[['weekly']], c(7, 2)) + m <- prophet(DATA, weekly.seasonality = 2, seasonality.prior.scale = 3) + true <- list(period = 7, fourier.order = 2, prior.scale = 3) + for (name in names(true)) { + expect_equal(m$seasonalities$weekly[[name]], true[[name]]) + } }) test_that("auto_yearly_seasonality", { @@ -319,9 +336,12 @@ test_that("auto_yearly_seasonality", { # Should be enabled m <- prophet(DATA, fit = FALSE) expect_equal(m$yearly.seasonality, 'auto') - m <- prophet:::fit.prophet(m, DATA) + m <- fit.prophet(m, DATA) expect_true('yearly' %in% names(m$seasonalities)) - expect_equal(m$seasonalities[['yearly']], c(365.25, 10)) + true <- list(period = 365.25, fourier.order = 10, prior.scale = 10) + for (name in names(true)) { + expect_equal(m$seasonalities$yearly[[name]], true[[name]]) + } # Should be disabled due to too short history N.w <- 240 train.y <- DATA[1:N.w, ] @@ -329,8 +349,11 @@ test_that("auto_yearly_seasonality", { expect_false('yearly' %in% names(m$seasonalities)) m <- prophet(train.y, yearly.seasonality = TRUE) expect_true('yearly' %in% names(m$seasonalities)) - m <- prophet(DATA, yearly.seasonality=7) - expect_equal(m$seasonalities[['yearly']], c(365.25, 7)) + m <- prophet(DATA, yearly.seasonality = 7, seasonality.prior.scale = 3) + true <- list(period = 365.25, fourier.order = 7, prior.scale = 3) + for (name in names(true)) { + expect_equal(m$seasonalities$yearly[[name]], true[[name]]) + } }) test_that("auto_daily_seasonality", { @@ -338,9 +361,12 @@ test_that("auto_daily_seasonality", { # Should be enabled m <- prophet(DATA2, fit = FALSE) expect_equal(m$daily.seasonality, 'auto') - m <- prophet:::fit.prophet(m, DATA2) + m <- fit.prophet(m, DATA2) expect_true('daily' %in% names(m$seasonalities)) - expect_equal(m$seasonalities[['daily']], c(1, 4)) + true <- list(period = 1, fourier.order = 4, prior.scale = 10) + for (name in names(true)) { + expect_equal(m$seasonalities$daily[[name]], true[[name]]) + } # Should be disabled due to too short history N.d <- 430 train.y <- DATA2[1:N.d, ] @@ -348,8 +374,11 @@ test_that("auto_daily_seasonality", { expect_false('daily' %in% names(m$seasonalities)) m <- prophet(train.y, daily.seasonality = TRUE) expect_true('daily' %in% names(m$seasonalities)) - m <- prophet(DATA2, daily.seasonality=7) - expect_equal(m$seasonalities[['daily']], c(1, 7)) + m <- prophet(DATA2, daily.seasonality = 7, seasonality.prior.scale = 3) + true <- list(period = 1, fourier.order = 7, prior.scale = 3) + for (name in names(true)) { + expect_equal(m$seasonalities$daily[[name]], true[[name]]) + } m <- prophet(DATA) expect_false('daily' %in% names(m$seasonalities)) }) @@ -366,10 +395,14 @@ test_that("test_subdaily_holidays", { test_that("custom_seasonality", { skip_if_not(Sys.getenv('R_ARCH') != '/i386') holidays <- data.frame(ds = c('2017-01-02'), - holiday = c('special_day')) + holiday = c('special_day'), + prior_scale = c(4)) m <- prophet(holidays=holidays) m <- add_seasonality(m, name='monthly', period=30, fourier.order=5) - expect_equal(m$seasonalities[['monthly']], c(30, 5)) + true <- list(period = 30, fourier.order = 5, prior.scale = 10) + for (name in names(true)) { + expect_equal(m$seasonalities$monthly[[name]], true[[name]]) + } expect_error( add_seasonality(m, name='special_day', period=30, fourier_order=5) ) @@ -377,6 +410,14 @@ test_that("custom_seasonality", { add_seasonality(m, name='trend', period=30, fourier_order=5) ) m <- add_seasonality(m, name='weekly', period=30, fourier.order=5) + # Test priors + m <- prophet(holidays = holidays, yearly.seasonality = FALSE) + m <- add_seasonality( + m, name='monthly', period=30, fourier.order=5, prior.scale = 2) + m <- fit.prophet(m, DATA) + prior.scales <- prophet:::make_all_seasonality_features( + m, m$history)$prior.scales + expect_true(all(prior.scales == c(rep(2, 10), rep(10, 6), 4))) }) test_that("added_regressors", { diff --git a/python/fbprophet/forecaster.py b/python/fbprophet/forecaster.py index 662217e..94728f7 100644 --- a/python/fbprophet/forecaster.py +++ b/python/fbprophet/forecaster.py @@ -413,9 +413,8 @@ class Prophet(object): except ValueError: lw = 0 uw = 0 - try: - ps = float(row.get('prior_scale', self.holidays_prior_scale)) - except ValueError: + ps = float(row.get('prior_scale', self.holidays_prior_scale)) + if np.isnan(ps): ps = float(self.holidays_prior_scale) if ( row.holiday in prior_scales and prior_scales[row.holiday] != ps diff --git a/python/fbprophet/tests/test_prophet.py b/python/fbprophet/tests/test_prophet.py index 068d0c4..23db966 100644 --- a/python/fbprophet/tests/test_prophet.py +++ b/python/fbprophet/tests/test_prophet.py @@ -308,6 +308,17 @@ class TestProphet(TestCase): holidays2 = pd.concat((holidays, holidays2)) feats, priors = Prophet(holidays=holidays2).make_holiday_features(df['ds']) self.assertEqual(priors, [8., 8., 5., 5.]) + holidays2 = pd.DataFrame({ + 'ds': pd.to_datetime(['2012-06-06', '2013-06-06']), + 'holiday': ['seans-bday'] * 2, + 'lower_window': [0] * 2, + 'upper_window': [1] * 2, + }) + holidays2 = pd.concat((holidays, holidays2)) + feats, priors = Prophet( + holidays=holidays2, holidays_prior_scale=4 + ).make_holiday_features(df['ds']) + self.assertEqual(priors, [4., 4., 5., 5.]) # Check incompatible priors holidays = pd.DataFrame({ 'ds': pd.to_datetime(['2016-12-25', '2016-12-27']),