diff --git a/docs/_docs/additional_topics.md b/docs/_docs/additional_topics.md index 4f2638f..69fb2a1 100644 --- a/docs/_docs/additional_topics.md +++ b/docs/_docs/additional_topics.md @@ -82,11 +82,12 @@ A common setting for forecasting is fitting models that need to be updated as ad ```python # Python -def get_stan_init(m): - """Retrieve parameters from a trained model. - - Retrieve parameters from a trained model in the format - used to initialize a new Stan model. +def warm_start_params(m): + """ + Retrieve parameters from a trained model in the format used to initialize a new Stan model. + Note that the new Stan model must have these same settings: + n_changepoints, seasonality features, mcmc sampling + for the retrieved parameters to be valid for the new model. Parameters ---------- @@ -115,7 +116,7 @@ m1 = Prophet().fit(df1) # A model fit to all data except the last day %timeit m2 = Prophet().fit(df) # Adding the last day, fitting from scratch -%timeit m2 = Prophet().fit(df, init=get_stan_init(m1)) # Adding the last day, warm-starting from m1 +%timeit m2 = Prophet().fit(df, init=warm_start_params(m1)) # Adding the last day, warm-starting from m1 ``` 1.33 s ± 55.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 185 ms ± 4.46 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) diff --git a/notebooks/additional_topics.ipynb b/notebooks/additional_topics.ipynb index 2c2c51d..cc57019 100644 --- a/notebooks/additional_topics.ipynb +++ b/notebooks/additional_topics.ipynb @@ -216,11 +216,12 @@ } ], "source": [ - "def get_stan_init(m):\n", - " \"\"\"Retrieve parameters from a trained model.\n", - "\n", - " Retrieve parameters from a trained model in the format\n", - " used to initialize a new Stan model.\n", + "def warm_start_params(m):\n", + " \"\"\"\n", + " Retrieve parameters from a trained model in the format used to initialize a new Stan model.\n", + " Note that the new Stan model must have these same settings:\n", + " n_changepoints, seasonality features, mcmc sampling\n", + " for the retrieved parameters to be valid for the new model.\n", "\n", " Parameters\n", " ----------\n", @@ -249,7 +250,7 @@ "\n", "\n", "%timeit m2 = Prophet().fit(df) # Adding the last day, fitting from scratch\n", - "%timeit m2 = Prophet().fit(df, init=get_stan_init(m1)) # Adding the last day, warm-starting from m1" + "%timeit m2 = Prophet().fit(df, init=warm_start_params(m1)) # Adding the last day, warm-starting from m1" ] }, { @@ -275,7 +276,7 @@ "metadata": { "celltoolbar": "Edit Metadata", "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -289,7 +290,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.8.10 (default, Mar 25 2022, 22:18:25) \n[Clang 12.0.5 (clang-1205.0.22.11)]" + }, + "vscode": { + "interpreter": { + "hash": "9f764ff66e3236555d51a00749f5293db82082e341e153963b8b2deea93b52fc" + } } }, "nbformat": 4, diff --git a/python/prophet/models.py b/python/prophet/models.py index 55adb0d..3071f29 100644 --- a/python/prophet/models.py +++ b/python/prophet/models.py @@ -80,14 +80,14 @@ class CmdStanPyBackend(IStanBackend): def fit(self, stan_init, stan_data, **kwargs): if 'inits' not in kwargs and 'init' in kwargs: - kwargs['inits'], _ = self.prepare_data(kwargs['init'], stan_data) + stan_init = self.sanitize_custom_inits(stan_init, kwargs['init']) del kwargs['init'] - stan_init, stan_data = self.prepare_data(stan_init, stan_data) + inits_list, data_list = self.prepare_data(stan_init, stan_data) args = dict( - data=stan_data, - inits=stan_init, - algorithm='Newton' if stan_data['T'] < 100 else 'LBFGS', + data=data_list, + inits=inits_list, + algorithm='Newton' if data_list['T'] < 100 else 'LBFGS', iter=int(1e4), ) args.update(kwargs) @@ -109,22 +109,20 @@ class CmdStanPyBackend(IStanBackend): def sampling(self, stan_init, stan_data, samples, **kwargs) -> dict: if 'inits' not in kwargs and 'init' in kwargs: - kwargs['inits'], _ = self.prepare_data(kwargs['init'], stan_data) + stan_init = self.sanitize_custom_inits(stan_init, kwargs['init']) del kwargs['init'] - stan_init, stan_data = self.prepare_data(stan_init, stan_data) + inits_list, data_list = self.prepare_data(stan_init, stan_data) args = dict( - data=stan_data, - inits=stan_init, + data=data_list, + inits=inits_list, ) - if 'chains' not in kwargs: kwargs['chains'] = 4 iter_half = samples // 2 kwargs['iter_sampling'] = iter_half if 'iter_warmup' not in kwargs: kwargs['iter_warmup'] = iter_half - args.update(kwargs) self.stan_fit = self.model.sample(**args) @@ -143,8 +141,25 @@ class CmdStanPyBackend(IStanBackend): return params + @staticmethod + def sanitize_custom_inits(default_inits, custom_inits): + """Validate that custom inits have the correct type and shape, otherwise use defaults.""" + sanitized = {} + for param in ['k', 'm', 'sigma_obs']: + try: + sanitized[param] = float(custom_inits.get(param)) + except Exception: + sanitized[param] = default_inits[param] + for param in ['delta', 'beta']: + if default_inits[param].shape == custom_inits[param].shape: + sanitized[param] = custom_inits[param] + else: + sanitized[param] = default_inits[param] + return sanitized + @staticmethod def prepare_data(init, data) -> Tuple[dict, dict]: + """Converts np.ndarrays to lists that can be read by cmdstanpy.""" cmdstanpy_data = { 'T': data['T'], 'S': data['S'], diff --git a/python/prophet/tests/test_prophet.py b/python/prophet/tests/test_prophet.py index d7536d5..ab619e8 100644 --- a/python/prophet/tests/test_prophet.py +++ b/python/prophet/tests/test_prophet.py @@ -15,6 +15,7 @@ from unittest import TestCase, skipUnless import numpy as np import pandas as pd from prophet import Prophet +from prophet.utilities import warm_start_params DATA = pd.read_csv( @@ -892,50 +893,12 @@ class TestProphet(TestCase): }, ) - @staticmethod - def get_stan_init(m): - """Retrieve parameters from a trained model. - - Retrieve parameters from a trained model in the format - used to initialize a new Stan model. - - Parameters - ---------- - m: A trained model of the Prophet class. - - Returns - ------- - A Dictionary containing retrieved parameters of m. - """ - res = {} - for pname in ['k', 'm', 'sigma_obs']: - if m.mcmc_samples == 0: - res[pname] = m.params[pname][0][0] - else: - res[pname] = np.mean(m.params[pname]) - for pname in ['delta', 'beta']: - if m.mcmc_samples == 0: - res[pname] = m.params[pname][0] - else: - res[pname] = np.mean(m.params[pname], axis=0) - return res - def test_fit_warm_start(self): - previous_df = DATA.iloc[:500] - df = DATA.iloc[:510] - m = Prophet() - m = m.fit(previous_df) - m_params = self.get_stan_init(m) - m2 = Prophet() - m2 = m2.fit(df, init=m_params) + m = Prophet().fit(DATA.iloc[:500]) + m2 = Prophet().fit(DATA.iloc[:510], init=warm_start_params(m)) self.assertEqual(len(m2.params['delta'][0]), 25) def test_sampling_warm_start(self): - previous_df = DATA.iloc[:500] - df = DATA.iloc[:510] - m = Prophet(mcmc_samples=100) - m = m.fit(previous_df, show_progress=False) - m_params = self.get_stan_init(m) - m2 = Prophet(mcmc_samples=100) - m2 = m2.fit(df, init=m_params, show_progress=False) + m = Prophet(mcmc_samples=100).fit(DATA.iloc[:500], show_progress=False) + m2 = Prophet(mcmc_samples=100).fit(DATA.iloc[:510], init=warm_start_params(m), show_progress=False) self.assertEqual(m2.params['delta'].shape, (200, 25)) diff --git a/python/prophet/utilities.py b/python/prophet/utilities.py index 662f230..1a56488 100644 --- a/python/prophet/utilities.py +++ b/python/prophet/utilities.py @@ -75,3 +75,31 @@ def regressor_coefficients(m): coefs.append(record) return pd.DataFrame(coefs) + +def warm_start_params(m): + """ + Retrieve parameters from a trained model in the format used to initialize a new Stan model. + Note that the new Stan model must have these same settings: + n_changepoints, seasonality features, mcmc sampling + for the retrieved parameters to be valid for the new model. + + Parameters + ---------- + m: A trained model of the Prophet class. + + Returns + ------- + A Dictionary containing retrieved parameters of m. + """ + res = {} + for pname in ['k', 'm', 'sigma_obs']: + if m.mcmc_samples == 0: + res[pname] = m.params[pname][0][0] + else: + res[pname] = np.mean(m.params[pname]) + for pname in ['delta', 'beta']: + if m.mcmc_samples == 0: + res[pname] = m.params[pname][0] + else: + res[pname] = np.mean(m.params[pname], axis=0) + return res