diff --git a/R/R/prophet.R b/R/R/prophet.R index 401e24e..1d80a16 100644 --- a/R/R/prophet.R +++ b/R/R/prophet.R @@ -464,21 +464,6 @@ set_changepoints <- function(m) { return(m) } -#' Gets changepoint matrix for history dataframe. -#' -#' @param m Prophet object. -#' -#' @return array of indexes. -#' -#' @keywords internal -get_changepoint_matrix <- function(m) { - A <- matrix(0, nrow(m$history), length(m$changepoints.t)) - for (i in seq_along(m$changepoints.t)) { - A[m$history$t >= m$changepoints.t[i], i] <- 1 - } - return(A) -} - #' Provides Fourier series components with the specified frequency and order. #' #' @param dates Vector of dates. @@ -905,7 +890,6 @@ fit.prophet <- function(m, df, ...) { prior.scales <- out2$prior.scales m <- set_changepoints(m) - A <- get_changepoint_matrix(m) # Construct input to stan dat <- list( @@ -914,7 +898,6 @@ fit.prophet <- function(m, df, ...) { S = length(m$changepoints.t), y = history$y_scaled, t = history$t, - A = A, t_change = array(m$changepoints.t), X = as.matrix(seasonal.features), sigmas = array(prior.scales), diff --git a/R/inst/stan/prophet_linear_growth.stan b/R/inst/stan/prophet_linear_growth.stan index 1639f4a..0d0e4dc 100644 --- a/R/inst/stan/prophet_linear_growth.stan +++ b/R/inst/stan/prophet_linear_growth.stan @@ -1,16 +1,44 @@ +functions { + matrix get_changepoint_matrix(vector t, vector t_change, int T, int S) { + // Assumes t and t_change are sorted. + matrix[T, S] A; + row_vector[S] a_row; + int cp_idx; + + // Start with an empty matrix. + A = rep_matrix(0, T, S); + a_row = rep_row_vector(0, S); + cp_idx = 1; + + // Fill in each row of A. + for (i in 1:T) { + while ((cp_idx <= S) && (t[i] >= t_change[cp_idx])) { + a_row[cp_idx] = 1; + cp_idx += 1; + } + A[i] = a_row; + } + return A; + } +} + data { int T; // Sample size int K; // Number of seasonal vectors vector[T] t; // Day vector[T] y; // Time-series int S; // Number of changepoints - matrix[T, S] A; // Split indicators - real t_change[S]; // Index of changepoints + vector[S] t_change; // Index of changepoints matrix[T,K] X; // season vectors vector[K] sigmas; // scale on seasonality prior real tau; // scale on changepoints prior } +transformed data { + matrix[T, S] A; + A = get_changepoint_matrix(t, t_change, T, S); +} + parameters { real k; // Base growth rate real m; // offset diff --git a/R/inst/stan/prophet_logistic_growth.stan b/R/inst/stan/prophet_logistic_growth.stan index 046b173..093de57 100644 --- a/R/inst/stan/prophet_logistic_growth.stan +++ b/R/inst/stan/prophet_logistic_growth.stan @@ -1,3 +1,27 @@ +functions { + matrix get_changepoint_matrix(vector t, vector t_change, int T, int S) { + // Assumes t and t_change are sorted. + matrix[T, S] A; + row_vector[S] a_row; + int cp_idx; + + // Start with an empty matrix. + A = rep_matrix(0, T, S); + a_row = rep_row_vector(0, S); + cp_idx = 1; + + // Fill in each row of A. + for (i in 1:T) { + while ((cp_idx <= S) && (t[i] >= t_change[cp_idx])) { + a_row[cp_idx] = 1; + cp_idx += 1; + } + A[i] = a_row; + } + return A; + } +} + data { int T; // Sample size int K; // Number of seasonal vectors @@ -5,13 +29,17 @@ data { vector[T] cap; // Capacities vector[T] y; // Time-series int S; // Number of changepoints - matrix[T, S] A; // Split indicators - real t_change[S]; // Index of changepoints + vector[S] t_change; // Index of changepoints matrix[T,K] X; // season vectors vector[K] sigmas; // scale on seasonality prior real tau; // scale on changepoints prior } +transformed data { + matrix[T, S] A; + A = get_changepoint_matrix(t, t_change, T, S); +} + parameters { real k; // Base growth rate real m; // offset diff --git a/R/tests/testthat/test_prophet.R b/R/tests/testthat/test_prophet.R index 351082f..ee588fc 100644 --- a/R/tests/testthat/test_prophet.R +++ b/R/tests/testthat/test_prophet.R @@ -121,11 +121,7 @@ test_that("get_changepoints", { cp <- m$changepoints.t expect_equal(length(cp), m$n.changepoints) expect_true(min(cp) > 0) - expect_true(max(cp) < N) - - mat <- prophet:::get_changepoint_matrix(m) - expect_equal(nrow(mat), floor(N / 2)) - expect_equal(ncol(mat), m$n.changepoints) + expect_true(max(cp) < 1) }) test_that("get_zero_changepoints", { @@ -141,10 +137,6 @@ test_that("get_zero_changepoints", { cp <- m$changepoints.t expect_equal(length(cp), 1) expect_equal(cp[1], 0) - - mat <- prophet:::get_changepoint_matrix(m) - expect_equal(nrow(mat), floor(N / 2)) - expect_equal(ncol(mat), 1) }) test_that("override_n_changepoints", { diff --git a/R/tests/testthat/test_stan_functions.R b/R/tests/testthat/test_stan_functions.R new file mode 100644 index 0000000..815e21d --- /dev/null +++ b/R/tests/testthat/test_stan_functions.R @@ -0,0 +1,56 @@ +library(prophet) +context("Prophet stan model tests") + +rstan::expose_stan_functions(rstan::stanc(file="../..//inst/stan/prophet_logistic_growth.stan")) + +DATA <- read.csv('data.csv') +N <- nrow(DATA) +train <- DATA[1:floor(N / 2), ] +future <- DATA[(ceiling(N/2) + 1):N, ] + +DATA2 <- read.csv('data2.csv') + +DATA$ds <- prophet:::set_date(DATA$ds) +DATA2$ds <- prophet:::set_date(DATA2$ds) + +test_that("get_changepoint_matrix", { + history <- train + m <- prophet(history, fit = FALSE) + + out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE) + history <- out$df + m <- out$m + m$history <- history + + m <- prophet:::set_changepoints(m) + + cp <- m$changepoints.t + + mat <- get_changepoint_matrix(history$t, cp, nrow(history), length(cp)) + expect_equal(nrow(mat), floor(N / 2)) + expect_equal(ncol(mat), m$n.changepoints) + # Compare to the R implementation + A <- matrix(0, nrow(history), length(cp)) + for (i in 1:length(cp)) { + A[history$t >= cp[i], i] <- 1 + } + expect_true(all(A == mat)) +}) + +test_that("get_zero_changepoints", { + history <- train + m <- prophet(history, n.changepoints = 0, fit = FALSE) + + out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE) + m <- out$m + history <- out$df + m$history <- history + + m <- prophet:::set_changepoints(m) + cp <- m$changepoints.t + + mat <- get_changepoint_matrix(history$t, cp, nrow(history), length(cp)) + expect_equal(nrow(mat), floor(N / 2)) + expect_equal(ncol(mat), 1) + expect_true(all(mat == 1)) +})