diff --git a/R/R/prophet.R b/R/R/prophet.R index b024c1f..96f7fce 100644 --- a/R/R/prophet.R +++ b/R/R/prophet.R @@ -38,12 +38,13 @@ globalVariables(c( #' @param holidays data frame with columns holiday (character) and ds (date #' type)and optionally columns lower_window and upper_window which specify a #' range of days around the date to be included as holidays. lower_window=-2 -#' will include 2 days prior to the date as holidays. +#' will include 2 days prior to the date as holidays. Also optionally can have +#' a column prior_scale specifying the prior scale for each holiday. #' @param seasonality.prior.scale Parameter modulating the strength of the #' seasonality model. Larger values allow the model to fit larger seasonal #' fluctuations, smaller values dampen the seasonality. #' @param holidays.prior.scale Parameter modulating the strength of the holiday -#' components model. +#' components model, unless overridden in the holidays input. #' @param changepoint.prior.scale Parameter modulating the flexibility of the #' automatic changepoint selection. Large values will allow many changepoints, #' small values will allow few changepoints. @@ -487,7 +488,9 @@ make_seasonality_features <- function(dates, period, series.order, prefix) { #' @param m Prophet object. #' @param dates Vector with dates used for computing seasonality. #' -#' @return A dataframe with a column for each holiday. +#' @return A list with entries +#' holiday.features: dataframe with a column for each holiday. +#' prior.scales: array of prior scales for each holiday column. #' #' @importFrom dplyr "%>%" #' @keywords internal @@ -505,19 +508,40 @@ make_holiday_features <- function(m, dates) { } else { offsets <- c(0) } - names <- paste( - .$holiday, '_delim_', ifelse(offsets < 0, '-', '+'), abs(offsets), sep = '') - dplyr::data_frame(ds = .$ds + offsets * 24 * 3600, holiday = names) + if (exists('prior_scale', where = .) && !is.na(.$prior_scale)) { + ps <- .$prior_scale + } else { + ps <- m$holidays.prior.scale + } + names <- paste(.$holiday, '_delim_', ifelse(offsets < 0, '-', '+'), + abs(offsets), sep = '') + dplyr::data_frame(ds = .$ds + offsets * 24 * 3600, holiday = names, + prior_scale = ps) }) %>% dplyr::mutate(x = 1.) %>% tidyr::spread(holiday, x, fill = 0) - holiday.mat <- data.frame(ds = dates) %>% - dplyr::left_join(wide, by = 'ds') %>% - dplyr::select(-ds) + holiday.features <- data.frame(ds = set_date(dates)) %>% + dplyr::left_join(wide, by = 'ds') - holiday.mat[is.na(holiday.mat)] <- 0 - return(holiday.mat) + prior.scales.all <- holiday.features$prior_scale + prior.scales <- c() + + holiday.features <- dplyr::select(holiday.features, -ds, -prior_scale) + holiday.features[is.na(holiday.features)] <- 0 + + for (name in colnames(holiday.features)) { + rows <- !is.na(holiday.features[[name]]) & (holiday.features[[name]] == 1) + ps <- unique(prior.scales.all[rows]) + if (length(ps) > 1) { + sn <- strsplit(name, '_delim_', fixed = TRUE)[[1]][1] + stop('Holiday ', sn, ' does not have a consistent prior scale ', + 'specification') + } + prior.scales <- c(prior.scales, ps) + } + return(list(holiday.features = holiday.features, + prior.scales = prior.scales)) } #' Add an additional regressor to be used for fitting and predicting. @@ -617,10 +641,9 @@ make_all_seasonality_features <- function(m, df) { # Holiday features if (!is.null(m$holidays)) { - features <- make_holiday_features(m, df$ds) - seasonal.features <- cbind(seasonal.features, features) - prior.scales <- c(prior.scales, - m$holidays.prior.scale * rep(1, ncol(features))) + hf <- make_holiday_features(m, df$ds) + seasonal.features <- cbind(seasonal.features, hf$holiday.features) + prior.scales <- c(prior.scales, hf$prior.scales) } # Additional regressors diff --git a/R/tests/testthat/test_prophet.R b/R/tests/testthat/test_prophet.R index 149731e..fedcee5 100644 --- a/R/tests/testthat/test_prophet.R +++ b/R/tests/testthat/test_prophet.R @@ -208,19 +208,60 @@ test_that("holidays", { ds = seq(prophet:::set_date('2016-12-20'), prophet:::set_date('2016-12-31'), by='d')) m <- prophet(train, holidays = holidays, fit = FALSE) - feats <- prophet:::make_holiday_features(m, df$ds) + out <- prophet:::make_holiday_features(m, df$ds) + feats <- out$holiday.features + priors <- out$prior.scales expect_equal(nrow(feats), nrow(df)) expect_equal(ncol(feats), 2) expect_equal(sum(colSums(feats) - c(1, 1)), 0) + expect_true(all(priors == c(10., 10.))) holidays = data.frame(ds = c('2016-12-25'), holiday = c('xmas'), lower_window = c(-1), upper_window = c(10)) m <- prophet(train, holidays = holidays, fit = FALSE) - feats <- prophet:::make_holiday_features(m, df$ds) + out <- prophet:::make_holiday_features(m, df$ds) + feats <- out$holiday.features + priors <- out$prior.scales expect_equal(nrow(feats), nrow(df)) expect_equal(ncol(feats), 12) + expect_true(all(priors == rep(10, 12))) + # Check prior specifications + holidays <- data.frame( + ds = prophet:::set_date(c('2016-12-25', '2017-12-25')), + holiday = c('xmas', 'xmas'), + lower_window = c(-1, -1), + upper_window = c(0, 0), + prior_scale = c(5., 5.) + ) + m <- prophet(holidays = holidays, fit = FALSE) + out <- prophet:::make_holiday_features(m, df$ds) + priors <- out$prior.scales + expect_true(all(priors == c(5., 5.))) + # 2 different priors + holidays2 <- data.frame( + ds = prophet:::set_date(c('2012-06-06', '2013-06-06')), + holiday = c('seans-bday', 'seans-bday'), + lower_window = c(0, 0), + upper_window = c(1, 1), + prior_scale = c(8, 8) + ) + holiday2 <- rbind(holidays, holidays2) + m <- prophet(holidays = holidays2, fit = FALSE) + out <- prophet:::make_holiday_features(m, df$ds) + priors <- out$prior.scales + expect_true(all(priors == c(8,8, 5, 5))) + # Check incompatible priors + holidays <- data.frame( + ds = prophet:::set_date(c('2016-12-25', '2016-12-27')), + holiday = c('xmasish', 'xmasish'), + lower_window = c(-1, -1), + upper_window = c(0, 0), + prior_scale = c(5., 6.) + ) + m <- prophet(holidays = holidays, fit = FALSE) + expect_error(prophet:::make_holiday_features(m, df$ds)) }) test_that("fit_with_holidays", { diff --git a/python/fbprophet/tests/test_prophet.py b/python/fbprophet/tests/test_prophet.py index 7226959..068d0c4 100644 --- a/python/fbprophet/tests/test_prophet.py +++ b/python/fbprophet/tests/test_prophet.py @@ -310,8 +310,8 @@ class TestProphet(TestCase): self.assertEqual(priors, [8., 8., 5., 5.]) # Check incompatible priors holidays = pd.DataFrame({ - 'ds': pd.to_datetime(['2016-12-25', '2017-12-25']), - 'holiday': ['xmas', 'xmas'], + 'ds': pd.to_datetime(['2016-12-25', '2016-12-27']), + 'holiday': ['xmasish', 'xmasish'], 'lower_window': [-1, -1], 'upper_window': [0, 0], 'prior_scale': [5., 6.],