From eb71d0117b7154e2b584e11fd2b7ea7aa13a2128 Mon Sep 17 00:00:00 2001 From: Cuong Duong Date: Thu, 12 Jan 2023 03:11:15 +1100 Subject: [PATCH] Properly allow warm starts (#2335) --- docs/_docs/additional_topics.md | 15 +++++--- notebooks/additional_topics.ipynb | 21 ++++++----- python/prophet/models.py | 12 +++---- python/prophet/tests/test_prophet.py | 54 ++++++++++++++++++++++++++-- 4 files changed, 80 insertions(+), 22 deletions(-) diff --git a/docs/_docs/additional_topics.md b/docs/_docs/additional_topics.md index 61ca88c..4f2638f 100644 --- a/docs/_docs/additional_topics.md +++ b/docs/_docs/additional_topics.md @@ -82,7 +82,7 @@ A common setting for forecasting is fitting models that need to be updated as ad ```python # Python -def stan_init(m): +def get_stan_init(m): """Retrieve parameters from a trained model. Retrieve parameters from a trained model in the format @@ -95,13 +95,18 @@ def stan_init(m): Returns ------- A Dictionary containing retrieved parameters of m. - """ res = {} for pname in ['k', 'm', 'sigma_obs']: - res[pname] = m.params[pname][0][0] + if m.mcmc_samples == 0: + res[pname] = m.params[pname][0][0] + else: + res[pname] = np.mean(m.params[pname]) for pname in ['delta', 'beta']: - res[pname] = m.params[pname][0] + if m.mcmc_samples == 0: + res[pname] = m.params[pname][0] + else: + res[pname] = np.mean(m.params[pname], axis=0) return res df = pd.read_csv('https://raw.githubusercontent.com/facebook/prophet/main/examples/example_wp_log_peyton_manning.csv') @@ -110,7 +115,7 @@ m1 = Prophet().fit(df1) # A model fit to all data except the last day %timeit m2 = Prophet().fit(df) # Adding the last day, fitting from scratch -%timeit m2 = Prophet().fit(df, init=stan_init(m1)) # Adding the last day, warm-starting from m1 +%timeit m2 = Prophet().fit(df, init=get_stan_init(m1)) # Adding the last day, warm-starting from m1 ``` 1.33 s ± 55.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 185 ms ± 4.46 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) diff --git a/notebooks/additional_topics.ipynb b/notebooks/additional_topics.ipynb index 7b8080b..2c2c51d 100644 --- a/notebooks/additional_topics.ipynb +++ b/notebooks/additional_topics.ipynb @@ -216,26 +216,31 @@ } ], "source": [ - "def stan_init(m):\n", + "def get_stan_init(m):\n", " \"\"\"Retrieve parameters from a trained model.\n", - " \n", + "\n", " Retrieve parameters from a trained model in the format\n", " used to initialize a new Stan model.\n", - " \n", + "\n", " Parameters\n", " ----------\n", " m: A trained model of the Prophet class.\n", - " \n", + "\n", " Returns\n", " -------\n", " A Dictionary containing retrieved parameters of m.\n", - " \n", " \"\"\"\n", " res = {}\n", " for pname in ['k', 'm', 'sigma_obs']:\n", - " res[pname] = m.params[pname][0][0]\n", + " if m.mcmc_samples == 0:\n", + " res[pname] = m.params[pname][0][0]\n", + " else:\n", + " res[pname] = np.mean(m.params[pname])\n", " for pname in ['delta', 'beta']:\n", - " res[pname] = m.params[pname][0]\n", + " if m.mcmc_samples == 0:\n", + " res[pname] = m.params[pname][0]\n", + " else:\n", + " res[pname] = np.mean(m.params[pname], axis=0)\n", " return res\n", "\n", "df = pd.read_csv('../examples/example_wp_log_peyton_manning.csv')\n", @@ -244,7 +249,7 @@ "\n", "\n", "%timeit m2 = Prophet().fit(df) # Adding the last day, fitting from scratch\n", - "%timeit m2 = Prophet().fit(df, init=stan_init(m1)) # Adding the last day, warm-starting from m1" + "%timeit m2 = Prophet().fit(df, init=get_stan_init(m1)) # Adding the last day, warm-starting from m1" ] }, { diff --git a/python/prophet/models.py b/python/prophet/models.py index 3c31eb5..55adb0d 100644 --- a/python/prophet/models.py +++ b/python/prophet/models.py @@ -79,11 +79,11 @@ class CmdStanPyBackend(IStanBackend): return cmdstanpy.CmdStanModel(exe_file=model_file) def fit(self, stan_init, stan_data, **kwargs): - (stan_init, stan_data) = self.prepare_data(stan_init, stan_data) - if 'inits' not in kwargs and 'init' in kwargs: - kwargs['inits'] = self.prepare_data(kwargs['init'], stan_data)[0] + kwargs['inits'], _ = self.prepare_data(kwargs['init'], stan_data) + del kwargs['init'] + stan_init, stan_data = self.prepare_data(stan_init, stan_data) args = dict( data=stan_data, inits=stan_init, @@ -108,11 +108,11 @@ class CmdStanPyBackend(IStanBackend): return params def sampling(self, stan_init, stan_data, samples, **kwargs) -> dict: - (stan_init, stan_data) = self.prepare_data(stan_init, stan_data) - if 'inits' not in kwargs and 'init' in kwargs: - kwargs['inits'] = self.prepare_data(kwargs['init'], stan_data)[0] + kwargs['inits'], _ = self.prepare_data(kwargs['init'], stan_data) + del kwargs['init'] + stan_init, stan_data = self.prepare_data(stan_init, stan_data) args = dict( data=stan_data, inits=stan_init, diff --git a/python/prophet/tests/test_prophet.py b/python/prophet/tests/test_prophet.py index 884ed96..d7536d5 100644 --- a/python/prophet/tests/test_prophet.py +++ b/python/prophet/tests/test_prophet.py @@ -76,7 +76,7 @@ class TestProphet(TestCase): forecaster = Prophet(mcmc_samples=500) # chains adjusted from 4 to 7 to satisfy test for cmdstanpy - forecaster.fit(train, seed=1237861298, chains=7) + forecaster.fit(train, seed=1237861298, chains=7, show_progress=False) np.random.seed(876543987) future = forecaster.make_future_dataframe(days, include_history=False) future = forecaster.predict(future) @@ -110,9 +110,9 @@ class TestProphet(TestCase): N = DATA.shape[0] train = DATA.head(N // 2) future = DATA.tail(N // 2) - + forecaster = Prophet(n_changepoints=0, mcmc_samples=100) - forecaster.fit(train) + forecaster.fit(train, show_progress=False) forecaster.predict(future) def test_fit_changepoint_not_in_history(self): @@ -891,3 +891,51 @@ class TestProphet(TestCase): 'holidays', }, ) + + @staticmethod + def get_stan_init(m): + """Retrieve parameters from a trained model. + + Retrieve parameters from a trained model in the format + used to initialize a new Stan model. + + Parameters + ---------- + m: A trained model of the Prophet class. + + Returns + ------- + A Dictionary containing retrieved parameters of m. + """ + res = {} + for pname in ['k', 'm', 'sigma_obs']: + if m.mcmc_samples == 0: + res[pname] = m.params[pname][0][0] + else: + res[pname] = np.mean(m.params[pname]) + for pname in ['delta', 'beta']: + if m.mcmc_samples == 0: + res[pname] = m.params[pname][0] + else: + res[pname] = np.mean(m.params[pname], axis=0) + return res + + def test_fit_warm_start(self): + previous_df = DATA.iloc[:500] + df = DATA.iloc[:510] + m = Prophet() + m = m.fit(previous_df) + m_params = self.get_stan_init(m) + m2 = Prophet() + m2 = m2.fit(df, init=m_params) + self.assertEqual(len(m2.params['delta'][0]), 25) + + def test_sampling_warm_start(self): + previous_df = DATA.iloc[:500] + df = DATA.iloc[:510] + m = Prophet(mcmc_samples=100) + m = m.fit(previous_df, show_progress=False) + m_params = self.get_stan_init(m) + m2 = Prophet(mcmc_samples=100) + m2 = m2.fit(df, init=m_params, show_progress=False) + self.assertEqual(m2.params['delta'].shape, (200, 25))