mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-05-28 22:56:38 +00:00
Allow overriding any of the arguments to stan functions
This commit is contained in:
parent
545045d615
commit
1d398c679d
2 changed files with 27 additions and 18 deletions
|
|
@ -1213,24 +1213,26 @@ fit.prophet <- function(m, df, ...) {
|
|||
m$params$sigma_obs <- 0.
|
||||
n.iteration <- 1.
|
||||
} else if (m$mcmc.samples > 0) {
|
||||
stan.fit <- rstan::sampling(
|
||||
model,
|
||||
args <- list(
|
||||
object = model,
|
||||
data = dat,
|
||||
init = stan_init,
|
||||
iter = m$mcmc.samples,
|
||||
...
|
||||
iter = m$mcmc.samples
|
||||
)
|
||||
args <- modifyList(args, list(...))
|
||||
stan.fit <- do.call(rstan::sampling, args)
|
||||
m$params <- rstan::extract(stan.fit)
|
||||
n.iteration <- length(m$params$k)
|
||||
} else {
|
||||
stan.fit <- rstan::optimizing(
|
||||
model,
|
||||
args <- list(
|
||||
object = model,
|
||||
data = dat,
|
||||
init = stan_init,
|
||||
iter = 1e4,
|
||||
as_vector = FALSE,
|
||||
...
|
||||
as_vector = FALSE
|
||||
)
|
||||
args <- modifyList(args, list(...))
|
||||
stan.fit <- do.call(rstan::optimizing, args)
|
||||
m$params <- stan.fit$par
|
||||
n.iteration <- 1
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1082,27 +1082,34 @@ class Prophet(object):
|
|||
for par in self.params:
|
||||
self.params[par] = np.array([self.params[par]])
|
||||
elif self.mcmc_samples > 0:
|
||||
stan_fit = model.sampling(
|
||||
dat,
|
||||
args = dict(
|
||||
data=dat,
|
||||
init=stan_init,
|
||||
iter=self.mcmc_samples,
|
||||
**kwargs
|
||||
)
|
||||
args.update(kwargs)
|
||||
stan_fit = model.sampling(**args)
|
||||
for par in stan_fit.model_pars:
|
||||
self.params[par] = stan_fit[par]
|
||||
# Shape vector parameters
|
||||
if par in ['delta', 'beta'] and len(self.params[par].shape) < 2:
|
||||
self.params[par] = self.params[par].reshape((-1, 1))
|
||||
else:
|
||||
args = dict(
|
||||
data=dat,
|
||||
init=stan_init,
|
||||
iter=1e4,
|
||||
)
|
||||
args.update(kwargs)
|
||||
try:
|
||||
params = model.optimizing(
|
||||
dat, init=stan_init, iter=1e4, **kwargs)
|
||||
params = model.optimizing(**args)
|
||||
except RuntimeError:
|
||||
kwargs.pop('algorithm', None)
|
||||
params = model.optimizing(
|
||||
dat, init=stan_init, iter=1e4, algorithm='Newton',
|
||||
**kwargs
|
||||
)
|
||||
if 'algorithm' not in args:
|
||||
# Fall back on Newton
|
||||
args['algorithm'] = 'Newton'
|
||||
params = model.optimizing(**args)
|
||||
else:
|
||||
raise
|
||||
for par in params:
|
||||
self.params[par] = params[par].reshape((1, -1))
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue