diff --git a/python/fbprophet/forecaster.py b/python/fbprophet/forecaster.py index 0ab70ff..dcbf6b4 100644 --- a/python/fbprophet/forecaster.py +++ b/python/fbprophet/forecaster.py @@ -330,7 +330,20 @@ class Prophet(object): return (k, m) # fb-block 7 - def fit(self, df): + def fit(self, df, **kwargs): + """Fit the Prophet model to data. + + Parameters + ---------- + df: pd.DataFrame containing history. Must have columns 'ds', 'y', and + if logistic growth, 'cap'. + kwargs: Additional arguments passed to Stan's sampling or optimizing + function, as appropriate. + + Returns + ------- + The fitted Prophet object. + """ history = df[df['y'].notnull()].copy() history.reset_index(inplace=True, drop=True) @@ -377,14 +390,14 @@ class Prophet(object): stan_fit = model.sampling( dat, init=stan_init, - chains=1, iter=self.mcmc_samples, + **kwargs ) for par in stan_fit.model_pars: self.params[par] = stan_fit[par] else: - params = model.optimizing(dat, init=stan_init, iter=1e4) + params = model.optimizing(dat, init=stan_init, iter=1e4, **kwargs) for par in params: self.params[par] = params[par].reshape((1, -1))