mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-07-03 03:59:00 +00:00
Custom prior scales R
This commit is contained in:
parent
23d8bc25dc
commit
ddbb353278
3 changed files with 83 additions and 19 deletions
|
|
@ -38,12 +38,13 @@ globalVariables(c(
|
|||
#' @param holidays data frame with columns holiday (character) and ds (date
|
||||
#' type)and optionally columns lower_window and upper_window which specify a
|
||||
#' range of days around the date to be included as holidays. lower_window=-2
|
||||
#' will include 2 days prior to the date as holidays.
|
||||
#' will include 2 days prior to the date as holidays. Also optionally can have
|
||||
#' a column prior_scale specifying the prior scale for each holiday.
|
||||
#' @param seasonality.prior.scale Parameter modulating the strength of the
|
||||
#' seasonality model. Larger values allow the model to fit larger seasonal
|
||||
#' fluctuations, smaller values dampen the seasonality.
|
||||
#' @param holidays.prior.scale Parameter modulating the strength of the holiday
|
||||
#' components model.
|
||||
#' components model, unless overridden in the holidays input.
|
||||
#' @param changepoint.prior.scale Parameter modulating the flexibility of the
|
||||
#' automatic changepoint selection. Large values will allow many changepoints,
|
||||
#' small values will allow few changepoints.
|
||||
|
|
@ -487,7 +488,9 @@ make_seasonality_features <- function(dates, period, series.order, prefix) {
|
|||
#' @param m Prophet object.
|
||||
#' @param dates Vector with dates used for computing seasonality.
|
||||
#'
|
||||
#' @return A dataframe with a column for each holiday.
|
||||
#' @return A list with entries
|
||||
#' holiday.features: dataframe with a column for each holiday.
|
||||
#' prior.scales: array of prior scales for each holiday column.
|
||||
#'
|
||||
#' @importFrom dplyr "%>%"
|
||||
#' @keywords internal
|
||||
|
|
@ -505,19 +508,40 @@ make_holiday_features <- function(m, dates) {
|
|||
} else {
|
||||
offsets <- c(0)
|
||||
}
|
||||
names <- paste(
|
||||
.$holiday, '_delim_', ifelse(offsets < 0, '-', '+'), abs(offsets), sep = '')
|
||||
dplyr::data_frame(ds = .$ds + offsets * 24 * 3600, holiday = names)
|
||||
if (exists('prior_scale', where = .) && !is.na(.$prior_scale)) {
|
||||
ps <- .$prior_scale
|
||||
} else {
|
||||
ps <- m$holidays.prior.scale
|
||||
}
|
||||
names <- paste(.$holiday, '_delim_', ifelse(offsets < 0, '-', '+'),
|
||||
abs(offsets), sep = '')
|
||||
dplyr::data_frame(ds = .$ds + offsets * 24 * 3600, holiday = names,
|
||||
prior_scale = ps)
|
||||
}) %>%
|
||||
dplyr::mutate(x = 1.) %>%
|
||||
tidyr::spread(holiday, x, fill = 0)
|
||||
|
||||
holiday.mat <- data.frame(ds = dates) %>%
|
||||
dplyr::left_join(wide, by = 'ds') %>%
|
||||
dplyr::select(-ds)
|
||||
holiday.features <- data.frame(ds = set_date(dates)) %>%
|
||||
dplyr::left_join(wide, by = 'ds')
|
||||
|
||||
holiday.mat[is.na(holiday.mat)] <- 0
|
||||
return(holiday.mat)
|
||||
prior.scales.all <- holiday.features$prior_scale
|
||||
prior.scales <- c()
|
||||
|
||||
holiday.features <- dplyr::select(holiday.features, -ds, -prior_scale)
|
||||
holiday.features[is.na(holiday.features)] <- 0
|
||||
|
||||
for (name in colnames(holiday.features)) {
|
||||
rows <- !is.na(holiday.features[[name]]) & (holiday.features[[name]] == 1)
|
||||
ps <- unique(prior.scales.all[rows])
|
||||
if (length(ps) > 1) {
|
||||
sn <- strsplit(name, '_delim_', fixed = TRUE)[[1]][1]
|
||||
stop('Holiday ', sn, ' does not have a consistent prior scale ',
|
||||
'specification')
|
||||
}
|
||||
prior.scales <- c(prior.scales, ps)
|
||||
}
|
||||
return(list(holiday.features = holiday.features,
|
||||
prior.scales = prior.scales))
|
||||
}
|
||||
|
||||
#' Add an additional regressor to be used for fitting and predicting.
|
||||
|
|
@ -617,10 +641,9 @@ make_all_seasonality_features <- function(m, df) {
|
|||
|
||||
# Holiday features
|
||||
if (!is.null(m$holidays)) {
|
||||
features <- make_holiday_features(m, df$ds)
|
||||
seasonal.features <- cbind(seasonal.features, features)
|
||||
prior.scales <- c(prior.scales,
|
||||
m$holidays.prior.scale * rep(1, ncol(features)))
|
||||
hf <- make_holiday_features(m, df$ds)
|
||||
seasonal.features <- cbind(seasonal.features, hf$holiday.features)
|
||||
prior.scales <- c(prior.scales, hf$prior.scales)
|
||||
}
|
||||
|
||||
# Additional regressors
|
||||
|
|
|
|||
|
|
@ -208,19 +208,60 @@ test_that("holidays", {
|
|||
ds = seq(prophet:::set_date('2016-12-20'),
|
||||
prophet:::set_date('2016-12-31'), by='d'))
|
||||
m <- prophet(train, holidays = holidays, fit = FALSE)
|
||||
feats <- prophet:::make_holiday_features(m, df$ds)
|
||||
out <- prophet:::make_holiday_features(m, df$ds)
|
||||
feats <- out$holiday.features
|
||||
priors <- out$prior.scales
|
||||
expect_equal(nrow(feats), nrow(df))
|
||||
expect_equal(ncol(feats), 2)
|
||||
expect_equal(sum(colSums(feats) - c(1, 1)), 0)
|
||||
expect_true(all(priors == c(10., 10.)))
|
||||
|
||||
holidays = data.frame(ds = c('2016-12-25'),
|
||||
holiday = c('xmas'),
|
||||
lower_window = c(-1),
|
||||
upper_window = c(10))
|
||||
m <- prophet(train, holidays = holidays, fit = FALSE)
|
||||
feats <- prophet:::make_holiday_features(m, df$ds)
|
||||
out <- prophet:::make_holiday_features(m, df$ds)
|
||||
feats <- out$holiday.features
|
||||
priors <- out$prior.scales
|
||||
expect_equal(nrow(feats), nrow(df))
|
||||
expect_equal(ncol(feats), 12)
|
||||
expect_true(all(priors == rep(10, 12)))
|
||||
# Check prior specifications
|
||||
holidays <- data.frame(
|
||||
ds = prophet:::set_date(c('2016-12-25', '2017-12-25')),
|
||||
holiday = c('xmas', 'xmas'),
|
||||
lower_window = c(-1, -1),
|
||||
upper_window = c(0, 0),
|
||||
prior_scale = c(5., 5.)
|
||||
)
|
||||
m <- prophet(holidays = holidays, fit = FALSE)
|
||||
out <- prophet:::make_holiday_features(m, df$ds)
|
||||
priors <- out$prior.scales
|
||||
expect_true(all(priors == c(5., 5.)))
|
||||
# 2 different priors
|
||||
holidays2 <- data.frame(
|
||||
ds = prophet:::set_date(c('2012-06-06', '2013-06-06')),
|
||||
holiday = c('seans-bday', 'seans-bday'),
|
||||
lower_window = c(0, 0),
|
||||
upper_window = c(1, 1),
|
||||
prior_scale = c(8, 8)
|
||||
)
|
||||
holiday2 <- rbind(holidays, holidays2)
|
||||
m <- prophet(holidays = holidays2, fit = FALSE)
|
||||
out <- prophet:::make_holiday_features(m, df$ds)
|
||||
priors <- out$prior.scales
|
||||
expect_true(all(priors == c(8,8, 5, 5)))
|
||||
# Check incompatible priors
|
||||
holidays <- data.frame(
|
||||
ds = prophet:::set_date(c('2016-12-25', '2016-12-27')),
|
||||
holiday = c('xmasish', 'xmasish'),
|
||||
lower_window = c(-1, -1),
|
||||
upper_window = c(0, 0),
|
||||
prior_scale = c(5., 6.)
|
||||
)
|
||||
m <- prophet(holidays = holidays, fit = FALSE)
|
||||
expect_error(prophet:::make_holiday_features(m, df$ds))
|
||||
})
|
||||
|
||||
test_that("fit_with_holidays", {
|
||||
|
|
|
|||
|
|
@ -310,8 +310,8 @@ class TestProphet(TestCase):
|
|||
self.assertEqual(priors, [8., 8., 5., 5.])
|
||||
# Check incompatible priors
|
||||
holidays = pd.DataFrame({
|
||||
'ds': pd.to_datetime(['2016-12-25', '2017-12-25']),
|
||||
'holiday': ['xmas', 'xmas'],
|
||||
'ds': pd.to_datetime(['2016-12-25', '2016-12-27']),
|
||||
'holiday': ['xmasish', 'xmasish'],
|
||||
'lower_window': [-1, -1],
|
||||
'upper_window': [0, 0],
|
||||
'prior_scale': [5., 6.],
|
||||
|
|
|
|||
Loading…
Reference in a new issue