mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-05-14 20:48:08 +00:00
Sanitize warm start parameters (#2342)
This commit is contained in:
parent
0a5becb777
commit
8b3d09caf7
5 changed files with 80 additions and 67 deletions
|
|
@ -82,11 +82,12 @@ A common setting for forecasting is fitting models that need to be updated as ad
|
|||
|
||||
```python
|
||||
# Python
|
||||
def get_stan_init(m):
|
||||
"""Retrieve parameters from a trained model.
|
||||
|
||||
Retrieve parameters from a trained model in the format
|
||||
used to initialize a new Stan model.
|
||||
def warm_start_params(m):
|
||||
"""
|
||||
Retrieve parameters from a trained model in the format used to initialize a new Stan model.
|
||||
Note that the new Stan model must have these same settings:
|
||||
n_changepoints, seasonality features, mcmc sampling
|
||||
for the retrieved parameters to be valid for the new model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
@ -115,7 +116,7 @@ m1 = Prophet().fit(df1) # A model fit to all data except the last day
|
|||
|
||||
|
||||
%timeit m2 = Prophet().fit(df) # Adding the last day, fitting from scratch
|
||||
%timeit m2 = Prophet().fit(df, init=get_stan_init(m1)) # Adding the last day, warm-starting from m1
|
||||
%timeit m2 = Prophet().fit(df, init=warm_start_params(m1)) # Adding the last day, warm-starting from m1
|
||||
```
|
||||
1.33 s ± 55.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
|
||||
185 ms ± 4.46 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
|
||||
|
|
|
|||
|
|
@ -216,11 +216,12 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"def get_stan_init(m):\n",
|
||||
" \"\"\"Retrieve parameters from a trained model.\n",
|
||||
"\n",
|
||||
" Retrieve parameters from a trained model in the format\n",
|
||||
" used to initialize a new Stan model.\n",
|
||||
"def warm_start_params(m):\n",
|
||||
" \"\"\"\n",
|
||||
" Retrieve parameters from a trained model in the format used to initialize a new Stan model.\n",
|
||||
" Note that the new Stan model must have these same settings:\n",
|
||||
" n_changepoints, seasonality features, mcmc sampling\n",
|
||||
" for the retrieved parameters to be valid for the new model.\n",
|
||||
"\n",
|
||||
" Parameters\n",
|
||||
" ----------\n",
|
||||
|
|
@ -249,7 +250,7 @@
|
|||
"\n",
|
||||
"\n",
|
||||
"%timeit m2 = Prophet().fit(df) # Adding the last day, fitting from scratch\n",
|
||||
"%timeit m2 = Prophet().fit(df, init=get_stan_init(m1)) # Adding the last day, warm-starting from m1"
|
||||
"%timeit m2 = Prophet().fit(df, init=warm_start_params(m1)) # Adding the last day, warm-starting from m1"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -275,7 +276,7 @@
|
|||
"metadata": {
|
||||
"celltoolbar": "Edit Metadata",
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
|
@ -289,7 +290,12 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
"version": "3.8.10 (default, Mar 25 2022, 22:18:25) \n[Clang 12.0.5 (clang-1205.0.22.11)]"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "9f764ff66e3236555d51a00749f5293db82082e341e153963b8b2deea93b52fc"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
|||
|
|
@ -80,14 +80,14 @@ class CmdStanPyBackend(IStanBackend):
|
|||
|
||||
def fit(self, stan_init, stan_data, **kwargs):
|
||||
if 'inits' not in kwargs and 'init' in kwargs:
|
||||
kwargs['inits'], _ = self.prepare_data(kwargs['init'], stan_data)
|
||||
stan_init = self.sanitize_custom_inits(stan_init, kwargs['init'])
|
||||
del kwargs['init']
|
||||
|
||||
stan_init, stan_data = self.prepare_data(stan_init, stan_data)
|
||||
inits_list, data_list = self.prepare_data(stan_init, stan_data)
|
||||
args = dict(
|
||||
data=stan_data,
|
||||
inits=stan_init,
|
||||
algorithm='Newton' if stan_data['T'] < 100 else 'LBFGS',
|
||||
data=data_list,
|
||||
inits=inits_list,
|
||||
algorithm='Newton' if data_list['T'] < 100 else 'LBFGS',
|
||||
iter=int(1e4),
|
||||
)
|
||||
args.update(kwargs)
|
||||
|
|
@ -109,22 +109,20 @@ class CmdStanPyBackend(IStanBackend):
|
|||
|
||||
def sampling(self, stan_init, stan_data, samples, **kwargs) -> dict:
|
||||
if 'inits' not in kwargs and 'init' in kwargs:
|
||||
kwargs['inits'], _ = self.prepare_data(kwargs['init'], stan_data)
|
||||
stan_init = self.sanitize_custom_inits(stan_init, kwargs['init'])
|
||||
del kwargs['init']
|
||||
|
||||
stan_init, stan_data = self.prepare_data(stan_init, stan_data)
|
||||
inits_list, data_list = self.prepare_data(stan_init, stan_data)
|
||||
args = dict(
|
||||
data=stan_data,
|
||||
inits=stan_init,
|
||||
data=data_list,
|
||||
inits=inits_list,
|
||||
)
|
||||
|
||||
if 'chains' not in kwargs:
|
||||
kwargs['chains'] = 4
|
||||
iter_half = samples // 2
|
||||
kwargs['iter_sampling'] = iter_half
|
||||
if 'iter_warmup' not in kwargs:
|
||||
kwargs['iter_warmup'] = iter_half
|
||||
|
||||
args.update(kwargs)
|
||||
|
||||
self.stan_fit = self.model.sample(**args)
|
||||
|
|
@ -143,8 +141,25 @@ class CmdStanPyBackend(IStanBackend):
|
|||
|
||||
return params
|
||||
|
||||
@staticmethod
|
||||
def sanitize_custom_inits(default_inits, custom_inits):
|
||||
"""Validate that custom inits have the correct type and shape, otherwise use defaults."""
|
||||
sanitized = {}
|
||||
for param in ['k', 'm', 'sigma_obs']:
|
||||
try:
|
||||
sanitized[param] = float(custom_inits.get(param))
|
||||
except Exception:
|
||||
sanitized[param] = default_inits[param]
|
||||
for param in ['delta', 'beta']:
|
||||
if default_inits[param].shape == custom_inits[param].shape:
|
||||
sanitized[param] = custom_inits[param]
|
||||
else:
|
||||
sanitized[param] = default_inits[param]
|
||||
return sanitized
|
||||
|
||||
@staticmethod
|
||||
def prepare_data(init, data) -> Tuple[dict, dict]:
|
||||
"""Converts np.ndarrays to lists that can be read by cmdstanpy."""
|
||||
cmdstanpy_data = {
|
||||
'T': data['T'],
|
||||
'S': data['S'],
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from unittest import TestCase, skipUnless
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
from prophet import Prophet
|
||||
from prophet.utilities import warm_start_params
|
||||
|
||||
|
||||
DATA = pd.read_csv(
|
||||
|
|
@ -892,50 +893,12 @@ class TestProphet(TestCase):
|
|||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_stan_init(m):
|
||||
"""Retrieve parameters from a trained model.
|
||||
|
||||
Retrieve parameters from a trained model in the format
|
||||
used to initialize a new Stan model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
m: A trained model of the Prophet class.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A Dictionary containing retrieved parameters of m.
|
||||
"""
|
||||
res = {}
|
||||
for pname in ['k', 'm', 'sigma_obs']:
|
||||
if m.mcmc_samples == 0:
|
||||
res[pname] = m.params[pname][0][0]
|
||||
else:
|
||||
res[pname] = np.mean(m.params[pname])
|
||||
for pname in ['delta', 'beta']:
|
||||
if m.mcmc_samples == 0:
|
||||
res[pname] = m.params[pname][0]
|
||||
else:
|
||||
res[pname] = np.mean(m.params[pname], axis=0)
|
||||
return res
|
||||
|
||||
def test_fit_warm_start(self):
|
||||
previous_df = DATA.iloc[:500]
|
||||
df = DATA.iloc[:510]
|
||||
m = Prophet()
|
||||
m = m.fit(previous_df)
|
||||
m_params = self.get_stan_init(m)
|
||||
m2 = Prophet()
|
||||
m2 = m2.fit(df, init=m_params)
|
||||
m = Prophet().fit(DATA.iloc[:500])
|
||||
m2 = Prophet().fit(DATA.iloc[:510], init=warm_start_params(m))
|
||||
self.assertEqual(len(m2.params['delta'][0]), 25)
|
||||
|
||||
def test_sampling_warm_start(self):
|
||||
previous_df = DATA.iloc[:500]
|
||||
df = DATA.iloc[:510]
|
||||
m = Prophet(mcmc_samples=100)
|
||||
m = m.fit(previous_df, show_progress=False)
|
||||
m_params = self.get_stan_init(m)
|
||||
m2 = Prophet(mcmc_samples=100)
|
||||
m2 = m2.fit(df, init=m_params, show_progress=False)
|
||||
m = Prophet(mcmc_samples=100).fit(DATA.iloc[:500], show_progress=False)
|
||||
m2 = Prophet(mcmc_samples=100).fit(DATA.iloc[:510], init=warm_start_params(m), show_progress=False)
|
||||
self.assertEqual(m2.params['delta'].shape, (200, 25))
|
||||
|
|
|
|||
|
|
@ -75,3 +75,31 @@ def regressor_coefficients(m):
|
|||
coefs.append(record)
|
||||
|
||||
return pd.DataFrame(coefs)
|
||||
|
||||
def warm_start_params(m):
|
||||
"""
|
||||
Retrieve parameters from a trained model in the format used to initialize a new Stan model.
|
||||
Note that the new Stan model must have these same settings:
|
||||
n_changepoints, seasonality features, mcmc sampling
|
||||
for the retrieved parameters to be valid for the new model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
m: A trained model of the Prophet class.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A Dictionary containing retrieved parameters of m.
|
||||
"""
|
||||
res = {}
|
||||
for pname in ['k', 'm', 'sigma_obs']:
|
||||
if m.mcmc_samples == 0:
|
||||
res[pname] = m.params[pname][0][0]
|
||||
else:
|
||||
res[pname] = np.mean(m.params[pname])
|
||||
for pname in ['delta', 'beta']:
|
||||
if m.mcmc_samples == 0:
|
||||
res[pname] = m.params[pname][0]
|
||||
else:
|
||||
res[pname] = np.mean(m.params[pname], axis=0)
|
||||
return res
|
||||
|
|
|
|||
Loading…
Reference in a new issue