mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-06-11 00:49:35 +00:00
Additional kwargs to Stan in Python
This commit is contained in:
parent
96ef0e236a
commit
e08cfd2176
1 changed files with 16 additions and 3 deletions
|
|
@ -330,7 +330,20 @@ class Prophet(object):
|
|||
return (k, m)
|
||||
|
||||
# fb-block 7
|
||||
def fit(self, df):
|
||||
def fit(self, df, **kwargs):
|
||||
"""Fit the Prophet model to data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df: pd.DataFrame containing history. Must have columns 'ds', 'y', and
|
||||
if logistic growth, 'cap'.
|
||||
kwargs: Additional arguments passed to Stan's sampling or optimizing
|
||||
function, as appropriate.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The fitted Prophet object.
|
||||
"""
|
||||
history = df[df['y'].notnull()].copy()
|
||||
history.reset_index(inplace=True, drop=True)
|
||||
|
||||
|
|
@ -377,14 +390,14 @@ class Prophet(object):
|
|||
stan_fit = model.sampling(
|
||||
dat,
|
||||
init=stan_init,
|
||||
chains=1,
|
||||
iter=self.mcmc_samples,
|
||||
**kwargs
|
||||
)
|
||||
for par in stan_fit.model_pars:
|
||||
self.params[par] = stan_fit[par]
|
||||
|
||||
else:
|
||||
params = model.optimizing(dat, init=stan_init, iter=1e4)
|
||||
params = model.optimizing(dat, init=stan_init, iter=1e4, **kwargs)
|
||||
for par in params:
|
||||
self.params[par] = params[par].reshape((1, -1))
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue