diff --git a/python/prophet/models.py b/python/prophet/models.py index 1b74cee..3cf4b93 100644 --- a/python/prophet/models.py +++ b/python/prophet/models.py @@ -87,25 +87,28 @@ class CmdStanPyBackend(IStanBackend): def fit(self, stan_init, stan_data, **kwargs): (stan_init, stan_data) = self.prepare_data(stan_init, stan_data) - if 'algorithm' not in kwargs: - kwargs['algorithm'] = 'Newton' if stan_data['T'] < 100 else 'LBFGS' - iterations = int(1e4) + + if 'inits' not in kwargs and 'init' in kwargs: + kwargs['inits'] = self.prepare_data(kwargs['init'], stan_data)[0] + + args = dict( + data=stan_data, + inits=stan_init, + algorithm='Newton' if stan_data['T'] < 100 else 'LBFGS', + iter=int(1e4), + ) + args.update(kwargs) + try: - self.stan_fit = self.model.optimize(data=stan_data, - inits=stan_init, - iter=iterations, - **kwargs) + self.stan_fit = self.model.optimize(**args) except RuntimeError as e: # Fall back on Newton - if self.newton_fallback and kwargs['algorithm'] != 'Newton': + if self.newton_fallback and args['algorithm'] != 'Newton': logger.warning( 'Optimization terminated abnormally. Falling back to Newton.' ) - kwargs['algorithm'] = 'Newton' - self.stan_fit = self.model.optimize(data=stan_data, - inits=stan_init, - iter=iterations, - **kwargs) + args['algorithm'] = 'Newton' + self.stan_fit = self.model.optimize(**args) else: raise e @@ -117,17 +120,26 @@ class CmdStanPyBackend(IStanBackend): def sampling(self, stan_init, stan_data, samples, **kwargs) -> dict: (stan_init, stan_data) = self.prepare_data(stan_init, stan_data) + + if 'inits' not in kwargs and 'init' in kwargs: + kwargs['inits'] = self.prepare_data(kwargs['init'], stan_data)[0] + + args = dict( + data=stan_data, + inits=stan_init, + algorithm='Newton' if stan_data['T'] < 100 else 'LBFGS', + ) if 'chains' not in kwargs: kwargs['chains'] = 4 iter_half = samples // 2 + kwargs['iter_sampling'] = iter_half if 'iter_warmup' not in kwargs: kwargs['iter_warmup'] = iter_half + + args.update(kwargs) - self.stan_fit = self.model.sample(data=stan_data, - inits=stan_init, - iter_sampling=iter_half, - **kwargs) + self.stan_fit = self.model.sample(**args) res = self.stan_fit.draws() (samples, c, columns) = res.shape res = res.reshape((samples * c, columns)) @@ -166,10 +178,10 @@ class CmdStanPyBackend(IStanBackend): 'm': init['m'], 'delta': init['delta'].tolist(), 'beta': init['beta'].tolist(), - 'sigma_obs': 1 + 'sigma_obs': init['sigma_obs'] } return (cmdstanpy_init, cmdstanpy_data) - + @staticmethod def stan_to_dict_numpy(column_names: Tuple[str, ...], data: 'np.array'): import numpy as np