Merge branch 'patch-1' of https://github.com/loulo1/prophet into loulo1-patch-1

This commit is contained in:
Ben Letham 2021-04-20 17:49:17 -07:00
commit 83c4ef3e2b

View file

@ -87,25 +87,28 @@ class CmdStanPyBackend(IStanBackend):
def fit(self, stan_init, stan_data, **kwargs):
(stan_init, stan_data) = self.prepare_data(stan_init, stan_data)
if 'algorithm' not in kwargs:
kwargs['algorithm'] = 'Newton' if stan_data['T'] < 100 else 'LBFGS'
iterations = int(1e4)
if 'inits' not in kwargs and 'init' in kwargs:
kwargs['inits'] = self.prepare_data(kwargs['init'], stan_data)[0]
args = dict(
data=stan_data,
inits=stan_init,
algorithm='Newton' if stan_data['T'] < 100 else 'LBFGS',
iter=int(1e4),
)
args.update(kwargs)
try:
self.stan_fit = self.model.optimize(data=stan_data,
inits=stan_init,
iter=iterations,
**kwargs)
self.stan_fit = self.model.optimize(**args)
except RuntimeError as e:
# Fall back on Newton
if self.newton_fallback and kwargs['algorithm'] != 'Newton':
if self.newton_fallback and args['algorithm'] != 'Newton':
logger.warning(
'Optimization terminated abnormally. Falling back to Newton.'
)
kwargs['algorithm'] = 'Newton'
self.stan_fit = self.model.optimize(data=stan_data,
inits=stan_init,
iter=iterations,
**kwargs)
args['algorithm'] = 'Newton'
self.stan_fit = self.model.optimize(**args)
else:
raise e
@ -117,17 +120,26 @@ class CmdStanPyBackend(IStanBackend):
def sampling(self, stan_init, stan_data, samples, **kwargs) -> dict:
(stan_init, stan_data) = self.prepare_data(stan_init, stan_data)
if 'inits' not in kwargs and 'init' in kwargs:
kwargs['inits'] = self.prepare_data(kwargs['init'], stan_data)[0]
args = dict(
data=stan_data,
inits=stan_init,
algorithm='Newton' if stan_data['T'] < 100 else 'LBFGS',
)
if 'chains' not in kwargs:
kwargs['chains'] = 4
iter_half = samples // 2
kwargs['iter_sampling'] = iter_half
if 'iter_warmup' not in kwargs:
kwargs['iter_warmup'] = iter_half
args.update(kwargs)
self.stan_fit = self.model.sample(data=stan_data,
inits=stan_init,
iter_sampling=iter_half,
**kwargs)
self.stan_fit = self.model.sample(**args)
res = self.stan_fit.draws()
(samples, c, columns) = res.shape
res = res.reshape((samples * c, columns))
@ -166,10 +178,10 @@ class CmdStanPyBackend(IStanBackend):
'm': init['m'],
'delta': init['delta'].tolist(),
'beta': init['beta'].tolist(),
'sigma_obs': 1
'sigma_obs': init['sigma_obs']
}
return (cmdstanpy_init, cmdstanpy_data)
@staticmethod
def stan_to_dict_numpy(column_names: Tuple[str, ...], data: 'np.array'):
import numpy as np