mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-06-07 00:13:24 +00:00
Combine trend models into a single stan file (R)
This commit is contained in:
parent
1f84fa960f
commit
aa4e223152
7 changed files with 203 additions and 185 deletions
|
|
@ -218,22 +218,25 @@ validate_column_name <- function(
|
|||
#' @return Stan model.
|
||||
#'
|
||||
#' @keywords internal
|
||||
get_prophet_stan_model <- function(model) {
|
||||
fn <- paste('prophet', model, 'growth.RData', sep = '_')
|
||||
get_prophet_stan_model <- function() {
|
||||
## If the cached model doesn't work, just compile a new one.
|
||||
tryCatch({
|
||||
binary <- system.file('libs', Sys.getenv('R_ARCH'), fn,
|
||||
package = 'prophet',
|
||||
mustWork = TRUE)
|
||||
binary <- system.file(
|
||||
'libs',
|
||||
Sys.getenv('R_ARCH'),
|
||||
'prophet_stan_model.RData',
|
||||
package = 'prophet',
|
||||
mustWork = TRUE
|
||||
)
|
||||
load(binary)
|
||||
obj.name <- paste(model, 'growth.stanm', sep = '.')
|
||||
obj.name <- 'model.stanm'
|
||||
stanm <- eval(parse(text = obj.name))
|
||||
|
||||
## Should cause an error if the model doesn't work.
|
||||
stanm@mk_cppmodule(stanm)
|
||||
stanm
|
||||
}, error = function(cond) {
|
||||
compile_stan_model(model)
|
||||
compile_stan_model()
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -245,14 +248,13 @@ get_prophet_stan_model <- function(model) {
|
|||
#' @return Stan model.
|
||||
#'
|
||||
#' @keywords internal
|
||||
compile_stan_model <- function(model) {
|
||||
fn <- paste('stan/prophet', model, 'growth.stan', sep = '_')
|
||||
compile_stan_model <- function() {
|
||||
fn <- 'stan/prophet.stan'
|
||||
|
||||
stan.src <- system.file(fn, package = 'prophet', mustWork = TRUE)
|
||||
stanc <- rstan::stanc(stan.src)
|
||||
|
||||
model.name <- paste(model, 'growth', sep = '_')
|
||||
return(rstan::stan_model(stanc_ret = stanc, model_name = model.name))
|
||||
return(rstan::stan_model(stanc_ret = stanc, model_name = 'prophet_model'))
|
||||
}
|
||||
|
||||
#' Convert date vector
|
||||
|
|
@ -901,21 +903,23 @@ fit.prophet <- function(m, df, ...) {
|
|||
t_change = array(m$changepoints.t),
|
||||
X = as.matrix(seasonal.features),
|
||||
sigmas = array(prior.scales),
|
||||
tau = m$changepoint.prior.scale
|
||||
tau = m$changepoint.prior.scale,
|
||||
trend_indicator = as.numeric(m$growth == 'logistic')
|
||||
)
|
||||
|
||||
# Run stan
|
||||
if (m$growth == 'linear') {
|
||||
dat$cap <- rep(0, nrow(history)) # Unused inside Stan
|
||||
kinit <- linear_growth_init(history)
|
||||
} else {
|
||||
dat$cap <- history$cap_scaled # Add capacities to the Stan data
|
||||
kinit <- logistic_growth_init(history)
|
||||
}
|
||||
|
||||
if (exists(".prophet.stan.models")) {
|
||||
model <- .prophet.stan.models[[m$growth]]
|
||||
if (exists(".prophet.stan.model")) {
|
||||
model <- .prophet.stan.model
|
||||
} else {
|
||||
model <- get_prophet_stan_model(m$growth)
|
||||
model <- get_prophet_stan_model()
|
||||
}
|
||||
|
||||
stan_init <- function() {
|
||||
|
|
|
|||
11
R/R/zzz.R
11
R/R/zzz.R
|
|
@ -6,9 +6,10 @@
|
|||
## of patent rights can be found in the PATENTS file in the same directory.
|
||||
|
||||
.onLoad <- function(libname, pkgname) {
|
||||
.prophet.stan.models <- list(
|
||||
"linear"=get_prophet_stan_model("linear"),
|
||||
"logistic"=get_prophet_stan_model("logistic"))
|
||||
assign(".prophet.stan.models", .prophet.stan.models,
|
||||
envir=parent.env(environment()))
|
||||
.prophet.stan.model <- get_prophet_stan_model()
|
||||
assign(
|
||||
".prophet.stan.model",
|
||||
.prophet.stan.model,
|
||||
envir=parent.env(environment())
|
||||
)
|
||||
}
|
||||
|
|
|
|||
123
R/inst/stan/prophet.stan
Normal file
123
R/inst/stan/prophet.stan
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
functions {
|
||||
matrix get_changepoint_matrix(vector t, vector t_change, int T, int S) {
|
||||
// Assumes t and t_change are sorted.
|
||||
matrix[T, S] A;
|
||||
row_vector[S] a_row;
|
||||
int cp_idx;
|
||||
|
||||
// Start with an empty matrix.
|
||||
A = rep_matrix(0, T, S);
|
||||
a_row = rep_row_vector(0, S);
|
||||
cp_idx = 1;
|
||||
|
||||
// Fill in each row of A.
|
||||
for (i in 1:T) {
|
||||
while ((cp_idx <= S) && (t[i] >= t_change[cp_idx])) {
|
||||
a_row[cp_idx] = 1;
|
||||
cp_idx += 1;
|
||||
}
|
||||
A[i] = a_row;
|
||||
}
|
||||
return A;
|
||||
}
|
||||
|
||||
// Logistic trend functions
|
||||
|
||||
vector logistic_gamma(real k, real m, vector delta, vector t_change, int S) {
|
||||
vector[S] gamma; // adjusted offsets, for piecewise continuity
|
||||
vector[S + 1] k_s; // actual rate in each segment
|
||||
real m_pr;
|
||||
|
||||
// Compute the rate in each segment
|
||||
k_s[1] = k;
|
||||
for (i in 1:S) {
|
||||
k_s[i + 1] = k_s[i] + delta[i];
|
||||
}
|
||||
|
||||
// Piecewise offsets
|
||||
m_pr = m; // The offset in the previous segment
|
||||
for (i in 1:S) {
|
||||
gamma[i] = (t_change[i] - m_pr) * (1 - k_s[i] / k_s[i + 1]);
|
||||
m_pr = m_pr + gamma[i]; // update for the next segment
|
||||
}
|
||||
return gamma;
|
||||
}
|
||||
|
||||
vector logistic_trend(
|
||||
real k,
|
||||
real m,
|
||||
vector delta,
|
||||
vector t,
|
||||
vector cap,
|
||||
matrix A,
|
||||
vector t_change,
|
||||
int S
|
||||
) {
|
||||
vector[S] gamma;
|
||||
|
||||
gamma = logistic_gamma(k, m, delta, t_change, S);
|
||||
return cap ./ (1 + exp(-(k + A * delta) .* (t - (m + A * gamma))));
|
||||
}
|
||||
|
||||
// Linear trend function
|
||||
|
||||
vector linear_trend(
|
||||
real k,
|
||||
real m,
|
||||
vector delta,
|
||||
vector t,
|
||||
matrix A,
|
||||
vector t_change
|
||||
) {
|
||||
return (k + A * delta) .* t + (m + A * (-t_change .* delta));
|
||||
}
|
||||
}
|
||||
|
||||
data {
|
||||
int T; // Number of time periods
|
||||
int<lower=1> K; // Number of regressors
|
||||
vector[T] t; // Time
|
||||
vector[T] cap; // Capacities for logistic trend
|
||||
vector[T] y; // Time series
|
||||
int S; // Number of changepoints
|
||||
vector[S] t_change; // Times of trend changepoints
|
||||
matrix[T,K] X; // Regressors
|
||||
vector[K] sigmas; // Scale on seasonality prior
|
||||
real<lower=0> tau; // Scale on changepoints prior
|
||||
int trend_indicator; // 0 for linear, 1 for logistic
|
||||
}
|
||||
|
||||
transformed data {
|
||||
matrix[T, S] A;
|
||||
A = get_changepoint_matrix(t, t_change, T, S);
|
||||
}
|
||||
|
||||
parameters {
|
||||
real k; // Base trend growth rate
|
||||
real m; // Trend offset
|
||||
vector[S] delta; // Trend rate adjustments
|
||||
real<lower=0> sigma_obs; // Observation noise
|
||||
vector[K] beta; // Regressor coefficients
|
||||
}
|
||||
|
||||
transformed parameters {
|
||||
vector[T] trend;
|
||||
|
||||
if (trend_indicator == 0) {
|
||||
trend = linear_trend(k, m, delta, t, A, t_change);
|
||||
} else if (trend_indicator == 1) {
|
||||
trend = logistic_trend(k, m, delta, t, cap, A, t_change, S);
|
||||
}
|
||||
}
|
||||
|
||||
model {
|
||||
//priors
|
||||
k ~ normal(0, 5);
|
||||
m ~ normal(0, 5);
|
||||
delta ~ double_exponential(0, tau);
|
||||
sigma_obs ~ normal(0, 0.1);
|
||||
beta ~ normal(0, sigmas);
|
||||
|
||||
// Likelihood
|
||||
y ~ normal(trend + X * beta, sigma_obs);
|
||||
}
|
||||
|
|
@ -1,68 +0,0 @@
|
|||
functions {
|
||||
matrix get_changepoint_matrix(vector t, vector t_change, int T, int S) {
|
||||
// Assumes t and t_change are sorted.
|
||||
matrix[T, S] A;
|
||||
row_vector[S] a_row;
|
||||
int cp_idx;
|
||||
|
||||
// Start with an empty matrix.
|
||||
A = rep_matrix(0, T, S);
|
||||
a_row = rep_row_vector(0, S);
|
||||
cp_idx = 1;
|
||||
|
||||
// Fill in each row of A.
|
||||
for (i in 1:T) {
|
||||
while ((cp_idx <= S) && (t[i] >= t_change[cp_idx])) {
|
||||
a_row[cp_idx] = 1;
|
||||
cp_idx += 1;
|
||||
}
|
||||
A[i] = a_row;
|
||||
}
|
||||
return A;
|
||||
}
|
||||
}
|
||||
|
||||
data {
|
||||
int T; // Sample size
|
||||
int<lower=1> K; // Number of seasonal vectors
|
||||
vector[T] t; // Day
|
||||
vector[T] y; // Time-series
|
||||
int S; // Number of changepoints
|
||||
vector[S] t_change; // Index of changepoints
|
||||
matrix[T,K] X; // season vectors
|
||||
vector[K] sigmas; // scale on seasonality prior
|
||||
real<lower=0> tau; // scale on changepoints prior
|
||||
}
|
||||
|
||||
transformed data {
|
||||
matrix[T, S] A;
|
||||
A = get_changepoint_matrix(t, t_change, T, S);
|
||||
}
|
||||
|
||||
parameters {
|
||||
real k; // Base growth rate
|
||||
real m; // offset
|
||||
vector[S] delta; // Rate adjustments
|
||||
real<lower=0> sigma_obs; // Observation noise (incl. seasonal variation)
|
||||
vector[K] beta; // seasonal vector
|
||||
}
|
||||
|
||||
transformed parameters {
|
||||
vector[S] gamma; // adjusted offsets, for piecewise continuity
|
||||
|
||||
for (i in 1:S) {
|
||||
gamma[i] = -t_change[i] * delta[i];
|
||||
}
|
||||
}
|
||||
|
||||
model {
|
||||
//priors
|
||||
k ~ normal(0, 5);
|
||||
m ~ normal(0, 5);
|
||||
delta ~ double_exponential(0, tau);
|
||||
sigma_obs ~ normal(0, 0.5);
|
||||
beta ~ normal(0, sigmas);
|
||||
|
||||
// Likelihood
|
||||
y ~ normal((k + A * delta) .* t + (m + A * gamma) + X * beta, sigma_obs);
|
||||
}
|
||||
|
|
@ -1,80 +0,0 @@
|
|||
functions {
|
||||
matrix get_changepoint_matrix(vector t, vector t_change, int T, int S) {
|
||||
// Assumes t and t_change are sorted.
|
||||
matrix[T, S] A;
|
||||
row_vector[S] a_row;
|
||||
int cp_idx;
|
||||
|
||||
// Start with an empty matrix.
|
||||
A = rep_matrix(0, T, S);
|
||||
a_row = rep_row_vector(0, S);
|
||||
cp_idx = 1;
|
||||
|
||||
// Fill in each row of A.
|
||||
for (i in 1:T) {
|
||||
while ((cp_idx <= S) && (t[i] >= t_change[cp_idx])) {
|
||||
a_row[cp_idx] = 1;
|
||||
cp_idx += 1;
|
||||
}
|
||||
A[i] = a_row;
|
||||
}
|
||||
return A;
|
||||
}
|
||||
}
|
||||
|
||||
data {
|
||||
int T; // Sample size
|
||||
int<lower=1> K; // Number of seasonal vectors
|
||||
vector[T] t; // Day
|
||||
vector[T] cap; // Capacities
|
||||
vector[T] y; // Time-series
|
||||
int S; // Number of changepoints
|
||||
vector[S] t_change; // Index of changepoints
|
||||
matrix[T,K] X; // season vectors
|
||||
vector[K] sigmas; // scale on seasonality prior
|
||||
real<lower=0> tau; // scale on changepoints prior
|
||||
}
|
||||
|
||||
transformed data {
|
||||
matrix[T, S] A;
|
||||
A = get_changepoint_matrix(t, t_change, T, S);
|
||||
}
|
||||
|
||||
parameters {
|
||||
real k; // Base growth rate
|
||||
real m; // offset
|
||||
vector[S] delta; // Rate adjustments
|
||||
real<lower=0> sigma_obs; // Observation noise (incl. seasonal variation)
|
||||
vector[K] beta; // seasonal vector
|
||||
}
|
||||
|
||||
transformed parameters {
|
||||
vector[S] gamma; // adjusted offsets, for piecewise continuity
|
||||
vector[S + 1] k_s; // actual rate in each segment
|
||||
real m_pr;
|
||||
|
||||
// Compute the rate in each segment
|
||||
k_s[1] = k;
|
||||
for (i in 1:S) {
|
||||
k_s[i + 1] = k_s[i] + delta[i];
|
||||
}
|
||||
|
||||
// Piecewise offsets
|
||||
m_pr = m; // The offset in the previous segment
|
||||
for (i in 1:S) {
|
||||
gamma[i] = (t_change[i] - m_pr) * (1 - k_s[i] / k_s[i + 1]);
|
||||
m_pr = m_pr + gamma[i]; // update for the next segment
|
||||
}
|
||||
}
|
||||
|
||||
model {
|
||||
//priors
|
||||
k ~ normal(0, 5);
|
||||
m ~ normal(0, 5);
|
||||
delta ~ double_exponential(0, tau);
|
||||
sigma_obs ~ normal(0, 0.1);
|
||||
beta ~ normal(0, sigmas);
|
||||
|
||||
// Likelihood
|
||||
y ~ normal(cap ./ (1 + exp(-(k + A * delta) .* (t - (m + A * gamma)))) + X * beta, sigma_obs);
|
||||
}
|
||||
|
|
@ -1,26 +1,21 @@
|
|||
|
||||
|
||||
packageStartupMessage('Compiling models (this will take a minute...)')
|
||||
packageStartupMessage('Compiling model (this will take a minute...)')
|
||||
|
||||
dest <- file.path(R_PACKAGE_DIR, paste0('libs', R_ARCH))
|
||||
dir.create(dest, recursive = TRUE, showWarnings = FALSE)
|
||||
|
||||
packageStartupMessage(paste('Writing models to:', dest))
|
||||
packageStartupMessage(paste('Writing model to:', dest))
|
||||
packageStartupMessage(paste('Compiling using binary:', R.home('bin')))
|
||||
|
||||
logistic.growth.src <- file.path(R_PACKAGE_SOURCE, 'inst', 'stan', 'prophet_logistic_growth.stan')
|
||||
logistic.growth.binary <- file.path(dest, 'prophet_logistic_growth.RData')
|
||||
logistic.growth.stanc <- rstan::stanc(logistic.growth.src)
|
||||
logistic.growth.stanm <- rstan::stan_model(stanc_ret = logistic.growth.stanc,
|
||||
model_name = 'logistic_growth')
|
||||
save('logistic.growth.stanm', file = logistic.growth.binary)
|
||||
model.src <- file.path(R_PACKAGE_SOURCE, 'inst', 'stan', 'prophet.stan')
|
||||
model.binary <- file.path(dest, 'prophet_stan_model.RData')
|
||||
model.stanc <- rstan::stanc(model.src)
|
||||
model.stanm <- rstan::stan_model(
|
||||
stanc_ret = model.stanc,
|
||||
model_name = 'prophet_model'
|
||||
)
|
||||
save('model.stanm', file = model.binary)
|
||||
|
||||
linear.growth.src <- file.path(R_PACKAGE_SOURCE, 'inst', 'stan', 'prophet_linear_growth.stan')
|
||||
linear.growth.binary <- file.path(dest, 'prophet_linear_growth.RData')
|
||||
linear.growth.stanc <- rstan::stanc(linear.growth.src)
|
||||
linear.growth.stanm <- rstan::stan_model(stanc_ret = linear.growth.stanc,
|
||||
model_name = 'linear_growth')
|
||||
save('linear.growth.stanm', file = linear.growth.binary)
|
||||
|
||||
packageStartupMessage('------ Models successfully compiled!')
|
||||
packageStartupMessage('------ Model successfully compiled!')
|
||||
packageStartupMessage('You can ignore any compiler warnings above.')
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
library(prophet)
|
||||
context("Prophet stan model tests")
|
||||
|
||||
rstan::expose_stan_functions(rstan::stanc(file="../..//inst/stan/prophet_logistic_growth.stan"))
|
||||
rstan::expose_stan_functions(
|
||||
rstan::stanc(file="../..//inst/stan/prophet.stan")
|
||||
)
|
||||
|
||||
DATA <- read.csv('data.csv')
|
||||
N <- nrow(DATA)
|
||||
|
|
@ -54,3 +56,44 @@ test_that("get_zero_changepoints", {
|
|||
expect_equal(ncol(mat), 1)
|
||||
expect_true(all(mat == 1))
|
||||
})
|
||||
|
||||
test_that("linear_trend", {
|
||||
t <- seq(0, 10)
|
||||
m <- 0
|
||||
k <- 1.0
|
||||
deltas <- c(0.5)
|
||||
changepoint.ts <- c(5)
|
||||
A <- get_changepoint_matrix(t, changepoint.ts, length(t), 1)
|
||||
|
||||
y <- linear_trend(k, m, deltas, t, A, changepoint.ts)
|
||||
y.true <- c(0, 1, 2, 3, 4, 5, 6.5, 8, 9.5, 11, 12.5)
|
||||
expect_equal(y, y.true)
|
||||
|
||||
t <- t[8:length(t)]
|
||||
A <- get_changepoint_matrix(t, changepoint.ts, length(t), 1)
|
||||
y.true <- y.true[8:length(y.true)]
|
||||
y <- linear_trend(k, m, deltas, t, A, changepoint.ts)
|
||||
expect_equal(y, y.true)
|
||||
})
|
||||
|
||||
test_that("piecewise_logistic", {
|
||||
t <- seq(0, 10)
|
||||
cap <- rep(10, 11)
|
||||
m <- 0
|
||||
k <- 1.0
|
||||
deltas <- c(0.5)
|
||||
changepoint.ts <- c(5)
|
||||
A <- get_changepoint_matrix(t, changepoint.ts, length(t), 1)
|
||||
|
||||
y <- logistic_trend(k, m, deltas, t, cap, A, changepoint.ts, 1)
|
||||
y.true <- c(5.000000, 7.310586, 8.807971, 9.525741, 9.820138, 9.933071,
|
||||
9.984988, 9.996646, 9.999252, 9.999833, 9.999963)
|
||||
expect_equal(y, y.true, tolerance = 1e-6)
|
||||
|
||||
t <- t[8:length(t)]
|
||||
A <- get_changepoint_matrix(t, changepoint.ts, length(t), 1)
|
||||
y.true <- y.true[8:length(y.true)]
|
||||
cap <- cap[8:length(cap)]
|
||||
y <- logistic_trend(k, m, deltas, t, cap, A, changepoint.ts, 1)
|
||||
expect_equal(y, y.true, tolerance = 1e-6)
|
||||
})
|
||||
|
|
|
|||
Loading…
Reference in a new issue