Properly allow warm starts (#2335)

This commit is contained in:
Cuong Duong 2023-01-12 03:11:15 +11:00 committed by GitHub
parent e2b4ef3a8e
commit eb71d0117b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 80 additions and 22 deletions

View file

@ -82,7 +82,7 @@ A common setting for forecasting is fitting models that need to be updated as ad
```python
# Python
def stan_init(m):
def get_stan_init(m):
"""Retrieve parameters from a trained model.
Retrieve parameters from a trained model in the format
@ -95,13 +95,18 @@ def stan_init(m):
Returns
-------
A Dictionary containing retrieved parameters of m.
"""
res = {}
for pname in ['k', 'm', 'sigma_obs']:
res[pname] = m.params[pname][0][0]
if m.mcmc_samples == 0:
res[pname] = m.params[pname][0][0]
else:
res[pname] = np.mean(m.params[pname])
for pname in ['delta', 'beta']:
res[pname] = m.params[pname][0]
if m.mcmc_samples == 0:
res[pname] = m.params[pname][0]
else:
res[pname] = np.mean(m.params[pname], axis=0)
return res
df = pd.read_csv('https://raw.githubusercontent.com/facebook/prophet/main/examples/example_wp_log_peyton_manning.csv')
@ -110,7 +115,7 @@ m1 = Prophet().fit(df1) # A model fit to all data except the last day
%timeit m2 = Prophet().fit(df) # Adding the last day, fitting from scratch
%timeit m2 = Prophet().fit(df, init=stan_init(m1)) # Adding the last day, warm-starting from m1
%timeit m2 = Prophet().fit(df, init=get_stan_init(m1)) # Adding the last day, warm-starting from m1
```
1.33 s ± 55.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
185 ms ± 4.46 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

View file

@ -216,26 +216,31 @@
}
],
"source": [
"def stan_init(m):\n",
"def get_stan_init(m):\n",
" \"\"\"Retrieve parameters from a trained model.\n",
" \n",
"\n",
" Retrieve parameters from a trained model in the format\n",
" used to initialize a new Stan model.\n",
" \n",
"\n",
" Parameters\n",
" ----------\n",
" m: A trained model of the Prophet class.\n",
" \n",
"\n",
" Returns\n",
" -------\n",
" A Dictionary containing retrieved parameters of m.\n",
" \n",
" \"\"\"\n",
" res = {}\n",
" for pname in ['k', 'm', 'sigma_obs']:\n",
" res[pname] = m.params[pname][0][0]\n",
" if m.mcmc_samples == 0:\n",
" res[pname] = m.params[pname][0][0]\n",
" else:\n",
" res[pname] = np.mean(m.params[pname])\n",
" for pname in ['delta', 'beta']:\n",
" res[pname] = m.params[pname][0]\n",
" if m.mcmc_samples == 0:\n",
" res[pname] = m.params[pname][0]\n",
" else:\n",
" res[pname] = np.mean(m.params[pname], axis=0)\n",
" return res\n",
"\n",
"df = pd.read_csv('../examples/example_wp_log_peyton_manning.csv')\n",
@ -244,7 +249,7 @@
"\n",
"\n",
"%timeit m2 = Prophet().fit(df) # Adding the last day, fitting from scratch\n",
"%timeit m2 = Prophet().fit(df, init=stan_init(m1)) # Adding the last day, warm-starting from m1"
"%timeit m2 = Prophet().fit(df, init=get_stan_init(m1)) # Adding the last day, warm-starting from m1"
]
},
{

View file

@ -79,11 +79,11 @@ class CmdStanPyBackend(IStanBackend):
return cmdstanpy.CmdStanModel(exe_file=model_file)
def fit(self, stan_init, stan_data, **kwargs):
(stan_init, stan_data) = self.prepare_data(stan_init, stan_data)
if 'inits' not in kwargs and 'init' in kwargs:
kwargs['inits'] = self.prepare_data(kwargs['init'], stan_data)[0]
kwargs['inits'], _ = self.prepare_data(kwargs['init'], stan_data)
del kwargs['init']
stan_init, stan_data = self.prepare_data(stan_init, stan_data)
args = dict(
data=stan_data,
inits=stan_init,
@ -108,11 +108,11 @@ class CmdStanPyBackend(IStanBackend):
return params
def sampling(self, stan_init, stan_data, samples, **kwargs) -> dict:
(stan_init, stan_data) = self.prepare_data(stan_init, stan_data)
if 'inits' not in kwargs and 'init' in kwargs:
kwargs['inits'] = self.prepare_data(kwargs['init'], stan_data)[0]
kwargs['inits'], _ = self.prepare_data(kwargs['init'], stan_data)
del kwargs['init']
stan_init, stan_data = self.prepare_data(stan_init, stan_data)
args = dict(
data=stan_data,
inits=stan_init,

View file

@ -76,7 +76,7 @@ class TestProphet(TestCase):
forecaster = Prophet(mcmc_samples=500)
# chains adjusted from 4 to 7 to satisfy test for cmdstanpy
forecaster.fit(train, seed=1237861298, chains=7)
forecaster.fit(train, seed=1237861298, chains=7, show_progress=False)
np.random.seed(876543987)
future = forecaster.make_future_dataframe(days, include_history=False)
future = forecaster.predict(future)
@ -110,9 +110,9 @@ class TestProphet(TestCase):
N = DATA.shape[0]
train = DATA.head(N // 2)
future = DATA.tail(N // 2)
forecaster = Prophet(n_changepoints=0, mcmc_samples=100)
forecaster.fit(train)
forecaster.fit(train, show_progress=False)
forecaster.predict(future)
def test_fit_changepoint_not_in_history(self):
@ -891,3 +891,51 @@ class TestProphet(TestCase):
'holidays',
},
)
@staticmethod
def get_stan_init(m):
"""Retrieve parameters from a trained model.
Retrieve parameters from a trained model in the format
used to initialize a new Stan model.
Parameters
----------
m: A trained model of the Prophet class.
Returns
-------
A Dictionary containing retrieved parameters of m.
"""
res = {}
for pname in ['k', 'm', 'sigma_obs']:
if m.mcmc_samples == 0:
res[pname] = m.params[pname][0][0]
else:
res[pname] = np.mean(m.params[pname])
for pname in ['delta', 'beta']:
if m.mcmc_samples == 0:
res[pname] = m.params[pname][0]
else:
res[pname] = np.mean(m.params[pname], axis=0)
return res
def test_fit_warm_start(self):
previous_df = DATA.iloc[:500]
df = DATA.iloc[:510]
m = Prophet()
m = m.fit(previous_df)
m_params = self.get_stan_init(m)
m2 = Prophet()
m2 = m2.fit(df, init=m_params)
self.assertEqual(len(m2.params['delta'][0]), 25)
def test_sampling_warm_start(self):
previous_df = DATA.iloc[:500]
df = DATA.iloc[:510]
m = Prophet(mcmc_samples=100)
m = m.fit(previous_df, show_progress=False)
m_params = self.get_stan_init(m)
m2 = Prophet(mcmc_samples=100)
m2 = m2.fit(df, init=m_params, show_progress=False)
self.assertEqual(m2.params['delta'].shape, (200, 25))