From a794018d654402ab6a97cb262e80d347db3485bd Mon Sep 17 00:00:00 2001 From: Cuong Duong Date: Tue, 11 May 2021 09:09:25 +1000 Subject: [PATCH] Add support for cmdstanr backend (#1806) * init * add suggested packages * use environment variables and align more with Py package * remove additional testing logic, default to lbfgs * Remove Newton specifier from test because cmdstanr expects newton Co-authored-by: Ben Letham --- R/DESCRIPTION | 6 +- R/R/prophet.R | 129 ++++++-------- R/R/stan_backends.R | 256 ++++++++++++++++++++++++++++ R/man/check_cmdstanr.Rd | 15 ++ R/man/dot-fit.Rd | 24 +++ R/man/dot-load_model.Rd | 18 ++ R/man/dot-sampling.Rd | 24 +++ R/man/dot-stan_args.Rd | 30 ++++ R/man/flat_growth_init.Rd | 2 +- R/man/get_stan_backend.Rd | 15 ++ R/man/prophet.Rd | 28 +-- R/tests/testthat/test_diagnostics.R | 2 +- R/tests/testthat/test_prophet.R | 5 +- 13 files changed, 462 insertions(+), 92 deletions(-) create mode 100644 R/R/stan_backends.R create mode 100644 R/man/check_cmdstanr.Rd create mode 100644 R/man/dot-fit.Rd create mode 100644 R/man/dot-load_model.Rd create mode 100644 R/man/dot-sampling.Rd create mode 100644 R/man/dot-stan_args.Rd create mode 100644 R/man/get_stan_backend.Rd diff --git a/R/DESCRIPTION b/R/DESCRIPTION index f0680a3..cbad205 100644 --- a/R/DESCRIPTION +++ b/R/DESCRIPTION @@ -35,14 +35,18 @@ Imports: tidyr (>= 0.6.1), xts Suggests: + cmdstanr, + posterior, knitr, testthat, readr, rmarkdown +Additional_repositories: + https://mc-stan.org/r-packages/ SystemRequirements: GNU make, C++11 Biarch: true License: MIT + file LICENSE -LinkingTo: +LinkingTo: BH (>= 1.66.0), Rcpp (>= 0.12.0), RcppParallel (>= 5.0.1), diff --git a/R/R/prophet.R b/R/R/prophet.R index 1f5af86..dd1efa1 100644 --- a/R/R/prophet.R +++ b/R/R/prophet.R @@ -16,54 +16,56 @@ globalVariables(c( #' Prophet forecaster. #' #' @param df (optional) Dataframe containing the history. Must have columns ds -#' (date type) and y, the time series. If growth is logistic, then df must -#' also have a column cap that specifies the capacity at each ds. If not -#' provided, then the model object will be instantiated but not fit; use -#' fit.prophet(m, df) to fit the model. -#' @param growth String 'linear', 'logistic', or 'flat' to specify a linear, logistic -#' or flat trend. +#' (date type) and y, the time series. If growth is logistic, then df must +#' also have a column cap that specifies the capacity at each ds. If not +#' provided, then the model object will be instantiated but not fit; use +#' fit.prophet(m, df) to fit the model. +#' @param growth String 'linear', 'logistic', or 'flat' to specify a linear, +#' logistic or flat trend. #' @param changepoints Vector of dates at which to include potential -#' changepoints. If not specified, potential changepoints are selected -#' automatically. +#' changepoints. If not specified, potential changepoints are selected +#' automatically. #' @param n.changepoints Number of potential changepoints to include. Not used -#' if input `changepoints` is supplied. If `changepoints` is not supplied, -#' then n.changepoints potential changepoints are selected uniformly from the -#' first `changepoint.range` proportion of df$ds. +#' if input `changepoints` is supplied. If `changepoints` is not supplied, +#' then n.changepoints potential changepoints are selected uniformly from the +#' first `changepoint.range` proportion of df$ds. #' @param changepoint.range Proportion of history in which trend changepoints -#' will be estimated. Defaults to 0.8 for the first 80%. Not used if -#' `changepoints` is specified. -#' @param yearly.seasonality Fit yearly seasonality. Can be 'auto', TRUE, -#' FALSE, or a number of Fourier terms to generate. -#' @param weekly.seasonality Fit weekly seasonality. Can be 'auto', TRUE, -#' FALSE, or a number of Fourier terms to generate. -#' @param daily.seasonality Fit daily seasonality. Can be 'auto', TRUE, -#' FALSE, or a number of Fourier terms to generate. +#' will be estimated. Defaults to 0.8 for the first 80%. Not used if +#' `changepoints` is specified. +#' @param yearly.seasonality Fit yearly seasonality. Can be 'auto', TRUE, FALSE, +#' or a number of Fourier terms to generate. +#' @param weekly.seasonality Fit weekly seasonality. Can be 'auto', TRUE, FALSE, +#' or a number of Fourier terms to generate. +#' @param daily.seasonality Fit daily seasonality. Can be 'auto', TRUE, FALSE, +#' or a number of Fourier terms to generate. #' @param holidays data frame with columns holiday (character) and ds (date -#' type)and optionally columns lower_window and upper_window which specify a -#' range of days around the date to be included as holidays. lower_window=-2 -#' will include 2 days prior to the date as holidays. Also optionally can have -#' a column prior_scale specifying the prior scale for each holiday. +#' type)and optionally columns lower_window and upper_window which specify a +#' range of days around the date to be included as holidays. lower_window=-2 +#' will include 2 days prior to the date as holidays. Also optionally can have +#' a column prior_scale specifying the prior scale for each holiday. #' @param seasonality.mode 'additive' (default) or 'multiplicative'. #' @param seasonality.prior.scale Parameter modulating the strength of the -#' seasonality model. Larger values allow the model to fit larger seasonal -#' fluctuations, smaller values dampen the seasonality. Can be specified for -#' individual seasonalities using add_seasonality. +#' seasonality model. Larger values allow the model to fit larger seasonal +#' fluctuations, smaller values dampen the seasonality. Can be specified for +#' individual seasonalities using add_seasonality. #' @param holidays.prior.scale Parameter modulating the strength of the holiday -#' components model, unless overridden in the holidays input. +#' components model, unless overridden in the holidays input. #' @param changepoint.prior.scale Parameter modulating the flexibility of the -#' automatic changepoint selection. Large values will allow many changepoints, -#' small values will allow few changepoints. +#' automatic changepoint selection. Large values will allow many changepoints, +#' small values will allow few changepoints. #' @param mcmc.samples Integer, if greater than 0, will do full Bayesian -#' inference with the specified number of MCMC samples. If 0, will do MAP -#' estimation. +#' inference with the specified number of MCMC samples. If 0, will do MAP +#' estimation. #' @param interval.width Numeric, width of the uncertainty intervals provided -#' for the forecast. If mcmc.samples=0, this will be only the uncertainty -#' in the trend using the MAP estimate of the extrapolated generative model. -#' If mcmc.samples>0, this will be integrated over all model parameters, -#' which will include uncertainty in seasonality. +#' for the forecast. If mcmc.samples=0, this will be only the uncertainty in +#' the trend using the MAP estimate of the extrapolated generative model. If +#' mcmc.samples>0, this will be integrated over all model parameters, which +#' will include uncertainty in seasonality. #' @param uncertainty.samples Number of simulated draws used to estimate -#' uncertainty intervals. Settings this value to 0 or False will disable -#' uncertainty estimation and speed up the calculation. +#' uncertainty intervals. Settings this value to 0 or False will disable +#' uncertainty estimation and speed up the calculation. +#' @param backend Whether to use the "rstan" or "cmdstanr" backend to fit the +#' model. If not provided, uses the R_STAN_BACKEND environment variable. #' @param fit Boolean, if FALSE the model is initialized but not fit. #' @param ... Additional arguments, passed to \code{\link{fit.prophet}} #' @@ -99,12 +101,15 @@ prophet <- function(df = NULL, interval.width = 0.80, uncertainty.samples = 1000, fit = TRUE, + backend = NULL, ... ) { if (!is.null(changepoints)) { n.changepoints <- length(changepoints) } + if (is.null(backend)) backend <- get_stan_backend() + m <- list( growth = growth, changepoints = changepoints, @@ -121,6 +126,7 @@ prophet <- function(df = NULL, mcmc.samples = mcmc.samples, interval.width = interval.width, uncertainty.samples = uncertainty.samples, + backend = backend, specified.changepoints = !is.null(changepoints), start = NULL, # This and following attributes are set during fitting y.scale = NULL, @@ -1068,7 +1074,7 @@ set_auto_seasonalities <- function(m) { #' Initialize flat growth. #' #' Provides a strong initialization for flat growth by setting the -#' growth to 0 and calculates the offset parameter that pass the +#' growth to 0 and calculates the offset parameter that pass the #' function through the mean of the the y_scaled values. #' #' @param df Data frame with columns ds (date), y_scaled (scaled time series), @@ -1227,11 +1233,7 @@ fit.prophet <- function(m, df, ...) { kinit <- logistic_growth_init(history) } - if (exists(".prophet.stan.model", where = prophet_model_env)) { - model <- get('.prophet.stan.model', envir = prophet_model_env) - } else { - model <- stanmodels$prophet - } + model <- .load_model(m$backend) stan_init <- function() { list(k = kinit[1], @@ -1242,45 +1244,24 @@ fit.prophet <- function(m, df, ...) { ) } - if (min(history$y) == max(history$y) & + if (min(history$y) == max(history$y) & (m$growth %in% c('linear', 'flat'))) { # Nothing to fit. m$params <- stan_init() m$params$sigma_obs <- 0. n.iteration <- 1. - } else if (m$mcmc.samples > 0) { - args <- list( - object = model, - data = dat, - init = stan_init, - iter = m$mcmc.samples - ) - args <- utils::modifyList(args, list(...)) - m$stan.fit <- do.call(rstan::sampling, args) - m$params <- rstan::extract(m$stan.fit) - n.iteration <- length(m$params$k) } else { - args <- list( - object = model, - data = dat, - init = stan_init, - algorithm = if(dat$T < 100) {'Newton'} else {'LBFGS'}, - iter = 1e4, - as_vector = FALSE - ) - args <- utils::modifyList(args, list(...)) - m$stan.fit <- do.call(rstan::optimizing, args) - if (m$stan.fit$return_code != 0) { - message( - 'Optimization terminated abnormally. Falling back to Newton optimizer.' - ) - args$algorithm = 'Newton' - m$stan.fit <- do.call(rstan::optimizing, args) + if (m$mcmc.samples > 0) { + args <- .stan_args(model, dat, stan_init, m$backend, type = "mcmc", m$mcmc.samples, ...) + model_output <- .sampling(args, m$backend) + } else { + args <- .stan_args(model, dat, stan_init, m$backend, type = "optimize", ...) + model_output <- .fit(args, m$backend) } - m$params <- m$stan.fit$par - n.iteration <- 1 + m$stan.fit <- model_output$stan_fit + m$params <- model_output$params + n.iteration <- model_output$n_iteration } - # Cast the parameters to have consistent form, whether full bayes or MAP for (name in c('delta', 'beta')){ m$params[[name]] <- matrix(m$params[[name]], nrow = n.iteration) diff --git a/R/R/stan_backends.R b/R/R/stan_backends.R new file mode 100644 index 0000000..b29cd28 --- /dev/null +++ b/R/R/stan_backends.R @@ -0,0 +1,256 @@ +#' Get the stan backend defined in the environment variables. +#' +#' @return 'rstan' or 'cmdstanr'. 'rstan' if variable is not set. +#' @keywords internal +get_stan_backend <- function() { + backend_setting <- Sys.getenv("R_STAN_BACKEND", "RSTAN") + if (backend_setting %in% c("RSTAN", "CMDSTANR")) { + backend <- switch( + backend_setting, + "RSTAN" = "rstan", + "CMDSTANR" = "cmdstanr" + ) + if (backend == "cmdstanr") check_cmdstanr() + return(backend) + } else { + return("rstan") + } +} + +#' Check that the required packages for using the cmdstanr backend are installed. +#' +#' @return NULL if successful, and prints the current version of cmdstan being used. +#' @keywords internal +check_cmdstanr <- function() { + if (!requireNamespace("cmdstanr", quietly = TRUE)) { + stop( + "Package \"cmdstanr\" needed to use cmdstanr backend. See installation instructions: https://mc-stan.org/cmdstanr/.", + call. = FALSE + ) + } + if (!requireNamespace("posterior", quietly = TRUE)) { + stop( + "Package \"posterior\" needed to use cmdstanr backend. See installation instructions: https://mc-stan.org/posterior/.", + call. = FALSE + ) + } + cmdstanr_version <- cmdstanr::cmdstan_version() + return(invisible(TRUE)) +} + +#' Load the Prophet Stan model. +#' +#' @param backend "rstan" or "cmdstanr". +#' +#' @return stanmodel object if backend = "rstan", CmdStanModel object if backend = "cmdstanr" +#' @keywords internal +.load_model <- function(backend) { + switch( + backend, + "rstan" = .load_model_rstan(), + "cmdstanr" = .load_model_cmdstanr() + ) +} + +#' @rdname .load_model +.load_model_rstan <- function() { + if (exists(".prophet.stan.model", where = prophet_model_env)) { + model <- get('.prophet.stan.model', envir = prophet_model_env) + } else { + model <- stanmodels$prophet + } + + return(model) +} + +#' @rdname .load_model +.load_model_cmdstanr <- function() { + model_file <- system.file( + "stan", + "prophet.stan", + package = "prophet", + mustWork = TRUE + ) + model <- cmdstanr::cmdstan_model(model_file) + + return(model) +} + +#' Gives Stan arguments the appropriate names depending on the chosen Stan backend. +#' +#' @param model Model object. +#' @param dat List containing data to use in fitting. +#' @param stan_init Function to initialize parameters for stan fit. +#' @param backend "rstan" or "cmdstanr". +#' @param type "mcmc" or "optimize". +#' @param mcmc_samples Integer, if greater than 0, will do full Bayesian +#' inference with the specified number of MCMC samples. If 0, will do MAP +#' estimation. +#' +#' @return Named list of arguments. +#' @keywords internal +.stan_args <- function(model, dat, stan_init, backend, type, mcmc_samples = 0, ...) { + args <- switch( + backend, + "rstan" = .stan_args_rstan(model, dat, stan_init, type, mcmc_samples), + "cmdstanr" = .stan_args_cmdstanr(model, dat, stan_init, type, mcmc_samples) + ) + args <- utils::modifyList(args, list(...)) + + return(args) +} + +#' @rdname .stan_args +.stan_args_rstan <- function(model, dat, stan_init, type, mcmc_samples = NULL) { + if (type == "mcmc") { + args <- list( + object = model, + data = dat, + init = stan_init, + iter = mcmc_samples, + chains = 4 + ) + } else if (type == "optimize") { + args <- list( + object = model, + data = dat, + init = stan_init, + algorithm = if(dat$T < 100) {'Newton'} else {'LBFGS'}, + iter = 1e4, + as_vector = FALSE + ) + } + + return(args) +} + +#' @rdname .stan_args +.stan_args_cmdstanr <- function(model, dat, stan_init, type, mcmc_samples = NULL) { + if (type == "mcmc") { + args <- list( + object = model, + data = dat, + init = stan_init, + iter_warmup = mcmc_samples / 2, + iter_sampling = mcmc_samples / 2, + chains = 4, + refresh = 0, + show_messages = FALSE + ) + } else if (type == "optimize") { + args <- list( + object = model, + data = dat, + init = stan_init, + algorithm = if(dat$T < 100) {'newton'} else {'lbfgs'}, + iter = 1e4, + refresh = 0 + ) + } + + return(args) +} + +#' Obtain the point estimates of the parameters of the Prophet model using +#' stan's optimization algorithms. +#' +#' @param args Named list of arguments suitable for the chosen backend. Must +#' include arguments required for optimization. +#' @param backend "rstan" or "cmdstanr". +#' +#' @return A named list containing "stan_fit" (the fitted stan object), +#' "params", and "n_iteration" +#' @keywords internal +.fit <- function(args, backend) { + switch( + backend, + "rstan" = .fit_rstan(args), + "cmdstanr" = .fit_cmdstanr(args) + ) +} + +#' Obtain the joint posterior distribution of the parameters of the Prophet +#' model using MCMC sampling. +#' +#' @param args Named list of arguments suitable for the chosen backend. Must +#' include arguments required for MCMC sampling. +#' @param backend "rstan" or "cmdstanr". +#' +#' @return A named list containing "stan_fit" (the fitted stan object), +#' "params", and "n_iteration" +#' @keywords internal +.sampling <- function(args, backend) { + switch( + backend, + "rstan" = .sampling_rstan(args), + "cmdstanr" = .sampling_cmdstanr(args) + ) +} + +#' @rdname .fit +.fit_rstan <- function(args) { + model_output <- list() + model_output$stan_fit <- do.call(rstan::optimizing, args) + if (model_output$stan_fit$return_code != 0) { + message( + 'Optimization terminated abnormally. Falling back to Newton optimizer.' + ) + args$algorithm = 'Newton' + model_output$stan_fit <- do.call(rstan::optimizing, args) + } + model_output$params <- model_output$stan_fit$par + model_output$n_iteration <- 1 + + return(model_output) +} + +#' @rdname .sampling +.sampling_rstan <- function(args) { + model_output <- list() + model_output$stan_fit <- do.call(rstan::sampling, args) + model_output$params <- rstan::extract(model_output$stan_fit) + model_output$n_iteration <- length(model_output$params$k) + + return(model_output) +} + +#' @rdname .fit +.fit_cmdstanr <- function(args) { + # TODO: Replace with method to extract parameter names once implemented in cmdstanr + param_names <- c("k", "m", "delta", "sigma_obs", "beta", "trend") + model_output <- list() + model <- args$object + args$object <- NULL + model_output$stan_fit <- do.call(model$optimize, args) + if (model_output$stan_fit$return_codes()[1] != 0) { + message( + 'Optimization terminated abnormally. Falling back to Newton optimizer.' + ) + args$algorithm = 'newton' + model_output$stan_fit <- do.call(model$optimize, args) + } + model_output$params <- list() + for (name in param_names) { + model_output$params[[name]] <- unname(model_output$stan_fit$mle(name)) + } + model_output$n_iteration <- 1 + + return(model_output) +} + +#' @rdname .sampling +.sampling_cmdstanr <- function(args) { + param_names <- c("k", "m", "delta", "sigma_obs", "beta", "trend") + model_output <- list() + model <- args$object + args$object <- NULL + param_names <- c(param_names, "lp__") + model_output$stan_fit <- do.call(model$sample, args) + model_output$params <- list() + for (name in param_names) { + model_output$params[[name]] <- posterior::as_draws_matrix(model_output$stan_fit$draws(name)) + } + model_output$n_iteration <- nrow(model_output$params$k) + + return(model_output) +} diff --git a/R/man/check_cmdstanr.Rd b/R/man/check_cmdstanr.Rd new file mode 100644 index 0000000..0e0ebe4 --- /dev/null +++ b/R/man/check_cmdstanr.Rd @@ -0,0 +1,15 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/stan_backends.R +\name{check_cmdstanr} +\alias{check_cmdstanr} +\title{Check that the required packages for using the cmdstanr backend are installed.} +\usage{ +check_cmdstanr() +} +\value{ +NULL if successful, and prints the current version of cmdstan being used. +} +\description{ +Check that the required packages for using the cmdstanr backend are installed. +} +\keyword{internal} diff --git a/R/man/dot-fit.Rd b/R/man/dot-fit.Rd new file mode 100644 index 0000000..1b57717 --- /dev/null +++ b/R/man/dot-fit.Rd @@ -0,0 +1,24 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/stan_backends.R +\name{.fit} +\alias{.fit} +\title{Obtain the point estimates of the parameters of the Prophet model using +stan's optimization algorithms.} +\usage{ +.fit(args, backend) +} +\arguments{ +\item{args}{Named list of arguments suitable for the chosen backend. Must +include arguments required for optimization.} + +\item{backend}{"rstan" or "cmdstanr".} +} +\value{ +A named list containing "stan_fit" (the fitted stan object), + "params", and "n_iteration" +} +\description{ +Obtain the point estimates of the parameters of the Prophet model using +stan's optimization algorithms. +} +\keyword{internal} diff --git a/R/man/dot-load_model.Rd b/R/man/dot-load_model.Rd new file mode 100644 index 0000000..10315e0 --- /dev/null +++ b/R/man/dot-load_model.Rd @@ -0,0 +1,18 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/stan_backends.R +\name{.load_model} +\alias{.load_model} +\title{Load the Prophet Stan model.} +\usage{ +.load_model(backend) +} +\arguments{ +\item{backend}{"rstan" or "cmdstanr".} +} +\value{ +stanmodel object if backend = "rstan", CmdStanModel object if backend = "cmdstanr" +} +\description{ +Load the Prophet Stan model. +} +\keyword{internal} diff --git a/R/man/dot-sampling.Rd b/R/man/dot-sampling.Rd new file mode 100644 index 0000000..18acdf9 --- /dev/null +++ b/R/man/dot-sampling.Rd @@ -0,0 +1,24 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/stan_backends.R +\name{.sampling} +\alias{.sampling} +\title{Obtain the joint posterior distribution of the parameters of the Prophet +model using MCMC sampling.} +\usage{ +.sampling(args, backend) +} +\arguments{ +\item{args}{Named list of arguments suitable for the chosen backend. Must +include arguments required for MCMC sampling.} + +\item{backend}{"rstan" or "cmdstanr".} +} +\value{ +A named list containing "stan_fit" (the fitted stan object), + "params", and "n_iteration" +} +\description{ +Obtain the joint posterior distribution of the parameters of the Prophet +model using MCMC sampling. +} +\keyword{internal} diff --git a/R/man/dot-stan_args.Rd b/R/man/dot-stan_args.Rd new file mode 100644 index 0000000..727aab9 --- /dev/null +++ b/R/man/dot-stan_args.Rd @@ -0,0 +1,30 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/stan_backends.R +\name{.stan_args} +\alias{.stan_args} +\title{Gives Stan arguments the appropriate names depending on the chosen Stan backend.} +\usage{ +.stan_args(model, dat, stan_init, backend, type, mcmc_samples = 0, ...) +} +\arguments{ +\item{model}{Model object.} + +\item{dat}{List containing data to use in fitting.} + +\item{stan_init}{Function to initialize parameters for stan fit.} + +\item{backend}{"rstan" or "cmdstanr".} + +\item{type}{"mcmc" or "optimize".} + +\item{mcmc_samples}{Integer, if greater than 0, will do full Bayesian +inference with the specified number of MCMC samples. If 0, will do MAP +estimation.} +} +\value{ +Named list of arguments. +} +\description{ +Gives Stan arguments the appropriate names depending on the chosen Stan backend. +} +\keyword{internal} diff --git a/R/man/flat_growth_init.Rd b/R/man/flat_growth_init.Rd index a1a0002..c2c7288 100644 --- a/R/man/flat_growth_init.Rd +++ b/R/man/flat_growth_init.Rd @@ -16,7 +16,7 @@ A vector (k, m) with the rate (k) and offset (m) of the flat } \description{ Provides a strong initialization for flat growth by setting the -growth to 0 and calculates the offset parameter that pass the +growth to 0 and calculates the offset parameter that pass the function through the mean of the the y_scaled values. } \keyword{internal} diff --git a/R/man/get_stan_backend.Rd b/R/man/get_stan_backend.Rd new file mode 100644 index 0000000..6cb38f3 --- /dev/null +++ b/R/man/get_stan_backend.Rd @@ -0,0 +1,15 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/stan_backends.R +\name{get_stan_backend} +\alias{get_stan_backend} +\title{Get the stan backend defined in the environment variables.} +\usage{ +get_stan_backend() +} +\value{ +'rstan' or 'cmdstanr'. 'rstan' if variable is not set. +} +\description{ +Get the stan backend defined in the environment variables. +} +\keyword{internal} diff --git a/R/man/prophet.Rd b/R/man/prophet.Rd index 10d4407..86bd9d8 100644 --- a/R/man/prophet.Rd +++ b/R/man/prophet.Rd @@ -22,6 +22,7 @@ prophet( interval.width = 0.8, uncertainty.samples = 1000, fit = TRUE, + backend = NULL, ... ) } @@ -32,8 +33,8 @@ also have a column cap that specifies the capacity at each ds. If not provided, then the model object will be instantiated but not fit; use fit.prophet(m, df) to fit the model.} -\item{growth}{String 'linear', 'logistic', or 'flat' to specify a linear, logistic -or flat trend.} +\item{growth}{String 'linear', 'logistic', or 'flat' to specify a linear, +logistic or flat trend.} \item{changepoints}{Vector of dates at which to include potential changepoints. If not specified, potential changepoints are selected @@ -48,14 +49,14 @@ first `changepoint.range` proportion of df$ds.} will be estimated. Defaults to 0.8 for the first 80%. Not used if `changepoints` is specified.} -\item{yearly.seasonality}{Fit yearly seasonality. Can be 'auto', TRUE, -FALSE, or a number of Fourier terms to generate.} +\item{yearly.seasonality}{Fit yearly seasonality. Can be 'auto', TRUE, FALSE, +or a number of Fourier terms to generate.} -\item{weekly.seasonality}{Fit weekly seasonality. Can be 'auto', TRUE, -FALSE, or a number of Fourier terms to generate.} +\item{weekly.seasonality}{Fit weekly seasonality. Can be 'auto', TRUE, FALSE, +or a number of Fourier terms to generate.} -\item{daily.seasonality}{Fit daily seasonality. Can be 'auto', TRUE, -FALSE, or a number of Fourier terms to generate.} +\item{daily.seasonality}{Fit daily seasonality. Can be 'auto', TRUE, FALSE, +or a number of Fourier terms to generate.} \item{holidays}{data frame with columns holiday (character) and ds (date type)and optionally columns lower_window and upper_window which specify a @@ -82,10 +83,10 @@ inference with the specified number of MCMC samples. If 0, will do MAP estimation.} \item{interval.width}{Numeric, width of the uncertainty intervals provided -for the forecast. If mcmc.samples=0, this will be only the uncertainty -in the trend using the MAP estimate of the extrapolated generative model. -If mcmc.samples>0, this will be integrated over all model parameters, -which will include uncertainty in seasonality.} +for the forecast. If mcmc.samples=0, this will be only the uncertainty in +the trend using the MAP estimate of the extrapolated generative model. If +mcmc.samples>0, this will be integrated over all model parameters, which +will include uncertainty in seasonality.} \item{uncertainty.samples}{Number of simulated draws used to estimate uncertainty intervals. Settings this value to 0 or False will disable @@ -93,6 +94,9 @@ uncertainty estimation and speed up the calculation.} \item{fit}{Boolean, if FALSE the model is initialized but not fit.} +\item{backend}{Whether to use the "rstan" or "cmdstanr" backend to fit the +model. If not provided, uses the R_STAN_BACKEND environment variable.} + \item{...}{Additional arguments, passed to \code{\link{fit.prophet}}} } \value{ diff --git a/R/tests/testthat/test_diagnostics.R b/R/tests/testthat/test_diagnostics.R index a594296..94af754 100644 --- a/R/tests/testthat/test_diagnostics.R +++ b/R/tests/testthat/test_diagnostics.R @@ -121,7 +121,7 @@ test_that("cross_validation_uncertainty_disabled", { skip_if_not(Sys.getenv('R_ARCH') != '/i386') for (uncertainty in c(0, FALSE)) { m <- prophet(uncertainty.samples = uncertainty) - m <- fit.prophet(m = m, df = DATA, algorithm = "Newton") + m <- fit.prophet(m = m, df = DATA) df.cv <- cross_validation( m, horizon = 4, units = "days", period = 4, initial = 115) expected.cols <- c('y', 'ds', 'yhat', 'cutoff') diff --git a/R/tests/testthat/test_prophet.R b/R/tests/testthat/test_prophet.R index 88469dc..06cb231 100644 --- a/R/tests/testthat/test_prophet.R +++ b/R/tests/testthat/test_prophet.R @@ -125,11 +125,10 @@ test_that("logistic_floor", { future1 <- future future1$cap <- 80. future1$floor <- 10. - m <- fit.prophet(m, history, algorithm = 'Newton') + m <- fit.prophet(m, history) expect_true(m$logistic.floor) expect_true('floor' %in% colnames(m$history)) expect_equal(m$history$y_scaled[1], 1., tolerance = 1e-6) - expect_equal(m$fit.kwargs, list(algorithm = 'Newton')) fcst1 <- predict(m, future1) m2 <- prophet(growth = 'logistic') @@ -139,7 +138,7 @@ test_that("logistic_floor", { history2$cap <- history2$cap + 10. future1$cap <- future1$cap + 10. future1$floor <- future1$floor + 10. - m2 <- fit.prophet(m2, history2, algorithm = 'Newton') + m2 <- fit.prophet(m2, history2) expect_equal(m2$history$y_scaled[1], 1., tolerance = 1e-6) })