mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-05-29 23:06:49 +00:00
Fix the issue #1814
I did as PyStanBackend. And now when we use the method fit of Prophet, we can do like in the documentation: https://facebook.github.io/prophet/docs/additional_topics.html#updating-fitted-models def stan_init(m): """Retrieve parameters from a trained model. Retrieve parameters from a trained model in the format used to initialize a new Stan model. Parameters ---------- m: A trained model of the Prophet class. Returns ------- A Dictionary containing retrieved parameters of m. """ res = {} for pname in ['k', 'm', 'sigma_obs']: res[pname] = m.params[pname][0][0] for pname in ['delta', 'beta']: res[pname] = m.params[pname][0] return res df = pd.read_csv('../examples/example_wp_log_peyton_manning.csv') df1 = df.loc[df['ds'] < '2016-01-19', :] # All data except the last day m1 = Prophet().fit(df1) # A model fit to all data except the last day %timeit m2 = Prophet().fit(df) # Adding the last day, fitting from scratch %timeit m2 = Prophet().fit(df, init=stan_init(m1)) # Adding the last day, warm-starting from m1 Update models.py Update models.py Update models.py Update models.py Update models.py Update models.py Update models.py Test Test2 Test4 Test4 Test are fixed
This commit is contained in:
parent
8882c6a3e3
commit
fc8fa49aac
1 changed files with 35 additions and 19 deletions
|
|
@ -87,25 +87,30 @@ class CmdStanPyBackend(IStanBackend):
|
|||
|
||||
def fit(self, stan_init, stan_data, **kwargs):
|
||||
(stan_init, stan_data) = self.prepare_data(stan_init, stan_data)
|
||||
if 'algorithm' not in kwargs:
|
||||
kwargs['algorithm'] = 'Newton' if stan_data['T'] < 100 else 'LBFGS'
|
||||
iterations = int(1e4)
|
||||
if 'init' in kwargs:
|
||||
kwargs['init'] = self.prepare_data(kwargs['init'], stan_data)[0]
|
||||
|
||||
args = dict(
|
||||
data=stan_data,
|
||||
init=stan_init,
|
||||
algorithm='Newton' if stan_data['T'] < 100 else 'LBFGS',
|
||||
iter=int(1e4),
|
||||
)
|
||||
args.update(kwargs)
|
||||
|
||||
args['inits'] = args['init']
|
||||
del args['init']
|
||||
|
||||
try:
|
||||
self.stan_fit = self.model.optimize(data=stan_data,
|
||||
inits=stan_init,
|
||||
iter=iterations,
|
||||
**kwargs)
|
||||
self.stan_fit = self.model.optimize(**args)
|
||||
except RuntimeError as e:
|
||||
# Fall back on Newton
|
||||
if self.newton_fallback and kwargs['algorithm'] != 'Newton':
|
||||
if self.newton_fallback and args['algorithm'] != 'Newton':
|
||||
logger.warning(
|
||||
'Optimization terminated abnormally. Falling back to Newton.'
|
||||
)
|
||||
kwargs['algorithm'] = 'Newton'
|
||||
self.stan_fit = self.model.optimize(data=stan_data,
|
||||
inits=stan_init,
|
||||
iter=iterations,
|
||||
**kwargs)
|
||||
args['algorithm'] = 'Newton'
|
||||
self.stan_fit = self.model.optimize(**args)
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
|
@ -117,17 +122,28 @@ class CmdStanPyBackend(IStanBackend):
|
|||
|
||||
def sampling(self, stan_init, stan_data, samples, **kwargs) -> dict:
|
||||
(stan_init, stan_data) = self.prepare_data(stan_init, stan_data)
|
||||
if 'init' in kwargs:
|
||||
kwargs['init'] = self.prepare_data(kwargs['init'], stan_data)[0]
|
||||
|
||||
args = dict(
|
||||
data=stan_data,
|
||||
init=stan_init,
|
||||
algorithm='Newton' if stan_data['T'] < 100 else 'LBFGS',
|
||||
)
|
||||
|
||||
if 'chains' not in kwargs:
|
||||
kwargs['chains'] = 4
|
||||
iter_half = samples // 2
|
||||
kwargs['iter_sampling'] = iter_half
|
||||
if 'iter_warmup' not in kwargs:
|
||||
kwargs['iter_warmup'] = iter_half
|
||||
|
||||
args.update(kwargs)
|
||||
|
||||
self.stan_fit = self.model.sample(data=stan_data,
|
||||
inits=stan_init,
|
||||
iter_sampling=iter_half,
|
||||
**kwargs)
|
||||
args['inits'] = args['init']
|
||||
del args['init']
|
||||
|
||||
self.stan_fit = self.model.sample(**args)
|
||||
res = self.stan_fit.sample
|
||||
(samples, c, columns) = res.shape
|
||||
res = res.reshape((samples * c, columns))
|
||||
|
|
@ -166,10 +182,10 @@ class CmdStanPyBackend(IStanBackend):
|
|||
'm': init['m'],
|
||||
'delta': init['delta'].tolist(),
|
||||
'beta': init['beta'].tolist(),
|
||||
'sigma_obs': 1
|
||||
'sigma_obs': init['sigma_obs']
|
||||
}
|
||||
return (cmdstanpy_init, cmdstanpy_data)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def stan_to_dict_numpy(column_names: Tuple[str, ...], data: 'np.array'):
|
||||
import numpy as np
|
||||
|
|
|
|||
Loading…
Reference in a new issue