mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-05-25 22:26:34 +00:00
Added growth='flat' functionality in R (#1778)
* added `flat_growth_init()` function * added validation for 'flat' * changed `fit.prophet()` to handle `growth='flat'` * added `trend='flat'` capabilities to `sample_predictive_trend()` and `fit.prophet()` * updated STAN code to handle flat trend * [Syntax fix] Removed unnecessary bracket * updated documentation * undid formatting that was accidentally applied by autoformatter * undid more formatting that was accidentally applied by autoformatter * added tests * typo in `sample_predictive_trend()` * updated notebook with example in R * updated documentation
This commit is contained in:
parent
73b53658e1
commit
2d8e6c7fd1
11 changed files with 186 additions and 32 deletions
|
|
@ -20,8 +20,8 @@ globalVariables(c(
|
|||
#' also have a column cap that specifies the capacity at each ds. If not
|
||||
#' provided, then the model object will be instantiated but not fit; use
|
||||
#' fit.prophet(m, df) to fit the model.
|
||||
#' @param growth String 'linear' or 'logistic' to specify a linear or logistic
|
||||
#' trend.
|
||||
#' @param growth String 'linear', 'logistic', or 'flat' to specify a linear, logistic
|
||||
#' or flat trend.
|
||||
#' @param changepoints Vector of dates at which to include potential
|
||||
#' changepoints. If not specified, potential changepoints are selected
|
||||
#' automatically.
|
||||
|
|
@ -154,8 +154,8 @@ prophet <- function(df = NULL,
|
|||
#'
|
||||
#' @keywords internal
|
||||
validate_inputs <- function(m) {
|
||||
if (!(m$growth %in% c('linear', 'logistic'))) {
|
||||
stop("Parameter 'growth' should be 'linear' or 'logistic'.")
|
||||
if (!(m$growth %in% c('linear', 'logistic', 'flat'))) {
|
||||
stop("Parameter 'growth' should be 'linear', 'logistic', or 'flat'.")
|
||||
}
|
||||
if ((m$changepoint.range < 0) | (m$changepoint.range > 1)) {
|
||||
stop("Parameter 'changepoint.range' must be in [0, 1]")
|
||||
|
|
@ -1053,7 +1053,28 @@ set_auto_seasonalities <- function(m) {
|
|||
return(m)
|
||||
}
|
||||
|
||||
#' Initialize linear growth.
|
||||
#' Initialize flat growth.
|
||||
#'
|
||||
#' Provides a strong initialization for flat growth by setting the
|
||||
#' growth to 0 and calculates the offset parameter that pass the
|
||||
#' function through the mean of the the y_scaled values.
|
||||
#'
|
||||
#' @param df Data frame with columns ds (date), y_scaled (scaled time series),
|
||||
#' and t (scaled time).
|
||||
#'
|
||||
#' @return A vector (k, m) with the rate (k) and offset (m) of the flat
|
||||
#' growth function.
|
||||
#'
|
||||
#' @keywords internal
|
||||
flat_growth_init <- function(df) {
|
||||
# Initialize the rate
|
||||
k <- 0
|
||||
# And the offset
|
||||
m <- mean(df$y_scaled)
|
||||
return(c(k, m))
|
||||
}
|
||||
|
||||
#' Initialize constant growth.
|
||||
#'
|
||||
#' Provides a strong initialization for linear growth by calculating the
|
||||
#' growth and offset parameters that pass the function through the first and
|
||||
|
|
@ -1177,7 +1198,7 @@ fit.prophet <- function(m, df, ...) {
|
|||
X = as.matrix(seasonal.features),
|
||||
sigmas = array(prior.scales),
|
||||
tau = m$changepoint.prior.scale,
|
||||
trend_indicator = as.numeric(m$growth == 'logistic'),
|
||||
trend_indicator = switch(m$growth, 'linear'=0, 'logistic'=1, 'flat'=2),
|
||||
s_a = array(component.cols$additive_terms),
|
||||
s_m = array(component.cols$multiplicative_terms)
|
||||
)
|
||||
|
|
@ -1186,7 +1207,10 @@ fit.prophet <- function(m, df, ...) {
|
|||
if (m$growth == 'linear') {
|
||||
dat$cap <- rep(0, nrow(history)) # Unused inside Stan
|
||||
kinit <- linear_growth_init(history)
|
||||
} else {
|
||||
} else if (m$growth == 'flat') {
|
||||
dat$cap <- rep(0, nrow(history)) # Unused inside Stan
|
||||
kinit <- flat_growth_init(history)
|
||||
} else if (m$growth == 'logistic') {
|
||||
dat$cap <- history$cap_scaled # Add capacities to the Stan data
|
||||
kinit <- logistic_growth_init(history)
|
||||
}
|
||||
|
|
@ -1206,7 +1230,8 @@ fit.prophet <- function(m, df, ...) {
|
|||
)
|
||||
}
|
||||
|
||||
if (min(history$y) == max(history$y)) {
|
||||
if (min(history$y) == max(history$y) &
|
||||
(m$growth %in% c('linear', 'flat'))) {
|
||||
# Nothing to fit.
|
||||
m$params <- stan_init()
|
||||
m$params$sigma_obs <- 0.
|
||||
|
|
@ -1318,6 +1343,19 @@ predict.prophet <- function(object, df = NULL, ...) {
|
|||
return(df)
|
||||
}
|
||||
|
||||
#' Evaluate the flat trend function.
|
||||
#'
|
||||
#' @param t Vector of times on which the function is evaluated.
|
||||
#' @param m Float initial offset.
|
||||
#'
|
||||
#' @return Vector y(t).
|
||||
#'
|
||||
#' @keywords internal
|
||||
flat_trend <- function(t, m) {
|
||||
y <- rep(m, length(t))
|
||||
return(y)
|
||||
}
|
||||
|
||||
#' Evaluate the piecewise linear function.
|
||||
#'
|
||||
#' @param t Vector of times on which the function is evaluated.
|
||||
|
|
@ -1392,7 +1430,9 @@ predict_trend <- function(model, df) {
|
|||
t <- df$t
|
||||
if (model$growth == 'linear') {
|
||||
trend <- piecewise_linear(t, deltas, k, param.m, model$changepoints.t)
|
||||
} else {
|
||||
} else if (model$growth == 'flat') {
|
||||
trend <- flat_trend(t, param.m)
|
||||
} else if (model$growth == 'logistic') {
|
||||
cap <- df$cap_scaled
|
||||
trend <- piecewise_logistic(
|
||||
t, cap, deltas, k, param.m, model$changepoints.t)
|
||||
|
|
@ -1592,7 +1632,9 @@ sample_predictive_trend <- function(model, df, iteration) {
|
|||
# Get the corresponding trend
|
||||
if (model$growth == 'linear') {
|
||||
trend <- piecewise_linear(t, deltas, k, param.m, changepoint.ts)
|
||||
} else {
|
||||
} else if (model$growth == 'flat') {
|
||||
trend <- flat_trend(t, param.m)
|
||||
} else if (model$growth == 'logistic') {
|
||||
cap <- df$cap_scaled
|
||||
trend <- piecewise_logistic(t, cap, deltas, k, param.m, changepoint.ts)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,7 +27,6 @@ functions {
|
|||
}
|
||||
|
||||
// Logistic trend functions
|
||||
|
||||
vector logistic_gamma(real k, real m, vector delta, vector t_change, int S) {
|
||||
vector[S] gamma; // adjusted offsets, for piecewise continuity
|
||||
vector[S + 1] k_s; // actual rate in each segment
|
||||
|
|
@ -62,7 +61,6 @@ functions {
|
|||
}
|
||||
|
||||
// Linear trend function
|
||||
|
||||
vector linear_trend(
|
||||
real k,
|
||||
real m,
|
||||
|
|
@ -73,6 +71,14 @@ functions {
|
|||
) {
|
||||
return (k + A * delta) .* t + (m + A * (-t_change .* delta));
|
||||
}
|
||||
|
||||
// Flat trend function
|
||||
vector flat_trend(
|
||||
real m,
|
||||
int T
|
||||
) {
|
||||
return rep_vector(m, T);
|
||||
}
|
||||
}
|
||||
|
||||
data {
|
||||
|
|
@ -86,7 +92,7 @@ data {
|
|||
matrix[T,K] X; // Regressors
|
||||
vector[K] sigmas; // Scale on seasonality prior
|
||||
real<lower=0> tau; // Scale on changepoints prior
|
||||
int trend_indicator; // 0 for linear, 1 for logistic
|
||||
int trend_indicator; // 0 for linear, 1 for logistic, 2 for flat
|
||||
vector[K] s_a; // Indicator of additive features
|
||||
vector[K] s_m; // Indicator of multiplicative features
|
||||
}
|
||||
|
|
@ -104,6 +110,17 @@ parameters {
|
|||
vector[K] beta; // Regressor coefficients
|
||||
}
|
||||
|
||||
transformed parameters {
|
||||
vector[T] trend;
|
||||
if (trend_indicator == 0) {
|
||||
trend = linear_trend(k, m, delta, t, A, t_change);
|
||||
} else if (trend_indicator == 1) {
|
||||
trend = logistic_trend(k, m, delta, t, cap, A, t_change, S);
|
||||
} else if (trend_indicator == 2) {
|
||||
trend = flat_trend(m, T);
|
||||
}
|
||||
}
|
||||
|
||||
model {
|
||||
//priors
|
||||
k ~ normal(0, 5);
|
||||
|
|
@ -113,19 +130,10 @@ model {
|
|||
beta ~ normal(0, sigmas);
|
||||
|
||||
// Likelihood
|
||||
if (trend_indicator == 0) {
|
||||
y ~ normal(
|
||||
linear_trend(k, m, delta, t, A, t_change)
|
||||
y ~ normal(
|
||||
trend
|
||||
.* (1 + X * (beta .* s_m))
|
||||
+ X * (beta .* s_a),
|
||||
sigma_obs
|
||||
);
|
||||
} else if (trend_indicator == 1) {
|
||||
y ~ normal(
|
||||
logistic_trend(k, m, delta, t, cap, A, t_change, S)
|
||||
.* (1 + X * (beta .* s_m))
|
||||
+ X * (beta .* s_a),
|
||||
sigma_obs
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
22
R/man/flat_growth_init.Rd
Normal file
22
R/man/flat_growth_init.Rd
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/prophet.R
|
||||
\name{flat_growth_init}
|
||||
\alias{flat_growth_init}
|
||||
\title{Initialize flat growth.}
|
||||
\usage{
|
||||
flat_growth_init(df)
|
||||
}
|
||||
\arguments{
|
||||
\item{df}{Data frame with columns ds (date), y_scaled (scaled time series),
|
||||
and t (scaled time).}
|
||||
}
|
||||
\value{
|
||||
A vector (k, m) with the rate (k) and offset (m) of the flat
|
||||
growth function.
|
||||
}
|
||||
\description{
|
||||
Provides a strong initialization for flat growth by setting the
|
||||
growth to 0 and calculates the offset parameter that pass the
|
||||
function through the mean of the the y_scaled values.
|
||||
}
|
||||
\keyword{internal}
|
||||
20
R/man/flat_trend.Rd
Normal file
20
R/man/flat_trend.Rd
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/prophet.R
|
||||
\name{flat_trend}
|
||||
\alias{flat_trend}
|
||||
\title{Evaluate the flat trend function.}
|
||||
\usage{
|
||||
flat_trend(t, m)
|
||||
}
|
||||
\arguments{
|
||||
\item{t}{Vector of times on which the function is evaluated.}
|
||||
|
||||
\item{m}{Float initial offset.}
|
||||
}
|
||||
\value{
|
||||
Vector y(t).
|
||||
}
|
||||
\description{
|
||||
Evaluate the flat trend function.
|
||||
}
|
||||
\keyword{internal}
|
||||
|
|
@ -2,7 +2,7 @@
|
|||
% Please edit documentation in R/prophet.R
|
||||
\name{linear_growth_init}
|
||||
\alias{linear_growth_init}
|
||||
\title{Initialize linear growth.}
|
||||
\title{Initialize constant growth.}
|
||||
\usage{
|
||||
linear_growth_init(df)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ performance_metrics(df, metrics = NULL, rolling_window = 0.1)
|
|||
\item{df}{The dataframe returned by cross_validation.}
|
||||
|
||||
\item{metrics}{An array of performance metrics to compute. If not provided,
|
||||
will use c('mse', 'rmse', 'mae', 'mape', 'mdape', 'coverage').}
|
||||
will use c('mse', 'rmse', 'mae', 'mape', 'mdape', 'smape', 'coverage').}
|
||||
|
||||
\item{rolling_window}{Proportion of data to use in each rolling window for
|
||||
computing the metrics. Should be in [0, 1] to average.}
|
||||
|
|
@ -26,6 +26,7 @@ By default the following metrics are included:
|
|||
'mae': mean absolute error,
|
||||
'mape': mean percent error,
|
||||
'mdape': median percent error,
|
||||
'smape': symmetric mean absolute percentage error,
|
||||
'coverage': coverage of the upper and lower intervals
|
||||
}
|
||||
\details{
|
||||
|
|
|
|||
|
|
@ -32,8 +32,8 @@ also have a column cap that specifies the capacity at each ds. If not
|
|||
provided, then the model object will be instantiated but not fit; use
|
||||
fit.prophet(m, df) to fit the model.}
|
||||
|
||||
\item{growth}{String 'linear' or 'logistic' to specify a linear or logistic
|
||||
trend.}
|
||||
\item{growth}{String 'linear', 'logistic', or 'flat' to specify a linear, logistic
|
||||
or flat trend.}
|
||||
|
||||
\item{changepoints}{Vector of dates at which to include potential
|
||||
changepoints. If not specified, potential changepoints are selected
|
||||
|
|
|
|||
22
R/man/smape.Rd
Normal file
22
R/man/smape.Rd
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/diagnostics.R
|
||||
\name{smape}
|
||||
\alias{smape}
|
||||
\title{Symmetric mean absolute percentage error
|
||||
based on Chen and Yang (2004) formula}
|
||||
\usage{
|
||||
smape(df, w)
|
||||
}
|
||||
\arguments{
|
||||
\item{df}{Cross-validation results dataframe.}
|
||||
|
||||
\item{w}{Aggregation window size.}
|
||||
}
|
||||
\value{
|
||||
Array of symmetric mean absolute percent errors.
|
||||
}
|
||||
\description{
|
||||
Symmetric mean absolute percentage error
|
||||
based on Chen and Yang (2004) formula
|
||||
}
|
||||
\keyword{internal}
|
||||
|
|
@ -56,6 +56,18 @@ test_that("cross_validation_logistic", {
|
|||
expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0)
|
||||
})
|
||||
|
||||
test_that("cross_validation_flat", {
|
||||
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
|
||||
df <- DATA
|
||||
m <- prophet(df, growth = 'flat')
|
||||
df.cv <- cross_validation(
|
||||
m, horizon = 1, units = "days", period = 1, initial = 140)
|
||||
expect_equal(length(unique(df.cv$cutoff)), 2)
|
||||
expect_true(all(df.cv$cutoff < df.cv$ds))
|
||||
df.merged <- dplyr::left_join(df.cv, m$history, by="ds")
|
||||
expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0)
|
||||
})
|
||||
|
||||
test_that("cross_validation_extra_regressors", {
|
||||
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
|
||||
df <- DATA
|
||||
|
|
@ -222,7 +234,7 @@ test_that("copy", {
|
|||
df$cap <- 200.
|
||||
df$binary_feature <- c(rep(0, 255), rep(1, 255))
|
||||
inputs <- list(
|
||||
growth = c('linear', 'logistic'),
|
||||
growth = c('linear', 'logistic', 'flat'),
|
||||
yearly.seasonality = c(TRUE, FALSE),
|
||||
weekly.seasonality = c(TRUE, FALSE),
|
||||
daily.seasonality = c(TRUE, FALSE),
|
||||
|
|
|
|||
|
|
@ -237,9 +237,13 @@ test_that("growth_init", {
|
|||
expect_equal(params[2], 0.5307511, tolerance = 1e-6)
|
||||
|
||||
params <- prophet:::logistic_growth_init(history)
|
||||
|
||||
expect_equal(params[1], 1.507925, tolerance = 1e-6)
|
||||
expect_equal(params[2], -0.08167497, tolerance = 1e-6)
|
||||
|
||||
params <- prophet:::flat_growth_init(history)
|
||||
expect_equal(params[1], 0, tolerance = 1e-6)
|
||||
expect_equal(params[2], 0.49335657, tolerance = 1e-6)
|
||||
|
||||
})
|
||||
|
||||
test_that("piecewise_linear", {
|
||||
|
|
@ -279,6 +283,19 @@ test_that("piecewise_logistic", {
|
|||
expect_equal(y, y.true, tolerance = 1e-6)
|
||||
})
|
||||
|
||||
test_that("flat_trend", {
|
||||
t <- seq(0, 10)
|
||||
m <- 0.5
|
||||
y = prophet:::flat_trend(t, m)
|
||||
y.true <- rep(0.5, length(t))
|
||||
expect_equal(y, y.true, tolerance = 1e-6)
|
||||
|
||||
t <- t[8:length(t)]
|
||||
y = prophet:::flat_trend(t, m)
|
||||
y.true <- y.true[8:length(y.true)]
|
||||
expect_equal(y, y.true, tolerance = 1e-6)
|
||||
})
|
||||
|
||||
test_that("holidays", {
|
||||
holidays <- data.frame(ds = c('2016-12-25'),
|
||||
holiday = c('xmas'),
|
||||
|
|
|
|||
|
|
@ -132,6 +132,16 @@
|
|||
"For time series that exhibit strong seasonality patterns rather than trend changes, it may be useful to force the trend growth rate to be flat. This can be achieved simply by passing `growth=flat` when creating the model:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%R\n",
|
||||
"m <- prophet(df, growth='flat')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
|
|
@ -145,7 +155,7 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This is currently implemented only in the Python version of Prophet. Note that if this is used on a time series that doesn't have a constant trend, any trend will be fit with the noise term and so there will be high predictive uncertainty in the forecast.\n",
|
||||
"Note that if this is used on a time series that doesn't have a constant trend, any trend will be fit with the noise term and so there will be high predictive uncertainty in the forecast.\n",
|
||||
"\n",
|
||||
"To use a trend besides these three built-in trend functions (piecewise linear, piecewise logistic growth, and flat), you can download the source code from github, modify the trend function as desired in a local branch, and then install that local version. This PR provides a good illustration of what must be done to implement a custom trend: https://github.com/facebook/prophet/pull/1466/files."
|
||||
]
|
||||
|
|
@ -237,4 +247,4 @@
|
|||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue