External regressors v2 (#283)

Add regressors in R
This commit is contained in:
Simon Kim 2017-08-30 11:04:56 -07:00 committed by Ben Letham
parent 4523315ffc
commit 17efc9aecd
5 changed files with 303 additions and 86 deletions

View file

@ -9,7 +9,7 @@
globalVariables(c(
"ds", "y", "cap", ".",
"component", "dow", "doy", "holiday", "holidays", "holidays_lower", "holidays_upper", "ix",
"lower", "n", "stat", "trend", "row_number",
"lower", "n", "stat", "trend", "row_number", "extra_regressors",
"trend_lower", "trend_upper", "upper", "value", "weekly", "weekly_lower", "weekly_upper",
"x", "yearly", "yearly_lower", "yearly_upper", "yhat", "yhat_lower", "yhat_upper"))
@ -80,6 +80,7 @@ prophet <- function(df = NULL,
weekly.seasonality = 'auto',
daily.seasonality = 'auto',
holidays = NULL,
extra_regressors = NULL, #new
seasonality.prior.scale = 10,
holidays.prior.scale = 10,
changepoint.prior.scale = 0.05,
@ -103,6 +104,7 @@ prophet <- function(df = NULL,
weekly.seasonality = weekly.seasonality,
daily.seasonality = daily.seasonality,
holidays = holidays,
extra_regressors = extra_regressors,
seasonality.prior.scale = seasonality.prior.scale,
changepoint.prior.scale = changepoint.prior.scale,
holidays.prior.scale = holidays.prior.scale,
@ -130,6 +132,48 @@ prophet <- function(df = NULL,
return(m)
}
#' Validates the name of a seasonality, holiday, or regressor
#'
#' @param m Prophet object.
#' @param name string
#' @param check_holidays bool check if name already used for holiday
#' @param check_seasonalities bool check if name already used for seasonality
#' @param check_regressors bool check if name already used for regressor
#'
validate_column_name <- function(m, name, check_holidays = TRUE,
check_seasonalities = TRUE, check_regressors = TRUE) {
if (grepl("_delim_", name)) {
stop('Holiday name cannot contain "_delim_"')
}
reserved_names = c('trend', 'seasonal', 'seasonalities', 'daily', 'weekly', 'yearly',
'holidays', 'zeros', 'extra_regressors', 'yhat')
rn_l = paste(reserved_names,"_lower",sep="")
rn_u = paste(reserved_names,"_upper",sep="")
reserved_names = c(reserved_names, rn_l, rn_u, c("ds","y"));
if(name %in% reserved_names){
error_message = paste("Name ", name, " is reserved.");
stop(error_message)
}
if(check_holidays & !is.null(m$holidays) & (name %in% unique(m$holidays$holiday))){
error_message = paste("Name ", name, " already used for a holiday.");
stop(error_message)
}
#m$yearly.seasonality
if(check_seasonalities & (name %in% m$seasonalities[[name]])){
error_message = paste("Name ", name, " already used for a seasonality.");
stop(error_message)
}
if(check_regressors & (name %in% m$extra_regressors[[name]])){
error_message = paste("Name ", name, " already used for an added regressor.");
stop(error_message)
}
}
#' Validates the inputs to Prophet.
#'
#' @param m Prophet object.
@ -161,13 +205,7 @@ validate_inputs <- function(m) {
}
}
for (h in unique(m$holidays$holiday)) {
if (grepl("_delim_", h)) {
stop('Holiday name cannot contain "_delim_"')
}
if (h %in% c('zeros', 'yearly', 'weekly', 'daily', 'yhat', 'seasonal',
'trend')) {
stop(paste0('Holiday name "', h, '" reserved.'))
}
validate_column_name(m,h, check_holidays=FALSE)
}
}
}
@ -219,7 +257,7 @@ compile_stan_model <- function(model) {
#' Convert date vector
#'
#' Convert the date to POSIXct object
#' Convert the date to POSIXct object
#'
#' @param ds Date vector, can be consisted of characters
#' @param tz string time zone
@ -230,18 +268,17 @@ compile_stan_model <- function(model) {
set_date <- function(ds = NULL, tz = "GMT") {
if (length(ds) == 0) {
return(NULL)
}
}
if (is.factor(ds)) {
ds <- as.character(ds)
}
if (min(nchar(ds)) < 12) {
ds <- as.POSIXct(ds, format = "%Y-%m-%d", tz = tz)
} else {
ds <- as.POSIXct(ds, format = "%Y-%m-%d %H:%M:%S", tz = tz)
}
attr(ds, "tzone") <- tz
return(ds)
}
@ -267,7 +304,8 @@ time_diff <- function(ds1, ds2, units = "days") {
#' and predicting.
#'
#' @param m Prophet object.
#' @param df Data frame with columns ds, y, and cap if logistic growth.
#' @param df Data frame with columns ds, y, and cap if logistic growth.Any
#' specified additional regressors must also be present.
#' @param initialize_scales Boolean set scaling factors in m from df.
#'
#' @return list with items 'df' and 'm'.
@ -283,6 +321,9 @@ setup_dataframe <- function(m, df, initialize_scales = FALSE) {
'format. Either %Y-%m-%d or %Y-%m-%d %H:%M:%S'))
}
#names(m$extra_regressors)
df <- df %>%
dplyr::arrange(ds)
@ -343,8 +384,10 @@ set_changepoints <- function(m) {
m$n.changepoints)
}
if (m$n.changepoints > 0) {
cp.indexes <- round(seq.int(1, hist.size,
length.out = (m$n.changepoints + 1))[-1])
# Place potential changepoints evenly through the first 80 pcnt of
# the history.
cp.indexes <- round(seq.int(1, floor(nrow(m$history) * .8),
length.out = (m$n.changepoints + 1))[-1])
m$changepoints <- m$history$ds[cp.indexes]
} else {
m$changepoints <- c()
@ -422,8 +465,7 @@ make_seasonality_features <- function(dates, period, series.order, prefix) {
make_holiday_features <- function(m, dates) {
scale.ratio <- m$holidays.prior.scale / m$seasonality.prior.scale
# Strip dates to be just days, for joining on holidays
dates <- set_date(format(dates, "%Y-%m-%d"))
dates <- set_date(format(dates))
wide <- m$holidays %>%
dplyr::mutate(ds = set_date(ds)) %>%
dplyr::group_by(holiday, ds) %>%
@ -450,6 +492,44 @@ make_holiday_features <- function(m, dates) {
return(holiday.mat)
}
#'Add an additional regressor to be used for fitting and predicting.
#'
#'The dataframe passed to `fit` and `predict` will have a column with the
#'specified name to be used as a regressor. When standardize='auto', the
#'regressor will be standardized unless it is binary. The regression
#'coefficient is given a prior with the specified scale parameter.
#'Decreasing the prior scale will add additional regularization. If no
#'prior scale is provided, self.holidays_prior_scale will be used.
#'
#' @param m
#' @param name string name of the regressor
#' @param prior_scale optional float scale for the normal prior. If not
#' provided, self.holidays_prior_scale will be used.
#' @param standardize optional, specify whether this regressor will be
#' standardized prior to fitting. Can be 'auto' (standardize if not
#' binary), True, or False.
#' @return The prophet model with the regressor added.
#' @export
add_regressor <- function(m, prior_scale=0.0, standardize='auto'){
if(!is.null(m$history)){
stop('Regressors must be added prior to model fitting.')
}
validate_column_name(m,check_regressors=FALSE);
if(prior_scale == 0){
prior_scale = m$holidays.prior.scale
}
if(prior_scale < 0){
stop("prior_scale is less than 0");
}
m$extra_regressors = list(name = list(prior_scale = prior_scale,
standardize=standardize,
mu=0,
std=1.0))
return(m)
}
#' Add a seasonal component with specified period and number of Fourier
#' components.
#'
@ -468,37 +548,93 @@ make_holiday_features <- function(m, dates) {
#' @export
add_seasonality <- function(m, name, period, fourier.order) {
if (!is.null(m$holidays)) {
if (name %in% (unique(m$holidays$holiday) %>% as.character())) {
stop('Name "', name, '" already used for holiday')
}
stop("Seasonality must be added prior to model fitting.")
}
if (!(name %in% c('daily', 'weekly', 'yearly'))) {
validate_column_name(name,check_seasonalities=FALSE)
}
m$seasonalities[[name]] <- c(period, fourier.order)
return(m)
}
#' Dataframe with seasonality features.
#' Includes seasonality features, holiday features, and added regressors.
#'
#' @param m Prophet object.
#' @param df Dataframe with dates for computing seasonality features.
#'
#' @return Dataframe with seasonality.
#' @return Dataframe with regressor features,
#' list of prior scales for each colum of the features and any added regressors
#'
#' @keywords internal
make_all_seasonality_features <- function(m, df) {
seasonal.features <- data.frame(zeros = rep(0, nrow(df)))
#seasonal.features <- data.frame(zeros = rep(0, nrow(df)))
seasonal.features <- c();
prior_scales <- c();
for (name in names(m$seasonalities)) {
period <- m$seasonalities[[name]][1]
series.order <- m$seasonalities[[name]][2]
seasonal.features <- cbind(
seasonal.features,
make_seasonality_features(df$ds, period, series.order, name))
features = make_seasonality_features(df$ds, period, series.order, name);
if(is.null(seasonal.features)){
seasonal.features <- features;
}
seasonal.features <- cbind(seasonal.features, features) #test append 와 문제가 없는지 확인
prior_scales = c(prior_scales, m$seasonality.prior.scale * dim(features)[2]);
}
if(!is.null(m$holidays)) {
seasonal.features <- cbind(
seasonal.features,
make_holiday_features(m, df$ds))
features = make_holiday_features(m, df$ds);
seasonal.features <- cbind(seasonal.features, features) #test
prior_scales <- c(prior_scales, m$holiday_prior_scale * dim(features)[2]);
}
return(seasonal.features)
# Additional regressors
for(name in names(m$extra_regressors)){
seasonal.features = cbind(seasonal.features, df[name]); #test
prior_scales = cbind(prior_scales, m$extra_regressors[[name]][[prior_scale]])
}
if(length(df) == 0){
seasonal.features =cbind(seasonal.features,data.frame(zeros = rep(0, nrow(df))));
prior_scales = c(prior_scales,0.1)
}
return(list(seasonal.features=seasonal.features, prior_scales=prior_scales))
}
#' Get number of Fourier components for built-in seasonalities.
#'
#' @param m Prophet object.
#' @param name String name of the seasonality component.
#' @param arg 'auto', TRUE, FALSE, or number of Fourier components as
#' provided.
#' @param auto.disable Bool if seasonality should be disabled when 'auto'.
#' @param default.order Int default Fourier order.
#'
#' @return Number of Fourier components, or 0 for disabled.
#'
#' @keywords internal
parse_seasonality_args <- function(m, name, arg, auto.disable, default.order) {
if (arg == 'auto') {
fourier.order <- 0
if (name %in% names(m$seasonalities)) {
warning('Found custom seasonality named "', name,
'", disabling built-in ', name, ' seasonality.')
} else if (auto.disable) {
warning('Disabling ', name, ' seasonality. Run prophet with ', name,
'.seasonality=TRUE to override this.')
} else {
fourier.order <- default.order
}
} else if (arg == TRUE) {
fourier.order <- default.order
} else if (arg == FALSE) {
fourier.order <- 0
} else {
fourier.order <- arg
}
return(fourier.order)
}
#' Get number of Fourier components for built-in seasonalities.
@ -666,7 +802,8 @@ fit.prophet <- function(m, df, ...) {
m <- out$m
m$history <- history
m <- set_auto_seasonalities(m)
seasonal.features <- make_all_seasonality_features(m, history)
seasonal.features <- make_all_seasonality_features(m, history)[[1]]
prior_scales <- make_all_seasonality_features(m, history)[[2]]
m <- set_changepoints(m)
A <- get_changepoint_matrix(m)
@ -681,7 +818,7 @@ fit.prophet <- function(m, df, ...) {
A = A,
t_change = array(m$changepoints.t),
X = as.matrix(seasonal.features),
sigma = m$seasonality.prior.scale,
sigma = prior_scales,
tau = m$changepoint.prior.scale
)
@ -882,7 +1019,7 @@ predict_trend <- function(model, df) {
#'
#' @keywords internal
predict_seasonal_components <- function(m, df) {
seasonal.features <- make_all_seasonality_features(m, df)
seasonal.features <- make_all_seasonality_features(m, df)[[1]]
lower.p <- (1 - m$interval.width)/2
upper.p <- (1 + m$interval.width)/2
@ -893,32 +1030,65 @@ predict_seasonal_components <- function(m, df) {
extra = "merge", fill = "right") %>%
dplyr::filter(component != 'zeros')
if (nrow(components) > 0) {
component.predictions <- components %>%
dplyr::group_by(component) %>% dplyr::do({
comp <- (as.matrix(seasonal.features[, .$col])
%*% t(m$params$beta[, .$col, drop = FALSE])) * m$y.scale
dplyr::data_frame(ix = 1:nrow(seasonal.features),
mean = rowMeans(comp, na.rm = TRUE),
lower = apply(comp, 1, stats::quantile, lower.p,
na.rm = TRUE),
upper = apply(comp, 1, stats::quantile, upper.p,
na.rm = TRUE))
}) %>%
tidyr::gather(stat, value, mean, lower, upper) %>%
dplyr::mutate(stat = ifelse(stat == 'mean', '', paste0('_', stat))) %>%
tidyr::unite(component, component, stat, sep="") %>%
tidyr::spread(component, value) %>%
dplyr::select(-ix)
#components <-
components <- rbind(components[,c(1,3)], data.frame("component"=rep("seasonal"),
"col"=c(1:dim(seasonal.features)[2])));
component.predictions$seasonal <- rowSums(
component.predictions[unique(components$component)])
} else {
component.predictions <- data.frame(seasonal = rep(0, nrow(df)))
components <- add_group_component(m,components, 'seasonalities', names(m$seasonalities));
if(!is.null(m$holidays)){
components <- add_group_component(m,components, 'holidays', unique(m$holidays$holiday));
}
components <- add_group_component(m,components, 'extra_regressors', names(m$extra_regressors));
# I am stuck on here: I am little confused that do I need to set
# components as list or dataframe ??
#
#if (nrow(components) > 0) {
component.predictions <- components %>%
dplyr::group_by(component) %>% dplyr::do({
comp <- (as.matrix(seasonal.features[, .$col])
%*% t(m$params$beta[, .$col, drop = FALSE])) * m$y.scale
dplyr::data_frame(ix = 1:nrow(seasonal.features),
mean = rowMeans(comp, na.rm = TRUE),
lower = apply(comp, 1, stats::quantile, lower.p,
na.rm = TRUE),
upper = apply(comp, 1, stats::quantile, upper.p,
na.rm = TRUE))
}) %>%
tidyr::gather(stat, value, mean, lower, upper) %>%
dplyr::mutate(stat = ifelse(stat == 'mean', '', paste0('_', stat))) %>%
tidyr::unite(component, component, stat, sep="") %>%
tidyr::spread(component, value) %>%
dplyr::select(-ix)
component.predictions$seasonal <- rowSums(
component.predictions[unique(components$component)])
# } else {
# component.predictions <- data.frame(seasonal = rep(0, nrow(df)))
# }
return(component.predictions)
}
#' Adds a component with given name that contains all of the components
#' in group.
#'
#' @param m Prophet object.
#' @param components Dataframe with components.
#' @param name Name of new group component.
#' @param group List of components that form the group.
#'
#' @return Dataframe with components.
#'
#' @keywords internal
add_group_component <- function(m, components, name, group) {
loc = (components$component %in% group);
new_comp = components[loc,];
new_comp$component = name;
components= rbind(components, new_comp);
return(components);
}
#' Prophet posterior predictive samples.
#'
#' @param m Prophet object.
@ -933,7 +1103,7 @@ sample_posterior_predictive <- function(m, df) {
samp.per.iter <- max(1, ceiling(m$uncertainty.samples / n.iterations))
nsamp <- n.iterations * samp.per.iter # The actual number of samples
seasonal.features <- make_all_seasonality_features(m, df)
seasonal.features <- make_all_seasonality_features(m, df)[[1]]
sim.values <- list("trend" = matrix(, nrow = nrow(df), ncol = nsamp),
"seasonal" = matrix(, nrow = nrow(df), ncol = nsamp),
"yhat" = matrix(, nrow = nrow(df), ncol = nsamp))
@ -950,7 +1120,7 @@ sample_posterior_predictive <- function(m, df) {
}
}
return(sim.values)
}
}
#' Sample from the posterior predictive distribution.
#'
@ -959,13 +1129,14 @@ sample_posterior_predictive <- function(m, df) {
#' (column cap) if logistic growth.
#'
#' @return A list with items "trend", "seasonal", and "yhat" containing
#' posterior predictive samples for that component.
#' posterior predictive samples for that component. "seasonal" is the sum
#' of seasonalities, holidays, and added regressors.
#'
#' @export
predictive_samples <- function(m, df) {
df <- setup_dataframe(m, df)$df
sim.values <- sample_posterior_predictive(m, df)
return(sim.values)
df <- setup_dataframe(m, df)$df
sim.values <- sample_posterior_predictive(m, df)
return(sim.values)
}
#' Prophet uncertainty intervals.
@ -1197,8 +1368,8 @@ plot.prophet <- function(x, fcst, uncertainty = TRUE, plot_cap = TRUE,
#' @export
#' @importFrom dplyr "%>%"
prophet_plot_components <- function(
m, fcst, uncertainty = TRUE, plot_cap = TRUE, weekly_start = 0,
yearly_start = 0) {
m, fcst, uncertainty = TRUE, plot_cap = TRUE, weekly_start = 0,
yearly_start = 0) {
df <- df_for_plotting(m, fcst)
# Plot the trend
panels <- list(plot_trend(df, uncertainty, plot_cap))
@ -1287,11 +1458,11 @@ plot_holidays <- function(m, df, uncertainty = TRUE) {
ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
if (uncertainty) {
gg.holidays <- gg.holidays +
ggplot2::geom_ribbon(ggplot2::aes(ymin = holidays_lower,
ymax = holidays_upper),
alpha = 0.2,
fill = "#0072B2",
na.rm = TRUE)
ggplot2::geom_ribbon(ggplot2::aes(ymin = holidays_lower,
ymax = holidays_upper),
alpha = 0.2,
fill = "#0072B2",
na.rm = TRUE)
}
return(gg.holidays)
}
@ -1311,9 +1482,10 @@ plot_weekly <- function(m, uncertainty = TRUE, weekly_start = 0) {
# Compute weekly seasonality for a Sun-Sat sequence of dates.
df.w <- data.frame(
ds=seq(set_date('2017-01-01'), by='d', length.out=7) +
weekly_start, cap=1.)
weekly_start, cap=1.)
df.w <- setup_dataframe(m, df.w)$df
seas <- predict_seasonal_components(m, df.w)
print(seas)
seas$dow <- factor(weekdays(df.w$ds), levels=weekdays(df.w$ds))
gg.weekly <- ggplot2::ggplot(seas, ggplot2::aes(x = dow, y = weekly,
@ -1322,11 +1494,11 @@ plot_weekly <- function(m, uncertainty = TRUE, weekly_start = 0) {
ggplot2::labs(x = "Day of week")
if (uncertainty) {
gg.weekly <- gg.weekly +
ggplot2::geom_ribbon(ggplot2::aes(ymin = weekly_lower,
ymax = weekly_upper),
alpha = 0.2,
fill = "#0072B2",
na.rm = TRUE)
ggplot2::geom_ribbon(ggplot2::aes(ymin = weekly_lower,
ymax = weekly_upper),
alpha = 0.2,
fill = "#0072B2",
na.rm = TRUE)
}
return(gg.weekly)
}
@ -1346,7 +1518,7 @@ plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0) {
# Compute yearly seasonality for a Jan 1 - Dec 31 sequence of dates.
df.y <- data.frame(
ds=seq(set_date('2017-01-01'), by='d', length.out=365) +
yearly_start, cap=1.)
yearly_start, cap=1.)
df.y <- setup_dataframe(m, df.y)$df
seas <- predict_seasonal_components(m, df.y)
seas$ds <- df.y$ds
@ -1358,11 +1530,11 @@ plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0) {
ggplot2::scale_x_datetime(labels = scales::date_format('%B %d'))
if (uncertainty) {
gg.yearly <- gg.yearly +
ggplot2::geom_ribbon(ggplot2::aes(ymin = yearly_lower,
ymax = yearly_upper),
alpha = 0.2,
fill = "#0072B2",
na.rm = TRUE)
ggplot2::geom_ribbon(ggplot2::aes(ymin = yearly_lower,
ymax = yearly_upper),
alpha = 0.2,
fill = "#0072B2",
na.rm = TRUE)
}
return(gg.yearly)
}
@ -1449,6 +1621,51 @@ prophet_copy <- function(m, cutoff = NULL) {
uncertainty.samples = m$uncertainty.samples,
fit = FALSE,
))
#' Sample from the posterior predictive distribution.
#'
#' @param m Prophet model object.
#' @param name String name of the seasonality.
#' @param uncertainty Boolean to plot uncertainty intervals.
#'
#' @return A ggplot2 plot.
#'
#' @keywords internal
plot_seasonality <- function(m, name, uncertainty = TRUE) {
# Compute seasonality from Jan 1 through a single period.
start <- set_date('2017-01-01')
period <- m$seasonalities[[name]][1]
end <- start + period * 24 * 3600
plot.points <- 200
df.y <- data.frame(
ds=seq(from=start, to=end, length.out=plot.points), cap=1.)
df.y <- setup_dataframe(m, df.y)$df
seas <- predict_seasonal_components(m, df.y)
seas$ds <- df.y$ds
gg.s <- ggplot2::ggplot(
seas, ggplot2::aes_string(x = 'ds', y = name, group = 1)) +
ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
if (period <= 2) {
fmt.str <- '%T'
}
else if (period < 14) {
fmt.str <- '%m/%d %R'
} else {
fmt.str <- '%m/%d'
}
gg.s <- gg.s +
ggplot2::scale_x_datetime(labels = scales::date_format(fmt.str))
if (uncertainty) {
gg.s <- gg.s +
ggplot2::geom_ribbon(
ggplot2::aes_string(
ymin = paste0(name, '_lower'), ymax = paste0(name, '_upper')
),
alpha = 0.2,
fill = "#0072B2",
na.rm = TRUE)
}
return(gg.s)
}
# fb-block 3

View file

@ -7,7 +7,7 @@ data {
matrix[T, S] A; // Split indicators
real t_change[S]; // Index of changepoints
matrix[T,K] X; // season vectors
real<lower=0> sigma; // scale on seasonality prior
vector[K] sigmas; // scale on seasonality prior
real<lower=0> tau; // scale on changepoints prior
}
@ -33,7 +33,7 @@ model {
m ~ normal(0, 5);
delta ~ double_exponential(0, tau);
sigma_obs ~ normal(0, 0.5);
beta ~ normal(0, sigma);
beta ~ normal(0, sigmas);
// Likelihood
y ~ normal((k + A * delta) .* t + (m + A * gamma) + X * beta, sigma_obs);

View file

@ -8,7 +8,7 @@ data {
matrix[T, S] A; // Split indicators
real t_change[S]; // Index of changepoints
matrix[T,K] X; // season vectors
real<lower=0> sigma; // scale on seasonality prior
vector[K] sigmas; // scale on seasonality prior
real<lower=0> tau; // scale on changepoints prior
}
@ -45,7 +45,7 @@ model {
m ~ normal(0, 5);
delta ~ double_exponential(0, tau);
sigma_obs ~ normal(0, 0.1);
beta ~ normal(0, sigma);
beta ~ normal(0, sigmas);
// Likelihood
y ~ normal(cap ./ (1 + exp(-(k + A * delta) .* (t - (m + A * gamma)))) + X * beta, sigma_obs);

View file

@ -4,8 +4,8 @@
\alias{plot.prophet}
\title{Plot the prophet forecast.}
\usage{
\method{plot}{prophet}(x, fcst, uncertainty = TRUE, plot_cap = TRUE,
xlabel = "ds", ylabel = "y", ...)
plot.prophet(x, fcst, uncertainty = TRUE, plot_cap = TRUE, xlabel = "ds",
ylabel = "y", ...)
}
\arguments{
\item{x}{Prophet object.}

View file

@ -1,5 +1,5 @@
- title: Docs
href: /docs/
href: docs/
category: docs
- title: GitHub