Custom prior scales R

This commit is contained in:
Ben Letham 2017-08-31 10:56:06 -07:00
parent 23d8bc25dc
commit ddbb353278
3 changed files with 83 additions and 19 deletions

View file

@ -38,12 +38,13 @@ globalVariables(c(
#' @param holidays data frame with columns holiday (character) and ds (date
#' type)and optionally columns lower_window and upper_window which specify a
#' range of days around the date to be included as holidays. lower_window=-2
#' will include 2 days prior to the date as holidays.
#' will include 2 days prior to the date as holidays. Also optionally can have
#' a column prior_scale specifying the prior scale for each holiday.
#' @param seasonality.prior.scale Parameter modulating the strength of the
#' seasonality model. Larger values allow the model to fit larger seasonal
#' fluctuations, smaller values dampen the seasonality.
#' @param holidays.prior.scale Parameter modulating the strength of the holiday
#' components model.
#' components model, unless overridden in the holidays input.
#' @param changepoint.prior.scale Parameter modulating the flexibility of the
#' automatic changepoint selection. Large values will allow many changepoints,
#' small values will allow few changepoints.
@ -487,7 +488,9 @@ make_seasonality_features <- function(dates, period, series.order, prefix) {
#' @param m Prophet object.
#' @param dates Vector with dates used for computing seasonality.
#'
#' @return A dataframe with a column for each holiday.
#' @return A list with entries
#' holiday.features: dataframe with a column for each holiday.
#' prior.scales: array of prior scales for each holiday column.
#'
#' @importFrom dplyr "%>%"
#' @keywords internal
@ -505,19 +508,40 @@ make_holiday_features <- function(m, dates) {
} else {
offsets <- c(0)
}
names <- paste(
.$holiday, '_delim_', ifelse(offsets < 0, '-', '+'), abs(offsets), sep = '')
dplyr::data_frame(ds = .$ds + offsets * 24 * 3600, holiday = names)
if (exists('prior_scale', where = .) && !is.na(.$prior_scale)) {
ps <- .$prior_scale
} else {
ps <- m$holidays.prior.scale
}
names <- paste(.$holiday, '_delim_', ifelse(offsets < 0, '-', '+'),
abs(offsets), sep = '')
dplyr::data_frame(ds = .$ds + offsets * 24 * 3600, holiday = names,
prior_scale = ps)
}) %>%
dplyr::mutate(x = 1.) %>%
tidyr::spread(holiday, x, fill = 0)
holiday.mat <- data.frame(ds = dates) %>%
dplyr::left_join(wide, by = 'ds') %>%
dplyr::select(-ds)
holiday.features <- data.frame(ds = set_date(dates)) %>%
dplyr::left_join(wide, by = 'ds')
holiday.mat[is.na(holiday.mat)] <- 0
return(holiday.mat)
prior.scales.all <- holiday.features$prior_scale
prior.scales <- c()
holiday.features <- dplyr::select(holiday.features, -ds, -prior_scale)
holiday.features[is.na(holiday.features)] <- 0
for (name in colnames(holiday.features)) {
rows <- !is.na(holiday.features[[name]]) & (holiday.features[[name]] == 1)
ps <- unique(prior.scales.all[rows])
if (length(ps) > 1) {
sn <- strsplit(name, '_delim_', fixed = TRUE)[[1]][1]
stop('Holiday ', sn, ' does not have a consistent prior scale ',
'specification')
}
prior.scales <- c(prior.scales, ps)
}
return(list(holiday.features = holiday.features,
prior.scales = prior.scales))
}
#' Add an additional regressor to be used for fitting and predicting.
@ -617,10 +641,9 @@ make_all_seasonality_features <- function(m, df) {
# Holiday features
if (!is.null(m$holidays)) {
features <- make_holiday_features(m, df$ds)
seasonal.features <- cbind(seasonal.features, features)
prior.scales <- c(prior.scales,
m$holidays.prior.scale * rep(1, ncol(features)))
hf <- make_holiday_features(m, df$ds)
seasonal.features <- cbind(seasonal.features, hf$holiday.features)
prior.scales <- c(prior.scales, hf$prior.scales)
}
# Additional regressors

View file

@ -208,19 +208,60 @@ test_that("holidays", {
ds = seq(prophet:::set_date('2016-12-20'),
prophet:::set_date('2016-12-31'), by='d'))
m <- prophet(train, holidays = holidays, fit = FALSE)
feats <- prophet:::make_holiday_features(m, df$ds)
out <- prophet:::make_holiday_features(m, df$ds)
feats <- out$holiday.features
priors <- out$prior.scales
expect_equal(nrow(feats), nrow(df))
expect_equal(ncol(feats), 2)
expect_equal(sum(colSums(feats) - c(1, 1)), 0)
expect_true(all(priors == c(10., 10.)))
holidays = data.frame(ds = c('2016-12-25'),
holiday = c('xmas'),
lower_window = c(-1),
upper_window = c(10))
m <- prophet(train, holidays = holidays, fit = FALSE)
feats <- prophet:::make_holiday_features(m, df$ds)
out <- prophet:::make_holiday_features(m, df$ds)
feats <- out$holiday.features
priors <- out$prior.scales
expect_equal(nrow(feats), nrow(df))
expect_equal(ncol(feats), 12)
expect_true(all(priors == rep(10, 12)))
# Check prior specifications
holidays <- data.frame(
ds = prophet:::set_date(c('2016-12-25', '2017-12-25')),
holiday = c('xmas', 'xmas'),
lower_window = c(-1, -1),
upper_window = c(0, 0),
prior_scale = c(5., 5.)
)
m <- prophet(holidays = holidays, fit = FALSE)
out <- prophet:::make_holiday_features(m, df$ds)
priors <- out$prior.scales
expect_true(all(priors == c(5., 5.)))
# 2 different priors
holidays2 <- data.frame(
ds = prophet:::set_date(c('2012-06-06', '2013-06-06')),
holiday = c('seans-bday', 'seans-bday'),
lower_window = c(0, 0),
upper_window = c(1, 1),
prior_scale = c(8, 8)
)
holiday2 <- rbind(holidays, holidays2)
m <- prophet(holidays = holidays2, fit = FALSE)
out <- prophet:::make_holiday_features(m, df$ds)
priors <- out$prior.scales
expect_true(all(priors == c(8,8, 5, 5)))
# Check incompatible priors
holidays <- data.frame(
ds = prophet:::set_date(c('2016-12-25', '2016-12-27')),
holiday = c('xmasish', 'xmasish'),
lower_window = c(-1, -1),
upper_window = c(0, 0),
prior_scale = c(5., 6.)
)
m <- prophet(holidays = holidays, fit = FALSE)
expect_error(prophet:::make_holiday_features(m, df$ds))
})
test_that("fit_with_holidays", {

View file

@ -310,8 +310,8 @@ class TestProphet(TestCase):
self.assertEqual(priors, [8., 8., 5., 5.])
# Check incompatible priors
holidays = pd.DataFrame({
'ds': pd.to_datetime(['2016-12-25', '2017-12-25']),
'holiday': ['xmas', 'xmas'],
'ds': pd.to_datetime(['2016-12-25', '2016-12-27']),
'holiday': ['xmasish', 'xmasish'],
'lower_window': [-1, -1],
'upper_window': [0, 0],
'prior_scale': [5., 6.],