mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-05-14 20:48:08 +00:00
Properly allow warm starts (#2335)
This commit is contained in:
parent
e2b4ef3a8e
commit
eb71d0117b
4 changed files with 80 additions and 22 deletions
|
|
@ -82,7 +82,7 @@ A common setting for forecasting is fitting models that need to be updated as ad
|
|||
|
||||
```python
|
||||
# Python
|
||||
def stan_init(m):
|
||||
def get_stan_init(m):
|
||||
"""Retrieve parameters from a trained model.
|
||||
|
||||
Retrieve parameters from a trained model in the format
|
||||
|
|
@ -95,13 +95,18 @@ def stan_init(m):
|
|||
Returns
|
||||
-------
|
||||
A Dictionary containing retrieved parameters of m.
|
||||
|
||||
"""
|
||||
res = {}
|
||||
for pname in ['k', 'm', 'sigma_obs']:
|
||||
res[pname] = m.params[pname][0][0]
|
||||
if m.mcmc_samples == 0:
|
||||
res[pname] = m.params[pname][0][0]
|
||||
else:
|
||||
res[pname] = np.mean(m.params[pname])
|
||||
for pname in ['delta', 'beta']:
|
||||
res[pname] = m.params[pname][0]
|
||||
if m.mcmc_samples == 0:
|
||||
res[pname] = m.params[pname][0]
|
||||
else:
|
||||
res[pname] = np.mean(m.params[pname], axis=0)
|
||||
return res
|
||||
|
||||
df = pd.read_csv('https://raw.githubusercontent.com/facebook/prophet/main/examples/example_wp_log_peyton_manning.csv')
|
||||
|
|
@ -110,7 +115,7 @@ m1 = Prophet().fit(df1) # A model fit to all data except the last day
|
|||
|
||||
|
||||
%timeit m2 = Prophet().fit(df) # Adding the last day, fitting from scratch
|
||||
%timeit m2 = Prophet().fit(df, init=stan_init(m1)) # Adding the last day, warm-starting from m1
|
||||
%timeit m2 = Prophet().fit(df, init=get_stan_init(m1)) # Adding the last day, warm-starting from m1
|
||||
```
|
||||
1.33 s ± 55.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
|
||||
185 ms ± 4.46 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
|
||||
|
|
|
|||
|
|
@ -216,26 +216,31 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"def stan_init(m):\n",
|
||||
"def get_stan_init(m):\n",
|
||||
" \"\"\"Retrieve parameters from a trained model.\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" Retrieve parameters from a trained model in the format\n",
|
||||
" used to initialize a new Stan model.\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" Parameters\n",
|
||||
" ----------\n",
|
||||
" m: A trained model of the Prophet class.\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" Returns\n",
|
||||
" -------\n",
|
||||
" A Dictionary containing retrieved parameters of m.\n",
|
||||
" \n",
|
||||
" \"\"\"\n",
|
||||
" res = {}\n",
|
||||
" for pname in ['k', 'm', 'sigma_obs']:\n",
|
||||
" res[pname] = m.params[pname][0][0]\n",
|
||||
" if m.mcmc_samples == 0:\n",
|
||||
" res[pname] = m.params[pname][0][0]\n",
|
||||
" else:\n",
|
||||
" res[pname] = np.mean(m.params[pname])\n",
|
||||
" for pname in ['delta', 'beta']:\n",
|
||||
" res[pname] = m.params[pname][0]\n",
|
||||
" if m.mcmc_samples == 0:\n",
|
||||
" res[pname] = m.params[pname][0]\n",
|
||||
" else:\n",
|
||||
" res[pname] = np.mean(m.params[pname], axis=0)\n",
|
||||
" return res\n",
|
||||
"\n",
|
||||
"df = pd.read_csv('../examples/example_wp_log_peyton_manning.csv')\n",
|
||||
|
|
@ -244,7 +249,7 @@
|
|||
"\n",
|
||||
"\n",
|
||||
"%timeit m2 = Prophet().fit(df) # Adding the last day, fitting from scratch\n",
|
||||
"%timeit m2 = Prophet().fit(df, init=stan_init(m1)) # Adding the last day, warm-starting from m1"
|
||||
"%timeit m2 = Prophet().fit(df, init=get_stan_init(m1)) # Adding the last day, warm-starting from m1"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
|||
|
|
@ -79,11 +79,11 @@ class CmdStanPyBackend(IStanBackend):
|
|||
return cmdstanpy.CmdStanModel(exe_file=model_file)
|
||||
|
||||
def fit(self, stan_init, stan_data, **kwargs):
|
||||
(stan_init, stan_data) = self.prepare_data(stan_init, stan_data)
|
||||
|
||||
if 'inits' not in kwargs and 'init' in kwargs:
|
||||
kwargs['inits'] = self.prepare_data(kwargs['init'], stan_data)[0]
|
||||
kwargs['inits'], _ = self.prepare_data(kwargs['init'], stan_data)
|
||||
del kwargs['init']
|
||||
|
||||
stan_init, stan_data = self.prepare_data(stan_init, stan_data)
|
||||
args = dict(
|
||||
data=stan_data,
|
||||
inits=stan_init,
|
||||
|
|
@ -108,11 +108,11 @@ class CmdStanPyBackend(IStanBackend):
|
|||
return params
|
||||
|
||||
def sampling(self, stan_init, stan_data, samples, **kwargs) -> dict:
|
||||
(stan_init, stan_data) = self.prepare_data(stan_init, stan_data)
|
||||
|
||||
if 'inits' not in kwargs and 'init' in kwargs:
|
||||
kwargs['inits'] = self.prepare_data(kwargs['init'], stan_data)[0]
|
||||
kwargs['inits'], _ = self.prepare_data(kwargs['init'], stan_data)
|
||||
del kwargs['init']
|
||||
|
||||
stan_init, stan_data = self.prepare_data(stan_init, stan_data)
|
||||
args = dict(
|
||||
data=stan_data,
|
||||
inits=stan_init,
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ class TestProphet(TestCase):
|
|||
forecaster = Prophet(mcmc_samples=500)
|
||||
|
||||
# chains adjusted from 4 to 7 to satisfy test for cmdstanpy
|
||||
forecaster.fit(train, seed=1237861298, chains=7)
|
||||
forecaster.fit(train, seed=1237861298, chains=7, show_progress=False)
|
||||
np.random.seed(876543987)
|
||||
future = forecaster.make_future_dataframe(days, include_history=False)
|
||||
future = forecaster.predict(future)
|
||||
|
|
@ -110,9 +110,9 @@ class TestProphet(TestCase):
|
|||
N = DATA.shape[0]
|
||||
train = DATA.head(N // 2)
|
||||
future = DATA.tail(N // 2)
|
||||
|
||||
|
||||
forecaster = Prophet(n_changepoints=0, mcmc_samples=100)
|
||||
forecaster.fit(train)
|
||||
forecaster.fit(train, show_progress=False)
|
||||
forecaster.predict(future)
|
||||
|
||||
def test_fit_changepoint_not_in_history(self):
|
||||
|
|
@ -891,3 +891,51 @@ class TestProphet(TestCase):
|
|||
'holidays',
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_stan_init(m):
|
||||
"""Retrieve parameters from a trained model.
|
||||
|
||||
Retrieve parameters from a trained model in the format
|
||||
used to initialize a new Stan model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
m: A trained model of the Prophet class.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A Dictionary containing retrieved parameters of m.
|
||||
"""
|
||||
res = {}
|
||||
for pname in ['k', 'm', 'sigma_obs']:
|
||||
if m.mcmc_samples == 0:
|
||||
res[pname] = m.params[pname][0][0]
|
||||
else:
|
||||
res[pname] = np.mean(m.params[pname])
|
||||
for pname in ['delta', 'beta']:
|
||||
if m.mcmc_samples == 0:
|
||||
res[pname] = m.params[pname][0]
|
||||
else:
|
||||
res[pname] = np.mean(m.params[pname], axis=0)
|
||||
return res
|
||||
|
||||
def test_fit_warm_start(self):
|
||||
previous_df = DATA.iloc[:500]
|
||||
df = DATA.iloc[:510]
|
||||
m = Prophet()
|
||||
m = m.fit(previous_df)
|
||||
m_params = self.get_stan_init(m)
|
||||
m2 = Prophet()
|
||||
m2 = m2.fit(df, init=m_params)
|
||||
self.assertEqual(len(m2.params['delta'][0]), 25)
|
||||
|
||||
def test_sampling_warm_start(self):
|
||||
previous_df = DATA.iloc[:500]
|
||||
df = DATA.iloc[:510]
|
||||
m = Prophet(mcmc_samples=100)
|
||||
m = m.fit(previous_df, show_progress=False)
|
||||
m_params = self.get_stan_init(m)
|
||||
m2 = Prophet(mcmc_samples=100)
|
||||
m2 = m2.fit(df, init=m_params, show_progress=False)
|
||||
self.assertEqual(m2.params['delta'].shape, (200, 25))
|
||||
|
|
|
|||
Loading…
Reference in a new issue