mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-05-31 23:27:52 +00:00
Move changepoint matrix calculation into stan (R)
This commit is contained in:
parent
7f9e4b80c1
commit
1f84fa960f
5 changed files with 117 additions and 30 deletions
|
|
@ -464,21 +464,6 @@ set_changepoints <- function(m) {
|
|||
return(m)
|
||||
}
|
||||
|
||||
#' Gets changepoint matrix for history dataframe.
|
||||
#'
|
||||
#' @param m Prophet object.
|
||||
#'
|
||||
#' @return array of indexes.
|
||||
#'
|
||||
#' @keywords internal
|
||||
get_changepoint_matrix <- function(m) {
|
||||
A <- matrix(0, nrow(m$history), length(m$changepoints.t))
|
||||
for (i in seq_along(m$changepoints.t)) {
|
||||
A[m$history$t >= m$changepoints.t[i], i] <- 1
|
||||
}
|
||||
return(A)
|
||||
}
|
||||
|
||||
#' Provides Fourier series components with the specified frequency and order.
|
||||
#'
|
||||
#' @param dates Vector of dates.
|
||||
|
|
@ -905,7 +890,6 @@ fit.prophet <- function(m, df, ...) {
|
|||
prior.scales <- out2$prior.scales
|
||||
|
||||
m <- set_changepoints(m)
|
||||
A <- get_changepoint_matrix(m)
|
||||
|
||||
# Construct input to stan
|
||||
dat <- list(
|
||||
|
|
@ -914,7 +898,6 @@ fit.prophet <- function(m, df, ...) {
|
|||
S = length(m$changepoints.t),
|
||||
y = history$y_scaled,
|
||||
t = history$t,
|
||||
A = A,
|
||||
t_change = array(m$changepoints.t),
|
||||
X = as.matrix(seasonal.features),
|
||||
sigmas = array(prior.scales),
|
||||
|
|
|
|||
|
|
@ -1,16 +1,44 @@
|
|||
functions {
|
||||
matrix get_changepoint_matrix(vector t, vector t_change, int T, int S) {
|
||||
// Assumes t and t_change are sorted.
|
||||
matrix[T, S] A;
|
||||
row_vector[S] a_row;
|
||||
int cp_idx;
|
||||
|
||||
// Start with an empty matrix.
|
||||
A = rep_matrix(0, T, S);
|
||||
a_row = rep_row_vector(0, S);
|
||||
cp_idx = 1;
|
||||
|
||||
// Fill in each row of A.
|
||||
for (i in 1:T) {
|
||||
while ((cp_idx <= S) && (t[i] >= t_change[cp_idx])) {
|
||||
a_row[cp_idx] = 1;
|
||||
cp_idx += 1;
|
||||
}
|
||||
A[i] = a_row;
|
||||
}
|
||||
return A;
|
||||
}
|
||||
}
|
||||
|
||||
data {
|
||||
int T; // Sample size
|
||||
int<lower=1> K; // Number of seasonal vectors
|
||||
vector[T] t; // Day
|
||||
vector[T] y; // Time-series
|
||||
int S; // Number of changepoints
|
||||
matrix[T, S] A; // Split indicators
|
||||
real t_change[S]; // Index of changepoints
|
||||
vector[S] t_change; // Index of changepoints
|
||||
matrix[T,K] X; // season vectors
|
||||
vector[K] sigmas; // scale on seasonality prior
|
||||
real<lower=0> tau; // scale on changepoints prior
|
||||
}
|
||||
|
||||
transformed data {
|
||||
matrix[T, S] A;
|
||||
A = get_changepoint_matrix(t, t_change, T, S);
|
||||
}
|
||||
|
||||
parameters {
|
||||
real k; // Base growth rate
|
||||
real m; // offset
|
||||
|
|
|
|||
|
|
@ -1,3 +1,27 @@
|
|||
functions {
|
||||
matrix get_changepoint_matrix(vector t, vector t_change, int T, int S) {
|
||||
// Assumes t and t_change are sorted.
|
||||
matrix[T, S] A;
|
||||
row_vector[S] a_row;
|
||||
int cp_idx;
|
||||
|
||||
// Start with an empty matrix.
|
||||
A = rep_matrix(0, T, S);
|
||||
a_row = rep_row_vector(0, S);
|
||||
cp_idx = 1;
|
||||
|
||||
// Fill in each row of A.
|
||||
for (i in 1:T) {
|
||||
while ((cp_idx <= S) && (t[i] >= t_change[cp_idx])) {
|
||||
a_row[cp_idx] = 1;
|
||||
cp_idx += 1;
|
||||
}
|
||||
A[i] = a_row;
|
||||
}
|
||||
return A;
|
||||
}
|
||||
}
|
||||
|
||||
data {
|
||||
int T; // Sample size
|
||||
int<lower=1> K; // Number of seasonal vectors
|
||||
|
|
@ -5,13 +29,17 @@ data {
|
|||
vector[T] cap; // Capacities
|
||||
vector[T] y; // Time-series
|
||||
int S; // Number of changepoints
|
||||
matrix[T, S] A; // Split indicators
|
||||
real t_change[S]; // Index of changepoints
|
||||
vector[S] t_change; // Index of changepoints
|
||||
matrix[T,K] X; // season vectors
|
||||
vector[K] sigmas; // scale on seasonality prior
|
||||
real<lower=0> tau; // scale on changepoints prior
|
||||
}
|
||||
|
||||
transformed data {
|
||||
matrix[T, S] A;
|
||||
A = get_changepoint_matrix(t, t_change, T, S);
|
||||
}
|
||||
|
||||
parameters {
|
||||
real k; // Base growth rate
|
||||
real m; // offset
|
||||
|
|
|
|||
|
|
@ -121,11 +121,7 @@ test_that("get_changepoints", {
|
|||
cp <- m$changepoints.t
|
||||
expect_equal(length(cp), m$n.changepoints)
|
||||
expect_true(min(cp) > 0)
|
||||
expect_true(max(cp) < N)
|
||||
|
||||
mat <- prophet:::get_changepoint_matrix(m)
|
||||
expect_equal(nrow(mat), floor(N / 2))
|
||||
expect_equal(ncol(mat), m$n.changepoints)
|
||||
expect_true(max(cp) < 1)
|
||||
})
|
||||
|
||||
test_that("get_zero_changepoints", {
|
||||
|
|
@ -141,10 +137,6 @@ test_that("get_zero_changepoints", {
|
|||
cp <- m$changepoints.t
|
||||
expect_equal(length(cp), 1)
|
||||
expect_equal(cp[1], 0)
|
||||
|
||||
mat <- prophet:::get_changepoint_matrix(m)
|
||||
expect_equal(nrow(mat), floor(N / 2))
|
||||
expect_equal(ncol(mat), 1)
|
||||
})
|
||||
|
||||
test_that("override_n_changepoints", {
|
||||
|
|
|
|||
56
R/tests/testthat/test_stan_functions.R
Normal file
56
R/tests/testthat/test_stan_functions.R
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
library(prophet)
|
||||
context("Prophet stan model tests")
|
||||
|
||||
rstan::expose_stan_functions(rstan::stanc(file="../..//inst/stan/prophet_logistic_growth.stan"))
|
||||
|
||||
DATA <- read.csv('data.csv')
|
||||
N <- nrow(DATA)
|
||||
train <- DATA[1:floor(N / 2), ]
|
||||
future <- DATA[(ceiling(N/2) + 1):N, ]
|
||||
|
||||
DATA2 <- read.csv('data2.csv')
|
||||
|
||||
DATA$ds <- prophet:::set_date(DATA$ds)
|
||||
DATA2$ds <- prophet:::set_date(DATA2$ds)
|
||||
|
||||
test_that("get_changepoint_matrix", {
|
||||
history <- train
|
||||
m <- prophet(history, fit = FALSE)
|
||||
|
||||
out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
|
||||
history <- out$df
|
||||
m <- out$m
|
||||
m$history <- history
|
||||
|
||||
m <- prophet:::set_changepoints(m)
|
||||
|
||||
cp <- m$changepoints.t
|
||||
|
||||
mat <- get_changepoint_matrix(history$t, cp, nrow(history), length(cp))
|
||||
expect_equal(nrow(mat), floor(N / 2))
|
||||
expect_equal(ncol(mat), m$n.changepoints)
|
||||
# Compare to the R implementation
|
||||
A <- matrix(0, nrow(history), length(cp))
|
||||
for (i in 1:length(cp)) {
|
||||
A[history$t >= cp[i], i] <- 1
|
||||
}
|
||||
expect_true(all(A == mat))
|
||||
})
|
||||
|
||||
test_that("get_zero_changepoints", {
|
||||
history <- train
|
||||
m <- prophet(history, n.changepoints = 0, fit = FALSE)
|
||||
|
||||
out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
|
||||
m <- out$m
|
||||
history <- out$df
|
||||
m$history <- history
|
||||
|
||||
m <- prophet:::set_changepoints(m)
|
||||
cp <- m$changepoints.t
|
||||
|
||||
mat <- get_changepoint_matrix(history$t, cp, nrow(history), length(cp))
|
||||
expect_equal(nrow(mat), floor(N / 2))
|
||||
expect_equal(ncol(mat), 1)
|
||||
expect_true(all(mat == 1))
|
||||
})
|
||||
Loading…
Reference in a new issue