diff --git a/python/fbprophet/models.py b/python/fbprophet/models.py index c75bc93..59a741a 100644 --- a/python/fbprophet/models.py +++ b/python/fbprophet/models.py @@ -87,20 +87,18 @@ class CmdStanPyBackend(IStanBackend): def fit(self, stan_init, stan_data, **kwargs): (stan_init, stan_data) = self.prepare_data(stan_init, stan_data) - if 'init' in kwargs: - kwargs['init'] = self.prepare_data(kwargs['init'], stan_data)[0] + + if 'inits' not in kwargs and 'init' in kwargs: + kwargs['inits'] = self.prepare_data(kwargs['init'], stan_data)[0] args = dict( data=stan_data, - init=stan_init, + inits=stan_init, algorithm='Newton' if stan_data['T'] < 100 else 'LBFGS', iter=int(1e4), ) args.update(kwargs) - args['inits'] = args['init'] - del args['init'] - try: self.stan_fit = self.model.optimize(**args) except RuntimeError as e: @@ -122,12 +120,13 @@ class CmdStanPyBackend(IStanBackend): def sampling(self, stan_init, stan_data, samples, **kwargs) -> dict: (stan_init, stan_data) = self.prepare_data(stan_init, stan_data) - if 'init' in kwargs: - kwargs['init'] = self.prepare_data(kwargs['init'], stan_data)[0] + + if 'inits' not in kwargs and 'init' in kwargs: + kwargs['inits'] = self.prepare_data(kwargs['init'], stan_data)[0] args = dict( data=stan_data, - init=stan_init, + inits=stan_init, algorithm='Newton' if stan_data['T'] < 100 else 'LBFGS', ) @@ -140,9 +139,6 @@ class CmdStanPyBackend(IStanBackend): args.update(kwargs) - args['inits'] = args['init'] - del args['init'] - self.stan_fit = self.model.sample(**args) res = self.stan_fit.sample (samples, c, columns) = res.shape