Added growth='flat' functionality in R (#1778)

* added `flat_growth_init()` function

* added validation for 'flat'

* changed `fit.prophet()` to handle `growth='flat'`

* added `trend='flat'` capabilities to `sample_predictive_trend()` and `fit.prophet()`

* updated STAN code to handle flat trend

* [Syntax fix] Removed unnecessary bracket

* updated documentation

* undid formatting that was accidentally applied by autoformatter

* undid more formatting that was accidentally applied by autoformatter

* added tests

* typo in `sample_predictive_trend()`

* updated notebook with example in R

* updated documentation
This commit is contained in:
Sam Snarr 2021-01-14 16:53:08 -05:00 committed by GitHub
parent 73b53658e1
commit 2d8e6c7fd1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 186 additions and 32 deletions

View file

@ -20,8 +20,8 @@ globalVariables(c(
#' also have a column cap that specifies the capacity at each ds. If not
#' provided, then the model object will be instantiated but not fit; use
#' fit.prophet(m, df) to fit the model.
#' @param growth String 'linear' or 'logistic' to specify a linear or logistic
#' trend.
#' @param growth String 'linear', 'logistic', or 'flat' to specify a linear, logistic
#' or flat trend.
#' @param changepoints Vector of dates at which to include potential
#' changepoints. If not specified, potential changepoints are selected
#' automatically.
@ -154,8 +154,8 @@ prophet <- function(df = NULL,
#'
#' @keywords internal
validate_inputs <- function(m) {
if (!(m$growth %in% c('linear', 'logistic'))) {
stop("Parameter 'growth' should be 'linear' or 'logistic'.")
if (!(m$growth %in% c('linear', 'logistic', 'flat'))) {
stop("Parameter 'growth' should be 'linear', 'logistic', or 'flat'.")
}
if ((m$changepoint.range < 0) | (m$changepoint.range > 1)) {
stop("Parameter 'changepoint.range' must be in [0, 1]")
@ -1053,7 +1053,28 @@ set_auto_seasonalities <- function(m) {
return(m)
}
#' Initialize linear growth.
#' Initialize flat growth.
#'
#' Provides a strong initialization for flat growth by setting the
#' growth to 0 and calculates the offset parameter that pass the
#' function through the mean of the the y_scaled values.
#'
#' @param df Data frame with columns ds (date), y_scaled (scaled time series),
#' and t (scaled time).
#'
#' @return A vector (k, m) with the rate (k) and offset (m) of the flat
#' growth function.
#'
#' @keywords internal
flat_growth_init <- function(df) {
# Initialize the rate
k <- 0
# And the offset
m <- mean(df$y_scaled)
return(c(k, m))
}
#' Initialize constant growth.
#'
#' Provides a strong initialization for linear growth by calculating the
#' growth and offset parameters that pass the function through the first and
@ -1177,7 +1198,7 @@ fit.prophet <- function(m, df, ...) {
X = as.matrix(seasonal.features),
sigmas = array(prior.scales),
tau = m$changepoint.prior.scale,
trend_indicator = as.numeric(m$growth == 'logistic'),
trend_indicator = switch(m$growth, 'linear'=0, 'logistic'=1, 'flat'=2),
s_a = array(component.cols$additive_terms),
s_m = array(component.cols$multiplicative_terms)
)
@ -1186,7 +1207,10 @@ fit.prophet <- function(m, df, ...) {
if (m$growth == 'linear') {
dat$cap <- rep(0, nrow(history)) # Unused inside Stan
kinit <- linear_growth_init(history)
} else {
} else if (m$growth == 'flat') {
dat$cap <- rep(0, nrow(history)) # Unused inside Stan
kinit <- flat_growth_init(history)
} else if (m$growth == 'logistic') {
dat$cap <- history$cap_scaled # Add capacities to the Stan data
kinit <- logistic_growth_init(history)
}
@ -1206,7 +1230,8 @@ fit.prophet <- function(m, df, ...) {
)
}
if (min(history$y) == max(history$y)) {
if (min(history$y) == max(history$y) &
(m$growth %in% c('linear', 'flat'))) {
# Nothing to fit.
m$params <- stan_init()
m$params$sigma_obs <- 0.
@ -1318,6 +1343,19 @@ predict.prophet <- function(object, df = NULL, ...) {
return(df)
}
#' Evaluate the flat trend function.
#'
#' @param t Vector of times on which the function is evaluated.
#' @param m Float initial offset.
#'
#' @return Vector y(t).
#'
#' @keywords internal
flat_trend <- function(t, m) {
y <- rep(m, length(t))
return(y)
}
#' Evaluate the piecewise linear function.
#'
#' @param t Vector of times on which the function is evaluated.
@ -1392,7 +1430,9 @@ predict_trend <- function(model, df) {
t <- df$t
if (model$growth == 'linear') {
trend <- piecewise_linear(t, deltas, k, param.m, model$changepoints.t)
} else {
} else if (model$growth == 'flat') {
trend <- flat_trend(t, param.m)
} else if (model$growth == 'logistic') {
cap <- df$cap_scaled
trend <- piecewise_logistic(
t, cap, deltas, k, param.m, model$changepoints.t)
@ -1592,7 +1632,9 @@ sample_predictive_trend <- function(model, df, iteration) {
# Get the corresponding trend
if (model$growth == 'linear') {
trend <- piecewise_linear(t, deltas, k, param.m, changepoint.ts)
} else {
} else if (model$growth == 'flat') {
trend <- flat_trend(t, param.m)
} else if (model$growth == 'logistic') {
cap <- df$cap_scaled
trend <- piecewise_logistic(t, cap, deltas, k, param.m, changepoint.ts)
}

View file

@ -27,7 +27,6 @@ functions {
}
// Logistic trend functions
vector logistic_gamma(real k, real m, vector delta, vector t_change, int S) {
vector[S] gamma; // adjusted offsets, for piecewise continuity
vector[S + 1] k_s; // actual rate in each segment
@ -62,7 +61,6 @@ functions {
}
// Linear trend function
vector linear_trend(
real k,
real m,
@ -73,6 +71,14 @@ functions {
) {
return (k + A * delta) .* t + (m + A * (-t_change .* delta));
}
// Flat trend function
vector flat_trend(
real m,
int T
) {
return rep_vector(m, T);
}
}
data {
@ -86,7 +92,7 @@ data {
matrix[T,K] X; // Regressors
vector[K] sigmas; // Scale on seasonality prior
real<lower=0> tau; // Scale on changepoints prior
int trend_indicator; // 0 for linear, 1 for logistic
int trend_indicator; // 0 for linear, 1 for logistic, 2 for flat
vector[K] s_a; // Indicator of additive features
vector[K] s_m; // Indicator of multiplicative features
}
@ -104,6 +110,17 @@ parameters {
vector[K] beta; // Regressor coefficients
}
transformed parameters {
vector[T] trend;
if (trend_indicator == 0) {
trend = linear_trend(k, m, delta, t, A, t_change);
} else if (trend_indicator == 1) {
trend = logistic_trend(k, m, delta, t, cap, A, t_change, S);
} else if (trend_indicator == 2) {
trend = flat_trend(m, T);
}
}
model {
//priors
k ~ normal(0, 5);
@ -113,19 +130,10 @@ model {
beta ~ normal(0, sigmas);
// Likelihood
if (trend_indicator == 0) {
y ~ normal(
linear_trend(k, m, delta, t, A, t_change)
y ~ normal(
trend
.* (1 + X * (beta .* s_m))
+ X * (beta .* s_a),
sigma_obs
);
} else if (trend_indicator == 1) {
y ~ normal(
logistic_trend(k, m, delta, t, cap, A, t_change, S)
.* (1 + X * (beta .* s_m))
+ X * (beta .* s_a),
sigma_obs
);
}
}

22
R/man/flat_growth_init.Rd Normal file
View file

@ -0,0 +1,22 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/prophet.R
\name{flat_growth_init}
\alias{flat_growth_init}
\title{Initialize flat growth.}
\usage{
flat_growth_init(df)
}
\arguments{
\item{df}{Data frame with columns ds (date), y_scaled (scaled time series),
and t (scaled time).}
}
\value{
A vector (k, m) with the rate (k) and offset (m) of the flat
growth function.
}
\description{
Provides a strong initialization for flat growth by setting the
growth to 0 and calculates the offset parameter that pass the
function through the mean of the the y_scaled values.
}
\keyword{internal}

20
R/man/flat_trend.Rd Normal file
View file

@ -0,0 +1,20 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/prophet.R
\name{flat_trend}
\alias{flat_trend}
\title{Evaluate the flat trend function.}
\usage{
flat_trend(t, m)
}
\arguments{
\item{t}{Vector of times on which the function is evaluated.}
\item{m}{Float initial offset.}
}
\value{
Vector y(t).
}
\description{
Evaluate the flat trend function.
}
\keyword{internal}

View file

@ -2,7 +2,7 @@
% Please edit documentation in R/prophet.R
\name{linear_growth_init}
\alias{linear_growth_init}
\title{Initialize linear growth.}
\title{Initialize constant growth.}
\usage{
linear_growth_init(df)
}

View file

@ -10,7 +10,7 @@ performance_metrics(df, metrics = NULL, rolling_window = 0.1)
\item{df}{The dataframe returned by cross_validation.}
\item{metrics}{An array of performance metrics to compute. If not provided,
will use c('mse', 'rmse', 'mae', 'mape', 'mdape', 'coverage').}
will use c('mse', 'rmse', 'mae', 'mape', 'mdape', 'smape', 'coverage').}
\item{rolling_window}{Proportion of data to use in each rolling window for
computing the metrics. Should be in [0, 1] to average.}
@ -26,6 +26,7 @@ By default the following metrics are included:
'mae': mean absolute error,
'mape': mean percent error,
'mdape': median percent error,
'smape': symmetric mean absolute percentage error,
'coverage': coverage of the upper and lower intervals
}
\details{

View file

@ -32,8 +32,8 @@ also have a column cap that specifies the capacity at each ds. If not
provided, then the model object will be instantiated but not fit; use
fit.prophet(m, df) to fit the model.}
\item{growth}{String 'linear' or 'logistic' to specify a linear or logistic
trend.}
\item{growth}{String 'linear', 'logistic', or 'flat' to specify a linear, logistic
or flat trend.}
\item{changepoints}{Vector of dates at which to include potential
changepoints. If not specified, potential changepoints are selected

22
R/man/smape.Rd Normal file
View file

@ -0,0 +1,22 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/diagnostics.R
\name{smape}
\alias{smape}
\title{Symmetric mean absolute percentage error
based on Chen and Yang (2004) formula}
\usage{
smape(df, w)
}
\arguments{
\item{df}{Cross-validation results dataframe.}
\item{w}{Aggregation window size.}
}
\value{
Array of symmetric mean absolute percent errors.
}
\description{
Symmetric mean absolute percentage error
based on Chen and Yang (2004) formula
}
\keyword{internal}

View file

@ -56,6 +56,18 @@ test_that("cross_validation_logistic", {
expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0)
})
test_that("cross_validation_flat", {
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
df <- DATA
m <- prophet(df, growth = 'flat')
df.cv <- cross_validation(
m, horizon = 1, units = "days", period = 1, initial = 140)
expect_equal(length(unique(df.cv$cutoff)), 2)
expect_true(all(df.cv$cutoff < df.cv$ds))
df.merged <- dplyr::left_join(df.cv, m$history, by="ds")
expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0)
})
test_that("cross_validation_extra_regressors", {
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
df <- DATA
@ -222,7 +234,7 @@ test_that("copy", {
df$cap <- 200.
df$binary_feature <- c(rep(0, 255), rep(1, 255))
inputs <- list(
growth = c('linear', 'logistic'),
growth = c('linear', 'logistic', 'flat'),
yearly.seasonality = c(TRUE, FALSE),
weekly.seasonality = c(TRUE, FALSE),
daily.seasonality = c(TRUE, FALSE),

View file

@ -237,9 +237,13 @@ test_that("growth_init", {
expect_equal(params[2], 0.5307511, tolerance = 1e-6)
params <- prophet:::logistic_growth_init(history)
expect_equal(params[1], 1.507925, tolerance = 1e-6)
expect_equal(params[2], -0.08167497, tolerance = 1e-6)
params <- prophet:::flat_growth_init(history)
expect_equal(params[1], 0, tolerance = 1e-6)
expect_equal(params[2], 0.49335657, tolerance = 1e-6)
})
test_that("piecewise_linear", {
@ -279,6 +283,19 @@ test_that("piecewise_logistic", {
expect_equal(y, y.true, tolerance = 1e-6)
})
test_that("flat_trend", {
t <- seq(0, 10)
m <- 0.5
y = prophet:::flat_trend(t, m)
y.true <- rep(0.5, length(t))
expect_equal(y, y.true, tolerance = 1e-6)
t <- t[8:length(t)]
y = prophet:::flat_trend(t, m)
y.true <- y.true[8:length(y.true)]
expect_equal(y, y.true, tolerance = 1e-6)
})
test_that("holidays", {
holidays <- data.frame(ds = c('2016-12-25'),
holiday = c('xmas'),

View file

@ -132,6 +132,16 @@
"For time series that exhibit strong seasonality patterns rather than trend changes, it may be useful to force the trend growth rate to be flat. This can be achieved simply by passing `growth=flat` when creating the model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%R\n",
"m <- prophet(df, growth='flat')"
]
},
{
"cell_type": "code",
"execution_count": 5,
@ -145,7 +155,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"This is currently implemented only in the Python version of Prophet. Note that if this is used on a time series that doesn't have a constant trend, any trend will be fit with the noise term and so there will be high predictive uncertainty in the forecast.\n",
"Note that if this is used on a time series that doesn't have a constant trend, any trend will be fit with the noise term and so there will be high predictive uncertainty in the forecast.\n",
"\n",
"To use a trend besides these three built-in trend functions (piecewise linear, piecewise logistic growth, and flat), you can download the source code from github, modify the trend function as desired in a local branch, and then install that local version. This PR provides a good illustration of what must be done to implement a custom trend: https://github.com/facebook/prophet/pull/1466/files."
]
@ -237,4 +247,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}