Allow overriding any of the arguments to stan functions

This commit is contained in:
Ben Letham 2018-12-03 19:05:47 -08:00
parent 545045d615
commit 1d398c679d
2 changed files with 27 additions and 18 deletions

View file

@ -1213,24 +1213,26 @@ fit.prophet <- function(m, df, ...) {
m$params$sigma_obs <- 0.
n.iteration <- 1.
} else if (m$mcmc.samples > 0) {
stan.fit <- rstan::sampling(
model,
args <- list(
object = model,
data = dat,
init = stan_init,
iter = m$mcmc.samples,
...
iter = m$mcmc.samples
)
args <- modifyList(args, list(...))
stan.fit <- do.call(rstan::sampling, args)
m$params <- rstan::extract(stan.fit)
n.iteration <- length(m$params$k)
} else {
stan.fit <- rstan::optimizing(
model,
args <- list(
object = model,
data = dat,
init = stan_init,
iter = 1e4,
as_vector = FALSE,
...
as_vector = FALSE
)
args <- modifyList(args, list(...))
stan.fit <- do.call(rstan::optimizing, args)
m$params <- stan.fit$par
n.iteration <- 1
}

View file

@ -1082,27 +1082,34 @@ class Prophet(object):
for par in self.params:
self.params[par] = np.array([self.params[par]])
elif self.mcmc_samples > 0:
stan_fit = model.sampling(
dat,
args = dict(
data=dat,
init=stan_init,
iter=self.mcmc_samples,
**kwargs
)
args.update(kwargs)
stan_fit = model.sampling(**args)
for par in stan_fit.model_pars:
self.params[par] = stan_fit[par]
# Shape vector parameters
if par in ['delta', 'beta'] and len(self.params[par].shape) < 2:
self.params[par] = self.params[par].reshape((-1, 1))
else:
args = dict(
data=dat,
init=stan_init,
iter=1e4,
)
args.update(kwargs)
try:
params = model.optimizing(
dat, init=stan_init, iter=1e4, **kwargs)
params = model.optimizing(**args)
except RuntimeError:
kwargs.pop('algorithm', None)
params = model.optimizing(
dat, init=stan_init, iter=1e4, algorithm='Newton',
**kwargs
)
if 'algorithm' not in args:
# Fall back on Newton
args['algorithm'] = 'Newton'
params = model.optimizing(**args)
else:
raise
for par in params:
self.params[par] = params[par].reshape((1, -1))