Add support for cmdstanr backend (#1806)

* init

* add suggested packages

* use environment variables and align more with Py package

* remove additional testing logic, default to lbfgs

* Remove Newton specifier from test because cmdstanr expects newton

Co-authored-by: Ben Letham <bletham@gmail.com>
This commit is contained in:
Cuong Duong 2021-05-11 09:09:25 +10:00 committed by GitHub
parent 9cad5a05fb
commit a794018d65
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 462 additions and 92 deletions

View file

@ -35,14 +35,18 @@ Imports:
tidyr (>= 0.6.1),
xts
Suggests:
cmdstanr,
posterior,
knitr,
testthat,
readr,
rmarkdown
Additional_repositories:
https://mc-stan.org/r-packages/
SystemRequirements: GNU make, C++11
Biarch: true
License: MIT + file LICENSE
LinkingTo:
LinkingTo:
BH (>= 1.66.0),
Rcpp (>= 0.12.0),
RcppParallel (>= 5.0.1),

View file

@ -16,54 +16,56 @@ globalVariables(c(
#' Prophet forecaster.
#'
#' @param df (optional) Dataframe containing the history. Must have columns ds
#' (date type) and y, the time series. If growth is logistic, then df must
#' also have a column cap that specifies the capacity at each ds. If not
#' provided, then the model object will be instantiated but not fit; use
#' fit.prophet(m, df) to fit the model.
#' @param growth String 'linear', 'logistic', or 'flat' to specify a linear, logistic
#' or flat trend.
#' (date type) and y, the time series. If growth is logistic, then df must
#' also have a column cap that specifies the capacity at each ds. If not
#' provided, then the model object will be instantiated but not fit; use
#' fit.prophet(m, df) to fit the model.
#' @param growth String 'linear', 'logistic', or 'flat' to specify a linear,
#' logistic or flat trend.
#' @param changepoints Vector of dates at which to include potential
#' changepoints. If not specified, potential changepoints are selected
#' automatically.
#' changepoints. If not specified, potential changepoints are selected
#' automatically.
#' @param n.changepoints Number of potential changepoints to include. Not used
#' if input `changepoints` is supplied. If `changepoints` is not supplied,
#' then n.changepoints potential changepoints are selected uniformly from the
#' first `changepoint.range` proportion of df$ds.
#' if input `changepoints` is supplied. If `changepoints` is not supplied,
#' then n.changepoints potential changepoints are selected uniformly from the
#' first `changepoint.range` proportion of df$ds.
#' @param changepoint.range Proportion of history in which trend changepoints
#' will be estimated. Defaults to 0.8 for the first 80%. Not used if
#' `changepoints` is specified.
#' @param yearly.seasonality Fit yearly seasonality. Can be 'auto', TRUE,
#' FALSE, or a number of Fourier terms to generate.
#' @param weekly.seasonality Fit weekly seasonality. Can be 'auto', TRUE,
#' FALSE, or a number of Fourier terms to generate.
#' @param daily.seasonality Fit daily seasonality. Can be 'auto', TRUE,
#' FALSE, or a number of Fourier terms to generate.
#' will be estimated. Defaults to 0.8 for the first 80%. Not used if
#' `changepoints` is specified.
#' @param yearly.seasonality Fit yearly seasonality. Can be 'auto', TRUE, FALSE,
#' or a number of Fourier terms to generate.
#' @param weekly.seasonality Fit weekly seasonality. Can be 'auto', TRUE, FALSE,
#' or a number of Fourier terms to generate.
#' @param daily.seasonality Fit daily seasonality. Can be 'auto', TRUE, FALSE,
#' or a number of Fourier terms to generate.
#' @param holidays data frame with columns holiday (character) and ds (date
#' type)and optionally columns lower_window and upper_window which specify a
#' range of days around the date to be included as holidays. lower_window=-2
#' will include 2 days prior to the date as holidays. Also optionally can have
#' a column prior_scale specifying the prior scale for each holiday.
#' type)and optionally columns lower_window and upper_window which specify a
#' range of days around the date to be included as holidays. lower_window=-2
#' will include 2 days prior to the date as holidays. Also optionally can have
#' a column prior_scale specifying the prior scale for each holiday.
#' @param seasonality.mode 'additive' (default) or 'multiplicative'.
#' @param seasonality.prior.scale Parameter modulating the strength of the
#' seasonality model. Larger values allow the model to fit larger seasonal
#' fluctuations, smaller values dampen the seasonality. Can be specified for
#' individual seasonalities using add_seasonality.
#' seasonality model. Larger values allow the model to fit larger seasonal
#' fluctuations, smaller values dampen the seasonality. Can be specified for
#' individual seasonalities using add_seasonality.
#' @param holidays.prior.scale Parameter modulating the strength of the holiday
#' components model, unless overridden in the holidays input.
#' components model, unless overridden in the holidays input.
#' @param changepoint.prior.scale Parameter modulating the flexibility of the
#' automatic changepoint selection. Large values will allow many changepoints,
#' small values will allow few changepoints.
#' automatic changepoint selection. Large values will allow many changepoints,
#' small values will allow few changepoints.
#' @param mcmc.samples Integer, if greater than 0, will do full Bayesian
#' inference with the specified number of MCMC samples. If 0, will do MAP
#' estimation.
#' inference with the specified number of MCMC samples. If 0, will do MAP
#' estimation.
#' @param interval.width Numeric, width of the uncertainty intervals provided
#' for the forecast. If mcmc.samples=0, this will be only the uncertainty
#' in the trend using the MAP estimate of the extrapolated generative model.
#' If mcmc.samples>0, this will be integrated over all model parameters,
#' which will include uncertainty in seasonality.
#' for the forecast. If mcmc.samples=0, this will be only the uncertainty in
#' the trend using the MAP estimate of the extrapolated generative model. If
#' mcmc.samples>0, this will be integrated over all model parameters, which
#' will include uncertainty in seasonality.
#' @param uncertainty.samples Number of simulated draws used to estimate
#' uncertainty intervals. Settings this value to 0 or False will disable
#' uncertainty estimation and speed up the calculation.
#' uncertainty intervals. Settings this value to 0 or False will disable
#' uncertainty estimation and speed up the calculation.
#' @param backend Whether to use the "rstan" or "cmdstanr" backend to fit the
#' model. If not provided, uses the R_STAN_BACKEND environment variable.
#' @param fit Boolean, if FALSE the model is initialized but not fit.
#' @param ... Additional arguments, passed to \code{\link{fit.prophet}}
#'
@ -99,12 +101,15 @@ prophet <- function(df = NULL,
interval.width = 0.80,
uncertainty.samples = 1000,
fit = TRUE,
backend = NULL,
...
) {
if (!is.null(changepoints)) {
n.changepoints <- length(changepoints)
}
if (is.null(backend)) backend <- get_stan_backend()
m <- list(
growth = growth,
changepoints = changepoints,
@ -121,6 +126,7 @@ prophet <- function(df = NULL,
mcmc.samples = mcmc.samples,
interval.width = interval.width,
uncertainty.samples = uncertainty.samples,
backend = backend,
specified.changepoints = !is.null(changepoints),
start = NULL, # This and following attributes are set during fitting
y.scale = NULL,
@ -1068,7 +1074,7 @@ set_auto_seasonalities <- function(m) {
#' Initialize flat growth.
#'
#' Provides a strong initialization for flat growth by setting the
#' growth to 0 and calculates the offset parameter that pass the
#' growth to 0 and calculates the offset parameter that pass the
#' function through the mean of the the y_scaled values.
#'
#' @param df Data frame with columns ds (date), y_scaled (scaled time series),
@ -1227,11 +1233,7 @@ fit.prophet <- function(m, df, ...) {
kinit <- logistic_growth_init(history)
}
if (exists(".prophet.stan.model", where = prophet_model_env)) {
model <- get('.prophet.stan.model', envir = prophet_model_env)
} else {
model <- stanmodels$prophet
}
model <- .load_model(m$backend)
stan_init <- function() {
list(k = kinit[1],
@ -1242,45 +1244,24 @@ fit.prophet <- function(m, df, ...) {
)
}
if (min(history$y) == max(history$y) &
if (min(history$y) == max(history$y) &
(m$growth %in% c('linear', 'flat'))) {
# Nothing to fit.
m$params <- stan_init()
m$params$sigma_obs <- 0.
n.iteration <- 1.
} else if (m$mcmc.samples > 0) {
args <- list(
object = model,
data = dat,
init = stan_init,
iter = m$mcmc.samples
)
args <- utils::modifyList(args, list(...))
m$stan.fit <- do.call(rstan::sampling, args)
m$params <- rstan::extract(m$stan.fit)
n.iteration <- length(m$params$k)
} else {
args <- list(
object = model,
data = dat,
init = stan_init,
algorithm = if(dat$T < 100) {'Newton'} else {'LBFGS'},
iter = 1e4,
as_vector = FALSE
)
args <- utils::modifyList(args, list(...))
m$stan.fit <- do.call(rstan::optimizing, args)
if (m$stan.fit$return_code != 0) {
message(
'Optimization terminated abnormally. Falling back to Newton optimizer.'
)
args$algorithm = 'Newton'
m$stan.fit <- do.call(rstan::optimizing, args)
if (m$mcmc.samples > 0) {
args <- .stan_args(model, dat, stan_init, m$backend, type = "mcmc", m$mcmc.samples, ...)
model_output <- .sampling(args, m$backend)
} else {
args <- .stan_args(model, dat, stan_init, m$backend, type = "optimize", ...)
model_output <- .fit(args, m$backend)
}
m$params <- m$stan.fit$par
n.iteration <- 1
m$stan.fit <- model_output$stan_fit
m$params <- model_output$params
n.iteration <- model_output$n_iteration
}
# Cast the parameters to have consistent form, whether full bayes or MAP
for (name in c('delta', 'beta')){
m$params[[name]] <- matrix(m$params[[name]], nrow = n.iteration)

256
R/R/stan_backends.R Normal file
View file

@ -0,0 +1,256 @@
#' Get the stan backend defined in the environment variables.
#'
#' @return 'rstan' or 'cmdstanr'. 'rstan' if variable is not set.
#' @keywords internal
get_stan_backend <- function() {
backend_setting <- Sys.getenv("R_STAN_BACKEND", "RSTAN")
if (backend_setting %in% c("RSTAN", "CMDSTANR")) {
backend <- switch(
backend_setting,
"RSTAN" = "rstan",
"CMDSTANR" = "cmdstanr"
)
if (backend == "cmdstanr") check_cmdstanr()
return(backend)
} else {
return("rstan")
}
}
#' Check that the required packages for using the cmdstanr backend are installed.
#'
#' @return NULL if successful, and prints the current version of cmdstan being used.
#' @keywords internal
check_cmdstanr <- function() {
if (!requireNamespace("cmdstanr", quietly = TRUE)) {
stop(
"Package \"cmdstanr\" needed to use cmdstanr backend. See installation instructions: https://mc-stan.org/cmdstanr/.",
call. = FALSE
)
}
if (!requireNamespace("posterior", quietly = TRUE)) {
stop(
"Package \"posterior\" needed to use cmdstanr backend. See installation instructions: https://mc-stan.org/posterior/.",
call. = FALSE
)
}
cmdstanr_version <- cmdstanr::cmdstan_version()
return(invisible(TRUE))
}
#' Load the Prophet Stan model.
#'
#' @param backend "rstan" or "cmdstanr".
#'
#' @return stanmodel object if backend = "rstan", CmdStanModel object if backend = "cmdstanr"
#' @keywords internal
.load_model <- function(backend) {
switch(
backend,
"rstan" = .load_model_rstan(),
"cmdstanr" = .load_model_cmdstanr()
)
}
#' @rdname .load_model
.load_model_rstan <- function() {
if (exists(".prophet.stan.model", where = prophet_model_env)) {
model <- get('.prophet.stan.model', envir = prophet_model_env)
} else {
model <- stanmodels$prophet
}
return(model)
}
#' @rdname .load_model
.load_model_cmdstanr <- function() {
model_file <- system.file(
"stan",
"prophet.stan",
package = "prophet",
mustWork = TRUE
)
model <- cmdstanr::cmdstan_model(model_file)
return(model)
}
#' Gives Stan arguments the appropriate names depending on the chosen Stan backend.
#'
#' @param model Model object.
#' @param dat List containing data to use in fitting.
#' @param stan_init Function to initialize parameters for stan fit.
#' @param backend "rstan" or "cmdstanr".
#' @param type "mcmc" or "optimize".
#' @param mcmc_samples Integer, if greater than 0, will do full Bayesian
#' inference with the specified number of MCMC samples. If 0, will do MAP
#' estimation.
#'
#' @return Named list of arguments.
#' @keywords internal
.stan_args <- function(model, dat, stan_init, backend, type, mcmc_samples = 0, ...) {
args <- switch(
backend,
"rstan" = .stan_args_rstan(model, dat, stan_init, type, mcmc_samples),
"cmdstanr" = .stan_args_cmdstanr(model, dat, stan_init, type, mcmc_samples)
)
args <- utils::modifyList(args, list(...))
return(args)
}
#' @rdname .stan_args
.stan_args_rstan <- function(model, dat, stan_init, type, mcmc_samples = NULL) {
if (type == "mcmc") {
args <- list(
object = model,
data = dat,
init = stan_init,
iter = mcmc_samples,
chains = 4
)
} else if (type == "optimize") {
args <- list(
object = model,
data = dat,
init = stan_init,
algorithm = if(dat$T < 100) {'Newton'} else {'LBFGS'},
iter = 1e4,
as_vector = FALSE
)
}
return(args)
}
#' @rdname .stan_args
.stan_args_cmdstanr <- function(model, dat, stan_init, type, mcmc_samples = NULL) {
if (type == "mcmc") {
args <- list(
object = model,
data = dat,
init = stan_init,
iter_warmup = mcmc_samples / 2,
iter_sampling = mcmc_samples / 2,
chains = 4,
refresh = 0,
show_messages = FALSE
)
} else if (type == "optimize") {
args <- list(
object = model,
data = dat,
init = stan_init,
algorithm = if(dat$T < 100) {'newton'} else {'lbfgs'},
iter = 1e4,
refresh = 0
)
}
return(args)
}
#' Obtain the point estimates of the parameters of the Prophet model using
#' stan's optimization algorithms.
#'
#' @param args Named list of arguments suitable for the chosen backend. Must
#' include arguments required for optimization.
#' @param backend "rstan" or "cmdstanr".
#'
#' @return A named list containing "stan_fit" (the fitted stan object),
#' "params", and "n_iteration"
#' @keywords internal
.fit <- function(args, backend) {
switch(
backend,
"rstan" = .fit_rstan(args),
"cmdstanr" = .fit_cmdstanr(args)
)
}
#' Obtain the joint posterior distribution of the parameters of the Prophet
#' model using MCMC sampling.
#'
#' @param args Named list of arguments suitable for the chosen backend. Must
#' include arguments required for MCMC sampling.
#' @param backend "rstan" or "cmdstanr".
#'
#' @return A named list containing "stan_fit" (the fitted stan object),
#' "params", and "n_iteration"
#' @keywords internal
.sampling <- function(args, backend) {
switch(
backend,
"rstan" = .sampling_rstan(args),
"cmdstanr" = .sampling_cmdstanr(args)
)
}
#' @rdname .fit
.fit_rstan <- function(args) {
model_output <- list()
model_output$stan_fit <- do.call(rstan::optimizing, args)
if (model_output$stan_fit$return_code != 0) {
message(
'Optimization terminated abnormally. Falling back to Newton optimizer.'
)
args$algorithm = 'Newton'
model_output$stan_fit <- do.call(rstan::optimizing, args)
}
model_output$params <- model_output$stan_fit$par
model_output$n_iteration <- 1
return(model_output)
}
#' @rdname .sampling
.sampling_rstan <- function(args) {
model_output <- list()
model_output$stan_fit <- do.call(rstan::sampling, args)
model_output$params <- rstan::extract(model_output$stan_fit)
model_output$n_iteration <- length(model_output$params$k)
return(model_output)
}
#' @rdname .fit
.fit_cmdstanr <- function(args) {
# TODO: Replace with method to extract parameter names once implemented in cmdstanr
param_names <- c("k", "m", "delta", "sigma_obs", "beta", "trend")
model_output <- list()
model <- args$object
args$object <- NULL
model_output$stan_fit <- do.call(model$optimize, args)
if (model_output$stan_fit$return_codes()[1] != 0) {
message(
'Optimization terminated abnormally. Falling back to Newton optimizer.'
)
args$algorithm = 'newton'
model_output$stan_fit <- do.call(model$optimize, args)
}
model_output$params <- list()
for (name in param_names) {
model_output$params[[name]] <- unname(model_output$stan_fit$mle(name))
}
model_output$n_iteration <- 1
return(model_output)
}
#' @rdname .sampling
.sampling_cmdstanr <- function(args) {
param_names <- c("k", "m", "delta", "sigma_obs", "beta", "trend")
model_output <- list()
model <- args$object
args$object <- NULL
param_names <- c(param_names, "lp__")
model_output$stan_fit <- do.call(model$sample, args)
model_output$params <- list()
for (name in param_names) {
model_output$params[[name]] <- posterior::as_draws_matrix(model_output$stan_fit$draws(name))
}
model_output$n_iteration <- nrow(model_output$params$k)
return(model_output)
}

15
R/man/check_cmdstanr.Rd Normal file
View file

@ -0,0 +1,15 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/stan_backends.R
\name{check_cmdstanr}
\alias{check_cmdstanr}
\title{Check that the required packages for using the cmdstanr backend are installed.}
\usage{
check_cmdstanr()
}
\value{
NULL if successful, and prints the current version of cmdstan being used.
}
\description{
Check that the required packages for using the cmdstanr backend are installed.
}
\keyword{internal}

24
R/man/dot-fit.Rd Normal file
View file

@ -0,0 +1,24 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/stan_backends.R
\name{.fit}
\alias{.fit}
\title{Obtain the point estimates of the parameters of the Prophet model using
stan's optimization algorithms.}
\usage{
.fit(args, backend)
}
\arguments{
\item{args}{Named list of arguments suitable for the chosen backend. Must
include arguments required for optimization.}
\item{backend}{"rstan" or "cmdstanr".}
}
\value{
A named list containing "stan_fit" (the fitted stan object),
"params", and "n_iteration"
}
\description{
Obtain the point estimates of the parameters of the Prophet model using
stan's optimization algorithms.
}
\keyword{internal}

18
R/man/dot-load_model.Rd Normal file
View file

@ -0,0 +1,18 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/stan_backends.R
\name{.load_model}
\alias{.load_model}
\title{Load the Prophet Stan model.}
\usage{
.load_model(backend)
}
\arguments{
\item{backend}{"rstan" or "cmdstanr".}
}
\value{
stanmodel object if backend = "rstan", CmdStanModel object if backend = "cmdstanr"
}
\description{
Load the Prophet Stan model.
}
\keyword{internal}

24
R/man/dot-sampling.Rd Normal file
View file

@ -0,0 +1,24 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/stan_backends.R
\name{.sampling}
\alias{.sampling}
\title{Obtain the joint posterior distribution of the parameters of the Prophet
model using MCMC sampling.}
\usage{
.sampling(args, backend)
}
\arguments{
\item{args}{Named list of arguments suitable for the chosen backend. Must
include arguments required for MCMC sampling.}
\item{backend}{"rstan" or "cmdstanr".}
}
\value{
A named list containing "stan_fit" (the fitted stan object),
"params", and "n_iteration"
}
\description{
Obtain the joint posterior distribution of the parameters of the Prophet
model using MCMC sampling.
}
\keyword{internal}

30
R/man/dot-stan_args.Rd Normal file
View file

@ -0,0 +1,30 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/stan_backends.R
\name{.stan_args}
\alias{.stan_args}
\title{Gives Stan arguments the appropriate names depending on the chosen Stan backend.}
\usage{
.stan_args(model, dat, stan_init, backend, type, mcmc_samples = 0, ...)
}
\arguments{
\item{model}{Model object.}
\item{dat}{List containing data to use in fitting.}
\item{stan_init}{Function to initialize parameters for stan fit.}
\item{backend}{"rstan" or "cmdstanr".}
\item{type}{"mcmc" or "optimize".}
\item{mcmc_samples}{Integer, if greater than 0, will do full Bayesian
inference with the specified number of MCMC samples. If 0, will do MAP
estimation.}
}
\value{
Named list of arguments.
}
\description{
Gives Stan arguments the appropriate names depending on the chosen Stan backend.
}
\keyword{internal}

View file

@ -16,7 +16,7 @@ A vector (k, m) with the rate (k) and offset (m) of the flat
}
\description{
Provides a strong initialization for flat growth by setting the
growth to 0 and calculates the offset parameter that pass the
growth to 0 and calculates the offset parameter that pass the
function through the mean of the the y_scaled values.
}
\keyword{internal}

15
R/man/get_stan_backend.Rd Normal file
View file

@ -0,0 +1,15 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/stan_backends.R
\name{get_stan_backend}
\alias{get_stan_backend}
\title{Get the stan backend defined in the environment variables.}
\usage{
get_stan_backend()
}
\value{
'rstan' or 'cmdstanr'. 'rstan' if variable is not set.
}
\description{
Get the stan backend defined in the environment variables.
}
\keyword{internal}

View file

@ -22,6 +22,7 @@ prophet(
interval.width = 0.8,
uncertainty.samples = 1000,
fit = TRUE,
backend = NULL,
...
)
}
@ -32,8 +33,8 @@ also have a column cap that specifies the capacity at each ds. If not
provided, then the model object will be instantiated but not fit; use
fit.prophet(m, df) to fit the model.}
\item{growth}{String 'linear', 'logistic', or 'flat' to specify a linear, logistic
or flat trend.}
\item{growth}{String 'linear', 'logistic', or 'flat' to specify a linear,
logistic or flat trend.}
\item{changepoints}{Vector of dates at which to include potential
changepoints. If not specified, potential changepoints are selected
@ -48,14 +49,14 @@ first `changepoint.range` proportion of df$ds.}
will be estimated. Defaults to 0.8 for the first 80%. Not used if
`changepoints` is specified.}
\item{yearly.seasonality}{Fit yearly seasonality. Can be 'auto', TRUE,
FALSE, or a number of Fourier terms to generate.}
\item{yearly.seasonality}{Fit yearly seasonality. Can be 'auto', TRUE, FALSE,
or a number of Fourier terms to generate.}
\item{weekly.seasonality}{Fit weekly seasonality. Can be 'auto', TRUE,
FALSE, or a number of Fourier terms to generate.}
\item{weekly.seasonality}{Fit weekly seasonality. Can be 'auto', TRUE, FALSE,
or a number of Fourier terms to generate.}
\item{daily.seasonality}{Fit daily seasonality. Can be 'auto', TRUE,
FALSE, or a number of Fourier terms to generate.}
\item{daily.seasonality}{Fit daily seasonality. Can be 'auto', TRUE, FALSE,
or a number of Fourier terms to generate.}
\item{holidays}{data frame with columns holiday (character) and ds (date
type)and optionally columns lower_window and upper_window which specify a
@ -82,10 +83,10 @@ inference with the specified number of MCMC samples. If 0, will do MAP
estimation.}
\item{interval.width}{Numeric, width of the uncertainty intervals provided
for the forecast. If mcmc.samples=0, this will be only the uncertainty
in the trend using the MAP estimate of the extrapolated generative model.
If mcmc.samples>0, this will be integrated over all model parameters,
which will include uncertainty in seasonality.}
for the forecast. If mcmc.samples=0, this will be only the uncertainty in
the trend using the MAP estimate of the extrapolated generative model. If
mcmc.samples>0, this will be integrated over all model parameters, which
will include uncertainty in seasonality.}
\item{uncertainty.samples}{Number of simulated draws used to estimate
uncertainty intervals. Settings this value to 0 or False will disable
@ -93,6 +94,9 @@ uncertainty estimation and speed up the calculation.}
\item{fit}{Boolean, if FALSE the model is initialized but not fit.}
\item{backend}{Whether to use the "rstan" or "cmdstanr" backend to fit the
model. If not provided, uses the R_STAN_BACKEND environment variable.}
\item{...}{Additional arguments, passed to \code{\link{fit.prophet}}}
}
\value{

View file

@ -121,7 +121,7 @@ test_that("cross_validation_uncertainty_disabled", {
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
for (uncertainty in c(0, FALSE)) {
m <- prophet(uncertainty.samples = uncertainty)
m <- fit.prophet(m = m, df = DATA, algorithm = "Newton")
m <- fit.prophet(m = m, df = DATA)
df.cv <- cross_validation(
m, horizon = 4, units = "days", period = 4, initial = 115)
expected.cols <- c('y', 'ds', 'yhat', 'cutoff')

View file

@ -125,11 +125,10 @@ test_that("logistic_floor", {
future1 <- future
future1$cap <- 80.
future1$floor <- 10.
m <- fit.prophet(m, history, algorithm = 'Newton')
m <- fit.prophet(m, history)
expect_true(m$logistic.floor)
expect_true('floor' %in% colnames(m$history))
expect_equal(m$history$y_scaled[1], 1., tolerance = 1e-6)
expect_equal(m$fit.kwargs, list(algorithm = 'Newton'))
fcst1 <- predict(m, future1)
m2 <- prophet(growth = 'logistic')
@ -139,7 +138,7 @@ test_that("logistic_floor", {
history2$cap <- history2$cap + 10.
future1$cap <- future1$cap + 10.
future1$floor <- future1$floor + 10.
m2 <- fit.prophet(m2, history2, algorithm = 'Newton')
m2 <- fit.prophet(m2, history2)
expect_equal(m2$history$y_scaled[1], 1., tolerance = 1e-6)
})