From 1d398c679d40007ade0dcf81ae0022cceaa1ce8b Mon Sep 17 00:00:00 2001 From: Ben Letham Date: Mon, 3 Dec 2018 19:05:47 -0800 Subject: [PATCH] Allow overriding any of the arguments to stan functions --- R/R/prophet.R | 18 ++++++++++-------- python/fbprophet/forecaster.py | 27 +++++++++++++++++---------- 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/R/R/prophet.R b/R/R/prophet.R index 60e7612..75862d7 100644 --- a/R/R/prophet.R +++ b/R/R/prophet.R @@ -1213,24 +1213,26 @@ fit.prophet <- function(m, df, ...) { m$params$sigma_obs <- 0. n.iteration <- 1. } else if (m$mcmc.samples > 0) { - stan.fit <- rstan::sampling( - model, + args <- list( + object = model, data = dat, init = stan_init, - iter = m$mcmc.samples, - ... + iter = m$mcmc.samples ) + args <- modifyList(args, list(...)) + stan.fit <- do.call(rstan::sampling, args) m$params <- rstan::extract(stan.fit) n.iteration <- length(m$params$k) } else { - stan.fit <- rstan::optimizing( - model, + args <- list( + object = model, data = dat, init = stan_init, iter = 1e4, - as_vector = FALSE, - ... + as_vector = FALSE ) + args <- modifyList(args, list(...)) + stan.fit <- do.call(rstan::optimizing, args) m$params <- stan.fit$par n.iteration <- 1 } diff --git a/python/fbprophet/forecaster.py b/python/fbprophet/forecaster.py index 163e4e5..7c7399c 100644 --- a/python/fbprophet/forecaster.py +++ b/python/fbprophet/forecaster.py @@ -1082,27 +1082,34 @@ class Prophet(object): for par in self.params: self.params[par] = np.array([self.params[par]]) elif self.mcmc_samples > 0: - stan_fit = model.sampling( - dat, + args = dict( + data=dat, init=stan_init, iter=self.mcmc_samples, - **kwargs ) + args.update(kwargs) + stan_fit = model.sampling(**args) for par in stan_fit.model_pars: self.params[par] = stan_fit[par] # Shape vector parameters if par in ['delta', 'beta'] and len(self.params[par].shape) < 2: self.params[par] = self.params[par].reshape((-1, 1)) else: + args = dict( + data=dat, + init=stan_init, + iter=1e4, + ) + args.update(kwargs) try: - params = model.optimizing( - dat, init=stan_init, iter=1e4, **kwargs) + params = model.optimizing(**args) except RuntimeError: - kwargs.pop('algorithm', None) - params = model.optimizing( - dat, init=stan_init, iter=1e4, algorithm='Newton', - **kwargs - ) + if 'algorithm' not in args: + # Fall back on Newton + args['algorithm'] = 'Newton' + params = model.optimizing(**args) + else: + raise for par in params: self.params[par] = params[par].reshape((1, -1))